refactor order bys into OrderQuery, add ordering to through models too

This commit is contained in:
collerek
2021-02-25 17:28:05 +01:00
parent c139ca4f61
commit 503f589fa7
16 changed files with 388 additions and 210 deletions

View File

@ -56,7 +56,7 @@ from ormar.fields import (
) # noqa: I100 ) # noqa: I100
from ormar.models import Model from ormar.models import Model
from ormar.models.metaclass import ModelMeta from ormar.models.metaclass import ModelMeta
from ormar.queryset import QuerySet from ormar.queryset import OrderAction, QuerySet
from ormar.relations import RelationType from ormar.relations import RelationType
from ormar.signals import Signal from ormar.signals import Signal
@ -106,4 +106,5 @@ __all__ = [
"BaseField", "BaseField",
"ManyToManyField", "ManyToManyField",
"ForeignKeyField", "ForeignKeyField",
"OrderAction",
] ]

View File

@ -2,7 +2,7 @@ from typing import Callable, Dict, List, TYPE_CHECKING, Tuple, Type, cast
from ormar.models.mixins.relation_mixin import RelationMixin from ormar.models.mixins.relation_mixin import RelationMixin
if TYPE_CHECKING: if TYPE_CHECKING: # pragma: no cover
from ormar.fields import ForeignKeyField, ManyToManyField from ormar.fields import ForeignKeyField, ManyToManyField
@ -18,10 +18,10 @@ class PrefetchQueryMixin(RelationMixin):
@staticmethod @staticmethod
def get_clause_target_and_filter_column_name( def get_clause_target_and_filter_column_name(
parent_model: Type["Model"], parent_model: Type["Model"],
target_model: Type["Model"], target_model: Type["Model"],
reverse: bool, reverse: bool,
related: str, related: str,
) -> Tuple[Type["Model"], str]: ) -> Tuple[Type["Model"], str]:
""" """
Returns Model on which query clause should be performed and name of the column. Returns Model on which query clause should be performed and name of the column.
@ -51,7 +51,7 @@ class PrefetchQueryMixin(RelationMixin):
@staticmethod @staticmethod
def get_column_name_for_id_extraction( def get_column_name_for_id_extraction(
parent_model: Type["Model"], reverse: bool, related: str, use_raw: bool, parent_model: Type["Model"], reverse: bool, related: str, use_raw: bool,
) -> str: ) -> str:
""" """
Returns name of the column that should be used to extract ids from model. Returns name of the column that should be used to extract ids from model.

View File

@ -17,7 +17,7 @@ from ormar.models import NewBaseModel # noqa: I202
from ormar.models.helpers.models import group_related_list from ormar.models.helpers.models import group_related_list
if TYPE_CHECKING: if TYPE_CHECKING: # pragma: no cover
from ormar.fields import ForeignKeyField from ormar.fields import ForeignKeyField
from ormar.models import T from ormar.models import T
else: else:

View File

@ -1,10 +1,20 @@
""" """
Contains QuerySet and different Query classes to allow for constructing of sql queries. Contains QuerySet and different Query classes to allow for constructing of sql queries.
""" """
from ormar.queryset.actions import FilterAction, OrderAction
from ormar.queryset.filter_query import FilterQuery from ormar.queryset.filter_query import FilterQuery
from ormar.queryset.limit_query import LimitQuery from ormar.queryset.limit_query import LimitQuery
from ormar.queryset.offset_query import OffsetQuery from ormar.queryset.offset_query import OffsetQuery
from ormar.queryset.order_query import OrderQuery from ormar.queryset.order_query import OrderQuery
from ormar.queryset.queryset import QuerySet, T from ormar.queryset.queryset import QuerySet, T
__all__ = ["T", "QuerySet", "FilterQuery", "LimitQuery", "OffsetQuery", "OrderQuery"] __all__ = [
"T",
"QuerySet",
"FilterQuery",
"LimitQuery",
"OffsetQuery",
"OrderQuery",
"FilterAction",
"OrderAction",
]

View File

@ -0,0 +1,4 @@
from ormar.queryset.actions.filter_action import FilterAction
from ormar.queryset.actions.order_action import OrderAction
__all__ = ["FilterAction", "OrderAction"]

View File

