From 3e615a80571574097f5f3b96a5c0fe33541b2bf0 Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 2 Dec 2020 19:15:55 +0100 Subject: [PATCH] work in progres pydantic_only and properties --- ormar/fields/model_fields.py | 24 +++- ormar/models/metaclass.py | 136 +++++++++++++++------- ormar/models/modelproxy.py | 70 +++++------ ormar/models/newbasemodel.py | 94 +++++++-------- tests/test_excluding_fields_in_fastapi.py | 38 ++++++ tests/test_pydantic_only_fields.py | 75 ++++++++++++ 6 files changed, 299 insertions(+), 138 deletions(-) create mode 100644 tests/test_pydantic_only_fields.py diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index b92da65..f90ed87 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -12,13 +12,20 @@ from ormar.fields.base import BaseField # noqa I101 def is_field_nullable( - nullable: Optional[bool], default: Any, server_default: Any + nullable: Optional[bool], + default: Any, + server_default: Any, + pydantic_only: Optional[bool], ) -> bool: if nullable is None: - return default is not None or server_default is not None + return default is not None or server_default is not None or pydantic_only return nullable +def is_auto_primary_key(primary_key: bool, autoincrement: bool): + return primary_key and autoincrement + + class ModelFieldFactory: _bases: Any = (BaseField,) _type: Any = None @@ -29,19 +36,24 @@ class ModelFieldFactory: default = kwargs.pop("default", None) server_default = kwargs.pop("server_default", None) nullable = kwargs.pop("nullable", None) + pydantic_only = kwargs.pop("pydantic_only", False) + + primary_key = kwargs.pop("primary_key", False) + autoincrement = kwargs.pop("autoincrement", False) namespace = dict( __type__=cls._type, alias=kwargs.pop("name", None), name=None, - primary_key=kwargs.pop("primary_key", False), + primary_key=primary_key, default=default, server_default=server_default, - nullable=is_field_nullable(nullable, default, server_default), + nullable=is_field_nullable(nullable, default, server_default, pydantic_only) + or is_auto_primary_key(primary_key, autoincrement), index=kwargs.pop("index", False), unique=kwargs.pop("unique", False), - pydantic_only=kwargs.pop("pydantic_only", False), - autoincrement=kwargs.pop("autoincrement", False), + pydantic_only=pydantic_only, + autoincrement=autoincrement, column_type=cls.get_column_type(**kwargs), choices=set(kwargs.pop("choices", [])), **kwargs diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 80265bd..2280c11 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -1,6 +1,7 @@ +import inspect import logging import warnings -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Type, Union import databases import pydantic @@ -15,6 +16,7 @@ from ormar import ForeignKey, ModelDefinitionError, Integer # noqa I100 from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.many_to_many import ManyToMany, ManyToManyField +from ormar.fields.model_fields import ModelFieldFactory from ormar.queryset import QuerySet from ormar.relations.alias_manager import AliasManager @@ -36,6 +38,8 @@ class ModelMeta: str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] ] alias_manager: AliasManager + include_props_in_dict: bool + include_props_in_fields: bool def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None: @@ -43,7 +47,7 @@ def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> def register_many_to_many_relation_on_build( - table_name: str, field: Type[ManyToManyField] + table_name: str, field: Type[ManyToManyField] ) -> None: alias_manager.add_relation_type(field.through.Meta.tablename, table_name) alias_manager.add_relation_type( @@ -52,11 +56,11 @@ def register_many_to_many_relation_on_build( def reverse_field_not_already_registered( - child: Type["Model"], child_model_name: str, parent_model: Type["Model"] + child: Type["Model"], child_model_name: str, parent_model: Type["Model"] ) -> bool: return ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ) @@ -67,7 +71,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if reverse_field_not_already_registered( - child, child_model_name, parent_model + child, child_model_name, parent_model ): register_reverse_model_fields( parent_model, child, child_model_name, model_field @@ -75,10 +79,10 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: def register_reverse_model_fields( - model: Type["Model"], - child: Type["Model"], - child_model_name: str, - model_field: Type["ForeignKeyField"], + model: Type["Model"], + child: Type["Model"], + child_model_name: str, + model_field: Type["ForeignKeyField"], ) -> None: if issubclass(model_field, ManyToManyField): model.Meta.model_fields[child_model_name] = ManyToMany( @@ -93,7 +97,7 @@ def register_reverse_model_fields( def adjust_through_many_to_many_model( - model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( model, real_name=model.get_name(), ondelete="CASCADE" @@ -110,7 +114,7 @@ def adjust_through_many_to_many_model( def create_pydantic_field( - field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] + field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.__fields__[field_name] = ModelField( name=field_name, @@ -121,8 +125,18 @@ def create_pydantic_field( ) +def get_pydantic_field(field_name: str, model: Type["Model"]) -> "ModelField": + return ModelField( + name=field_name, + type_=model.Meta.model_fields[field_name].__type__, + model_config=model.__config__, + required=not model.Meta.model_fields[field_name].nullable, + class_validators={}, + ) + + def create_and_append_m2m_fk( - model: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: column = sqlalchemy.Column( model.get_name(), @@ -138,7 +152,7 @@ def create_and_append_m2m_fk( def check_pk_column_validity( - field_name: str, field: BaseField, pkname: Optional[str] + field_name: str, field: BaseField, pkname: Optional[str] ) -> Optional[str]: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") @@ -148,7 +162,7 @@ def check_pk_column_validity( def sqlalchemy_columns_from_model_fields( - model_fields: Dict, table_name: str + model_fields: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column]]: columns = [] pkname = None @@ -162,9 +176,9 @@ def sqlalchemy_columns_from_model_fields( if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) if ( - not field.pydantic_only - and not field.virtual - and not issubclass(field, ManyToManyField) + not field.pydantic_only + and not field.virtual + and not issubclass(field, ManyToManyField) ): columns.append(field.get_column(field.get_alias())) register_relation_in_alias_manager(table_name, field) @@ -172,7 +186,7 @@ def sqlalchemy_columns_from_model_fields( def register_relation_in_alias_manager( - table_name: str, field: Type[ForeignKeyField] + table_name: str, field: Type[ForeignKeyField] ) -> None: if issubclass(field, ManyToManyField): register_many_to_many_relation_on_build(table_name, field) @@ -181,7 +195,7 @@ def register_relation_in_alias_manager( def populate_default_pydantic_field_value( - ormar_field: Type[BaseField], field_name: str, attrs: dict + ormar_field: Type[BaseField], field_name: str, attrs: dict ) -> dict: curr_def_value = attrs.get(field_name, ormar.Undefined) if lenient_issubclass(curr_def_value, ormar.fields.BaseField): @@ -228,7 +242,7 @@ def extract_annotations_and_default_vals(attrs: dict) -> Tuple[Dict, Dict]: def populate_meta_tablename_columns_and_pk( - name: str, new_model: Type["Model"] + name: str, new_model: Type["Model"] ) -> Type["Model"]: tablename = name.lower() + "s" new_model.Meta.tablename = ( @@ -254,7 +268,7 @@ def populate_meta_tablename_columns_and_pk( def populate_meta_sqlalchemy_table_if_required( - new_model: Type["Model"], + new_model: Type["Model"], ) -> Type["Model"]: if not hasattr(new_model.Meta, "table"): new_model.Meta.table = sqlalchemy.Table( @@ -295,21 +309,66 @@ def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, A return values -def populate_choices_validators( # noqa CCR001 - model: Type["Model"], attrs: Dict -) -> None: +def populate_choices_validators(model: Type["Model"]) -> None: # noqa CCR001 if model_initialized_and_has_model_fields(model): for _, field in model.Meta.model_fields.items(): if check_if_field_has_choices(field): - validators = attrs.get("__pre_root_validators__", []) + validators = getattr(model, "__pre_root_validators__", []) if choices_validator not in validators: validators.append(choices_validator) - attrs["__pre_root_validators__"] = validators + setattr(model, "__pre_root_validators__", validators) + + +def populate_default_options_values(new_model: Type["Model"], model_fields: Dict): + if not hasattr(new_model.Meta, "constraints"): + new_model.Meta.constraints = [] + if not hasattr(new_model.Meta, "model_fields"): + new_model.Meta.model_fields = model_fields + + if not hasattr(new_model.Meta, "include_props_in_dict"): + new_model.Meta.include_props_in_dict = True + if not hasattr(new_model.Meta, "include_props_in_fields"): + new_model.Meta.include_props_in_fields = False + + +def add_cached_properties(new_model): + new_model._props = { + prop + for prop in vars(new_model) + if isinstance(getattr(new_model, prop), property) + and prop + not in ("__values__", "__fields__", "fields", "pk_column", "saved") + } + new_model._quick_access_fields = { + "_orm_id", + "_orm_saved", + "_orm", + "__fields__", + "_related_names", + "_props", + "__class__", + "extract_related_names", + } + new_model._related_names = None + + +def add_property_fields(new_model): + if new_model.Meta.include_props_in_fields: + for prop in new_model._props: + field_type = getattr(new_model, prop).fget.__annotations__.get('return') + new_model.Meta.model_fields[prop] = ModelFieldFactory(nullable=True, pydantic_only=True) + new_model.__fields__[prop] = ModelField( + name=prop, + type_=Optional[field_type] if field_type else Any, + model_config=new_model.__config__, + required=False, + class_validators={}, + ) class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__( # type: ignore - mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict + mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict ) -> "ModelMetaclass": attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name @@ -319,28 +378,21 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): ) if hasattr(new_model, "Meta"): - if not hasattr(new_model.Meta, "constraints"): - new_model.Meta.constraints = [] - if not hasattr(new_model.Meta, "model_fields"): - new_model.Meta.model_fields = model_fields + populate_default_options_values(new_model, model_fields) new_model = populate_meta_tablename_columns_and_pk(name, new_model) new_model = populate_meta_sqlalchemy_table_if_required(new_model) expand_reverse_relationships(new_model) - populate_choices_validators(new_model, attrs) - + populate_choices_validators(new_model) if new_model.Meta.pkname not in attrs["__annotations__"]: field_name = new_model.Meta.pkname - field = Integer(name=field_name, primary_key=True) attrs["__annotations__"][field_name] = Optional[int] # type: ignore - populate_default_pydantic_field_value( - field, field_name, attrs # type: ignore + attrs[field_name] = None + new_model.__fields__[field_name] = get_pydantic_field( + field_name=field_name, model=new_model ) - - new_model = super().__new__( # type: ignore - mcs, name, bases, attrs - ) - new_model.Meta.alias_manager = alias_manager new_model.objects = QuerySet(new_model) + add_cached_properties(new_model) + add_property_fields(new_model) return new_model diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index ac3c5ed..1f637cc 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -49,6 +49,7 @@ class ModelTableProxy: _related_names_hash: Union[str, bytes] pk: Any get_name: Callable + _props: Set def dict(self): # noqa A003 raise NotImplementedError # pragma no cover @@ -68,7 +69,7 @@ class ModelTableProxy: @staticmethod def get_clause_target_and_filter_column_name( - parent_model: Type["Model"], target_model: Type["Model"], reverse: bool + parent_model: Type["Model"], target_model: Type["Model"], reverse: bool ) -> Tuple[Type["Model"], str]: if reverse: field = target_model.resolve_relation_field(target_model, parent_model) @@ -83,10 +84,10 @@ class ModelTableProxy: @staticmethod def get_column_name_for_id_extraction( - parent_model: Type["Model"], - target_model: Type["Model"], - reverse: bool, - use_raw: bool, + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool, + use_raw: bool, ) -> str: if reverse: column_name = parent_model.Meta.pkname @@ -109,7 +110,7 @@ class ModelTableProxy: def get_relation_model_id(self, target_field: Type["BaseField"]) -> Optional[int]: if target_field.virtual or issubclass( - target_field, ormar.fields.ManyToManyField + target_field, ormar.fields.ManyToManyField ): return self.pk related_name = self.resolve_relation_name(self, target_field.to) @@ -126,9 +127,9 @@ class ModelTableProxy: @classmethod def get_names_to_exclude( - cls, - fields: Optional[Union[Dict, Set]] = None, - exclude_fields: Optional[Union[Dict, Set]] = None, + cls, + fields: Optional[Union[Dict, Set]] = None, + exclude_fields: Optional[Union[Dict, Set]] = None, ) -> Set: fields_names = cls.extract_db_own_fields() if fields and fields is not Ellipsis: @@ -207,14 +208,13 @@ class ModelTableProxy: @classmethod def extract_related_names(cls) -> Set: - if isinstance(cls._related_names_hash, (str, bytes)): + if isinstance(cls._related_names, Set): return cls._related_names related_names = set() for name, field in cls.Meta.model_fields.items(): if inspect.isclass(field) and issubclass(field, ForeignKeyField): related_names.add(name) - cls._related_names_hash = json.dumps(list(cls.Meta.model_fields.keys())) cls._related_names = related_names return related_names @@ -241,9 +241,9 @@ class ModelTableProxy: @classmethod def _update_excluded_with_related_not_required( - cls, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny", None], - nested: bool = False, + cls, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny", None], + nested: bool = False, ) -> Union[Set, Dict]: exclude = exclude or {} related_set = cls._exclude_related_names_not_required(nested=nested) @@ -269,18 +269,18 @@ class ModelTableProxy: @staticmethod def resolve_relation_name( # noqa CCR001 - item: Union[ - "NewBaseModel", - Type["NewBaseModel"], - "ModelTableProxy", - Type["ModelTableProxy"], - ], - related: Union[ - "NewBaseModel", - Type["NewBaseModel"], - "ModelTableProxy", - Type["ModelTableProxy"], - ], + item: Union[ + "NewBaseModel", + Type["NewBaseModel"], + "ModelTableProxy", + Type["ModelTableProxy"], + ], + related: Union[ + "NewBaseModel", + Type["NewBaseModel"], + "ModelTableProxy", + Type["ModelTableProxy"], + ], ) -> str: for name, field in item.Meta.model_fields.items(): if issubclass(field, ForeignKeyField): @@ -296,7 +296,7 @@ class ModelTableProxy: @staticmethod def resolve_relation_field( - item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]] + item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]] ) -> Type[BaseField]: name = ModelTableProxy.resolve_relation_name(item, related) to_field = item.Meta.model_fields.get(name) @@ -343,12 +343,12 @@ class ModelTableProxy: for field in one.Meta.model_fields.keys(): current_field = getattr(one, field) if isinstance(current_field, list) and not isinstance( - current_field, ormar.Model + current_field, ormar.Model ): setattr(other, field, current_field + getattr(other, field)) elif ( - isinstance(current_field, ormar.Model) - and current_field.pk == getattr(other, field).pk + isinstance(current_field, ormar.Model) + and current_field.pk == getattr(other, field).pk ): setattr( other, @@ -360,7 +360,7 @@ class ModelTableProxy: @staticmethod def _populate_pk_column( - model: Type["Model"], columns: List[str], use_alias: bool = False, + model: Type["Model"], columns: List[str], use_alias: bool = False, ) -> List[str]: pk_alias = ( model.get_column_alias(model.Meta.pkname) @@ -373,10 +373,10 @@ class ModelTableProxy: @staticmethod def own_table_columns( - model: Type["Model"], - fields: Optional[Union[Set, Dict]], - exclude_fields: Optional[Union[Set, Dict]], - use_alias: bool = False, + model: Type["Model"], + fields: Optional[Union[Set, Dict]], + exclude_fields: Optional[Union[Set, Dict]], + use_alias: bool = False, ) -> List[str]: columns = [ model.get_column_name_from_alias(col.name) if not use_alias else col.name diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 837c10f..77879c4 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -51,9 +51,6 @@ class NewBaseModel( "_orm_id", "_orm_saved", "_orm", - "_related_names", - "_related_names_hash", - "_props", ) if TYPE_CHECKING: # pragma no cover @@ -70,7 +67,7 @@ class NewBaseModel( _orm_saved: bool _related_names: Set _related_names_hash: str - _props: List[str] + _props: Set Meta: ModelMeta # noinspection PyMissingConstructor @@ -158,27 +155,22 @@ class NewBaseModel( self.set_save_status(False) def __getattribute__(self, item: str) -> Any: - if item in ( - "_orm_id", - "_orm_saved", - "_orm", - "__fields__", - "_related_names", - "_props", - ): + if item in object.__getattribute__(self, '_quick_access_fields'): return object.__getattribute__(self, item) if item == "pk": return self.__dict__.get(self.Meta.pkname, None) - if item != "extract_related_names" and item in self.extract_related_names(): + if item in self.extract_related_names(): return self._extract_related_model_instead_of_field(item) - if item != "__fields__" and item in self.__fields__: + if item in self._props: + return object.__getattribute__(self, item) + if item in self.__fields__: value = self.__dict__.get(item, None) value = self._convert_json(item, value, "loads") return value return super().__getattribute__(item) def _extract_related_model_instead_of_field( - self, item: str + self, item: str ) -> Optional[Union["T", Sequence["T"]]]: # alias = self.get_column_alias(item) if item in self._orm: @@ -192,9 +184,9 @@ class NewBaseModel( def __same__(self, other: "NewBaseModel") -> bool: return ( - self._orm_id == other._orm_id - or self.dict() == other.dict() - or (self.pk == other.pk and self.pk is not None) + self._orm_id == other._orm_id + or self.dict() == other.dict() + or (self.pk == other.pk and self.pk is not None) ) @classmethod @@ -228,19 +220,10 @@ class NewBaseModel( @classmethod def get_properties( - cls, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None] + cls, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None] ) -> List[str]: - if isinstance(cls._props, list): - props = cls._props - else: - props = [ - prop - for prop in dir(cls) - if isinstance(getattr(cls, prop), property) - and prop - not in ("__values__", "__fields__", "fields", "pk_column", "saved") - ] - cls._props = props + + props = cls._props if include: props = [prop for prop in props if prop in include] if exclude: @@ -248,7 +231,7 @@ class NewBaseModel( return props def _get_related_not_excluded_fields( - self, include: Optional[Dict], exclude: Optional[Dict], + self, include: Optional[Dict], exclude: Optional[Dict], ) -> List: fields = [field for field in self.extract_related_names()] if include: @@ -263,15 +246,15 @@ class NewBaseModel( @staticmethod def _extract_nested_models_from_list( - models: MutableSequence, - include: Union[Set, Dict, None], - exclude: Union[Set, Dict, None], + models: MutableSequence, + include: Union[Set, Dict, None], + exclude: Union[Set, Dict, None], ) -> List: result = [] for model in models: try: result.append( - model.dict(nested=True, include=include, exclude=exclude,) + model.dict(nested=True, include=include, exclude=exclude, ) ) except ReferenceError: # pragma no cover continue @@ -279,17 +262,17 @@ class NewBaseModel( @staticmethod def _skip_ellipsis( - items: Union[Set, Dict, None], key: str + items: Union[Set, Dict, None], key: str ) -> Union[Set, Dict, None]: result = Excludable.get_child(items, key) return result if result is not Ellipsis else None def _extract_nested_models( # noqa: CCR001 - self, - nested: bool, - dict_instance: Dict, - include: Optional[Dict], - exclude: Optional[Dict], + self, + nested: bool, + dict_instance: Dict, + include: Optional[Dict], + exclude: Optional[Dict], ) -> Dict: fields = self._get_related_not_excluded_fields(include=include, exclude=exclude) @@ -315,16 +298,16 @@ class NewBaseModel( return dict_instance def dict( # type: ignore # noqa A003 - self, - *, - include: Union[Set, Dict] = None, - exclude: Union[Set, Dict] = None, - by_alias: bool = False, - skip_defaults: bool = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - nested: bool = False, + self, + *, + include: Union[Set, Dict] = None, + exclude: Union[Set, Dict] = None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + nested: bool = False, ) -> "DictStrAny": # noqa: A003' dict_instance = super().dict( include=include, @@ -349,9 +332,10 @@ class NewBaseModel( ) # include model properties as fields - props = self.get_properties(include=include, exclude=exclude) - if props: - dict_instance.update({prop: getattr(self, prop) for prop in props}) + if self.Meta.include_props_in_dict: + props = self.get_properties(include=include, exclude=exclude) + if props: + dict_instance.update({prop: getattr(self, prop) for prop in props}) return dict_instance @@ -379,4 +363,4 @@ class NewBaseModel( return value def _is_conversion_to_json_needed(self, column_name: str) -> bool: - return self.Meta.model_fields[column_name].__type__ == pydantic.Json + return column_name in self.Meta.model_fields and self.Meta.model_fields[column_name].__type__ == pydantic.Json diff --git a/tests/test_excluding_fields_in_fastapi.py b/tests/test_excluding_fields_in_fastapi.py index c9ebc46..c04ac94 100644 --- a/tests/test_excluding_fields_in_fastapi.py +++ b/tests/test_excluding_fields_in_fastapi.py @@ -64,6 +64,8 @@ class RandomModel(ormar.Model): metadata = metadata database = database + include_props_in_fields = True + id: int = ormar.Integer(primary_key=True) password: str = ormar.String(max_length=255, default=gen_pass) first_name: str = ormar.String(max_length=255, default="John") @@ -72,6 +74,10 @@ class RandomModel(ormar.Model): server_default=sqlalchemy.func.now() ) + @property + def full_name(self): + return ' '.join([self.first_name, self.last_name]) + class User(ormar.Model): class Meta: @@ -136,6 +142,12 @@ async def create_user5(user: RandomModel): return await user.save() +@app.post("/random2/", response_model=RandomModel) +async def create_user6(user: RandomModel): + user = await user.save() + return user.dict() + + def test_all_endpoints(): client = TestClient(app) with client as client: @@ -187,4 +199,30 @@ def test_all_endpoints(): "first_name", "last_name", "created_date", + "full_name" + ] + assert response.json().get("full_name") == "John Test" + + RandomModel.Meta.include_props_in_fields = False + user3 = {"last_name": "Test"} + response = client.post("/random/", json=user3) + assert list(response.json().keys()) == [ + "id", + "password", + "first_name", + "last_name", + "created_date", + "full_name" + ] + + RandomModel.Meta.include_props_in_dict = True + user3 = {"last_name": "Test"} + response = client.post("/random2/", json=user3) + assert list(response.json().keys()) == [ + "id", + "password", + "first_name", + "last_name", + "created_date", + "full_name" ] diff --git a/tests/test_pydantic_only_fields.py b/tests/test_pydantic_only_fields.py new file mode 100644 index 0000000..19ac37f --- /dev/null +++ b/tests/test_pydantic_only_fields.py @@ -0,0 +1,75 @@ +import datetime + +import databases +import pytest +import sqlalchemy +from pydantic import validator + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Album(ormar.Model): + class Meta: + tablename = "albums" + metadata = metadata + database = database + include_props_in_dict = True + include_props_in_fields = True + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + timestamp: datetime.datetime = ormar.DateTime(pydantic_only=True) + + @property + def name10(self) -> str: + return self.name + '_10' + + @validator('name') + def test(cls, v): + return v + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_pydantic_only_fields(): + async with database: + async with database.transaction(force_rollback=True): + album = await Album.objects.create(name='Hitchcock') + assert album.pk is not None + assert album.saved + assert album.timestamp is None + + album = await Album.objects.exclude_fields('timestamp').get() + assert album.timestamp is None + + album = await Album.objects.fields({'name', 'timestamp'}).get() + assert album.timestamp is None + + test_dict = album.dict() + assert 'timestamp' in test_dict + assert test_dict['timestamp'] is None + + album.timestamp = datetime.datetime.now() + test_dict = album.dict() + assert 'timestamp' in test_dict + assert test_dict['timestamp'] is not None + assert test_dict.get('name10') == 'Hitchcock_10' + + Album.Meta.include_props_in_dict = False + test_dict = album.dict() + assert 'timestamp' in test_dict + assert test_dict['timestamp'] is not None + # key is still there as now it's a field + assert test_dict['name10'] is None