refactors in join to register complex aliases on duplicate, to do is doing the same in filter clauses

This commit is contained in:
collerek
2021-01-17 12:29:21 +01:00
parent 28cc847b57
commit d6e2c85b79
9 changed files with 137 additions and 74 deletions

View File

@ -22,12 +22,12 @@ def is_field_an_forward_ref(field: Type["BaseField"]) -> bool:
:rtype: bool
"""
return issubclass(field, ForeignKeyField) and (
field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef
field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef
)
def populate_default_options_values(
new_model: Type["Model"], model_fields: Dict
new_model: Type["Model"], model_fields: Dict
) -> None:
"""
Sets all optional Meta values to it's defaults
@ -52,7 +52,8 @@ def populate_default_options_values(
new_model.Meta.abstract = False
if any(
is_field_an_forward_ref(field) for field in new_model.Meta.model_fields.values()
is_field_an_forward_ref(field) for field in
new_model.Meta.model_fields.values()
):
new_model.Meta.requires_ref_update = True
else:
@ -77,7 +78,7 @@ def extract_annotations_and_default_vals(attrs: Dict) -> Tuple[Dict, Dict]:
# cannot be in relations helpers due to cyclical import
def validate_related_names_in_relations( # noqa CCR001
model_fields: Dict, new_model: Type["Model"]
model_fields: Dict, new_model: Type["Model"]
) -> None:
"""
Performs a validation of relation_names in relation fields.
@ -122,20 +123,24 @@ def group_related_list(list_: List) -> Dict:
will become:
{'people': {'houses': [], 'cars': ['models', 'colors']}}
Result dictionary is sorted by length of the values and by key
:param list_: list of related models used in select related
:type list_: List[str]
:return: list converted to dictionary to avoid repetition and group nested models
:rtype: Dict[str, List]
"""
test_dict: Dict[str, Any] = dict()
result_dict: Dict[str, Any] = dict()
list_.sort(key=lambda x: x.split("__")[0])
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
for key, group in grouped:
group_list = list(group)
new = [
new = sorted([
"__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1
]
])
if any("__" in x for x in new):
test_dict[key] = group_related_list(new)
result_dict[key] = group_related_list(new)
else:
test_dict[key] = new
return test_dict
result_dict.setdefault(key, []).extend(new)
return {k: v for k, v in
sorted(result_dict.items(), key=lambda item: len(item[1]))}

View File

@ -34,7 +34,7 @@ class QueryClause:
"""
def __init__(
self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
) -> None:
self._select_related = select_related[:]
@ -44,7 +44,7 @@ class QueryClause:
self.table = self.model_cls.Meta.table
def filter( # noqa: A003
self, **kwargs: Any
self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
"""
Main external access point that processes the clauses into sqlalchemy text
@ -65,7 +65,7 @@ class QueryClause:
return filter_clauses, select_related
def _populate_filter_clauses(
self, **kwargs: Any
self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
"""
Iterates all clauses and extracts used operator and field from related
@ -98,7 +98,9 @@ class QueryClause:
table_prefix,
model_cls,
) = self._determine_filter_target_table(
related_parts, select_related
related_parts=related_parts,
select_related=select_related,
field_name=field_name
)
table = model_cls.Meta.table
@ -116,12 +118,12 @@ class QueryClause:
return filter_clauses, select_related
def _process_column_clause_for_operator_and_value(
self,
value: Any,
op: str,
column: sqlalchemy.Column,
table: sqlalchemy.Table,
table_prefix: str,
self,
value: Any,
op: str,
column: sqlalchemy.Column,
table: sqlalchemy.Table,
table_prefix: str,
) -> sqlalchemy.sql.expression.TextClause:
"""
Escapes characters if it's required.
@ -158,7 +160,7 @@ class QueryClause:
return clause
def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str]
self, related_parts: List[str], select_related: List[str], field_name: str
) -> Tuple[List[str], str, Type["Model"]]:
"""
Adds related strings to select_related list otherwise the clause would fail as
@ -187,27 +189,34 @@ class QueryClause:
# Walk the relationships to the actual model class
# against which the comparison is being made.
previous_model = model_cls
for part in related_parts:
part2 = part
if issubclass(model_cls.Meta.model_fields[part], ManyToManyField):
through_field = model_cls.Meta.model_fields[part]
previous_model = through_field.through
part2 = through_field.default_target_field_name() # type: ignore
manager = model_cls.Meta.alias_manager
manager = model_cls.Meta.alias_manager
for relation in related_parts:
related_field = model_cls.Meta.model_fields[relation]
if issubclass(related_field, ManyToManyField):
previous_model = related_field.through
relation = related_field.default_target_field_name() # type: ignore
table_prefix = manager.resolve_relation_alias(
from_model=previous_model, relation_name=part2
from_model=previous_model, relation_name=relation
)
model_cls = model_cls.Meta.model_fields[part].to
model_cls = related_field.to
previous_model = model_cls
# handle duplicated aliases in nested relations
# TODO: check later and remove nocover
complex_prefix = manager.resolve_relation_alias(
from_model=self.model_cls,
relation_name='__'.join([related_str, field_name])
)
if complex_prefix: # pragma: nocover
table_prefix = complex_prefix
return select_related, table_prefix, model_cls
def _compile_clause(
self,
clause: sqlalchemy.sql.expression.BinaryExpression,
column: sqlalchemy.Column,
table: sqlalchemy.Table,
table_prefix: str,
modifiers: Dict,
self,
clause: sqlalchemy.sql.expression.BinaryExpression,
column: sqlalchemy.Column,
table: sqlalchemy.Table,
table_prefix: str,
modifiers: Dict,
) -> sqlalchemy.sql.expression.TextClause:
"""
Compiles the clause to str using appropriate database dialect, replace columns
@ -287,7 +296,7 @@ class QueryClause:
@staticmethod
def _extract_operator_field_and_related(
parts: List[str],
parts: List[str],
) -> Tuple[str, str, Optional[List]]:
"""
Splits filter query key and extracts required parts.

View File

@ -24,18 +24,20 @@ if TYPE_CHECKING: # pragma no cover
class SqlJoin:
def __init__( # noqa: CFQ002
self,
used_aliases: List,
select_from: sqlalchemy.sql.select,
columns: List[sqlalchemy.Column],
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
order_columns: Optional[List],
sorted_orders: OrderedDict,
main_model: Type["Model"],
relation_name: str,
related_models: Any = None,
own_alias: str = "",
self,
used_aliases: List,
select_from: sqlalchemy.sql.select,
columns: List[sqlalchemy.Column],
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
order_columns: Optional[List],
sorted_orders: OrderedDict,
main_model: Type["Model"],
relation_name: str,
relation_str: str,
related_models: Any = None,
own_alias: str = "",
source_model: Type["Model"] = None,
) -> None:
self.relation_name = relation_name
self.related_models = related_models or []
@ -53,6 +55,9 @@ class SqlJoin:
self._next_model: Optional[Type["Model"]] = None
self._next_alias: Optional[str] = None
self.relation_str = relation_str
self.source_model = source_model
@property
def next_model(self) -> Type["Model"]:
if not self._next_model: # pragma: nocover
@ -85,7 +90,8 @@ class SqlJoin:
"""
return self.main_model.Meta.alias_manager
def on_clause(self, previous_alias: str, from_clause: str, to_clause: str,) -> text:
def on_clause(self, previous_alias: str, from_clause: str,
to_clause: str, ) -> text:
"""
Receives aliases and names of both ends of the join and combines them
into one text clause used in joins.
@ -117,11 +123,7 @@ class SqlJoin:
self.process_m2m_through_table()
self.next_model = self.target_field.to
self.next_alias = self.alias_manager.resolve_relation_alias(
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
self._forward_join()
self._process_following_joins()
@ -132,6 +134,23 @@ class SqlJoin:
self.sorted_orders,
)
def _forward_join(self):
self.next_alias = self.alias_manager.resolve_relation_alias(
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
else:
if '__' in self.relation_str:
relation_key = f'{self.source_model.get_name()}_{self.relation_str}'
if relation_key not in self.alias_manager:
print(f'registering {relation_key}')
self.next_alias = self.alias_manager.add_alias(
alias_key=relation_key)
else:
self.next_alias = self.alias_manager[relation_key]
self._process_join()
def _process_following_joins(self) -> None:
"""
Iterates through nested models to create subsequent joins.
@ -139,8 +158,8 @@ class SqlJoin:
for related_name in self.related_models:
remainder = None
if (
isinstance(self.related_models, dict)
and self.related_models[related_name]
isinstance(self.related_models, dict)
and self.related_models[related_name]
):
remainder = self.related_models[related_name]
self._process_deeper_join(related_name=related_name, remainder=remainder)
@ -175,7 +194,9 @@ class SqlJoin:
main_model=self.next_model,
relation_name=related_name,
related_models=remainder,
relation_str='__'.join([self.relation_str, related_name]),
own_alias=self.next_alias,
source_model=self.source_model or self.main_model
)
(
self.used_aliases,
@ -203,11 +224,8 @@ class SqlJoin:
self._replace_many_to_many_order_by_columns(self.relation_name, new_part)
self.next_model = self.target_field.through
self.next_alias = self.alias_manager.resolve_relation_alias(
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
self._forward_join()
self.relation_name = new_part
self.own_alias = self.next_alias
self.target_field = self.next_model.Meta.model_fields[self.relation_name]
@ -226,18 +244,18 @@ class SqlJoin:
"""
target_field = self.target_field
is_primary_self_ref = (
target_field.self_reference
and self.relation_name == target_field.self_reference_primary
target_field.self_reference
and self.relation_name == target_field.self_reference_primary
)
if (is_primary_self_ref and not reverse) or (
not is_primary_self_ref and reverse
not is_primary_self_ref and reverse
):
new_part = target_field.default_source_field_name() # type: ignore
else:
new_part = target_field.default_target_field_name() # type: ignore
return new_part
def _process_join(self,) -> None: # noqa: CFQ002
def _process_join(self, ) -> None: # noqa: CFQ002
"""
Resolves to and from column names and table names.
@ -317,10 +335,10 @@ class SqlJoin:
:rtype: bool
"""
return len(condition) >= 2 and (
condition[-2] == part or condition[-2][1:] == part
condition[-2] == part or condition[-2][1:] == part
)
def set_aliased_order_by(self, condition: List[str], to_table: str,) -> None:
def set_aliased_order_by(self, condition: List[str], to_table: str, ) -> None:
"""
Substitute hyphens ('-') with descending order.
Construct actual sqlalchemy text clause using aliased table and column name.
@ -335,7 +353,7 @@ class SqlJoin:
order = text(f"{self.next_alias}_{to_table}.{column_alias} {direction}")
self.sorted_orders["__".join(condition)] = order
def get_order_bys(self, to_table: str, pkname_alias: str,) -> None: # noqa: CCR001
def get_order_bys(self, to_table: str, pkname_alias: str, ) -> None: # noqa: CCR001
"""
Triggers construction of order bys if they are given.
Otherwise by default each table is sorted by a primary key column asc.

View File

@ -159,6 +159,7 @@ class Query:
sorted_orders=self.sorted_orders,
main_model=self.model_cls,
relation_name=related,
relation_str=related,
related_models=remainder,
)

View File

@ -228,6 +228,9 @@ class QuerySet:
:return: filtered QuerySet
:rtype: QuerySet
"""
# TODO: delay processing of filter clauses or switch to group one
# that keeps all aliases even if duplicated - now initialized too late
# in the join
qryclause = QueryClause(
model_cls=self.model,
select_related=self._select_related,

View File

@ -33,6 +33,12 @@ class AliasManager:
def __init__(self) -> None:
self._aliases_new: Dict[str, str] = dict()
def __contains__(self, item):
return self._aliases_new.__contains__(item)
def __getitem__(self, key):
return self._aliases_new.__getitem__(key)
@staticmethod
def prefixed_columns(
alias: str, table: sqlalchemy.Table, fields: List = None

View File

@ -169,7 +169,7 @@ async def test_other_forwardref_relation(cleanup):
async def test_m2m_self_forwardref_relation(cleanup):
async with db:
async with db.transaction(force_rollback=True):
checkers = await Game.objects.create(name="checkers")
checkers = await Game.objects.create(name="Checkers")
uno = await Game(name="Uno").save()
jenga = await Game(name="Jenga").save()
@ -186,15 +186,17 @@ async def test_m2m_self_forwardref_relation(cleanup):
await billy.friends.add(kate)
await billy.friends.add(steve)
# await steve.friends.add(kate)
# await steve.friends.add(billy)
billy_check = await Child.objects.select_related(
["friends", "favourite_game", "least_favourite_game"]
["friends", "favourite_game", "least_favourite_game",
"friends__favourite_game", "friends__least_favourite_game"]
).get(name="Billy")
assert len(billy_check.friends) == 2
assert billy_check.friends[0].name == "Kate"
assert billy_check.friends[0].favourite_game.name == 'Checkers'
assert billy_check.friends[0].least_favourite_game.name == 'Uno'
assert billy_check.friends[1].name == "Steve"
assert billy_check.friends[1].favourite_game.name == 'Jenga'
assert billy_check.friends[1].least_favourite_game.name == 'Uno'
assert billy_check.favourite_game.name == "Uno"
kate_check = await Child.objects.select_related(["also_friends",]).get(
@ -203,3 +205,13 @@ async def test_m2m_self_forwardref_relation(cleanup):
assert len(kate_check.also_friends) == 1
assert kate_check.also_friends[0].name == "Billy"
# TODO: Fix filters with complex prefixes
# billy_check = await Child.objects.select_related(
# ["friends", "favourite_game", "least_favourite_game",
# "friends__favourite_game", "friends__least_favourite_game"]
# ).filter(friends__favourite_game__name="Checkers").get(name="Billy")
# assert len(billy_check.friends) == 1
# assert billy_check.friends[0].name == "Kate"
# assert billy_check.friends[0].favourite_game.name == 'Checkers'
# assert billy_check.friends[0].least_favourite_game.name == 'Uno'

View File

@ -0,0 +1,9 @@
from ormar.models.helpers.models import group_related_list
def test_group_related_list():
given = ['friends__least_favourite_game', 'least_favourite_game', 'friends',
'favourite_game', 'friends__favourite_game']
expected = {'least_favourite_game': [], 'favourite_game': [],
'friends': ['favourite_game', 'least_favourite_game']}
assert group_related_list(given) == expected