diff --git a/ormar/queryset/actions/filter_action.py b/ormar/queryset/actions/filter_action.py index 443b36f..dd20ec7 100644 --- a/ormar/queryset/actions/filter_action.py +++ b/ormar/queryset/actions/filter_action.py @@ -1,7 +1,7 @@ +import datetime from typing import Any, Dict, TYPE_CHECKING, Type import sqlalchemy -from sqlalchemy import text import ormar # noqa: I100, I202 from ormar.exceptions import QueryDefinitionError @@ -125,7 +125,7 @@ class FilterAction(QueryAction): sufix = "%" if "end" not in self.operator else "" self.filter_value = f"{prefix}{self.filter_value}{sufix}" - def get_text_clause(self) -> sqlalchemy.sql.expression.TextClause: + def get_text_clause(self) -> sqlalchemy.sql.expression.BinaryExpression: """ Escapes characters if it's required. Substitutes values of the models if value is a ormar Model with its pk value. @@ -137,18 +137,44 @@ class FilterAction(QueryAction): if isinstance(self.filter_value, ormar.Model): self.filter_value = self.filter_value.pk + # self._convert_dates_if_required() + op_attr = FILTER_OPERATORS[self.operator] if self.operator == "isnull": op_attr = "is_" if self.filter_value else "isnot" filter_value = None else: filter_value = self.filter_value - clause = getattr(self.column, op_attr)(filter_value) + if self.table_prefix: + aliased_table = self.source_model.Meta.alias_manager.prefixed_table_name( + self.table_prefix, self.column.table + ) + aliased_column = getattr(aliased_table.c, self.column.name) + else: + aliased_column = self.column + clause = getattr(aliased_column, op_attr)(filter_value) clause = self._compile_clause( clause, modifiers={"escape": "\\" if self.has_escaped_character else None} ) return clause + def _convert_dates_if_required(self) -> None: + """ + Converts dates, time and datetime to isoformat + """ + if isinstance( + self.filter_value, (datetime.date, datetime.time, datetime.datetime) + ): + self.filter_value = self.filter_value.isoformat() + + if isinstance(self.filter_value, (list, tuple, set)): + self.filter_value = [ + x.isoformat() + if isinstance(x, (datetime.date, datetime.time, datetime.datetime)) + else x + for x in self.filter_value + ] + def _compile_clause( self, clause: sqlalchemy.sql.expression.BinaryExpression, modifiers: Dict ) -> sqlalchemy.sql.expression.TextClause: @@ -166,19 +192,24 @@ class FilterAction(QueryAction): for modifier, modifier_value in modifiers.items(): clause.modifiers[modifier] = modifier_value - clause_text = str( - clause.compile( - dialect=self.target_model.Meta.database._backend._dialect, - compile_kwargs={"literal_binds": True}, - ) - ) - alias = f"{self.table_prefix}_" if self.table_prefix else "" - aliased_name = f"{alias}{self.table.name}.{self.column.name}" - clause_text = clause_text.replace( - f"{self.table.name}.{self.column.name}", aliased_name - ) - dialect_name = self.target_model.Meta.database._backend._dialect.name - if dialect_name != "sqlite": # pragma: no cover - clause_text = clause_text.replace("%%", "%") # remove %% in some dialects - clause = text(clause_text) + # compiled_clause = clause.compile( + # dialect=self.target_model.Meta.database._backend._dialect, + # # compile_kwargs={"literal_binds": True}, + # ) + # + # compiled_clause2 = clause.compile( + # dialect=self.target_model.Meta.database._backend._dialect, + # compile_kwargs={"literal_binds": True}, + # ) + # + # alias = f"{self.table_prefix}_" if self.table_prefix else "" + # aliased_name = f"{alias}{self.table.name}.{self.column.name}" + # clause_text = self.compile_query(compiled_query=compiled_clause) + # clause_text = clause_text.replace( + # f"{self.table.name}.{self.column.name}", aliased_name + # ) + # # dialect_name = self.target_model.Meta.database._backend._dialect.name + # # if dialect_name != "sqlite": # pragma: no cover + # # clause_text = clause_text.replace("%%", "%") + # clause = text(clause_text) return clause diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index 402dd6b..3ee6288 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -121,19 +121,23 @@ class FilterGroup: :return: complied and escaped clause :rtype: sqlalchemy.sql.elements.TextClause """ - prefix = " NOT " if self.exclude else "" + # prefix = " NOT " if self.exclude else "" if self.filter_type == FilterType.AND: - clause = sqlalchemy.text( - f"{prefix}( " - + str(sqlalchemy.sql.and_(*self._get_text_clauses())) - + " )" - ) + # clause = sqlalchemy.text( + # f"{prefix}( " + # + str(sqlalchemy.sql.and_(*self._get_text_clauses())) + # + " )" + # ) + clause = sqlalchemy.sql.and_(*self._get_text_clauses()) else: - clause = sqlalchemy.text( - f"{prefix}( " - + str(sqlalchemy.sql.or_(*self._get_text_clauses())) - + " )" - ) + # clause = sqlalchemy.text( + # f"{prefix}( " + # + str(sqlalchemy.sql.or_(*self._get_text_clauses())) + # + " )" + # ) + clause = sqlalchemy.sql.or_(*self._get_text_clauses()) + if self.exclude: + clause = sqlalchemy.sql.not_(clause) return clause diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index ddee8ad..bbfe012 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -187,6 +187,7 @@ class Query: for order in list(self.sorted_orders.keys()): if order is not None and order.get_field_name_text() != pk_aliased_name: aliased_col = order.get_field_name_text() + # maxes[aliased_col] = order.get_text_clause() maxes[aliased_col] = order.get_min_or_max() elif order.get_field_name_text() == pk_aliased_name: maxes[pk_aliased_name] = order.get_text_clause() diff --git a/ormar/relations/alias_manager.py b/ormar/relations/alias_manager.py index d13b7ea..e091e5e 100644 --- a/ormar/relations/alias_manager.py +++ b/ormar/relations/alias_manager.py @@ -35,6 +35,7 @@ class AliasManager: def __init__(self) -> None: self._aliases_new: Dict[str, str] = dict() self._reversed_aliases: Dict[str, str] = dict() + self._prefixed_tables: Dict[str, text] = dict() def __contains__(self, item: str) -> bool: return self._aliases_new.__contains__(item) @@ -77,15 +78,20 @@ class AliasManager: :rtype: List[text] """ alias = f"{alias}_" if alias else "" + aliased_fields = [f"{alias}{x}" for x in fields] + # TODO: check if normal fields still needed or only aliased one all_columns = ( table.columns if not fields - else [col for col in table.columns if col.name in fields] + else [ + col + for col in table.columns + if col.name in fields or col.name in aliased_fields + ] ) return [column.label(f"{alias}{column.name}") for column in all_columns] - @staticmethod - def prefixed_table_name(alias: str, table: sqlalchemy.Table) -> text: + def prefixed_table_name(self, alias: str, table: sqlalchemy.Table) -> text: """ Creates text clause with table name with aliased name. @@ -96,7 +102,9 @@ class AliasManager: :return: sqlalchemy text clause as "table_name aliased_name" :rtype: sqlalchemy text clause """ - return table.alias(f"{alias}_{table.name}") + full_alias = f"{alias}_{table.name}" + key = f"{full_alias}_{id(table)}" + return self._prefixed_tables.setdefault(key, table.alias(full_alias)) def add_relation_type( self, source_model: Type["Model"], relation_name: str, reverse_name: str = None