working m2m and fk self relations with forwardref

This commit is contained in:
collerek
2021-01-12 14:38:22 +01:00
parent 8b67c83d0c
commit 4209d37364
10 changed files with 62 additions and 25 deletions

View File

@ -45,6 +45,7 @@ class BaseField(FieldInfo):
to: Type["Model"] to: Type["Model"]
through: Type["Model"] through: Type["Model"]
self_reference: bool = False self_reference: bool = False
self_reference_primary: Optional[str] = None
default: Any default: Any
server_default: Any server_default: Any
@ -277,6 +278,7 @@ class BaseField(FieldInfo):
cls.owner == cls.to or cls.owner.Meta == cls.to.Meta cls.owner == cls.to or cls.owner.Meta == cls.to.Meta
): ):
cls.self_reference = True cls.self_reference = True
cls.self_reference_primary = cls.name
@classmethod @classmethod
def has_unresolved_forward_refs(cls) -> bool: def has_unresolved_forward_refs(cls) -> bool:

View File

@ -128,7 +128,7 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro
:return: name of the field :return: name of the field
:rtype: str :rtype: str
""" """
prefix = "to_" if cls.self_reference else "" prefix = "from_" if cls.self_reference else ""
return f"{prefix}{cls.to.get_name()}" return f"{prefix}{cls.to.get_name()}"
@classmethod @classmethod
@ -138,7 +138,7 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro
:return: name of the field :return: name of the field
:rtype: str :rtype: str
""" """
prefix = "from_" if cls.self_reference else "" prefix = "to_" if cls.self_reference else ""
return f"{prefix}{cls.owner.get_name()}" return f"{prefix}{cls.owner.get_name()}"
@classmethod @classmethod

View File

@ -110,6 +110,7 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None:
related_name=model_field.name, related_name=model_field.name,
owner=model_field.to, owner=model_field.to,
self_reference=model_field.self_reference, self_reference=model_field.self_reference,
self_reference_primary=model_field.self_reference_primary,
) )
# register foreign keys on through model # register foreign keys on through model
adjust_through_many_to_many_model(model_field=model_field) adjust_through_many_to_many_model(model_field=model_field)

View File

@ -586,7 +586,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
new_model = populate_meta_tablename_columns_and_pk(name, new_model) new_model = populate_meta_tablename_columns_and_pk(name, new_model)
populate_meta_sqlalchemy_table_if_required(new_model.Meta) populate_meta_sqlalchemy_table_if_required(new_model.Meta)
expand_reverse_relationships(new_model) expand_reverse_relationships(new_model)
for field_name, field in new_model.Meta.model_fields.items(): for field in new_model.Meta.model_fields.values():
register_relation_in_alias_manager(field=field) register_relation_in_alias_manager(field=field)
if new_model.Meta.pkname not in attrs["__annotations__"]: if new_model.Meta.pkname not in attrs["__annotations__"]:

View File

@ -125,6 +125,12 @@ class Model(NewBaseModel):
) )
): ):
through_field = previous_model.Meta.model_fields[related_name] through_field = previous_model.Meta.model_fields[related_name]
if (
through_field.self_reference
and related_name == through_field.self_reference_primary
):
rel_name2 = through_field.default_source_field_name() # type: ignore
else:
rel_name2 = through_field.default_target_field_name() # type: ignore rel_name2 = through_field.default_target_field_name() # type: ignore
previous_model = through_field.through # type: ignore previous_model = through_field.through # type: ignore

View File

