refactor choices validation from root validator to field validator

This commit is contained in:
collerek
2021-10-10 14:11:25 +02:00
parent d992f3dc3b
commit d8f0dc92f0
11 changed files with 168 additions and 124 deletions

View File

@ -61,6 +61,7 @@ from ormar.fields import (
LargeBinary, LargeBinary,
ManyToMany, ManyToMany,
ManyToManyField, ManyToManyField,
SQL_ENCODERS_MAP,
SmallInteger, SmallInteger,
String, String,
Text, Text,
@ -132,6 +133,7 @@ __all__ = [
"or_", "or_",
"EncryptBackends", "EncryptBackends",
"ENCODERS_MAP", "ENCODERS_MAP",
"SQL_ENCODERS_MAP",
"DECODERS_MAP", "DECODERS_MAP",
"LargeBinary", "LargeBinary",
"Extra", "Extra",

View File

@ -24,7 +24,7 @@ from ormar.fields.model_fields import (
Time, Time,
UUID, UUID,
) )
from ormar.fields.parsers import DECODERS_MAP, ENCODERS_MAP from ormar.fields.parsers import DECODERS_MAP, ENCODERS_MAP, SQL_ENCODERS_MAP
from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends
from ormar.fields.through_field import Through, ThroughField from ormar.fields.through_field import Through, ThroughField
@ -54,6 +54,7 @@ __all__ = [
"EncryptBackend", "EncryptBackend",
"DECODERS_MAP", "DECODERS_MAP",
"ENCODERS_MAP", "ENCODERS_MAP",
"SQL_ENCODERS_MAP",
"LargeBinary", "LargeBinary",
"UniqueColumns", "UniqueColumns",
] ]

View File

