diff --git a/ormar/__init__.py b/ormar/__init__.py index 41913b8..c970f7d 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -1,3 +1,4 @@ +from ormar.decorators import property_field from ormar.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch from ormar.protocols import QuerySetProtocol, RelationProtocol # noqa: I100 from ormar.fields import ( # noqa: I100 @@ -58,4 +59,5 @@ __all__ = [ "QuerySetProtocol", "RelationProtocol", "ModelMeta", + "property_field", ] diff --git a/ormar/decorators/__init__.py b/ormar/decorators/__init__.py new file mode 100644 index 0000000..7dfbe5e --- /dev/null +++ b/ormar/decorators/__init__.py @@ -0,0 +1,5 @@ +from ormar.decorators.property_field import property_field + +__all__ = [ + "property_field", +] diff --git a/ormar/decorators/property_field.py b/ormar/decorators/property_field.py new file mode 100644 index 0000000..8d6a2e2 --- /dev/null +++ b/ormar/decorators/property_field.py @@ -0,0 +1,19 @@ +import inspect +from collections.abc import Callable +from typing import Union + +from ormar.exceptions import ModelDefinitionError + + +def property_field(func: Callable) -> Union[property, Callable]: + if isinstance(func, property): # pragma: no cover + func.fget.__property_field__ = True + else: + arguments = list(inspect.signature(func).parameters.keys()) + if len(arguments) > 1 or arguments[0] != "self": + raise ModelDefinitionError( + "property_field decorator can be used " + "only on class methods with no arguments" + ) + func.__dict__["__property_field__"] = True + return func diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 55b0bbb..8460e13 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -1,7 +1,6 @@ -import inspect import logging import warnings -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Type, Union import databases import pydantic @@ -16,7 +15,6 @@ from ormar import ForeignKey, Integer, ModelDefinitionError # noqa I100 from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.many_to_many import ManyToMany, ManyToManyField -from ormar.fields.model_fields import ModelFieldFactory from ormar.models.quick_access_views import quick_access_set from ormar.queryset import QuerySet from ormar.relations.alias_manager import AliasManager @@ -39,8 +37,7 @@ class ModelMeta: str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] ] alias_manager: AliasManager - include_props_in_dict: bool - include_props_in_fields: bool + property_fields: Set def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None: @@ -48,7 +45,7 @@ def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> def register_many_to_many_relation_on_build( - table_name: str, field: Type[ManyToManyField] + table_name: str, field: Type[ManyToManyField] ) -> None: alias_manager.add_relation_type(field.through.Meta.tablename, table_name) alias_manager.add_relation_type( @@ -57,11 +54,11 @@ def register_many_to_many_relation_on_build( def reverse_field_not_already_registered( - child: Type["Model"], child_model_name: str, parent_model: Type["Model"] + child: Type["Model"], child_model_name: str, parent_model: Type["Model"] ) -> bool: return ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ) @@ -72,7 +69,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if reverse_field_not_already_registered( - child, child_model_name, parent_model + child, child_model_name, parent_model ): register_reverse_model_fields( parent_model, child, child_model_name, model_field @@ -80,10 +77,10 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: def register_reverse_model_fields( - model: Type["Model"], - child: Type["Model"], - child_model_name: str, - model_field: Type["ForeignKeyField"], + model: Type["Model"], + child: Type["Model"], + child_model_name: str, + model_field: Type["ForeignKeyField"], ) -> None: if issubclass(model_field, ManyToManyField): model.Meta.model_fields[child_model_name] = ManyToMany( @@ -98,7 +95,7 @@ def register_reverse_model_fields( def adjust_through_many_to_many_model( - model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( model, real_name=model.get_name(), ondelete="CASCADE" @@ -115,7 +112,7 @@ def adjust_through_many_to_many_model( def create_pydantic_field( - field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] + field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.__fields__[field_name] = ModelField( name=field_name, @@ -137,7 +134,7 @@ def get_pydantic_field(field_name: str, model: Type["Model"]) -> "ModelField": def create_and_append_m2m_fk( - model: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: column = sqlalchemy.Column( model.get_name(), @@ -153,7 +150,7 @@ def create_and_append_m2m_fk( def check_pk_column_validity( - field_name: str, field: BaseField, pkname: Optional[str] + field_name: str, field: BaseField, pkname: Optional[str] ) -> Optional[str]: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") @@ -163,7 +160,7 @@ def check_pk_column_validity( def sqlalchemy_columns_from_model_fields( - model_fields: Dict, table_name: str + model_fields: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column]]: columns = [] pkname = None @@ -177,9 +174,9 @@ def sqlalchemy_columns_from_model_fields( if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) if ( - not field.pydantic_only - and not field.virtual - and not issubclass(field, ManyToManyField) + not field.pydantic_only + and not field.virtual + and not issubclass(field, ManyToManyField) ): columns.append(field.get_column(field.get_alias())) register_relation_in_alias_manager(table_name, field) @@ -187,7 +184,7 @@ def sqlalchemy_columns_from_model_fields( def register_relation_in_alias_manager( - table_name: str, field: Type[ForeignKeyField] + table_name: str, field: Type[ForeignKeyField] ) -> None: if issubclass(field, ManyToManyField): register_many_to_many_relation_on_build(table_name, field) @@ -196,7 +193,7 @@ def register_relation_in_alias_manager( def populate_default_pydantic_field_value( - ormar_field: Type[BaseField], field_name: str, attrs: dict + ormar_field: Type[BaseField], field_name: str, attrs: dict ) -> dict: curr_def_value = attrs.get(field_name, ormar.Undefined) if lenient_issubclass(curr_def_value, ormar.fields.BaseField): @@ -243,7 +240,7 @@ def extract_annotations_and_default_vals(attrs: dict) -> Tuple[Dict, Dict]: def populate_meta_tablename_columns_and_pk( - name: str, new_model: Type["Model"] + name: str, new_model: Type["Model"] ) -> Type["Model"]: tablename = name.lower() + "s" new_model.Meta.tablename = ( @@ -269,7 +266,7 @@ def populate_meta_tablename_columns_and_pk( def populate_meta_sqlalchemy_table_if_required( - new_model: Type["Model"], + new_model: Type["Model"], ) -> Type["Model"]: if not hasattr(new_model.Meta, "table"): new_model.Meta.table = sqlalchemy.Table( @@ -321,52 +318,42 @@ def populate_choices_validators(model: Type["Model"]) -> None: # noqa CCR001 def populate_default_options_values( - new_model: Type["Model"], model_fields: Dict + new_model: Type["Model"], model_fields: Dict ) -> None: if not hasattr(new_model.Meta, "constraints"): new_model.Meta.constraints = [] if not hasattr(new_model.Meta, "model_fields"): new_model.Meta.model_fields = model_fields - if not hasattr(new_model.Meta, "include_props_in_dict"): - new_model.Meta.include_props_in_dict = True - if not hasattr(new_model.Meta, "include_props_in_fields"): - new_model.Meta.include_props_in_fields = False - def add_cached_properties(new_model: Type["Model"]) -> None: - new_model._props = { - prop[0] - for prop in inspect.getmembers(new_model, lambda o: isinstance(o, property)) - if prop[0] not in ("__values__", "__fields__", "fields", "pk_column", "saved") - } new_model._quick_access_fields = quick_access_set new_model._related_names = None new_model._pydantic_fields = {name for name in new_model.__fields__} -def add_property_fields(new_model: Type["Model"]) -> None: - pass +def property_fields_not_set(new_model: Type["Model"]) -> bool: + return ( + not hasattr(new_model.Meta, "property_fields") + or not new_model.Meta.property_fields + ) -# if new_model.Meta.include_props_in_fields: -# for prop in new_model._props: -# field_type = getattr(new_model, prop).fget.__annotations__.get("return") -# new_model.Meta.model_fields[prop] = ModelFieldFactory( # type: ignore -# nullable=True, pydantic_only=True -# ) -# new_model.__fields__[prop] = ModelField( -# name=prop, -# type_=Optional[field_type] if field_type is not None else Any, # type: ignore -# model_config=new_model.__config__, -# required=False, -# class_validators={}, -# ) +def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: CCR001 + if property_fields_not_set(new_model): + props = set() + for var_name, value in attrs.items(): + if isinstance(value, property): + value = value.fget + field_config = getattr(value, "__property_field__", None) + if field_config: + props.add(var_name) + new_model.Meta.property_fields = props class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__( # type: ignore - mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict + mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict ) -> "ModelMetaclass": attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name @@ -391,6 +378,6 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): ) new_model.Meta.alias_manager = alias_manager new_model.objects = QuerySet(new_model) - add_property_fields(new_model) + add_property_fields(new_model, attrs) return new_model diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index eee9ae9..3a82230 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -45,7 +45,7 @@ class ModelTableProxy: pk: Any get_name: Callable _props: Set - dict: Callable + dict: Callable # noqa: A001, VNE003 def _extract_own_model_fields(self) -> Dict: related_names = self.extract_related_names() diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 7a7ee59..35a5d57 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -1,4 +1,3 @@ -import inspect import json import uuid from typing import ( @@ -68,14 +67,12 @@ class NewBaseModel( _orm_saved: bool _related_names: Optional[Set] _related_names_hash: str - _props: Set _pydantic_fields: Set _quick_access_fields: Set Meta: ModelMeta # noinspection PyMissingConstructor def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore - caller_name = inspect.currentframe().f_back.f_code.co_name object.__setattr__(self, "_orm_id", uuid.uuid4().hex) object.__setattr__(self, "_orm_saved", False) object.__setattr__( @@ -96,9 +93,14 @@ class NewBaseModel( if "pk" in kwargs: kwargs[self.Meta.pkname] = kwargs.pop("pk") + + # remove property fields values from validation + kwargs = { + k: v + for k, v in kwargs.items() + if k not in object.__getattribute__(self, "Meta").property_fields + } # build the models to set them and validate but don't register - if self.Meta.include_props_in_dict: - kwargs = {k: v for k, v in kwargs.items() if k not in object.__getattribute__(self, '_props')} try: new_kwargs: Dict[str, Any] = { k: self._convert_json( @@ -136,7 +138,7 @@ class NewBaseModel( ) def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 - if name in ("_orm_id", "_orm_saved", "_orm", "_related_names", "_props"): + if name in object.__getattribute__(self, "_quick_access_fields"): object.__setattr__(self, name, value) elif name == "pk": object.__setattr__(self, self.Meta.pkname, value) @@ -168,8 +170,9 @@ class NewBaseModel( return object.__getattribute__( self, "_extract_related_model_instead_of_field" )(item) - if item in object.__getattribute__(self, "_props"): - return object.__getattribute__(self, item) + if item in object.__getattribute__(self, "Meta").property_fields: + value = object.__getattribute__(self, item) + return value() if callable(value) else value if item in object.__getattribute__(self, "_pydantic_fields"): value = object.__getattribute__(self, "__dict__").get(item, None) value = object.__getattribute__(self, "_convert_json")(item, value, "loads") @@ -177,7 +180,7 @@ class NewBaseModel( return object.__getattribute__(self, item) # pragma: no cover def _extract_related_model_instead_of_field( - self, item: str + self, item: str ) -> Optional[Union["T", Sequence["T"]]]: if item in self._orm: return self._orm.get(item) @@ -190,9 +193,9 @@ class NewBaseModel( def __same__(self, other: "NewBaseModel") -> bool: return ( - self._orm_id == other._orm_id - or (self.pk == other.pk and self.pk is not None) - or self.dict() == other.dict() + self._orm_id == other._orm_id + or (self.pk == other.pk and self.pk is not None) + or self.dict() == other.dict() ) @classmethod @@ -226,10 +229,10 @@ class NewBaseModel( @classmethod def get_properties( - cls, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None] + cls, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None] ) -> Set[str]: - props = cls._props + props = cls.Meta.property_fields if include: props = {prop for prop in props if prop in include} if exclude: @@ -237,7 +240,7 @@ class NewBaseModel( return props def _get_related_not_excluded_fields( - self, include: Optional[Dict], exclude: Optional[Dict], + self, include: Optional[Dict], exclude: Optional[Dict], ) -> List: fields = [field for field in self.extract_related_names()] if include: @@ -252,15 +255,15 @@ class NewBaseModel( @staticmethod def _extract_nested_models_from_list( - models: MutableSequence, - include: Union[Set, Dict, None], - exclude: Union[Set, Dict, None], + models: MutableSequence, + include: Union[Set, Dict, None], + exclude: Union[Set, Dict, None], ) -> List: result = [] for model in models: try: result.append( - model.dict(nested=True, include=include, exclude=exclude, ) + model.dict(nested=True, include=include, exclude=exclude,) ) except ReferenceError: # pragma no cover continue @@ -268,17 +271,17 @@ class NewBaseModel( @staticmethod def _skip_ellipsis( - items: Union[Set, Dict, None], key: str + items: Union[Set, Dict, None], key: str ) -> Union[Set, Dict, None]: result = Excludable.get_child(items, key) return result if result is not Ellipsis else None def _extract_nested_models( # noqa: CCR001 - self, - nested: bool, - dict_instance: Dict, - include: Optional[Dict], - exclude: Optional[Dict], + self, + nested: bool, + dict_instance: Dict, + include: Optional[Dict], + exclude: Optional[Dict], ) -> Dict: fields = self._get_related_not_excluded_fields(include=include, exclude=exclude) @@ -304,16 +307,16 @@ class NewBaseModel( return dict_instance def dict( # type: ignore # noqa A003 - self, - *, - include: Union[Set, Dict] = None, - exclude: Union[Set, Dict] = None, - by_alias: bool = False, - skip_defaults: bool = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - nested: bool = False, + self, + *, + include: Union[Set, Dict] = None, + exclude: Union[Set, Dict] = None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + nested: bool = False, ) -> "DictStrAny": # noqa: A003' # callable_name = inspect.currentframe().f_back.f_code.co_name # print('dict', callable_name) @@ -339,8 +342,8 @@ class NewBaseModel( exclude=exclude, # type: ignore ) - # include model properties as fields - if self.Meta.include_props_in_dict: + # include model properties as fields in dict + if object.__getattribute__(self, "Meta").property_fields: props = self.get_properties(include=include, exclude=exclude) if props: dict_instance.update({prop: getattr(self, prop) for prop in props}) @@ -372,6 +375,6 @@ class NewBaseModel( def _is_conversion_to_json_needed(self, column_name: str) -> bool: return ( - column_name in self.Meta.model_fields - and self.Meta.model_fields[column_name].__type__ == pydantic.Json + column_name in self.Meta.model_fields + and self.Meta.model_fields[column_name].__type__ == pydantic.Json ) diff --git a/ormar/models/quick_access_views.py b/ormar/models/quick_access_views.py index bbab89e..471c809 100644 --- a/ormar/models/quick_access_views.py +++ b/ormar/models/quick_access_views.py @@ -25,7 +25,6 @@ quick_access_set = { "_orm", "_orm_id", "_orm_saved", - "_props", "_related_names", "_skip_ellipsis", "_update_and_follow", diff --git a/tests/test_excluding_fields_in_fastapi.py b/tests/test_excluding_fields_in_fastapi.py index cb801ce..b9398ca 100644 --- a/tests/test_excluding_fields_in_fastapi.py +++ b/tests/test_excluding_fields_in_fastapi.py @@ -10,6 +10,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient import ormar +from ormar import property_field from tests.settings import DATABASE_URL app = FastAPI() @@ -74,8 +75,8 @@ class RandomModel(ormar.Model): server_default=sqlalchemy.func.now() ) - @property - def full_name(self): + @property_field + def full_name(self) -> str: return " ".join([self.first_name, self.last_name]) @@ -140,7 +141,6 @@ async def create_user4(user: User2): @app.post("/random/", response_model=RandomModel) async def create_user5(user: RandomModel): user = await user.save() - print('returning') return user @@ -170,24 +170,20 @@ def test_excluding_fields_in_endpoints(): "last_name": "Doe", } - print('before call') response = client.post("/users/", json=user2) created_user = User(**response.json()) assert created_user.pk is not None assert created_user.password is None - print('before call') response = client.post("/users2/", json=user) created_user2 = User(**response.json()) assert created_user2.pk is not None assert created_user2.password is None # response has only 3 fields from UserBase - print('before call') response = client.post("/users3/", json=user) assert list(response.json().keys()) == ["email", "first_name", "last_name"] - print('before call') response = client.post("/users4/", json=user) assert list(response.json().keys()) == [ "id", @@ -197,40 +193,41 @@ def test_excluding_fields_in_endpoints(): "category", ] - # user3 = {"last_name": "Test"} - # print('before call') - # response = client.post("/random/", json=user3) - # assert list(response.json().keys()) == [ - # "id", - # "password", - # "first_name", - # "last_name", - # "created_date", - # "full_name", - # ] - # assert response.json().get("full_name") == "John Test" - # - # RandomModel.Meta.include_props_in_fields = False - # user3 = {"last_name": "Test"} - # print('before call') - # response = client.post("/random/", json=user3) - # assert list(response.json().keys()) == [ - # "id", - # "password", - # "first_name", - # "last_name", - # "created_date", - # "full_name", - # ] - # assert response.json().get("full_name") == "John Test" - def test_adding_fields_in_endpoints(): + client = TestClient(app) + with client as client: + user3 = {"last_name": "Test"} + response = client.post("/random/", json=user3) + assert list(response.json().keys()) == [ + "id", + "password", + "first_name", + "last_name", + "created_date", + "full_name", + ] + assert response.json().get("full_name") == "John Test" + + RandomModel.Meta.include_props_in_fields = False + user3 = {"last_name": "Test"} + response = client.post("/random/", json=user3) + assert list(response.json().keys()) == [ + "id", + "password", + "first_name", + "last_name", + "created_date", + "full_name", + ] + assert response.json().get("full_name") == "John Test" + + +def test_adding_fields_in_endpoints2(): client = TestClient(app) with client as client: RandomModel.Meta.include_props_in_dict = True user3 = {"last_name": "Test"} - print('before call') response = client.post("/random2/", json=user3) assert list(response.json().keys()) == [ "id", diff --git a/tests/test_properties.py b/tests/test_properties.py index b89db67..6456c6a 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -1,8 +1,10 @@ +# type: ignore import databases import pytest import sqlalchemy import ormar +from ormar import ModelDefinitionError, property_field from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -19,15 +21,15 @@ class Song(ormar.Model): name: str = ormar.String(max_length=100) sort_order: int = ormar.Integer() - @property + @property_field def sorted_name(self): return f"{self.sort_order}: {self.name}" - @property + @property_field def sample(self): return "sample" - @property + @property_field def sample2(self): return "sample2" @@ -66,3 +68,12 @@ async def test_sort_order_on_main_model(): assert "sample" not in check_include assert "sample2" in check_include assert "sorted_name" in check_include + + +def test_wrong_definition(): + with pytest.raises(ModelDefinitionError): + + class WrongModel(ormar.Model): # pragma: no cover + @property_field + def test(self, aa=10, bb=30): + pass diff --git a/tests/test_pydantic_only_fields.py b/tests/test_pydantic_only_fields.py index 6c6b14e..ee58177 100644 --- a/tests/test_pydantic_only_fields.py +++ b/tests/test_pydantic_only_fields.py @@ -3,9 +3,9 @@ import datetime import databases import pytest import sqlalchemy -from pydantic import validator import ormar +from ormar import property_field from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -17,17 +17,27 @@ class Album(ormar.Model): tablename = "albums" metadata = metadata database = database - include_props_in_dict = True - include_props_in_fields = True id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) timestamp: datetime.datetime = ormar.DateTime(pydantic_only=True) - @property + @property_field def name10(self) -> str: return self.name + "_10" + @property_field + def name20(self) -> str: + return self.name + "_20" + + @property + def name30(self) -> str: + return self.name + "_30" + + @property_field + def name40(self) -> str: + return self.name + "_40" + @pytest.fixture(autouse=True, scope="module") def create_test_database(): @@ -57,14 +67,13 @@ async def test_pydantic_only_fields(): assert "timestamp" in test_dict assert test_dict["timestamp"] is None + assert album.name30 == "Hitchcock_30" + album.timestamp = datetime.datetime.now() test_dict = album.dict() assert "timestamp" in test_dict assert test_dict["timestamp"] is not None assert test_dict.get("name10") == "Hitchcock_10" - - Album.Meta.include_props_in_dict = False - test_dict = album.dict() - assert "timestamp" in test_dict - assert test_dict["timestamp"] is not None - assert test_dict.get("name10", 'aa') == 'aa' + assert test_dict.get("name20") == "Hitchcock_20" + assert test_dict.get("name40") == "Hitchcock_40" + assert "name30" not in test_dict