refactors in join to register complex aliases on duplicate, to do is doing the same in filter clauses

This commit is contained in:
collerek
2021-01-17 12:29:21 +01:00
parent 28cc847b57
commit d6e2c85b79
9 changed files with 137 additions and 74 deletions

View File

@ -22,12 +22,12 @@ def is_field_an_forward_ref(field: Type["BaseField"]) -> bool:
:rtype: bool :rtype: bool
""" """
return issubclass(field, ForeignKeyField) and ( return issubclass(field, ForeignKeyField) and (
field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef
) )
def populate_default_options_values( def populate_default_options_values(
new_model: Type["Model"], model_fields: Dict new_model: Type["Model"], model_fields: Dict
) -> None: ) -> None:
""" """
Sets all optional Meta values to it's defaults Sets all optional Meta values to it's defaults
@ -52,7 +52,8 @@ def populate_default_options_values(
new_model.Meta.abstract = False new_model.Meta.abstract = False
if any( if any(
is_field_an_forward_ref(field) for field in new_model.Meta.model_fields.values() is_field_an_forward_ref(field) for field in
new_model.Meta.model_fields.values()
): ):
new_model.Meta.requires_ref_update = True new_model.Meta.requires_ref_update = True
else: else:
@ -77,7 +78,7 @@ def extract_annotations_and_default_vals(attrs: Dict) -> Tuple[Dict, Dict]:
# cannot be in relations helpers due to cyclical import # cannot be in relations helpers due to cyclical import
def validate_related_names_in_relations( # noqa CCR001 def validate_related_names_in_relations( # noqa CCR001
model_fields: Dict, new_model: Type["Model"] model_fields: Dict, new_model: Type["Model"]
) -> None: ) -> None:
""" """
Performs a validation of relation_names in relation fields. Performs a validation of relation_names in relation fields.
@ -122,20 +123,24 @@ def group_related_list(list_: List) -> Dict:
will become: will become:
{'people': {'houses': [], 'cars': ['models', 'colors']}} {'people': {'houses': [], 'cars': ['models', 'colors']}}
Result dictionary is sorted by length of the values and by key
:param list_: list of related models used in select related :param list_: list of related models used in select related
:type list_: List[str] :type list_: List[str]
:return: list converted to dictionary to avoid repetition and group nested models :return: list converted to dictionary to avoid repetition and group nested models
:rtype: Dict[str, List] :rtype: Dict[str, List]
""" """
test_dict: Dict[str, Any] = dict() result_dict: Dict[str, Any] = dict()
list_.sort(key=lambda x: x.split("__")[0])
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0]) grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
for key, group in grouped: for key, group in grouped:
group_list = list(group) group_list = list(group)
new = [ new = sorted([
"__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1 "__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1
] ])
if any("__" in x for x in new): if any("__" in x for x in new):
test_dict[key] = group_related_list(new) result_dict[key] = group_related_list(new)
else: else:
test_dict[key] = new result_dict.setdefault(key, []).extend(new)
return test_dict return {k: v for k, v in
sorted(result_dict.items(), key=lambda item: len(item[1]))}

View File

