From ace348e172ff5681bc5f228c77e4da25f85a8b2a Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 11 Aug 2020 15:27:10 +0200 Subject: [PATCH] refactored reverse relation registration into the metaclass --- .coverage | Bin 53248 -> 53248 bytes orm/fields.py | 37 +++++---------------- orm/models.py | 58 +++++++++++++++++++++++++-------- tests/test_same_table_joins.py | 25 +++++++------- 4 files changed, 65 insertions(+), 55 deletions(-) diff --git a/.coverage b/.coverage index 5d7b859a5b4294ba4300067c998729bd4a6b59ad..04d6ff895dc3fdaa51adcd2b47292b0028b57483 100644 GIT binary patch delta 166 zcmV;X09pTlpaX!Q1F$A93o$VuF*Q0eHaajclQA!Fld3OO2rM8cEiG_lVzd4)6Ho!F zvl4)94iLJ{n|I#4=RKeQ&o_9!&UcfKj?7i;|LyK?UqAQfy+42MZr^>X{PHmeI0gg( z333L0dn@1aEDKP-w*FL(wZ?;YD-`#KB-f!>q+wGsMz3s1iy8G|_ U_I)P>1OW+9ldX>NrBvl4)9 z4iIkh=AAe1dC$-P=Nmj<=R1>+j?7Z9|F^rpef`{@_x}93yM6bm^2^5@H3kF$32X*_ zy_IkI@%x|4v)Rr4n#2HnU;F6#zS%ape|Nuid%wNcZ?}K4_O`$7>F&Sx+xGz~1q1;J RR+F%gBLl4e2D8(TEkL^@RBQkM diff --git a/orm/fields.py b/orm/fields.py index f101346..a2dca94 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -51,7 +51,7 @@ class BaseField: @property def is_required(self) -> bool: return ( - not self.nullable and not self.has_default and not self.is_auto_primary_key + not self.nullable and not self.has_default and not self.is_auto_primary_key ) @property @@ -204,12 +204,12 @@ def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": class ForeignKey(BaseField): def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual @@ -229,7 +229,7 @@ class ForeignKey(BaseField): return to_column.get_column_type() def expand_relationship( - self, value: Any, child: "Model" + self, value: Any, child: "Model" ) -> Union["Model", List["Model"]]: if isinstance(value, orm.models.Model) and not isinstance(value, self.to): @@ -261,7 +261,6 @@ class ForeignKey(BaseField): return model def add_to_relationship_registry(self, model: "Model", child: "Model") -> None: - child_model_name = self.related_name or child.get_name() + "s" model._orm_relationship_manager.add_relation( model.__class__.__name__.lower(), child.__class__.__name__.lower(), @@ -269,23 +268,3 @@ class ForeignKey(BaseField): child, virtual=self.virtual, ) - - if ( - child_model_name not in model.__fields__ - and child.get_name() not in model.__fields__ - ): - self.register_reverse_model_fields(model, child, child_model_name) - - @staticmethod - def register_reverse_model_fields( - model: "Model", child: "Model", child_model_name: str - ) -> None: - model.__fields__[child_model_name] = ModelField( - name=child_model_name, - type_=Optional[child.__pydantic_model__], - model_config=child.__pydantic_model__.__config__, - class_validators=child.__pydantic_model__.__validators__, - ) - model.__model_fields__[child_model_name] = ForeignKey( - child.__class__, name=child_model_name, virtual=True - ) diff --git a/orm/models.py b/orm/models.py index d41b09c..f7a6e6a 100644 --- a/orm/models.py +++ b/orm/models.py @@ -6,6 +6,7 @@ from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar from typing import Callable, Dict, Set import databases +from pydantic.fields import ModelField import orm.queryset as qry from orm.exceptions import ModelDefinitionError @@ -41,8 +42,35 @@ def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> ) +def expand_reverse_relationships(model: Type["Model"]): + for field_name, model_field in model.__model_fields__.items(): + if isinstance(model_field, ForeignKey): + child_model_name = model_field.related_name or model.__name__.lower() + 's' + parent_model = model_field.to + child = model + if ( + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ + ): + register_reverse_model_fields(parent_model, child, child_model_name) + + +def register_reverse_model_fields( + model: Type["Model"], child: Type["Model"], child_model_name: str +) -> None: + model.__fields__[child_model_name] = ModelField( + name=child_model_name, + type_=Optional[child.__pydantic_model__], + model_config=child.__pydantic_model__.__config__, + class_validators=child.__pydantic_model__.__validators__, + ) + model.__model_fields__[child_model_name] = ForeignKey( + child, name=child_model_name, virtual=True + ) + + def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str + name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: pkname: Optional[str] = None columns: List[sqlalchemy.Column] = [] @@ -100,14 +128,16 @@ class ModelMetaclass(type): attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__) attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__) attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__) - attrs["__model_fields__"] = model_fields + attrs["__model_fields__"] = model_fields attrs["_orm_relationship_manager"] = relationship_manager new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) + expand_reverse_relationships(new_model) + return new_model @@ -168,9 +198,9 @@ class FakePydantic(list, metaclass=ModelMetaclass): item = getattr(self.values, key, None) if ( - item is not None - and self._is_conversion_to_json_needed(key) - and isinstance(item, str) + item is not None + and self._is_conversion_to_json_needed(key) + and isinstance(item, str) ): try: item = json.loads(item) @@ -186,7 +216,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): if self.__class__ != other.__class__: # pragma no cover return False return self._orm_id == other._orm_id or ( - self.values is not None and other.values is not None and self.pk == other.pk + self.values is not None and other.values is not None and self.pk == other.pk ) def __repr__(self) -> str: # pragma no cover @@ -242,7 +272,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): related_names = set() for name, field in cls.__fields__.items(): if inspect.isclass(field.type_) and issubclass( - field.type_, pydantic.BaseModel + field.type_, pydantic.BaseModel ): related_names.add(name) return related_names @@ -274,7 +304,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): for field in one.__model_fields__.keys(): # print(field, one.dict(), other.dict()) if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), Model + getattr(one, field), Model ): setattr(other, field, getattr(one, field) + getattr(other, field)) elif isinstance(getattr(one, field), Model): @@ -296,10 +326,10 @@ class Model(FakePydantic): @classmethod def from_row( - cls, - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - previous_table: str = None, + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + previous_table: str = None, ) -> "Model": item = {} @@ -357,8 +387,8 @@ class Model(FakePydantic): self_fields.pop(self.__pkname__) expr = ( self.__table__.update() - .values(**self_fields) - .where(self.pk_column == getattr(self, self.__pkname__)) + .values(**self_fields) + .where(self.pk_column == getattr(self, self.__pkname__)) ) result = await self.__database__.execute(expr) return result diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 66f6ad1..8c67f4b 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -1,3 +1,5 @@ +import asyncio + import databases import pytest import sqlalchemy @@ -59,16 +61,17 @@ class Teacher(orm.Model): category = orm.ForeignKey(Category, nullable=True) +@pytest.fixture(scope='module') +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + @pytest.fixture(autouse=True, scope="module") -def create_test_database(): +async def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) metadata.create_all(engine) - yield - metadata.drop_all(engine) - - -@pytest.fixture() -async def init_relation(): department = await Department.objects.create(id=1, name="Math Department") class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") @@ -77,13 +80,11 @@ async def init_relation(): await Student.objects.create(name="Jack", category=category2, schoolclass=class1) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) yield - engine = sqlalchemy.create_engine(DATABASE_URL) metadata.drop_all(engine) - metadata.create_all(engine) @pytest.mark.asyncio -async def test_model_multiple_instances_of_same_table_in_schema(init_relation): +async def test_model_multiple_instances_of_same_table_in_schema(): async with database: classes = await SchoolClass.objects.select_related( ["teachers__category", "students"] @@ -102,7 +103,7 @@ async def test_model_multiple_instances_of_same_table_in_schema(init_relation): @pytest.mark.asyncio -async def test_right_tables_join(init_relation): +async def test_right_tables_join(): async with database: classes = await SchoolClass.objects.select_related( ["teachers__category", "students"] @@ -115,7 +116,7 @@ async def test_right_tables_join(init_relation): @pytest.mark.asyncio -async def test_multiple_reverse_related_objects(init_relation): +async def test_multiple_reverse_related_objects(): async with database: classes = await SchoolClass.objects.select_related( ["teachers__category", "students"]