@ -1,11 +1,13 @@
import datetime import datetime
import decimal import decimal
import uuid import uuid
from typing import Any, Optional, TYPE_CHECKING, Union, overload from enum import Enum
from typing import Any, Optional, Set, TYPE_CHECKING, Type, Union, overload
import pydantic import pydantic
import sqlalchemy import sqlalchemy
import ormar # noqa I101
from ormar import ModelDefinitionError # noqa I101 from ormar import ModelDefinitionError # noqa I101
from ormar.fields import sqlalchemy_uuid from ormar.fields import sqlalchemy_uuid
from ormar.fields.base import BaseField # noqa I101 from ormar.fields.base import BaseField # noqa I101
@ -60,6 +62,39 @@ def is_auto_primary_key(primary_key: bool, autoincrement: bool) -> bool:
return primary_key and autoincrement return primary_key and autoincrement
def convert_choices_if_needed(
field_type: "Type", choices: Set, nullable: bool, scale: int = None
) -> Set:
"""
Converts dates to isoformat as fastapi can check this condition in routes
and the fields are not yet parsed.
Converts enums to list of it's values.
Converts uuids to strings.
Converts decimal to float with given scale.
:param field_type: type o the field
:type field_type: Type
:param nullable: set of choices
:type nullable: Set
:param scale: scale for decimals
:type scale: int
:param scale: scale for decimals
:type scale: int
:return: value, choices list
:rtype: Tuple[Any, Set]
"""
choices = {o.value if isinstance(o, Enum) else o for o in choices}
encoder = ormar.ENCODERS_MAP.get(field_type, lambda x: x)
if field_type == decimal.Decimal:
precision = scale
choices = {encoder(o, precision) for o in choices}
elif encoder:
choices = {encoder(o) for o in choices}
if nullable:
choices.add(None)
return choices
class ModelFieldFactory: class ModelFieldFactory:
""" """
Default field factory that construct Field classes and populated their values. Default field factory that construct Field classes and populated their values.
@ -96,6 +131,15 @@ class ModelFieldFactory:
else (nullable if sql_nullable is None else sql_nullable) else (nullable if sql_nullable is None else sql_nullable)
) )
choices = set(kwargs.pop("choices", []))
if choices:
choices = convert_choices_if_needed(
field_type=cls._type,
choices=choices,
nullable=nullable,
scale=kwargs.get("scale", None),
)
namespace = dict( namespace = dict(
__type__=cls._type, __type__=cls._type,
__pydantic_type__=overwrite_pydantic_type __pydantic_type__=overwrite_pydantic_type
@ -114,7 +158,7 @@ class ModelFieldFactory:
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
autoincrement=autoincrement, autoincrement=autoincrement,
column_type=cls.get_column_type(**kwargs), column_type=cls.get_column_type(**kwargs),
choices=set(kwargs.pop("choices", [])), choices=choices,
encrypt_secret=encrypt_secret, encrypt_secret=encrypt_secret,
encrypt_backend=encrypt_backend, encrypt_backend=encrypt_backend,
encrypt_custom_backend=encrypt_custom_backend, encrypt_custom_backend=encrypt_custom_backend,

View File

@ -1,6 +1,8 @@
import base64
import datetime import datetime
import decimal import decimal
from typing import Any import uuid
from typing import Any, Callable, Dict, Union
import pydantic import pydantic
from pydantic.datetime_parse import parse_date, parse_datetime, parse_time from pydantic.datetime_parse import parse_date, parse_datetime, parse_time
@ -19,21 +21,55 @@ def encode_bool(value: bool) -> str:
return "true" if value else "false" return "true" if value else "false"
def encode_decimal(value: decimal.Decimal, precision: int = None) -> float:
if precision:
return (
round(float(value), precision)
if isinstance(value, decimal.Decimal)
else value
)
return float(value)
def encode_bytes(value: Union[str, bytes], represent_as_string: bool = False) -> bytes:
if represent_as_string:
return value if isinstance(value, bytes) else base64.b64decode(value)
return value if isinstance(value, bytes) else value.encode("utf-8")
def encode_json(value: Any) -> str: def encode_json(value: Any) -> str:
value = json.dumps(value) if not isinstance(value, str) else value value = json.dumps(value) if not isinstance(value, str) else re_dump_value(value)
value = value.decode("utf-8") if isinstance(value, bytes) else value value = value.decode("utf-8") if isinstance(value, bytes) else value
return value return value
ENCODERS_MAP = { def re_dump_value(value: str) -> Union[str, bytes]:
bool: encode_bool, """
Rw-dumps choices due to different string representation in orjson and json
:param value: string to re-dump
:type value: str
:return: re-dumped choices
:rtype: List[str]
"""
try:
result: Union[str, bytes] = json.dumps(json.loads(value))
except json.JSONDecodeError:
result = value
return result
ENCODERS_MAP: Dict[type, Callable] = {
datetime.datetime: lambda x: x.isoformat(), datetime.datetime: lambda x: x.isoformat(),
datetime.date: lambda x: x.isoformat(), datetime.date: lambda x: x.isoformat(),
datetime.time: lambda x: x.isoformat(), datetime.time: lambda x: x.isoformat(),
pydantic.Json: encode_json, pydantic.Json: encode_json,
decimal.Decimal: float, decimal.Decimal: encode_decimal,
uuid.UUID: str,
bytes: encode_bytes,
} }
SQL_ENCODERS_MAP: Dict[type, Callable] = {bool: encode_bool, **ENCODERS_MAP}
DECODERS_MAP = { DECODERS_MAP = {
bool: parse_bool, bool: parse_bool,
datetime.datetime: parse_datetime, datetime.datetime: parse_datetime,

View File

@ -160,7 +160,7 @@ class EncryptedString(types.TypeDecorator):
try: try:
value = self._underlying_type.process_bind_param(value, dialect) value = self._underlying_type.process_bind_param(value, dialect)
except AttributeError: except AttributeError:
encoder = ormar.ENCODERS_MAP.get(self.type_, None) encoder = ormar.SQL_ENCODERS_MAP.get(self.type_, None)
if encoder: if encoder:
value = encoder(value) # type: ignore value = encoder(value) # type: ignore

View File

@ -1,30 +1,37 @@
import base64 import base64
import datetime
import decimal import decimal
import numbers import numbers
import uuid from typing import (
from enum import Enum Any,
from typing import Any, Dict, List, Set, TYPE_CHECKING, Tuple, Type, Union Callable,
Dict,
List,
Set,
TYPE_CHECKING,
Type,
Union,
)
try: try:
import orjson as json import orjson as json
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
import json # type: ignore import json # type: ignore # noqa: F401
import pydantic import pydantic
from pydantic.fields import SHAPE_LIST from pydantic.class_validators import make_generic_validator
from pydantic.fields import ModelField, SHAPE_LIST
from pydantic.main import SchemaExtraCallable from pydantic.main import SchemaExtraCallable
import ormar # noqa: I100, I202 import ormar # noqa: I100, I202
from ormar.fields import BaseField
from ormar.models.helpers.models import meta_field_not_set from ormar.models.helpers.models import meta_field_not_set
from ormar.queryset.utils import translate_list_to_dict from ormar.queryset.utils import translate_list_to_dict
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
from ormar.fields import BaseField
def check_if_field_has_choices(field: BaseField) -> bool: def check_if_field_has_choices(field: "BaseField") -> bool:
""" """
Checks if given field has choices populated. Checks if given field has choices populated.
A if it has one, a validator for this field needs to be attached. A if it has one, a validator for this field needs to be attached.
@ -37,110 +44,53 @@ def check_if_field_has_choices(field: BaseField) -> bool:
return hasattr(field, "choices") and bool(field.choices) return hasattr(field, "choices") and bool(field.choices)
def convert_choices_if_needed( # noqa: CCR001 def convert_value_if_needed(field: "BaseField", value: Any) -> Any:
field: "BaseField", value: Any
) -> Tuple[Any, List]:
""" """
Converts dates to isoformat as fastapi can check this condition in routes Converts dates to isoformat as fastapi can check this condition in routes
and the fields are not yet parsed. and the fields are not yet parsed.
Converts enums to list of it's values. Converts enums to list of it's values.
Converts uuids to strings. Converts uuids to strings.
Converts decimal to float with given scale. Converts decimal to float with given scale.
:param field: ormar field to check with choices :param field: ormar field to check with choices
:type field: BaseField :type field: BaseField
:param value: current values of the model to verify :param value: current values of the model to verify
:type value: Dict
:return: value, choices list
:rtype: Tuple[Any, List]
"""
# TODO use same maps as with EncryptedString
choices = [o.value if isinstance(o, Enum) else o for o in field.choices]
if field.__type__ in [datetime.datetime, datetime.date, datetime.time]:
value = value.isoformat() if not isinstance(value, str) else value
choices = [o.isoformat() for o in field.choices]
elif field.__type__ == pydantic.Json:
value = (
json.dumps(value) if not isinstance(value, str) else re_dump_value(value)
)
value = value.decode("utf-8") if isinstance(value, bytes) else value
choices = [re_dump_value(x) for x in field.choices]
elif field.__type__ == uuid.UUID:
value = str(value) if not isinstance(value, str) else value
choices = [str(o) for o in field.choices]
elif field.__type__ == decimal.Decimal:
precision = field.scale # type: ignore
value = (
round(float(value), precision)
if isinstance(value, decimal.Decimal)
else value
)
choices = [round(float(o), precision) for o in choices]
elif field.__type__ == bytes:
if field.represent_as_base64_str:
value = value if isinstance(value, bytes) else base64.b64decode(value)
else:
value = value if isinstance(value, bytes) else value.encode("utf-8")
return value, choices
def re_dump_value(value: str) -> str:
"""
Rw-dumps choices due to different string representation in orjson and json
:param value: string to re-dump
:type value: str
:return: re-dumped choices
:rtype: List[str]
"""
try:
result: Union[str, bytes] = json.dumps(json.loads(value))
except json.JSONDecodeError:
result = value
return result.decode("utf-8") if isinstance(result, bytes) else result
def validate_choices(field: "BaseField", value: Any) -> None:
"""
Validates if given value is in provided choices.
:raises ValueError: If value is not in choices.
:param field:field to validate
:type field: BaseField
:param value: value of the field
:type value: Any :type value: Any
:return: value, choices list
:rtype: Any
""" """
value, choices = convert_choices_if_needed(field=field, value=value) encoder = ormar.ENCODERS_MAP.get(field.__type__, lambda x: x)
if field.nullable: if field.__type__ == decimal.Decimal:
choices.append(None) precision = field.scale # type: ignore
if value is not ormar.Undefined and value not in choices: value = encoder(value, precision)
raise ValueError( elif encoder:
f"{field.name}: '{value}' " f"not in allowed choices set:" f" {choices}" value = encoder(value)
) return value
def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, Any]: def generate_validator(ormar_field: "BaseField") -> Callable:
""" choices = ormar_field.choices
Validator that is attached to pydantic model pre root validators.
Validator checks if field value is in field.choices list.
:raises ValueError: if field value is outside of allowed choices. def validate_choices(cls: type, value: Any, field: "ModelField") -> None:
:param cls: constructed class """
:type cls: Model class Validates if given value is in provided choices.
:param values: dictionary of field values (pydantic side)
:type values: Dict[str, Any] :raises ValueError: If value is not in choices.
:return: values if pass validation, otherwise exception is raised :param field:field to validate
:rtype: Dict[str, Any] :type field: BaseField
""" :param value: value of the field
for field_name, field in cls.Meta.model_fields.items(): :type value: Any
if check_if_field_has_choices(field): """
value = values.get(field_name, ormar.Undefined) adjusted_value = convert_value_if_needed(field=ormar_field, value=value)
validate_choices(field=field, value=value) if adjusted_value is not ormar.Undefined and adjusted_value not in choices:
return values raise ValueError(
f"{field.name}: '{adjusted_value}' "
f"not in allowed choices set:"
f" {choices}"
)
return value
return validate_choices
def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> Dict: def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> Dict:
@ -172,7 +122,7 @@ def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> D
def populates_sample_fields_values( def populates_sample_fields_values(
example: Dict[str, Any], name: str, field: BaseField, relation_map: Dict = None example: Dict[str, Any], name: str, field: "BaseField", relation_map: Dict = None
) -> None: ) -> None:
""" """
Iterates the field and sets fields to sample values Iterates the field and sets fields to sample values
@ -350,15 +300,14 @@ def populate_choices_validators(model: Type["Model"]) -> None: # noqa CCR001
""" """
fields_with_choices = [] fields_with_choices = []
if not meta_field_not_set(model=model, field_name="model_fields"): if not meta_field_not_set(model=model, field_name="model_fields"):
if hasattr(model, "_choices_fields"):
return
model._choices_fields = set()
for name, field in model.Meta.model_fields.items(): for name, field in model.Meta.model_fields.items():
if check_if_field_has_choices(field): if check_if_field_has_choices(field):
fields_with_choices.append(name) fields_with_choices.append(name)
validators = getattr(model, "__pre_root_validators__", []) validator = make_generic_validator(generate_validator(field))
if choices_validator not in validators: model.__fields__[name].validators.append(validator)
validators.append(choices_validator)
model.__pre_root_validators__ = validators
if not model._choices_fields:
model._choices_fields = set()
model._choices_fields.add(name) model._choices_fields.add(name)
if fields_with_choices: if fields_with_choices:

View File

@ -106,7 +106,6 @@ def add_cached_properties(new_model: Type["Model"]) -> None:
new_model._through_names = None new_model._through_names = None
new_model._related_fields = None new_model._related_fields = None
new_model._pydantic_fields = {name for name in new_model.__fields__} new_model._pydantic_fields = {name for name in new_model.__fields__}
new_model._choices_fields = set()
new_model._json_fields = set() new_model._json_fields = set()
new_model._bytes_fields = set() new_model._bytes_fields = set()

View File

@ -11,9 +11,10 @@ from typing import (
cast, cast,
) )
import ormar import pydantic
import ormar # noqa: I100, I202
from ormar.exceptions import ModelPersistenceError from ormar.exceptions import ModelPersistenceError
from ormar.models.helpers.validation import validate_choices
from ormar.models.mixins import AliasMixin from ormar.models.mixins import AliasMixin
from ormar.models.mixins.relation_mixin import RelationMixin from ormar.models.mixins.relation_mixin import RelationMixin
@ -29,6 +30,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if TYPE_CHECKING: # pragma: nocover if TYPE_CHECKING: # pragma: nocover
_choices_fields: Optional[Set] _choices_fields: Optional[Set]
_skip_ellipsis: Callable _skip_ellipsis: Callable
__fields__: Dict[str, pydantic.fields.ModelField]
@classmethod @classmethod
def prepare_model_to_save(cls, new_kwargs: dict) -> dict: def prepare_model_to_save(cls, new_kwargs: dict) -> dict:
@ -180,9 +182,18 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if not cls._choices_fields: if not cls._choices_fields:
return new_kwargs return new_kwargs
for field_name, field in cls.Meta.model_fields.items(): fields_to_check = [
if field_name in new_kwargs and field_name in cls._choices_fields: field
validate_choices(field=field, value=new_kwargs.get(field_name)) for field in cls.Meta.model_fields.values()
if field.name in cls._choices_fields and field.name in new_kwargs
]
for field in fields_to_check:
if new_kwargs[field.name] not in field.choices:
raise ValueError(
f"{field.name}: '{new_kwargs[field.name]}' "
f"not in allowed choices set:"
f" {field.choices}"
)
return new_kwargs return new_kwargs
@staticmethod @staticmethod

View File

@ -90,7 +90,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
_related_names: Optional[Set] _related_names: Optional[Set]
_through_names: Optional[Set] _through_names: Optional[Set]
_related_names_hash: str _related_names_hash: str
_choices_fields: Optional[Set] _choices_fields: Set
_pydantic_fields: Set _pydantic_fields: Set
_quick_access_fields: Set _quick_access_fields: Set
_json_fields: Set _json_fields: Set
@ -928,6 +928,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
:return: dictionary of fields names and values. :return: dictionary of fields names and values.
:rtype: Dict :rtype: Dict
""" """
# TODO: Cache this dictionary?
self_fields = self._extract_own_model_fields() self_fields = self._extract_own_model_fields()
self_fields = { self_fields = {
k: v k: v

View File

@ -29,8 +29,8 @@ def check_node_not_dict_or_not_last_node(
:param part: :param part:
:type part: str :type part: str
:param parts: :param is_last: flag to check if last element
:type parts: List[str] :type is_last: bool
:param current_level: current level of the traversed structure :param current_level: current level of the traversed structure
:type current_level: Any :type current_level: Any
:return: result of the check :return: result of the check
@ -52,7 +52,7 @@ def translate_list_to_dict( # noqa: CCR001
Default required key ise Ellipsis like in pydantic. Default required key ise Ellipsis like in pydantic.
:param list_to_trans: input list :param list_to_trans: input list
:type list_to_trans: set :type list_to_trans: Union[List, Set]
:param is_order: flag if change affects order_by clauses are they require special :param is_order: flag if change affects order_by clauses are they require special
default value with sort order. default value with sort order.
:type is_order: bool :type is_order: bool

View File

@ -121,7 +121,8 @@ def test_all_endpoints():
"blob_col": blob.decode("utf-8"), "blob_col": blob.decode("utf-8"),
}, },
) )
if response.status_code != 200:
print(response.text)
assert response.status_code == 200 assert response.status_code == 200
item = Organisation(**response.json()) item = Organisation(**response.json())
assert item.pk is not None assert item.pk is not None