From c23afd17a05fb1fa68084aa1fd9200f31591e645 Mon Sep 17 00:00:00 2001 From: collerek Date: Thu, 10 Dec 2020 16:09:55 +0100 Subject: [PATCH] first ver of working concrete inheritance --- ormar/models/metaclass.py | 48 +++++++-- tests/test_inheritance_concrete.py | 152 +++++++++++++++++++++++++++++ 2 files changed, 191 insertions(+), 9 deletions(-) create mode 100644 tests/test_inheritance_concrete.py diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 8a6ab3c..653d2e8 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -25,6 +25,7 @@ if TYPE_CHECKING: # pragma no cover alias_manager = AliasManager() PARSED_FIELDS_KEY = "__parsed_fields__" +CONFIG_KEY = "Config" class ModelMeta: @@ -478,13 +479,19 @@ def get_potential_fields(attrs: Dict) -> Dict: def check_conflicting_fields( - new_fields: Set, attrs: Dict, base_class: type, curr_class: type + new_fields: Set, + attrs: Dict, + base_class: type, + curr_class: type, + previous_fields: Set = None, ) -> None: """ You cannot redefine fields with same names in inherited classes. Ormar will raise an exception if it encounters a field that is an ormar Field and at the same time was already declared in one of base classes. + :param previous_fields: set of names of fields defined in base model + :type previous_fields: Set[str] :param new_fields: set of names of fields defined in current model :type new_fields: Set[str] :param attrs: namespace of current class @@ -494,7 +501,8 @@ def check_conflicting_fields( :param curr_class: current constructed class :type curr_class: Model or model parent class """ - previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)}) + if not previous_fields: + previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)}) overwrite = new_fields.intersection(previous_fields) if overwrite: @@ -539,7 +547,8 @@ def extract_mixin_fields_from_dict( model_fields: Dict[ str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] ], -) -> Tuple[Dict, Dict]: + bases: Any, +) -> Tuple[Dict, Dict, Any]: """ Extracts fields from base classes if they have valid oramr fields. @@ -565,14 +574,29 @@ def extract_mixin_fields_from_dict( """ if hasattr(base_class, "Meta"): new_fields = set(base_class.Meta.model_fields.keys()) # type: ignore + previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)}) check_conflicting_fields( new_fields=new_fields, attrs=attrs, base_class=base_class, curr_class=curr_class, + previous_fields=previous_fields, ) + if previous_fields and not base_class.Meta.abstract: # type: ignore + raise ModelDefinitionError( + f"{curr_class.__name__} cannot inherit " + f"from non abstract class {base_class.__name__}" + ) model_fields.update(base_class.Meta.model_fields) # type: ignore - return attrs, model_fields + # keep only parent ormar models as they already have all the predecessors + # keeping also Model, NewBaseModel etc. would cause mro conflicts + new_bases = tuple( + base + for base in bases + if issubclass(base, ormar.Model) and base != ormar.Model + ) + + return attrs, model_fields, new_bases key = "__annotations__" if hasattr(base_class, PARSED_FIELDS_KEY): @@ -596,7 +620,7 @@ def extract_mixin_fields_from_dict( new_model_fields=new_model_fields, new_fields=new_fields, ) - return attrs, model_fields + return attrs, model_fields, bases potential_fields = get_potential_fields(base_class.__dict__) if potential_fields: @@ -624,7 +648,7 @@ def extract_mixin_fields_from_dict( new_model_fields=new_model_fields, new_fields=new_fields, ) - return attrs, model_fields + return attrs, model_fields, bases class ModelMetaclass(pydantic.main.ModelMetaclass): @@ -634,13 +658,19 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name attrs, model_fields = extract_annotations_and_default_vals(attrs) + new_bases = bases for base in reversed(bases): - attrs, model_fields = extract_mixin_fields_from_dict( - base_class=base, curr_class=mcs, attrs=attrs, model_fields=model_fields + attrs, model_fields, new_bases = extract_mixin_fields_from_dict( + base_class=base, + curr_class=mcs, + attrs=attrs, + model_fields=model_fields, + bases=new_bases, ) # print(attrs, model_fields) + new_model = super().__new__( # type: ignore - mcs, name, bases, attrs + mcs, name, new_bases, attrs ) add_cached_properties(new_model) diff --git a/tests/test_inheritance_concrete.py b/tests/test_inheritance_concrete.py new file mode 100644 index 0000000..69e3532 --- /dev/null +++ b/tests/test_inheritance_concrete.py @@ -0,0 +1,152 @@ +# type: ignore +import datetime +from typing import Optional + +import databases +import pytest +import sqlalchemy as sa +from sqlalchemy import create_engine + +import ormar +from ormar import ModelDefinitionError +from tests.settings import DATABASE_URL + +metadata = sa.MetaData() +db = databases.Database(DATABASE_URL) +engine = create_engine(DATABASE_URL) + + +class AuditModel(ormar.Model): + class Meta: + abstract = True + + created_by: str = ormar.String(max_length=100) + updated_by: str = ormar.String(max_length=100, default="Sam") + + +class DateFieldsModelNoSubclass(ormar.Model): + class Meta: + tablename = "test_date_models" + metadata = metadata + database = db + + date_id: int = ormar.Integer(primary_key=True) + created_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) + updated_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) + + +class DateFieldsModel(ormar.Model): + class Meta: + abstract = True + + created_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) + updated_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) + + +class Category(ormar.Model, DateFieldsModel, AuditModel): + class Meta(ormar.ModelMeta): + tablename = "categories" + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=50, unique=True, index=True) + code: int = ormar.Integer() + + +class Subject(ormar.Model, DateFieldsModel): + class Meta(ormar.ModelMeta): + tablename = "subjects" + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=50, unique=True, index=True) + category: Optional[Category] = ormar.ForeignKey(Category) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_field_redefining_raises_error(): + with pytest.raises(ModelDefinitionError): + class WrongField(ormar.Model, DateFieldsModel): # pragma: no cover + class Meta(ormar.ModelMeta): + tablename = "wrongs" + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + created_date: datetime.datetime = ormar.DateTime() + + +def test_model_subclassing_non_abstract_raises_error(): + with pytest.raises(ModelDefinitionError): + class WrongField2(ormar.Model, DateFieldsModelNoSubclass): # pragma: no cover + class Meta(ormar.ModelMeta): + tablename = "wrongs" + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + + +def round_date_to_seconds( + date: datetime.datetime, +) -> datetime.datetime: # pragma: no cover + if date.microsecond >= 500000: + date = date + datetime.timedelta(seconds=1) + return date.replace(microsecond=0) + + +@pytest.mark.asyncio +async def test_fields_inherited_from_mixin(): + async with db: + async with db.transaction(force_rollback=True): + cat = await Category( + name="Foo", code=123, created_by="Sam", updated_by="Max" + ).save() + sub = await Subject(name="Bar", category=cat).save() + mixin_columns = ["created_date", "updated_date"] + mixin2_columns = ["created_by", "updated_by"] + assert all(field in Category.Meta.model_fields for field in mixin_columns) + assert cat.created_date is not None + assert cat.updated_date is not None + assert all(field in Subject.Meta.model_fields for field in mixin_columns) + assert sub.created_date is not None + assert sub.updated_date is not None + + assert all(field in Category.Meta.model_fields for field in mixin2_columns) + assert all( + field not in Subject.Meta.model_fields for field in mixin2_columns + ) + + inspector = sa.inspect(engine) + assert "categories" in inspector.get_table_names() + table_columns = [x.get("name") for x in inspector.get_columns("categories")] + assert all(col in table_columns for col in mixin_columns) # + mixin2_columns) + + assert "subjects" in inspector.get_table_names() + table_columns = [x.get("name") for x in inspector.get_columns("subjects")] + assert all(col in table_columns for col in mixin_columns) + + sub2 = ( + await Subject.objects.select_related("category") + .order_by("-created_date") + .exclude_fields("updated_date") + .get() + ) + assert round_date_to_seconds(sub2.created_date) == round_date_to_seconds( + sub.created_date + ) + assert sub2.category.updated_date is not None + assert round_date_to_seconds( + sub2.category.created_date + ) == round_date_to_seconds(cat.created_date) + assert sub2.updated_date is None + assert sub2.category.created_by == "Sam" + assert sub2.category.updated_by == cat.updated_by