add mixin support with fields definitions

This commit is contained in:
collerek
2020-12-09 14:27:10 +01:00
parent 3b4dc59e5a
commit 53e0fa8e65
5 changed files with 314 additions and 15 deletions

View File

@ -6,7 +6,7 @@ import databases
import pydantic import pydantic
import sqlalchemy import sqlalchemy
from pydantic import BaseConfig from pydantic import BaseConfig
from pydantic.fields import ModelField from pydantic.fields import FieldInfo, ModelField
from pydantic.utils import lenient_issubclass from pydantic.utils import lenient_issubclass
from sqlalchemy.sql.schema import ColumnCollectionConstraint from sqlalchemy.sql.schema import ColumnCollectionConstraint
@ -24,6 +24,7 @@ if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
alias_manager = AliasManager() alias_manager = AliasManager()
PARSED_FIELDS_KEY = "__parsed_fields__"
class ModelMeta: class ModelMeta:
@ -221,9 +222,7 @@ def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]:
DeprecationWarning, DeprecationWarning,
) )
potential_fields.update( potential_fields.update(get_potential_fields(attrs))
{k: v for k, v in attrs.items() if lenient_issubclass(v, BaseField)}
)
for field_name, field in potential_fields.items(): for field_name, field in potential_fields.items():
field.name = field_name field.name = field_name
attrs = populate_default_pydantic_field_value(field, field_name, attrs) attrs = populate_default_pydantic_field_value(field, field_name, attrs)
@ -331,6 +330,7 @@ def populate_default_options_values(
def add_cached_properties(new_model: Type["Model"]) -> None: def add_cached_properties(new_model: Type["Model"]) -> None:
new_model._quick_access_fields = quick_access_set new_model._quick_access_fields = quick_access_set
new_model._related_names = None new_model._related_names = None
new_model._related_fields = None
new_model._pydantic_fields = {name for name in new_model.__fields__} new_model._pydantic_fields = {name for name in new_model.__fields__}
@ -362,6 +362,161 @@ def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001
new_model.Meta.signals = signals new_model.Meta.signals = signals
def get_potential_fields(attrs: Dict) -> Dict:
"""
Gets all the fields in current class namespace that are Fields.
:param attrs: current class namespace
:type attrs: Dict
:return: extracted fields that are ormar Fields
:rtype: Dict
"""
return {k: v for k, v in attrs.items() if lenient_issubclass(v, BaseField)}
def check_conflicting_fields(
new_fields: Set, attrs: Dict, base_class: type, curr_class: type
) -> None:
"""
You cannot redefine fields with same names in inherited classes.
Ormar will raise an exception if it encounters a field that is an ormar
Field and at the same time was already declared in one of base classes.
:param new_fields: set of names of fields defined in current model
:type new_fields: Set[str]
:param attrs: namespace of current class
:type attrs: Dict
:param base_class: one of the parent classes
:type base_class: Model or model parent class
:param curr_class: current constructed class
:type curr_class: Model or model parent class
"""
previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)})
overwrite = new_fields.intersection(previous_fields)
if overwrite:
raise ModelDefinitionError(
f"Model {curr_class} redefines the fields: "
f"{overwrite} already defined in {base_class}!"
)
def update_attrs_and_fields(
attrs: Dict,
new_attrs: Dict,
model_fields: Dict,
new_model_fields: Dict,
new_fields: Set,
) -> None:
"""
Updates __annotations__, values of model fields (so pydantic FieldInfos)
as well as model.Meta.model_fields definitions from parents.
:param attrs: new namespace for class being constructed
:type attrs: Dict
:param new_attrs: part of the namespace extracted from parent class
:type new_attrs: Dict
:param model_fields: ormar fields in defined in current class
:type model_fields: Dict[str, BaseField]
:param new_model_fields: ormar fields defined in parent classes
:type new_model_fields: Dict[str, BaseField]
:param new_fields: set of new fields names
:type new_fields: Set[str]
"""
key = "__annotations__"
attrs[key].update(new_attrs[key])
attrs.update({name: new_attrs[name] for name in new_fields})
model_fields.update(new_model_fields)
def extract_mixin_fields_from_dict(
base_class: type,
curr_class: type,
attrs: Dict,
model_fields: Dict[
str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]]
],
) -> Tuple[Dict, Dict]:
"""
Extracts fields from base classes if they have valid oramr fields.
If model was already parsed -> fields definitions need to be removed from class
cause pydantic complains about field re-definition so after first child
we need to extract from __parsed_fields__ not the class itself.
If the class is parsed first time annotations and field definition is parsed
from the class.__dict__.
If the class is a ormar.Model it is skipped.
:param base_class: one of the parent classes
:type base_class: Model or model parent class
:param curr_class: current constructed class
:type curr_class: Model or model parent class
:param attrs: new namespace for class being constructed
:type attrs: Dict
:param model_fields: ormar fields in defined in current class
:type model_fields: Dict[str, BaseField]
:return: updated attrs and model_fields
:rtype: Tuple[Dict, Dict]
"""
if hasattr(base_class, "Meta"):
# not a mixin base parent Model
return attrs, model_fields
key = "__annotations__"
if hasattr(base_class, PARSED_FIELDS_KEY):
# model was already parsed -> fields definitions need to be removed from class
# cause pydantic complains about field re-definition so after first child
# we need to extract from __parsed_fields__ not the class itself
new_attrs, new_model_fields = getattr(base_class, PARSED_FIELDS_KEY)
new_fields = set(new_model_fields.keys())
check_conflicting_fields(
new_fields=new_fields,
attrs=attrs,
base_class=base_class,
curr_class=curr_class,
)
update_attrs_and_fields(
attrs=attrs,
new_attrs=new_attrs,
model_fields=model_fields,
new_model_fields=new_model_fields,
new_fields=new_fields,
)
return attrs, model_fields
potential_fields = get_potential_fields(base_class.__dict__)
if potential_fields:
# parent model has ormar fields defined and was not parsed before
new_attrs = {key: base_class.__dict__.get(key, {})}
new_attrs.update(potential_fields)
new_fields = set(potential_fields.keys())
check_conflicting_fields(
new_fields=new_fields,
attrs=attrs,
base_class=base_class,
curr_class=curr_class,
)
for name in new_fields:
delattr(base_class, name)
new_attrs, new_model_fields = extract_annotations_and_default_vals(new_attrs)
setattr(base_class, PARSED_FIELDS_KEY, (new_attrs, new_model_fields))
update_attrs_and_fields(
attrs=attrs,
new_attrs=new_attrs,
model_fields=model_fields,
new_model_fields=new_model_fields,
new_fields=new_fields,
)
return attrs, model_fields
class ModelMetaclass(pydantic.main.ModelMetaclass): class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__( # type: ignore def __new__( # type: ignore
mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict
@ -369,9 +524,15 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
attrs["Config"] = get_pydantic_base_orm_config() attrs["Config"] = get_pydantic_base_orm_config()
attrs["__name__"] = name attrs["__name__"] = name
attrs, model_fields = extract_annotations_and_default_vals(attrs) attrs, model_fields = extract_annotations_and_default_vals(attrs)
for base in reversed(bases):
attrs, model_fields = extract_mixin_fields_from_dict(
base_class=base, curr_class=mcs, attrs=attrs, model_fields=model_fields
)
# print(attrs, model_fields)
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
) )
add_cached_properties(new_model) add_cached_properties(new_model)
if hasattr(new_model, "Meta"): if hasattr(new_model, "Meta"):

