From 9bb22d2ea4071e1df94e7f95dfb78b0eec74e486 Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 27 Aug 2021 16:02:20 +0200 Subject: [PATCH] add overwriting pydantic types #311 --- docs/fields/common-parameters.md | 31 +++++++++++ ormar/fields/base.py | 1 + ormar/fields/foreign_key.py | 2 + ormar/fields/model_fields.py | 5 ++ ormar/models/helpers/pydantic.py | 12 +++-- .../test_overwriting_pydantic_field_type.py | 54 +++++++++++++++++++ 6 files changed, 102 insertions(+), 3 deletions(-) create mode 100644 tests/test_model_definition/test_overwriting_pydantic_field_type.py diff --git a/docs/fields/common-parameters.md b/docs/fields/common-parameters.md index d5a4c0d..044dfe9 100644 --- a/docs/fields/common-parameters.md +++ b/docs/fields/common-parameters.md @@ -158,6 +158,37 @@ Used for data related to given model but not to be stored in the database. Used in pydantic only. +## overwrite_pydantic_type + +By default, ormar uses predefined pydantic field types that it applies on model creation (hence the type hints are optional). + +If you want to, you can apply your own type, that will be **completely** replacing the build in one. +So it's on you as a user to provide a type that is valid in the context of given ormar field type. + +!!!warning + Note that by default you should use build in arguments that are passed to underlying pydantic field. + + You can check what arguments are supported in field types section or in [pydantic](https://pydantic-docs.helpmanual.io/usage/schema/#field-customisation) docs. + +!!!danger + Setting a wrong type of pydantic field can break your model, so overwrite it only when you know what you are doing. + + As it's easy to break functionality of ormar the `overwrite_pydantic_type` argument is not available on relation fields! + +```python +# sample overwrites +class OverwriteTest(ormar.Model): + class Meta: + tablename = "overwrites" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + my_int: str = ormar.Integer(overwrite_pydantic_type=PositiveInt) + constraint_dict: Json = ormar.JSON( + overwrite_pydantic_type=Optional[Json[Dict[str, int]]]) +``` + ## choices `choices`: `Sequence` = `[]` diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 3310d25..8eb4188 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -31,6 +31,7 @@ class BaseField(FieldInfo): def __init__(self, **kwargs: Any) -> None: self.__type__: type = kwargs.pop("__type__", None) + self.__pydantic_type__: type = kwargs.pop("__pydantic_type__", None) self.__sample__: type = kwargs.pop("__sample__", None) self.related_name = kwargs.pop("related_name", None) diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 707463f..de0774c 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -143,12 +143,14 @@ def validate_not_allowed_fields(kwargs: Dict) -> None: encrypt_secret = kwargs.pop("encrypt_secret", None) encrypt_backend = kwargs.pop("encrypt_backend", None) encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None) + overwrite_pydantic_type = kwargs.pop("overwrite_pydantic_type", None) not_supported = [ default, encrypt_secret, encrypt_backend, encrypt_custom_backend, + overwrite_pydantic_type, ] if any(x is not None for x in not_supported): raise ModelDefinitionError( diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index f1230b5..8478718 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -84,8 +84,13 @@ class ModelFieldFactory: encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE) encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None) + overwrite_pydantic_type = kwargs.pop("overwrite_pydantic_type", None) + namespace = dict( __type__=cls._type, + __pydantic_type__=overwrite_pydantic_type + if overwrite_pydantic_type is not None + else cls._type, __sample__=cls._sample, alias=kwargs.pop("name", None), name=None, diff --git a/ormar/models/helpers/pydantic.py b/ormar/models/helpers/pydantic.py index 87ed5d1..1423176 100644 --- a/ormar/models/helpers/pydantic.py +++ b/ormar/models/helpers/pydantic.py @@ -5,8 +5,8 @@ import pydantic from pydantic.fields import ModelField from pydantic.utils import lenient_issubclass -from ormar.fields import BaseField # noqa: I100, I202 -from ormar.exceptions import ModelDefinitionError +from ormar.exceptions import ModelDefinitionError # noqa: I100, I202 +from ormar.fields import BaseField if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -84,9 +84,15 @@ def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]: for field_name, field in potential_fields.items(): field.name = field_name model_fields[field_name] = field - attrs["__annotations__"][field_name] = ( + default_type = ( field.__type__ if not field.nullable else Optional[field.__type__] ) + overwrite_type = ( + field.__pydantic_type__ + if field.__type__ != field.__pydantic_type__ + else None + ) + attrs["__annotations__"][field_name] = overwrite_type or default_type return attrs, model_fields diff --git a/tests/test_model_definition/test_overwriting_pydantic_field_type.py b/tests/test_model_definition/test_overwriting_pydantic_field_type.py new file mode 100644 index 0000000..0b77592 --- /dev/null +++ b/tests/test_model_definition/test_overwriting_pydantic_field_type.py @@ -0,0 +1,54 @@ +from typing import Dict, Optional + +import databases +import pytest +import sqlalchemy +from pydantic import Json, PositiveInt, ValidationError + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class OverwriteTest(ormar.Model): + class Meta: + tablename = "overwrites" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + my_int: str = ormar.Integer(overwrite_pydantic_type=PositiveInt) + constraint_dict: Json = ormar.JSON( + overwrite_pydantic_type=Optional[Json[Dict[str, int]]] + ) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_constraints(): + with pytest.raises(ValidationError) as e: + OverwriteTest(my_int=-10) + assert "ensure this value is greater than 0" in str(e.value) + + with pytest.raises(ValidationError) as e: + OverwriteTest(my_int=10, constraint_dict={"aa": "ab"}) + assert "value is not a valid integer" in str(e.value) + + +@pytest.mark.asyncio +async def test_saving(): + async with database: + await OverwriteTest(my_int=5, constraint_dict={"aa": 123}).save() + + test = await OverwriteTest.objects.get() + assert test.my_int == 5 + assert test.constraint_dict == {"aa": 123}