refactor choices validation from root validator to field validator

This commit is contained in:
collerek
2021-10-10 14:11:25 +02:00
parent d992f3dc3b
commit d8f0dc92f0
11 changed files with 168 additions and 124 deletions

View File

@ -1,30 +1,37 @@
import base64
import datetime
import decimal
import numbers
import uuid
from enum import Enum
from typing import Any, Dict, List, Set, TYPE_CHECKING, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
Set,
TYPE_CHECKING,
Type,
Union,
)
try:
import orjson as json
except ImportError: # pragma: no cover
import json # type: ignore
import json # type: ignore # noqa: F401
import pydantic
from pydantic.fields import SHAPE_LIST
from pydantic.class_validators import make_generic_validator
from pydantic.fields import ModelField, SHAPE_LIST
from pydantic.main import SchemaExtraCallable
import ormar # noqa: I100, I202
from ormar.fields import BaseField
from ormar.models.helpers.models import meta_field_not_set
from ormar.queryset.utils import translate_list_to_dict
if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.fields import BaseField
def check_if_field_has_choices(field: BaseField) -> bool:
def check_if_field_has_choices(field: "BaseField") -> bool:
"""
Checks if given field has choices populated.
A if it has one, a validator for this field needs to be attached.
@ -37,110 +44,53 @@ def check_if_field_has_choices(field: BaseField) -> bool:
return hasattr(field, "choices") and bool(field.choices)
def convert_choices_if_needed( # noqa: CCR001
field: "BaseField", value: Any
) -> Tuple[Any, List]:
def convert_value_if_needed(field: "BaseField", value: Any) -> Any:
"""
Converts dates to isoformat as fastapi can check this condition in routes
and the fields are not yet parsed.
Converts enums to list of it's values.
Converts uuids to strings.
Converts decimal to float with given scale.
:param field: ormar field to check with choices
:type field: BaseField
:param value: current values of the model to verify
:type value: Dict
:return: value, choices list
:rtype: Tuple[Any, List]
"""
# TODO use same maps as with EncryptedString
choices = [o.value if isinstance(o, Enum) else o for o in field.choices]
if field.__type__ in [datetime.datetime, datetime.date, datetime.time]:
value = value.isoformat() if not isinstance(value, str) else value
choices = [o.isoformat() for o in field.choices]
elif field.__type__ == pydantic.Json:
value = (
json.dumps(value) if not isinstance(value, str) else re_dump_value(value)
)
value = value.decode("utf-8") if isinstance(value, bytes) else value
choices = [re_dump_value(x) for x in field.choices]
elif field.__type__ == uuid.UUID:
value = str(value) if not isinstance(value, str) else value
choices = [str(o) for o in field.choices]
elif field.__type__ == decimal.Decimal:
precision = field.scale # type: ignore
value = (
round(float(value), precision)
if isinstance(value, decimal.Decimal)
else value
)
choices = [round(float(o), precision) for o in choices]
elif field.__type__ == bytes:
if field.represent_as_base64_str:
value = value if isinstance(value, bytes) else base64.b64decode(value)
else:
value = value if isinstance(value, bytes) else value.encode("utf-8")
return value, choices
def re_dump_value(value: str) -> str:
"""
Rw-dumps choices due to different string representation in orjson and json
:param value: string to re-dump
:type value: str
:return: re-dumped choices
:rtype: List[str]
"""
try:
result: Union[str, bytes] = json.dumps(json.loads(value))
except json.JSONDecodeError:
result = value
return result.decode("utf-8") if isinstance(result, bytes) else result
def validate_choices(field: "BaseField", value: Any) -> None:
"""
Validates if given value is in provided choices.
:raises ValueError: If value is not in choices.
:param field:field to validate
:type field: BaseField
:param value: value of the field
:type value: Any
:return: value, choices list
:rtype: Any
"""
value, choices = convert_choices_if_needed(field=field, value=value)
if field.nullable:
choices.append(None)
if value is not ormar.Undefined and value not in choices:
raise ValueError(
f"{field.name}: '{value}' " f"not in allowed choices set:" f" {choices}"
)
encoder = ormar.ENCODERS_MAP.get(field.__type__, lambda x: x)
if field.__type__ == decimal.Decimal:
precision = field.scale # type: ignore
value = encoder(value, precision)
elif encoder:
value = encoder(value)
return value
def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validator that is attached to pydantic model pre root validators.
Validator checks if field value is in field.choices list.
def generate_validator(ormar_field: "BaseField") -> Callable:
choices = ormar_field.choices
:raises ValueError: if field value is outside of allowed choices.
:param cls: constructed class
:type cls: Model class
:param values: dictionary of field values (pydantic side)
:type values: Dict[str, Any]
:return: values if pass validation, otherwise exception is raised
:rtype: Dict[str, Any]
"""
for field_name, field in cls.Meta.model_fields.items():
if check_if_field_has_choices(field):
value = values.get(field_name, ormar.Undefined)
validate_choices(field=field, value=value)
return values
def validate_choices(cls: type, value: Any, field: "ModelField") -> None:
"""
Validates if given value is in provided choices.
:raises ValueError: If value is not in choices.
:param field:field to validate
:type field: BaseField
:param value: value of the field
:type value: Any
"""
adjusted_value = convert_value_if_needed(field=ormar_field, value=value)
if adjusted_value is not ormar.Undefined and adjusted_value not in choices:
raise ValueError(
f"{field.name}: '{adjusted_value}' "
f"not in allowed choices set:"
f" {choices}"
)
return value
return validate_choices
def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> Dict:
@ -172,7 +122,7 @@ def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> D
def populates_sample_fields_values(
example: Dict[str, Any], name: str, field: BaseField, relation_map: Dict = None
example: Dict[str, Any], name: str, field: "BaseField", relation_map: Dict = None
) -> None:
"""
Iterates the field and sets fields to sample values
@ -350,15 +300,14 @@ def populate_choices_validators(model: Type["Model"]) -> None: # noqa CCR001
"""
fields_with_choices = []
if not meta_field_not_set(model=model, field_name="model_fields"):
if hasattr(model, "_choices_fields"):
return
model._choices_fields = set()
for name, field in model.Meta.model_fields.items():
if check_if_field_has_choices(field):
fields_with_choices.append(name)
validators = getattr(model, "__pre_root_validators__", [])
if choices_validator not in validators:
validators.append(choices_validator)
model.__pre_root_validators__ = validators
if not model._choices_fields:
model._choices_fields = set()
validator = make_generic_validator(generate_validator(field))
model.__fields__[name].validators.append(validator)
model._choices_fields.add(name)
if fields_with_choices:

