extract filters into filter actions and delay their processing time to allow for registration of complex relations, refactoring and optimization, now one join with relations with same aliases are possible

This commit is contained in:
collerek
2021-01-21 15:55:23 +01:00
parent d6e2c85b79
commit a2834666fc
13 changed files with 425 additions and 325 deletions

View File

@ -52,8 +52,7 @@ def populate_default_options_values(
new_model.Meta.abstract = False new_model.Meta.abstract = False
if any( if any(
is_field_an_forward_ref(field) for field in is_field_an_forward_ref(field) for field in new_model.Meta.model_fields.values()
new_model.Meta.model_fields.values()
): ):
new_model.Meta.requires_ref_update = True new_model.Meta.requires_ref_update = True
else: else:
@ -135,12 +134,11 @@ def group_related_list(list_: List) -> Dict:
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0]) grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
for key, group in grouped: for key, group in grouped:
group_list = list(group) group_list = list(group)
new = sorted([ new = sorted(
"__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1 ["__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1]
]) )
if any("__" in x for x in new): if any("__" in x for x in new):
result_dict[key] = group_related_list(new) result_dict[key] = group_related_list(new)
else: else:
result_dict.setdefault(key, []).extend(new) result_dict.setdefault(key, []).extend(new)
return {k: v for k, v in return {k: v for k, v in sorted(result_dict.items(), key=lambda item: len(item[1]))}
sorted(result_dict.items(), key=lambda item: len(item[1]))}

View File

