wip - remove literal binds

This commit is contained in:
collerek
2022-01-05 18:19:14 +01:00
parent aab46de800
commit 3f264d974b
4 changed files with 77 additions and 33 deletions

View File

@ -1,7 +1,7 @@
import datetime
from typing import Any, Dict, TYPE_CHECKING, Type from typing import Any, Dict, TYPE_CHECKING, Type
import sqlalchemy import sqlalchemy
from sqlalchemy import text
import ormar # noqa: I100, I202 import ormar # noqa: I100, I202
from ormar.exceptions import QueryDefinitionError from ormar.exceptions import QueryDefinitionError
@ -125,7 +125,7 @@ class FilterAction(QueryAction):
sufix = "%" if "end" not in self.operator else "" sufix = "%" if "end" not in self.operator else ""
self.filter_value = f"{prefix}{self.filter_value}{sufix}" self.filter_value = f"{prefix}{self.filter_value}{sufix}"
def get_text_clause(self) -> sqlalchemy.sql.expression.TextClause: def get_text_clause(self) -> sqlalchemy.sql.expression.BinaryExpression:
""" """
Escapes characters if it's required. Escapes characters if it's required.
Substitutes values of the models if value is a ormar Model with its pk value. Substitutes values of the models if value is a ormar Model with its pk value.
@ -137,18 +137,44 @@ class FilterAction(QueryAction):
if isinstance(self.filter_value, ormar.Model): if isinstance(self.filter_value, ormar.Model):
self.filter_value = self.filter_value.pk self.filter_value = self.filter_value.pk
# self._convert_dates_if_required()
op_attr = FILTER_OPERATORS[self.operator] op_attr = FILTER_OPERATORS[self.operator]
if self.operator == "isnull": if self.operator == "isnull":
op_attr = "is_" if self.filter_value else "isnot" op_attr = "is_" if self.filter_value else "isnot"
filter_value = None filter_value = None
else: else:
filter_value = self.filter_value filter_value = self.filter_value
clause = getattr(self.column, op_attr)(filter_value) if self.table_prefix:
aliased_table = self.source_model.Meta.alias_manager.prefixed_table_name(
self.table_prefix, self.column.table
)
aliased_column = getattr(aliased_table.c, self.column.name)
else:
aliased_column = self.column
clause = getattr(aliased_column, op_attr)(filter_value)
clause = self._compile_clause( clause = self._compile_clause(
clause, modifiers={"escape": "\\" if self.has_escaped_character else None} clause, modifiers={"escape": "\\" if self.has_escaped_character else None}
) )
return clause return clause
def _convert_dates_if_required(self) -> None:
"""
Converts dates, time and datetime to isoformat
"""
if isinstance(
self.filter_value, (datetime.date, datetime.time, datetime.datetime)
):
self.filter_value = self.filter_value.isoformat()
if isinstance(self.filter_value, (list, tuple, set)):
self.filter_value = [
x.isoformat()
if isinstance(x, (datetime.date, datetime.time, datetime.datetime))
else x
for x in self.filter_value
]
def _compile_clause( def _compile_clause(
self, clause: sqlalchemy.sql.expression.BinaryExpression, modifiers: Dict self, clause: sqlalchemy.sql.expression.BinaryExpression, modifiers: Dict
) -> sqlalchemy.sql.expression.TextClause: ) -> sqlalchemy.sql.expression.TextClause:
@ -166,19 +192,24 @@ class FilterAction(QueryAction):
for modifier, modifier_value in modifiers.items(): for modifier, modifier_value in modifiers.items():
clause.modifiers[modifier] = modifier_value clause.modifiers[modifier] = modifier_value
clause_text = str( # compiled_clause = clause.compile(
clause.compile( # dialect=self.target_model.Meta.database._backend._dialect,
dialect=self.target_model.Meta.database._backend._dialect, # # compile_kwargs={"literal_binds": True},
compile_kwargs={"literal_binds": True}, # )
) #
) # compiled_clause2 = clause.compile(
alias = f"{self.table_prefix}_" if self.table_prefix else "" # dialect=self.target_model.Meta.database._backend._dialect,
aliased_name = f"{alias}{self.table.name}.{self.column.name}" # compile_kwargs={"literal_binds": True},
clause_text = clause_text.replace( # )
f"{self.table.name}.{self.column.name}", aliased_name #
) # alias = f"{self.table_prefix}_" if self.table_prefix else ""
dialect_name = self.target_model.Meta.database._backend._dialect.name # aliased_name = f"{alias}{self.table.name}.{self.column.name}"
if dialect_name != "sqlite": # pragma: no cover # clause_text = self.compile_query(compiled_query=compiled_clause)
clause_text = clause_text.replace("%%", "%") # remove %% in some dialects # clause_text = clause_text.replace(
clause = text(clause_text) # f"{self.table.name}.{self.column.name}", aliased_name
# )
# # dialect_name = self.target_model.Meta.database._backend._dialect.name
# # if dialect_name != "sqlite": # pragma: no cover
# # clause_text = clause_text.replace("%%", "%")
# clause = text(clause_text)
return clause return clause

