add overwriting pydantic types #311

This commit is contained in:
collerek
2021-08-27 16:02:20 +02:00
parent 40fe1ad864
commit 9bb22d2ea4
6 changed files with 102 additions and 3 deletions

View File

@ -158,6 +158,37 @@ Used for data related to given model but not to be stored in the database.
Used in pydantic only. Used in pydantic only.
## overwrite_pydantic_type
By default, ormar uses predefined pydantic field types that it applies on model creation (hence the type hints are optional).
If you want to, you can apply your own type, that will be **completely** replacing the build in one.
So it's on you as a user to provide a type that is valid in the context of given ormar field type.
!!!warning
Note that by default you should use build in arguments that are passed to underlying pydantic field.
You can check what arguments are supported in field types section or in [pydantic](https://pydantic-docs.helpmanual.io/usage/schema/#field-customisation) docs.
!!!danger
Setting a wrong type of pydantic field can break your model, so overwrite it only when you know what you are doing.
As it's easy to break functionality of ormar the `overwrite_pydantic_type` argument is not available on relation fields!
```python
# sample overwrites
class OverwriteTest(ormar.Model):
class Meta:
tablename = "overwrites"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
my_int: str = ormar.Integer(overwrite_pydantic_type=PositiveInt)
constraint_dict: Json = ormar.JSON(
overwrite_pydantic_type=Optional[Json[Dict[str, int]]])
```
## choices ## choices
`choices`: `Sequence` = `[]` `choices`: `Sequence` = `[]`

View File

@ -31,6 +31,7 @@ class BaseField(FieldInfo):
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
self.__type__: type = kwargs.pop("__type__", None) self.__type__: type = kwargs.pop("__type__", None)
self.__pydantic_type__: type = kwargs.pop("__pydantic_type__", None)
self.__sample__: type = kwargs.pop("__sample__", None) self.__sample__: type = kwargs.pop("__sample__", None)
self.related_name = kwargs.pop("related_name", None) self.related_name = kwargs.pop("related_name", None)

View File

@ -143,12 +143,14 @@ def validate_not_allowed_fields(kwargs: Dict) -> None:
encrypt_secret = kwargs.pop("encrypt_secret", None) encrypt_secret = kwargs.pop("encrypt_secret", None)
encrypt_backend = kwargs.pop("encrypt_backend", None) encrypt_backend = kwargs.pop("encrypt_backend", None)
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None) encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
overwrite_pydantic_type = kwargs.pop("overwrite_pydantic_type", None)
not_supported = [ not_supported = [
default, default,
encrypt_secret, encrypt_secret,
encrypt_backend, encrypt_backend,
encrypt_custom_backend, encrypt_custom_backend,
overwrite_pydantic_type,
] ]
if any(x is not None for x in not_supported): if any(x is not None for x in not_supported):
raise ModelDefinitionError( raise ModelDefinitionError(

View File

@ -84,8 +84,13 @@ class ModelFieldFactory:
encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE) encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE)
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None) encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
overwrite_pydantic_type = kwargs.pop("overwrite_pydantic_type", None)
namespace = dict( namespace = dict(
__type__=cls._type, __type__=cls._type,
__pydantic_type__=overwrite_pydantic_type
if overwrite_pydantic_type is not None
else cls._type,
__sample__=cls._sample, __sample__=cls._sample,
alias=kwargs.pop("name", None), alias=kwargs.pop("name", None),
name=None, name=None,

View File

@ -5,8 +5,8 @@ import pydantic
from pydantic.fields import ModelField from pydantic.fields import ModelField
from pydantic.utils import lenient_issubclass from pydantic.utils import lenient_issubclass
from ormar.fields import BaseField # noqa: I100, I202 from ormar.exceptions import ModelDefinitionError # noqa: I100, I202
from ormar.exceptions import ModelDefinitionError from ormar.fields import BaseField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -84,9 +84,15 @@ def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]:
for field_name, field in potential_fields.items(): for field_name, field in potential_fields.items():
field.name = field_name field.name = field_name
model_fields[field_name] = field model_fields[field_name] = field
attrs["__annotations__"][field_name] = ( default_type = (
field.__type__ if not field.nullable else Optional[field.__type__] field.__type__ if not field.nullable else Optional[field.__type__]
) )
overwrite_type = (
field.__pydantic_type__
if field.__type__ != field.__pydantic_type__
else None
)
attrs["__annotations__"][field_name] = overwrite_type or default_type
return attrs, model_fields return attrs, model_fields

View File

@ -0,0 +1,54 @@
from typing import Dict, Optional
import databases
import pytest
import sqlalchemy
from pydantic import Json, PositiveInt, ValidationError
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class OverwriteTest(ormar.Model):
class Meta:
tablename = "overwrites"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
my_int: str = ormar.Integer(overwrite_pydantic_type=PositiveInt)
constraint_dict: Json = ormar.JSON(
overwrite_pydantic_type=Optional[Json[Dict[str, int]]]
)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
def test_constraints():
with pytest.raises(ValidationError) as e:
OverwriteTest(my_int=-10)
assert "ensure this value is greater than 0" in str(e.value)
with pytest.raises(ValidationError) as e:
OverwriteTest(my_int=10, constraint_dict={"aa": "ab"})
assert "value is not a valid integer" in str(e.value)
@pytest.mark.asyncio
async def test_saving():
async with database:
await OverwriteTest(my_int=5, constraint_dict={"aa": 123}).save()
test = await OverwriteTest.objects.get()
assert test.my_int == 5
assert test.constraint_dict == {"aa": 123}