@ -34,7 +34,7 @@ class QueryClause:
""" """
def __init__( def __init__(
self, model_cls: Type["Model"], filter_clauses: List, select_related: List, self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
) -> None: ) -> None:
self._select_related = select_related[:] self._select_related = select_related[:]
@ -44,7 +44,7 @@ class QueryClause:
self.table = self.model_cls.Meta.table self.table = self.model_cls.Meta.table
def filter( # noqa: A003 def filter( # noqa: A003
self, **kwargs: Any self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
""" """
Main external access point that processes the clauses into sqlalchemy text Main external access point that processes the clauses into sqlalchemy text
@ -65,7 +65,7 @@ class QueryClause:
return filter_clauses, select_related return filter_clauses, select_related
def _populate_filter_clauses( def _populate_filter_clauses(
self, **kwargs: Any self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
""" """
Iterates all clauses and extracts used operator and field from related Iterates all clauses and extracts used operator and field from related
@ -98,7 +98,9 @@ class QueryClause:
table_prefix, table_prefix,
model_cls, model_cls,
) = self._determine_filter_target_table( ) = self._determine_filter_target_table(
related_parts, select_related related_parts=related_parts,
select_related=select_related,
field_name=field_name
) )
table = model_cls.Meta.table table = model_cls.Meta.table
@ -116,12 +118,12 @@ class QueryClause:
return filter_clauses, select_related return filter_clauses, select_related
def _process_column_clause_for_operator_and_value( def _process_column_clause_for_operator_and_value(
self, self,
value: Any, value: Any,
op: str, op: str,
column: sqlalchemy.Column, column: sqlalchemy.Column,
table: sqlalchemy.Table, table: sqlalchemy.Table,
table_prefix: str, table_prefix: str,
) -> sqlalchemy.sql.expression.TextClause: ) -> sqlalchemy.sql.expression.TextClause:
""" """
Escapes characters if it's required. Escapes characters if it's required.
@ -158,7 +160,7 @@ class QueryClause:
return clause return clause
def _determine_filter_target_table( def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str] self, related_parts: List[str], select_related: List[str], field_name: str
) -> Tuple[List[str], str, Type["Model"]]: ) -> Tuple[List[str], str, Type["Model"]]:
""" """
Adds related strings to select_related list otherwise the clause would fail as Adds related strings to select_related list otherwise the clause would fail as
@ -187,27 +189,34 @@ class QueryClause:
# Walk the relationships to the actual model class # Walk the relationships to the actual model class
# against which the comparison is being made. # against which the comparison is being made.
previous_model = model_cls previous_model = model_cls
for part in related_parts: manager = model_cls.Meta.alias_manager
part2 = part for relation in related_parts:
if issubclass(model_cls.Meta.model_fields[part], ManyToManyField): related_field = model_cls.Meta.model_fields[relation]
through_field = model_cls.Meta.model_fields[part] if issubclass(related_field, ManyToManyField):
previous_model = through_field.through previous_model = related_field.through
part2 = through_field.default_target_field_name() # type: ignore relation = related_field.default_target_field_name() # type: ignore
manager = model_cls.Meta.alias_manager
table_prefix = manager.resolve_relation_alias( table_prefix = manager.resolve_relation_alias(
from_model=previous_model, relation_name=part2 from_model=previous_model, relation_name=relation
) )
model_cls = model_cls.Meta.model_fields[part].to model_cls = related_field.to
previous_model = model_cls previous_model = model_cls
# handle duplicated aliases in nested relations
# TODO: check later and remove nocover
complex_prefix = manager.resolve_relation_alias(
from_model=self.model_cls,
relation_name='__'.join([related_str, field_name])
)
if complex_prefix: # pragma: nocover
table_prefix = complex_prefix
return select_related, table_prefix, model_cls return select_related, table_prefix, model_cls
def _compile_clause( def _compile_clause(
self, self,
clause: sqlalchemy.sql.expression.BinaryExpression, clause: sqlalchemy.sql.expression.BinaryExpression,
column: sqlalchemy.Column, column: sqlalchemy.Column,
table: sqlalchemy.Table, table: sqlalchemy.Table,
table_prefix: str, table_prefix: str,
modifiers: Dict, modifiers: Dict,
) -> sqlalchemy.sql.expression.TextClause: ) -> sqlalchemy.sql.expression.TextClause:
""" """
Compiles the clause to str using appropriate database dialect, replace columns Compiles the clause to str using appropriate database dialect, replace columns
@ -287,7 +296,7 @@ class QueryClause:
@staticmethod @staticmethod
def _extract_operator_field_and_related( def _extract_operator_field_and_related(
parts: List[str], parts: List[str],
) -> Tuple[str, str, Optional[List]]: ) -> Tuple[str, str, Optional[List]]:
""" """
Splits filter query key and extracts required parts. Splits filter query key and extracts required parts.