@ -1,36 +1,31 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type import itertools
from dataclasses import dataclass
import sqlalchemy from typing import Any, List, TYPE_CHECKING, Tuple, Type
from sqlalchemy import text
import ormar # noqa I100 import ormar # noqa I100
from ormar.exceptions import QueryDefinitionError from ormar.queryset.filter_action import FilterAction
from ormar.fields.many_to_many import ManyToManyField from ormar.queryset.utils import get_relationship_alias_model_and_str
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
FILTER_OPERATORS = {
"exact": "__eq__", @dataclass
"iexact": "ilike", class Prefix:
"contains": "like", source_model: Type["Model"]
"icontains": "ilike", table_prefix: str
"startswith": "like", model_cls: Type["Model"]
"istartswith": "ilike", relation_str: str
"endswith": "like",
"iendswith": "ilike", @property
"in": "in_", def alias_key(self) -> str:
"gt": "__gt__", source_model_name = self.source_model.get_name()
"gte": "__ge__", return f"{source_model_name}_" f"{self.relation_str}"
"lt": "__lt__",
"lte": "__le__",
}
ESCAPE_CHARACTERS = ["%", "_"]
class QueryClause: class QueryClause:
""" """
Constructs where clauses from strings passed as arguments Constructs FilterActions from strings passed as arguments
""" """
def __init__( def __init__(
@ -43,9 +38,9 @@ class QueryClause:
self.model_cls = model_cls self.model_cls = model_cls
self.table = self.model_cls.Meta.table self.table = self.model_cls.Meta.table
def filter( # noqa: A003 def prepare_filter( # noqa: A003
self, **kwargs: Any self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: ) -> Tuple[List[FilterAction], List[str]]:
""" """
Main external access point that processes the clauses into sqlalchemy text Main external access point that processes the clauses into sqlalchemy text
clauses and updates select_related list with implicit related tables clauses and updates select_related list with implicit related tables
@ -66,7 +61,7 @@ class QueryClause:
def _populate_filter_clauses( def _populate_filter_clauses(
self, **kwargs: Any self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: ) -> Tuple[List[FilterAction], List[str]]:
""" """
Iterates all clauses and extracts used operator and field from related Iterates all clauses and extracts used operator and field from related
models if needed. Based on the chain of related names the target table models if needed. Based on the chain of related names the target table
@ -81,238 +76,84 @@ class QueryClause:
select_related = list(self._select_related) select_related = list(self._select_related)
for key, value in kwargs.items(): for key, value in kwargs.items():
table_prefix = "" filter_action = FilterAction(
if "__" in key: filter_str=key, value=value, model_cls=self.model_cls
parts = key.split("__") )
select_related = filter_action.update_select_related(
( select_related=select_related
op,
field_name,
related_parts,
) = self._extract_operator_field_and_related(parts)
model_cls = self.model_cls
if related_parts:
(
select_related,
table_prefix,
model_cls,
) = self._determine_filter_target_table(
related_parts=related_parts,
select_related=select_related,
field_name=field_name
) )
table = model_cls.Meta.table filter_clauses.append(filter_action)
column = model_cls.Meta.table.columns[field_name]
else: self._register_complex_duplicates(select_related)
op = "exact" filter_clauses = self._switch_filter_action_prefixes(
column = self.table.columns[self.model_cls.get_column_alias(key)] filter_clauses=filter_clauses
table = self.table
clause = self._process_column_clause_for_operator_and_value(
value, op, column, table, table_prefix
) )
filter_clauses.append(clause)
return filter_clauses, select_related return filter_clauses, select_related
def _process_column_clause_for_operator_and_value( def _register_complex_duplicates(self, select_related: List[str]) -> None:
self,
value: Any,
op: str,
column: sqlalchemy.Column,
table: sqlalchemy.Table,
table_prefix: str,
) -> sqlalchemy.sql.expression.TextClause:
""" """
Escapes characters if it's required. Checks if duplicate aliases are presented which can happen in self relation
Substitutes values of the models if value is a ormar Model with its pk value. or when two joins end with the same pair of models.
Compiles the clause.
:param value: value of the filter If there are duplicates, the all duplicated joins are registered as source
:type value: Any model and whole relation key (not just last relation name).
:param op: filter operator
:type op: str
:param column: column on which filter should be applied
:type column: sqlalchemy.sql.schema.Column
:param table: table on which filter should be applied
:type table: sqlalchemy.sql.schema.Table
:param table_prefix: prefix from AliasManager
:type table_prefix: str
:return: complied and escaped clause
:rtype: sqlalchemy.sql.elements.TextClause
"""
value, has_escaped_character = self._escape_characters_in_clause(op, value)
if isinstance(value, ormar.Model): :param select_related: list of relation strings
value = value.pk
op_attr = FILTER_OPERATORS[op]
clause = getattr(column, op_attr)(value)
clause = self._compile_clause(
clause,
column,
table,
table_prefix,
modifiers={"escape": "\\" if has_escaped_character else None},
)
return clause
def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str], field_name: str
) -> Tuple[List[str], str, Type["Model"]]:
"""
Adds related strings to select_related list otherwise the clause would fail as
the required columns would not be present. That means that select_related
list is filled with missing values present in filters.
Walks the relation to retrieve the actual model on which the clause should be
constructed, extracts alias based on last relation leading to target model.
:param related_parts: list of split parts of related string
:type related_parts: List[str]
:param select_related: list of related models
:type select_related: List[str] :type select_related: List[str]
:return: list of related models, table_prefix, final model class :return: None
:rtype: Tuple[List[str], str, Type[Model]] :rtype: None
""" """
table_prefix = "" prefixes = self._parse_related_prefixes(select_related=select_related)
model_cls = self.model_cls
select_related = [relation for relation in select_related]
# Add any implied select_related manager = self.model_cls.Meta.alias_manager
related_str = "__".join(related_parts) filtered_prefixes = sorted(prefixes, key=lambda x: x.table_prefix)
if related_str not in select_related: grouped = itertools.groupby(filtered_prefixes, key=lambda x: x.table_prefix)
select_related.append(related_str) for _, group in grouped:
sorted_group = sorted(
# Walk the relationships to the actual model class group, key=lambda x: len(x.relation_str), reverse=True
# against which the comparison is being made.
previous_model = model_cls
manager = model_cls.Meta.alias_manager
for relation in related_parts:
related_field = model_cls.Meta.model_fields[relation]
if issubclass(related_field, ManyToManyField):
previous_model = related_field.through
relation = related_field.default_target_field_name() # type: ignore
table_prefix = manager.resolve_relation_alias(
from_model=previous_model, relation_name=relation
) )
model_cls = related_field.to for prefix in sorted_group[:-1]:
previous_model = model_cls if prefix.alias_key not in manager:
# handle duplicated aliases in nested relations manager.add_alias(alias_key=prefix.alias_key)
# TODO: check later and remove nocover
complex_prefix = manager.resolve_relation_alias( def _parse_related_prefixes(self, select_related: List[str]) -> List[Prefix]:
from_model=self.model_cls, """
relation_name='__'.join([related_str, field_name]) Walks all relation strings and parses the target models and prefixes.
:param select_related: list of relation strings
:type select_related: List[str]
:return: list of parsed prefixes
:rtype: List[Prefix]
"""
prefixes: List[Prefix] = []
for related in select_related:
prefix = Prefix(
self.model_cls,
*get_relationship_alias_model_and_str(
self.model_cls, related.split("__")
),
) )
if complex_prefix: # pragma: nocover prefixes.append(prefix)
table_prefix = complex_prefix return prefixes
return select_related, table_prefix, model_cls
def _compile_clause( def _switch_filter_action_prefixes(
self, self, filter_clauses: List[FilterAction]
clause: sqlalchemy.sql.expression.BinaryExpression, ) -> List[FilterAction]:
column: sqlalchemy.Column,
table: sqlalchemy.Table,
table_prefix: str,
modifiers: Dict,
) -> sqlalchemy.sql.expression.TextClause:
""" """
Compiles the clause to str using appropriate database dialect, replace columns Substitutes aliases for filter action if the complex key (whole relation str) is
names with aliased names and converts it back to TextClause. present in alias_manager.
:param clause: original not compiled clause :param filter_clauses: raw list of actions
:type clause: sqlalchemy.sql.elements.BinaryExpression :type filter_clauses: List[FilterAction]
:param column: column on which filter should be applied :return: list of actions with aliases changed if needed
:type column: sqlalchemy.sql.schema.Column :rtype: List[FilterAction]
:param table: table on which filter should be applied
:type table: sqlalchemy.sql.schema.Table
:param table_prefix: prefix from AliasManager
:type table_prefix: str
:param modifiers: sqlalchemy modifiers - used only to escape chars here
:type modifiers: Dict[str, NoneType]
:return: compiled and escaped clause
:rtype: sqlalchemy.sql.elements.TextClause
""" """
for modifier, modifier_value in modifiers.items(): manager = self.model_cls.Meta.alias_manager
clause.modifiers[modifier] = modifier_value for action in filter_clauses:
new_alias = manager.resolve_relation_alias(
clause_text = str( self.model_cls, action.related_str
clause.compile(
dialect=self.model_cls.Meta.database._backend._dialect,
compile_kwargs={"literal_binds": True},
) )
) if "__" in action.related_str and new_alias:
alias = f"{table_prefix}_" if table_prefix else "" action.table_prefix = new_alias
aliased_name = f"{alias}{table.name}.{column.name}" return filter_clauses
clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name)
clause = text(clause_text)
return clause
@staticmethod
def _escape_characters_in_clause(op: str, value: Any) -> Tuple[Any, bool]:
"""
Escapes the special characters ["%", "_"] if needed.
Adds `%` for `like` queries.
:raises QueryDefinitionError: if contains or icontains is used with
ormar model instance
:param op: operator used in query
:type op: str
:param value: value of the filter
:type value: Any
:return: escaped value and flag if escaping is needed
:rtype: Tuple[Any, bool]
"""
has_escaped_character = False
if op not in [
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
]:
return value, has_escaped_character
if isinstance(value, ormar.Model):
raise QueryDefinitionError(
"You cannot use contains and icontains with instance of the Model"
)
has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value)
if has_escaped_character:
# enable escape modifier
for char in ESCAPE_CHARACTERS:
value = value.replace(char, f"\\{char}")
prefix = "%" if "start" not in op else ""
sufix = "%" if "end" not in op else ""
value = f"{prefix}{value}{sufix}"
return value, has_escaped_character
@staticmethod
def _extract_operator_field_and_related(
parts: List[str],
) -> Tuple[str, str, Optional[List]]:
"""
Splits filter query key and extracts required parts.
:param parts: split filter query key
:type parts: List[str]
:return: operator, field_name, list of related parts
:rtype: Tuple[str, str, Optional[List]]
"""
if parts[-1] in FILTER_OPERATORS:
op = parts[-1]
field_name = parts[-2]
related_parts = parts[:-2]
else:
op = "exact"
field_name = parts[-1]
related_parts = parts[:-1]
return op, field_name, related_parts

View File

@ -0,0 +1,201 @@
from typing import Any, Dict, List, TYPE_CHECKING, Type
import sqlalchemy
from sqlalchemy import text
import ormar # noqa: I100, I202
from ormar.exceptions import QueryDefinitionError
from ormar.queryset.utils import get_relationship_alias_model_and_str
if TYPE_CHECKING: # pragma: nocover
from ormar import Model
FILTER_OPERATORS = {
"exact": "__eq__",
"iexact": "ilike",
"contains": "like",
"icontains": "ilike",
"startswith": "like",
"istartswith": "ilike",
"endswith": "like",
"iendswith": "ilike",
"in": "in_",
"gt": "__gt__",
"gte": "__ge__",
"lt": "__lt__",
"lte": "__le__",
}
ESCAPE_CHARACTERS = ["%", "_"]
class FilterAction:
"""
Filter Actions is populated by queryset when filter() is called.
All required params are extracted but kept raw until actual filter clause value
is required -> then the action is converted into text() clause.
Extracted in order to easily change table prefixes on complex relations.
"""
def __init__(self, filter_str: str, value: Any, model_cls: Type["Model"]) -> None:
parts = filter_str.split("__")
if parts[-1] in FILTER_OPERATORS:
self.operator = parts[-1]
self.field_name = parts[-2]
self.related_parts = parts[:-2]
else:
self.operator = "exact"
self.field_name = parts[-1]
self.related_parts = parts[:-1]
self.filter_value = value
self.table_prefix = ""
self.source_model = model_cls
self.target_model = model_cls
self._determine_filter_target_table()
self._escape_characters_in_clause()
@property
def table(self) -> sqlalchemy.Table:
"""Shortcut to sqlalchemy Table of filtered target model"""
return self.target_model.Meta.table
@property
def column(self) -> sqlalchemy.Column:
"""Shortcut to sqlalchemy column of filtered target model"""
aliased_name = self.target_model.get_column_alias(self.field_name)
return self.target_model.Meta.table.columns[aliased_name]
def has_escaped_characters(self) -> bool:
"""Check if value is a string that contains characters to escape"""
return isinstance(self.filter_value, str) and any(
c for c in ESCAPE_CHARACTERS if c in self.filter_value
)
def update_select_related(self, select_related: List[str]) -> List[str]:
"""
Updates list of select related with related part included in the filter key.
That way If you want to just filter by relation you do not have to provide
select_related separately.
:param select_related: list of relation join strings
:type select_related: List[str]
:return: list of relation joins with implied joins from filter added
:rtype: List[str]
"""
select_related = select_related[:]
if self.related_str and not any(
rel.startswith(self.related_str) for rel in select_related
):
select_related.append(self.related_str)
return select_related
def _determine_filter_target_table(self) -> None:
"""
Walks the relation to retrieve the actual model on which the clause should be
constructed, extracts alias based on last relation leading to target model.
"""
(
self.table_prefix,
self.target_model,
self.related_str,
) = get_relationship_alias_model_and_str(self.source_model, self.related_parts)
def _escape_characters_in_clause(self) -> None:
"""
Escapes the special characters ["%", "_"] if needed.
Adds `%` for `like` queries.
:raises QueryDefinitionError: if contains or icontains is used with
ormar model instance
:return: escaped value and flag if escaping is needed
:rtype: Tuple[Any, bool]
"""
self.has_escaped_character = False
if self.operator in [
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
]:
if isinstance(self.filter_value, ormar.Model):
raise QueryDefinitionError(
"You cannot use contains and icontains with instance of the Model"
)
self.has_escaped_character = self.has_escaped_characters()
if self.has_escaped_character:
self._escape_chars()
self._prefix_suffix_quote()
def _escape_chars(self) -> None:
"""Actually replaces chars to escape in value"""
for char in ESCAPE_CHARACTERS:
self.filter_value = self.filter_value.replace(char, f"\\{char}")
def _prefix_suffix_quote(self) -> None:
"""
Adds % to the beginning of the value if operator checks for containment and not
starts with.
Adds % to the end of the value if operator checks for containment and not
end with.
:return:
:rtype:
"""
prefix = "%" if "start" not in self.operator else ""
sufix = "%" if "end" not in self.operator else ""
self.filter_value = f"{prefix}{self.filter_value}{sufix}"
def get_text_clause(self,) -> sqlalchemy.sql.expression.TextClause:
"""
Escapes characters if it's required.
Substitutes values of the models if value is a ormar Model with its pk value.
Compiles the clause.
:return: complied and escaped clause
:rtype: sqlalchemy.sql.elements.TextClause
"""
if isinstance(self.filter_value, ormar.Model):
self.filter_value = self.filter_value.pk
op_attr = FILTER_OPERATORS[self.operator]
clause = getattr(self.column, op_attr)(self.filter_value)
clause = self._compile_clause(
clause, modifiers={"escape": "\\" if self.has_escaped_character else None},
)
return clause
def _compile_clause(
self, clause: sqlalchemy.sql.expression.BinaryExpression, modifiers: Dict,
) -> sqlalchemy.sql.expression.TextClause:
"""
Compiles the clause to str using appropriate database dialect, replace columns
names with aliased names and converts it back to TextClause.
:param clause: original not compiled clause
:type clause: sqlalchemy.sql.elements.BinaryExpression
:param modifiers: sqlalchemy modifiers - used only to escape chars here
:type modifiers: Dict[str, NoneType]
:return: compiled and escaped clause
:rtype: sqlalchemy.sql.elements.TextClause
"""
for modifier, modifier_value in modifiers.items():
clause.modifiers[modifier] = modifier_value
clause_text = str(
clause.compile(
dialect=self.target_model.Meta.database._backend._dialect,
compile_kwargs={"literal_binds": True},
)
)
alias = f"{self.table_prefix}_" if self.table_prefix else ""
aliased_name = f"{alias}{self.table.name}.{self.column.name}"
clause_text = clause_text.replace(
f"{self.table.name}.{self.column.name}", aliased_name
)
clause = text(clause_text)
return clause

View File

@ -1,6 +1,7 @@
from typing import List from typing import List
import sqlalchemy import sqlalchemy
from ormar.queryset.filter_action import FilterAction
class FilterQuery: class FilterQuery:
@ -8,7 +9,9 @@ class FilterQuery:
Modifies the select query with given list of where/filter clauses. Modifies the select query with given list of where/filter clauses.
""" """
def __init__(self, filter_clauses: List, exclude: bool = False) -> None: def __init__(
self, filter_clauses: List[FilterAction], exclude: bool = False
) -> None:
self.exclude = exclude self.exclude = exclude
self.filter_clauses = filter_clauses self.filter_clauses = filter_clauses
@ -23,9 +26,11 @@ class FilterQuery:
""" """
if self.filter_clauses: if self.filter_clauses:
if len(self.filter_clauses) == 1: if len(self.filter_clauses) == 1:
clause = self.filter_clauses[0] clause = self.filter_clauses[0].get_text_clause()
else: else:
clause = sqlalchemy.sql.and_(*self.filter_clauses) clause = sqlalchemy.sql.and_(
*[x.get_text_clause() for x in self.filter_clauses]
)
clause = sqlalchemy.sql.not_(clause) if self.exclude else clause clause = sqlalchemy.sql.not_(clause) if self.exclude else clause
expr = expr.where(clause) expr = expr.where(clause)
return expr return expr

View File

@ -62,7 +62,7 @@ class SqlJoin:
def next_model(self) -> Type["Model"]: def next_model(self) -> Type["Model"]:
if not self._next_model: # pragma: nocover if not self._next_model: # pragma: nocover
raise RelationshipInstanceError( raise RelationshipInstanceError(
"Cannot link to related table if " "relation to model is not set." "Cannot link to related table if relation.to model is not set."
) )
return self._next_model return self._next_model
@ -90,8 +90,7 @@ class SqlJoin:
""" """
return self.main_model.Meta.alias_manager return self.main_model.Meta.alias_manager
def on_clause(self, previous_alias: str, from_clause: str, def on_clause(self, previous_alias: str, from_clause: str, to_clause: str,) -> text:
to_clause: str, ) -> text:
""" """
Receives aliases and names of both ends of the join and combines them Receives aliases and names of both ends of the join and combines them
into one text clause used in joins. into one text clause used in joins.
@ -134,19 +133,23 @@ class SqlJoin:
self.sorted_orders, self.sorted_orders,
) )
def _forward_join(self): def _forward_join(self) -> None:
"""
Process actual join.
Registers complex relation join on encountering of the duplicated alias.
"""
self.next_alias = self.alias_manager.resolve_relation_alias( self.next_alias = self.alias_manager.resolve_relation_alias(
from_model=self.target_field.owner, relation_name=self.relation_name from_model=self.target_field.owner, relation_name=self.relation_name
) )
if self.next_alias not in self.used_aliases: if self.next_alias not in self.used_aliases:
self._process_join() self._process_join()
else: else:
if '__' in self.relation_str: if "__" in self.relation_str and self.source_model:
relation_key = f'{self.source_model.get_name()}_{self.relation_str}' relation_key = f"{self.source_model.get_name()}_{self.relation_str}"
if relation_key not in self.alias_manager: if relation_key not in self.alias_manager:
print(f'registering {relation_key}')
self.next_alias = self.alias_manager.add_alias( self.next_alias = self.alias_manager.add_alias(
alias_key=relation_key) alias_key=relation_key
)
else: else:
self.next_alias = self.alias_manager[relation_key] self.next_alias = self.alias_manager[relation_key]
self._process_join() self._process_join()
@ -194,9 +197,9 @@ class SqlJoin:
main_model=self.next_model, main_model=self.next_model,
relation_name=related_name, relation_name=related_name,
related_models=remainder, related_models=remainder,
relation_str='__'.join([self.relation_str, related_name]), relation_str="__".join([self.relation_str, related_name]),
own_alias=self.next_alias, own_alias=self.next_alias,
source_model=self.source_model or self.main_model source_model=self.source_model or self.main_model,
) )
( (
self.used_aliases, self.used_aliases,

View File

@ -290,7 +290,7 @@ class PrefetchQuery:
model_cls=clause_target, select_related=[], filter_clauses=[], model_cls=clause_target, select_related=[], filter_clauses=[],
) )
kwargs = {f"{filter_column}__in": ids} kwargs = {f"{filter_column}__in": ids}
filter_clauses, _ = qryclause.filter(**kwargs) filter_clauses, _ = qryclause.prepare_filter(**kwargs)
return filter_clauses return filter_clauses
return [] return []