@ -1,11 +1,11 @@
from typing import Any, Dict, List, TYPE_CHECKING, Type from typing import Any, Dict, TYPE_CHECKING, Type
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
import ormar # noqa: I100, I202 import ormar # noqa: I100, I202
from ormar.exceptions import QueryDefinitionError from ormar.exceptions import QueryDefinitionError
from ormar.queryset.utils import get_relationship_alias_model_and_str from ormar.queryset.actions.query_action import QueryAction
if TYPE_CHECKING: # pragma: nocover if TYPE_CHECKING: # pragma: nocover
from ormar import Model from ormar import Model
@ -28,7 +28,7 @@ FILTER_OPERATORS = {
ESCAPE_CHARACTERS = ["%", "_"] ESCAPE_CHARACTERS = ["%", "_"]
class FilterAction: class FilterAction(QueryAction):
""" """
Filter Actions is populated by queryset when filter() is called. Filter Actions is populated by queryset when filter() is called.
@ -39,7 +39,18 @@ class FilterAction:
""" """
def __init__(self, filter_str: str, value: Any, model_cls: Type["Model"]) -> None: def __init__(self, filter_str: str, value: Any, model_cls: Type["Model"]) -> None:
parts = filter_str.split("__") super().__init__(query_str=filter_str, model_cls=model_cls)
self.filter_value = value
self._escape_characters_in_clause()
def has_escaped_characters(self) -> bool:
"""Check if value is a string that contains characters to escape"""
return isinstance(self.filter_value, str) and any(
c for c in ESCAPE_CHARACTERS if c in self.filter_value
)
def _split_value_into_parts(self, query_str: str) -> None:
parts = query_str.split("__")
if parts[-1] in FILTER_OPERATORS: if parts[-1] in FILTER_OPERATORS:
self.operator = parts[-1] self.operator = parts[-1]
self.field_name = parts[-2] self.field_name = parts[-2]
@ -49,61 +60,6 @@ class FilterAction:
self.field_name = parts[-1] self.field_name = parts[-1]
self.related_parts = parts[:-1] self.related_parts = parts[:-1]
self.filter_value = value
self.table_prefix = ""
self.source_model = model_cls
self.target_model = model_cls
self.is_through = False
self._determine_filter_target_table()
self._escape_characters_in_clause()
@property
def table(self) -> sqlalchemy.Table:
"""Shortcut to sqlalchemy Table of filtered target model"""
return self.target_model.Meta.table
@property
def column(self) -> sqlalchemy.Column:
"""Shortcut to sqlalchemy column of filtered target model"""
aliased_name = self.target_model.get_column_alias(self.field_name)
return self.target_model.Meta.table.columns[aliased_name]
def has_escaped_characters(self) -> bool:
"""Check if value is a string that contains characters to escape"""
return isinstance(self.filter_value, str) and any(
c for c in ESCAPE_CHARACTERS if c in self.filter_value
)
def update_select_related(self, select_related: List[str]) -> List[str]:
"""
Updates list of select related with related part included in the filter key.
That way If you want to just filter by relation you do not have to provide
select_related separately.
:param select_related: list of relation join strings
:type select_related: List[str]
:return: list of relation joins with implied joins from filter added
:rtype: List[str]
"""
select_related = select_related[:]
if self.related_str and not any(
rel.startswith(self.related_str) for rel in select_related
):
select_related.append(self.related_str)
return select_related
def _determine_filter_target_table(self) -> None:
"""
Walks the relation to retrieve the actual model on which the clause should be
constructed, extracts alias based on last relation leading to target model.
"""
(
self.table_prefix,
self.target_model,
self.related_str,
self.is_through,
) = get_relationship_alias_model_and_str(self.source_model, self.related_parts)
def _escape_characters_in_clause(self) -> None: def _escape_characters_in_clause(self) -> None:
""" """
Escapes the special characters ["%", "_"] if needed. Escapes the special characters ["%", "_"] if needed.
@ -151,7 +107,7 @@ class FilterAction:
sufix = "%" if "end" not in self.operator else "" sufix = "%" if "end" not in self.operator else ""
self.filter_value = f"{prefix}{self.filter_value}{sufix}" self.filter_value = f"{prefix}{self.filter_value}{sufix}"
def get_text_clause(self,) -> sqlalchemy.sql.expression.TextClause: def get_text_clause(self) -> sqlalchemy.sql.expression.TextClause:
""" """
Escapes characters if it's required. Escapes characters if it's required.
Substitutes values of the models if value is a ormar Model with its pk value. Substitutes values of the models if value is a ormar Model with its pk value.

View File

@ -0,0 +1,68 @@
from typing import TYPE_CHECKING, Type
import sqlalchemy
from sqlalchemy import text
from ormar.queryset.actions.query_action import QueryAction # noqa: I100, I202
if TYPE_CHECKING: # pragma: nocover
from ormar import Model
class OrderAction(QueryAction):
"""
Order Actions is populated by queryset when order_by() is called.
All required params are extracted but kept raw until actual filter clause value
is required -> then the action is converted into text() clause.
Extracted in order to easily change table prefixes on complex relations.
"""
def __init__(
self, order_str: str, model_cls: Type["Model"], alias: str = None
) -> None:
self.direction: str = ""
super().__init__(query_str=order_str, model_cls=model_cls)
self.is_source_model_order = False
if alias:
self.table_prefix = alias
if self.source_model == self.target_model and "__" not in self.related_str:
self.is_source_model_order = True
@property
def field_alias(self) -> str:
return self.target_model.get_column_alias(self.field_name)
def get_text_clause(self) -> sqlalchemy.sql.expression.TextClause:
"""
Escapes characters if it's required.
Substitutes values of the models if value is a ormar Model with its pk value.
Compiles the clause.
:return: complied and escaped clause
:rtype: sqlalchemy.sql.elements.TextClause
"""
prefix = f"{self.table_prefix}_" if self.table_prefix else ""
return text(f"{prefix}{self.table}" f".{self.field_alias} {self.direction}")
def _split_value_into_parts(self, order_str: str) -> None:
if order_str.startswith("-"):
self.direction = "desc"
order_str = order_str[1:]
parts = order_str.split("__")
self.field_name = parts[-1]
self.related_parts = parts[:-1]
def check_if_filter_apply(self, target_model: Type["Model"], alias: str) -> bool:
"""
Checks filter conditions to find if they apply to current join.
:param target_model: model which is now processed
:type target_model: Type["Model"]
:param alias: prefix of the relation
:type alias: str
:return: result of the check
:rtype: bool
"""
return target_model == self.target_model and alias == self.table_prefix

View File

@ -0,0 +1,93 @@
import abc
from typing import Any, List, TYPE_CHECKING, Type
import sqlalchemy
from ormar.queryset.utils import get_relationship_alias_model_and_str # noqa: I202
if TYPE_CHECKING: # pragma: nocover
from ormar import Model
class QueryAction(abc.ABC):
"""
Base QueryAction class with common params for Filter and Order actions.
"""
def __init__(self, query_str: str, model_cls: Type["Model"]) -> None:
self.query_str = query_str
self.field_name: str = ""
self.related_parts: List[str] = []
self.related_str: str = ""
self.table_prefix = ""
self.source_model = model_cls
self.target_model = model_cls
self.is_through = False
self._split_value_into_parts(query_str)
self._determine_filter_target_table()
def __eq__(self, other: object) -> bool: # pragma: no cover
if not isinstance(other, QueryAction):
return False
return self.query_str == other.query_str
def __hash__(self) -> Any:
return hash((self.table_prefix, self.query_str))
@abc.abstractmethod
def _split_value_into_parts(self, query_str: str) -> None: # pragma: no cover
"""
Splits string into related parts and field_name
:param query_str: query action string to split (i..e filter or order by)
:type query_str: str
"""
pass
@abc.abstractmethod
def get_text_clause(
self,
) -> sqlalchemy.sql.expression.TextClause: # pragma: no cover
pass
@property
def table(self) -> sqlalchemy.Table:
"""Shortcut to sqlalchemy Table of filtered target model"""
return self.target_model.Meta.table
@property
def column(self) -> sqlalchemy.Column:
"""Shortcut to sqlalchemy column of filtered target model"""
aliased_name = self.target_model.get_column_alias(self.field_name)
return self.target_model.Meta.table.columns[aliased_name]
def update_select_related(self, select_related: List[str]) -> List[str]:
"""
Updates list of select related with related part included in the filter key.
That way If you want to just filter by relation you do not have to provide
select_related separately.
:param select_related: list of relation join strings
:type select_related: List[str]
:return: list of relation joins with implied joins from filter added
:rtype: List[str]
"""
select_related = select_related[:]
if self.related_str and not any(
rel.startswith(self.related_str) for rel in select_related
):
select_related.append(self.related_str)
return select_related
def _determine_filter_target_table(self) -> None:
"""
Walks the relation to retrieve the actual model on which the clause should be
constructed, extracts alias based on last relation leading to target model.
"""
(
self.table_prefix,
self.target_model,
self.related_str,
self.is_through,
) = get_relationship_alias_model_and_str(self.source_model, self.related_parts)

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from typing import Any, List, TYPE_CHECKING, Tuple, Type from typing import Any, List, TYPE_CHECKING, Tuple, Type
import ormar # noqa I100 import ormar # noqa I100
from ormar.queryset.filter_action import FilterAction from ormar.queryset.actions.filter_action import FilterAction
from ormar.queryset.utils import get_relationship_alias_model_and_str from ormar.queryset.utils import get_relationship_alias_model_and_str
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover

View File

@ -1,7 +1,7 @@
from typing import List from typing import List
import sqlalchemy import sqlalchemy
from ormar.queryset.filter_action import FilterAction from ormar.queryset.actions.filter_action import FilterAction
class FilterQuery: class FilterQuery:

View File

@ -14,11 +14,13 @@ from typing import (
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
from ormar.exceptions import RelationshipInstanceError # noqa I100 import ormar # noqa I100
from ormar.exceptions import RelationshipInstanceError
from ormar.relations import AliasManager from ormar.relations import AliasManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
from ormar.queryset import OrderAction
class SqlJoin: class SqlJoin:
@ -29,7 +31,7 @@ class SqlJoin:
columns: List[sqlalchemy.Column], columns: List[sqlalchemy.Column],
fields: Optional[Union[Set, Dict]], fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]], exclude_fields: Optional[Union[Set, Dict]],
order_columns: Optional[List], order_columns: Optional[List["OrderAction"]],
sorted_orders: OrderedDict, sorted_orders: OrderedDict,
main_model: Type["Model"], main_model: Type["Model"],
relation_name: str, relation_name: str,
@ -89,7 +91,18 @@ class SqlJoin:
""" """
return self.main_model.Meta.alias_manager return self.main_model.Meta.alias_manager
def on_clause(self, previous_alias: str, from_clause: str, to_clause: str,) -> text: @property
def to_table(self) -> str:
"""
Shortcut to table name of the next model
:return: name of the target table
:rtype: str
"""
return self.next_model.Meta.table.name
def _on_clause(
self, previous_alias: str, from_clause: str, to_clause: str,
) -> text:
""" """
Receives aliases and names of both ends of the join and combines them Receives aliases and names of both ends of the join and combines them
into one text clause used in joins. into one text clause used in joins.
@ -118,7 +131,7 @@ class SqlJoin:
:rtype: Tuple[List[str], Join, List[TextClause], collections.OrderedDict] :rtype: Tuple[List[str], Join, List[TextClause], collections.OrderedDict]
""" """
if self.target_field.is_multi: if self.target_field.is_multi:
self.process_m2m_through_table() self._process_m2m_through_table()
self.next_model = self.target_field.to self.next_model = self.target_field.to
self._forward_join() self._forward_join()
@ -207,7 +220,7 @@ class SqlJoin:
self.sorted_orders, self.sorted_orders,
) = sql_join.build_join() ) = sql_join.build_join()
def process_m2m_through_table(self) -> None: def _process_m2m_through_table(self) -> None:
""" """
Process Through table of the ManyToMany relation so that source table is Process Through table of the ManyToMany relation so that source table is
linked to the through table (one additional join) linked to the through table (one additional join)
@ -222,8 +235,7 @@ class SqlJoin:
To point to through model To point to through model
""" """
new_part = self.process_m2m_related_name_change() new_part = self._process_m2m_related_name_change()
self._replace_many_to_many_order_by_columns(self.relation_name, new_part)
self.next_model = self.target_field.through self.next_model = self.target_field.through
self._forward_join() self._forward_join()
@ -232,7 +244,7 @@ class SqlJoin:
self.own_alias = self.next_alias self.own_alias = self.next_alias
self.target_field = self.next_model.Meta.model_fields[self.relation_name] self.target_field = self.next_model.Meta.model_fields[self.relation_name]
def process_m2m_related_name_change(self, reverse: bool = False) -> str: def _process_m2m_related_name_change(self, reverse: bool = False) -> str:
""" """
Extracts relation name to link join through the Through model declared on Extracts relation name to link join through the Through model declared on
relation field. relation field.
@ -272,24 +284,21 @@ class SqlJoin:
Process order_by causes for non m2m relations. Process order_by causes for non m2m relations.
""" """
to_table = self.next_model.Meta.table.name to_key, from_key = self._get_to_and_from_keys()
to_key, from_key = self.get_to_and_from_keys()
on_clause = self.on_clause( on_clause = self._on_clause(
previous_alias=self.own_alias, previous_alias=self.own_alias,
from_clause=f"{self.target_field.owner.Meta.tablename}.{from_key}", from_clause=f"{self.target_field.owner.Meta.tablename}.{from_key}",
to_clause=f"{to_table}.{to_key}", to_clause=f"{self.to_table}.{to_key}",
)
target_table = self.alias_manager.prefixed_table_name(
self.next_alias, self.to_table
) )
target_table = self.alias_manager.prefixed_table_name(self.next_alias, to_table)
self.select_from = sqlalchemy.sql.outerjoin( self.select_from = sqlalchemy.sql.outerjoin(
self.select_from, target_table, on_clause self.select_from, target_table, on_clause
) )
pkname_alias = self.next_model.get_column_alias(self.next_model.Meta.pkname) self._get_order_bys()
if not self.target_field.is_multi:
self.get_order_bys(
to_table=to_table, pkname_alias=pkname_alias,
)
# TODO: fix fields and exclusions for through model? # TODO: fix fields and exclusions for through model?
self_related_fields = self.next_model.own_table_columns( self_related_fields = self.next_model.own_table_columns(
@ -305,88 +314,35 @@ class SqlJoin:
) )
self.used_aliases.append(self.next_alias) self.used_aliases.append(self.next_alias)
def _replace_many_to_many_order_by_columns(self, part: str, new_part: str) -> None: def _set_default_primary_key_order_by(self) -> None:
""" clause = ormar.OrderAction(
Substitutes the name of the relation with actual model name in m2m order bys. order_str=self.next_model.Meta.pkname,
model_cls=self.next_model,
:param part: name of the field with relation alias=self.next_alias,
:type part: str
:param new_part: name of the target model
:type new_part: str
"""
if self.order_columns:
split_order_columns = [
x.split("__") for x in self.order_columns if "__" in x
]
for condition in split_order_columns:
if self._check_if_condition_apply(condition, part):
condition[-2] = condition[-2].replace(part, new_part)
self.order_columns = [x for x in self.order_columns if "__" not in x] + [
"__".join(x) for x in split_order_columns
]
@staticmethod
def _check_if_condition_apply(condition: List, part: str) -> bool:
"""
Checks filter conditions to find if they apply to current join.
:param condition: list of parts of condition split by '__'
:type condition: List[str]
:param part: name of the current relation join.
:type part: str
:return: result of the check
:rtype: bool
"""
return len(condition) >= 2 and (
condition[-2] == part or condition[-2][1:] == part
) )
self.sorted_orders[clause] = clause.get_text_clause()
def set_aliased_order_by(self, condition: List[str], to_table: str,) -> None: def _get_order_bys(self) -> None: # noqa: CCR001
"""
Substitute hyphens ('-') with descending order.
Construct actual sqlalchemy text clause using aliased table and column name.
:param condition: list of parts of a current condition split by '__'
:type condition: List[str]
:param to_table: target table
:type to_table: sqlalchemy.sql.elements.quoted_name
"""
direction = f"{'desc' if condition[0][0] == '-' else ''}"
column_alias = self.next_model.get_column_alias(condition[-1])
order = text(f"{self.next_alias}_{to_table}.{column_alias} {direction}")
self.sorted_orders["__".join(condition)] = order
def get_order_bys(self, to_table: str, pkname_alias: str,) -> None: # noqa: CCR001
""" """
Triggers construction of order bys if they are given. Triggers construction of order bys if they are given.
Otherwise by default each table is sorted by a primary key column asc. Otherwise by default each table is sorted by a primary key column asc.
:param to_table: target table
:type to_table: sqlalchemy.sql.elements.quoted_name
:param pkname_alias: alias of the primary key column
:type pkname_alias: str
""" """
alias = self.next_alias alias = self.next_alias
if self.order_columns: if self.order_columns:
current_table_sorted = False current_table_sorted = False
split_order_columns = [ for condition in self.order_columns:
x.split("__") for x in self.order_columns if "__" in x if condition.check_if_filter_apply(
] target_model=self.next_model, alias=alias
for condition in split_order_columns: ):
if self._check_if_condition_apply(condition, self.relation_name):
current_table_sorted = True current_table_sorted = True
self.set_aliased_order_by( self.sorted_orders[condition] = condition.get_text_clause()
condition=condition, to_table=to_table, if not current_table_sorted and not self.target_field.is_multi:
) self._set_default_primary_key_order_by()
if not current_table_sorted:
order = text(f"{alias}_{to_table}.{pkname_alias}")
self.sorted_orders[f"{alias}.{pkname_alias}"] = order
else: elif not self.target_field.is_multi:
order = text(f"{alias}_{to_table}.{pkname_alias}") self._set_default_primary_key_order_by()
self.sorted_orders[f"{alias}.{pkname_alias}"] = order
def get_to_and_from_keys(self) -> Tuple[str, str]: def _get_to_and_from_keys(self) -> Tuple[str, str]:
""" """
Based on the relation type, name of the relation and previous models and parts Based on the relation type, name of the relation and previous models and parts
stored in JoinParameters it resolves the current to and from keys, which are stored in JoinParameters it resolves the current to and from keys, which are
@ -396,7 +352,7 @@ class SqlJoin:
:rtype: Tuple[str, str] :rtype: Tuple[str, str]
""" """
if self.target_field.is_multi: if self.target_field.is_multi:
to_key = self.process_m2m_related_name_change(reverse=True) to_key = self._process_m2m_related_name_change(reverse=True)
from_key = self.main_model.get_column_alias(self.main_model.Meta.pkname) from_key = self.main_model.get_column_alias(self.main_model.Meta.pkname)
elif self.target_field.virtual: elif self.target_field.virtual:

