refactors in join to register complex aliases on duplicate, to do is doing the same in filter clauses
This commit is contained in:
@ -22,12 +22,12 @@ def is_field_an_forward_ref(field: Type["BaseField"]) -> bool:
|
||||
:rtype: bool
|
||||
"""
|
||||
return issubclass(field, ForeignKeyField) and (
|
||||
field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef
|
||||
field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef
|
||||
)
|
||||
|
||||
|
||||
def populate_default_options_values(
|
||||
new_model: Type["Model"], model_fields: Dict
|
||||
new_model: Type["Model"], model_fields: Dict
|
||||
) -> None:
|
||||
"""
|
||||
Sets all optional Meta values to it's defaults
|
||||
@ -52,7 +52,8 @@ def populate_default_options_values(
|
||||
new_model.Meta.abstract = False
|
||||
|
||||
if any(
|
||||
is_field_an_forward_ref(field) for field in new_model.Meta.model_fields.values()
|
||||
is_field_an_forward_ref(field) for field in
|
||||
new_model.Meta.model_fields.values()
|
||||
):
|
||||
new_model.Meta.requires_ref_update = True
|
||||
else:
|
||||
@ -77,7 +78,7 @@ def extract_annotations_and_default_vals(attrs: Dict) -> Tuple[Dict, Dict]:
|
||||
|
||||
# cannot be in relations helpers due to cyclical import
|
||||
def validate_related_names_in_relations( # noqa CCR001
|
||||
model_fields: Dict, new_model: Type["Model"]
|
||||
model_fields: Dict, new_model: Type["Model"]
|
||||
) -> None:
|
||||
"""
|
||||
Performs a validation of relation_names in relation fields.
|
||||
@ -122,20 +123,24 @@ def group_related_list(list_: List) -> Dict:
|
||||
will become:
|
||||
{'people': {'houses': [], 'cars': ['models', 'colors']}}
|
||||
|
||||
Result dictionary is sorted by length of the values and by key
|
||||
|
||||
:param list_: list of related models used in select related
|
||||
:type list_: List[str]
|
||||
:return: list converted to dictionary to avoid repetition and group nested models
|
||||
:rtype: Dict[str, List]
|
||||
"""
|
||||
test_dict: Dict[str, Any] = dict()
|
||||
result_dict: Dict[str, Any] = dict()
|
||||
list_.sort(key=lambda x: x.split("__")[0])
|
||||
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
|
||||
for key, group in grouped:
|
||||
group_list = list(group)
|
||||
new = [
|
||||
new = sorted([
|
||||
"__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1
|
||||
]
|
||||
])
|
||||
if any("__" in x for x in new):
|
||||
test_dict[key] = group_related_list(new)
|
||||
result_dict[key] = group_related_list(new)
|
||||
else:
|
||||
test_dict[key] = new
|
||||
return test_dict
|
||||
result_dict.setdefault(key, []).extend(new)
|
||||
return {k: v for k, v in
|
||||
sorted(result_dict.items(), key=lambda item: len(item[1]))}
|
||||
|
||||
@ -34,7 +34,7 @@ class QueryClause:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
|
||||
self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
|
||||
) -> None:
|
||||
|
||||
self._select_related = select_related[:]
|
||||
@ -44,7 +44,7 @@ class QueryClause:
|
||||
self.table = self.model_cls.Meta.table
|
||||
|
||||
def filter( # noqa: A003
|
||||
self, **kwargs: Any
|
||||
self, **kwargs: Any
|
||||
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
|
||||
"""
|
||||
Main external access point that processes the clauses into sqlalchemy text
|
||||
@ -65,7 +65,7 @@ class QueryClause:
|
||||
return filter_clauses, select_related
|
||||
|
||||
def _populate_filter_clauses(
|
||||
self, **kwargs: Any
|
||||
self, **kwargs: Any
|
||||
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
|
||||
"""
|
||||
Iterates all clauses and extracts used operator and field from related
|
||||
@ -98,7 +98,9 @@ class QueryClause:
|
||||
table_prefix,
|
||||
model_cls,
|
||||
) = self._determine_filter_target_table(
|
||||
related_parts, select_related
|
||||
related_parts=related_parts,
|
||||
select_related=select_related,
|
||||
field_name=field_name
|
||||
)
|
||||
|
||||
table = model_cls.Meta.table
|
||||
@ -116,12 +118,12 @@ class QueryClause:
|
||||
return filter_clauses, select_related
|
||||
|
||||
def _process_column_clause_for_operator_and_value(
|
||||
self,
|
||||
value: Any,
|
||||
op: str,
|
||||
column: sqlalchemy.Column,
|
||||
table: sqlalchemy.Table,
|
||||
table_prefix: str,
|
||||
self,
|
||||
value: Any,
|
||||
op: str,
|
||||
column: sqlalchemy.Column,
|
||||
table: sqlalchemy.Table,
|
||||
table_prefix: str,
|
||||
) -> sqlalchemy.sql.expression.TextClause:
|
||||
"""
|
||||
Escapes characters if it's required.
|
||||
@ -158,7 +160,7 @@ class QueryClause:
|
||||
return clause
|
||||
|
||||
def _determine_filter_target_table(
|
||||
self, related_parts: List[str], select_related: List[str]
|
||||
self, related_parts: List[str], select_related: List[str], field_name: str
|
||||
) -> Tuple[List[str], str, Type["Model"]]:
|
||||
"""
|
||||
Adds related strings to select_related list otherwise the clause would fail as
|
||||
@ -187,27 +189,34 @@ class QueryClause:
|
||||
# Walk the relationships to the actual model class
|
||||
# against which the comparison is being made.
|
||||
previous_model = model_cls
|
||||
for part in related_parts:
|
||||
part2 = part
|
||||
if issubclass(model_cls.Meta.model_fields[part], ManyToManyField):
|
||||
through_field = model_cls.Meta.model_fields[part]
|
||||
previous_model = through_field.through
|
||||
part2 = through_field.default_target_field_name() # type: ignore
|
||||
manager = model_cls.Meta.alias_manager
|
||||
manager = model_cls.Meta.alias_manager
|
||||
for relation in related_parts:
|
||||
related_field = model_cls.Meta.model_fields[relation]
|
||||
if issubclass(related_field, ManyToManyField):
|
||||
previous_model = related_field.through
|
||||
relation = related_field.default_target_field_name() # type: ignore
|
||||
table_prefix = manager.resolve_relation_alias(
|
||||
from_model=previous_model, relation_name=part2
|
||||
from_model=previous_model, relation_name=relation
|
||||
)
|
||||
model_cls = model_cls.Meta.model_fields[part].to
|
||||
model_cls = related_field.to
|
||||
previous_model = model_cls
|
||||
# handle duplicated aliases in nested relations
|
||||
# TODO: check later and remove nocover
|
||||
complex_prefix = manager.resolve_relation_alias(
|
||||
from_model=self.model_cls,
|
||||
relation_name='__'.join([related_str, field_name])
|
||||
)
|
||||
if complex_prefix: # pragma: nocover
|
||||
table_prefix = complex_prefix
|
||||
return select_related, table_prefix, model_cls
|
||||
|
||||
def _compile_clause(
|
||||
self,
|
||||
clause: sqlalchemy.sql.expression.BinaryExpression,
|
||||
column: sqlalchemy.Column,
|
||||
table: sqlalchemy.Table,
|
||||
table_prefix: str,
|
||||
modifiers: Dict,
|
||||
self,
|
||||
clause: sqlalchemy.sql.expression.BinaryExpression,
|
||||
column: sqlalchemy.Column,
|
||||
table: sqlalchemy.Table,
|
||||
table_prefix: str,
|
||||
modifiers: Dict,
|
||||
) -> sqlalchemy.sql.expression.TextClause:
|
||||
"""
|
||||
Compiles the clause to str using appropriate database dialect, replace columns
|
||||
@ -287,7 +296,7 @@ class QueryClause:
|
||||
|
||||
@staticmethod
|
||||
def _extract_operator_field_and_related(
|
||||
parts: List[str],
|
||||
parts: List[str],
|
||||
) -> Tuple[str, str, Optional[List]]:
|
||||
"""
|
||||
Splits filter query key and extracts required parts.
|
||||
|
||||
@ -24,18 +24,20 @@ if TYPE_CHECKING: # pragma no cover
|
||||
|
||||
class SqlJoin:
|
||||
def __init__( # noqa: CFQ002
|
||||
self,
|
||||
used_aliases: List,
|
||||
select_from: sqlalchemy.sql.select,
|
||||
columns: List[sqlalchemy.Column],
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
order_columns: Optional[List],
|
||||
sorted_orders: OrderedDict,
|
||||
main_model: Type["Model"],
|
||||
relation_name: str,
|
||||
related_models: Any = None,
|
||||
own_alias: str = "",
|
||||
self,
|
||||
used_aliases: List,
|
||||
select_from: sqlalchemy.sql.select,
|
||||
columns: List[sqlalchemy.Column],
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
order_columns: Optional[List],
|
||||
sorted_orders: OrderedDict,
|
||||
main_model: Type["Model"],
|
||||
relation_name: str,
|
||||
relation_str: str,
|
||||
related_models: Any = None,
|
||||
own_alias: str = "",
|
||||
source_model: Type["Model"] = None,
|
||||
) -> None:
|
||||
self.relation_name = relation_name
|
||||
self.related_models = related_models or []
|
||||
@ -53,6 +55,9 @@ class SqlJoin:
|
||||
self._next_model: Optional[Type["Model"]] = None
|
||||
self._next_alias: Optional[str] = None
|
||||
|
||||
self.relation_str = relation_str
|
||||
self.source_model = source_model
|
||||
|
||||
@property
|
||||
def next_model(self) -> Type["Model"]:
|
||||
if not self._next_model: # pragma: nocover
|
||||
@ -85,7 +90,8 @@ class SqlJoin:
|
||||
"""
|
||||
return self.main_model.Meta.alias_manager
|
||||
|
||||
def on_clause(self, previous_alias: str, from_clause: str, to_clause: str,) -> text:
|
||||
def on_clause(self, previous_alias: str, from_clause: str,
|
||||
to_clause: str, ) -> text:
|
||||
"""
|
||||
Receives aliases and names of both ends of the join and combines them
|
||||
into one text clause used in joins.
|
||||
@ -117,11 +123,7 @@ class SqlJoin:
|
||||
self.process_m2m_through_table()
|
||||
|
||||
self.next_model = self.target_field.to
|
||||
self.next_alias = self.alias_manager.resolve_relation_alias(
|
||||
from_model=self.target_field.owner, relation_name=self.relation_name
|
||||
)
|
||||
if self.next_alias not in self.used_aliases:
|
||||
self._process_join()
|
||||
self._forward_join()
|
||||
|
||||
self._process_following_joins()
|
||||
|
||||
@ -132,6 +134,23 @@ class SqlJoin:
|
||||
self.sorted_orders,
|
||||
)
|
||||
|
||||
def _forward_join(self):
|
||||
self.next_alias = self.alias_manager.resolve_relation_alias(
|
||||
from_model=self.target_field.owner, relation_name=self.relation_name
|
||||
)
|
||||
if self.next_alias not in self.used_aliases:
|
||||
self._process_join()
|
||||
else:
|
||||
if '__' in self.relation_str:
|
||||
relation_key = f'{self.source_model.get_name()}_{self.relation_str}'
|
||||
if relation_key not in self.alias_manager:
|
||||
print(f'registering {relation_key}')
|
||||
self.next_alias = self.alias_manager.add_alias(
|
||||
alias_key=relation_key)
|
||||
else:
|
||||
self.next_alias = self.alias_manager[relation_key]
|
||||
self._process_join()
|
||||
|
||||
def _process_following_joins(self) -> None:
|
||||
"""
|
||||
Iterates through nested models to create subsequent joins.
|
||||
@ -139,8 +158,8 @@ class SqlJoin:
|
||||
for related_name in self.related_models:
|
||||
remainder = None
|
||||
if (
|
||||
isinstance(self.related_models, dict)
|
||||
and self.related_models[related_name]
|
||||
isinstance(self.related_models, dict)
|
||||
and self.related_models[related_name]
|
||||
):
|
||||
remainder = self.related_models[related_name]
|
||||
self._process_deeper_join(related_name=related_name, remainder=remainder)
|
||||
@ -175,7 +194,9 @@ class SqlJoin:
|
||||
main_model=self.next_model,
|
||||
relation_name=related_name,
|
||||
related_models=remainder,
|
||||
relation_str='__'.join([self.relation_str, related_name]),
|
||||
own_alias=self.next_alias,
|
||||
source_model=self.source_model or self.main_model
|
||||
)
|
||||
(
|
||||
self.used_aliases,
|
||||
@ -203,11 +224,8 @@ class SqlJoin:
|
||||
self._replace_many_to_many_order_by_columns(self.relation_name, new_part)
|
||||
|
||||
self.next_model = self.target_field.through
|
||||
self.next_alias = self.alias_manager.resolve_relation_alias(
|
||||
from_model=self.target_field.owner, relation_name=self.relation_name
|
||||
)
|
||||
if self.next_alias not in self.used_aliases:
|
||||
self._process_join()
|
||||
self._forward_join()
|
||||
|
||||
self.relation_name = new_part
|
||||
self.own_alias = self.next_alias
|
||||
self.target_field = self.next_model.Meta.model_fields[self.relation_name]
|
||||
@ -226,18 +244,18 @@ class SqlJoin:
|
||||
"""
|
||||
target_field = self.target_field
|
||||
is_primary_self_ref = (
|
||||
target_field.self_reference
|
||||
and self.relation_name == target_field.self_reference_primary
|
||||
target_field.self_reference
|
||||
and self.relation_name == target_field.self_reference_primary
|
||||
)
|
||||
if (is_primary_self_ref and not reverse) or (
|
||||
not is_primary_self_ref and reverse
|
||||
not is_primary_self_ref and reverse
|
||||
):
|
||||
new_part = target_field.default_source_field_name() # type: ignore
|
||||
else:
|
||||
new_part = target_field.default_target_field_name() # type: ignore
|
||||
return new_part
|
||||
|
||||
def _process_join(self,) -> None: # noqa: CFQ002
|
||||
def _process_join(self, ) -> None: # noqa: CFQ002
|
||||
"""
|
||||
Resolves to and from column names and table names.
|
||||
|
||||
@ -317,10 +335,10 @@ class SqlJoin:
|
||||
:rtype: bool
|
||||
"""
|
||||
return len(condition) >= 2 and (
|
||||
condition[-2] == part or condition[-2][1:] == part
|
||||
condition[-2] == part or condition[-2][1:] == part
|
||||
)
|
||||
|
||||
def set_aliased_order_by(self, condition: List[str], to_table: str,) -> None:
|
||||
def set_aliased_order_by(self, condition: List[str], to_table: str, ) -> None:
|
||||
"""
|
||||
Substitute hyphens ('-') with descending order.
|
||||
Construct actual sqlalchemy text clause using aliased table and column name.
|
||||
@ -335,7 +353,7 @@ class SqlJoin:
|
||||
order = text(f"{self.next_alias}_{to_table}.{column_alias} {direction}")
|
||||
self.sorted_orders["__".join(condition)] = order
|
||||
|
||||
def get_order_bys(self, to_table: str, pkname_alias: str,) -> None: # noqa: CCR001
|
||||
def get_order_bys(self, to_table: str, pkname_alias: str, ) -> None: # noqa: CCR001
|
||||
"""
|
||||
Triggers construction of order bys if they are given.
|
||||
Otherwise by default each table is sorted by a primary key column asc.
|
||||
|
||||
@ -159,6 +159,7 @@ class Query:
|
||||
sorted_orders=self.sorted_orders,
|
||||
main_model=self.model_cls,
|
||||
relation_name=related,
|
||||
relation_str=related,
|
||||
related_models=remainder,
|
||||
)
|
||||
|
||||
|
||||
@ -228,6 +228,9 @@ class QuerySet:
|
||||
:return: filtered QuerySet
|
||||
:rtype: QuerySet
|
||||
"""
|
||||
# TODO: delay processing of filter clauses or switch to group one
|
||||
# that keeps all aliases even if duplicated - now initialized too late
|
||||
# in the join
|
||||
qryclause = QueryClause(
|
||||
model_cls=self.model,
|
||||
select_related=self._select_related,
|
||||
|
||||
@ -33,6 +33,12 @@ class AliasManager:
|
||||
def __init__(self) -> None:
|
||||
self._aliases_new: Dict[str, str] = dict()
|
||||
|
||||
def __contains__(self, item):
|
||||
return self._aliases_new.__contains__(item)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._aliases_new.__getitem__(key)
|
||||
|
||||
@staticmethod
|
||||
def prefixed_columns(
|
||||
alias: str, table: sqlalchemy.Table, fields: List = None
|
||||
|
||||
@ -169,7 +169,7 @@ async def test_other_forwardref_relation(cleanup):
|
||||
async def test_m2m_self_forwardref_relation(cleanup):
|
||||
async with db:
|
||||
async with db.transaction(force_rollback=True):
|
||||
checkers = await Game.objects.create(name="checkers")
|
||||
checkers = await Game.objects.create(name="Checkers")
|
||||
uno = await Game(name="Uno").save()
|
||||
jenga = await Game(name="Jenga").save()
|
||||
|
||||
@ -186,15 +186,17 @@ async def test_m2m_self_forwardref_relation(cleanup):
|
||||
await billy.friends.add(kate)
|
||||
await billy.friends.add(steve)
|
||||
|
||||
# await steve.friends.add(kate)
|
||||
# await steve.friends.add(billy)
|
||||
|
||||
billy_check = await Child.objects.select_related(
|
||||
["friends", "favourite_game", "least_favourite_game"]
|
||||
["friends", "favourite_game", "least_favourite_game",
|
||||
"friends__favourite_game", "friends__least_favourite_game"]
|
||||
).get(name="Billy")
|
||||
assert len(billy_check.friends) == 2
|
||||
assert billy_check.friends[0].name == "Kate"
|
||||
assert billy_check.friends[0].favourite_game.name == 'Checkers'
|
||||
assert billy_check.friends[0].least_favourite_game.name == 'Uno'
|
||||
assert billy_check.friends[1].name == "Steve"
|
||||
assert billy_check.friends[1].favourite_game.name == 'Jenga'
|
||||
assert billy_check.friends[1].least_favourite_game.name == 'Uno'
|
||||
assert billy_check.favourite_game.name == "Uno"
|
||||
|
||||
kate_check = await Child.objects.select_related(["also_friends",]).get(
|
||||
@ -203,3 +205,13 @@ async def test_m2m_self_forwardref_relation(cleanup):
|
||||
|
||||
assert len(kate_check.also_friends) == 1
|
||||
assert kate_check.also_friends[0].name == "Billy"
|
||||
|
||||
# TODO: Fix filters with complex prefixes
|
||||
# billy_check = await Child.objects.select_related(
|
||||
# ["friends", "favourite_game", "least_favourite_game",
|
||||
# "friends__favourite_game", "friends__least_favourite_game"]
|
||||
# ).filter(friends__favourite_game__name="Checkers").get(name="Billy")
|
||||
# assert len(billy_check.friends) == 1
|
||||
# assert billy_check.friends[0].name == "Kate"
|
||||
# assert billy_check.friends[0].favourite_game.name == 'Checkers'
|
||||
# assert billy_check.friends[0].least_favourite_game.name == 'Uno'
|
||||
|
||||
9
tests/test_models_helpers.py
Normal file
9
tests/test_models_helpers.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ormar.models.helpers.models import group_related_list
|
||||
|
||||
|
||||
def test_group_related_list():
|
||||
given = ['friends__least_favourite_game', 'least_favourite_game', 'friends',
|
||||
'favourite_game', 'friends__favourite_game']
|
||||
expected = {'least_favourite_game': [], 'favourite_game': [],
|
||||
'friends': ['favourite_game', 'least_favourite_game']}
|
||||
assert group_related_list(given) == expected
|
||||
Reference in New Issue
Block a user