diff --git a/docs/fields/common-parameters.md b/docs/fields/common-parameters.md index d5a4c0d..044dfe9 100644 --- a/docs/fields/common-parameters.md +++ b/docs/fields/common-parameters.md @@ -158,6 +158,37 @@ Used for data related to given model but not to be stored in the database. Used in pydantic only. +## overwrite_pydantic_type + +By default, ormar uses predefined pydantic field types that it applies on model creation (hence the type hints are optional). + +If you want to, you can apply your own type, that will be **completely** replacing the build in one. +So it's on you as a user to provide a type that is valid in the context of given ormar field type. + +!!!warning + Note that by default you should use build in arguments that are passed to underlying pydantic field. + + You can check what arguments are supported in field types section or in [pydantic](https://pydantic-docs.helpmanual.io/usage/schema/#field-customisation) docs. + +!!!danger + Setting a wrong type of pydantic field can break your model, so overwrite it only when you know what you are doing. + + As it's easy to break functionality of ormar the `overwrite_pydantic_type` argument is not available on relation fields! + +```python +# sample overwrites +class OverwriteTest(ormar.Model): + class Meta: + tablename = "overwrites" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + my_int: str = ormar.Integer(overwrite_pydantic_type=PositiveInt) + constraint_dict: Json = ormar.JSON( + overwrite_pydantic_type=Optional[Json[Dict[str, int]]]) +``` + ## choices `choices`: `Sequence` = `[]` diff --git a/docs/releases.md b/docs/releases.md index a891e9e..c488d01 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,10 @@ +# 0.10.17 + +## ✨ Features + +* Allow overwriting the default pydantic type for model fields [#312](https://github.com/collerek/ormar/issues/285) +* Add support for `sqlalchemy` >=1.4 (requires `databases` >= 0.5.0) [#142](https://github.com/collerek/ormar/issues/142) + # 0.10.16 ## ✨ Features diff --git a/ormar/__init__.py b/ormar/__init__.py index 1542801..222d30c 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -77,7 +77,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.10.16" +__version__ = "0.10.17" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 3310d25..8eb4188 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -31,6 +31,7 @@ class BaseField(FieldInfo): def __init__(self, **kwargs: Any) -> None: self.__type__: type = kwargs.pop("__type__", None) + self.__pydantic_type__: type = kwargs.pop("__pydantic_type__", None) self.__sample__: type = kwargs.pop("__sample__", None) self.related_name = kwargs.pop("related_name", None) diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 707463f..de0774c 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -143,12 +143,14 @@ def validate_not_allowed_fields(kwargs: Dict) -> None: encrypt_secret = kwargs.pop("encrypt_secret", None) encrypt_backend = kwargs.pop("encrypt_backend", None) encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None) + overwrite_pydantic_type = kwargs.pop("overwrite_pydantic_type", None) not_supported = [ default, encrypt_secret, encrypt_backend, encrypt_custom_backend, + overwrite_pydantic_type, ] if any(x is not None for x in not_supported): raise ModelDefinitionError( diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index f1230b5..8478718 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -84,8 +84,13 @@ class ModelFieldFactory: encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE) encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None) + overwrite_pydantic_type = kwargs.pop("overwrite_pydantic_type", None) + namespace = dict( __type__=cls._type, + __pydantic_type__=overwrite_pydantic_type + if overwrite_pydantic_type is not None + else cls._type, __sample__=cls._sample, alias=kwargs.pop("name", None), name=None, diff --git a/ormar/models/helpers/pydantic.py b/ormar/models/helpers/pydantic.py index 87ed5d1..1423176 100644 --- a/ormar/models/helpers/pydantic.py +++ b/ormar/models/helpers/pydantic.py @@ -5,8 +5,8 @@ import pydantic from pydantic.fields import ModelField from pydantic.utils import lenient_issubclass -from ormar.fields import BaseField # noqa: I100, I202 -from ormar.exceptions import ModelDefinitionError +from ormar.exceptions import ModelDefinitionError # noqa: I100, I202 +from ormar.fields import BaseField if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -84,9 +84,15 @@ def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]: for field_name, field in potential_fields.items(): field.name = field_name model_fields[field_name] = field - attrs["__annotations__"][field_name] = ( + default_type = ( field.__type__ if not field.nullable else Optional[field.__type__] ) + overwrite_type = ( + field.__pydantic_type__ + if field.__type__ != field.__pydantic_type__ + else None + ) + attrs["__annotations__"][field_name] = overwrite_type or default_type return attrs, model_fields diff --git a/ormar/models/model_row.py b/ormar/models/model_row.py index 3c2dd1c..a9f0680 100644 --- a/ormar/models/model_row.py +++ b/ormar/models/model_row.py @@ -10,7 +10,11 @@ from typing import ( cast, ) -import sqlalchemy +try: + from sqlalchemy.engine.result import ResultProxy +except ImportError: # pragma: no cover + from sqlalchemy.engine.result import Row as ResultProxy # type: ignore + from ormar.models import NewBaseModel # noqa: I202 from ormar.models.excludable import ExcludableItems @@ -25,7 +29,7 @@ class ModelRow(NewBaseModel): @classmethod def from_row( # noqa: CFQ002 cls, - row: sqlalchemy.engine.ResultProxy, + row: ResultProxy, source_model: Type["Model"], select_related: List = None, related_models: Any = None, @@ -59,7 +63,7 @@ class ModelRow(NewBaseModel): :param source_model: model on which relation was defined :type source_model: Type[Model] :param row: raw result row from the database - :type row: sqlalchemy.engine.result.ResultProxy + :type row: ResultProxy :param select_related: list of names of related models fetched from database :type select_related: List :param related_models: list or dict of related models @@ -153,7 +157,7 @@ class ModelRow(NewBaseModel): def _populate_nested_models_from_row( # noqa: CFQ002 cls, item: dict, - row: sqlalchemy.engine.ResultProxy, + row: ResultProxy, source_model: Type["Model"], related_models: Any, excludable: ExcludableItems, @@ -183,7 +187,7 @@ class ModelRow(NewBaseModel): :param item: dictionary of already populated nested models, otherwise empty dict :type item: Dict :param row: raw result row from the database - :type row: sqlalchemy.engine.result.ResultProxy + :type row: ResultProxy :param related_models: list or dict of related models :type related_models: Union[Dict, List] :return: dictionary with keys corresponding to model fields names @@ -263,7 +267,7 @@ class ModelRow(NewBaseModel): @classmethod def _populate_through_instance( # noqa: CFQ002 cls, - row: sqlalchemy.engine.ResultProxy, + row: ResultProxy, item: Dict, related: str, excludable: ExcludableItems, @@ -275,7 +279,7 @@ class ModelRow(NewBaseModel): Normally it's child class, unless the query is from queryset. :param row: row from db result - :type row: sqlalchemy.engine.ResultProxy + :type row: ResultProxy :param item: parent item dict :type item: Dict :param related: current relation name @@ -301,7 +305,7 @@ class ModelRow(NewBaseModel): @classmethod def _create_through_instance( cls, - row: sqlalchemy.engine.ResultProxy, + row: ResultProxy, through_name: str, related: str, excludable: ExcludableItems, @@ -343,7 +347,7 @@ class ModelRow(NewBaseModel): def extract_prefixed_table_columns( cls, item: dict, - row: sqlalchemy.engine.result.ResultProxy, + row: ResultProxy, table_prefix: str, excludable: ExcludableItems, ) -> Dict: diff --git a/requirements.txt b/requirements.txt index abdd726..696c4aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -databases[sqlite]>=0.3.2,<=0.4.1 -databases[postgresql]>=0.3.2,<=0.4.1 -databases[mysql]>=0.3.2,<=0.4.1 +databases[sqlite]>=0.3.2,<0.5.1 +databases[postgresql]>=0.3.2,<0.5.1 +databases[mysql]>=0.3.2,<0.5.1 pydantic >=1.6.1,!=1.7,!=1.7.1,!=1.7.2,!=1.7.3,!=1.8,!=1.8.1,<=1.8.2 -sqlalchemy>=1.3.18,<=1.3.23 +sqlalchemy>=1.3.18,<=1.4.23 typing_extensions>=3.7,<=3.7.4.3 orjson cryptography diff --git a/tests/test_model_definition/test_overwriting_pydantic_field_type.py b/tests/test_model_definition/test_overwriting_pydantic_field_type.py new file mode 100644 index 0000000..0b77592 --- /dev/null +++ b/tests/test_model_definition/test_overwriting_pydantic_field_type.py @@ -0,0 +1,54 @@ +from typing import Dict, Optional + +import databases +import pytest +import sqlalchemy +from pydantic import Json, PositiveInt, ValidationError + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class OverwriteTest(ormar.Model): + class Meta: + tablename = "overwrites" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + my_int: str = ormar.Integer(overwrite_pydantic_type=PositiveInt) + constraint_dict: Json = ormar.JSON( + overwrite_pydantic_type=Optional[Json[Dict[str, int]]] + ) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_constraints(): + with pytest.raises(ValidationError) as e: + OverwriteTest(my_int=-10) + assert "ensure this value is greater than 0" in str(e.value) + + with pytest.raises(ValidationError) as e: + OverwriteTest(my_int=10, constraint_dict={"aa": "ab"}) + assert "value is not a valid integer" in str(e.value) + + +@pytest.mark.asyncio +async def test_saving(): + async with database: + await OverwriteTest(my_int=5, constraint_dict={"aa": 123}).save() + + test = await OverwriteTest.objects.get() + assert test.my_int == 5 + assert test.constraint_dict == {"aa": 123}