View File

@ -24,18 +24,20 @@ if TYPE_CHECKING: # pragma no cover
class SqlJoin: class SqlJoin:
def __init__( # noqa: CFQ002 def __init__( # noqa: CFQ002
self, self,
used_aliases: List, used_aliases: List,
select_from: sqlalchemy.sql.select, select_from: sqlalchemy.sql.select,
columns: List[sqlalchemy.Column], columns: List[sqlalchemy.Column],
fields: Optional[Union[Set, Dict]], fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]], exclude_fields: Optional[Union[Set, Dict]],
order_columns: Optional[List], order_columns: Optional[List],
sorted_orders: OrderedDict, sorted_orders: OrderedDict,
main_model: Type["Model"], main_model: Type["Model"],
relation_name: str, relation_name: str,
related_models: Any = None, relation_str: str,
own_alias: str = "", related_models: Any = None,
own_alias: str = "",
source_model: Type["Model"] = None,
) -> None: ) -> None:
self.relation_name = relation_name self.relation_name = relation_name
self.related_models = related_models or [] self.related_models = related_models or []
@ -53,6 +55,9 @@ class SqlJoin:
self._next_model: Optional[Type["Model"]] = None self._next_model: Optional[Type["Model"]] = None
self._next_alias: Optional[str] = None self._next_alias: Optional[str] = None
self.relation_str = relation_str
self.source_model = source_model
@property @property
def next_model(self) -> Type["Model"]: def next_model(self) -> Type["Model"]:
if not self._next_model: # pragma: nocover if not self._next_model: # pragma: nocover
@ -85,7 +90,8 @@ class SqlJoin:
""" """
return self.main_model.Meta.alias_manager return self.main_model.Meta.alias_manager
def on_clause(self, previous_alias: str, from_clause: str, to_clause: str,) -> text: def on_clause(self, previous_alias: str, from_clause: str,
to_clause: str, ) -> text:
""" """
Receives aliases and names of both ends of the join and combines them Receives aliases and names of both ends of the join and combines them
into one text clause used in joins. into one text clause used in joins.
@ -117,11 +123,7 @@ class SqlJoin:
self.process_m2m_through_table() self.process_m2m_through_table()
self.next_model = self.target_field.to self.next_model = self.target_field.to
self.next_alias = self.alias_manager.resolve_relation_alias( self._forward_join()
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
self._process_following_joins() self._process_following_joins()
@ -132,6 +134,23 @@ class SqlJoin:
self.sorted_orders, self.sorted_orders,
) )
def _forward_join(self):
self.next_alias = self.alias_manager.resolve_relation_alias(
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
else:
if '__' in self.relation_str:
relation_key = f'{self.source_model.get_name()}_{self.relation_str}'
if relation_key not in self.alias_manager:
print(f'registering {relation_key}')
self.next_alias = self.alias_manager.add_alias(
alias_key=relation_key)
else:
self.next_alias = self.alias_manager[relation_key]
self._process_join()
def _process_following_joins(self) -> None: def _process_following_joins(self) -> None:
""" """
Iterates through nested models to create subsequent joins. Iterates through nested models to create subsequent joins.
@ -139,8 +158,8 @@ class SqlJoin:
for related_name in self.related_models: for related_name in self.related_models:
remainder = None remainder = None
if ( if (
isinstance(self.related_models, dict) isinstance(self.related_models, dict)
and self.related_models[related_name] and self.related_models[related_name]
): ):
remainder = self.related_models[related_name] remainder = self.related_models[related_name]
self._process_deeper_join(related_name=related_name, remainder=remainder) self._process_deeper_join(related_name=related_name, remainder=remainder)
@ -175,7 +194,9 @@ class SqlJoin:
main_model=self.next_model, main_model=self.next_model,
relation_name=related_name, relation_name=related_name,
related_models=remainder, related_models=remainder,
relation_str='__'.join([self.relation_str, related_name]),
own_alias=self.next_alias, own_alias=self.next_alias,
source_model=self.source_model or self.main_model
) )
( (
self.used_aliases, self.used_aliases,
@ -203,11 +224,8 @@ class SqlJoin:
self._replace_many_to_many_order_by_columns(self.relation_name, new_part) self._replace_many_to_many_order_by_columns(self.relation_name, new_part)
self.next_model = self.target_field.through self.next_model = self.target_field.through
self.next_alias = self.alias_manager.resolve_relation_alias( self._forward_join()
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
self.relation_name = new_part self.relation_name = new_part
self.own_alias = self.next_alias self.own_alias = self.next_alias
self.target_field = self.next_model.Meta.model_fields[self.relation_name] self.target_field = self.next_model.Meta.model_fields[self.relation_name]
@ -226,18 +244,18 @@ class SqlJoin:
""" """
target_field = self.target_field target_field = self.target_field
is_primary_self_ref = ( is_primary_self_ref = (
target_field.self_reference target_field.self_reference
and self.relation_name == target_field.self_reference_primary and self.relation_name == target_field.self_reference_primary
) )
if (is_primary_self_ref and not reverse) or ( if (is_primary_self_ref and not reverse) or (
not is_primary_self_ref and reverse not is_primary_self_ref and reverse
): ):
new_part = target_field.default_source_field_name() # type: ignore new_part = target_field.default_source_field_name() # type: ignore
else: else:
new_part = target_field.default_target_field_name() # type: ignore new_part = target_field.default_target_field_name() # type: ignore
return new_part return new_part
def _process_join(self,) -> None: # noqa: CFQ002 def _process_join(self, ) -> None: # noqa: CFQ002
""" """
Resolves to and from column names and table names. Resolves to and from column names and table names.
@ -317,10 +335,10 @@ class SqlJoin:
:rtype: bool :rtype: bool
""" """
return len(condition) >= 2 and ( return len(condition) >= 2 and (
condition[-2] == part or condition[-2][1:] == part condition[-2] == part or condition[-2][1:] == part
) )
def set_aliased_order_by(self, condition: List[str], to_table: str,) -> None: def set_aliased_order_by(self, condition: List[str], to_table: str, ) -> None:
""" """
Substitute hyphens ('-') with descending order. Substitute hyphens ('-') with descending order.
Construct actual sqlalchemy text clause using aliased table and column name. Construct actual sqlalchemy text clause using aliased table and column name.
@ -335,7 +353,7 @@ class SqlJoin:
order = text(f"{self.next_alias}_{to_table}.{column_alias} {direction}") order = text(f"{self.next_alias}_{to_table}.{column_alias} {direction}")
self.sorted_orders["__".join(condition)] = order self.sorted_orders["__".join(condition)] = order
def get_order_bys(self, to_table: str, pkname_alias: str,) -> None: # noqa: CCR001 def get_order_bys(self, to_table: str, pkname_alias: str, ) -> None: # noqa: CCR001
""" """
Triggers construction of order bys if they are given. Triggers construction of order bys if they are given.
Otherwise by default each table is sorted by a primary key column asc. Otherwise by default each table is sorted by a primary key column asc.

View File

@ -159,6 +159,7 @@ class Query:
sorted_orders=self.sorted_orders, sorted_orders=self.sorted_orders,
main_model=self.model_cls, main_model=self.model_cls,
relation_name=related, relation_name=related,
relation_str=related,
related_models=remainder, related_models=remainder,
) )

View File

@ -228,6 +228,9 @@ class QuerySet:
:return: filtered QuerySet :return: filtered QuerySet
:rtype: QuerySet :rtype: QuerySet
""" """
# TODO: delay processing of filter clauses or switch to group one
# that keeps all aliases even if duplicated - now initialized too late
# in the join
qryclause = QueryClause( qryclause = QueryClause(
model_cls=self.model, model_cls=self.model,
select_related=self._select_related, select_related=self._select_related,

View File

@ -33,6 +33,12 @@ class AliasManager:
def __init__(self) -> None: def __init__(self) -> None:
self._aliases_new: Dict[str, str] = dict() self._aliases_new: Dict[str, str] = dict()
def __contains__(self, item):
return self._aliases_new.__contains__(item)
def __getitem__(self, key):
return self._aliases_new.__getitem__(key)
@staticmethod @staticmethod
def prefixed_columns( def prefixed_columns(
alias: str, table: sqlalchemy.Table, fields: List = None alias: str, table: sqlalchemy.Table, fields: List = None

View File

@ -169,7 +169,7 @@ async def test_other_forwardref_relation(cleanup):
async def test_m2m_self_forwardref_relation(cleanup): async def test_m2m_self_forwardref_relation(cleanup):
async with db: async with db:
async with db.transaction(force_rollback=True): async with db.transaction(force_rollback=True):
checkers = await Game.objects.create(name="checkers") checkers = await Game.objects.create(name="Checkers")
uno = await Game(name="Uno").save() uno = await Game(name="Uno").save()
jenga = await Game(name="Jenga").save() jenga = await Game(name="Jenga").save()
@ -186,15 +186,17 @@ async def test_m2m_self_forwardref_relation(cleanup):
await billy.friends.add(kate) await billy.friends.add(kate)
await billy.friends.add(steve) await billy.friends.add(steve)
# await steve.friends.add(kate)
# await steve.friends.add(billy)
billy_check = await Child.objects.select_related( billy_check = await Child.objects.select_related(
["friends", "favourite_game", "least_favourite_game"] ["friends", "favourite_game", "least_favourite_game",
"friends__favourite_game", "friends__least_favourite_game"]
).get(name="Billy") ).get(name="Billy")
assert len(billy_check.friends) == 2 assert len(billy_check.friends) == 2
assert billy_check.friends[0].name == "Kate" assert billy_check.friends[0].name == "Kate"
assert billy_check.friends[0].favourite_game.name == 'Checkers'
assert billy_check.friends[0].least_favourite_game.name == 'Uno'
assert billy_check.friends[1].name == "Steve" assert billy_check.friends[1].name == "Steve"
assert billy_check.friends[1].favourite_game.name == 'Jenga'
assert billy_check.friends[1].least_favourite_game.name == 'Uno'
assert billy_check.favourite_game.name == "Uno" assert billy_check.favourite_game.name == "Uno"
kate_check = await Child.objects.select_related(["also_friends",]).get( kate_check = await Child.objects.select_related(["also_friends",]).get(
@ -203,3 +205,13 @@ async def test_m2m_self_forwardref_relation(cleanup):
assert len(kate_check.also_friends) == 1 assert len(kate_check.also_friends) == 1
assert kate_check.also_friends[0].name == "Billy" assert kate_check.also_friends[0].name == "Billy"
# TODO: Fix filters with complex prefixes
# billy_check = await Child.objects.select_related(
# ["friends", "favourite_game", "least_favourite_game",
# "friends__favourite_game", "friends__least_favourite_game"]
# ).filter(friends__favourite_game__name="Checkers").get(name="Billy")
# assert len(billy_check.friends) == 1
# assert billy_check.friends[0].name == "Kate"
# assert billy_check.friends[0].favourite_game.name == 'Checkers'
# assert billy_check.friends[0].least_favourite_game.name == 'Uno'

View File

@ -0,0 +1,9 @@
from ormar.models.helpers.models import group_related_list
def test_group_related_list():
given = ['friends__least_favourite_game', 'least_favourite_game', 'friends',
'favourite_game', 'friends__favourite_game']
expected = {'least_favourite_game': [], 'favourite_game': [],
'friends': ['favourite_game', 'least_favourite_game']}
assert group_related_list(given) == expected