refactors in join to register complex aliases on duplicate, to do is doing the same in filter clauses
This commit is contained in:
@ -52,7 +52,8 @@ def populate_default_options_values(
|
|||||||
new_model.Meta.abstract = False
|
new_model.Meta.abstract = False
|
||||||
|
|
||||||
if any(
|
if any(
|
||||||
is_field_an_forward_ref(field) for field in new_model.Meta.model_fields.values()
|
is_field_an_forward_ref(field) for field in
|
||||||
|
new_model.Meta.model_fields.values()
|
||||||
):
|
):
|
||||||
new_model.Meta.requires_ref_update = True
|
new_model.Meta.requires_ref_update = True
|
||||||
else:
|
else:
|
||||||
@ -122,20 +123,24 @@ def group_related_list(list_: List) -> Dict:
|
|||||||
will become:
|
will become:
|
||||||
{'people': {'houses': [], 'cars': ['models', 'colors']}}
|
{'people': {'houses': [], 'cars': ['models', 'colors']}}
|
||||||
|
|
||||||
|
Result dictionary is sorted by length of the values and by key
|
||||||
|
|
||||||
:param list_: list of related models used in select related
|
:param list_: list of related models used in select related
|
||||||
:type list_: List[str]
|
:type list_: List[str]
|
||||||
:return: list converted to dictionary to avoid repetition and group nested models
|
:return: list converted to dictionary to avoid repetition and group nested models
|
||||||
:rtype: Dict[str, List]
|
:rtype: Dict[str, List]
|
||||||
"""
|
"""
|
||||||
test_dict: Dict[str, Any] = dict()
|
result_dict: Dict[str, Any] = dict()
|
||||||
|
list_.sort(key=lambda x: x.split("__")[0])
|
||||||
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
|
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
|
||||||
for key, group in grouped:
|
for key, group in grouped:
|
||||||
group_list = list(group)
|
group_list = list(group)
|
||||||
new = [
|
new = sorted([
|
||||||
"__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1
|
"__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1
|
||||||
]
|
])
|
||||||
if any("__" in x for x in new):
|
if any("__" in x for x in new):
|
||||||
test_dict[key] = group_related_list(new)
|
result_dict[key] = group_related_list(new)
|
||||||
else:
|
else:
|
||||||
test_dict[key] = new
|
result_dict.setdefault(key, []).extend(new)
|
||||||
return test_dict
|
return {k: v for k, v in
|
||||||
|
sorted(result_dict.items(), key=lambda item: len(item[1]))}
|
||||||
|
|||||||
@ -98,7 +98,9 @@ class QueryClause:
|
|||||||
table_prefix,
|
table_prefix,
|
||||||
model_cls,
|
model_cls,
|
||||||
) = self._determine_filter_target_table(
|
) = self._determine_filter_target_table(
|
||||||
related_parts, select_related
|
related_parts=related_parts,
|
||||||
|
select_related=select_related,
|
||||||
|
field_name=field_name
|
||||||
)
|
)
|
||||||
|
|
||||||
table = model_cls.Meta.table
|
table = model_cls.Meta.table
|
||||||
@ -158,7 +160,7 @@ class QueryClause:
|
|||||||
return clause
|
return clause
|
||||||
|
|
||||||
def _determine_filter_target_table(
|
def _determine_filter_target_table(
|
||||||
self, related_parts: List[str], select_related: List[str]
|
self, related_parts: List[str], select_related: List[str], field_name: str
|
||||||
) -> Tuple[List[str], str, Type["Model"]]:
|
) -> Tuple[List[str], str, Type["Model"]]:
|
||||||
"""
|
"""
|
||||||
Adds related strings to select_related list otherwise the clause would fail as
|
Adds related strings to select_related list otherwise the clause would fail as
|
||||||
@ -187,18 +189,25 @@ class QueryClause:
|
|||||||
# Walk the relationships to the actual model class
|
# Walk the relationships to the actual model class
|
||||||
# against which the comparison is being made.
|
# against which the comparison is being made.
|
||||||
previous_model = model_cls
|
previous_model = model_cls
|
||||||
for part in related_parts:
|
|
||||||
part2 = part
|
|
||||||
if issubclass(model_cls.Meta.model_fields[part], ManyToManyField):
|
|
||||||
through_field = model_cls.Meta.model_fields[part]
|
|
||||||
previous_model = through_field.through
|
|
||||||
part2 = through_field.default_target_field_name() # type: ignore
|
|
||||||
manager = model_cls.Meta.alias_manager
|
manager = model_cls.Meta.alias_manager
|
||||||
|
for relation in related_parts:
|
||||||
|
related_field = model_cls.Meta.model_fields[relation]
|
||||||
|
if issubclass(related_field, ManyToManyField):
|
||||||
|
previous_model = related_field.through
|
||||||
|
relation = related_field.default_target_field_name() # type: ignore
|
||||||
table_prefix = manager.resolve_relation_alias(
|
table_prefix = manager.resolve_relation_alias(
|
||||||
from_model=previous_model, relation_name=part2
|
from_model=previous_model, relation_name=relation
|
||||||
)
|
)
|
||||||
model_cls = model_cls.Meta.model_fields[part].to
|
model_cls = related_field.to
|
||||||
previous_model = model_cls
|
previous_model = model_cls
|
||||||
|
# handle duplicated aliases in nested relations
|
||||||
|
# TODO: check later and remove nocover
|
||||||
|
complex_prefix = manager.resolve_relation_alias(
|
||||||
|
from_model=self.model_cls,
|
||||||
|
relation_name='__'.join([related_str, field_name])
|
||||||
|
)
|
||||||
|
if complex_prefix: # pragma: nocover
|
||||||
|
table_prefix = complex_prefix
|
||||||
return select_related, table_prefix, model_cls
|
return select_related, table_prefix, model_cls
|
||||||
|
|
||||||
def _compile_clause(
|
def _compile_clause(
|
||||||
|
|||||||
@ -34,8 +34,10 @@ class SqlJoin:
|
|||||||
sorted_orders: OrderedDict,
|
sorted_orders: OrderedDict,
|
||||||
main_model: Type["Model"],
|
main_model: Type["Model"],
|
||||||
relation_name: str,
|
relation_name: str,
|
||||||
|
relation_str: str,
|
||||||
related_models: Any = None,
|
related_models: Any = None,
|
||||||
own_alias: str = "",
|
own_alias: str = "",
|
||||||
|
source_model: Type["Model"] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.relation_name = relation_name
|
self.relation_name = relation_name
|
||||||
self.related_models = related_models or []
|
self.related_models = related_models or []
|
||||||
@ -53,6 +55,9 @@ class SqlJoin:
|
|||||||
self._next_model: Optional[Type["Model"]] = None
|
self._next_model: Optional[Type["Model"]] = None
|
||||||
self._next_alias: Optional[str] = None
|
self._next_alias: Optional[str] = None
|
||||||
|
|
||||||
|
self.relation_str = relation_str
|
||||||
|
self.source_model = source_model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def next_model(self) -> Type["Model"]:
|
def next_model(self) -> Type["Model"]:
|
||||||
if not self._next_model: # pragma: nocover
|
if not self._next_model: # pragma: nocover
|
||||||
@ -85,7 +90,8 @@ class SqlJoin:
|
|||||||
"""
|
"""
|
||||||
return self.main_model.Meta.alias_manager
|
return self.main_model.Meta.alias_manager
|
||||||
|
|
||||||
def on_clause(self, previous_alias: str, from_clause: str, to_clause: str,) -> text:
|
def on_clause(self, previous_alias: str, from_clause: str,
|
||||||
|
to_clause: str, ) -> text:
|
||||||
"""
|
"""
|
||||||
Receives aliases and names of both ends of the join and combines them
|
Receives aliases and names of both ends of the join and combines them
|
||||||
into one text clause used in joins.
|
into one text clause used in joins.
|
||||||
@ -117,11 +123,7 @@ class SqlJoin:
|
|||||||
self.process_m2m_through_table()
|
self.process_m2m_through_table()
|
||||||
|
|
||||||
self.next_model = self.target_field.to
|
self.next_model = self.target_field.to
|
||||||
self.next_alias = self.alias_manager.resolve_relation_alias(
|
self._forward_join()
|
||||||
from_model=self.target_field.owner, relation_name=self.relation_name
|
|
||||||
)
|
|
||||||
if self.next_alias not in self.used_aliases:
|
|
||||||
self._process_join()
|
|
||||||
|
|
||||||
self._process_following_joins()
|
self._process_following_joins()
|
||||||
|
|
||||||
@ -132,6 +134,23 @@ class SqlJoin:
|
|||||||
self.sorted_orders,
|
self.sorted_orders,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _forward_join(self):
|
||||||
|
self.next_alias = self.alias_manager.resolve_relation_alias(
|
||||||
|
from_model=self.target_field.owner, relation_name=self.relation_name
|
||||||
|
)
|
||||||
|
if self.next_alias not in self.used_aliases:
|
||||||
|
self._process_join()
|
||||||
|
else:
|
||||||
|
if '__' in self.relation_str:
|
||||||
|
relation_key = f'{self.source_model.get_name()}_{self.relation_str}'
|
||||||
|
if relation_key not in self.alias_manager:
|
||||||
|
print(f'registering {relation_key}')
|
||||||
|
self.next_alias = self.alias_manager.add_alias(
|
||||||
|
alias_key=relation_key)
|
||||||
|
else:
|
||||||
|
self.next_alias = self.alias_manager[relation_key]
|
||||||
|
self._process_join()
|
||||||
|
|
||||||
def _process_following_joins(self) -> None:
|
def _process_following_joins(self) -> None:
|
||||||
"""
|
"""
|
||||||
Iterates through nested models to create subsequent joins.
|
Iterates through nested models to create subsequent joins.
|
||||||
@ -175,7 +194,9 @@ class SqlJoin:
|
|||||||
main_model=self.next_model,
|
main_model=self.next_model,
|
||||||
relation_name=related_name,
|
relation_name=related_name,
|
||||||
related_models=remainder,
|
related_models=remainder,
|
||||||
|
relation_str='__'.join([self.relation_str, related_name]),
|
||||||
own_alias=self.next_alias,
|
own_alias=self.next_alias,
|
||||||
|
source_model=self.source_model or self.main_model
|
||||||
)
|
)
|
||||||
(
|
(
|
||||||
self.used_aliases,
|
self.used_aliases,
|
||||||
@ -203,11 +224,8 @@ class SqlJoin:
|
|||||||
self._replace_many_to_many_order_by_columns(self.relation_name, new_part)
|
self._replace_many_to_many_order_by_columns(self.relation_name, new_part)
|
||||||
|
|
||||||
self.next_model = self.target_field.through
|
self.next_model = self.target_field.through
|
||||||
self.next_alias = self.alias_manager.resolve_relation_alias(
|
self._forward_join()
|
||||||
from_model=self.target_field.owner, relation_name=self.relation_name
|
|
||||||
)
|
|
||||||
if self.next_alias not in self.used_aliases:
|
|
||||||
self._process_join()
|
|
||||||
self.relation_name = new_part
|
self.relation_name = new_part
|
||||||
self.own_alias = self.next_alias
|
self.own_alias = self.next_alias
|
||||||
self.target_field = self.next_model.Meta.model_fields[self.relation_name]
|
self.target_field = self.next_model.Meta.model_fields[self.relation_name]
|
||||||
|
|||||||
@ -159,6 +159,7 @@ class Query:
|
|||||||
sorted_orders=self.sorted_orders,
|
sorted_orders=self.sorted_orders,
|
||||||
main_model=self.model_cls,
|
main_model=self.model_cls,
|
||||||
relation_name=related,
|
relation_name=related,
|
||||||
|
relation_str=related,
|
||||||
related_models=remainder,
|
related_models=remainder,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -228,6 +228,9 @@ class QuerySet:
|
|||||||
:return: filtered QuerySet
|
:return: filtered QuerySet
|
||||||
:rtype: QuerySet
|
:rtype: QuerySet
|
||||||
"""
|
"""
|
||||||
|
# TODO: delay processing of filter clauses or switch to group one
|
||||||
|
# that keeps all aliases even if duplicated - now initialized too late
|
||||||
|
# in the join
|
||||||
qryclause = QueryClause(
|
qryclause = QueryClause(
|
||||||
model_cls=self.model,
|
model_cls=self.model,
|
||||||
select_related=self._select_related,
|
select_related=self._select_related,
|
||||||
|
|||||||
@ -33,6 +33,12 @@ class AliasManager:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._aliases_new: Dict[str, str] = dict()
|
self._aliases_new: Dict[str, str] = dict()
|
||||||
|
|
||||||
|
def __contains__(self, item):
|
||||||
|
return self._aliases_new.__contains__(item)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._aliases_new.__getitem__(key)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prefixed_columns(
|
def prefixed_columns(
|
||||||
alias: str, table: sqlalchemy.Table, fields: List = None
|
alias: str, table: sqlalchemy.Table, fields: List = None
|
||||||
|
|||||||
@ -169,7 +169,7 @@ async def test_other_forwardref_relation(cleanup):
|
|||||||
async def test_m2m_self_forwardref_relation(cleanup):
|
async def test_m2m_self_forwardref_relation(cleanup):
|
||||||
async with db:
|
async with db:
|
||||||
async with db.transaction(force_rollback=True):
|
async with db.transaction(force_rollback=True):
|
||||||
checkers = await Game.objects.create(name="checkers")
|
checkers = await Game.objects.create(name="Checkers")
|
||||||
uno = await Game(name="Uno").save()
|
uno = await Game(name="Uno").save()
|
||||||
jenga = await Game(name="Jenga").save()
|
jenga = await Game(name="Jenga").save()
|
||||||
|
|
||||||
@ -186,15 +186,17 @@ async def test_m2m_self_forwardref_relation(cleanup):
|
|||||||
await billy.friends.add(kate)
|
await billy.friends.add(kate)
|
||||||
await billy.friends.add(steve)
|
await billy.friends.add(steve)
|
||||||
|
|
||||||
# await steve.friends.add(kate)
|
|
||||||
# await steve.friends.add(billy)
|
|
||||||
|
|
||||||
billy_check = await Child.objects.select_related(
|
billy_check = await Child.objects.select_related(
|
||||||
["friends", "favourite_game", "least_favourite_game"]
|
["friends", "favourite_game", "least_favourite_game",
|
||||||
|
"friends__favourite_game", "friends__least_favourite_game"]
|
||||||
).get(name="Billy")
|
).get(name="Billy")
|
||||||
assert len(billy_check.friends) == 2
|
assert len(billy_check.friends) == 2
|
||||||
assert billy_check.friends[0].name == "Kate"
|
assert billy_check.friends[0].name == "Kate"
|
||||||
|
assert billy_check.friends[0].favourite_game.name == 'Checkers'
|
||||||
|
assert billy_check.friends[0].least_favourite_game.name == 'Uno'
|
||||||
assert billy_check.friends[1].name == "Steve"
|
assert billy_check.friends[1].name == "Steve"
|
||||||
|
assert billy_check.friends[1].favourite_game.name == 'Jenga'
|
||||||
|
assert billy_check.friends[1].least_favourite_game.name == 'Uno'
|
||||||
assert billy_check.favourite_game.name == "Uno"
|
assert billy_check.favourite_game.name == "Uno"
|
||||||
|
|
||||||
kate_check = await Child.objects.select_related(["also_friends",]).get(
|
kate_check = await Child.objects.select_related(["also_friends",]).get(
|
||||||
@ -203,3 +205,13 @@ async def test_m2m_self_forwardref_relation(cleanup):
|
|||||||
|
|
||||||
assert len(kate_check.also_friends) == 1
|
assert len(kate_check.also_friends) == 1
|
||||||
assert kate_check.also_friends[0].name == "Billy"
|
assert kate_check.also_friends[0].name == "Billy"
|
||||||
|
|
||||||
|
# TODO: Fix filters with complex prefixes
|
||||||
|
# billy_check = await Child.objects.select_related(
|
||||||
|
# ["friends", "favourite_game", "least_favourite_game",
|
||||||
|
# "friends__favourite_game", "friends__least_favourite_game"]
|
||||||
|
# ).filter(friends__favourite_game__name="Checkers").get(name="Billy")
|
||||||
|
# assert len(billy_check.friends) == 1
|
||||||
|
# assert billy_check.friends[0].name == "Kate"
|
||||||
|
# assert billy_check.friends[0].favourite_game.name == 'Checkers'
|
||||||
|
# assert billy_check.friends[0].least_favourite_game.name == 'Uno'
|
||||||
|
|||||||
9
tests/test_models_helpers.py
Normal file
9
tests/test_models_helpers.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from ormar.models.helpers.models import group_related_list
|
||||||
|
|
||||||
|
|
||||||
|
def test_group_related_list():
|
||||||
|
given = ['friends__least_favourite_game', 'least_favourite_game', 'friends',
|
||||||
|
'favourite_game', 'friends__favourite_game']
|
||||||
|
expected = {'least_favourite_game': [], 'favourite_game': [],
|
||||||
|
'friends': ['favourite_game', 'least_favourite_game']}
|
||||||
|
assert group_related_list(given) == expected
|
||||||
Reference in New Issue
Block a user