diff --git a/ormar/__init__.py b/ormar/__init__.py index f444fb6..21ab87e 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -61,6 +61,7 @@ from ormar.fields import ( LargeBinary, ManyToMany, ManyToManyField, + SQL_ENCODERS_MAP, SmallInteger, String, Text, @@ -132,6 +133,7 @@ __all__ = [ "or_", "EncryptBackends", "ENCODERS_MAP", + "SQL_ENCODERS_MAP", "DECODERS_MAP", "LargeBinary", "Extra", diff --git a/ormar/fields/__init__.py b/ormar/fields/__init__.py index 9ea9387..e90b5df 100644 --- a/ormar/fields/__init__.py +++ b/ormar/fields/__init__.py @@ -24,7 +24,7 @@ from ormar.fields.model_fields import ( Time, UUID, ) -from ormar.fields.parsers import DECODERS_MAP, ENCODERS_MAP +from ormar.fields.parsers import DECODERS_MAP, ENCODERS_MAP, SQL_ENCODERS_MAP from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends from ormar.fields.through_field import Through, ThroughField @@ -54,6 +54,7 @@ __all__ = [ "EncryptBackend", "DECODERS_MAP", "ENCODERS_MAP", + "SQL_ENCODERS_MAP", "LargeBinary", "UniqueColumns", ] diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index 6003620..a404204 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -1,11 +1,13 @@ import datetime import decimal import uuid -from typing import Any, Optional, TYPE_CHECKING, Union, overload +from enum import Enum +from typing import Any, Optional, Set, TYPE_CHECKING, Type, Union, overload import pydantic import sqlalchemy +import ormar # noqa I101 from ormar import ModelDefinitionError # noqa I101 from ormar.fields import sqlalchemy_uuid from ormar.fields.base import BaseField # noqa I101 @@ -60,6 +62,39 @@ def is_auto_primary_key(primary_key: bool, autoincrement: bool) -> bool: return primary_key and autoincrement +def convert_choices_if_needed( + field_type: "Type", choices: Set, nullable: bool, scale: int = None +) -> Set: + """ + Converts dates to isoformat as fastapi can check this condition in routes + and the fields are not yet parsed. + Converts enums to list of it's values. + Converts uuids to strings. + Converts decimal to float with given scale. + + :param field_type: type o the field + :type field_type: Type + :param nullable: set of choices + :type nullable: Set + :param scale: scale for decimals + :type scale: int + :param scale: scale for decimals + :type scale: int + :return: value, choices list + :rtype: Tuple[Any, Set] + """ + choices = {o.value if isinstance(o, Enum) else o for o in choices} + encoder = ormar.ENCODERS_MAP.get(field_type, lambda x: x) + if field_type == decimal.Decimal: + precision = scale + choices = {encoder(o, precision) for o in choices} + elif encoder: + choices = {encoder(o) for o in choices} + if nullable: + choices.add(None) + return choices + + class ModelFieldFactory: """ Default field factory that construct Field classes and populated their values. @@ -96,6 +131,15 @@ class ModelFieldFactory: else (nullable if sql_nullable is None else sql_nullable) ) + choices = set(kwargs.pop("choices", [])) + if choices: + choices = convert_choices_if_needed( + field_type=cls._type, + choices=choices, + nullable=nullable, + scale=kwargs.get("scale", None), + ) + namespace = dict( __type__=cls._type, __pydantic_type__=overwrite_pydantic_type @@ -114,7 +158,7 @@ class ModelFieldFactory: pydantic_only=pydantic_only, autoincrement=autoincrement, column_type=cls.get_column_type(**kwargs), - choices=set(kwargs.pop("choices", [])), + choices=choices, encrypt_secret=encrypt_secret, encrypt_backend=encrypt_backend, encrypt_custom_backend=encrypt_custom_backend, diff --git a/ormar/fields/parsers.py b/ormar/fields/parsers.py index e0f1a53..e8b7301 100644 --- a/ormar/fields/parsers.py +++ b/ormar/fields/parsers.py @@ -1,6 +1,8 @@ +import base64 import datetime import decimal -from typing import Any +import uuid +from typing import Any, Callable, Dict, Union import pydantic from pydantic.datetime_parse import parse_date, parse_datetime, parse_time @@ -19,21 +21,55 @@ def encode_bool(value: bool) -> str: return "true" if value else "false" +def encode_decimal(value: decimal.Decimal, precision: int = None) -> float: + if precision: + return ( + round(float(value), precision) + if isinstance(value, decimal.Decimal) + else value + ) + return float(value) + + +def encode_bytes(value: Union[str, bytes], represent_as_string: bool = False) -> bytes: + if represent_as_string: + return value if isinstance(value, bytes) else base64.b64decode(value) + return value if isinstance(value, bytes) else value.encode("utf-8") + + def encode_json(value: Any) -> str: - value = json.dumps(value) if not isinstance(value, str) else value + value = json.dumps(value) if not isinstance(value, str) else re_dump_value(value) value = value.decode("utf-8") if isinstance(value, bytes) else value return value -ENCODERS_MAP = { - bool: encode_bool, +def re_dump_value(value: str) -> Union[str, bytes]: + """ + Rw-dumps choices due to different string representation in orjson and json + :param value: string to re-dump + :type value: str + :return: re-dumped choices + :rtype: List[str] + """ + try: + result: Union[str, bytes] = json.dumps(json.loads(value)) + except json.JSONDecodeError: + result = value + return result + + +ENCODERS_MAP: Dict[type, Callable] = { datetime.datetime: lambda x: x.isoformat(), datetime.date: lambda x: x.isoformat(), datetime.time: lambda x: x.isoformat(), pydantic.Json: encode_json, - decimal.Decimal: float, + decimal.Decimal: encode_decimal, + uuid.UUID: str, + bytes: encode_bytes, } +SQL_ENCODERS_MAP: Dict[type, Callable] = {bool: encode_bool, **ENCODERS_MAP} + DECODERS_MAP = { bool: parse_bool, datetime.datetime: parse_datetime, diff --git a/ormar/fields/sqlalchemy_encrypted.py b/ormar/fields/sqlalchemy_encrypted.py index 97769ba..88dc0af 100644 --- a/ormar/fields/sqlalchemy_encrypted.py +++ b/ormar/fields/sqlalchemy_encrypted.py @@ -160,7 +160,7 @@ class EncryptedString(types.TypeDecorator): try: value = self._underlying_type.process_bind_param(value, dialect) except AttributeError: - encoder = ormar.ENCODERS_MAP.get(self.type_, None) + encoder = ormar.SQL_ENCODERS_MAP.get(self.type_, None) if encoder: value = encoder(value) # type: ignore diff --git a/ormar/models/helpers/validation.py b/ormar/models/helpers/validation.py index 11db796..0ed1635 100644 --- a/ormar/models/helpers/validation.py +++ b/ormar/models/helpers/validation.py @@ -1,30 +1,37 @@ import base64 -import datetime import decimal import numbers -import uuid -from enum import Enum -from typing import Any, Dict, List, Set, TYPE_CHECKING, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + List, + Set, + TYPE_CHECKING, + Type, + Union, +) try: import orjson as json except ImportError: # pragma: no cover - import json # type: ignore + import json # type: ignore # noqa: F401 import pydantic -from pydantic.fields import SHAPE_LIST +from pydantic.class_validators import make_generic_validator +from pydantic.fields import ModelField, SHAPE_LIST from pydantic.main import SchemaExtraCallable import ormar # noqa: I100, I202 -from ormar.fields import BaseField from ormar.models.helpers.models import meta_field_not_set from ormar.queryset.utils import translate_list_to_dict if TYPE_CHECKING: # pragma no cover from ormar import Model + from ormar.fields import BaseField -def check_if_field_has_choices(field: BaseField) -> bool: +def check_if_field_has_choices(field: "BaseField") -> bool: """ Checks if given field has choices populated. A if it has one, a validator for this field needs to be attached. @@ -37,110 +44,53 @@ def check_if_field_has_choices(field: BaseField) -> bool: return hasattr(field, "choices") and bool(field.choices) -def convert_choices_if_needed( # noqa: CCR001 - field: "BaseField", value: Any -) -> Tuple[Any, List]: +def convert_value_if_needed(field: "BaseField", value: Any) -> Any: """ Converts dates to isoformat as fastapi can check this condition in routes and the fields are not yet parsed. - Converts enums to list of it's values. - Converts uuids to strings. - Converts decimal to float with given scale. :param field: ormar field to check with choices :type field: BaseField :param value: current values of the model to verify - :type value: Dict - :return: value, choices list - :rtype: Tuple[Any, List] - """ - # TODO use same maps as with EncryptedString - choices = [o.value if isinstance(o, Enum) else o for o in field.choices] - - if field.__type__ in [datetime.datetime, datetime.date, datetime.time]: - value = value.isoformat() if not isinstance(value, str) else value - choices = [o.isoformat() for o in field.choices] - elif field.__type__ == pydantic.Json: - value = ( - json.dumps(value) if not isinstance(value, str) else re_dump_value(value) - ) - value = value.decode("utf-8") if isinstance(value, bytes) else value - choices = [re_dump_value(x) for x in field.choices] - elif field.__type__ == uuid.UUID: - value = str(value) if not isinstance(value, str) else value - choices = [str(o) for o in field.choices] - elif field.__type__ == decimal.Decimal: - precision = field.scale # type: ignore - value = ( - round(float(value), precision) - if isinstance(value, decimal.Decimal) - else value - ) - choices = [round(float(o), precision) for o in choices] - elif field.__type__ == bytes: - if field.represent_as_base64_str: - value = value if isinstance(value, bytes) else base64.b64decode(value) - else: - value = value if isinstance(value, bytes) else value.encode("utf-8") - - return value, choices - - -def re_dump_value(value: str) -> str: - """ - Rw-dumps choices due to different string representation in orjson and json - :param value: string to re-dump - :type value: str - :return: re-dumped choices - :rtype: List[str] - """ - try: - result: Union[str, bytes] = json.dumps(json.loads(value)) - except json.JSONDecodeError: - result = value - return result.decode("utf-8") if isinstance(result, bytes) else result - - -def validate_choices(field: "BaseField", value: Any) -> None: - """ - Validates if given value is in provided choices. - - :raises ValueError: If value is not in choices. - :param field:field to validate - :type field: BaseField - :param value: value of the field :type value: Any + :return: value, choices list + :rtype: Any """ - value, choices = convert_choices_if_needed(field=field, value=value) - if field.nullable: - choices.append(None) - if value is not ormar.Undefined and value not in choices: - raise ValueError( - f"{field.name}: '{value}' " f"not in allowed choices set:" f" {choices}" - ) + encoder = ormar.ENCODERS_MAP.get(field.__type__, lambda x: x) + if field.__type__ == decimal.Decimal: + precision = field.scale # type: ignore + value = encoder(value, precision) + elif encoder: + value = encoder(value) + return value -def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, Any]: - """ - Validator that is attached to pydantic model pre root validators. - Validator checks if field value is in field.choices list. +def generate_validator(ormar_field: "BaseField") -> Callable: + choices = ormar_field.choices - :raises ValueError: if field value is outside of allowed choices. - :param cls: constructed class - :type cls: Model class - :param values: dictionary of field values (pydantic side) - :type values: Dict[str, Any] - :return: values if pass validation, otherwise exception is raised - :rtype: Dict[str, Any] - """ - for field_name, field in cls.Meta.model_fields.items(): - if check_if_field_has_choices(field): - value = values.get(field_name, ormar.Undefined) - validate_choices(field=field, value=value) - return values + def validate_choices(cls: type, value: Any, field: "ModelField") -> None: + """ + Validates if given value is in provided choices. + + :raises ValueError: If value is not in choices. + :param field:field to validate + :type field: BaseField + :param value: value of the field + :type value: Any + """ + adjusted_value = convert_value_if_needed(field=ormar_field, value=value) + if adjusted_value is not ormar.Undefined and adjusted_value not in choices: + raise ValueError( + f"{field.name}: '{adjusted_value}' " + f"not in allowed choices set:" + f" {choices}" + ) + return value + + return validate_choices def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> Dict: @@ -172,7 +122,7 @@ def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> D def populates_sample_fields_values( - example: Dict[str, Any], name: str, field: BaseField, relation_map: Dict = None + example: Dict[str, Any], name: str, field: "BaseField", relation_map: Dict = None ) -> None: """ Iterates the field and sets fields to sample values @@ -350,15 +300,14 @@ def populate_choices_validators(model: Type["Model"]) -> None: # noqa CCR001 """ fields_with_choices = [] if not meta_field_not_set(model=model, field_name="model_fields"): + if hasattr(model, "_choices_fields"): + return + model._choices_fields = set() for name, field in model.Meta.model_fields.items(): if check_if_field_has_choices(field): fields_with_choices.append(name) - validators = getattr(model, "__pre_root_validators__", []) - if choices_validator not in validators: - validators.append(choices_validator) - model.__pre_root_validators__ = validators - if not model._choices_fields: - model._choices_fields = set() + validator = make_generic_validator(generate_validator(field)) + model.__fields__[name].validators.append(validator) model._choices_fields.add(name) if fields_with_choices: diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 49043cc..2bee725 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -106,7 +106,6 @@ def add_cached_properties(new_model: Type["Model"]) -> None: new_model._through_names = None new_model._related_fields = None new_model._pydantic_fields = {name for name in new_model.__fields__} - new_model._choices_fields = set() new_model._json_fields = set() new_model._bytes_fields = set() diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index d1769f9..de3923e 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -11,9 +11,10 @@ from typing import ( cast, ) -import ormar +import pydantic + +import ormar # noqa: I100, I202 from ormar.exceptions import ModelPersistenceError -from ormar.models.helpers.validation import validate_choices from ormar.models.mixins import AliasMixin from ormar.models.mixins.relation_mixin import RelationMixin @@ -29,6 +30,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin): if TYPE_CHECKING: # pragma: nocover _choices_fields: Optional[Set] _skip_ellipsis: Callable + __fields__: Dict[str, pydantic.fields.ModelField] @classmethod def prepare_model_to_save(cls, new_kwargs: dict) -> dict: @@ -180,9 +182,18 @@ class SavePrepareMixin(RelationMixin, AliasMixin): if not cls._choices_fields: return new_kwargs - for field_name, field in cls.Meta.model_fields.items(): - if field_name in new_kwargs and field_name in cls._choices_fields: - validate_choices(field=field, value=new_kwargs.get(field_name)) + fields_to_check = [ + field + for field in cls.Meta.model_fields.values() + if field.name in cls._choices_fields and field.name in new_kwargs + ] + for field in fields_to_check: + if new_kwargs[field.name] not in field.choices: + raise ValueError( + f"{field.name}: '{new_kwargs[field.name]}' " + f"not in allowed choices set:" + f" {field.choices}" + ) return new_kwargs @staticmethod diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 19f6634..1fbf109 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -90,7 +90,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass _related_names: Optional[Set] _through_names: Optional[Set] _related_names_hash: str - _choices_fields: Optional[Set] + _choices_fields: Set _pydantic_fields: Set _quick_access_fields: Set _json_fields: Set @@ -928,6 +928,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass :return: dictionary of fields names and values. :rtype: Dict """ + # TODO: Cache this dictionary? self_fields = self._extract_own_model_fields() self_fields = { k: v diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index f99c698..4ebe07c 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -29,8 +29,8 @@ def check_node_not_dict_or_not_last_node( :param part: :type part: str - :param parts: - :type parts: List[str] + :param is_last: flag to check if last element + :type is_last: bool :param current_level: current level of the traversed structure :type current_level: Any :return: result of the check @@ -52,7 +52,7 @@ def translate_list_to_dict( # noqa: CCR001 Default required key ise Ellipsis like in pydantic. :param list_to_trans: input list - :type list_to_trans: set + :type list_to_trans: Union[List, Set] :param is_order: flag if change affects order_by clauses are they require special default value with sort order. :type is_order: bool diff --git a/tests/test_fastapi/test_choices_schema.py b/tests/test_fastapi/test_choices_schema.py index 86336c5..da8a167 100644 --- a/tests/test_fastapi/test_choices_schema.py +++ b/tests/test_fastapi/test_choices_schema.py @@ -121,7 +121,8 @@ def test_all_endpoints(): "blob_col": blob.decode("utf-8"), }, ) - + if response.status_code != 200: + print(response.text) assert response.status_code == 200 item = Organisation(**response.json()) assert item.pk is not None