diff --git a/ormar/models/helpers/models.py b/ormar/models/helpers/models.py index 75282d6..727bdd7 100644 --- a/ormar/models/helpers/models.py +++ b/ormar/models/helpers/models.py @@ -22,12 +22,12 @@ def is_field_an_forward_ref(field: Type["BaseField"]) -> bool: :rtype: bool """ return issubclass(field, ForeignKeyField) and ( - field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef + field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef ) def populate_default_options_values( - new_model: Type["Model"], model_fields: Dict + new_model: Type["Model"], model_fields: Dict ) -> None: """ Sets all optional Meta values to it's defaults @@ -52,8 +52,7 @@ def populate_default_options_values( new_model.Meta.abstract = False if any( - is_field_an_forward_ref(field) for field in - new_model.Meta.model_fields.values() + is_field_an_forward_ref(field) for field in new_model.Meta.model_fields.values() ): new_model.Meta.requires_ref_update = True else: @@ -78,7 +77,7 @@ def extract_annotations_and_default_vals(attrs: Dict) -> Tuple[Dict, Dict]: # cannot be in relations helpers due to cyclical import def validate_related_names_in_relations( # noqa CCR001 - model_fields: Dict, new_model: Type["Model"] + model_fields: Dict, new_model: Type["Model"] ) -> None: """ Performs a validation of relation_names in relation fields. @@ -135,12 +134,11 @@ def group_related_list(list_: List) -> Dict: grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0]) for key, group in grouped: group_list = list(group) - new = sorted([ - "__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1 - ]) + new = sorted( + ["__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1] + ) if any("__" in x for x in new): result_dict[key] = group_related_list(new) else: result_dict.setdefault(key, []).extend(new) - return {k: v for k, v in - sorted(result_dict.items(), key=lambda item: len(item[1]))} + return {k: v for k, v in sorted(result_dict.items(), key=lambda item: len(item[1]))} diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index 3985b5e..e52ae4a 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -1,40 +1,35 @@ -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type - -import sqlalchemy -from sqlalchemy import text +import itertools +from dataclasses import dataclass +from typing import Any, List, TYPE_CHECKING, Tuple, Type import ormar # noqa I100 -from ormar.exceptions import QueryDefinitionError -from ormar.fields.many_to_many import ManyToManyField +from ormar.queryset.filter_action import FilterAction +from ormar.queryset.utils import get_relationship_alias_model_and_str if TYPE_CHECKING: # pragma no cover from ormar import Model -FILTER_OPERATORS = { - "exact": "__eq__", - "iexact": "ilike", - "contains": "like", - "icontains": "ilike", - "startswith": "like", - "istartswith": "ilike", - "endswith": "like", - "iendswith": "ilike", - "in": "in_", - "gt": "__gt__", - "gte": "__ge__", - "lt": "__lt__", - "lte": "__le__", -} -ESCAPE_CHARACTERS = ["%", "_"] + +@dataclass +class Prefix: + source_model: Type["Model"] + table_prefix: str + model_cls: Type["Model"] + relation_str: str + + @property + def alias_key(self) -> str: + source_model_name = self.source_model.get_name() + return f"{source_model_name}_" f"{self.relation_str}" class QueryClause: """ - Constructs where clauses from strings passed as arguments + Constructs FilterActions from strings passed as arguments """ def __init__( - self, model_cls: Type["Model"], filter_clauses: List, select_related: List, + self, model_cls: Type["Model"], filter_clauses: List, select_related: List, ) -> None: self._select_related = select_related[:] @@ -43,9 +38,9 @@ class QueryClause: self.model_cls = model_cls self.table = self.model_cls.Meta.table - def filter( # noqa: A003 - self, **kwargs: Any - ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: + def prepare_filter( # noqa: A003 + self, **kwargs: Any + ) -> Tuple[List[FilterAction], List[str]]: """ Main external access point that processes the clauses into sqlalchemy text clauses and updates select_related list with implicit related tables @@ -65,8 +60,8 @@ class QueryClause: return filter_clauses, select_related def _populate_filter_clauses( - self, **kwargs: Any - ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: + self, **kwargs: Any + ) -> Tuple[List[FilterAction], List[str]]: """ Iterates all clauses and extracts used operator and field from related models if needed. Based on the chain of related names the target table @@ -81,238 +76,84 @@ class QueryClause: select_related = list(self._select_related) for key, value in kwargs.items(): - table_prefix = "" - if "__" in key: - parts = key.split("__") - - ( - op, - field_name, - related_parts, - ) = self._extract_operator_field_and_related(parts) - - model_cls = self.model_cls - if related_parts: - ( - select_related, - table_prefix, - model_cls, - ) = self._determine_filter_target_table( - related_parts=related_parts, - select_related=select_related, - field_name=field_name - ) - - table = model_cls.Meta.table - column = model_cls.Meta.table.columns[field_name] - - else: - op = "exact" - column = self.table.columns[self.model_cls.get_column_alias(key)] - table = self.table - - clause = self._process_column_clause_for_operator_and_value( - value, op, column, table, table_prefix + filter_action = FilterAction( + filter_str=key, value=value, model_cls=self.model_cls ) - filter_clauses.append(clause) + select_related = filter_action.update_select_related( + select_related=select_related + ) + + filter_clauses.append(filter_action) + + self._register_complex_duplicates(select_related) + filter_clauses = self._switch_filter_action_prefixes( + filter_clauses=filter_clauses + ) return filter_clauses, select_related - def _process_column_clause_for_operator_and_value( - self, - value: Any, - op: str, - column: sqlalchemy.Column, - table: sqlalchemy.Table, - table_prefix: str, - ) -> sqlalchemy.sql.expression.TextClause: + def _register_complex_duplicates(self, select_related: List[str]) -> None: """ - Escapes characters if it's required. - Substitutes values of the models if value is a ormar Model with its pk value. - Compiles the clause. + Checks if duplicate aliases are presented which can happen in self relation + or when two joins end with the same pair of models. - :param value: value of the filter - :type value: Any - :param op: filter operator - :type op: str - :param column: column on which filter should be applied - :type column: sqlalchemy.sql.schema.Column - :param table: table on which filter should be applied - :type table: sqlalchemy.sql.schema.Table - :param table_prefix: prefix from AliasManager - :type table_prefix: str - :return: complied and escaped clause - :rtype: sqlalchemy.sql.elements.TextClause - """ - value, has_escaped_character = self._escape_characters_in_clause(op, value) + If there are duplicates, the all duplicated joins are registered as source + model and whole relation key (not just last relation name). - if isinstance(value, ormar.Model): - value = value.pk - - op_attr = FILTER_OPERATORS[op] - clause = getattr(column, op_attr)(value) - clause = self._compile_clause( - clause, - column, - table, - table_prefix, - modifiers={"escape": "\\" if has_escaped_character else None}, - ) - return clause - - def _determine_filter_target_table( - self, related_parts: List[str], select_related: List[str], field_name: str - ) -> Tuple[List[str], str, Type["Model"]]: - """ - Adds related strings to select_related list otherwise the clause would fail as - the required columns would not be present. That means that select_related - list is filled with missing values present in filters. - - Walks the relation to retrieve the actual model on which the clause should be - constructed, extracts alias based on last relation leading to target model. - - :param related_parts: list of split parts of related string - :type related_parts: List[str] - :param select_related: list of related models + :param select_related: list of relation strings :type select_related: List[str] - :return: list of related models, table_prefix, final model class - :rtype: Tuple[List[str], str, Type[Model]] + :return: None + :rtype: None """ - table_prefix = "" - model_cls = self.model_cls - select_related = [relation for relation in select_related] + prefixes = self._parse_related_prefixes(select_related=select_related) - # Add any implied select_related - related_str = "__".join(related_parts) - if related_str not in select_related: - select_related.append(related_str) - - # Walk the relationships to the actual model class - # against which the comparison is being made. - previous_model = model_cls - manager = model_cls.Meta.alias_manager - for relation in related_parts: - related_field = model_cls.Meta.model_fields[relation] - if issubclass(related_field, ManyToManyField): - previous_model = related_field.through - relation = related_field.default_target_field_name() # type: ignore - table_prefix = manager.resolve_relation_alias( - from_model=previous_model, relation_name=relation + manager = self.model_cls.Meta.alias_manager + filtered_prefixes = sorted(prefixes, key=lambda x: x.table_prefix) + grouped = itertools.groupby(filtered_prefixes, key=lambda x: x.table_prefix) + for _, group in grouped: + sorted_group = sorted( + group, key=lambda x: len(x.relation_str), reverse=True ) - model_cls = related_field.to - previous_model = model_cls - # handle duplicated aliases in nested relations - # TODO: check later and remove nocover - complex_prefix = manager.resolve_relation_alias( - from_model=self.model_cls, - relation_name='__'.join([related_str, field_name]) - ) - if complex_prefix: # pragma: nocover - table_prefix = complex_prefix - return select_related, table_prefix, model_cls + for prefix in sorted_group[:-1]: + if prefix.alias_key not in manager: + manager.add_alias(alias_key=prefix.alias_key) - def _compile_clause( - self, - clause: sqlalchemy.sql.expression.BinaryExpression, - column: sqlalchemy.Column, - table: sqlalchemy.Table, - table_prefix: str, - modifiers: Dict, - ) -> sqlalchemy.sql.expression.TextClause: + def _parse_related_prefixes(self, select_related: List[str]) -> List[Prefix]: """ - Compiles the clause to str using appropriate database dialect, replace columns - names with aliased names and converts it back to TextClause. + Walks all relation strings and parses the target models and prefixes. - :param clause: original not compiled clause - :type clause: sqlalchemy.sql.elements.BinaryExpression - :param column: column on which filter should be applied - :type column: sqlalchemy.sql.schema.Column - :param table: table on which filter should be applied - :type table: sqlalchemy.sql.schema.Table - :param table_prefix: prefix from AliasManager - :type table_prefix: str - :param modifiers: sqlalchemy modifiers - used only to escape chars here - :type modifiers: Dict[str, NoneType] - :return: compiled and escaped clause - :rtype: sqlalchemy.sql.elements.TextClause + :param select_related: list of relation strings + :type select_related: List[str] + :return: list of parsed prefixes + :rtype: List[Prefix] """ - for modifier, modifier_value in modifiers.items(): - clause.modifiers[modifier] = modifier_value - - clause_text = str( - clause.compile( - dialect=self.model_cls.Meta.database._backend._dialect, - compile_kwargs={"literal_binds": True}, + prefixes: List[Prefix] = [] + for related in select_related: + prefix = Prefix( + self.model_cls, + *get_relationship_alias_model_and_str( + self.model_cls, related.split("__") + ), ) - ) - alias = f"{table_prefix}_" if table_prefix else "" - aliased_name = f"{alias}{table.name}.{column.name}" - clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name) - clause = text(clause_text) - return clause + prefixes.append(prefix) + return prefixes - @staticmethod - def _escape_characters_in_clause(op: str, value: Any) -> Tuple[Any, bool]: + def _switch_filter_action_prefixes( + self, filter_clauses: List[FilterAction] + ) -> List[FilterAction]: """ - Escapes the special characters ["%", "_"] if needed. - Adds `%` for `like` queries. + Substitutes aliases for filter action if the complex key (whole relation str) is + present in alias_manager. - :raises QueryDefinitionError: if contains or icontains is used with - ormar model instance - :param op: operator used in query - :type op: str - :param value: value of the filter - :type value: Any - :return: escaped value and flag if escaping is needed - :rtype: Tuple[Any, bool] + :param filter_clauses: raw list of actions + :type filter_clauses: List[FilterAction] + :return: list of actions with aliases changed if needed + :rtype: List[FilterAction] """ - has_escaped_character = False - - if op not in [ - "contains", - "icontains", - "startswith", - "istartswith", - "endswith", - "iendswith", - ]: - return value, has_escaped_character - - if isinstance(value, ormar.Model): - raise QueryDefinitionError( - "You cannot use contains and icontains with instance of the Model" + manager = self.model_cls.Meta.alias_manager + for action in filter_clauses: + new_alias = manager.resolve_relation_alias( + self.model_cls, action.related_str ) - - has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value) - - if has_escaped_character: - # enable escape modifier - for char in ESCAPE_CHARACTERS: - value = value.replace(char, f"\\{char}") - prefix = "%" if "start" not in op else "" - sufix = "%" if "end" not in op else "" - value = f"{prefix}{value}{sufix}" - - return value, has_escaped_character - - @staticmethod - def _extract_operator_field_and_related( - parts: List[str], - ) -> Tuple[str, str, Optional[List]]: - """ - Splits filter query key and extracts required parts. - - :param parts: split filter query key - :type parts: List[str] - :return: operator, field_name, list of related parts - :rtype: Tuple[str, str, Optional[List]] - """ - if parts[-1] in FILTER_OPERATORS: - op = parts[-1] - field_name = parts[-2] - related_parts = parts[:-2] - else: - op = "exact" - field_name = parts[-1] - related_parts = parts[:-1] - - return op, field_name, related_parts + if "__" in action.related_str and new_alias: + action.table_prefix = new_alias + return filter_clauses diff --git a/ormar/queryset/filter_action.py b/ormar/queryset/filter_action.py new file mode 100644 index 0000000..4f26864 --- /dev/null +++ b/ormar/queryset/filter_action.py @@ -0,0 +1,201 @@ +from typing import Any, Dict, List, TYPE_CHECKING, Type + +import sqlalchemy +from sqlalchemy import text + +import ormar # noqa: I100, I202 +from ormar.exceptions import QueryDefinitionError +from ormar.queryset.utils import get_relationship_alias_model_and_str + +if TYPE_CHECKING: # pragma: nocover + from ormar import Model + +FILTER_OPERATORS = { + "exact": "__eq__", + "iexact": "ilike", + "contains": "like", + "icontains": "ilike", + "startswith": "like", + "istartswith": "ilike", + "endswith": "like", + "iendswith": "ilike", + "in": "in_", + "gt": "__gt__", + "gte": "__ge__", + "lt": "__lt__", + "lte": "__le__", +} +ESCAPE_CHARACTERS = ["%", "_"] + + +class FilterAction: + """ + Filter Actions is populated by queryset when filter() is called. + + All required params are extracted but kept raw until actual filter clause value + is required -> then the action is converted into text() clause. + + Extracted in order to easily change table prefixes on complex relations. + """ + + def __init__(self, filter_str: str, value: Any, model_cls: Type["Model"]) -> None: + parts = filter_str.split("__") + if parts[-1] in FILTER_OPERATORS: + self.operator = parts[-1] + self.field_name = parts[-2] + self.related_parts = parts[:-2] + else: + self.operator = "exact" + self.field_name = parts[-1] + self.related_parts = parts[:-1] + + self.filter_value = value + self.table_prefix = "" + self.source_model = model_cls + self.target_model = model_cls + self._determine_filter_target_table() + self._escape_characters_in_clause() + + @property + def table(self) -> sqlalchemy.Table: + """Shortcut to sqlalchemy Table of filtered target model""" + return self.target_model.Meta.table + + @property + def column(self) -> sqlalchemy.Column: + """Shortcut to sqlalchemy column of filtered target model""" + aliased_name = self.target_model.get_column_alias(self.field_name) + return self.target_model.Meta.table.columns[aliased_name] + + def has_escaped_characters(self) -> bool: + """Check if value is a string that contains characters to escape""" + return isinstance(self.filter_value, str) and any( + c for c in ESCAPE_CHARACTERS if c in self.filter_value + ) + + def update_select_related(self, select_related: List[str]) -> List[str]: + """ + Updates list of select related with related part included in the filter key. + That way If you want to just filter by relation you do not have to provide + select_related separately. + + :param select_related: list of relation join strings + :type select_related: List[str] + :return: list of relation joins with implied joins from filter added + :rtype: List[str] + """ + select_related = select_related[:] + if self.related_str and not any( + rel.startswith(self.related_str) for rel in select_related + ): + select_related.append(self.related_str) + return select_related + + def _determine_filter_target_table(self) -> None: + """ + Walks the relation to retrieve the actual model on which the clause should be + constructed, extracts alias based on last relation leading to target model. + """ + ( + self.table_prefix, + self.target_model, + self.related_str, + ) = get_relationship_alias_model_and_str(self.source_model, self.related_parts) + + def _escape_characters_in_clause(self) -> None: + """ + Escapes the special characters ["%", "_"] if needed. + Adds `%` for `like` queries. + + :raises QueryDefinitionError: if contains or icontains is used with + ormar model instance + :return: escaped value and flag if escaping is needed + :rtype: Tuple[Any, bool] + """ + self.has_escaped_character = False + if self.operator in [ + "contains", + "icontains", + "startswith", + "istartswith", + "endswith", + "iendswith", + ]: + if isinstance(self.filter_value, ormar.Model): + raise QueryDefinitionError( + "You cannot use contains and icontains with instance of the Model" + ) + self.has_escaped_character = self.has_escaped_characters() + if self.has_escaped_character: + self._escape_chars() + self._prefix_suffix_quote() + + def _escape_chars(self) -> None: + """Actually replaces chars to escape in value""" + for char in ESCAPE_CHARACTERS: + self.filter_value = self.filter_value.replace(char, f"\\{char}") + + def _prefix_suffix_quote(self) -> None: + """ + Adds % to the beginning of the value if operator checks for containment and not + starts with. + + Adds % to the end of the value if operator checks for containment and not + end with. + :return: + :rtype: + """ + prefix = "%" if "start" not in self.operator else "" + sufix = "%" if "end" not in self.operator else "" + self.filter_value = f"{prefix}{self.filter_value}{sufix}" + + def get_text_clause(self,) -> sqlalchemy.sql.expression.TextClause: + """ + Escapes characters if it's required. + Substitutes values of the models if value is a ormar Model with its pk value. + Compiles the clause. + + :return: complied and escaped clause + :rtype: sqlalchemy.sql.elements.TextClause + """ + + if isinstance(self.filter_value, ormar.Model): + self.filter_value = self.filter_value.pk + + op_attr = FILTER_OPERATORS[self.operator] + clause = getattr(self.column, op_attr)(self.filter_value) + clause = self._compile_clause( + clause, modifiers={"escape": "\\" if self.has_escaped_character else None}, + ) + return clause + + def _compile_clause( + self, clause: sqlalchemy.sql.expression.BinaryExpression, modifiers: Dict, + ) -> sqlalchemy.sql.expression.TextClause: + """ + Compiles the clause to str using appropriate database dialect, replace columns + names with aliased names and converts it back to TextClause. + + :param clause: original not compiled clause + :type clause: sqlalchemy.sql.elements.BinaryExpression + :param modifiers: sqlalchemy modifiers - used only to escape chars here + :type modifiers: Dict[str, NoneType] + :return: compiled and escaped clause + :rtype: sqlalchemy.sql.elements.TextClause + """ + for modifier, modifier_value in modifiers.items(): + clause.modifiers[modifier] = modifier_value + + clause_text = str( + clause.compile( + dialect=self.target_model.Meta.database._backend._dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + alias = f"{self.table_prefix}_" if self.table_prefix else "" + aliased_name = f"{alias}{self.table.name}.{self.column.name}" + clause_text = clause_text.replace( + f"{self.table.name}.{self.column.name}", aliased_name + ) + clause = text(clause_text) + return clause diff --git a/ormar/queryset/filter_query.py b/ormar/queryset/filter_query.py index cb43170..4100f16 100644 --- a/ormar/queryset/filter_query.py +++ b/ormar/queryset/filter_query.py @@ -1,6 +1,7 @@ from typing import List import sqlalchemy +from ormar.queryset.filter_action import FilterAction class FilterQuery: @@ -8,7 +9,9 @@ class FilterQuery: Modifies the select query with given list of where/filter clauses. """ - def __init__(self, filter_clauses: List, exclude: bool = False) -> None: + def __init__( + self, filter_clauses: List[FilterAction], exclude: bool = False + ) -> None: self.exclude = exclude self.filter_clauses = filter_clauses @@ -23,9 +26,11 @@ class FilterQuery: """ if self.filter_clauses: if len(self.filter_clauses) == 1: - clause = self.filter_clauses[0] + clause = self.filter_clauses[0].get_text_clause() else: - clause = sqlalchemy.sql.and_(*self.filter_clauses) + clause = sqlalchemy.sql.and_( + *[x.get_text_clause() for x in self.filter_clauses] + ) clause = sqlalchemy.sql.not_(clause) if self.exclude else clause expr = expr.where(clause) return expr diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index 4013efe..0b44078 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -24,20 +24,20 @@ if TYPE_CHECKING: # pragma no cover class SqlJoin: def __init__( # noqa: CFQ002 - self, - used_aliases: List, - select_from: sqlalchemy.sql.select, - columns: List[sqlalchemy.Column], - fields: Optional[Union[Set, Dict]], - exclude_fields: Optional[Union[Set, Dict]], - order_columns: Optional[List], - sorted_orders: OrderedDict, - main_model: Type["Model"], - relation_name: str, - relation_str: str, - related_models: Any = None, - own_alias: str = "", - source_model: Type["Model"] = None, + self, + used_aliases: List, + select_from: sqlalchemy.sql.select, + columns: List[sqlalchemy.Column], + fields: Optional[Union[Set, Dict]], + exclude_fields: Optional[Union[Set, Dict]], + order_columns: Optional[List], + sorted_orders: OrderedDict, + main_model: Type["Model"], + relation_name: str, + relation_str: str, + related_models: Any = None, + own_alias: str = "", + source_model: Type["Model"] = None, ) -> None: self.relation_name = relation_name self.related_models = related_models or [] @@ -62,7 +62,7 @@ class SqlJoin: def next_model(self) -> Type["Model"]: if not self._next_model: # pragma: nocover raise RelationshipInstanceError( - "Cannot link to related table if " "relation to model is not set." + "Cannot link to related table if relation.to model is not set." ) return self._next_model @@ -90,8 +90,7 @@ class SqlJoin: """ return self.main_model.Meta.alias_manager - def on_clause(self, previous_alias: str, from_clause: str, - to_clause: str, ) -> text: + def on_clause(self, previous_alias: str, from_clause: str, to_clause: str,) -> text: """ Receives aliases and names of both ends of the join and combines them into one text clause used in joins. @@ -134,19 +133,23 @@ class SqlJoin: self.sorted_orders, ) - def _forward_join(self): + def _forward_join(self) -> None: + """ + Process actual join. + Registers complex relation join on encountering of the duplicated alias. + """ self.next_alias = self.alias_manager.resolve_relation_alias( from_model=self.target_field.owner, relation_name=self.relation_name ) if self.next_alias not in self.used_aliases: self._process_join() else: - if '__' in self.relation_str: - relation_key = f'{self.source_model.get_name()}_{self.relation_str}' + if "__" in self.relation_str and self.source_model: + relation_key = f"{self.source_model.get_name()}_{self.relation_str}" if relation_key not in self.alias_manager: - print(f'registering {relation_key}') self.next_alias = self.alias_manager.add_alias( - alias_key=relation_key) + alias_key=relation_key + ) else: self.next_alias = self.alias_manager[relation_key] self._process_join() @@ -158,8 +161,8 @@ class SqlJoin: for related_name in self.related_models: remainder = None if ( - isinstance(self.related_models, dict) - and self.related_models[related_name] + isinstance(self.related_models, dict) + and self.related_models[related_name] ): remainder = self.related_models[related_name] self._process_deeper_join(related_name=related_name, remainder=remainder) @@ -194,9 +197,9 @@ class SqlJoin: main_model=self.next_model, relation_name=related_name, related_models=remainder, - relation_str='__'.join([self.relation_str, related_name]), + relation_str="__".join([self.relation_str, related_name]), own_alias=self.next_alias, - source_model=self.source_model or self.main_model + source_model=self.source_model or self.main_model, ) ( self.used_aliases, @@ -244,18 +247,18 @@ class SqlJoin: """ target_field = self.target_field is_primary_self_ref = ( - target_field.self_reference - and self.relation_name == target_field.self_reference_primary + target_field.self_reference + and self.relation_name == target_field.self_reference_primary ) if (is_primary_self_ref and not reverse) or ( - not is_primary_self_ref and reverse + not is_primary_self_ref and reverse ): new_part = target_field.default_source_field_name() # type: ignore else: new_part = target_field.default_target_field_name() # type: ignore return new_part - def _process_join(self, ) -> None: # noqa: CFQ002 + def _process_join(self,) -> None: # noqa: CFQ002 """ Resolves to and from column names and table names. @@ -335,10 +338,10 @@ class SqlJoin: :rtype: bool """ return len(condition) >= 2 and ( - condition[-2] == part or condition[-2][1:] == part + condition[-2] == part or condition[-2][1:] == part ) - def set_aliased_order_by(self, condition: List[str], to_table: str, ) -> None: + def set_aliased_order_by(self, condition: List[str], to_table: str,) -> None: """ Substitute hyphens ('-') with descending order. Construct actual sqlalchemy text clause using aliased table and column name. @@ -353,7 +356,7 @@ class SqlJoin: order = text(f"{self.next_alias}_{to_table}.{column_alias} {direction}") self.sorted_orders["__".join(condition)] = order - def get_order_bys(self, to_table: str, pkname_alias: str, ) -> None: # noqa: CCR001 + def get_order_bys(self, to_table: str, pkname_alias: str,) -> None: # noqa: CCR001 """ Triggers construction of order bys if they are given. Otherwise by default each table is sorted by a primary key column asc. diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 96de967..4c8c6d7 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -290,7 +290,7 @@ class PrefetchQuery: model_cls=clause_target, select_related=[], filter_clauses=[], ) kwargs = {f"{filter_column}__in": ids} - filter_clauses, _ = qryclause.filter(**kwargs) + filter_clauses, _ = qryclause.prepare_filter(**kwargs) return filter_clauses return [] diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 60f3b5e..edb28c1 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -8,6 +8,7 @@ from sqlalchemy import text import ormar # noqa I100 from ormar.models.helpers.models import group_related_list from ormar.queryset import FilterQuery, LimitQuery, OffsetQuery, OrderQuery +from ormar.queryset.filter_action import FilterAction from ormar.queryset.join import SqlJoin if TYPE_CHECKING: # pragma no cover @@ -18,8 +19,8 @@ class Query: def __init__( # noqa CFQ002 self, model_cls: Type["Model"], - filter_clauses: List, - exclude_clauses: List, + filter_clauses: List[FilterAction], + exclude_clauses: List[FilterAction], select_related: List, limit_count: Optional[int], offset: Optional[int], @@ -200,12 +201,12 @@ class Query: filters_to_use = [ filter_clause for filter_clause in self.filter_clauses - if filter_clause.text.startswith(f"{self.table.name}.") + if filter_clause.table_prefix == "" ] excludes_to_use = [ filter_clause for filter_clause in self.exclude_clauses - if filter_clause.text.startswith(f"{self.table.name}.") + if filter_clause.table_prefix == "" ] sorts_to_use = {k: v for k, v in self.sorted_orders.items() if "__" not in k} expr = FilterQuery(filter_clauses=filters_to_use).apply(expr) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 08c4be5..dff9635 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -236,7 +236,7 @@ class QuerySet: select_related=self._select_related, filter_clauses=self.filter_clauses, ) - filter_clauses, select_related = qryclause.filter(**kwargs) + filter_clauses, select_related = qryclause.prepare_filter(**kwargs) if _exclude: exclude_clauses = filter_clauses filter_clauses = self.filter_clauses diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index 12a7fa5..e2cf33a 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -7,10 +7,13 @@ from typing import ( Sequence, Set, TYPE_CHECKING, + Tuple, Type, Union, ) +from ormar.fields import ManyToManyField + if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -212,3 +215,35 @@ def extract_models_to_dict_of_lists( for model in models: extract_nested_models(model, model_type, select_dict, extracted) return extracted + + +def get_relationship_alias_model_and_str( + source_model: Type["Model"], related_parts: List +) -> Tuple[str, Type["Model"], str]: + """ + Walks the relation to retrieve the actual model on which the clause should be + constructed, extracts alias based on last relation leading to target model. + :param related_parts: list of related names extracted from string + :type related_parts: Union[List, List[str]] + :param source_model: model from which relation starts + :type source_model: Type[Model] + :return: table prefix, target model and relation string + :rtype: Tuple[str, Type["Model"], str] + """ + table_prefix = "" + model_cls = source_model + previous_model = model_cls + manager = model_cls.Meta.alias_manager + for relation in related_parts: + related_field = model_cls.Meta.model_fields[relation] + if issubclass(related_field, ManyToManyField): + previous_model = related_field.through + relation = related_field.default_target_field_name() # type: ignore + table_prefix = manager.resolve_relation_alias( + from_model=previous_model, relation_name=relation + ) + model_cls = related_field.to + previous_model = model_cls + relation_str = "__".join(related_parts) + + return table_prefix, model_cls, relation_str diff --git a/ormar/relations/alias_manager.py b/ormar/relations/alias_manager.py index 803abc5..cd3dc8b 100644 --- a/ormar/relations/alias_manager.py +++ b/ormar/relations/alias_manager.py @@ -1,7 +1,7 @@ import string import uuid from random import choices -from typing import Dict, List, TYPE_CHECKING, Type +from typing import Any, Dict, List, TYPE_CHECKING, Type import sqlalchemy from sqlalchemy import text @@ -33,10 +33,10 @@ class AliasManager: def __init__(self) -> None: self._aliases_new: Dict[str, str] = dict() - def __contains__(self, item): + def __contains__(self, item: str) -> bool: return self._aliases_new.__contains__(item) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: return self._aliases_new.__getitem__(key) @staticmethod diff --git a/tests/test_forward_refs.py b/tests/test_forward_refs.py index 9132c41..60505b6 100644 --- a/tests/test_forward_refs.py +++ b/tests/test_forward_refs.py @@ -187,31 +187,44 @@ async def test_m2m_self_forwardref_relation(cleanup): await billy.friends.add(steve) billy_check = await Child.objects.select_related( - ["friends", "favourite_game", "least_favourite_game", - "friends__favourite_game", "friends__least_favourite_game"] + [ + "friends", + "favourite_game", + "least_favourite_game", + "friends__favourite_game", + "friends__least_favourite_game", + ] ).get(name="Billy") assert len(billy_check.friends) == 2 assert billy_check.friends[0].name == "Kate" - assert billy_check.friends[0].favourite_game.name == 'Checkers' - assert billy_check.friends[0].least_favourite_game.name == 'Uno' + assert billy_check.friends[0].favourite_game.name == "Checkers" + assert billy_check.friends[0].least_favourite_game.name == "Uno" assert billy_check.friends[1].name == "Steve" - assert billy_check.friends[1].favourite_game.name == 'Jenga' - assert billy_check.friends[1].least_favourite_game.name == 'Uno' + assert billy_check.friends[1].favourite_game.name == "Jenga" + assert billy_check.friends[1].least_favourite_game.name == "Uno" assert billy_check.favourite_game.name == "Uno" - kate_check = await Child.objects.select_related(["also_friends",]).get( + kate_check = await Child.objects.select_related(["also_friends"]).get( name="Kate" ) assert len(kate_check.also_friends) == 1 assert kate_check.also_friends[0].name == "Billy" - # TODO: Fix filters with complex prefixes - # billy_check = await Child.objects.select_related( - # ["friends", "favourite_game", "least_favourite_game", - # "friends__favourite_game", "friends__least_favourite_game"] - # ).filter(friends__favourite_game__name="Checkers").get(name="Billy") - # assert len(billy_check.friends) == 1 - # assert billy_check.friends[0].name == "Kate" - # assert billy_check.friends[0].favourite_game.name == 'Checkers' - # assert billy_check.friends[0].least_favourite_game.name == 'Uno' + billy_check = ( + await Child.objects.select_related( + [ + "friends", + "favourite_game", + "least_favourite_game", + "friends__favourite_game", + "friends__least_favourite_game", + ] + ) + .filter(friends__favourite_game__name="Checkers") + .get(name="Billy") + ) + assert len(billy_check.friends) == 1 + assert billy_check.friends[0].name == "Kate" + assert billy_check.friends[0].favourite_game.name == "Checkers" + assert billy_check.friends[0].least_favourite_game.name == "Uno" diff --git a/tests/test_models_helpers.py b/tests/test_models_helpers.py index 56a77c0..a397e91 100644 --- a/tests/test_models_helpers.py +++ b/tests/test_models_helpers.py @@ -2,8 +2,16 @@ from ormar.models.helpers.models import group_related_list def test_group_related_list(): - given = ['friends__least_favourite_game', 'least_favourite_game', 'friends', - 'favourite_game', 'friends__favourite_game'] - expected = {'least_favourite_game': [], 'favourite_game': [], - 'friends': ['favourite_game', 'least_favourite_game']} + given = [ + "friends__least_favourite_game", + "least_favourite_game", + "friends", + "favourite_game", + "friends__favourite_game", + ] + expected = { + "least_favourite_game": [], + "favourite_game": [], + "friends": ["favourite_game", "least_favourite_game"], + } assert group_related_list(given) == expected diff --git a/tests/test_more_same_table_joins.py b/tests/test_more_same_table_joins.py index 3fe22fe..9dc086e 100644 --- a/tests/test_more_same_table_joins.py +++ b/tests/test_more_same_table_joins.py @@ -101,15 +101,10 @@ async def test_model_multiple_instances_of_same_table_in_schema(): async with database: await create_data() classes = await SchoolClass.objects.select_related( - ["teachers__category__department", "students"] + ["teachers__category__department", "students__category__department"] ).all() assert classes[0].name == "Math" assert classes[0].students[0].name == "Jane" assert len(classes[0].dict().get("students")) == 2 assert classes[0].teachers[0].category.department.name == "Law Department" - - assert classes[0].students[0].category.pk is not None - assert classes[0].students[0].category.name is None - await classes[0].students[0].category.load() - await classes[0].students[0].category.department.load() assert classes[0].students[0].category.department.name == "Math Department"