working m2m and fk self relations with forwardref
This commit is contained in:
@ -45,6 +45,7 @@ class BaseField(FieldInfo):
|
|||||||
to: Type["Model"]
|
to: Type["Model"]
|
||||||
through: Type["Model"]
|
through: Type["Model"]
|
||||||
self_reference: bool = False
|
self_reference: bool = False
|
||||||
|
self_reference_primary: Optional[str] = None
|
||||||
|
|
||||||
default: Any
|
default: Any
|
||||||
server_default: Any
|
server_default: Any
|
||||||
@ -277,6 +278,7 @@ class BaseField(FieldInfo):
|
|||||||
cls.owner == cls.to or cls.owner.Meta == cls.to.Meta
|
cls.owner == cls.to or cls.owner.Meta == cls.to.Meta
|
||||||
):
|
):
|
||||||
cls.self_reference = True
|
cls.self_reference = True
|
||||||
|
cls.self_reference_primary = cls.name
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def has_unresolved_forward_refs(cls) -> bool:
|
def has_unresolved_forward_refs(cls) -> bool:
|
||||||
|
|||||||
@ -128,7 +128,7 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro
|
|||||||
:return: name of the field
|
:return: name of the field
|
||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
prefix = "to_" if cls.self_reference else ""
|
prefix = "from_" if cls.self_reference else ""
|
||||||
return f"{prefix}{cls.to.get_name()}"
|
return f"{prefix}{cls.to.get_name()}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -138,7 +138,7 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro
|
|||||||
:return: name of the field
|
:return: name of the field
|
||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
prefix = "from_" if cls.self_reference else ""
|
prefix = "to_" if cls.self_reference else ""
|
||||||
return f"{prefix}{cls.owner.get_name()}"
|
return f"{prefix}{cls.owner.get_name()}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -110,6 +110,7 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None:
|
|||||||
related_name=model_field.name,
|
related_name=model_field.name,
|
||||||
owner=model_field.to,
|
owner=model_field.to,
|
||||||
self_reference=model_field.self_reference,
|
self_reference=model_field.self_reference,
|
||||||
|
self_reference_primary=model_field.self_reference_primary,
|
||||||
)
|
)
|
||||||
# register foreign keys on through model
|
# register foreign keys on through model
|
||||||
adjust_through_many_to_many_model(model_field=model_field)
|
adjust_through_many_to_many_model(model_field=model_field)
|
||||||
|
|||||||
@ -586,7 +586,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
|||||||
new_model = populate_meta_tablename_columns_and_pk(name, new_model)
|
new_model = populate_meta_tablename_columns_and_pk(name, new_model)
|
||||||
populate_meta_sqlalchemy_table_if_required(new_model.Meta)
|
populate_meta_sqlalchemy_table_if_required(new_model.Meta)
|
||||||
expand_reverse_relationships(new_model)
|
expand_reverse_relationships(new_model)
|
||||||
for field_name, field in new_model.Meta.model_fields.items():
|
for field in new_model.Meta.model_fields.values():
|
||||||
register_relation_in_alias_manager(field=field)
|
register_relation_in_alias_manager(field=field)
|
||||||
|
|
||||||
if new_model.Meta.pkname not in attrs["__annotations__"]:
|
if new_model.Meta.pkname not in attrs["__annotations__"]:
|
||||||
|
|||||||
@ -125,7 +125,13 @@ class Model(NewBaseModel):
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
through_field = previous_model.Meta.model_fields[related_name]
|
through_field = previous_model.Meta.model_fields[related_name]
|
||||||
rel_name2 = through_field.default_target_field_name() # type: ignore
|
if (
|
||||||
|
through_field.self_reference
|
||||||
|
and related_name == through_field.self_reference_primary
|
||||||
|
):
|
||||||
|
rel_name2 = through_field.default_source_field_name() # type: ignore
|
||||||
|
else:
|
||||||
|
rel_name2 = through_field.default_target_field_name() # type: ignore
|
||||||
previous_model = through_field.through # type: ignore
|
previous_model = through_field.through # type: ignore
|
||||||
|
|
||||||
if previous_model and rel_name2:
|
if previous_model and rel_name2:
|
||||||
|
|||||||
@ -436,6 +436,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
|
|
||||||
Populates Meta table of the Model which is left empty before.
|
Populates Meta table of the Model which is left empty before.
|
||||||
|
|
||||||
|
Sets self_reference flag on models that links to themselves.
|
||||||
|
|
||||||
Calls the pydantic method to evaluate pydantic fields.
|
Calls the pydantic method to evaluate pydantic fields.
|
||||||
|
|
||||||
:param localns: local namespace
|
:param localns: local namespace
|
||||||
@ -446,7 +448,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
globalns = sys.modules[cls.__module__].__dict__.copy()
|
globalns = sys.modules[cls.__module__].__dict__.copy()
|
||||||
globalns.setdefault(cls.__name__, cls)
|
globalns.setdefault(cls.__name__, cls)
|
||||||
fields_to_check = cls.Meta.model_fields.copy()
|
fields_to_check = cls.Meta.model_fields.copy()
|
||||||
for field_name, field in fields_to_check.items():
|
for field in fields_to_check.values():
|
||||||
if field.has_unresolved_forward_refs():
|
if field.has_unresolved_forward_refs():
|
||||||
field = cast(Type[ForeignKeyField], field)
|
field = cast(Type[ForeignKeyField], field)
|
||||||
field.evaluate_forward_ref(globalns=globalns, localns=localns)
|
field.evaluate_forward_ref(globalns=globalns, localns=localns)
|
||||||
|
|||||||
@ -135,7 +135,14 @@ class SqlJoin:
|
|||||||
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
|
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
|
||||||
):
|
):
|
||||||
_fields = join_parameters.model_cls.Meta.model_fields
|
_fields = join_parameters.model_cls.Meta.model_fields
|
||||||
new_part = _fields[part].default_target_field_name()
|
target_field = _fields[part]
|
||||||
|
if (
|
||||||
|
target_field.self_reference
|
||||||
|
and part == target_field.self_reference_primary
|
||||||
|
):
|
||||||
|
new_part = target_field.default_source_field_name() # type: ignore
|
||||||
|
else:
|
||||||
|
new_part = target_field.default_target_field_name() # type: ignore
|
||||||
self._switch_many_to_many_order_columns(part, new_part)
|
self._switch_many_to_many_order_columns(part, new_part)
|
||||||
if index > 0: # nested joins
|
if index > 0: # nested joins
|
||||||
fields, exclude_fields = SqlJoin.update_inclusions(
|
fields, exclude_fields = SqlJoin.update_inclusions(
|
||||||
@ -430,18 +437,25 @@ class SqlJoin:
|
|||||||
:rtype: Tuple[str, str]
|
:rtype: Tuple[str, str]
|
||||||
"""
|
"""
|
||||||
if is_multi:
|
if is_multi:
|
||||||
to_field = join_params.prev_model.get_name()
|
target_field = join_params.model_cls.Meta.model_fields[part]
|
||||||
to_key = model_cls.get_column_alias(to_field)
|
if (
|
||||||
|
target_field.self_reference
|
||||||
|
and part == target_field.self_reference_primary
|
||||||
|
):
|
||||||
|
to_key = target_field.default_target_field_name() # type: ignore
|
||||||
|
else:
|
||||||
|
to_key = target_field.default_source_field_name() # type: ignore
|
||||||
from_key = join_params.prev_model.get_column_alias(
|
from_key = join_params.prev_model.get_column_alias(
|
||||||
join_params.prev_model.Meta.pkname
|
join_params.prev_model.Meta.pkname
|
||||||
)
|
)
|
||||||
breakpoint()
|
|
||||||
elif join_params.prev_model.Meta.model_fields[part].virtual:
|
elif join_params.prev_model.Meta.model_fields[part].virtual:
|
||||||
to_field = join_params.prev_model.Meta.model_fields[part].get_related_name()
|
to_field = join_params.prev_model.Meta.model_fields[part].get_related_name()
|
||||||
to_key = model_cls.get_column_alias(to_field)
|
to_key = model_cls.get_column_alias(to_field)
|
||||||
from_key = join_params.prev_model.get_column_alias(
|
from_key = join_params.prev_model.get_column_alias(
|
||||||
join_params.prev_model.Meta.pkname
|
join_params.prev_model.Meta.pkname
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
to_key = model_cls.get_column_alias(model_cls.Meta.pkname)
|
to_key = model_cls.get_column_alias(model_cls.Meta.pkname)
|
||||||
from_key = join_params.prev_model.get_column_alias(part)
|
from_key = join_params.prev_model.get_column_alias(part)
|
||||||
|
|||||||
@ -197,7 +197,7 @@ class QuerySet:
|
|||||||
limit_raw_sql=self.limit_sql_raw,
|
limit_raw_sql=self.limit_sql_raw,
|
||||||
)
|
)
|
||||||
exp = qry.build_select_expression()
|
exp = qry.build_select_expression()
|
||||||
print("\n", exp.compile(compile_kwargs={"literal_binds": True}))
|
# print("\n", exp.compile(compile_kwargs={"literal_binds": True}))
|
||||||
return exp
|
return exp
|
||||||
|
|
||||||
def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003
|
def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003
|
||||||
|
|||||||
@ -107,8 +107,8 @@ class QuerysetProxy(ormar.QuerySetProtocol):
|
|||||||
:type child: Model
|
:type child: Model
|
||||||
"""
|
"""
|
||||||
model_cls = self.relation.through
|
model_cls = self.relation.through
|
||||||
owner_column = self.related_field.default_target_field_name()
|
owner_column = self.related_field.default_target_field_name() # type: ignore
|
||||||
child_column = self.related_field.default_source_field_name()
|
child_column = self.related_field.default_source_field_name() # type: ignore
|
||||||
kwargs = {owner_column: self._owner.pk, child_column: child.pk}
|
kwargs = {owner_column: self._owner.pk, child_column: child.pk}
|
||||||
if child.pk is None:
|
if child.pk is None:
|
||||||
raise ModelPersistenceError(
|
raise ModelPersistenceError(
|
||||||
@ -118,6 +118,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
|
|||||||
)
|
)
|
||||||
expr = model_cls.Meta.table.insert()
|
expr = model_cls.Meta.table.insert()
|
||||||
expr = expr.values(**kwargs)
|
expr = expr.values(**kwargs)
|
||||||
|
# print("\n", expr.compile(compile_kwargs={"literal_binds": True}))
|
||||||
await model_cls.Meta.database.execute(expr)
|
await model_cls.Meta.database.execute(expr)
|
||||||
|
|
||||||
async def delete_through_instance(self, child: "T") -> None:
|
async def delete_through_instance(self, child: "T") -> None:
|
||||||
@ -128,8 +129,8 @@ class QuerysetProxy(ormar.QuerySetProtocol):
|
|||||||
:type child: Model
|
:type child: Model
|
||||||
"""
|
"""
|
||||||
queryset = ormar.QuerySet(model_cls=self.relation.through)
|
queryset = ormar.QuerySet(model_cls=self.relation.through)
|
||||||
owner_column = self.related_field.default_target_field_name()
|
owner_column = self.related_field.default_target_field_name() # type: ignore
|
||||||
child_column = self.related_field.default_source_field_name()
|
child_column = self.related_field.default_source_field_name() # type: ignore
|
||||||
kwargs = {owner_column: self._owner, child_column: child}
|
kwargs = {owner_column: self._owner, child_column: child}
|
||||||
link_instance = await queryset.filter(**kwargs).get() # type: ignore
|
link_instance = await queryset.filter(**kwargs).get() # type: ignore
|
||||||
await link_instance.delete()
|
await link_instance.delete()
|
||||||
|
|||||||
@ -72,6 +72,16 @@ def create_test_database():
|
|||||||
metadata.drop_all(engine)
|
metadata.drop_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
async def cleanup():
|
||||||
|
yield
|
||||||
|
async with db:
|
||||||
|
await ChildFriend.objects.delete(each=True)
|
||||||
|
await Child.objects.delete(each=True)
|
||||||
|
await Game.objects.delete(each=True)
|
||||||
|
await Person.objects.delete(each=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_not_uprated_model_raises_errors():
|
async def test_not_uprated_model_raises_errors():
|
||||||
Person2 = ForwardRef("Person2")
|
Person2 = ForwardRef("Person2")
|
||||||
@ -126,7 +136,7 @@ async def test_self_relation():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_other_forwardref_relation():
|
async def test_other_forwardref_relation(cleanup):
|
||||||
checkers = await Game.objects.create(name="checkers")
|
checkers = await Game.objects.create(name="checkers")
|
||||||
uno = await Game(name="Uno").save()
|
uno = await Game(name="Uno").save()
|
||||||
|
|
||||||
@ -147,7 +157,7 @@ async def test_other_forwardref_relation():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_m2m_self_forwardref_relation():
|
async def test_m2m_self_forwardref_relation(cleanup):
|
||||||
checkers = await Game.objects.create(name="checkers")
|
checkers = await Game.objects.create(name="checkers")
|
||||||
uno = await Game(name="Uno").save()
|
uno = await Game(name="Uno").save()
|
||||||
jenga = await Game(name="Jenga").save()
|
jenga = await Game(name="Jenga").save()
|
||||||
@ -165,16 +175,17 @@ async def test_m2m_self_forwardref_relation():
|
|||||||
await billy.friends.add(kate)
|
await billy.friends.add(kate)
|
||||||
await billy.friends.add(steve)
|
await billy.friends.add(steve)
|
||||||
|
|
||||||
await steve.friends.add(kate)
|
# await steve.friends.add(kate)
|
||||||
await steve.friends.add(billy)
|
# await steve.friends.add(billy)
|
||||||
|
|
||||||
billy_check = await Child.objects.select_related(
|
billy_check = await Child.objects.select_related(
|
||||||
[
|
["friends", "favourite_game", "least_favourite_game",]
|
||||||
"friends",
|
|
||||||
"favourite_game",
|
|
||||||
"least_favourite_game",
|
|
||||||
"friends__favourite_game",
|
|
||||||
"friends__least_favourite_game",
|
|
||||||
]
|
|
||||||
).get(name="Billy")
|
).get(name="Billy")
|
||||||
assert len(billy_check.friends) == 2
|
assert len(billy_check.friends) == 2
|
||||||
|
assert billy_check.friends[0].name == "Kate"
|
||||||
|
assert billy_check.friends[1].name == "Steve"
|
||||||
|
assert billy_check.favourite_game.name == "Uno"
|
||||||
|
|
||||||
|
kate_check = await Child.objects.select_related(["also_friends",]).get(name="Kate")
|
||||||
|
assert len(kate_check.also_friends) == 1
|
||||||
|
assert kate_check.also_friends[0].name == "Billy"
|
||||||
|
|||||||
Reference in New Issue
Block a user