refactors in join to register complex aliases on duplicate, to do is doing the same in filter clauses

This commit is contained in:
collerek
2021-01-17 12:29:21 +01:00
parent 28cc847b57
commit d6e2c85b79
9 changed files with 137 additions and 74 deletions

View File

@ -52,7 +52,8 @@ def populate_default_options_values(
new_model.Meta.abstract = False new_model.Meta.abstract = False
if any( if any(
is_field_an_forward_ref(field) for field in new_model.Meta.model_fields.values() is_field_an_forward_ref(field) for field in
new_model.Meta.model_fields.values()
): ):
new_model.Meta.requires_ref_update = True new_model.Meta.requires_ref_update = True
else: else:
@ -122,20 +123,24 @@ def group_related_list(list_: List) -> Dict:
will become: will become:
{'people': {'houses': [], 'cars': ['models', 'colors']}} {'people': {'houses': [], 'cars': ['models', 'colors']}}
Result dictionary is sorted by length of the values and by key
:param list_: list of related models used in select related :param list_: list of related models used in select related
:type list_: List[str] :type list_: List[str]
:return: list converted to dictionary to avoid repetition and group nested models :return: list converted to dictionary to avoid repetition and group nested models
:rtype: Dict[str, List] :rtype: Dict[str, List]
""" """
test_dict: Dict[str, Any] = dict() result_dict: Dict[str, Any] = dict()
list_.sort(key=lambda x: x.split("__")[0])
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0]) grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
for key, group in grouped: for key, group in grouped:
group_list = list(group) group_list = list(group)
new = [ new = sorted([
"__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1 "__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1
] ])
if any("__" in x for x in new): if any("__" in x for x in new):
test_dict[key] = group_related_list(new) result_dict[key] = group_related_list(new)
else: else:
test_dict[key] = new result_dict.setdefault(key, []).extend(new)
return test_dict return {k: v for k, v in
sorted(result_dict.items(), key=lambda item: len(item[1]))}

View File