View File

@ -106,7 +106,6 @@ def add_cached_properties(new_model: Type["Model"]) -> None:
new_model._through_names = None
new_model._related_fields = None
new_model._pydantic_fields = {name for name in new_model.__fields__}
new_model._choices_fields = set()
new_model._json_fields = set()
new_model._bytes_fields = set()

View File

@ -11,9 +11,10 @@ from typing import (
cast,
)
import ormar
import pydantic
import ormar # noqa: I100, I202
from ormar.exceptions import ModelPersistenceError
from ormar.models.helpers.validation import validate_choices
from ormar.models.mixins import AliasMixin
from ormar.models.mixins.relation_mixin import RelationMixin
@ -29,6 +30,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if TYPE_CHECKING: # pragma: nocover
_choices_fields: Optional[Set]
_skip_ellipsis: Callable
__fields__: Dict[str, pydantic.fields.ModelField]
@classmethod
def prepare_model_to_save(cls, new_kwargs: dict) -> dict:
@ -180,9 +182,18 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if not cls._choices_fields:
return new_kwargs
for field_name, field in cls.Meta.model_fields.items():
if field_name in new_kwargs and field_name in cls._choices_fields:
validate_choices(field=field, value=new_kwargs.get(field_name))
fields_to_check = [
field
for field in cls.Meta.model_fields.values()
if field.name in cls._choices_fields and field.name in new_kwargs
]
for field in fields_to_check:
if new_kwargs[field.name] not in field.choices:
raise ValueError(
f"{field.name}: '{new_kwargs[field.name]}' "
f"not in allowed choices set:"
f" {field.choices}"
)
return new_kwargs
@staticmethod

View File

@ -90,7 +90,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
_related_names: Optional[Set]
_through_names: Optional[Set]
_related_names_hash: str
_choices_fields: Optional[Set]
_choices_fields: Set
_pydantic_fields: Set
_quick_access_fields: Set
_json_fields: Set
@ -928,6 +928,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
:return: dictionary of fields names and values.
:rtype: Dict
"""
# TODO: Cache this dictionary?
self_fields = self._extract_own_model_fields()
self_fields = {
k: v