View File

@ -20,6 +20,7 @@ from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from ormar import Model from ormar import Model
from ormar.fields import ForeignKeyField, BaseField from ormar.fields import ForeignKeyField, BaseField
from ormar.queryset import OrderAction
def add_relation_field_to_fields( def add_relation_field_to_fields(
@ -128,7 +129,7 @@ class PrefetchQuery:
exclude_fields: Optional[Union[Dict, Set]], exclude_fields: Optional[Union[Dict, Set]],
prefetch_related: List, prefetch_related: List,
select_related: List, select_related: List,
orders_by: List, orders_by: List["OrderAction"],
) -> None: ) -> None:
self.model = model_cls self.model = model_cls
@ -141,7 +142,9 @@ class PrefetchQuery:
self.models: Dict = {} self.models: Dict = {}
self.select_dict = translate_list_to_dict(self._select_related) self.select_dict = translate_list_to_dict(self._select_related)
self.orders_by = orders_by or [] self.orders_by = orders_by or []
self.order_dict = translate_list_to_dict(self.orders_by, is_order=True) self.order_dict = translate_list_to_dict(
[x.query_str for x in self.orders_by], is_order=True
)
async def prefetch_related( async def prefetch_related(
self, models: Sequence["Model"], rows: List self, models: Sequence["Model"], rows: List

View File

@ -8,11 +8,12 @@ from sqlalchemy import text
import ormar # noqa I100 import ormar # noqa I100
from ormar.models.helpers.models import group_related_list from ormar.models.helpers.models import group_related_list
from ormar.queryset import FilterQuery, LimitQuery, OffsetQuery, OrderQuery from ormar.queryset import FilterQuery, LimitQuery, OffsetQuery, OrderQuery
from ormar.queryset.filter_action import FilterAction from ormar.queryset.actions.filter_action import FilterAction
from ormar.queryset.join import SqlJoin from ormar.queryset.join import SqlJoin
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
from ormar.queryset import OrderAction
class Query: class Query:
@ -26,7 +27,7 @@ class Query:
offset: Optional[int], offset: Optional[int],
fields: Optional[Union[Dict, Set]], fields: Optional[Union[Dict, Set]],
exclude_fields: Optional[Union[Dict, Set]], exclude_fields: Optional[Union[Dict, Set]],
order_bys: Optional[List], order_bys: Optional[List["OrderAction"]],
limit_raw_sql: bool, limit_raw_sql: bool,
) -> None: ) -> None:
self.query_offset = offset self.query_offset = offset
@ -45,7 +46,7 @@ class Query:
self.select_from: List[str] = [] self.select_from: List[str] = []
self.columns = [sqlalchemy.Column] self.columns = [sqlalchemy.Column]
self.order_columns = order_bys self.order_columns = order_bys
self.sorted_orders: OrderedDict = OrderedDict() self.sorted_orders: OrderedDict[OrderAction, text] = OrderedDict()
self._init_sorted_orders() self._init_sorted_orders()
self.limit_raw_sql = limit_raw_sql self.limit_raw_sql = limit_raw_sql
@ -58,28 +59,6 @@ class Query:
for clause in self.order_columns: for clause in self.order_columns:
self.sorted_orders[clause] = None self.sorted_orders[clause] = None
@property
def prefixed_pk_name(self) -> str:
"""
Shortcut for extracting prefixed with alias primary key column name from main
model
:return: alias of pk column prefix with table name.
:rtype: str
"""
pkname_alias = self.model_cls.get_column_alias(self.model_cls.Meta.pkname)
return f"{self.table.name}.{pkname_alias}"
def alias(self, name: str) -> str:
"""
Shortcut to extracting column alias from given master model.
:param name: name of column
:type name: str
:return: alias of given column name
:rtype: str
"""
return self.model_cls.get_column_alias(name)
def apply_order_bys_for_primary_model(self) -> None: # noqa: CCR001 def apply_order_bys_for_primary_model(self) -> None: # noqa: CCR001
""" """
Applies order_by queries on main model when it's used as a subquery. Applies order_by queries on main model when it's used as a subquery.
@ -88,16 +67,13 @@ class Query:
""" """
if self.order_columns: if self.order_columns:
for clause in self.order_columns: for clause in self.order_columns:
if "__" not in clause: if clause.is_source_model_order:
text_clause = ( self.sorted_orders[clause] = clause.get_text_clause()
text(f"{self.table.name}.{self.alias(clause[1:])} desc")
if clause.startswith("-")
else text(f"{self.table.name}.{self.alias(clause)}")
)
self.sorted_orders[clause] = text_clause
else: else:
order = text(self.prefixed_pk_name) clause = ormar.OrderAction(
self.sorted_orders[self.prefixed_pk_name] = order order_str=self.model_cls.Meta.pkname, model_cls=self.model_cls
)
self.sorted_orders[clause] = clause.get_text_clause()
def _pagination_query_required(self) -> bool: def _pagination_query_required(self) -> bool:
""" """
@ -208,7 +184,9 @@ class Query:
for filter_clause in self.exclude_clauses for filter_clause in self.exclude_clauses
if filter_clause.table_prefix == "" if filter_clause.table_prefix == ""
] ]
sorts_to_use = {k: v for k, v in self.sorted_orders.items() if "__" not in k} sorts_to_use = {
k: v for k, v in self.sorted_orders.items() if k.is_source_model_order
}
expr = FilterQuery(filter_clauses=filters_to_use).apply(expr) expr = FilterQuery(filter_clauses=filters_to_use).apply(expr)
expr = FilterQuery(filter_clauses=excludes_to_use, exclude=True).apply(expr) expr = FilterQuery(filter_clauses=excludes_to_use, exclude=True).apply(expr)
expr = OrderQuery(sorted_orders=sorts_to_use).apply(expr) expr = OrderQuery(sorted_orders=sorts_to_use).apply(expr)

