Merge pull request #325 from collerek/add_force_overwrite_pydantic_type

Add force overwrite pydantic type
This commit is contained in:
collerek
2021-08-30 09:53:21 +02:00
committed by GitHub
10 changed files with 127 additions and 17 deletions

View File

@ -158,6 +158,37 @@ Used for data related to given model but not to be stored in the database.
Used in pydantic only. Used in pydantic only.
## overwrite_pydantic_type
By default, ormar uses predefined pydantic field types that it applies on model creation (hence the type hints are optional).
If you want to, you can apply your own type, that will be **completely** replacing the build in one.
So it's on you as a user to provide a type that is valid in the context of given ormar field type.
!!!warning
Note that by default you should use build in arguments that are passed to underlying pydantic field.
You can check what arguments are supported in field types section or in [pydantic](https://pydantic-docs.helpmanual.io/usage/schema/#field-customisation) docs.
!!!danger
Setting a wrong type of pydantic field can break your model, so overwrite it only when you know what you are doing.
As it's easy to break functionality of ormar the `overwrite_pydantic_type` argument is not available on relation fields!
```python
# sample overwrites
class OverwriteTest(ormar.Model):
class Meta:
tablename = "overwrites"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
my_int: str = ormar.Integer(overwrite_pydantic_type=PositiveInt)
constraint_dict: Json = ormar.JSON(
overwrite_pydantic_type=Optional[Json[Dict[str, int]]])
```
## choices ## choices
`choices`: `Sequence` = `[]` `choices`: `Sequence` = `[]`

View File

@ -1,3 +1,10 @@
# 0.10.17
## ✨ Features
* Allow overwriting the default pydantic type for model fields [#312](https://github.com/collerek/ormar/issues/285)
* Add support for `sqlalchemy` >=1.4 (requires `databases` >= 0.5.0) [#142](https://github.com/collerek/ormar/issues/142)
# 0.10.16 # 0.10.16
## ✨ Features ## ✨ Features

View File

@ -77,7 +77,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType() Undefined = UndefinedType()
__version__ = "0.10.16" __version__ = "0.10.17"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",

View File

@ -31,6 +31,7 @@ class BaseField(FieldInfo):
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
self.__type__: type = kwargs.pop("__type__", None) self.__type__: type = kwargs.pop("__type__", None)
self.__pydantic_type__: type = kwargs.pop("__pydantic_type__", None)
self.__sample__: type = kwargs.pop("__sample__", None) self.__sample__: type = kwargs.pop("__sample__", None)
self.related_name = kwargs.pop("related_name", None) self.related_name = kwargs.pop("related_name", None)

View File

@ -143,12 +143,14 @@ def validate_not_allowed_fields(kwargs: Dict) -> None:
encrypt_secret = kwargs.pop("encrypt_secret", None) encrypt_secret = kwargs.pop("encrypt_secret", None)
encrypt_backend = kwargs.pop("encrypt_backend", None) encrypt_backend = kwargs.pop("encrypt_backend", None)
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None) encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
overwrite_pydantic_type = kwargs.pop("overwrite_pydantic_type", None)
not_supported = [ not_supported = [
default, default,
encrypt_secret, encrypt_secret,
encrypt_backend, encrypt_backend,
encrypt_custom_backend, encrypt_custom_backend,
overwrite_pydantic_type,
] ]
if any(x is not None for x in not_supported): if any(x is not None for x in not_supported):
raise ModelDefinitionError( raise ModelDefinitionError(

View File

@ -84,8 +84,13 @@ class ModelFieldFactory:
encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE) encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE)
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None) encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
overwrite_pydantic_type = kwargs.pop("overwrite_pydantic_type", None)
namespace = dict( namespace = dict(
__type__=cls._type, __type__=cls._type,
__pydantic_type__=overwrite_pydantic_type
if overwrite_pydantic_type is not None
else cls._type,
__sample__=cls._sample, __sample__=cls._sample,
alias=kwargs.pop("name", None), alias=kwargs.pop("name", None),
name=None, name=None,

View File

@ -5,8 +5,8 @@ import pydantic
from pydantic.fields import ModelField from pydantic.fields import ModelField
from pydantic.utils import lenient_issubclass from pydantic.utils import lenient_issubclass
from ormar.fields import BaseField # noqa: I100, I202 from ormar.exceptions import ModelDefinitionError # noqa: I100, I202
from ormar.exceptions import ModelDefinitionError from ormar.fields import BaseField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -84,9 +84,15 @@ def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]:
for field_name, field in potential_fields.items(): for field_name, field in potential_fields.items():
field.name = field_name field.name = field_name
model_fields[field_name] = field model_fields[field_name] = field
attrs["__annotations__"][field_name] = ( default_type = (
field.__type__ if not field.nullable else Optional[field.__type__] field.__type__ if not field.nullable else Optional[field.__type__]
) )
overwrite_type = (
field.__pydantic_type__
if field.__type__ != field.__pydantic_type__
else None
)
attrs["__annotations__"][field_name] = overwrite_type or default_type
return attrs, model_fields return attrs, model_fields

View File

@ -10,7 +10,11 @@ from typing import (
cast, cast,
) )
import sqlalchemy try:
from sqlalchemy.engine.result import ResultProxy
except ImportError: # pragma: no cover
from sqlalchemy.engine.result import Row as ResultProxy # type: ignore
from ormar.models import NewBaseModel # noqa: I202 from ormar.models import NewBaseModel # noqa: I202
from ormar.models.excludable import ExcludableItems from ormar.models.excludable import ExcludableItems
@ -25,7 +29,7 @@ class ModelRow(NewBaseModel):
@classmethod @classmethod
def from_row( # noqa: CFQ002 def from_row( # noqa: CFQ002
cls, cls,
row: sqlalchemy.engine.ResultProxy, row: ResultProxy,
source_model: Type["Model"], source_model: Type["Model"],
select_related: List = None, select_related: List = None,
related_models: Any = None, related_models: Any = None,
@ -59,7 +63,7 @@ class ModelRow(NewBaseModel):
:param source_model: model on which relation was defined :param source_model: model on which relation was defined
:type source_model: Type[Model] :type source_model: Type[Model]
:param row: raw result row from the database :param row: raw result row from the database
:type row: sqlalchemy.engine.result.ResultProxy :type row: ResultProxy
:param select_related: list of names of related models fetched from database :param select_related: list of names of related models fetched from database
:type select_related: List :type select_related: List
:param related_models: list or dict of related models :param related_models: list or dict of related models
@ -153,7 +157,7 @@ class ModelRow(NewBaseModel):
def _populate_nested_models_from_row( # noqa: CFQ002 def _populate_nested_models_from_row( # noqa: CFQ002
cls, cls,
item: dict, item: dict,
row: sqlalchemy.engine.ResultProxy, row: ResultProxy,
source_model: Type["Model"], source_model: Type["Model"],
related_models: Any, related_models: Any,
excludable: ExcludableItems, excludable: ExcludableItems,
@ -183,7 +187,7 @@ class ModelRow(NewBaseModel):
:param item: dictionary of already populated nested models, otherwise empty dict :param item: dictionary of already populated nested models, otherwise empty dict
:type item: Dict :type item: Dict
:param row: raw result row from the database :param row: raw result row from the database
:type row: sqlalchemy.engine.result.ResultProxy :type row: ResultProxy
:param related_models: list or dict of related models :param related_models: list or dict of related models
:type related_models: Union[Dict, List] :type related_models: Union[Dict, List]
:return: dictionary with keys corresponding to model fields names :return: dictionary with keys corresponding to model fields names
@ -263,7 +267,7 @@ class ModelRow(NewBaseModel):
@classmethod @classmethod
def _populate_through_instance( # noqa: CFQ002 def _populate_through_instance( # noqa: CFQ002
cls, cls,
row: sqlalchemy.engine.ResultProxy, row: ResultProxy,
item: Dict, item: Dict,
related: str, related: str,
excludable: ExcludableItems, excludable: ExcludableItems,
@ -275,7 +279,7 @@ class ModelRow(NewBaseModel):
Normally it's child class, unless the query is from queryset. Normally it's child class, unless the query is from queryset.
:param row: row from db result :param row: row from db result
:type row: sqlalchemy.engine.ResultProxy :type row: ResultProxy
:param item: parent item dict :param item: parent item dict
:type item: Dict :type item: Dict
:param related: current relation name :param related: current relation name
@ -301,7 +305,7 @@ class ModelRow(NewBaseModel):
@classmethod @classmethod
def _create_through_instance( def _create_through_instance(
cls, cls,
row: sqlalchemy.engine.ResultProxy, row: ResultProxy,
through_name: str, through_name: str,
related: str, related: str,
excludable: ExcludableItems, excludable: ExcludableItems,
@ -343,7 +347,7 @@ class ModelRow(NewBaseModel):
def extract_prefixed_table_columns( def extract_prefixed_table_columns(
cls, cls,
item: dict, item: dict,
row: sqlalchemy.engine.result.ResultProxy, row: ResultProxy,
table_prefix: str, table_prefix: str,
excludable: ExcludableItems, excludable: ExcludableItems,
) -> Dict: ) -> Dict:

View File

@ -1,8 +1,8 @@
databases[sqlite]>=0.3.2,<=0.4.1 databases[sqlite]>=0.3.2,<0.5.1
databases[postgresql]>=0.3.2,<=0.4.1 databases[postgresql]>=0.3.2,<0.5.1
databases[mysql]>=0.3.2,<=0.4.1 databases[mysql]>=0.3.2,<0.5.1
pydantic >=1.6.1,!=1.7,!=1.7.1,!=1.7.2,!=1.7.3,!=1.8,!=1.8.1,<=1.8.2 pydantic >=1.6.1,!=1.7,!=1.7.1,!=1.7.2,!=1.7.3,!=1.8,!=1.8.1,<=1.8.2
sqlalchemy>=1.3.18,<=1.3.23 sqlalchemy>=1.3.18,<=1.4.23
typing_extensions>=3.7,<=3.7.4.3 typing_extensions>=3.7,<=3.7.4.3
orjson orjson
cryptography cryptography

View File

@ -0,0 +1,54 @@
from typing import Dict, Optional
import databases
import pytest
import sqlalchemy
from pydantic import Json, PositiveInt, ValidationError
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class OverwriteTest(ormar.Model):
class Meta:
tablename = "overwrites"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
my_int: str = ormar.Integer(overwrite_pydantic_type=PositiveInt)
constraint_dict: Json = ormar.JSON(
overwrite_pydantic_type=Optional[Json[Dict[str, int]]]
)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
def test_constraints():
with pytest.raises(ValidationError) as e:
OverwriteTest(my_int=-10)
assert "ensure this value is greater than 0" in str(e.value)
with pytest.raises(ValidationError) as e:
OverwriteTest(my_int=10, constraint_dict={"aa": "ab"})
assert "value is not a valid integer" in str(e.value)
@pytest.mark.asyncio
async def test_saving():
async with database:
await OverwriteTest(my_int=5, constraint_dict={"aa": 123}).save()
test = await OverwriteTest.objects.get()
assert test.my_int == 5
assert test.constraint_dict == {"aa": 123}