From 4209d37364684c72e1ddf35c43b0d1814588f1c4 Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 12 Jan 2021 14:38:22 +0100 Subject: [PATCH] working m2m and fk self relations with forwardref --- ormar/fields/base.py | 2 ++ ormar/fields/many_to_many.py | 4 ++-- ormar/models/helpers/relations.py | 1 + ormar/models/metaclass.py | 2 +- ormar/models/model.py | 8 +++++++- ormar/models/newbasemodel.py | 4 +++- ormar/queryset/join.py | 22 +++++++++++++++++---- ormar/queryset/queryset.py | 2 +- ormar/relations/querysetproxy.py | 9 +++++---- tests/test_forward_refs.py | 33 ++++++++++++++++++++----------- 10 files changed, 62 insertions(+), 25 deletions(-) diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 966992e..4f05aa1 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -45,6 +45,7 @@ class BaseField(FieldInfo): to: Type["Model"] through: Type["Model"] self_reference: bool = False + self_reference_primary: Optional[str] = None default: Any server_default: Any @@ -277,6 +278,7 @@ class BaseField(FieldInfo): cls.owner == cls.to or cls.owner.Meta == cls.to.Meta ): cls.self_reference = True + cls.self_reference_primary = cls.name @classmethod def has_unresolved_forward_refs(cls) -> bool: diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py index 29923d0..f789284 100644 --- a/ormar/fields/many_to_many.py +++ b/ormar/fields/many_to_many.py @@ -128,7 +128,7 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro :return: name of the field :rtype: str """ - prefix = "to_" if cls.self_reference else "" + prefix = "from_" if cls.self_reference else "" return f"{prefix}{cls.to.get_name()}" @classmethod @@ -138,7 +138,7 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro :return: name of the field :rtype: str """ - prefix = "from_" if cls.self_reference else "" + prefix = "to_" if cls.self_reference else "" return f"{prefix}{cls.owner.get_name()}" @classmethod diff --git a/ormar/models/helpers/relations.py b/ormar/models/helpers/relations.py index 70a558c..4cf19ea 100644 --- a/ormar/models/helpers/relations.py +++ b/ormar/models/helpers/relations.py @@ -110,6 +110,7 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None: related_name=model_field.name, owner=model_field.to, self_reference=model_field.self_reference, + self_reference_primary=model_field.self_reference_primary, ) # register foreign keys on through model adjust_through_many_to_many_model(model_field=model_field) diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index d6dc1c5..20a00d6 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -586,7 +586,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): new_model = populate_meta_tablename_columns_and_pk(name, new_model) populate_meta_sqlalchemy_table_if_required(new_model.Meta) expand_reverse_relationships(new_model) - for field_name, field in new_model.Meta.model_fields.items(): + for field in new_model.Meta.model_fields.values(): register_relation_in_alias_manager(field=field) if new_model.Meta.pkname not in attrs["__annotations__"]: diff --git a/ormar/models/model.py b/ormar/models/model.py index 63388cb..35eaab1 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -125,7 +125,13 @@ class Model(NewBaseModel): ) ): through_field = previous_model.Meta.model_fields[related_name] - rel_name2 = through_field.default_target_field_name() # type: ignore + if ( + through_field.self_reference + and related_name == through_field.self_reference_primary + ): + rel_name2 = through_field.default_source_field_name() # type: ignore + else: + rel_name2 = through_field.default_target_field_name() # type: ignore previous_model = through_field.through # type: ignore if previous_model and rel_name2: diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index c062d7a..ed69e46 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -436,6 +436,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass Populates Meta table of the Model which is left empty before. + Sets self_reference flag on models that links to themselves. + Calls the pydantic method to evaluate pydantic fields. :param localns: local namespace @@ -446,7 +448,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass globalns = sys.modules[cls.__module__].__dict__.copy() globalns.setdefault(cls.__name__, cls) fields_to_check = cls.Meta.model_fields.copy() - for field_name, field in fields_to_check.items(): + for field in fields_to_check.values(): if field.has_unresolved_forward_refs(): field = cast(Type[ForeignKeyField], field) field.evaluate_forward_ref(globalns=globalns, localns=localns) diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index 0991562..2b2205f 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -135,7 +135,14 @@ class SqlJoin: join_parameters.model_cls.Meta.model_fields[part], ManyToManyField ): _fields = join_parameters.model_cls.Meta.model_fields - new_part = _fields[part].default_target_field_name() + target_field = _fields[part] + if ( + target_field.self_reference + and part == target_field.self_reference_primary + ): + new_part = target_field.default_source_field_name() # type: ignore + else: + new_part = target_field.default_target_field_name() # type: ignore self._switch_many_to_many_order_columns(part, new_part) if index > 0: # nested joins fields, exclude_fields = SqlJoin.update_inclusions( @@ -430,18 +437,25 @@ class SqlJoin: :rtype: Tuple[str, str] """ if is_multi: - to_field = join_params.prev_model.get_name() - to_key = model_cls.get_column_alias(to_field) + target_field = join_params.model_cls.Meta.model_fields[part] + if ( + target_field.self_reference + and part == target_field.self_reference_primary + ): + to_key = target_field.default_target_field_name() # type: ignore + else: + to_key = target_field.default_source_field_name() # type: ignore from_key = join_params.prev_model.get_column_alias( join_params.prev_model.Meta.pkname ) - breakpoint() + elif join_params.prev_model.Meta.model_fields[part].virtual: to_field = join_params.prev_model.Meta.model_fields[part].get_related_name() to_key = model_cls.get_column_alias(to_field) from_key = join_params.prev_model.get_column_alias( join_params.prev_model.Meta.pkname ) + else: to_key = model_cls.get_column_alias(model_cls.Meta.pkname) from_key = join_params.prev_model.get_column_alias(part) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 11387b5..4940265 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -197,7 +197,7 @@ class QuerySet: limit_raw_sql=self.limit_sql_raw, ) exp = qry.build_select_expression() - print("\n", exp.compile(compile_kwargs={"literal_binds": True})) + # print("\n", exp.compile(compile_kwargs={"literal_binds": True})) return exp def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003 diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index 7003c7a..0740a17 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -107,8 +107,8 @@ class QuerysetProxy(ormar.QuerySetProtocol): :type child: Model """ model_cls = self.relation.through - owner_column = self.related_field.default_target_field_name() - child_column = self.related_field.default_source_field_name() + owner_column = self.related_field.default_target_field_name() # type: ignore + child_column = self.related_field.default_source_field_name() # type: ignore kwargs = {owner_column: self._owner.pk, child_column: child.pk} if child.pk is None: raise ModelPersistenceError( @@ -118,6 +118,7 @@ class QuerysetProxy(ormar.QuerySetProtocol): ) expr = model_cls.Meta.table.insert() expr = expr.values(**kwargs) + # print("\n", expr.compile(compile_kwargs={"literal_binds": True})) await model_cls.Meta.database.execute(expr) async def delete_through_instance(self, child: "T") -> None: @@ -128,8 +129,8 @@ class QuerysetProxy(ormar.QuerySetProtocol): :type child: Model """ queryset = ormar.QuerySet(model_cls=self.relation.through) - owner_column = self.related_field.default_target_field_name() - child_column = self.related_field.default_source_field_name() + owner_column = self.related_field.default_target_field_name() # type: ignore + child_column = self.related_field.default_source_field_name() # type: ignore kwargs = {owner_column: self._owner, child_column: child} link_instance = await queryset.filter(**kwargs).get() # type: ignore await link_instance.delete() diff --git a/tests/test_forward_refs.py b/tests/test_forward_refs.py index e0e5303..963a623 100644 --- a/tests/test_forward_refs.py +++ b/tests/test_forward_refs.py @@ -72,6 +72,16 @@ def create_test_database(): metadata.drop_all(engine) +@pytest.fixture(scope="function") +async def cleanup(): + yield + async with db: + await ChildFriend.objects.delete(each=True) + await Child.objects.delete(each=True) + await Game.objects.delete(each=True) + await Person.objects.delete(each=True) + + @pytest.mark.asyncio async def test_not_uprated_model_raises_errors(): Person2 = ForwardRef("Person2") @@ -126,7 +136,7 @@ async def test_self_relation(): @pytest.mark.asyncio -async def test_other_forwardref_relation(): +async def test_other_forwardref_relation(cleanup): checkers = await Game.objects.create(name="checkers") uno = await Game(name="Uno").save() @@ -147,7 +157,7 @@ async def test_other_forwardref_relation(): @pytest.mark.asyncio -async def test_m2m_self_forwardref_relation(): +async def test_m2m_self_forwardref_relation(cleanup): checkers = await Game.objects.create(name="checkers") uno = await Game(name="Uno").save() jenga = await Game(name="Jenga").save() @@ -165,16 +175,17 @@ async def test_m2m_self_forwardref_relation(): await billy.friends.add(kate) await billy.friends.add(steve) - await steve.friends.add(kate) - await steve.friends.add(billy) + # await steve.friends.add(kate) + # await steve.friends.add(billy) billy_check = await Child.objects.select_related( - [ - "friends", - "favourite_game", - "least_favourite_game", - "friends__favourite_game", - "friends__least_favourite_game", - ] + ["friends", "favourite_game", "least_favourite_game",] ).get(name="Billy") assert len(billy_check.friends) == 2 + assert billy_check.friends[0].name == "Kate" + assert billy_check.friends[1].name == "Steve" + assert billy_check.favourite_game.name == "Uno" + + kate_check = await Child.objects.select_related(["also_friends",]).get(name="Kate") + assert len(kate_check.also_friends) == 1 + assert kate_check.also_friends[0].name == "Billy"