View File

@ -8,6 +8,7 @@ from sqlalchemy import text
import ormar # noqa I100 import ormar # noqa I100
from ormar.models.helpers.models import group_related_list from ormar.models.helpers.models import group_related_list
from ormar.queryset import FilterQuery, LimitQuery, OffsetQuery, OrderQuery from ormar.queryset import FilterQuery, LimitQuery, OffsetQuery, OrderQuery
from ormar.queryset.filter_action import FilterAction
from ormar.queryset.join import SqlJoin from ormar.queryset.join import SqlJoin
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -18,8 +19,8 @@ class Query:
def __init__( # noqa CFQ002 def __init__( # noqa CFQ002
self, self,
model_cls: Type["Model"], model_cls: Type["Model"],
filter_clauses: List, filter_clauses: List[FilterAction],
exclude_clauses: List, exclude_clauses: List[FilterAction],
select_related: List, select_related: List,
limit_count: Optional[int], limit_count: Optional[int],
offset: Optional[int], offset: Optional[int],
@ -200,12 +201,12 @@ class Query:
filters_to_use = [ filters_to_use = [
filter_clause filter_clause
for filter_clause in self.filter_clauses for filter_clause in self.filter_clauses
if filter_clause.text.startswith(f"{self.table.name}.") if filter_clause.table_prefix == ""
] ]
excludes_to_use = [ excludes_to_use = [
filter_clause filter_clause
for filter_clause in self.exclude_clauses for filter_clause in self.exclude_clauses
if filter_clause.text.startswith(f"{self.table.name}.") if filter_clause.table_prefix == ""
] ]
sorts_to_use = {k: v for k, v in self.sorted_orders.items() if "__" not in k} sorts_to_use = {k: v for k, v in self.sorted_orders.items() if "__" not in k}
expr = FilterQuery(filter_clauses=filters_to_use).apply(expr) expr = FilterQuery(filter_clauses=filters_to_use).apply(expr)

View File

@ -236,7 +236,7 @@ class QuerySet:
select_related=self._select_related, select_related=self._select_related,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
) )
filter_clauses, select_related = qryclause.filter(**kwargs) filter_clauses, select_related = qryclause.prepare_filter(**kwargs)
if _exclude: if _exclude:
exclude_clauses = filter_clauses exclude_clauses = filter_clauses
filter_clauses = self.filter_clauses filter_clauses = self.filter_clauses

