first ver of working concrete inheritance

This commit is contained in:
collerek
2020-12-10 16:09:55 +01:00
parent fc710687e6
commit c23afd17a0
2 changed files with 191 additions and 9 deletions

View File

@ -25,6 +25,7 @@ if TYPE_CHECKING: # pragma no cover
alias_manager = AliasManager() alias_manager = AliasManager()
PARSED_FIELDS_KEY = "__parsed_fields__" PARSED_FIELDS_KEY = "__parsed_fields__"
CONFIG_KEY = "Config"
class ModelMeta: class ModelMeta:
@ -478,13 +479,19 @@ def get_potential_fields(attrs: Dict) -> Dict:
def check_conflicting_fields( def check_conflicting_fields(
new_fields: Set, attrs: Dict, base_class: type, curr_class: type new_fields: Set,
attrs: Dict,
base_class: type,
curr_class: type,
previous_fields: Set = None,
) -> None: ) -> None:
""" """
You cannot redefine fields with same names in inherited classes. You cannot redefine fields with same names in inherited classes.
Ormar will raise an exception if it encounters a field that is an ormar Ormar will raise an exception if it encounters a field that is an ormar
Field and at the same time was already declared in one of base classes. Field and at the same time was already declared in one of base classes.
:param previous_fields: set of names of fields defined in base model
:type previous_fields: Set[str]
:param new_fields: set of names of fields defined in current model :param new_fields: set of names of fields defined in current model
:type new_fields: Set[str] :type new_fields: Set[str]
:param attrs: namespace of current class :param attrs: namespace of current class
@ -494,6 +501,7 @@ def check_conflicting_fields(
:param curr_class: current constructed class :param curr_class: current constructed class
:type curr_class: Model or model parent class :type curr_class: Model or model parent class
""" """
if not previous_fields:
previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)}) previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)})
overwrite = new_fields.intersection(previous_fields) overwrite = new_fields.intersection(previous_fields)
@ -539,7 +547,8 @@ def extract_mixin_fields_from_dict(
model_fields: Dict[ model_fields: Dict[
str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]]
], ],
) -> Tuple[Dict, Dict]: bases: Any,
) -> Tuple[Dict, Dict, Any]:
""" """
Extracts fields from base classes if they have valid oramr fields. Extracts fields from base classes if they have valid oramr fields.
@ -565,14 +574,29 @@ def extract_mixin_fields_from_dict(
""" """
if hasattr(base_class, "Meta"): if hasattr(base_class, "Meta"):
new_fields = set(base_class.Meta.model_fields.keys()) # type: ignore new_fields = set(base_class.Meta.model_fields.keys()) # type: ignore
previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)})
check_conflicting_fields( check_conflicting_fields(
new_fields=new_fields, new_fields=new_fields,
attrs=attrs, attrs=attrs,
base_class=base_class, base_class=base_class,
curr_class=curr_class, curr_class=curr_class,
previous_fields=previous_fields,
)
if previous_fields and not base_class.Meta.abstract: # type: ignore
raise ModelDefinitionError(
f"{curr_class.__name__} cannot inherit "
f"from non abstract class {base_class.__name__}"
) )
model_fields.update(base_class.Meta.model_fields) # type: ignore model_fields.update(base_class.Meta.model_fields) # type: ignore
return attrs, model_fields # keep only parent ormar models as they already have all the predecessors
# keeping also Model, NewBaseModel etc. would cause mro conflicts
new_bases = tuple(
base
for base in bases
if issubclass(base, ormar.Model) and base != ormar.Model
)
return attrs, model_fields, new_bases
key = "__annotations__" key = "__annotations__"
if hasattr(base_class, PARSED_FIELDS_KEY): if hasattr(base_class, PARSED_FIELDS_KEY):
@ -596,7 +620,7 @@ def extract_mixin_fields_from_dict(
new_model_fields=new_model_fields, new_model_fields=new_model_fields,
new_fields=new_fields, new_fields=new_fields,
) )
return attrs, model_fields return attrs, model_fields, bases
potential_fields = get_potential_fields(base_class.__dict__) potential_fields = get_potential_fields(base_class.__dict__)
if potential_fields: if potential_fields:
@ -624,7 +648,7 @@ def extract_mixin_fields_from_dict(
new_model_fields=new_model_fields, new_model_fields=new_model_fields,
new_fields=new_fields, new_fields=new_fields,
) )
return attrs, model_fields return attrs, model_fields, bases
class ModelMetaclass(pydantic.main.ModelMetaclass): class ModelMetaclass(pydantic.main.ModelMetaclass):
@ -634,13 +658,19 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
attrs["Config"] = get_pydantic_base_orm_config() attrs["Config"] = get_pydantic_base_orm_config()
attrs["__name__"] = name attrs["__name__"] = name
attrs, model_fields = extract_annotations_and_default_vals(attrs) attrs, model_fields = extract_annotations_and_default_vals(attrs)
new_bases = bases
for base in reversed(bases): for base in reversed(bases):
attrs, model_fields = extract_mixin_fields_from_dict( attrs, model_fields, new_bases = extract_mixin_fields_from_dict(
base_class=base, curr_class=mcs, attrs=attrs, model_fields=model_fields base_class=base,
curr_class=mcs,
attrs=attrs,
model_fields=model_fields,
bases=new_bases,
) )
# print(attrs, model_fields) # print(attrs, model_fields)
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, new_bases, attrs
) )
add_cached_properties(new_model) add_cached_properties(new_model)

