diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 5649d22..2a3cccf 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -6,7 +6,7 @@ import databases import pydantic import sqlalchemy from pydantic import BaseConfig -from pydantic.fields import ModelField +from pydantic.fields import FieldInfo, ModelField from pydantic.utils import lenient_issubclass from sqlalchemy.sql.schema import ColumnCollectionConstraint @@ -24,6 +24,7 @@ if TYPE_CHECKING: # pragma no cover from ormar import Model alias_manager = AliasManager() +PARSED_FIELDS_KEY = "__parsed_fields__" class ModelMeta: @@ -221,9 +222,7 @@ def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]: DeprecationWarning, ) - potential_fields.update( - {k: v for k, v in attrs.items() if lenient_issubclass(v, BaseField)} - ) + potential_fields.update(get_potential_fields(attrs)) for field_name, field in potential_fields.items(): field.name = field_name attrs = populate_default_pydantic_field_value(field, field_name, attrs) @@ -331,6 +330,7 @@ def populate_default_options_values( def add_cached_properties(new_model: Type["Model"]) -> None: new_model._quick_access_fields = quick_access_set new_model._related_names = None + new_model._related_fields = None new_model._pydantic_fields = {name for name in new_model.__fields__} @@ -362,6 +362,161 @@ def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001 new_model.Meta.signals = signals +def get_potential_fields(attrs: Dict) -> Dict: + """ + Gets all the fields in current class namespace that are Fields. + + :param attrs: current class namespace + :type attrs: Dict + :return: extracted fields that are ormar Fields + :rtype: Dict + """ + return {k: v for k, v in attrs.items() if lenient_issubclass(v, BaseField)} + + +def check_conflicting_fields( + new_fields: Set, attrs: Dict, base_class: type, curr_class: type +) -> None: + """ + You cannot redefine fields with same names in inherited classes. + Ormar will raise an exception if it encounters a field that is an ormar + Field and at the same time was already declared in one of base classes. + + :param new_fields: set of names of fields defined in current model + :type new_fields: Set[str] + :param attrs: namespace of current class + :type attrs: Dict + :param base_class: one of the parent classes + :type base_class: Model or model parent class + :param curr_class: current constructed class + :type curr_class: Model or model parent class + """ + previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)}) + overwrite = new_fields.intersection(previous_fields) + + if overwrite: + raise ModelDefinitionError( + f"Model {curr_class} redefines the fields: " + f"{overwrite} already defined in {base_class}!" + ) + + +def update_attrs_and_fields( + attrs: Dict, + new_attrs: Dict, + model_fields: Dict, + new_model_fields: Dict, + new_fields: Set, +) -> None: + """ + Updates __annotations__, values of model fields (so pydantic FieldInfos) + as well as model.Meta.model_fields definitions from parents. + + :param attrs: new namespace for class being constructed + :type attrs: Dict + :param new_attrs: part of the namespace extracted from parent class + :type new_attrs: Dict + :param model_fields: ormar fields in defined in current class + :type model_fields: Dict[str, BaseField] + :param new_model_fields: ormar fields defined in parent classes + :type new_model_fields: Dict[str, BaseField] + :param new_fields: set of new fields names + :type new_fields: Set[str] + """ + key = "__annotations__" + attrs[key].update(new_attrs[key]) + attrs.update({name: new_attrs[name] for name in new_fields}) + model_fields.update(new_model_fields) + + +def extract_mixin_fields_from_dict( + base_class: type, + curr_class: type, + attrs: Dict, + model_fields: Dict[ + str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] + ], +) -> Tuple[Dict, Dict]: + """ + Extracts fields from base classes if they have valid oramr fields. + + If model was already parsed -> fields definitions need to be removed from class + cause pydantic complains about field re-definition so after first child + we need to extract from __parsed_fields__ not the class itself. + + If the class is parsed first time annotations and field definition is parsed + from the class.__dict__. + + If the class is a ormar.Model it is skipped. + + :param base_class: one of the parent classes + :type base_class: Model or model parent class + :param curr_class: current constructed class + :type curr_class: Model or model parent class + :param attrs: new namespace for class being constructed + :type attrs: Dict + :param model_fields: ormar fields in defined in current class + :type model_fields: Dict[str, BaseField] + :return: updated attrs and model_fields + :rtype: Tuple[Dict, Dict] + """ + if hasattr(base_class, "Meta"): + # not a mixin base parent Model + return attrs, model_fields + + key = "__annotations__" + if hasattr(base_class, PARSED_FIELDS_KEY): + # model was already parsed -> fields definitions need to be removed from class + # cause pydantic complains about field re-definition so after first child + # we need to extract from __parsed_fields__ not the class itself + new_attrs, new_model_fields = getattr(base_class, PARSED_FIELDS_KEY) + + new_fields = set(new_model_fields.keys()) + check_conflicting_fields( + new_fields=new_fields, + attrs=attrs, + base_class=base_class, + curr_class=curr_class, + ) + + update_attrs_and_fields( + attrs=attrs, + new_attrs=new_attrs, + model_fields=model_fields, + new_model_fields=new_model_fields, + new_fields=new_fields, + ) + return attrs, model_fields + + potential_fields = get_potential_fields(base_class.__dict__) + if potential_fields: + # parent model has ormar fields defined and was not parsed before + new_attrs = {key: base_class.__dict__.get(key, {})} + new_attrs.update(potential_fields) + + new_fields = set(potential_fields.keys()) + check_conflicting_fields( + new_fields=new_fields, + attrs=attrs, + base_class=base_class, + curr_class=curr_class, + ) + for name in new_fields: + delattr(base_class, name) + + new_attrs, new_model_fields = extract_annotations_and_default_vals(new_attrs) + setattr(base_class, PARSED_FIELDS_KEY, (new_attrs, new_model_fields)) + + update_attrs_and_fields( + attrs=attrs, + new_attrs=new_attrs, + model_fields=model_fields, + new_model_fields=new_model_fields, + new_fields=new_fields, + ) + return attrs, model_fields + + class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__( # type: ignore mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict @@ -369,9 +524,15 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name attrs, model_fields = extract_annotations_and_default_vals(attrs) + for base in reversed(bases): + attrs, model_fields = extract_mixin_fields_from_dict( + base_class=base, curr_class=mcs, attrs=attrs, model_fields=model_fields + ) + # print(attrs, model_fields) new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) + add_cached_properties(new_model) if hasattr(new_model, "Meta"): diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index b545ecf..e330028 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -41,7 +41,7 @@ class ModelTableProxy: if TYPE_CHECKING: # pragma no cover Meta: ModelMeta _related_names: Optional[Set] - _related_names_hash: Union[str, bytes] + _related_fields: Optional[List] pk: Any get_name: Callable _props: Set @@ -202,6 +202,19 @@ class ModelTableProxy: return field_name return alias # if not found it's not an alias but actual name + @classmethod + def extract_related_fields(cls) -> List: + + if isinstance(cls._related_fields, List): + return cls._related_fields + + related_fields = [] + for name in cls.extract_related_names(): + related_fields.append(cls.Meta.model_fields[name]) + cls._related_fields = related_fields + + return related_fields + @classmethod def extract_related_names(cls) -> Set: diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 4ab1b0b..4bfecab 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -28,7 +28,6 @@ from pydantic import BaseModel import ormar # noqa I100 from ormar.exceptions import ModelError from ormar.fields import BaseField -from ormar.fields.foreign_key import ForeignKeyField from ormar.models.excludable import Excludable from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.modelproxy import ModelTableProxy @@ -79,14 +78,7 @@ class NewBaseModel( object.__setattr__( self, "_orm", - RelationsManager( - related_fields=[ - field - for name, field in self.Meta.model_fields.items() - if issubclass(field, ForeignKeyField) - ], - owner=self, - ), + RelationsManager(related_fields=self.extract_related_fields(), owner=self,), ) pk_only = kwargs.pop("__pk_only__", False) diff --git a/tests/test_inheritance_mixins.py b/tests/test_inheritance_mixins.py new file mode 100644 index 0000000..80de023 --- /dev/null +++ b/tests/test_inheritance_mixins.py @@ -0,0 +1,133 @@ +# type: ignore +import datetime +from typing import Optional + +import databases +import pytest +import sqlalchemy as sa +from sqlalchemy import create_engine + +import ormar +from ormar import ModelDefinitionError +from tests.settings import DATABASE_URL + +metadata = sa.MetaData() +db = databases.Database(DATABASE_URL) +engine = create_engine(DATABASE_URL) + + +class AuditMixin: + created_by: str = ormar.String(max_length=100) + updated_by: str = ormar.String(max_length=100) + + +class DateFieldsMixins: + created_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) + updated_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) + + +class Category(ormar.Model, DateFieldsMixins, AuditMixin): + class Meta(ormar.ModelMeta): + tablename = "categories" + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=50, unique=True, index=True) + code: int = ormar.Integer() + + +class Subject(ormar.Model, DateFieldsMixins): + class Meta(ormar.ModelMeta): + tablename = "subjects" + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=50, unique=True, index=True) + category: Optional[Category] = ormar.ForeignKey(Category) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_field_redefining_raises_error(): + with pytest.raises(ModelDefinitionError): + + class WrongField(ormar.Model, DateFieldsMixins): # pragma: no cover + class Meta(ormar.ModelMeta): + tablename = "wrongs" + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + created_date: datetime.datetime = ormar.DateTime() + + +def test_field_redefining_in_second_raises_error(): + class OkField(ormar.Model, DateFieldsMixins): # pragma: no cover + class Meta(ormar.ModelMeta): + tablename = "oks" + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + + with pytest.raises(ModelDefinitionError): + + class WrongField(ormar.Model, DateFieldsMixins): # pragma: no cover + class Meta(ormar.ModelMeta): + tablename = "wrongs" + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + created_date: datetime.datetime = ormar.DateTime() + + +@pytest.mark.asyncio +async def test_fields_inherited_from_mixin(): + async with db: + async with db.transaction(force_rollback=True): + cat = await Category( + name="Foo", code=123, created_by="Sam", updated_by="Max" + ).save() + sub = await Subject(name="Bar", category=cat).save() + mixin_columns = ["created_date", "updated_date"] + mixin2_columns = ["created_by", "updated_by"] + assert all(field in Category.Meta.model_fields for field in mixin_columns) + assert cat.created_date is not None + assert cat.updated_date is not None + assert all(field in Subject.Meta.model_fields for field in mixin_columns) + assert sub.created_date is not None + assert sub.updated_date is not None + + assert all(field in Category.Meta.model_fields for field in mixin2_columns) + assert all( + field not in Subject.Meta.model_fields for field in mixin2_columns + ) + + inspector = sa.inspect(engine) + assert "categories" in inspector.get_table_names() + table_columns = [x.get("name") for x in inspector.get_columns("categories")] + assert all(col in table_columns for col in mixin_columns + mixin2_columns) + + assert "subjects" in inspector.get_table_names() + table_columns = [x.get("name") for x in inspector.get_columns("subjects")] + assert all(col in table_columns for col in mixin_columns) + + sub2 = ( + await Subject.objects.select_related("category") + .order_by("-created_date") + .exclude_fields("updated_date") + .get() + ) + assert sub2.created_date == sub.created_date + assert sub2.category.updated_date is not None + assert sub2.category.created_date == cat.created_date + assert sub2.updated_date is None + assert sub2.category.created_by == "Sam" diff --git a/tests/test_saving_related.py b/tests/test_saving_related.py index 26c7000..6dd4fd2 100644 --- a/tests/test_saving_related.py +++ b/tests/test_saving_related.py @@ -55,7 +55,7 @@ async def test_model_relationship(): assert ws.topic == "Topic 1" assert ws.category.name == "Foo" - ws.topic = 'Topic 2' + ws.topic = "Topic 2" await ws.update() assert ws.id == 1