View File

@ -7,10 +7,13 @@ from typing import (
Sequence, Sequence,
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
Tuple,
Type, Type,
Union, Union,
) )
from ormar.fields import ManyToManyField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -212,3 +215,35 @@ def extract_models_to_dict_of_lists(
for model in models: for model in models:
extract_nested_models(model, model_type, select_dict, extracted) extract_nested_models(model, model_type, select_dict, extracted)
return extracted return extracted
def get_relationship_alias_model_and_str(
source_model: Type["Model"], related_parts: List
) -> Tuple[str, Type["Model"], str]:
"""
Walks the relation to retrieve the actual model on which the clause should be
constructed, extracts alias based on last relation leading to target model.
:param related_parts: list of related names extracted from string
:type related_parts: Union[List, List[str]]
:param source_model: model from which relation starts
:type source_model: Type[Model]
:return: table prefix, target model and relation string
:rtype: Tuple[str, Type["Model"], str]
"""
table_prefix = ""
model_cls = source_model
previous_model = model_cls
manager = model_cls.Meta.alias_manager
for relation in related_parts:
related_field = model_cls.Meta.model_fields[relation]
if issubclass(related_field, ManyToManyField):
previous_model = related_field.through
relation = related_field.default_target_field_name() # type: ignore
table_prefix = manager.resolve_relation_alias(
from_model=previous_model, relation_name=relation
)
model_cls = related_field.to
previous_model = model_cls
relation_str = "__".join(related_parts)
return table_prefix, model_cls, relation_str

View File

@ -1,7 +1,7 @@
import string import string
import uuid import uuid
from random import choices from random import choices
from typing import Dict, List, TYPE_CHECKING, Type from typing import Any, Dict, List, TYPE_CHECKING, Type
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
@ -33,10 +33,10 @@ class AliasManager:
def __init__(self) -> None: def __init__(self) -> None:
self._aliases_new: Dict[str, str] = dict() self._aliases_new: Dict[str, str] = dict()
def __contains__(self, item): def __contains__(self, item: str) -> bool:
return self._aliases_new.__contains__(item) return self._aliases_new.__contains__(item)
def __getitem__(self, key): def __getitem__(self, key: str) -> Any:
return self._aliases_new.__getitem__(key) return self._aliases_new.__getitem__(key)
@staticmethod @staticmethod

View File

@ -187,31 +187,44 @@ async def test_m2m_self_forwardref_relation(cleanup):
await billy.friends.add(steve) await billy.friends.add(steve)
billy_check = await Child.objects.select_related( billy_check = await Child.objects.select_related(
["friends", "favourite_game", "least_favourite_game", [
"friends__favourite_game", "friends__least_favourite_game"] "friends",
"favourite_game",
"least_favourite_game",
"friends__favourite_game",
"friends__least_favourite_game",
]
).get(name="Billy") ).get(name="Billy")
assert len(billy_check.friends) == 2 assert len(billy_check.friends) == 2
assert billy_check.friends[0].name == "Kate" assert billy_check.friends[0].name == "Kate"
assert billy_check.friends[0].favourite_game.name == 'Checkers' assert billy_check.friends[0].favourite_game.name == "Checkers"
assert billy_check.friends[0].least_favourite_game.name == 'Uno' assert billy_check.friends[0].least_favourite_game.name == "Uno"
assert billy_check.friends[1].name == "Steve" assert billy_check.friends[1].name == "Steve"
assert billy_check.friends[1].favourite_game.name == 'Jenga' assert billy_check.friends[1].favourite_game.name == "Jenga"
assert billy_check.friends[1].least_favourite_game.name == 'Uno' assert billy_check.friends[1].least_favourite_game.name == "Uno"
assert billy_check.favourite_game.name == "Uno" assert billy_check.favourite_game.name == "Uno"
kate_check = await Child.objects.select_related(["also_friends",]).get( kate_check = await Child.objects.select_related(["also_friends"]).get(
name="Kate" name="Kate"
) )
assert len(kate_check.also_friends) == 1 assert len(kate_check.also_friends) == 1
assert kate_check.also_friends[0].name == "Billy" assert kate_check.also_friends[0].name == "Billy"
# TODO: Fix filters with complex prefixes billy_check = (
# billy_check = await Child.objects.select_related( await Child.objects.select_related(
# ["friends", "favourite_game", "least_favourite_game", [
# "friends__favourite_game", "friends__least_favourite_game"] "friends",
# ).filter(friends__favourite_game__name="Checkers").get(name="Billy") "favourite_game",
# assert len(billy_check.friends) == 1 "least_favourite_game",
# assert billy_check.friends[0].name == "Kate" "friends__favourite_game",
# assert billy_check.friends[0].favourite_game.name == 'Checkers' "friends__least_favourite_game",
# assert billy_check.friends[0].least_favourite_game.name == 'Uno' ]
)
.filter(friends__favourite_game__name="Checkers")
.get(name="Billy")
)
assert len(billy_check.friends) == 1
assert billy_check.friends[0].name == "Kate"
assert billy_check.friends[0].favourite_game.name == "Checkers"
assert billy_check.friends[0].least_favourite_game.name == "Uno"