View File

@ -21,6 +21,7 @@ import ormar # noqa I100
from ormar import MultipleMatches, NoMatch from ormar import MultipleMatches, NoMatch
from ormar.exceptions import ModelError, ModelPersistenceError, QueryDefinitionError from ormar.exceptions import ModelError, ModelPersistenceError, QueryDefinitionError
from ormar.queryset import FilterQuery from ormar.queryset import FilterQuery
from ormar.queryset.actions.order_action import OrderAction
from ormar.queryset.clause import QueryClause from ormar.queryset.clause import QueryClause
from ormar.queryset.prefetch_query import PrefetchQuery from ormar.queryset.prefetch_query import PrefetchQuery
from ormar.queryset.query import Query from ormar.queryset.query import Query
@ -514,7 +515,12 @@ class QuerySet(Generic[T]):
if not isinstance(columns, list): if not isinstance(columns, list):
columns = [columns] columns = [columns]
order_bys = self.order_bys + [x for x in columns if x not in self.order_bys] orders_by = [
OrderAction(order_str=x, model_cls=self.model_cls) # type: ignore
for x in columns
]
order_bys = self.order_bys + [x for x in orders_by if x not in self.order_bys]
return self.__class__( return self.__class__(
model_cls=self.model, model_cls=self.model,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
@ -713,7 +719,14 @@ class QuerySet(Generic[T]):
return await self.filter(**kwargs).first() return await self.filter(**kwargs).first()
expr = self.build_select_expression( expr = self.build_select_expression(
limit=1, order_bys=[f"{self.model.Meta.pkname}"] + self.order_bys limit=1,
order_bys=[
OrderAction(
order_str=f"{self.model.Meta.pkname}",
model_cls=self.model_cls, # type: ignore
)
]
+ self.order_bys,
) )
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
processed_rows = self._process_query_result_rows(rows) processed_rows = self._process_query_result_rows(rows)
@ -742,7 +755,14 @@ class QuerySet(Generic[T]):
if not self.filter_clauses: if not self.filter_clauses:
expr = self.build_select_expression( expr = self.build_select_expression(
limit=1, order_bys=[f"-{self.model.Meta.pkname}"] + self.order_bys limit=1,
order_bys=[
OrderAction(
order_str=f"-{self.model.Meta.pkname}",
model_cls=self.model_cls, # type: ignore
)
]
+ self.order_bys,
) )
else: else:
expr = self.build_select_expression() expr = self.build_select_expression()

View File

@ -232,16 +232,24 @@ def get_relationship_alias_model_and_str(
is_through = False is_through = False
model_cls = source_model model_cls = source_model
previous_model = model_cls previous_model = model_cls
previous_models = [model_cls]
manager = model_cls.Meta.alias_manager manager = model_cls.Meta.alias_manager
for relation in related_parts[:]: for relation in related_parts[:]:
related_field = model_cls.Meta.model_fields[relation] related_field = model_cls.Meta.model_fields[relation]
if related_field.is_through: if related_field.is_through:
# through is always last - cannot go further
is_through = True is_through = True
related_parts = [ related_parts.remove(relation)
x.replace(relation, related_field.related_name) if x == relation else x through_field = related_field.owner.Meta.model_fields[
for x in related_parts related_field.related_name or ""
] ]
relation = related_field.related_name if len(previous_models) > 1 and previous_models[-2] == through_field.to:
previous_model = through_field.to
relation = through_field.related_name
else:
relation = related_field.related_name
if related_field.is_multi: if related_field.is_multi:
previous_model = related_field.through previous_model = related_field.through
relation = related_field.default_target_field_name() # type: ignore relation = related_field.default_target_field_name() # type: ignore
@ -250,6 +258,8 @@ def get_relationship_alias_model_and_str(
) )
model_cls = related_field.to model_cls = related_field.to
previous_model = model_cls previous_model = model_cls
if not is_through:
previous_models.append(previous_model)
relation_str = "__".join(related_parts) relation_str = "__".join(related_parts)
return table_prefix, model_cls, relation_str, is_through return table_prefix, model_cls, relation_str, is_through

View File

@ -34,6 +34,14 @@ class PostCategory(ormar.Model):
param_name: str = ormar.String(default="Name", max_length=200) param_name: str = ormar.String(default="Name", max_length=200)
class Blog(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=200)
class Post(ormar.Model): class Post(ormar.Model):
class Meta(BaseMeta): class Meta(BaseMeta):
pass pass
@ -41,6 +49,7 @@ class Post(ormar.Model):
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=200) title: str = ormar.String(max_length=200)
categories = ormar.ManyToMany(Category, through=PostCategory) categories = ormar.ManyToMany(Category, through=PostCategory)
blog = ormar.ForeignKey(Blog)
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
@ -146,18 +155,86 @@ async def test_filtering_by_through_model() -> Any:
) )
post2 = ( post2 = (
await Post.objects.filter(postcategory__sort_order__gt=1) await Post.objects.select_related("categories")
.select_related("categories") .filter(postcategory__sort_order__gt=1)
.get() .get()
) )
assert len(post2.categories) == 1 assert len(post2.categories) == 1
assert post2.categories[0].postcategory.sort_order == 2 assert post2.categories[0].postcategory.sort_order == 2
post3 = await Post.objects.filter( post3 = await Post.objects.filter(
categories__postcategory__param_name="volume").get() categories__postcategory__param_name="volume"
).get()
assert len(post3.categories) == 1 assert len(post3.categories) == 1
assert post3.categories[0].postcategory.param_name == "volume" assert post3.categories[0].postcategory.param_name == "volume"
@pytest.mark.asyncio
async def test_deep_filtering_by_through_model() -> Any:
async with database:
blog = await Blog(title="My Blog").save()
post = await Post(title="Test post", blog=blog).save()
await post.categories.create(
name="Test category1",
postcategory={"sort_order": 1, "param_name": "volume"},
)
await post.categories.create(
name="Test category2", postcategory={"sort_order": 2, "param_name": "area"}
)
blog2 = (
await Blog.objects.select_related("posts__categories")
.filter(posts__postcategory__sort_order__gt=1)
.get()
)
assert len(blog2.posts) == 1
assert len(blog2.posts[0].categories) == 1
assert blog2.posts[0].categories[0].postcategory.sort_order == 2
blog3 = await Blog.objects.filter(
posts__categories__postcategory__param_name="volume"
).get()
assert len(blog3.posts) == 1
assert len(blog3.posts[0].categories) == 1
assert blog3.posts[0].categories[0].postcategory.param_name == "volume"
@pytest.mark.asyncio
async def test_ordering_by_through_model() -> Any:
async with database:
post = await Post(title="Test post").save()
await post.categories.create(
name="Test category1",
postcategory={"sort_order": 2, "param_name": "volume"},
)
await post.categories.create(
name="Test category2", postcategory={"sort_order": 1, "param_name": "area"}
)
await post.categories.create(
name="Test category3",
postcategory={"sort_order": 3, "param_name": "velocity"},
)
post2 = (
await Post.objects.select_related("categories")
.order_by("-postcategory__sort_order")
.get()
)
assert len(post2.categories) == 3
assert post2.categories[0].name == "Test category3"
assert post2.categories[2].name == "Test category2"
post3 = (
await Post.objects.select_related("categories")
.order_by("categories__postcategory__param_name")
.get()
)
assert len(post3.categories) == 3
assert post3.categories[0].postcategory.param_name == "area"
assert post3.categories[2].postcategory.param_name == "volume"
# TODO: check/ modify following # TODO: check/ modify following
# add to fields with class lower name (V) # add to fields with class lower name (V)
@ -166,10 +243,12 @@ async def test_filtering_by_through_model() -> Any:
# creating in queryset proxy (dict with through name and kwargs) (V) # creating in queryset proxy (dict with through name and kwargs) (V)
# loading the data into model instance of though model (V) <- fix fields ane exclude # loading the data into model instance of though model (V) <- fix fields ane exclude
# accessing from instance (V) <- no both sides only nested one is relevant, fix one side # accessing from instance (V) <- no both sides only nested one is relevant, fix one side
# filtering in filter (through name normally) (V) < - table prefix from normal relation, check if is_through needed # filtering in filter (through name normally) (V) < - table prefix from normal relation,
# check if is_through needed, resolved side of relation
# ordering by in order_by
# updating in query # updating in query
# ordering by in order_by
# modifying from instance (both sides?) # modifying from instance (both sides?)
# including/excluding in fields? # including/excluding in fields?
# allowing to change fk fields names in through model? # allowing to change fk fields names in through model?