@ -98,7 +98,9 @@ class QueryClause:
table_prefix, table_prefix,
model_cls, model_cls,
) = self._determine_filter_target_table( ) = self._determine_filter_target_table(
related_parts, select_related related_parts=related_parts,
select_related=select_related,
field_name=field_name
) )
table = model_cls.Meta.table table = model_cls.Meta.table
@ -158,7 +160,7 @@ class QueryClause:
return clause return clause
def _determine_filter_target_table( def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str] self, related_parts: List[str], select_related: List[str], field_name: str
) -> Tuple[List[str], str, Type["Model"]]: ) -> Tuple[List[str], str, Type["Model"]]:
""" """
Adds related strings to select_related list otherwise the clause would fail as Adds related strings to select_related list otherwise the clause would fail as
@ -187,18 +189,25 @@ class QueryClause:
# Walk the relationships to the actual model class # Walk the relationships to the actual model class
# against which the comparison is being made. # against which the comparison is being made.
previous_model = model_cls previous_model = model_cls
for part in related_parts:
part2 = part
if issubclass(model_cls.Meta.model_fields[part], ManyToManyField):
through_field = model_cls.Meta.model_fields[part]
previous_model = through_field.through
part2 = through_field.default_target_field_name() # type: ignore
manager = model_cls.Meta.alias_manager manager = model_cls.Meta.alias_manager
for relation in related_parts:
related_field = model_cls.Meta.model_fields[relation]
if issubclass(related_field, ManyToManyField):
previous_model = related_field.through
relation = related_field.default_target_field_name() # type: ignore
table_prefix = manager.resolve_relation_alias( table_prefix = manager.resolve_relation_alias(
from_model=previous_model, relation_name=part2 from_model=previous_model, relation_name=relation
) )
model_cls = model_cls.Meta.model_fields[part].to model_cls = related_field.to
previous_model = model_cls previous_model = model_cls
# handle duplicated aliases in nested relations
# TODO: check later and remove nocover
complex_prefix = manager.resolve_relation_alias(
from_model=self.model_cls,
relation_name='__'.join([related_str, field_name])
)
if complex_prefix: # pragma: nocover
table_prefix = complex_prefix
return select_related, table_prefix, model_cls return select_related, table_prefix, model_cls
def _compile_clause( def _compile_clause(

View File

@ -34,8 +34,10 @@ class SqlJoin:
sorted_orders: OrderedDict, sorted_orders: OrderedDict,
main_model: Type["Model"], main_model: Type["Model"],
relation_name: str, relation_name: str,
relation_str: str,
related_models: Any = None, related_models: Any = None,
own_alias: str = "", own_alias: str = "",
source_model: Type["Model"] = None,
) -> None: ) -> None:
self.relation_name = relation_name self.relation_name = relation_name
self.related_models = related_models or [] self.related_models = related_models or []
@ -53,6 +55,9 @@ class SqlJoin:
self._next_model: Optional[Type["Model"]] = None self._next_model: Optional[Type["Model"]] = None
self._next_alias: Optional[str] = None self._next_alias: Optional[str] = None
self.relation_str = relation_str
self.source_model = source_model
@property @property
def next_model(self) -> Type["Model"]: def next_model(self) -> Type["Model"]:
if not self._next_model: # pragma: nocover if not self._next_model: # pragma: nocover
@ -85,7 +90,8 @@ class SqlJoin:
""" """
return self.main_model.Meta.alias_manager return self.main_model.Meta.alias_manager
def on_clause(self, previous_alias: str, from_clause: str, to_clause: str,) -> text: def on_clause(self, previous_alias: str, from_clause: str,
to_clause: str, ) -> text:
""" """
Receives aliases and names of both ends of the join and combines them Receives aliases and names of both ends of the join and combines them
into one text clause used in joins. into one text clause used in joins.
@ -117,11 +123,7 @@ class SqlJoin:
self.process_m2m_through_table() self.process_m2m_through_table()
self.next_model = self.target_field.to self.next_model = self.target_field.to
self.next_alias = self.alias_manager.resolve_relation_alias( self._forward_join()
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
self._process_following_joins() self._process_following_joins()
@ -132,6 +134,23 @@ class SqlJoin:
self.sorted_orders, self.sorted_orders,
) )
def _forward_join(self):
self.next_alias = self.alias_manager.resolve_relation_alias(
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
else:
if '__' in self.relation_str:
relation_key = f'{self.source_model.get_name()}_{self.relation_str}'
if relation_key not in self.alias_manager:
print(f'registering {relation_key}')
self.next_alias = self.alias_manager.add_alias(
alias_key=relation_key)
else:
self.next_alias = self.alias_manager[relation_key]
self._process_join()
def _process_following_joins(self) -> None: def _process_following_joins(self) -> None:
""" """
Iterates through nested models to create subsequent joins. Iterates through nested models to create subsequent joins.
@ -175,7 +194,9 @@ class SqlJoin:
main_model=self.next_model, main_model=self.next_model,
relation_name=related_name, relation_name=related_name,
related_models=remainder, related_models=remainder,
relation_str='__'.join([self.relation_str, related_name]),
own_alias=self.next_alias, own_alias=self.next_alias,
source_model=self.source_model or self.main_model
) )
( (
self.used_aliases, self.used_aliases,
@ -203,11 +224,8 @@ class SqlJoin:
self._replace_many_to_many_order_by_columns(self.relation_name, new_part) self._replace_many_to_many_order_by_columns(self.relation_name, new_part)
self.next_model = self.target_field.through self.next_model = self.target_field.through
self.next_alias = self.alias_manager.resolve_relation_alias( self._forward_join()
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
self.relation_name = new_part self.relation_name = new_part
self.own_alias = self.next_alias self.own_alias = self.next_alias
self.target_field = self.next_model.Meta.model_fields[self.relation_name] self.target_field = self.next_model.Meta.model_fields[self.relation_name]

View File

@ -159,6 +159,7 @@ class Query:
sorted_orders=self.sorted_orders, sorted_orders=self.sorted_orders,
main_model=self.model_cls, main_model=self.model_cls,
relation_name=related, relation_name=related,
relation_str=related,
related_models=remainder, related_models=remainder,
) )

View File

@ -228,6 +228,9 @@ class QuerySet:
:return: filtered QuerySet :return: filtered QuerySet
:rtype: QuerySet :rtype: QuerySet
""" """
# TODO: delay processing of filter clauses or switch to group one
# that keeps all aliases even if duplicated - now initialized too late
# in the join
qryclause = QueryClause( qryclause = QueryClause(
model_cls=self.model, model_cls=self.model,
select_related=self._select_related, select_related=self._select_related,

View File

@ -33,6 +33,12 @@ class AliasManager:
def __init__(self) -> None: def __init__(self) -> None:
self._aliases_new: Dict[str, str] = dict() self._aliases_new: Dict[str, str] = dict()
def __contains__(self, item):
return self._aliases_new.__contains__(item)
def __getitem__(self, key):
return self._aliases_new.__getitem__(key)
@staticmethod @staticmethod
def prefixed_columns( def prefixed_columns(
alias: str, table: sqlalchemy.Table, fields: List = None alias: str, table: sqlalchemy.Table, fields: List = None

View File

@ -169,7 +169,7 @@ async def test_other_forwardref_relation(cleanup):
async def test_m2m_self_forwardref_relation(cleanup): async def test_m2m_self_forwardref_relation(cleanup):
async with db: async with db:
async with db.transaction(force_rollback=True): async with db.transaction(force_rollback=True):
checkers = await Game.objects.create(name="checkers") checkers = await Game.objects.create(name="Checkers")
uno = await Game(name="Uno").save() uno = await Game(name="Uno").save()
jenga = await Game(name="Jenga").save() jenga = await Game(name="Jenga").save()
@ -186,15 +186,17 @@ async def test_m2m_self_forwardref_relation(cleanup):
await billy.friends.add(kate) await billy.friends.add(kate)
await billy.friends.add(steve) await billy.friends.add(steve)
# await steve.friends.add(kate)
# await steve.friends.add(billy)
billy_check = await Child.objects.select_related( billy_check = await Child.objects.select_related(
["friends", "favourite_game", "least_favourite_game"] ["friends", "favourite_game", "least_favourite_game",
"friends__favourite_game", "friends__least_favourite_game"]
).get(name="Billy") ).get(name="Billy")
assert len(billy_check.friends) == 2 assert len(billy_check.friends) == 2
assert billy_check.friends[0].name == "Kate" assert billy_check.friends[0].name == "Kate"
assert billy_check.friends[0].favourite_game.name == 'Checkers'
assert billy_check.friends[0].least_favourite_game.name == 'Uno'
assert billy_check.friends[1].name == "Steve" assert billy_check.friends[1].name == "Steve"
assert billy_check.friends[1].favourite_game.name == 'Jenga'
assert billy_check.friends[1].least_favourite_game.name == 'Uno'
assert billy_check.favourite_game.name == "Uno" assert billy_check.favourite_game.name == "Uno"
kate_check = await Child.objects.select_related(["also_friends",]).get( kate_check = await Child.objects.select_related(["also_friends",]).get(
@ -203,3 +205,13 @@ async def test_m2m_self_forwardref_relation(cleanup):
assert len(kate_check.also_friends) == 1 assert len(kate_check.also_friends) == 1
assert kate_check.also_friends[0].name == "Billy" assert kate_check.also_friends[0].name == "Billy"
# TODO: Fix filters with complex prefixes
# billy_check = await Child.objects.select_related(
# ["friends", "favourite_game", "least_favourite_game",
# "friends__favourite_game", "friends__least_favourite_game"]
# ).filter(friends__favourite_game__name="Checkers").get(name="Billy")
# assert len(billy_check.friends) == 1
# assert billy_check.friends[0].name == "Kate"
# assert billy_check.friends[0].favourite_game.name == 'Checkers'
# assert billy_check.friends[0].least_favourite_game.name == 'Uno'

View File

@ -0,0 +1,9 @@
from ormar.models.helpers.models import group_related_list
def test_group_related_list():
given = ['friends__least_favourite_game', 'least_favourite_game', 'friends',
'favourite_game', 'friends__favourite_game']
expected = {'least_favourite_game': [], 'favourite_game': [],
'friends': ['favourite_game', 'least_favourite_game']}
assert group_related_list(given) == expected