working m2m and fk self relations with forwardref
This commit is contained in:
@ -45,6 +45,7 @@ class BaseField(FieldInfo):
|
||||
to: Type["Model"]
|
||||
through: Type["Model"]
|
||||
self_reference: bool = False
|
||||
self_reference_primary: Optional[str] = None
|
||||
|
||||
default: Any
|
||||
server_default: Any
|
||||
@ -277,6 +278,7 @@ class BaseField(FieldInfo):
|
||||
cls.owner == cls.to or cls.owner.Meta == cls.to.Meta
|
||||
):
|
||||
cls.self_reference = True
|
||||
cls.self_reference_primary = cls.name
|
||||
|
||||
@classmethod
|
||||
def has_unresolved_forward_refs(cls) -> bool:
|
||||
|
||||
@ -128,7 +128,7 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro
|
||||
:return: name of the field
|
||||
:rtype: str
|
||||
"""
|
||||
prefix = "to_" if cls.self_reference else ""
|
||||
prefix = "from_" if cls.self_reference else ""
|
||||
return f"{prefix}{cls.to.get_name()}"
|
||||
|
||||
@classmethod
|
||||
@ -138,7 +138,7 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro
|
||||
:return: name of the field
|
||||
:rtype: str
|
||||
"""
|
||||
prefix = "from_" if cls.self_reference else ""
|
||||
prefix = "to_" if cls.self_reference else ""
|
||||
return f"{prefix}{cls.owner.get_name()}"
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -110,6 +110,7 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None:
|
||||
related_name=model_field.name,
|
||||
owner=model_field.to,
|
||||
self_reference=model_field.self_reference,
|
||||
self_reference_primary=model_field.self_reference_primary,
|
||||
)
|
||||
# register foreign keys on through model
|
||||
adjust_through_many_to_many_model(model_field=model_field)
|
||||
|
||||
@ -586,7 +586,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
new_model = populate_meta_tablename_columns_and_pk(name, new_model)
|
||||
populate_meta_sqlalchemy_table_if_required(new_model.Meta)
|
||||
expand_reverse_relationships(new_model)
|
||||
for field_name, field in new_model.Meta.model_fields.items():
|
||||
for field in new_model.Meta.model_fields.values():
|
||||
register_relation_in_alias_manager(field=field)
|
||||
|
||||
if new_model.Meta.pkname not in attrs["__annotations__"]:
|
||||
|
||||
@ -125,7 +125,13 @@ class Model(NewBaseModel):
|
||||
)
|
||||
):
|
||||
through_field = previous_model.Meta.model_fields[related_name]
|
||||
rel_name2 = through_field.default_target_field_name() # type: ignore
|
||||
if (
|
||||
through_field.self_reference
|
||||
and related_name == through_field.self_reference_primary
|
||||
):
|
||||
rel_name2 = through_field.default_source_field_name() # type: ignore
|
||||
else:
|
||||
rel_name2 = through_field.default_target_field_name() # type: ignore
|
||||
previous_model = through_field.through # type: ignore
|
||||
|
||||
if previous_model and rel_name2:
|
||||
|
||||
@ -436,6 +436,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
|
||||
Populates Meta table of the Model which is left empty before.
|
||||
|
||||
Sets self_reference flag on models that links to themselves.
|
||||
|
||||
Calls the pydantic method to evaluate pydantic fields.
|
||||
|
||||
:param localns: local namespace
|
||||
@ -446,7 +448,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
globalns = sys.modules[cls.__module__].__dict__.copy()
|
||||
globalns.setdefault(cls.__name__, cls)
|
||||
fields_to_check = cls.Meta.model_fields.copy()
|
||||
for field_name, field in fields_to_check.items():
|
||||
for field in fields_to_check.values():
|
||||
if field.has_unresolved_forward_refs():
|
||||
field = cast(Type[ForeignKeyField], field)
|
||||
field.evaluate_forward_ref(globalns=globalns, localns=localns)
|
||||
|
||||
@ -135,7 +135,14 @@ class SqlJoin:
|
||||
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
|
||||
):
|
||||
_fields = join_parameters.model_cls.Meta.model_fields
|
||||
new_part = _fields[part].default_target_field_name()
|
||||
target_field = _fields[part]
|
||||
if (
|
||||
target_field.self_reference
|
||||
and part == target_field.self_reference_primary
|
||||
):
|
||||
new_part = target_field.default_source_field_name() # type: ignore
|
||||
else:
|
||||
new_part = target_field.default_target_field_name() # type: ignore
|
||||
self._switch_many_to_many_order_columns(part, new_part)
|
||||
if index > 0: # nested joins
|
||||
fields, exclude_fields = SqlJoin.update_inclusions(
|
||||
@ -430,18 +437,25 @@ class SqlJoin:
|
||||
:rtype: Tuple[str, str]
|
||||
"""
|
||||
if is_multi:
|
||||
to_field = join_params.prev_model.get_name()
|
||||
to_key = model_cls.get_column_alias(to_field)
|
||||
target_field = join_params.model_cls.Meta.model_fields[part]
|
||||
if (
|
||||
target_field.self_reference
|
||||
and part == target_field.self_reference_primary
|
||||
):
|
||||
to_key = target_field.default_target_field_name() # type: ignore
|
||||
else:
|
||||
to_key = target_field.default_source_field_name() # type: ignore
|
||||
from_key = join_params.prev_model.get_column_alias(
|
||||
join_params.prev_model.Meta.pkname
|
||||
)
|
||||
breakpoint()
|
||||
|
||||
elif join_params.prev_model.Meta.model_fields[part].virtual:
|
||||
to_field = join_params.prev_model.Meta.model_fields[part].get_related_name()
|
||||
to_key = model_cls.get_column_alias(to_field)
|
||||
from_key = join_params.prev_model.get_column_alias(
|
||||
join_params.prev_model.Meta.pkname
|
||||
)
|
||||
|
||||
else:
|
||||
to_key = model_cls.get_column_alias(model_cls.Meta.pkname)
|
||||
from_key = join_params.prev_model.get_column_alias(part)
|
||||
|
||||
@ -197,7 +197,7 @@ class QuerySet:
|
||||
limit_raw_sql=self.limit_sql_raw,
|
||||
)
|
||||
exp = qry.build_select_expression()
|
||||
print("\n", exp.compile(compile_kwargs={"literal_binds": True}))
|
||||
# print("\n", exp.compile(compile_kwargs={"literal_binds": True}))
|
||||
return exp
|
||||
|
||||
def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003
|
||||
|
||||
@ -107,8 +107,8 @@ class QuerysetProxy(ormar.QuerySetProtocol):
|
||||
:type child: Model
|
||||
"""
|
||||
model_cls = self.relation.through
|
||||
owner_column = self.related_field.default_target_field_name()
|
||||
child_column = self.related_field.default_source_field_name()
|
||||
owner_column = self.related_field.default_target_field_name() # type: ignore
|
||||
child_column = self.related_field.default_source_field_name() # type: ignore
|
||||
kwargs = {owner_column: self._owner.pk, child_column: child.pk}
|
||||
if child.pk is None:
|
||||
raise ModelPersistenceError(
|
||||
@ -118,6 +118,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
|
||||
)
|
||||
expr = model_cls.Meta.table.insert()
|
||||
expr = expr.values(**kwargs)
|
||||
# print("\n", expr.compile(compile_kwargs={"literal_binds": True}))
|
||||
await model_cls.Meta.database.execute(expr)
|
||||
|
||||
async def delete_through_instance(self, child: "T") -> None:
|
||||
@ -128,8 +129,8 @@ class QuerysetProxy(ormar.QuerySetProtocol):
|
||||
:type child: Model
|
||||
"""
|
||||
queryset = ormar.QuerySet(model_cls=self.relation.through)
|
||||
owner_column = self.related_field.default_target_field_name()
|
||||
child_column = self.related_field.default_source_field_name()
|
||||
owner_column = self.related_field.default_target_field_name() # type: ignore
|
||||
child_column = self.related_field.default_source_field_name() # type: ignore
|
||||
kwargs = {owner_column: self._owner, child_column: child}
|
||||
link_instance = await queryset.filter(**kwargs).get() # type: ignore
|
||||
await link_instance.delete()
|
||||
|
||||
@ -72,6 +72,16 @@ def create_test_database():
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def cleanup():
|
||||
yield
|
||||
async with db:
|
||||
await ChildFriend.objects.delete(each=True)
|
||||
await Child.objects.delete(each=True)
|
||||
await Game.objects.delete(each=True)
|
||||
await Person.objects.delete(each=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_uprated_model_raises_errors():
|
||||
Person2 = ForwardRef("Person2")
|
||||
@ -126,7 +136,7 @@ async def test_self_relation():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_other_forwardref_relation():
|
||||
async def test_other_forwardref_relation(cleanup):
|
||||
checkers = await Game.objects.create(name="checkers")
|
||||
uno = await Game(name="Uno").save()
|
||||
|
||||
@ -147,7 +157,7 @@ async def test_other_forwardref_relation():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_m2m_self_forwardref_relation():
|
||||
async def test_m2m_self_forwardref_relation(cleanup):
|
||||
checkers = await Game.objects.create(name="checkers")
|
||||
uno = await Game(name="Uno").save()
|
||||
jenga = await Game(name="Jenga").save()
|
||||
@ -165,16 +175,17 @@ async def test_m2m_self_forwardref_relation():
|
||||
await billy.friends.add(kate)
|
||||
await billy.friends.add(steve)
|
||||
|
||||
await steve.friends.add(kate)
|
||||
await steve.friends.add(billy)
|
||||
# await steve.friends.add(kate)
|
||||
# await steve.friends.add(billy)
|
||||
|
||||
billy_check = await Child.objects.select_related(
|
||||
[
|
||||
"friends",
|
||||
"favourite_game",
|
||||
"least_favourite_game",
|
||||
"friends__favourite_game",
|
||||
"friends__least_favourite_game",
|
||||
]
|
||||
["friends", "favourite_game", "least_favourite_game",]
|
||||
).get(name="Billy")
|
||||
assert len(billy_check.friends) == 2
|
||||
assert billy_check.friends[0].name == "Kate"
|
||||
assert billy_check.friends[1].name == "Steve"
|
||||
assert billy_check.favourite_game.name == "Uno"
|
||||
|
||||
kate_check = await Child.objects.select_related(["also_friends",]).get(name="Kate")
|
||||
assert len(kate_check.also_friends) == 1
|
||||
assert kate_check.also_friends[0].name == "Billy"
|
||||
|
||||
Reference in New Issue
Block a user