View File

@ -2,8 +2,16 @@ from ormar.models.helpers.models import group_related_list
def test_group_related_list(): def test_group_related_list():
given = ['friends__least_favourite_game', 'least_favourite_game', 'friends', given = [
'favourite_game', 'friends__favourite_game'] "friends__least_favourite_game",
expected = {'least_favourite_game': [], 'favourite_game': [], "least_favourite_game",
'friends': ['favourite_game', 'least_favourite_game']} "friends",
"favourite_game",
"friends__favourite_game",
]
expected = {
"least_favourite_game": [],
"favourite_game": [],
"friends": ["favourite_game", "least_favourite_game"],
}
assert group_related_list(given) == expected assert group_related_list(given) == expected

View File

@ -101,15 +101,10 @@ async def test_model_multiple_instances_of_same_table_in_schema():
async with database: async with database:
await create_data() await create_data()
classes = await SchoolClass.objects.select_related( classes = await SchoolClass.objects.select_related(
["teachers__category__department", "students"] ["teachers__category__department", "students__category__department"]
).all() ).all()
assert classes[0].name == "Math" assert classes[0].name == "Math"
assert classes[0].students[0].name == "Jane" assert classes[0].students[0].name == "Jane"
assert len(classes[0].dict().get("students")) == 2 assert len(classes[0].dict().get("students")) == 2
assert classes[0].teachers[0].category.department.name == "Law Department" assert classes[0].teachers[0].category.department.name == "Law Department"
assert classes[0].students[0].category.pk is not None
assert classes[0].students[0].category.name is None
await classes[0].students[0].category.load()
await classes[0].students[0].category.department.load()
assert classes[0].students[0].category.department.name == "Math Department" assert classes[0].students[0].category.department.name == "Math Department"