View File

@ -0,0 +1,152 @@
# type: ignore
import datetime
from typing import Optional
import databases
import pytest
import sqlalchemy as sa
from sqlalchemy import create_engine
import ormar
from ormar import ModelDefinitionError
from tests.settings import DATABASE_URL
metadata = sa.MetaData()
db = databases.Database(DATABASE_URL)
engine = create_engine(DATABASE_URL)
class AuditModel(ormar.Model):
class Meta:
abstract = True
created_by: str = ormar.String(max_length=100)
updated_by: str = ormar.String(max_length=100, default="Sam")
class DateFieldsModelNoSubclass(ormar.Model):
class Meta:
tablename = "test_date_models"
metadata = metadata
database = db
date_id: int = ormar.Integer(primary_key=True)
created_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now)
updated_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now)
class DateFieldsModel(ormar.Model):
class Meta:
abstract = True
created_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now)
updated_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now)
class Category(ormar.Model, DateFieldsModel, AuditModel):
class Meta(ormar.ModelMeta):
tablename = "categories"
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=50, unique=True, index=True)
code: int = ormar.Integer()
class Subject(ormar.Model, DateFieldsModel):
class Meta(ormar.ModelMeta):
tablename = "subjects"
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=50, unique=True, index=True)
category: Optional[Category] = ormar.ForeignKey(Category)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
metadata.create_all(engine)
yield
metadata.drop_all(engine)
def test_field_redefining_raises_error():
with pytest.raises(ModelDefinitionError):
class WrongField(ormar.Model, DateFieldsModel): # pragma: no cover
class Meta(ormar.ModelMeta):
tablename = "wrongs"
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True)
created_date: datetime.datetime = ormar.DateTime()
def test_model_subclassing_non_abstract_raises_error():
with pytest.raises(ModelDefinitionError):
class WrongField2(ormar.Model, DateFieldsModelNoSubclass): # pragma: no cover
class Meta(ormar.ModelMeta):
tablename = "wrongs"
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True)
def round_date_to_seconds(
date: datetime.datetime,
) -> datetime.datetime: # pragma: no cover
if date.microsecond >= 500000:
date = date + datetime.timedelta(seconds=1)
return date.replace(microsecond=0)
@pytest.mark.asyncio
async def test_fields_inherited_from_mixin():
async with db:
async with db.transaction(force_rollback=True):
cat = await Category(
name="Foo", code=123, created_by="Sam", updated_by="Max"
).save()
sub = await Subject(name="Bar", category=cat).save()
mixin_columns = ["created_date", "updated_date"]
mixin2_columns = ["created_by", "updated_by"]
assert all(field in Category.Meta.model_fields for field in mixin_columns)
assert cat.created_date is not None
assert cat.updated_date is not None
assert all(field in Subject.Meta.model_fields for field in mixin_columns)
assert sub.created_date is not None
assert sub.updated_date is not None
assert all(field in Category.Meta.model_fields for field in mixin2_columns)
assert all(
field not in Subject.Meta.model_fields for field in mixin2_columns
)
inspector = sa.inspect(engine)
assert "categories" in inspector.get_table_names()
table_columns = [x.get("name") for x in inspector.get_columns("categories")]
assert all(col in table_columns for col in mixin_columns) # + mixin2_columns)
assert "subjects" in inspector.get_table_names()
table_columns = [x.get("name") for x in inspector.get_columns("subjects")]
assert all(col in table_columns for col in mixin_columns)
sub2 = (
await Subject.objects.select_related("category")
.order_by("-created_date")
.exclude_fields("updated_date")
.get()
)
assert round_date_to_seconds(sub2.created_date) == round_date_to_seconds(
sub.created_date
)
assert sub2.category.updated_date is not None
assert round_date_to_seconds(
sub2.category.created_date
) == round_date_to_seconds(cat.created_date)
assert sub2.updated_date is None
assert sub2.category.created_by == "Sam"
assert sub2.category.updated_by == cat.updated_by