wip - remove literal binds
This commit is contained in:
@ -1,7 +1,7 @@
|
|||||||
|
import datetime
|
||||||
from typing import Any, Dict, TYPE_CHECKING, Type
|
from typing import Any, Dict, TYPE_CHECKING, Type
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy import text
|
|
||||||
|
|
||||||
import ormar # noqa: I100, I202
|
import ormar # noqa: I100, I202
|
||||||
from ormar.exceptions import QueryDefinitionError
|
from ormar.exceptions import QueryDefinitionError
|
||||||
@ -125,7 +125,7 @@ class FilterAction(QueryAction):
|
|||||||
sufix = "%" if "end" not in self.operator else ""
|
sufix = "%" if "end" not in self.operator else ""
|
||||||
self.filter_value = f"{prefix}{self.filter_value}{sufix}"
|
self.filter_value = f"{prefix}{self.filter_value}{sufix}"
|
||||||
|
|
||||||
def get_text_clause(self) -> sqlalchemy.sql.expression.TextClause:
|
def get_text_clause(self) -> sqlalchemy.sql.expression.BinaryExpression:
|
||||||
"""
|
"""
|
||||||
Escapes characters if it's required.
|
Escapes characters if it's required.
|
||||||
Substitutes values of the models if value is a ormar Model with its pk value.
|
Substitutes values of the models if value is a ormar Model with its pk value.
|
||||||
@ -137,18 +137,44 @@ class FilterAction(QueryAction):
|
|||||||
if isinstance(self.filter_value, ormar.Model):
|
if isinstance(self.filter_value, ormar.Model):
|
||||||
self.filter_value = self.filter_value.pk
|
self.filter_value = self.filter_value.pk
|
||||||
|
|
||||||
|
# self._convert_dates_if_required()
|
||||||
|
|
||||||
op_attr = FILTER_OPERATORS[self.operator]
|
op_attr = FILTER_OPERATORS[self.operator]
|
||||||
if self.operator == "isnull":
|
if self.operator == "isnull":
|
||||||
op_attr = "is_" if self.filter_value else "isnot"
|
op_attr = "is_" if self.filter_value else "isnot"
|
||||||
filter_value = None
|
filter_value = None
|
||||||
else:
|
else:
|
||||||
filter_value = self.filter_value
|
filter_value = self.filter_value
|
||||||
clause = getattr(self.column, op_attr)(filter_value)
|
if self.table_prefix:
|
||||||
|
aliased_table = self.source_model.Meta.alias_manager.prefixed_table_name(
|
||||||
|
self.table_prefix, self.column.table
|
||||||
|
)
|
||||||
|
aliased_column = getattr(aliased_table.c, self.column.name)
|
||||||
|
else:
|
||||||
|
aliased_column = self.column
|
||||||
|
clause = getattr(aliased_column, op_attr)(filter_value)
|
||||||
clause = self._compile_clause(
|
clause = self._compile_clause(
|
||||||
clause, modifiers={"escape": "\\" if self.has_escaped_character else None}
|
clause, modifiers={"escape": "\\" if self.has_escaped_character else None}
|
||||||
)
|
)
|
||||||
return clause
|
return clause
|
||||||
|
|
||||||
|
def _convert_dates_if_required(self) -> None:
|
||||||
|
"""
|
||||||
|
Converts dates, time and datetime to isoformat
|
||||||
|
"""
|
||||||
|
if isinstance(
|
||||||
|
self.filter_value, (datetime.date, datetime.time, datetime.datetime)
|
||||||
|
):
|
||||||
|
self.filter_value = self.filter_value.isoformat()
|
||||||
|
|
||||||
|
if isinstance(self.filter_value, (list, tuple, set)):
|
||||||
|
self.filter_value = [
|
||||||
|
x.isoformat()
|
||||||
|
if isinstance(x, (datetime.date, datetime.time, datetime.datetime))
|
||||||
|
else x
|
||||||
|
for x in self.filter_value
|
||||||
|
]
|
||||||
|
|
||||||
def _compile_clause(
|
def _compile_clause(
|
||||||
self, clause: sqlalchemy.sql.expression.BinaryExpression, modifiers: Dict
|
self, clause: sqlalchemy.sql.expression.BinaryExpression, modifiers: Dict
|
||||||
) -> sqlalchemy.sql.expression.TextClause:
|
) -> sqlalchemy.sql.expression.TextClause:
|
||||||
@ -166,19 +192,24 @@ class FilterAction(QueryAction):
|
|||||||
for modifier, modifier_value in modifiers.items():
|
for modifier, modifier_value in modifiers.items():
|
||||||
clause.modifiers[modifier] = modifier_value
|
clause.modifiers[modifier] = modifier_value
|
||||||
|
|
||||||
clause_text = str(
|
# compiled_clause = clause.compile(
|
||||||
clause.compile(
|
# dialect=self.target_model.Meta.database._backend._dialect,
|
||||||
dialect=self.target_model.Meta.database._backend._dialect,
|
# # compile_kwargs={"literal_binds": True},
|
||||||
compile_kwargs={"literal_binds": True},
|
# )
|
||||||
)
|
#
|
||||||
)
|
# compiled_clause2 = clause.compile(
|
||||||
alias = f"{self.table_prefix}_" if self.table_prefix else ""
|
# dialect=self.target_model.Meta.database._backend._dialect,
|
||||||
aliased_name = f"{alias}{self.table.name}.{self.column.name}"
|
# compile_kwargs={"literal_binds": True},
|
||||||
clause_text = clause_text.replace(
|
# )
|
||||||
f"{self.table.name}.{self.column.name}", aliased_name
|
#
|
||||||
)
|
# alias = f"{self.table_prefix}_" if self.table_prefix else ""
|
||||||
dialect_name = self.target_model.Meta.database._backend._dialect.name
|
# aliased_name = f"{alias}{self.table.name}.{self.column.name}"
|
||||||
if dialect_name != "sqlite": # pragma: no cover
|
# clause_text = self.compile_query(compiled_query=compiled_clause)
|
||||||
clause_text = clause_text.replace("%%", "%") # remove %% in some dialects
|
# clause_text = clause_text.replace(
|
||||||
clause = text(clause_text)
|
# f"{self.table.name}.{self.column.name}", aliased_name
|
||||||
|
# )
|
||||||
|
# # dialect_name = self.target_model.Meta.database._backend._dialect.name
|
||||||
|
# # if dialect_name != "sqlite": # pragma: no cover
|
||||||
|
# # clause_text = clause_text.replace("%%", "%")
|
||||||
|
# clause = text(clause_text)
|
||||||
return clause
|
return clause
|
||||||
|
|||||||
@ -121,19 +121,23 @@ class FilterGroup:
|
|||||||
:return: complied and escaped clause
|
:return: complied and escaped clause
|
||||||
:rtype: sqlalchemy.sql.elements.TextClause
|
:rtype: sqlalchemy.sql.elements.TextClause
|
||||||
"""
|
"""
|
||||||
prefix = " NOT " if self.exclude else ""
|
# prefix = " NOT " if self.exclude else ""
|
||||||
if self.filter_type == FilterType.AND:
|
if self.filter_type == FilterType.AND:
|
||||||
clause = sqlalchemy.text(
|
# clause = sqlalchemy.text(
|
||||||
f"{prefix}( "
|
# f"{prefix}( "
|
||||||
+ str(sqlalchemy.sql.and_(*self._get_text_clauses()))
|
# + str(sqlalchemy.sql.and_(*self._get_text_clauses()))
|
||||||
+ " )"
|
# + " )"
|
||||||
)
|
# )
|
||||||
|
clause = sqlalchemy.sql.and_(*self._get_text_clauses())
|
||||||
else:
|
else:
|
||||||
clause = sqlalchemy.text(
|
# clause = sqlalchemy.text(
|
||||||
f"{prefix}( "
|
# f"{prefix}( "
|
||||||
+ str(sqlalchemy.sql.or_(*self._get_text_clauses()))
|
# + str(sqlalchemy.sql.or_(*self._get_text_clauses()))
|
||||||
+ " )"
|
# + " )"
|
||||||
)
|
# )
|
||||||
|
clause = sqlalchemy.sql.or_(*self._get_text_clauses())
|
||||||
|
if self.exclude:
|
||||||
|
clause = sqlalchemy.sql.not_(clause)
|
||||||
return clause
|
return clause
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -187,6 +187,7 @@ class Query:
|
|||||||
for order in list(self.sorted_orders.keys()):
|
for order in list(self.sorted_orders.keys()):
|
||||||
if order is not None and order.get_field_name_text() != pk_aliased_name:
|
if order is not None and order.get_field_name_text() != pk_aliased_name:
|
||||||
aliased_col = order.get_field_name_text()
|
aliased_col = order.get_field_name_text()
|
||||||
|
# maxes[aliased_col] = order.get_text_clause()
|
||||||
maxes[aliased_col] = order.get_min_or_max()
|
maxes[aliased_col] = order.get_min_or_max()
|
||||||
elif order.get_field_name_text() == pk_aliased_name:
|
elif order.get_field_name_text() == pk_aliased_name:
|
||||||
maxes[pk_aliased_name] = order.get_text_clause()
|
maxes[pk_aliased_name] = order.get_text_clause()
|
||||||
|
|||||||
@ -35,6 +35,7 @@ class AliasManager:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._aliases_new: Dict[str, str] = dict()
|
self._aliases_new: Dict[str, str] = dict()
|
||||||
self._reversed_aliases: Dict[str, str] = dict()
|
self._reversed_aliases: Dict[str, str] = dict()
|
||||||
|
self._prefixed_tables: Dict[str, text] = dict()
|
||||||
|
|
||||||
def __contains__(self, item: str) -> bool:
|
def __contains__(self, item: str) -> bool:
|
||||||
return self._aliases_new.__contains__(item)
|
return self._aliases_new.__contains__(item)
|
||||||
@ -77,15 +78,20 @@ class AliasManager:
|
|||||||
:rtype: List[text]
|
:rtype: List[text]
|
||||||
"""
|
"""
|
||||||
alias = f"{alias}_" if alias else ""
|
alias = f"{alias}_" if alias else ""
|
||||||
|
aliased_fields = [f"{alias}{x}" for x in fields]
|
||||||
|
# TODO: check if normal fields still needed or only aliased one
|
||||||
all_columns = (
|
all_columns = (
|
||||||
table.columns
|
table.columns
|
||||||
if not fields
|
if not fields
|
||||||
else [col for col in table.columns if col.name in fields]
|
else [
|
||||||
|
col
|
||||||
|
for col in table.columns
|
||||||
|
if col.name in fields or col.name in aliased_fields
|
||||||
|
]
|
||||||
)
|
)
|
||||||
return [column.label(f"{alias}{column.name}") for column in all_columns]
|
return [column.label(f"{alias}{column.name}") for column in all_columns]
|
||||||
|
|
||||||
@staticmethod
|
def prefixed_table_name(self, alias: str, table: sqlalchemy.Table) -> text:
|
||||||
def prefixed_table_name(alias: str, table: sqlalchemy.Table) -> text:
|
|
||||||
"""
|
"""
|
||||||
Creates text clause with table name with aliased name.
|
Creates text clause with table name with aliased name.
|
||||||
|
|
||||||
@ -96,7 +102,9 @@ class AliasManager:
|
|||||||
:return: sqlalchemy text clause as "table_name aliased_name"
|
:return: sqlalchemy text clause as "table_name aliased_name"
|
||||||
:rtype: sqlalchemy text clause
|
:rtype: sqlalchemy text clause
|
||||||
"""
|
"""
|
||||||
return table.alias(f"{alias}_{table.name}")
|
full_alias = f"{alias}_{table.name}"
|
||||||
|
key = f"{full_alias}_{id(table)}"
|
||||||
|
return self._prefixed_tables.setdefault(key, table.alias(full_alias))
|
||||||
|
|
||||||
def add_relation_type(
|
def add_relation_type(
|
||||||
self, source_model: Type["Model"], relation_name: str, reverse_name: str = None
|
self, source_model: Type["Model"], relation_name: str, reverse_name: str = None
|
||||||
|
|||||||
Reference in New Issue
Block a user