extract filters into filter actions and delay their processing time to allow for registration of complex relations, refactoring and optimization, now one join with relations with same aliases are possible

This commit is contained in:
collerek
2021-01-21 15:55:23 +01:00
parent d6e2c85b79
commit a2834666fc
13 changed files with 425 additions and 325 deletions

View File

@ -22,12 +22,12 @@ def is_field_an_forward_ref(field: Type["BaseField"]) -> bool:
:rtype: bool
"""
return issubclass(field, ForeignKeyField) and (
field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef
field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef
)
def populate_default_options_values(
new_model: Type["Model"], model_fields: Dict
new_model: Type["Model"], model_fields: Dict
) -> None:
"""
Sets all optional Meta values to it's defaults
@ -52,8 +52,7 @@ def populate_default_options_values(
new_model.Meta.abstract = False
if any(
is_field_an_forward_ref(field) for field in
new_model.Meta.model_fields.values()
is_field_an_forward_ref(field) for field in new_model.Meta.model_fields.values()
):
new_model.Meta.requires_ref_update = True
else:
@ -78,7 +77,7 @@ def extract_annotations_and_default_vals(attrs: Dict) -> Tuple[Dict, Dict]:
# cannot be in relations helpers due to cyclical import
def validate_related_names_in_relations( # noqa CCR001
model_fields: Dict, new_model: Type["Model"]
model_fields: Dict, new_model: Type["Model"]
) -> None:
"""
Performs a validation of relation_names in relation fields.
@ -135,12 +134,11 @@ def group_related_list(list_: List) -> Dict:
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
for key, group in grouped:
group_list = list(group)
new = sorted([
"__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1
])
new = sorted(
["__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1]
)
if any("__" in x for x in new):
result_dict[key] = group_related_list(new)
else:
result_dict.setdefault(key, []).extend(new)
return {k: v for k, v in
sorted(result_dict.items(), key=lambda item: len(item[1]))}
return {k: v for k, v in sorted(result_dict.items(), key=lambda item: len(item[1]))}

View File

@ -1,40 +1,35 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type
import sqlalchemy
from sqlalchemy import text
import itertools
from dataclasses import dataclass
from typing import Any, List, TYPE_CHECKING, Tuple, Type
import ormar # noqa I100
from ormar.exceptions import QueryDefinitionError
from ormar.fields.many_to_many import ManyToManyField
from ormar.queryset.filter_action import FilterAction
from ormar.queryset.utils import get_relationship_alias_model_and_str
if TYPE_CHECKING: # pragma no cover
from ormar import Model
FILTER_OPERATORS = {
"exact": "__eq__",
"iexact": "ilike",
"contains": "like",
"icontains": "ilike",
"startswith": "like",
"istartswith": "ilike",
"endswith": "like",
"iendswith": "ilike",
"in": "in_",
"gt": "__gt__",
"gte": "__ge__",
"lt": "__lt__",
"lte": "__le__",
}
ESCAPE_CHARACTERS = ["%", "_"]
@dataclass
class Prefix:
source_model: Type["Model"]
table_prefix: str
model_cls: Type["Model"]
relation_str: str
@property
def alias_key(self) -> str:
source_model_name = self.source_model.get_name()
return f"{source_model_name}_" f"{self.relation_str}"
class QueryClause:
"""
Constructs where clauses from strings passed as arguments
Constructs FilterActions from strings passed as arguments
"""
def __init__(
self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
) -> None:
self._select_related = select_related[:]
@ -43,9 +38,9 @@ class QueryClause:
self.model_cls = model_cls
self.table = self.model_cls.Meta.table
def filter( # noqa: A003
self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
def prepare_filter( # noqa: A003
self, **kwargs: Any
) -> Tuple[List[FilterAction], List[str]]:
"""
Main external access point that processes the clauses into sqlalchemy text
clauses and updates select_related list with implicit related tables
@ -65,8 +60,8 @@ class QueryClause:
return filter_clauses, select_related
def _populate_filter_clauses(
self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
self, **kwargs: Any
) -> Tuple[List[FilterAction], List[str]]:
"""
Iterates all clauses and extracts used operator and field from related
models if needed. Based on the chain of related names the target table
@ -81,238 +76,84 @@ class QueryClause:
select_related = list(self._select_related)
for key, value in kwargs.items():
table_prefix = ""
if "__" in key:
parts = key.split("__")
(
op,
field_name,
related_parts,
) = self._extract_operator_field_and_related(parts)
model_cls = self.model_cls
if related_parts:
(
select_related,
table_prefix,
model_cls,
) = self._determine_filter_target_table(
related_parts=related_parts,
select_related=select_related,
field_name=field_name
)
table = model_cls.Meta.table
column = model_cls.Meta.table.columns[field_name]
else:
op = "exact"
column = self.table.columns[self.model_cls.get_column_alias(key)]
table = self.table
clause = self._process_column_clause_for_operator_and_value(
value, op, column, table, table_prefix
filter_action = FilterAction(
filter_str=key, value=value, model_cls=self.model_cls
)
filter_clauses.append(clause)
select_related = filter_action.update_select_related(
select_related=select_related
)
filter_clauses.append(filter_action)
self._register_complex_duplicates(select_related)
filter_clauses = self._switch_filter_action_prefixes(
filter_clauses=filter_clauses
)
return filter_clauses, select_related
def _process_column_clause_for_operator_and_value(
self,
value: Any,
op: str,
column: sqlalchemy.Column,
table: sqlalchemy.Table,
table_prefix: str,
) -> sqlalchemy.sql.expression.TextClause:
def _register_complex_duplicates(self, select_related: List[str]) -> None:
"""
Escapes characters if it's required.
Substitutes values of the models if value is a ormar Model with its pk value.
Compiles the clause.
Checks if duplicate aliases are presented which can happen in self relation
or when two joins end with the same pair of models.
:param value: value of the filter
:type value: Any
:param op: filter operator
:type op: str
:param column: column on which filter should be applied
:type column: sqlalchemy.sql.schema.Column
:param table: table on which filter should be applied
:type table: sqlalchemy.sql.schema.Table
:param table_prefix: prefix from AliasManager
:type table_prefix: str
:return: complied and escaped clause
:rtype: sqlalchemy.sql.elements.TextClause
"""
value, has_escaped_character = self._escape_characters_in_clause(op, value)
If there are duplicates, the all duplicated joins are registered as source
model and whole relation key (not just last relation name).
if isinstance(value, ormar.Model):
value = value.pk
op_attr = FILTER_OPERATORS[op]
clause = getattr(column, op_attr)(value)
clause = self._compile_clause(
clause,
column,
table,
table_prefix,
modifiers={"escape": "\\" if has_escaped_character else None},
)
return clause
def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str], field_name: str
) -> Tuple[List[str], str, Type["Model"]]:
"""
Adds related strings to select_related list otherwise the clause would fail as
the required columns would not be present. That means that select_related
list is filled with missing values present in filters.
Walks the relation to retrieve the actual model on which the clause should be
constructed, extracts alias based on last relation leading to target model.
:param related_parts: list of split parts of related string
:type related_parts: List[str]
:param select_related: list of related models
:param select_related: list of relation strings
:type select_related: List[str]
:return: list of related models, table_prefix, final model class
:rtype: Tuple[List[str], str, Type[Model]]
:return: None
:rtype: None
"""
table_prefix = ""
model_cls = self.model_cls
select_related = [relation for relation in select_related]
prefixes = self._parse_related_prefixes(select_related=select_related)
# Add any implied select_related
related_str = "__".join(related_parts)
if related_str not in select_related:
select_related.append(related_str)
# Walk the relationships to the actual model class
# against which the comparison is being made.
previous_model = model_cls
manager = model_cls.Meta.alias_manager
for relation in related_parts:
related_field = model_cls.Meta.model_fields[relation]
if issubclass(related_field, ManyToManyField):
previous_model = related_field.through
relation = related_field.default_target_field_name() # type: ignore
table_prefix = manager.resolve_relation_alias(
from_model=previous_model, relation_name=relation
manager = self.model_cls.Meta.alias_manager
filtered_prefixes = sorted(prefixes, key=lambda x: x.table_prefix)
grouped = itertools.groupby(filtered_prefixes, key=lambda x: x.table_prefix)
for _, group in grouped:
sorted_group = sorted(
group, key=lambda x: len(x.relation_str), reverse=True
)
model_cls = related_field.to
previous_model = model_cls
# handle duplicated aliases in nested relations
# TODO: check later and remove nocover
complex_prefix = manager.resolve_relation_alias(
from_model=self.model_cls,
relation_name='__'.join([related_str, field_name])
)
if complex_prefix: # pragma: nocover
table_prefix = complex_prefix
return select_related, table_prefix, model_cls
for prefix in sorted_group[:-1]:
if prefix.alias_key not in manager:
manager.add_alias(alias_key=prefix.alias_key)
def _compile_clause(
self,
clause: sqlalchemy.sql.expression.BinaryExpression,
column: sqlalchemy.Column,
table: sqlalchemy.Table,
table_prefix: str,
modifiers: Dict,
) -> sqlalchemy.sql.expression.TextClause:
def _parse_related_prefixes(self, select_related: List[str]) -> List[Prefix]:
"""
Compiles the clause to str using appropriate database dialect, replace columns
names with aliased names and converts it back to TextClause.
Walks all relation strings and parses the target models and prefixes.
:param clause: original not compiled clause
:type clause: sqlalchemy.sql.elements.BinaryExpression
:param column: column on which filter should be applied
:type column: sqlalchemy.sql.schema.Column
:param table: table on which filter should be applied
:type table: sqlalchemy.sql.schema.Table
:param table_prefix: prefix from AliasManager
:type table_prefix: str
:param modifiers: sqlalchemy modifiers - used only to escape chars here
:type modifiers: Dict[str, NoneType]
:return: compiled and escaped clause
:rtype: sqlalchemy.sql.elements.TextClause
:param select_related: list of relation strings
:type select_related: List[str]
:return: list of parsed prefixes
:rtype: List[Prefix]
"""
for modifier, modifier_value in modifiers.items():
clause.modifiers[modifier] = modifier_value
clause_text = str(
clause.compile(
dialect=self.model_cls.Meta.database._backend._dialect,
compile_kwargs={"literal_binds": True},
prefixes: List[Prefix] = []
for related in select_related:
prefix = Prefix(
self.model_cls,
*get_relationship_alias_model_and_str(
self.model_cls, related.split("__")
),
)
)
alias = f"{table_prefix}_" if table_prefix else ""
aliased_name = f"{alias}{table.name}.{column.name}"
clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name)
clause = text(clause_text)
return clause
prefixes.append(prefix)
return prefixes
@staticmethod
def _escape_characters_in_clause(op: str, value: Any) -> Tuple[Any, bool]:
def _switch_filter_action_prefixes(
self, filter_clauses: List[FilterAction]
) -> List[FilterAction]:
"""
Escapes the special characters ["%", "_"] if needed.
Adds `%` for `like` queries.
Substitutes aliases for filter action if the complex key (whole relation str) is
present in alias_manager.
:raises QueryDefinitionError: if contains or icontains is used with
ormar model instance
:param op: operator used in query
:type op: str
:param value: value of the filter
:type value: Any
:return: escaped value and flag if escaping is needed
:rtype: Tuple[Any, bool]
:param filter_clauses: raw list of actions
:type filter_clauses: List[FilterAction]
:return: list of actions with aliases changed if needed
:rtype: List[FilterAction]
"""
has_escaped_character = False
if op not in [
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
]:
return value, has_escaped_character
if isinstance(value, ormar.Model):
raise QueryDefinitionError(
"You cannot use contains and icontains with instance of the Model"
manager = self.model_cls.Meta.alias_manager
for action in filter_clauses:
new_alias = manager.resolve_relation_alias(
self.model_cls, action.related_str
)
has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value)
if has_escaped_character:
# enable escape modifier
for char in ESCAPE_CHARACTERS:
value = value.replace(char, f"\\{char}")
prefix = "%" if "start" not in op else ""
sufix = "%" if "end" not in op else ""
value = f"{prefix}{value}{sufix}"
return value, has_escaped_character
@staticmethod
def _extract_operator_field_and_related(
parts: List[str],
) -> Tuple[str, str, Optional[List]]:
"""
Splits filter query key and extracts required parts.
:param parts: split filter query key
:type parts: List[str]
:return: operator, field_name, list of related parts
:rtype: Tuple[str, str, Optional[List]]
"""
if parts[-1] in FILTER_OPERATORS:
op = parts[-1]
field_name = parts[-2]
related_parts = parts[:-2]
else:
op = "exact"
field_name = parts[-1]
related_parts = parts[:-1]
return op, field_name, related_parts
if "__" in action.related_str and new_alias:
action.table_prefix = new_alias
return filter_clauses

View File

@ -0,0 +1,201 @@
from typing import Any, Dict, List, TYPE_CHECKING, Type
import sqlalchemy
from sqlalchemy import text
import ormar # noqa: I100, I202
from ormar.exceptions import QueryDefinitionError
from ormar.queryset.utils import get_relationship_alias_model_and_str
if TYPE_CHECKING: # pragma: nocover
from ormar import Model
FILTER_OPERATORS = {
"exact": "__eq__",
"iexact": "ilike",
"contains": "like",
"icontains": "ilike",
"startswith": "like",
"istartswith": "ilike",
"endswith": "like",
"iendswith": "ilike",
"in": "in_",
"gt": "__gt__",
"gte": "__ge__",
"lt": "__lt__",
"lte": "__le__",
}
ESCAPE_CHARACTERS = ["%", "_"]
class FilterAction:
"""
Filter Actions is populated by queryset when filter() is called.
All required params are extracted but kept raw until actual filter clause value
is required -> then the action is converted into text() clause.
Extracted in order to easily change table prefixes on complex relations.
"""
def __init__(self, filter_str: str, value: Any, model_cls: Type["Model"]) -> None:
parts = filter_str.split("__")
if parts[-1] in FILTER_OPERATORS:
self.operator = parts[-1]
self.field_name = parts[-2]
self.related_parts = parts[:-2]
else:
self.operator = "exact"
self.field_name = parts[-1]
self.related_parts = parts[:-1]
self.filter_value = value
self.table_prefix = ""
self.source_model = model_cls
self.target_model = model_cls
self._determine_filter_target_table()
self._escape_characters_in_clause()
@property
def table(self) -> sqlalchemy.Table:
"""Shortcut to sqlalchemy Table of filtered target model"""
return self.target_model.Meta.table
@property
def column(self) -> sqlalchemy.Column:
"""Shortcut to sqlalchemy column of filtered target model"""
aliased_name = self.target_model.get_column_alias(self.field_name)
return self.target_model.Meta.table.columns[aliased_name]
def has_escaped_characters(self) -> bool:
"""Check if value is a string that contains characters to escape"""
return isinstance(self.filter_value, str) and any(
c for c in ESCAPE_CHARACTERS if c in self.filter_value
)
def update_select_related(self, select_related: List[str]) -> List[str]:
"""
Updates list of select related with related part included in the filter key.
That way If you want to just filter by relation you do not have to provide
select_related separately.
:param select_related: list of relation join strings
:type select_related: List[str]
:return: list of relation joins with implied joins from filter added
:rtype: List[str]
"""
select_related = select_related[:]
if self.related_str and not any(
rel.startswith(self.related_str) for rel in select_related
):
select_related.append(self.related_str)
return select_related
def _determine_filter_target_table(self) -> None:
"""
Walks the relation to retrieve the actual model on which the clause should be
constructed, extracts alias based on last relation leading to target model.
"""
(
self.table_prefix,
self.target_model,
self.related_str,
) = get_relationship_alias_model_and_str(self.source_model, self.related_parts)
def _escape_characters_in_clause(self) -> None:
"""
Escapes the special characters ["%", "_"] if needed.
Adds `%` for `like` queries.
:raises QueryDefinitionError: if contains or icontains is used with
ormar model instance
:return: escaped value and flag if escaping is needed
:rtype: Tuple[Any, bool]
"""
self.has_escaped_character = False
if self.operator in [
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
]:
if isinstance(self.filter_value, ormar.Model):
raise QueryDefinitionError(
"You cannot use contains and icontains with instance of the Model"
)
self.has_escaped_character = self.has_escaped_characters()
if self.has_escaped_character:
self._escape_chars()
self._prefix_suffix_quote()
def _escape_chars(self) -> None:
"""Actually replaces chars to escape in value"""
for char in ESCAPE_CHARACTERS:
self.filter_value = self.filter_value.replace(char, f"\\{char}")
def _prefix_suffix_quote(self) -> None:
"""
Adds % to the beginning of the value if operator checks for containment and not
starts with.
Adds % to the end of the value if operator checks for containment and not
end with.
:return:
:rtype:
"""
prefix = "%" if "start" not in self.operator else ""
sufix = "%" if "end" not in self.operator else ""
self.filter_value = f"{prefix}{self.filter_value}{sufix}"
def get_text_clause(self,) -> sqlalchemy.sql.expression.TextClause:
"""
Escapes characters if it's required.
Substitutes values of the models if value is a ormar Model with its pk value.
Compiles the clause.
:return: complied and escaped clause
:rtype: sqlalchemy.sql.elements.TextClause
"""
if isinstance(self.filter_value, ormar.Model):
self.filter_value = self.filter_value.pk
op_attr = FILTER_OPERATORS[self.operator]
clause = getattr(self.column, op_attr)(self.filter_value)
clause = self._compile_clause(
clause, modifiers={"escape": "\\" if self.has_escaped_character else None},
)
return clause
def _compile_clause(
self, clause: sqlalchemy.sql.expression.BinaryExpression, modifiers: Dict,
) -> sqlalchemy.sql.expression.TextClause:
"""
Compiles the clause to str using appropriate database dialect, replace columns
names with aliased names and converts it back to TextClause.
:param clause: original not compiled clause
:type clause: sqlalchemy.sql.elements.BinaryExpression
:param modifiers: sqlalchemy modifiers - used only to escape chars here
:type modifiers: Dict[str, NoneType]
:return: compiled and escaped clause
:rtype: sqlalchemy.sql.elements.TextClause
"""
for modifier, modifier_value in modifiers.items():
clause.modifiers[modifier] = modifier_value
clause_text = str(
clause.compile(
dialect=self.target_model.Meta.database._backend._dialect,
compile_kwargs={"literal_binds": True},
)
)
alias = f"{self.table_prefix}_" if self.table_prefix else ""
aliased_name = f"{alias}{self.table.name}.{self.column.name}"
clause_text = clause_text.replace(
f"{self.table.name}.{self.column.name}", aliased_name
)
clause = text(clause_text)
return clause

View File

@ -1,6 +1,7 @@
from typing import List
import sqlalchemy
from ormar.queryset.filter_action import FilterAction
class FilterQuery:
@ -8,7 +9,9 @@ class FilterQuery:
Modifies the select query with given list of where/filter clauses.
"""
def __init__(self, filter_clauses: List, exclude: bool = False) -> None:
def __init__(
self, filter_clauses: List[FilterAction], exclude: bool = False
) -> None:
self.exclude = exclude
self.filter_clauses = filter_clauses
@ -23,9 +26,11 @@ class FilterQuery:
"""
if self.filter_clauses:
if len(self.filter_clauses) == 1:
clause = self.filter_clauses[0]
clause = self.filter_clauses[0].get_text_clause()
else:
clause = sqlalchemy.sql.and_(*self.filter_clauses)
clause = sqlalchemy.sql.and_(
*[x.get_text_clause() for x in self.filter_clauses]
)
clause = sqlalchemy.sql.not_(clause) if self.exclude else clause
expr = expr.where(clause)
return expr

View File

@ -24,20 +24,20 @@ if TYPE_CHECKING: # pragma no cover
class SqlJoin:
def __init__( # noqa: CFQ002
self,
used_aliases: List,
select_from: sqlalchemy.sql.select,
columns: List[sqlalchemy.Column],
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
order_columns: Optional[List],
sorted_orders: OrderedDict,
main_model: Type["Model"],
relation_name: str,
relation_str: str,
related_models: Any = None,
own_alias: str = "",
source_model: Type["Model"] = None,
self,
used_aliases: List,
select_from: sqlalchemy.sql.select,
columns: List[sqlalchemy.Column],
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
order_columns: Optional[List],
sorted_orders: OrderedDict,
main_model: Type["Model"],
relation_name: str,
relation_str: str,
related_models: Any = None,
own_alias: str = "",
source_model: Type["Model"] = None,
) -> None:
self.relation_name = relation_name
self.related_models = related_models or []
@ -62,7 +62,7 @@ class SqlJoin:
def next_model(self) -> Type["Model"]:
if not self._next_model: # pragma: nocover
raise RelationshipInstanceError(
"Cannot link to related table if " "relation to model is not set."
"Cannot link to related table if relation.to model is not set."
)
return self._next_model
@ -90,8 +90,7 @@ class SqlJoin:
"""
return self.main_model.Meta.alias_manager
def on_clause(self, previous_alias: str, from_clause: str,
to_clause: str, ) -> text:
def on_clause(self, previous_alias: str, from_clause: str, to_clause: str,) -> text:
"""
Receives aliases and names of both ends of the join and combines them
into one text clause used in joins.
@ -134,19 +133,23 @@ class SqlJoin:
self.sorted_orders,
)
def _forward_join(self):
def _forward_join(self) -> None:
"""
Process actual join.
Registers complex relation join on encountering of the duplicated alias.
"""
self.next_alias = self.alias_manager.resolve_relation_alias(
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
else:
if '__' in self.relation_str:
relation_key = f'{self.source_model.get_name()}_{self.relation_str}'
if "__" in self.relation_str and self.source_model:
relation_key = f"{self.source_model.get_name()}_{self.relation_str}"
if relation_key not in self.alias_manager:
print(f'registering {relation_key}')
self.next_alias = self.alias_manager.add_alias(
alias_key=relation_key)
alias_key=relation_key
)
else:
self.next_alias = self.alias_manager[relation_key]
self._process_join()
@ -158,8 +161,8 @@ class SqlJoin:
for related_name in self.related_models:
remainder = None
if (
isinstance(self.related_models, dict)
and self.related_models[related_name]
isinstance(self.related_models, dict)
and self.related_models[related_name]
):
remainder = self.related_models[related_name]
self._process_deeper_join(related_name=related_name, remainder=remainder)
@ -194,9 +197,9 @@ class SqlJoin:
main_model=self.next_model,
relation_name=related_name,
related_models=remainder,
relation_str='__'.join([self.relation_str, related_name]),
relation_str="__".join([self.relation_str, related_name]),
own_alias=self.next_alias,
source_model=self.source_model or self.main_model
source_model=self.source_model or self.main_model,
)
(
self.used_aliases,
@ -244,18 +247,18 @@ class SqlJoin:
"""
target_field = self.target_field
is_primary_self_ref = (
target_field.self_reference
and self.relation_name == target_field.self_reference_primary
target_field.self_reference
and self.relation_name == target_field.self_reference_primary
)
if (is_primary_self_ref and not reverse) or (
not is_primary_self_ref and reverse
not is_primary_self_ref and reverse
):
new_part = target_field.default_source_field_name() # type: ignore
else:
new_part = target_field.default_target_field_name() # type: ignore
return new_part
def _process_join(self, ) -> None: # noqa: CFQ002
def _process_join(self,) -> None: # noqa: CFQ002
"""
Resolves to and from column names and table names.
@ -335,10 +338,10 @@ class SqlJoin:
:rtype: bool
"""
return len(condition) >= 2 and (
condition[-2] == part or condition[-2][1:] == part
condition[-2] == part or condition[-2][1:] == part
)
def set_aliased_order_by(self, condition: List[str], to_table: str, ) -> None:
def set_aliased_order_by(self, condition: List[str], to_table: str,) -> None:
"""
Substitute hyphens ('-') with descending order.
Construct actual sqlalchemy text clause using aliased table and column name.
@ -353,7 +356,7 @@ class SqlJoin:
order = text(f"{self.next_alias}_{to_table}.{column_alias} {direction}")
self.sorted_orders["__".join(condition)] = order
def get_order_bys(self, to_table: str, pkname_alias: str, ) -> None: # noqa: CCR001
def get_order_bys(self, to_table: str, pkname_alias: str,) -> None: # noqa: CCR001
"""
Triggers construction of order bys if they are given.
Otherwise by default each table is sorted by a primary key column asc.

View File

@ -290,7 +290,7 @@ class PrefetchQuery:
model_cls=clause_target, select_related=[], filter_clauses=[],
)
kwargs = {f"{filter_column}__in": ids}
filter_clauses, _ = qryclause.filter(**kwargs)
filter_clauses, _ = qryclause.prepare_filter(**kwargs)
return filter_clauses
return []