@ -436,6 +436,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
Populates Meta table of the Model which is left empty before. Populates Meta table of the Model which is left empty before.
Sets self_reference flag on models that links to themselves.
Calls the pydantic method to evaluate pydantic fields. Calls the pydantic method to evaluate pydantic fields.
:param localns: local namespace :param localns: local namespace
@ -446,7 +448,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
globalns = sys.modules[cls.__module__].__dict__.copy() globalns = sys.modules[cls.__module__].__dict__.copy()
globalns.setdefault(cls.__name__, cls) globalns.setdefault(cls.__name__, cls)
fields_to_check = cls.Meta.model_fields.copy() fields_to_check = cls.Meta.model_fields.copy()
for field_name, field in fields_to_check.items(): for field in fields_to_check.values():
if field.has_unresolved_forward_refs(): if field.has_unresolved_forward_refs():
field = cast(Type[ForeignKeyField], field) field = cast(Type[ForeignKeyField], field)
field.evaluate_forward_ref(globalns=globalns, localns=localns) field.evaluate_forward_ref(globalns=globalns, localns=localns)

View File

@ -135,7 +135,14 @@ class SqlJoin:
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
): ):
_fields = join_parameters.model_cls.Meta.model_fields _fields = join_parameters.model_cls.Meta.model_fields
new_part = _fields[part].default_target_field_name() target_field = _fields[part]
if (
target_field.self_reference
and part == target_field.self_reference_primary
):
new_part = target_field.default_source_field_name() # type: ignore
else:
new_part = target_field.default_target_field_name() # type: ignore
self._switch_many_to_many_order_columns(part, new_part) self._switch_many_to_many_order_columns(part, new_part)
if index > 0: # nested joins if index > 0: # nested joins
fields, exclude_fields = SqlJoin.update_inclusions( fields, exclude_fields = SqlJoin.update_inclusions(
@ -430,18 +437,25 @@ class SqlJoin:
:rtype: Tuple[str, str] :rtype: Tuple[str, str]
""" """
if is_multi: if is_multi:
to_field = join_params.prev_model.get_name() target_field = join_params.model_cls.Meta.model_fields[part]
to_key = model_cls.get_column_alias(to_field) if (
target_field.self_reference
and part == target_field.self_reference_primary
):
to_key = target_field.default_target_field_name() # type: ignore
else:
to_key = target_field.default_source_field_name() # type: ignore
from_key = join_params.prev_model.get_column_alias( from_key = join_params.prev_model.get_column_alias(
join_params.prev_model.Meta.pkname join_params.prev_model.Meta.pkname
) )
breakpoint()
elif join_params.prev_model.Meta.model_fields[part].virtual: elif join_params.prev_model.Meta.model_fields[part].virtual:
to_field = join_params.prev_model.Meta.model_fields[part].get_related_name() to_field = join_params.prev_model.Meta.model_fields[part].get_related_name()
to_key = model_cls.get_column_alias(to_field) to_key = model_cls.get_column_alias(to_field)
from_key = join_params.prev_model.get_column_alias( from_key = join_params.prev_model.get_column_alias(
join_params.prev_model.Meta.pkname join_params.prev_model.Meta.pkname
) )
else: else:
to_key = model_cls.get_column_alias(model_cls.Meta.pkname) to_key = model_cls.get_column_alias(model_cls.Meta.pkname)
from_key = join_params.prev_model.get_column_alias(part) from_key = join_params.prev_model.get_column_alias(part)

View File

@ -197,7 +197,7 @@ class QuerySet:
limit_raw_sql=self.limit_sql_raw, limit_raw_sql=self.limit_sql_raw,
) )
exp = qry.build_select_expression() exp = qry.build_select_expression()
print("\n", exp.compile(compile_kwargs={"literal_binds": True})) # print("\n", exp.compile(compile_kwargs={"literal_binds": True}))
return exp return exp
def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003 def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003

View File

