refactor choices validation from root validator to field validator

This commit is contained in:
collerek
2021-10-10 14:11:25 +02:00
parent d992f3dc3b
commit d8f0dc92f0
11 changed files with 168 additions and 124 deletions

View File

@ -61,6 +61,7 @@ from ormar.fields import (
LargeBinary,
ManyToMany,
ManyToManyField,
SQL_ENCODERS_MAP,
SmallInteger,
String,
Text,
@ -132,6 +133,7 @@ __all__ = [
"or_",
"EncryptBackends",
"ENCODERS_MAP",
"SQL_ENCODERS_MAP",
"DECODERS_MAP",
"LargeBinary",
"Extra",

View File

@ -24,7 +24,7 @@ from ormar.fields.model_fields import (
Time,
UUID,
)
from ormar.fields.parsers import DECODERS_MAP, ENCODERS_MAP
from ormar.fields.parsers import DECODERS_MAP, ENCODERS_MAP, SQL_ENCODERS_MAP
from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends
from ormar.fields.through_field import Through, ThroughField
@ -54,6 +54,7 @@ __all__ = [
"EncryptBackend",
"DECODERS_MAP",
"ENCODERS_MAP",
"SQL_ENCODERS_MAP",
"LargeBinary",
"UniqueColumns",
]

View File

@ -1,11 +1,13 @@
import datetime
import decimal
import uuid
from typing import Any, Optional, TYPE_CHECKING, Union, overload
from enum import Enum
from typing import Any, Optional, Set, TYPE_CHECKING, Type, Union, overload
import pydantic
import sqlalchemy
import ormar # noqa I101
from ormar import ModelDefinitionError # noqa I101
from ormar.fields import sqlalchemy_uuid
from ormar.fields.base import BaseField # noqa I101
@ -60,6 +62,39 @@ def is_auto_primary_key(primary_key: bool, autoincrement: bool) -> bool:
return primary_key and autoincrement
def convert_choices_if_needed(
field_type: "Type", choices: Set, nullable: bool, scale: int = None
) -> Set:
"""
Converts dates to isoformat as fastapi can check this condition in routes
and the fields are not yet parsed.
Converts enums to list of it's values.
Converts uuids to strings.
Converts decimal to float with given scale.
:param field_type: type o the field
:type field_type: Type
:param nullable: set of choices
:type nullable: Set
:param scale: scale for decimals
:type scale: int
:param scale: scale for decimals
:type scale: int
:return: value, choices list
:rtype: Tuple[Any, Set]
"""
choices = {o.value if isinstance(o, Enum) else o for o in choices}
encoder = ormar.ENCODERS_MAP.get(field_type, lambda x: x)
if field_type == decimal.Decimal:
precision = scale
choices = {encoder(o, precision) for o in choices}
elif encoder:
choices = {encoder(o) for o in choices}
if nullable:
choices.add(None)
return choices
class ModelFieldFactory:
"""
Default field factory that construct Field classes and populated their values.
@ -96,6 +131,15 @@ class ModelFieldFactory:
else (nullable if sql_nullable is None else sql_nullable)
)
choices = set(kwargs.pop("choices", []))
if choices:
choices = convert_choices_if_needed(
field_type=cls._type,
choices=choices,
nullable=nullable,
scale=kwargs.get("scale", None),
)
namespace = dict(
__type__=cls._type,
__pydantic_type__=overwrite_pydantic_type
@ -114,7 +158,7 @@ class ModelFieldFactory:
pydantic_only=pydantic_only,
autoincrement=autoincrement,
column_type=cls.get_column_type(**kwargs),
choices=set(kwargs.pop("choices", [])),
choices=choices,
encrypt_secret=encrypt_secret,
encrypt_backend=encrypt_backend,
encrypt_custom_backend=encrypt_custom_backend,

View File

@ -1,6 +1,8 @@
import base64
import datetime
import decimal
from typing import Any
import uuid
from typing import Any, Callable, Dict, Union
import pydantic
from pydantic.datetime_parse import parse_date, parse_datetime, parse_time
@ -19,21 +21,55 @@ def encode_bool(value: bool) -> str:
return "true" if value else "false"
def encode_decimal(value: decimal.Decimal, precision: int = None) -> float:
if precision:
return (
round(float(value), precision)
if isinstance(value, decimal.Decimal)
else value
)
return float(value)
def encode_bytes(value: Union[str, bytes], represent_as_string: bool = False) -> bytes:
if represent_as_string:
return value if isinstance(value, bytes) else base64.b64decode(value)
return value if isinstance(value, bytes) else value.encode("utf-8")
def encode_json(value: Any) -> str:
value = json.dumps(value) if not isinstance(value, str) else value
value = json.dumps(value) if not isinstance(value, str) else re_dump_value(value)
value = value.decode("utf-8") if isinstance(value, bytes) else value
return value
ENCODERS_MAP = {
bool: encode_bool,
def re_dump_value(value: str) -> Union[str, bytes]:
"""
Rw-dumps choices due to different string representation in orjson and json
:param value: string to re-dump
:type value: str
:return: re-dumped choices
:rtype: List[str]
"""
try:
result: Union[str, bytes] = json.dumps(json.loads(value))
except json.JSONDecodeError:
result = value
return result
ENCODERS_MAP: Dict[type, Callable] = {
datetime.datetime: lambda x: x.isoformat(),
datetime.date: lambda x: x.isoformat(),
datetime.time: lambda x: x.isoformat(),
pydantic.Json: encode_json,
decimal.Decimal: float,
decimal.Decimal: encode_decimal,
uuid.UUID: str,
bytes: encode_bytes,
}
SQL_ENCODERS_MAP: Dict[type, Callable] = {bool: encode_bool, **ENCODERS_MAP}
DECODERS_MAP = {
bool: parse_bool,
datetime.datetime: parse_datetime,

View File

@ -160,7 +160,7 @@ class EncryptedString(types.TypeDecorator):
try:
value = self._underlying_type.process_bind_param(value, dialect)
except AttributeError:
encoder = ormar.ENCODERS_MAP.get(self.type_, None)
encoder = ormar.SQL_ENCODERS_MAP.get(self.type_, None)
if encoder:
value = encoder(value) # type: ignore

View File

@ -1,30 +1,37 @@
import base64
import datetime
import decimal
import numbers
import uuid
from enum import Enum
from typing import Any, Dict, List, Set, TYPE_CHECKING, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
Set,
TYPE_CHECKING,
Type,
Union,
)
try:
import orjson as json
except ImportError: # pragma: no cover
import json # type: ignore
import json # type: ignore # noqa: F401
import pydantic
from pydantic.fields import SHAPE_LIST
from pydantic.class_validators import make_generic_validator
from pydantic.fields import ModelField, SHAPE_LIST
from pydantic.main import SchemaExtraCallable
import ormar # noqa: I100, I202
from ormar.fields import BaseField
from ormar.models.helpers.models import meta_field_not_set
from ormar.queryset.utils import translate_list_to_dict
if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.fields import BaseField
def check_if_field_has_choices(field: BaseField) -> bool:
def check_if_field_has_choices(field: "BaseField") -> bool:
"""
Checks if given field has choices populated.
A if it has one, a validator for this field needs to be attached.
@ -37,74 +44,34 @@ def check_if_field_has_choices(field: BaseField) -> bool:
return hasattr(field, "choices") and bool(field.choices)
def convert_choices_if_needed( # noqa: CCR001
field: "BaseField", value: Any
) -> Tuple[Any, List]:
def convert_value_if_needed(field: "BaseField", value: Any) -> Any:
"""
Converts dates to isoformat as fastapi can check this condition in routes
and the fields are not yet parsed.
Converts enums to list of it's values.
Converts uuids to strings.
Converts decimal to float with given scale.
:param field: ormar field to check with choices
:type field: BaseField
:param value: current values of the model to verify
:type value: Dict
:type value: Any
:return: value, choices list
:rtype: Tuple[Any, List]
:rtype: Any
"""
# TODO use same maps as with EncryptedString
choices = [o.value if isinstance(o, Enum) else o for o in field.choices]
if field.__type__ in [datetime.datetime, datetime.date, datetime.time]:
value = value.isoformat() if not isinstance(value, str) else value
choices = [o.isoformat() for o in field.choices]
elif field.__type__ == pydantic.Json:
value = (
json.dumps(value) if not isinstance(value, str) else re_dump_value(value)
)
value = value.decode("utf-8") if isinstance(value, bytes) else value
choices = [re_dump_value(x) for x in field.choices]
elif field.__type__ == uuid.UUID:
value = str(value) if not isinstance(value, str) else value
choices = [str(o) for o in field.choices]
elif field.__type__ == decimal.Decimal:
encoder = ormar.ENCODERS_MAP.get(field.__type__, lambda x: x)
if field.__type__ == decimal.Decimal:
precision = field.scale # type: ignore
value = (
round(float(value), precision)
if isinstance(value, decimal.Decimal)
else value
)
choices = [round(float(o), precision) for o in choices]
elif field.__type__ == bytes:
if field.represent_as_base64_str:
value = value if isinstance(value, bytes) else base64.b64decode(value)
else:
value = value if isinstance(value, bytes) else value.encode("utf-8")
return value, choices
value = encoder(value, precision)
elif encoder:
value = encoder(value)
return value
def re_dump_value(value: str) -> str:
"""
Rw-dumps choices due to different string representation in orjson and json
:param value: string to re-dump
:type value: str
:return: re-dumped choices
:rtype: List[str]
"""
try:
result: Union[str, bytes] = json.dumps(json.loads(value))
except json.JSONDecodeError:
result = value
return result.decode("utf-8") if isinstance(result, bytes) else result
def generate_validator(ormar_field: "BaseField") -> Callable:
choices = ormar_field.choices
def validate_choices(field: "BaseField", value: Any) -> None:
def validate_choices(cls: type, value: Any, field: "ModelField") -> None:
"""
Validates if given value is in provided choices.
@ -114,33 +81,16 @@ def validate_choices(field: "BaseField", value: Any) -> None:
:param value: value of the field
:type value: Any
"""
value, choices = convert_choices_if_needed(field=field, value=value)
if field.nullable:
choices.append(None)
if value is not ormar.Undefined and value not in choices:
adjusted_value = convert_value_if_needed(field=ormar_field, value=value)
if adjusted_value is not ormar.Undefined and adjusted_value not in choices:
raise ValueError(
f"{field.name}: '{value}' " f"not in allowed choices set:" f" {choices}"
f"{field.name}: '{adjusted_value}' "
f"not in allowed choices set:"
f" {choices}"
)
return value
def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validator that is attached to pydantic model pre root validators.
Validator checks if field value is in field.choices list.
:raises ValueError: if field value is outside of allowed choices.
:param cls: constructed class
:type cls: Model class
:param values: dictionary of field values (pydantic side)
:type values: Dict[str, Any]
:return: values if pass validation, otherwise exception is raised
:rtype: Dict[str, Any]
"""
for field_name, field in cls.Meta.model_fields.items():
if check_if_field_has_choices(field):
value = values.get(field_name, ormar.Undefined)
validate_choices(field=field, value=value)
return values
return validate_choices
def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> Dict:
@ -172,7 +122,7 @@ def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> D
def populates_sample_fields_values(
example: Dict[str, Any], name: str, field: BaseField, relation_map: Dict = None
example: Dict[str, Any], name: str, field: "BaseField", relation_map: Dict = None
) -> None:
"""
Iterates the field and sets fields to sample values
@ -350,15 +300,14 @@ def populate_choices_validators(model: Type["Model"]) -> None: # noqa CCR001
"""
fields_with_choices = []
if not meta_field_not_set(model=model, field_name="model_fields"):
if hasattr(model, "_choices_fields"):
return
model._choices_fields = set()
for name, field in model.Meta.model_fields.items():
if check_if_field_has_choices(field):
fields_with_choices.append(name)
validators = getattr(model, "__pre_root_validators__", [])
if choices_validator not in validators:
validators.append(choices_validator)
model.__pre_root_validators__ = validators
if not model._choices_fields:
model._choices_fields = set()
validator = make_generic_validator(generate_validator(field))
model.__fields__[name].validators.append(validator)
model._choices_fields.add(name)
if fields_with_choices:

View File

@ -106,7 +106,6 @@ def add_cached_properties(new_model: Type["Model"]) -> None:
new_model._through_names = None
new_model._related_fields = None
new_model._pydantic_fields = {name for name in new_model.__fields__}
new_model._choices_fields = set()
new_model._json_fields = set()
new_model._bytes_fields = set()

View File

@ -11,9 +11,10 @@ from typing import (
cast,
)
import ormar
import pydantic
import ormar # noqa: I100, I202
from ormar.exceptions import ModelPersistenceError
from ormar.models.helpers.validation import validate_choices
from ormar.models.mixins import AliasMixin
from ormar.models.mixins.relation_mixin import RelationMixin
@ -29,6 +30,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if TYPE_CHECKING: # pragma: nocover
_choices_fields: Optional[Set]
_skip_ellipsis: Callable
__fields__: Dict[str, pydantic.fields.ModelField]
@classmethod
def prepare_model_to_save(cls, new_kwargs: dict) -> dict:
@ -180,9 +182,18 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if not cls._choices_fields:
return new_kwargs
for field_name, field in cls.Meta.model_fields.items():
if field_name in new_kwargs and field_name in cls._choices_fields:
validate_choices(field=field, value=new_kwargs.get(field_name))
fields_to_check = [
field
for field in cls.Meta.model_fields.values()
if field.name in cls._choices_fields and field.name in new_kwargs
]
for field in fields_to_check:
if new_kwargs[field.name] not in field.choices:
raise ValueError(
f"{field.name}: '{new_kwargs[field.name]}' "
f"not in allowed choices set:"
f" {field.choices}"
)
return new_kwargs
@staticmethod

View File

@ -90,7 +90,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
_related_names: Optional[Set]
_through_names: Optional[Set]
_related_names_hash: str
_choices_fields: Optional[Set]
_choices_fields: Set
_pydantic_fields: Set
_quick_access_fields: Set
_json_fields: Set
@ -928,6 +928,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
:return: dictionary of fields names and values.
:rtype: Dict
"""
# TODO: Cache this dictionary?
self_fields = self._extract_own_model_fields()
self_fields = {
k: v

View File

@ -29,8 +29,8 @@ def check_node_not_dict_or_not_last_node(
:param part:
:type part: str
:param parts:
:type parts: List[str]
:param is_last: flag to check if last element
:type is_last: bool
:param current_level: current level of the traversed structure
:type current_level: Any
:return: result of the check
@ -52,7 +52,7 @@ def translate_list_to_dict( # noqa: CCR001
Default required key ise Ellipsis like in pydantic.
:param list_to_trans: input list
:type list_to_trans: set
:type list_to_trans: Union[List, Set]
:param is_order: flag if change affects order_by clauses are they require special
default value with sort order.
:type is_order: bool

View File

@ -121,7 +121,8 @@ def test_all_endpoints():
"blob_col": blob.decode("utf-8"),
},
)
if response.status_code != 200:
print(response.text)
assert response.status_code == 200
item = Organisation(**response.json())
assert item.pk is not None