View File

@ -41,7 +41,7 @@ class ModelTableProxy:
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
Meta: ModelMeta Meta: ModelMeta
_related_names: Optional[Set] _related_names: Optional[Set]
_related_names_hash: Union[str, bytes] _related_fields: Optional[List]
pk: Any pk: Any
get_name: Callable get_name: Callable
_props: Set _props: Set
@ -202,6 +202,19 @@ class ModelTableProxy:
return field_name return field_name
return alias # if not found it's not an alias but actual name return alias # if not found it's not an alias but actual name
@classmethod
def extract_related_fields(cls) -> List:
if isinstance(cls._related_fields, List):
return cls._related_fields
related_fields = []
for name in cls.extract_related_names():
related_fields.append(cls.Meta.model_fields[name])
cls._related_fields = related_fields
return related_fields
@classmethod @classmethod
def extract_related_names(cls) -> Set: def extract_related_names(cls) -> Set:

View File

@ -28,7 +28,6 @@ from pydantic import BaseModel
import ormar # noqa I100 import ormar # noqa I100
from ormar.exceptions import ModelError from ormar.exceptions import ModelError
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.excludable import Excludable from ormar.models.excludable import Excludable
from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.metaclass import ModelMeta, ModelMetaclass
from ormar.models.modelproxy import ModelTableProxy from ormar.models.modelproxy import ModelTableProxy
@ -79,14 +78,7 @@ class NewBaseModel(
object.__setattr__( object.__setattr__(
self, self,
"_orm", "_orm",
RelationsManager( RelationsManager(related_fields=self.extract_related_fields(), owner=self,),
related_fields=[
field
for name, field in self.Meta.model_fields.items()
if issubclass(field, ForeignKeyField)
],
owner=self,
),
) )
pk_only = kwargs.pop("__pk_only__", False) pk_only = kwargs.pop("__pk_only__", False)

View File

@ -0,0 +1,133 @@
# type: ignore
import datetime
from typing import Optional
import databases
import pytest
import sqlalchemy as sa
from sqlalchemy import create_engine
import ormar
from ormar import ModelDefinitionError
from tests.settings import DATABASE_URL
metadata = sa.MetaData()
db = databases.Database(DATABASE_URL)
engine = create_engine(DATABASE_URL)
class AuditMixin:
created_by: str = ormar.String(max_length=100)
updated_by: str = ormar.String(max_length=100)
class DateFieldsMixins:
created_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now)
updated_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now)
class Category(ormar.Model, DateFieldsMixins, AuditMixin):
class Meta(ormar.ModelMeta):
tablename = "categories"
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=50, unique=True, index=True)
code: int = ormar.Integer()
class Subject(ormar.Model, DateFieldsMixins):
class Meta(ormar.ModelMeta):
tablename = "subjects"
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=50, unique=True, index=True)
category: Optional[Category] = ormar.ForeignKey(Category)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
metadata.create_all(engine)
yield
metadata.drop_all(engine)
def test_field_redefining_raises_error():
with pytest.raises(ModelDefinitionError):
class WrongField(ormar.Model, DateFieldsMixins): # pragma: no cover
class Meta(ormar.ModelMeta):
tablename = "wrongs"
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True)
created_date: datetime.datetime = ormar.DateTime()
def test_field_redefining_in_second_raises_error():
class OkField(ormar.Model, DateFieldsMixins): # pragma: no cover
class Meta(ormar.ModelMeta):
tablename = "oks"
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True)
with pytest.raises(ModelDefinitionError):
class WrongField(ormar.Model, DateFieldsMixins): # pragma: no cover
class Meta(ormar.ModelMeta):
tablename = "wrongs"
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True)
created_date: datetime.datetime = ormar.DateTime()
@pytest.mark.asyncio
async def test_fields_inherited_from_mixin():
async with db:
async with db.transaction(force_rollback=True):
cat = await Category(
name="Foo", code=123, created_by="Sam", updated_by="Max"
).save()
sub = await Subject(name="Bar", category=cat).save()
mixin_columns = ["created_date", "updated_date"]
mixin2_columns = ["created_by", "updated_by"]
assert all(field in Category.Meta.model_fields for field in mixin_columns)
assert cat.created_date is not None
assert cat.updated_date is not None
assert all(field in Subject.Meta.model_fields for field in mixin_columns)
assert sub.created_date is not None
assert sub.updated_date is not None
assert all(field in Category.Meta.model_fields for field in mixin2_columns)
assert all(
field not in Subject.Meta.model_fields for field in mixin2_columns
)
inspector = sa.inspect(engine)
assert "categories" in inspector.get_table_names()
table_columns = [x.get("name") for x in inspector.get_columns("categories")]
assert all(col in table_columns for col in mixin_columns + mixin2_columns)
assert "subjects" in inspector.get_table_names()
table_columns = [x.get("name") for x in inspector.get_columns("subjects")]
assert all(col in table_columns for col in mixin_columns)
sub2 = (
await Subject.objects.select_related("category")
.order_by("-created_date")
.exclude_fields("updated_date")
.get()
)
assert sub2.created_date == sub.created_date
assert sub2.category.updated_date is not None
assert sub2.category.created_date == cat.created_date
assert sub2.updated_date is None
assert sub2.category.created_by == "Sam"

View File

@ -55,7 +55,7 @@ async def test_model_relationship():
assert ws.topic == "Topic 1" assert ws.topic == "Topic 1"
assert ws.category.name == "Foo" assert ws.category.name == "Foo"
ws.topic = 'Topic 2' ws.topic = "Topic 2"
await ws.update() await ws.update()
assert ws.id == 1 assert ws.id == 1