View File

@ -8,6 +8,7 @@ from sqlalchemy import text
import ormar # noqa I100
from ormar.models.helpers.models import group_related_list
from ormar.queryset import FilterQuery, LimitQuery, OffsetQuery, OrderQuery
from ormar.queryset.filter_action import FilterAction
from ormar.queryset.join import SqlJoin
if TYPE_CHECKING: # pragma no cover
@ -18,8 +19,8 @@ class Query:
def __init__( # noqa CFQ002
self,
model_cls: Type["Model"],
filter_clauses: List,
exclude_clauses: List,
filter_clauses: List[FilterAction],
exclude_clauses: List[FilterAction],
select_related: List,
limit_count: Optional[int],
offset: Optional[int],
@ -200,12 +201,12 @@ class Query:
filters_to_use = [
filter_clause
for filter_clause in self.filter_clauses
if filter_clause.text.startswith(f"{self.table.name}.")
if filter_clause.table_prefix == ""
]
excludes_to_use = [
filter_clause
for filter_clause in self.exclude_clauses
if filter_clause.text.startswith(f"{self.table.name}.")
if filter_clause.table_prefix == ""
]
sorts_to_use = {k: v for k, v in self.sorted_orders.items() if "__" not in k}
expr = FilterQuery(filter_clauses=filters_to_use).apply(expr)

View File

@ -236,7 +236,7 @@ class QuerySet:
select_related=self._select_related,
filter_clauses=self.filter_clauses,
)
filter_clauses, select_related = qryclause.filter(**kwargs)
filter_clauses, select_related = qryclause.prepare_filter(**kwargs)
if _exclude:
exclude_clauses = filter_clauses
filter_clauses = self.filter_clauses