View File

@ -121,19 +121,23 @@ class FilterGroup:
:return: complied and escaped clause :return: complied and escaped clause
:rtype: sqlalchemy.sql.elements.TextClause :rtype: sqlalchemy.sql.elements.TextClause
""" """
prefix = " NOT " if self.exclude else "" # prefix = " NOT " if self.exclude else ""
if self.filter_type == FilterType.AND: if self.filter_type == FilterType.AND:
clause = sqlalchemy.text( # clause = sqlalchemy.text(
f"{prefix}( " # f"{prefix}( "
+ str(sqlalchemy.sql.and_(*self._get_text_clauses())) # + str(sqlalchemy.sql.and_(*self._get_text_clauses()))
+ " )" # + " )"
) # )
clause = sqlalchemy.sql.and_(*self._get_text_clauses())
else: else:
clause = sqlalchemy.text( # clause = sqlalchemy.text(
f"{prefix}( " # f"{prefix}( "
+ str(sqlalchemy.sql.or_(*self._get_text_clauses())) # + str(sqlalchemy.sql.or_(*self._get_text_clauses()))
+ " )" # + " )"
) # )
clause = sqlalchemy.sql.or_(*self._get_text_clauses())
if self.exclude:
clause = sqlalchemy.sql.not_(clause)
return clause return clause

View File

@ -187,6 +187,7 @@ class Query:
for order in list(self.sorted_orders.keys()): for order in list(self.sorted_orders.keys()):
if order is not None and order.get_field_name_text() != pk_aliased_name: if order is not None and order.get_field_name_text() != pk_aliased_name:
aliased_col = order.get_field_name_text() aliased_col = order.get_field_name_text()
# maxes[aliased_col] = order.get_text_clause()
maxes[aliased_col] = order.get_min_or_max() maxes[aliased_col] = order.get_min_or_max()
elif order.get_field_name_text() == pk_aliased_name: elif order.get_field_name_text() == pk_aliased_name:
maxes[pk_aliased_name] = order.get_text_clause() maxes[pk_aliased_name] = order.get_text_clause()

View File

@ -35,6 +35,7 @@ class AliasManager:
def __init__(self) -> None: def __init__(self) -> None:
self._aliases_new: Dict[str, str] = dict() self._aliases_new: Dict[str, str] = dict()
self._reversed_aliases: Dict[str, str] = dict() self._reversed_aliases: Dict[str, str] = dict()
self._prefixed_tables: Dict[str, text] = dict()
def __contains__(self, item: str) -> bool: def __contains__(self, item: str) -> bool:
return self._aliases_new.__contains__(item) return self._aliases_new.__contains__(item)
@ -77,15 +78,20 @@ class AliasManager:
:rtype: List[text] :rtype: List[text]
""" """
alias = f"{alias}_" if alias else "" alias = f"{alias}_" if alias else ""
aliased_fields = [f"{alias}{x}" for x in fields]
# TODO: check if normal fields still needed or only aliased one
all_columns = ( all_columns = (
table.columns table.columns
if not fields if not fields
else [col for col in table.columns if col.name in fields] else [
col
for col in table.columns
if col.name in fields or col.name in aliased_fields
]
) )
return [column.label(f"{alias}{column.name}") for column in all_columns] return [column.label(f"{alias}{column.name}") for column in all_columns]
@staticmethod def prefixed_table_name(self, alias: str, table: sqlalchemy.Table) -> text:
def prefixed_table_name(alias: str, table: sqlalchemy.Table) -> text:
""" """
Creates text clause with table name with aliased name. Creates text clause with table name with aliased name.
@ -96,7 +102,9 @@ class AliasManager:
:return: sqlalchemy text clause as "table_name aliased_name" :return: sqlalchemy text clause as "table_name aliased_name"
:rtype: sqlalchemy text clause :rtype: sqlalchemy text clause
""" """
return table.alias(f"{alias}_{table.name}") full_alias = f"{alias}_{table.name}"
key = f"{full_alias}_{id(table)}"
return self._prefixed_tables.setdefault(key, table.alias(full_alias))
def add_relation_type( def add_relation_type(
self, source_model: Type["Model"], relation_name: str, reverse_name: str = None self, source_model: Type["Model"], relation_name: str, reverse_name: str = None