refactors in join to register complex aliases on duplicate, to do is doing the same in filter clauses

This commit is contained in:
collerek
2021-01-17 12:29:21 +01:00
parent 28cc847b57
commit d6e2c85b79
9 changed files with 137 additions and 74 deletions

View File

@ -52,7 +52,8 @@ def populate_default_options_values(
new_model.Meta.abstract = False
if any(
is_field_an_forward_ref(field) for field in new_model.Meta.model_fields.values()
is_field_an_forward_ref(field) for field in
new_model.Meta.model_fields.values()
):
new_model.Meta.requires_ref_update = True
else:
@ -122,20 +123,24 @@ def group_related_list(list_: List) -> Dict:
will become:
{'people': {'houses': [], 'cars': ['models', 'colors']}}
Result dictionary is sorted by length of the values and by key
:param list_: list of related models used in select related
:type list_: List[str]
:return: list converted to dictionary to avoid repetition and group nested models
:rtype: Dict[str, List]
"""
test_dict: Dict[str, Any] = dict()
result_dict: Dict[str, Any] = dict()
list_.sort(key=lambda x: x.split("__")[0])
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
for key, group in grouped:
group_list = list(group)
new = [
new = sorted([
"__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1
]
])
if any("__" in x for x in new):
test_dict[key] = group_related_list(new)
result_dict[key] = group_related_list(new)
else:
test_dict[key] = new
return test_dict
result_dict.setdefault(key, []).extend(new)
return {k: v for k, v in
sorted(result_dict.items(), key=lambda item: len(item[1]))}

View File

@ -98,7 +98,9 @@ class QueryClause:
table_prefix,
model_cls,
) = self._determine_filter_target_table(
related_parts, select_related
related_parts=related_parts,
select_related=select_related,
field_name=field_name
)
table = model_cls.Meta.table
@ -158,7 +160,7 @@ class QueryClause:
return clause
def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str]
self, related_parts: List[str], select_related: List[str], field_name: str
) -> Tuple[List[str], str, Type["Model"]]:
"""
Adds related strings to select_related list otherwise the clause would fail as
@ -187,18 +189,25 @@ class QueryClause:
# Walk the relationships to the actual model class
# against which the comparison is being made.
previous_model = model_cls
for part in related_parts:
part2 = part
if issubclass(model_cls.Meta.model_fields[part], ManyToManyField):
through_field = model_cls.Meta.model_fields[part]
previous_model = through_field.through
part2 = through_field.default_target_field_name() # type: ignore
manager = model_cls.Meta.alias_manager
for relation in related_parts:
related_field = model_cls.Meta.model_fields[relation]
if issubclass(related_field, ManyToManyField):
previous_model = related_field.through
relation = related_field.default_target_field_name() # type: ignore
table_prefix = manager.resolve_relation_alias(
from_model=previous_model, relation_name=part2
from_model=previous_model, relation_name=relation
)
model_cls = model_cls.Meta.model_fields[part].to
model_cls = related_field.to
previous_model = model_cls
# handle duplicated aliases in nested relations
# TODO: check later and remove nocover
complex_prefix = manager.resolve_relation_alias(
from_model=self.model_cls,
relation_name='__'.join([related_str, field_name])
)
if complex_prefix: # pragma: nocover
table_prefix = complex_prefix
return select_related, table_prefix, model_cls
def _compile_clause(

View File

@ -34,8 +34,10 @@ class SqlJoin:
sorted_orders: OrderedDict,
main_model: Type["Model"],
relation_name: str,
relation_str: str,
related_models: Any = None,
own_alias: str = "",
source_model: Type["Model"] = None,
) -> None:
self.relation_name = relation_name
self.related_models = related_models or []
@ -53,6 +55,9 @@ class SqlJoin:
self._next_model: Optional[Type["Model"]] = None
self._next_alias: Optional[str] = None
self.relation_str = relation_str
self.source_model = source_model
@property
def next_model(self) -> Type["Model"]:
if not self._next_model: # pragma: nocover
@ -85,7 +90,8 @@ class SqlJoin:
"""
return self.main_model.Meta.alias_manager
def on_clause(self, previous_alias: str, from_clause: str, to_clause: str,) -> text:
def on_clause(self, previous_alias: str, from_clause: str,
to_clause: str, ) -> text:
"""
Receives aliases and names of both ends of the join and combines them
into one text clause used in joins.
@ -117,11 +123,7 @@ class SqlJoin:
self.process_m2m_through_table()
self.next_model = self.target_field.to
self.next_alias = self.alias_manager.resolve_relation_alias(
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
self._forward_join()
self._process_following_joins()
@ -132,6 +134,23 @@ class SqlJoin:
self.sorted_orders,
)
def _forward_join(self):
self.next_alias = self.alias_manager.resolve_relation_alias(
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
else:
if '__' in self.relation_str:
relation_key = f'{self.source_model.get_name()}_{self.relation_str}'
if relation_key not in self.alias_manager:
print(f'registering {relation_key}')
self.next_alias = self.alias_manager.add_alias(
alias_key=relation_key)
else:
self.next_alias = self.alias_manager[relation_key]
self._process_join()
def _process_following_joins(self) -> None:
"""
Iterates through nested models to create subsequent joins.
@ -175,7 +194,9 @@ class SqlJoin:
main_model=self.next_model,
relation_name=related_name,
related_models=remainder,
relation_str='__'.join([self.relation_str, related_name]),
own_alias=self.next_alias,
source_model=self.source_model or self.main_model
)
(
self.used_aliases,
@ -203,11 +224,8 @@ class SqlJoin:
self._replace_many_to_many_order_by_columns(self.relation_name, new_part)
self.next_model = self.target_field.through
self.next_alias = self.alias_manager.resolve_relation_alias(
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
self._forward_join()
self.relation_name = new_part
self.own_alias = self.next_alias
self.target_field = self.next_model.Meta.model_fields[self.relation_name]

View File

@ -159,6 +159,7 @@ class Query:
sorted_orders=self.sorted_orders,
main_model=self.model_cls,
relation_name=related,
relation_str=related,
related_models=remainder,
)

View File

@ -228,6 +228,9 @@ class QuerySet:
:return: filtered QuerySet
:rtype: QuerySet
"""
# TODO: delay processing of filter clauses or switch to group one
# that keeps all aliases even if duplicated - now initialized too late
# in the join
qryclause = QueryClause(
model_cls=self.model,
select_related=self._select_related,

View File

@ -33,6 +33,12 @@ class AliasManager:
def __init__(self) -> None:
self._aliases_new: Dict[str, str] = dict()
def __contains__(self, item):
return self._aliases_new.__contains__(item)
def __getitem__(self, key):
return self._aliases_new.__getitem__(key)
@staticmethod
def prefixed_columns(
alias: str, table: sqlalchemy.Table, fields: List = None

View File

@ -169,7 +169,7 @@ async def test_other_forwardref_relation(cleanup):
async def test_m2m_self_forwardref_relation(cleanup):
async with db:
async with db.transaction(force_rollback=True):
checkers = await Game.objects.create(name="checkers")
checkers = await Game.objects.create(name="Checkers")
uno = await Game(name="Uno").save()
jenga = await Game(name="Jenga").save()
@ -186,15 +186,17 @@ async def test_m2m_self_forwardref_relation(cleanup):
await billy.friends.add(kate)
await billy.friends.add(steve)
# await steve.friends.add(kate)
# await steve.friends.add(billy)
billy_check = await Child.objects.select_related(
["friends", "favourite_game", "least_favourite_game"]
["friends", "favourite_game", "least_favourite_game",
"friends__favourite_game", "friends__least_favourite_game"]
).get(name="Billy")
assert len(billy_check.friends) == 2
assert billy_check.friends[0].name == "Kate"
assert billy_check.friends[0].favourite_game.name == 'Checkers'
assert billy_check.friends[0].least_favourite_game.name == 'Uno'
assert billy_check.friends[1].name == "Steve"
assert billy_check.friends[1].favourite_game.name == 'Jenga'
assert billy_check.friends[1].least_favourite_game.name == 'Uno'
assert billy_check.favourite_game.name == "Uno"
kate_check = await Child.objects.select_related(["also_friends",]).get(
@ -203,3 +205,13 @@ async def test_m2m_self_forwardref_relation(cleanup):
assert len(kate_check.also_friends) == 1
assert kate_check.also_friends[0].name == "Billy"
# TODO: Fix filters with complex prefixes
# billy_check = await Child.objects.select_related(
# ["friends", "favourite_game", "least_favourite_game",
# "friends__favourite_game", "friends__least_favourite_game"]
# ).filter(friends__favourite_game__name="Checkers").get(name="Billy")
# assert len(billy_check.friends) == 1
# assert billy_check.friends[0].name == "Kate"
# assert billy_check.friends[0].favourite_game.name == 'Checkers'
# assert billy_check.friends[0].least_favourite_game.name == 'Uno'

View File

@ -0,0 +1,9 @@
from ormar.models.helpers.models import group_related_list
def test_group_related_list():
given = ['friends__least_favourite_game', 'least_favourite_game', 'friends',
'favourite_game', 'friends__favourite_game']
expected = {'least_favourite_game': [], 'favourite_game': [],
'friends': ['favourite_game', 'least_favourite_game']}
assert group_related_list(given) == expected