View File

@ -7,10 +7,13 @@ from typing import (
Sequence,
Set,
TYPE_CHECKING,
Tuple,
Type,
Union,
)
from ormar.fields import ManyToManyField
if TYPE_CHECKING: # pragma no cover
from ormar import Model
@ -212,3 +215,35 @@ def extract_models_to_dict_of_lists(
for model in models:
extract_nested_models(model, model_type, select_dict, extracted)
return extracted
def get_relationship_alias_model_and_str(
source_model: Type["Model"], related_parts: List
) -> Tuple[str, Type["Model"], str]:
"""
Walks the relation to retrieve the actual model on which the clause should be
constructed, extracts alias based on last relation leading to target model.
:param related_parts: list of related names extracted from string
:type related_parts: Union[List, List[str]]
:param source_model: model from which relation starts
:type source_model: Type[Model]
:return: table prefix, target model and relation string
:rtype: Tuple[str, Type["Model"], str]
"""
table_prefix = ""
model_cls = source_model
previous_model = model_cls
manager = model_cls.Meta.alias_manager
for relation in related_parts:
related_field = model_cls.Meta.model_fields[relation]
if issubclass(related_field, ManyToManyField):
previous_model = related_field.through
relation = related_field.default_target_field_name() # type: ignore
table_prefix = manager.resolve_relation_alias(
from_model=previous_model, relation_name=relation
)
model_cls = related_field.to
previous_model = model_cls
relation_str = "__".join(related_parts)
return table_prefix, model_cls, relation_str

View File

@ -1,7 +1,7 @@
import string
import uuid
from random import choices
from typing import Dict, List, TYPE_CHECKING, Type
from typing import Any, Dict, List, TYPE_CHECKING, Type
import sqlalchemy
from sqlalchemy import text
@ -33,10 +33,10 @@ class AliasManager:
def __init__(self) -> None:
self._aliases_new: Dict[str, str] = dict()
def __contains__(self, item):
def __contains__(self, item: str) -> bool:
return self._aliases_new.__contains__(item)
def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
return self._aliases_new.__getitem__(key)
@staticmethod

View File

@ -187,31 +187,44 @@ async def test_m2m_self_forwardref_relation(cleanup):
await billy.friends.add(steve)
billy_check = await Child.objects.select_related(
["friends", "favourite_game", "least_favourite_game",
"friends__favourite_game", "friends__least_favourite_game"]
[
"friends",
"favourite_game",
"least_favourite_game",
"friends__favourite_game",
"friends__least_favourite_game",
]
).get(name="Billy")
assert len(billy_check.friends) == 2
assert billy_check.friends[0].name == "Kate"
assert billy_check.friends[0].favourite_game.name == 'Checkers'
assert billy_check.friends[0].least_favourite_game.name == 'Uno'
assert billy_check.friends[0].favourite_game.name == "Checkers"
assert billy_check.friends[0].least_favourite_game.name == "Uno"
assert billy_check.friends[1].name == "Steve"
assert billy_check.friends[1].favourite_game.name == 'Jenga'
assert billy_check.friends[1].least_favourite_game.name == 'Uno'
assert billy_check.friends[1].favourite_game.name == "Jenga"
assert billy_check.friends[1].least_favourite_game.name == "Uno"
assert billy_check.favourite_game.name == "Uno"
kate_check = await Child.objects.select_related(["also_friends",]).get(
kate_check = await Child.objects.select_related(["also_friends"]).get(
name="Kate"
)
assert len(kate_check.also_friends) == 1
assert kate_check.also_friends[0].name == "Billy"
# TODO: Fix filters with complex prefixes
# billy_check = await Child.objects.select_related(
# ["friends", "favourite_game", "least_favourite_game",
# "friends__favourite_game", "friends__least_favourite_game"]
# ).filter(friends__favourite_game__name="Checkers").get(name="Billy")
# assert len(billy_check.friends) == 1
# assert billy_check.friends[0].name == "Kate"
# assert billy_check.friends[0].favourite_game.name == 'Checkers'
# assert billy_check.friends[0].least_favourite_game.name == 'Uno'
billy_check = (
await Child.objects.select_related(
[
"friends",
"favourite_game",
"least_favourite_game",
"friends__favourite_game",
"friends__least_favourite_game",
]
)
.filter(friends__favourite_game__name="Checkers")
.get(name="Billy")
)
assert len(billy_check.friends) == 1
assert billy_check.friends[0].name == "Kate"
assert billy_check.friends[0].favourite_game.name == "Checkers"
assert billy_check.friends[0].least_favourite_game.name == "Uno"

View File

@ -2,8 +2,16 @@ from ormar.models.helpers.models import group_related_list
def test_group_related_list():
given = ['friends__least_favourite_game', 'least_favourite_game', 'friends',
'favourite_game', 'friends__favourite_game']
expected = {'least_favourite_game': [], 'favourite_game': [],
'friends': ['favourite_game', 'least_favourite_game']}
given = [
"friends__least_favourite_game",
"least_favourite_game",
"friends",
"favourite_game",
"friends__favourite_game",
]
expected = {
"least_favourite_game": [],
"favourite_game": [],
"friends": ["favourite_game", "least_favourite_game"],
}
assert group_related_list(given) == expected

View File

@ -101,15 +101,10 @@ async def test_model_multiple_instances_of_same_table_in_schema():
async with database:
await create_data()
classes = await SchoolClass.objects.select_related(
["teachers__category__department", "students"]
["teachers__category__department", "students__category__department"]
).all()
assert classes[0].name == "Math"
assert classes[0].students[0].name == "Jane"
assert len(classes[0].dict().get("students")) == 2
assert classes[0].teachers[0].category.department.name == "Law Department"
assert classes[0].students[0].category.pk is not None
assert classes[0].students[0].category.name is None
await classes[0].students[0].category.load()
await classes[0].students[0].category.department.load()
assert classes[0].students[0].category.department.name == "Math Department"