diff --git a/ormar/__init__.py b/ormar/__init__.py index b2c7020..328e894 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -56,7 +56,7 @@ from ormar.fields import ( ) # noqa: I100 from ormar.models import Model from ormar.models.metaclass import ModelMeta -from ormar.queryset import QuerySet +from ormar.queryset import OrderAction, QuerySet from ormar.relations import RelationType from ormar.signals import Signal @@ -106,4 +106,5 @@ __all__ = [ "BaseField", "ManyToManyField", "ForeignKeyField", + "OrderAction", ] diff --git a/ormar/models/mixins/prefetch_mixin.py b/ormar/models/mixins/prefetch_mixin.py index d8ee350..440052a 100644 --- a/ormar/models/mixins/prefetch_mixin.py +++ b/ormar/models/mixins/prefetch_mixin.py @@ -2,7 +2,7 @@ from typing import Callable, Dict, List, TYPE_CHECKING, Tuple, Type, cast from ormar.models.mixins.relation_mixin import RelationMixin -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from ormar.fields import ForeignKeyField, ManyToManyField @@ -18,10 +18,10 @@ class PrefetchQueryMixin(RelationMixin): @staticmethod def get_clause_target_and_filter_column_name( - parent_model: Type["Model"], - target_model: Type["Model"], - reverse: bool, - related: str, + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool, + related: str, ) -> Tuple[Type["Model"], str]: """ Returns Model on which query clause should be performed and name of the column. @@ -51,7 +51,7 @@ class PrefetchQueryMixin(RelationMixin): @staticmethod def get_column_name_for_id_extraction( - parent_model: Type["Model"], reverse: bool, related: str, use_raw: bool, + parent_model: Type["Model"], reverse: bool, related: str, use_raw: bool, ) -> str: """ Returns name of the column that should be used to extract ids from model. diff --git a/ormar/models/model_row.py b/ormar/models/model_row.py index f0a4a31..476e274 100644 --- a/ormar/models/model_row.py +++ b/ormar/models/model_row.py @@ -17,7 +17,7 @@ from ormar.models import NewBaseModel # noqa: I202 from ormar.models.helpers.models import group_related_list -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from ormar.fields import ForeignKeyField from ormar.models import T else: diff --git a/ormar/queryset/__init__.py b/ormar/queryset/__init__.py index ebfab7b..11b072e 100644 --- a/ormar/queryset/__init__.py +++ b/ormar/queryset/__init__.py @@ -1,10 +1,20 @@ """ Contains QuerySet and different Query classes to allow for constructing of sql queries. """ +from ormar.queryset.actions import FilterAction, OrderAction from ormar.queryset.filter_query import FilterQuery from ormar.queryset.limit_query import LimitQuery from ormar.queryset.offset_query import OffsetQuery from ormar.queryset.order_query import OrderQuery from ormar.queryset.queryset import QuerySet, T -__all__ = ["T", "QuerySet", "FilterQuery", "LimitQuery", "OffsetQuery", "OrderQuery"] +__all__ = [ + "T", + "QuerySet", + "FilterQuery", + "LimitQuery", + "OffsetQuery", + "OrderQuery", + "FilterAction", + "OrderAction", +] diff --git a/ormar/queryset/actions/__init__.py b/ormar/queryset/actions/__init__.py new file mode 100644 index 0000000..088d68a --- /dev/null +++ b/ormar/queryset/actions/__init__.py @@ -0,0 +1,4 @@ +from ormar.queryset.actions.filter_action import FilterAction +from ormar.queryset.actions.order_action import OrderAction + +__all__ = ["FilterAction", "OrderAction"] diff --git a/ormar/queryset/filter_action.py b/ormar/queryset/actions/filter_action.py similarity index 71% rename from ormar/queryset/filter_action.py rename to ormar/queryset/actions/filter_action.py index d2d8e45..43c71df 100644 --- a/ormar/queryset/filter_action.py +++ b/ormar/queryset/actions/filter_action.py @@ -1,11 +1,11 @@ -from typing import Any, Dict, List, TYPE_CHECKING, Type +from typing import Any, Dict, TYPE_CHECKING, Type import sqlalchemy from sqlalchemy import text import ormar # noqa: I100, I202 from ormar.exceptions import QueryDefinitionError -from ormar.queryset.utils import get_relationship_alias_model_and_str +from ormar.queryset.actions.query_action import QueryAction if TYPE_CHECKING: # pragma: nocover from ormar import Model @@ -28,7 +28,7 @@ FILTER_OPERATORS = { ESCAPE_CHARACTERS = ["%", "_"] -class FilterAction: +class FilterAction(QueryAction): """ Filter Actions is populated by queryset when filter() is called. @@ -39,7 +39,18 @@ class FilterAction: """ def __init__(self, filter_str: str, value: Any, model_cls: Type["Model"]) -> None: - parts = filter_str.split("__") + super().__init__(query_str=filter_str, model_cls=model_cls) + self.filter_value = value + self._escape_characters_in_clause() + + def has_escaped_characters(self) -> bool: + """Check if value is a string that contains characters to escape""" + return isinstance(self.filter_value, str) and any( + c for c in ESCAPE_CHARACTERS if c in self.filter_value + ) + + def _split_value_into_parts(self, query_str: str) -> None: + parts = query_str.split("__") if parts[-1] in FILTER_OPERATORS: self.operator = parts[-1] self.field_name = parts[-2] @@ -49,61 +60,6 @@ class FilterAction: self.field_name = parts[-1] self.related_parts = parts[:-1] - self.filter_value = value - self.table_prefix = "" - self.source_model = model_cls - self.target_model = model_cls - self.is_through = False - self._determine_filter_target_table() - self._escape_characters_in_clause() - - @property - def table(self) -> sqlalchemy.Table: - """Shortcut to sqlalchemy Table of filtered target model""" - return self.target_model.Meta.table - - @property - def column(self) -> sqlalchemy.Column: - """Shortcut to sqlalchemy column of filtered target model""" - aliased_name = self.target_model.get_column_alias(self.field_name) - return self.target_model.Meta.table.columns[aliased_name] - - def has_escaped_characters(self) -> bool: - """Check if value is a string that contains characters to escape""" - return isinstance(self.filter_value, str) and any( - c for c in ESCAPE_CHARACTERS if c in self.filter_value - ) - - def update_select_related(self, select_related: List[str]) -> List[str]: - """ - Updates list of select related with related part included in the filter key. - That way If you want to just filter by relation you do not have to provide - select_related separately. - - :param select_related: list of relation join strings - :type select_related: List[str] - :return: list of relation joins with implied joins from filter added - :rtype: List[str] - """ - select_related = select_related[:] - if self.related_str and not any( - rel.startswith(self.related_str) for rel in select_related - ): - select_related.append(self.related_str) - return select_related - - def _determine_filter_target_table(self) -> None: - """ - Walks the relation to retrieve the actual model on which the clause should be - constructed, extracts alias based on last relation leading to target model. - """ - ( - self.table_prefix, - self.target_model, - self.related_str, - self.is_through, - ) = get_relationship_alias_model_and_str(self.source_model, self.related_parts) - def _escape_characters_in_clause(self) -> None: """ Escapes the special characters ["%", "_"] if needed. @@ -151,7 +107,7 @@ class FilterAction: sufix = "%" if "end" not in self.operator else "" self.filter_value = f"{prefix}{self.filter_value}{sufix}" - def get_text_clause(self,) -> sqlalchemy.sql.expression.TextClause: + def get_text_clause(self) -> sqlalchemy.sql.expression.TextClause: """ Escapes characters if it's required. Substitutes values of the models if value is a ormar Model with its pk value. diff --git a/ormar/queryset/actions/order_action.py b/ormar/queryset/actions/order_action.py new file mode 100644 index 0000000..2173e24 --- /dev/null +++ b/ormar/queryset/actions/order_action.py @@ -0,0 +1,68 @@ +from typing import TYPE_CHECKING, Type + +import sqlalchemy +from sqlalchemy import text + +from ormar.queryset.actions.query_action import QueryAction # noqa: I100, I202 + +if TYPE_CHECKING: # pragma: nocover + from ormar import Model + + +class OrderAction(QueryAction): + """ + Order Actions is populated by queryset when order_by() is called. + + All required params are extracted but kept raw until actual filter clause value + is required -> then the action is converted into text() clause. + + Extracted in order to easily change table prefixes on complex relations. + """ + + def __init__( + self, order_str: str, model_cls: Type["Model"], alias: str = None + ) -> None: + self.direction: str = "" + super().__init__(query_str=order_str, model_cls=model_cls) + self.is_source_model_order = False + if alias: + self.table_prefix = alias + if self.source_model == self.target_model and "__" not in self.related_str: + self.is_source_model_order = True + + @property + def field_alias(self) -> str: + return self.target_model.get_column_alias(self.field_name) + + def get_text_clause(self) -> sqlalchemy.sql.expression.TextClause: + """ + Escapes characters if it's required. + Substitutes values of the models if value is a ormar Model with its pk value. + Compiles the clause. + + :return: complied and escaped clause + :rtype: sqlalchemy.sql.elements.TextClause + """ + prefix = f"{self.table_prefix}_" if self.table_prefix else "" + return text(f"{prefix}{self.table}" f".{self.field_alias} {self.direction}") + + def _split_value_into_parts(self, order_str: str) -> None: + if order_str.startswith("-"): + self.direction = "desc" + order_str = order_str[1:] + parts = order_str.split("__") + self.field_name = parts[-1] + self.related_parts = parts[:-1] + + def check_if_filter_apply(self, target_model: Type["Model"], alias: str) -> bool: + """ + Checks filter conditions to find if they apply to current join. + + :param target_model: model which is now processed + :type target_model: Type["Model"] + :param alias: prefix of the relation + :type alias: str + :return: result of the check + :rtype: bool + """ + return target_model == self.target_model and alias == self.table_prefix diff --git a/ormar/queryset/actions/query_action.py b/ormar/queryset/actions/query_action.py new file mode 100644 index 0000000..2c6ee84 --- /dev/null +++ b/ormar/queryset/actions/query_action.py @@ -0,0 +1,93 @@ +import abc +from typing import Any, List, TYPE_CHECKING, Type + +import sqlalchemy + +from ormar.queryset.utils import get_relationship_alias_model_and_str # noqa: I202 + +if TYPE_CHECKING: # pragma: nocover + from ormar import Model + + +class QueryAction(abc.ABC): + """ + Base QueryAction class with common params for Filter and Order actions. + """ + + def __init__(self, query_str: str, model_cls: Type["Model"]) -> None: + self.query_str = query_str + self.field_name: str = "" + self.related_parts: List[str] = [] + self.related_str: str = "" + + self.table_prefix = "" + self.source_model = model_cls + self.target_model = model_cls + self.is_through = False + + self._split_value_into_parts(query_str) + self._determine_filter_target_table() + + def __eq__(self, other: object) -> bool: # pragma: no cover + if not isinstance(other, QueryAction): + return False + return self.query_str == other.query_str + + def __hash__(self) -> Any: + return hash((self.table_prefix, self.query_str)) + + @abc.abstractmethod + def _split_value_into_parts(self, query_str: str) -> None: # pragma: no cover + """ + Splits string into related parts and field_name + :param query_str: query action string to split (i..e filter or order by) + :type query_str: str + """ + pass + + @abc.abstractmethod + def get_text_clause( + self, + ) -> sqlalchemy.sql.expression.TextClause: # pragma: no cover + pass + + @property + def table(self) -> sqlalchemy.Table: + """Shortcut to sqlalchemy Table of filtered target model""" + return self.target_model.Meta.table + + @property + def column(self) -> sqlalchemy.Column: + """Shortcut to sqlalchemy column of filtered target model""" + aliased_name = self.target_model.get_column_alias(self.field_name) + return self.target_model.Meta.table.columns[aliased_name] + + def update_select_related(self, select_related: List[str]) -> List[str]: + """ + Updates list of select related with related part included in the filter key. + That way If you want to just filter by relation you do not have to provide + select_related separately. + + :param select_related: list of relation join strings + :type select_related: List[str] + :return: list of relation joins with implied joins from filter added + :rtype: List[str] + """ + select_related = select_related[:] + if self.related_str and not any( + rel.startswith(self.related_str) for rel in select_related + ): + select_related.append(self.related_str) + return select_related + + def _determine_filter_target_table(self) -> None: + """ + Walks the relation to retrieve the actual model on which the clause should be + constructed, extracts alias based on last relation leading to target model. + """ + ( + self.table_prefix, + self.target_model, + self.related_str, + self.is_through, + ) = get_relationship_alias_model_and_str(self.source_model, self.related_parts) diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index b5a3c5b..b98616d 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import Any, List, TYPE_CHECKING, Tuple, Type import ormar # noqa I100 -from ormar.queryset.filter_action import FilterAction +from ormar.queryset.actions.filter_action import FilterAction from ormar.queryset.utils import get_relationship_alias_model_and_str if TYPE_CHECKING: # pragma no cover diff --git a/ormar/queryset/filter_query.py b/ormar/queryset/filter_query.py index 4100f16..cb9b880 100644 --- a/ormar/queryset/filter_query.py +++ b/ormar/queryset/filter_query.py @@ -1,7 +1,7 @@ from typing import List import sqlalchemy -from ormar.queryset.filter_action import FilterAction +from ormar.queryset.actions.filter_action import FilterAction class FilterQuery: diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index 4626fbd..f18c81d 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -14,11 +14,13 @@ from typing import ( import sqlalchemy from sqlalchemy import text -from ormar.exceptions import RelationshipInstanceError # noqa I100 +import ormar # noqa I100 +from ormar.exceptions import RelationshipInstanceError from ormar.relations import AliasManager if TYPE_CHECKING: # pragma no cover from ormar import Model + from ormar.queryset import OrderAction class SqlJoin: @@ -29,7 +31,7 @@ class SqlJoin: columns: List[sqlalchemy.Column], fields: Optional[Union[Set, Dict]], exclude_fields: Optional[Union[Set, Dict]], - order_columns: Optional[List], + order_columns: Optional[List["OrderAction"]], sorted_orders: OrderedDict, main_model: Type["Model"], relation_name: str, @@ -89,7 +91,18 @@ class SqlJoin: """ return self.main_model.Meta.alias_manager - def on_clause(self, previous_alias: str, from_clause: str, to_clause: str,) -> text: + @property + def to_table(self) -> str: + """ + Shortcut to table name of the next model + :return: name of the target table + :rtype: str + """ + return self.next_model.Meta.table.name + + def _on_clause( + self, previous_alias: str, from_clause: str, to_clause: str, + ) -> text: """ Receives aliases and names of both ends of the join and combines them into one text clause used in joins. @@ -118,7 +131,7 @@ class SqlJoin: :rtype: Tuple[List[str], Join, List[TextClause], collections.OrderedDict] """ if self.target_field.is_multi: - self.process_m2m_through_table() + self._process_m2m_through_table() self.next_model = self.target_field.to self._forward_join() @@ -207,7 +220,7 @@ class SqlJoin: self.sorted_orders, ) = sql_join.build_join() - def process_m2m_through_table(self) -> None: + def _process_m2m_through_table(self) -> None: """ Process Through table of the ManyToMany relation so that source table is linked to the through table (one additional join) @@ -222,8 +235,7 @@ class SqlJoin: To point to through model """ - new_part = self.process_m2m_related_name_change() - self._replace_many_to_many_order_by_columns(self.relation_name, new_part) + new_part = self._process_m2m_related_name_change() self.next_model = self.target_field.through self._forward_join() @@ -232,7 +244,7 @@ class SqlJoin: self.own_alias = self.next_alias self.target_field = self.next_model.Meta.model_fields[self.relation_name] - def process_m2m_related_name_change(self, reverse: bool = False) -> str: + def _process_m2m_related_name_change(self, reverse: bool = False) -> str: """ Extracts relation name to link join through the Through model declared on relation field. @@ -272,24 +284,21 @@ class SqlJoin: Process order_by causes for non m2m relations. """ - to_table = self.next_model.Meta.table.name - to_key, from_key = self.get_to_and_from_keys() + to_key, from_key = self._get_to_and_from_keys() - on_clause = self.on_clause( + on_clause = self._on_clause( previous_alias=self.own_alias, from_clause=f"{self.target_field.owner.Meta.tablename}.{from_key}", - to_clause=f"{to_table}.{to_key}", + to_clause=f"{self.to_table}.{to_key}", + ) + target_table = self.alias_manager.prefixed_table_name( + self.next_alias, self.to_table ) - target_table = self.alias_manager.prefixed_table_name(self.next_alias, to_table) self.select_from = sqlalchemy.sql.outerjoin( self.select_from, target_table, on_clause ) - pkname_alias = self.next_model.get_column_alias(self.next_model.Meta.pkname) - if not self.target_field.is_multi: - self.get_order_bys( - to_table=to_table, pkname_alias=pkname_alias, - ) + self._get_order_bys() # TODO: fix fields and exclusions for through model? self_related_fields = self.next_model.own_table_columns( @@ -305,88 +314,35 @@ class SqlJoin: ) self.used_aliases.append(self.next_alias) - def _replace_many_to_many_order_by_columns(self, part: str, new_part: str) -> None: - """ - Substitutes the name of the relation with actual model name in m2m order bys. - - :param part: name of the field with relation - :type part: str - :param new_part: name of the target model - :type new_part: str - """ - if self.order_columns: - split_order_columns = [ - x.split("__") for x in self.order_columns if "__" in x - ] - for condition in split_order_columns: - if self._check_if_condition_apply(condition, part): - condition[-2] = condition[-2].replace(part, new_part) - self.order_columns = [x for x in self.order_columns if "__" not in x] + [ - "__".join(x) for x in split_order_columns - ] - - @staticmethod - def _check_if_condition_apply(condition: List, part: str) -> bool: - """ - Checks filter conditions to find if they apply to current join. - - :param condition: list of parts of condition split by '__' - :type condition: List[str] - :param part: name of the current relation join. - :type part: str - :return: result of the check - :rtype: bool - """ - return len(condition) >= 2 and ( - condition[-2] == part or condition[-2][1:] == part + def _set_default_primary_key_order_by(self) -> None: + clause = ormar.OrderAction( + order_str=self.next_model.Meta.pkname, + model_cls=self.next_model, + alias=self.next_alias, ) + self.sorted_orders[clause] = clause.get_text_clause() - def set_aliased_order_by(self, condition: List[str], to_table: str,) -> None: - """ - Substitute hyphens ('-') with descending order. - Construct actual sqlalchemy text clause using aliased table and column name. - - :param condition: list of parts of a current condition split by '__' - :type condition: List[str] - :param to_table: target table - :type to_table: sqlalchemy.sql.elements.quoted_name - """ - direction = f"{'desc' if condition[0][0] == '-' else ''}" - column_alias = self.next_model.get_column_alias(condition[-1]) - order = text(f"{self.next_alias}_{to_table}.{column_alias} {direction}") - self.sorted_orders["__".join(condition)] = order - - def get_order_bys(self, to_table: str, pkname_alias: str,) -> None: # noqa: CCR001 + def _get_order_bys(self) -> None: # noqa: CCR001 """ Triggers construction of order bys if they are given. Otherwise by default each table is sorted by a primary key column asc. - - :param to_table: target table - :type to_table: sqlalchemy.sql.elements.quoted_name - :param pkname_alias: alias of the primary key column - :type pkname_alias: str """ alias = self.next_alias if self.order_columns: current_table_sorted = False - split_order_columns = [ - x.split("__") for x in self.order_columns if "__" in x - ] - for condition in split_order_columns: - if self._check_if_condition_apply(condition, self.relation_name): + for condition in self.order_columns: + if condition.check_if_filter_apply( + target_model=self.next_model, alias=alias + ): current_table_sorted = True - self.set_aliased_order_by( - condition=condition, to_table=to_table, - ) - if not current_table_sorted: - order = text(f"{alias}_{to_table}.{pkname_alias}") - self.sorted_orders[f"{alias}.{pkname_alias}"] = order + self.sorted_orders[condition] = condition.get_text_clause() + if not current_table_sorted and not self.target_field.is_multi: + self._set_default_primary_key_order_by() - else: - order = text(f"{alias}_{to_table}.{pkname_alias}") - self.sorted_orders[f"{alias}.{pkname_alias}"] = order + elif not self.target_field.is_multi: + self._set_default_primary_key_order_by() - def get_to_and_from_keys(self) -> Tuple[str, str]: + def _get_to_and_from_keys(self) -> Tuple[str, str]: """ Based on the relation type, name of the relation and previous models and parts stored in JoinParameters it resolves the current to and from keys, which are @@ -396,7 +352,7 @@ class SqlJoin: :rtype: Tuple[str, str] """ if self.target_field.is_multi: - to_key = self.process_m2m_related_name_change(reverse=True) + to_key = self._process_m2m_related_name_change(reverse=True) from_key = self.main_model.get_column_alias(self.main_model.Meta.pkname) elif self.target_field.virtual: diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 7abf4c6..08d2675 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -20,6 +20,7 @@ from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list if TYPE_CHECKING: # pragma: no cover from ormar import Model from ormar.fields import ForeignKeyField, BaseField + from ormar.queryset import OrderAction def add_relation_field_to_fields( @@ -128,7 +129,7 @@ class PrefetchQuery: exclude_fields: Optional[Union[Dict, Set]], prefetch_related: List, select_related: List, - orders_by: List, + orders_by: List["OrderAction"], ) -> None: self.model = model_cls @@ -141,7 +142,9 @@ class PrefetchQuery: self.models: Dict = {} self.select_dict = translate_list_to_dict(self._select_related) self.orders_by = orders_by or [] - self.order_dict = translate_list_to_dict(self.orders_by, is_order=True) + self.order_dict = translate_list_to_dict( + [x.query_str for x in self.orders_by], is_order=True + ) async def prefetch_related( self, models: Sequence["Model"], rows: List diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index edb28c1..7c5a211 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -8,11 +8,12 @@ from sqlalchemy import text import ormar # noqa I100 from ormar.models.helpers.models import group_related_list from ormar.queryset import FilterQuery, LimitQuery, OffsetQuery, OrderQuery -from ormar.queryset.filter_action import FilterAction +from ormar.queryset.actions.filter_action import FilterAction from ormar.queryset.join import SqlJoin if TYPE_CHECKING: # pragma no cover from ormar import Model + from ormar.queryset import OrderAction class Query: @@ -26,7 +27,7 @@ class Query: offset: Optional[int], fields: Optional[Union[Dict, Set]], exclude_fields: Optional[Union[Dict, Set]], - order_bys: Optional[List], + order_bys: Optional[List["OrderAction"]], limit_raw_sql: bool, ) -> None: self.query_offset = offset @@ -45,7 +46,7 @@ class Query: self.select_from: List[str] = [] self.columns = [sqlalchemy.Column] self.order_columns = order_bys - self.sorted_orders: OrderedDict = OrderedDict() + self.sorted_orders: OrderedDict[OrderAction, text] = OrderedDict() self._init_sorted_orders() self.limit_raw_sql = limit_raw_sql @@ -58,28 +59,6 @@ class Query: for clause in self.order_columns: self.sorted_orders[clause] = None - @property - def prefixed_pk_name(self) -> str: - """ - Shortcut for extracting prefixed with alias primary key column name from main - model - :return: alias of pk column prefix with table name. - :rtype: str - """ - pkname_alias = self.model_cls.get_column_alias(self.model_cls.Meta.pkname) - return f"{self.table.name}.{pkname_alias}" - - def alias(self, name: str) -> str: - """ - Shortcut to extracting column alias from given master model. - - :param name: name of column - :type name: str - :return: alias of given column name - :rtype: str - """ - return self.model_cls.get_column_alias(name) - def apply_order_bys_for_primary_model(self) -> None: # noqa: CCR001 """ Applies order_by queries on main model when it's used as a subquery. @@ -88,16 +67,13 @@ class Query: """ if self.order_columns: for clause in self.order_columns: - if "__" not in clause: - text_clause = ( - text(f"{self.table.name}.{self.alias(clause[1:])} desc") - if clause.startswith("-") - else text(f"{self.table.name}.{self.alias(clause)}") - ) - self.sorted_orders[clause] = text_clause + if clause.is_source_model_order: + self.sorted_orders[clause] = clause.get_text_clause() else: - order = text(self.prefixed_pk_name) - self.sorted_orders[self.prefixed_pk_name] = order + clause = ormar.OrderAction( + order_str=self.model_cls.Meta.pkname, model_cls=self.model_cls + ) + self.sorted_orders[clause] = clause.get_text_clause() def _pagination_query_required(self) -> bool: """ @@ -208,7 +184,9 @@ class Query: for filter_clause in self.exclude_clauses if filter_clause.table_prefix == "" ] - sorts_to_use = {k: v for k, v in self.sorted_orders.items() if "__" not in k} + sorts_to_use = { + k: v for k, v in self.sorted_orders.items() if k.is_source_model_order + } expr = FilterQuery(filter_clauses=filters_to_use).apply(expr) expr = FilterQuery(filter_clauses=excludes_to_use, exclude=True).apply(expr) expr = OrderQuery(sorted_orders=sorts_to_use).apply(expr) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index ba55586..2202abd 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -21,6 +21,7 @@ import ormar # noqa I100 from ormar import MultipleMatches, NoMatch from ormar.exceptions import ModelError, ModelPersistenceError, QueryDefinitionError from ormar.queryset import FilterQuery +from ormar.queryset.actions.order_action import OrderAction from ormar.queryset.clause import QueryClause from ormar.queryset.prefetch_query import PrefetchQuery from ormar.queryset.query import Query @@ -514,7 +515,12 @@ class QuerySet(Generic[T]): if not isinstance(columns, list): columns = [columns] - order_bys = self.order_bys + [x for x in columns if x not in self.order_bys] + orders_by = [ + OrderAction(order_str=x, model_cls=self.model_cls) # type: ignore + for x in columns + ] + + order_bys = self.order_bys + [x for x in orders_by if x not in self.order_bys] return self.__class__( model_cls=self.model, filter_clauses=self.filter_clauses, @@ -713,7 +719,14 @@ class QuerySet(Generic[T]): return await self.filter(**kwargs).first() expr = self.build_select_expression( - limit=1, order_bys=[f"{self.model.Meta.pkname}"] + self.order_bys + limit=1, + order_bys=[ + OrderAction( + order_str=f"{self.model.Meta.pkname}", + model_cls=self.model_cls, # type: ignore + ) + ] + + self.order_bys, ) rows = await self.database.fetch_all(expr) processed_rows = self._process_query_result_rows(rows) @@ -742,7 +755,14 @@ class QuerySet(Generic[T]): if not self.filter_clauses: expr = self.build_select_expression( - limit=1, order_bys=[f"-{self.model.Meta.pkname}"] + self.order_bys + limit=1, + order_bys=[ + OrderAction( + order_str=f"-{self.model.Meta.pkname}", + model_cls=self.model_cls, # type: ignore + ) + ] + + self.order_bys, ) else: expr = self.build_select_expression() diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index 6b98028..9445fe0 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -232,16 +232,24 @@ def get_relationship_alias_model_and_str( is_through = False model_cls = source_model previous_model = model_cls + previous_models = [model_cls] manager = model_cls.Meta.alias_manager for relation in related_parts[:]: related_field = model_cls.Meta.model_fields[relation] + if related_field.is_through: + # through is always last - cannot go further is_through = True - related_parts = [ - x.replace(relation, related_field.related_name) if x == relation else x - for x in related_parts + related_parts.remove(relation) + through_field = related_field.owner.Meta.model_fields[ + related_field.related_name or "" ] - relation = related_field.related_name + if len(previous_models) > 1 and previous_models[-2] == through_field.to: + previous_model = through_field.to + relation = through_field.related_name + else: + relation = related_field.related_name + if related_field.is_multi: previous_model = related_field.through relation = related_field.default_target_field_name() # type: ignore @@ -250,6 +258,8 @@ def get_relationship_alias_model_and_str( ) model_cls = related_field.to previous_model = model_cls + if not is_through: + previous_models.append(previous_model) relation_str = "__".join(related_parts) return table_prefix, model_cls, relation_str, is_through diff --git a/tests/test_m2m_through_fields.py b/tests/test_m2m_through_fields.py index 9d7f38a..5c1b44f 100644 --- a/tests/test_m2m_through_fields.py +++ b/tests/test_m2m_through_fields.py @@ -34,6 +34,14 @@ class PostCategory(ormar.Model): param_name: str = ormar.String(default="Name", max_length=200) +class Blog(ormar.Model): + class Meta(BaseMeta): + pass + + id: int = ormar.Integer(primary_key=True) + title: str = ormar.String(max_length=200) + + class Post(ormar.Model): class Meta(BaseMeta): pass @@ -41,6 +49,7 @@ class Post(ormar.Model): id: int = ormar.Integer(primary_key=True) title: str = ormar.String(max_length=200) categories = ormar.ManyToMany(Category, through=PostCategory) + blog = ormar.ForeignKey(Blog) @pytest.fixture(autouse=True, scope="module") @@ -146,18 +155,86 @@ async def test_filtering_by_through_model() -> Any: ) post2 = ( - await Post.objects.filter(postcategory__sort_order__gt=1) - .select_related("categories") - .get() + await Post.objects.select_related("categories") + .filter(postcategory__sort_order__gt=1) + .get() ) assert len(post2.categories) == 1 assert post2.categories[0].postcategory.sort_order == 2 post3 = await Post.objects.filter( - categories__postcategory__param_name="volume").get() + categories__postcategory__param_name="volume" + ).get() assert len(post3.categories) == 1 assert post3.categories[0].postcategory.param_name == "volume" + +@pytest.mark.asyncio +async def test_deep_filtering_by_through_model() -> Any: + async with database: + blog = await Blog(title="My Blog").save() + post = await Post(title="Test post", blog=blog).save() + + await post.categories.create( + name="Test category1", + postcategory={"sort_order": 1, "param_name": "volume"}, + ) + await post.categories.create( + name="Test category2", postcategory={"sort_order": 2, "param_name": "area"} + ) + + blog2 = ( + await Blog.objects.select_related("posts__categories") + .filter(posts__postcategory__sort_order__gt=1) + .get() + ) + assert len(blog2.posts) == 1 + assert len(blog2.posts[0].categories) == 1 + assert blog2.posts[0].categories[0].postcategory.sort_order == 2 + + blog3 = await Blog.objects.filter( + posts__categories__postcategory__param_name="volume" + ).get() + assert len(blog3.posts) == 1 + assert len(blog3.posts[0].categories) == 1 + assert blog3.posts[0].categories[0].postcategory.param_name == "volume" + + +@pytest.mark.asyncio +async def test_ordering_by_through_model() -> Any: + async with database: + post = await Post(title="Test post").save() + await post.categories.create( + name="Test category1", + postcategory={"sort_order": 2, "param_name": "volume"}, + ) + await post.categories.create( + name="Test category2", postcategory={"sort_order": 1, "param_name": "area"} + ) + await post.categories.create( + name="Test category3", + postcategory={"sort_order": 3, "param_name": "velocity"}, + ) + + post2 = ( + await Post.objects.select_related("categories") + .order_by("-postcategory__sort_order") + .get() + ) + assert len(post2.categories) == 3 + assert post2.categories[0].name == "Test category3" + assert post2.categories[2].name == "Test category2" + + post3 = ( + await Post.objects.select_related("categories") + .order_by("categories__postcategory__param_name") + .get() + ) + assert len(post3.categories) == 3 + assert post3.categories[0].postcategory.param_name == "area" + assert post3.categories[2].postcategory.param_name == "volume" + + # TODO: check/ modify following # add to fields with class lower name (V) @@ -166,10 +243,12 @@ async def test_filtering_by_through_model() -> Any: # creating in queryset proxy (dict with through name and kwargs) (V) # loading the data into model instance of though model (V) <- fix fields ane exclude # accessing from instance (V) <- no both sides only nested one is relevant, fix one side -# filtering in filter (through name normally) (V) < - table prefix from normal relation, check if is_through needed +# filtering in filter (through name normally) (V) < - table prefix from normal relation, +# check if is_through needed, resolved side of relation +# ordering by in order_by + # updating in query -# ordering by in order_by # modifying from instance (both sides?) # including/excluding in fields? # allowing to change fk fields names in through model?