@ -107,8 +107,8 @@ class QuerysetProxy(ormar.QuerySetProtocol):
:type child: Model :type child: Model
""" """
model_cls = self.relation.through model_cls = self.relation.through
owner_column = self.related_field.default_target_field_name() owner_column = self.related_field.default_target_field_name() # type: ignore
child_column = self.related_field.default_source_field_name() child_column = self.related_field.default_source_field_name() # type: ignore
kwargs = {owner_column: self._owner.pk, child_column: child.pk} kwargs = {owner_column: self._owner.pk, child_column: child.pk}
if child.pk is None: if child.pk is None:
raise ModelPersistenceError( raise ModelPersistenceError(
@ -118,6 +118,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
) )
expr = model_cls.Meta.table.insert() expr = model_cls.Meta.table.insert()
expr = expr.values(**kwargs) expr = expr.values(**kwargs)
# print("\n", expr.compile(compile_kwargs={"literal_binds": True}))
await model_cls.Meta.database.execute(expr) await model_cls.Meta.database.execute(expr)
async def delete_through_instance(self, child: "T") -> None: async def delete_through_instance(self, child: "T") -> None:
@ -128,8 +129,8 @@ class QuerysetProxy(ormar.QuerySetProtocol):
:type child: Model :type child: Model
""" """
queryset = ormar.QuerySet(model_cls=self.relation.through) queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self.related_field.default_target_field_name() owner_column = self.related_field.default_target_field_name() # type: ignore
child_column = self.related_field.default_source_field_name() child_column = self.related_field.default_source_field_name() # type: ignore
kwargs = {owner_column: self._owner, child_column: child} kwargs = {owner_column: self._owner, child_column: child}
link_instance = await queryset.filter(**kwargs).get() # type: ignore link_instance = await queryset.filter(**kwargs).get() # type: ignore
await link_instance.delete() await link_instance.delete()

View File

@ -72,6 +72,16 @@ def create_test_database():
metadata.drop_all(engine) metadata.drop_all(engine)
@pytest.fixture(scope="function")
async def cleanup():
yield
async with db:
await ChildFriend.objects.delete(each=True)
await Child.objects.delete(each=True)
await Game.objects.delete(each=True)
await Person.objects.delete(each=True)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_not_uprated_model_raises_errors(): async def test_not_uprated_model_raises_errors():
Person2 = ForwardRef("Person2") Person2 = ForwardRef("Person2")
@ -126,7 +136,7 @@ async def test_self_relation():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_other_forwardref_relation(): async def test_other_forwardref_relation(cleanup):
checkers = await Game.objects.create(name="checkers") checkers = await Game.objects.create(name="checkers")
uno = await Game(name="Uno").save() uno = await Game(name="Uno").save()
@ -147,7 +157,7 @@ async def test_other_forwardref_relation():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_m2m_self_forwardref_relation(): async def test_m2m_self_forwardref_relation(cleanup):
checkers = await Game.objects.create(name="checkers") checkers = await Game.objects.create(name="checkers")
uno = await Game(name="Uno").save() uno = await Game(name="Uno").save()
jenga = await Game(name="Jenga").save() jenga = await Game(name="Jenga").save()
@ -165,16 +175,17 @@ async def test_m2m_self_forwardref_relation():
await billy.friends.add(kate) await billy.friends.add(kate)
await billy.friends.add(steve) await billy.friends.add(steve)
await steve.friends.add(kate) # await steve.friends.add(kate)
await steve.friends.add(billy) # await steve.friends.add(billy)
billy_check = await Child.objects.select_related( billy_check = await Child.objects.select_related(
[ ["friends", "favourite_game", "least_favourite_game",]
"friends",
"favourite_game",
"least_favourite_game",
"friends__favourite_game",
"friends__least_favourite_game",
]
).get(name="Billy") ).get(name="Billy")
assert len(billy_check.friends) == 2 assert len(billy_check.friends) == 2
assert billy_check.friends[0].name == "Kate"
assert billy_check.friends[1].name == "Steve"
assert billy_check.favourite_game.name == "Uno"
kate_check = await Child.objects.select_related(["also_friends",]).get(name="Kate")
assert len(kate_check.also_friends) == 1
assert kate_check.also_friends[0].name == "Billy"