refactor query and queryclause into separate classes

This commit is contained in:
collerek
2020-08-09 13:27:53 +02:00
parent 836836c136
commit becb914e55
2 changed files with 215 additions and 160 deletions

BIN
.coverage

Binary file not shown.

View File

@ -35,6 +35,8 @@ FILTER_OPERATORS = {
"lte": "__le__", "lte": "__le__",
} }
ESCAPE_CHARACTERS = ["%", "_"]
class JoinParameters(NamedTuple): class JoinParameters(NamedTuple):
prev_model: Type["Model"] prev_model: Type["Model"]
@ -43,22 +45,23 @@ class JoinParameters(NamedTuple):
model_cls: Type["Model"] model_cls: Type["Model"]
class QuerySet: class Query:
ESCAPE_CHARACTERS = ["%", "_"]
def __init__( def __init__(
self, self,
model_cls: Type["Model"] = None, model_cls: Type["Model"],
filter_clauses: List = None, filter_clauses: List,
select_related: List = None, select_related: List,
limit_count: int = None, limit_count: int,
offset: int = None, offset: int,
) -> None: ) -> None:
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
self._select_related = [] if select_related is None else select_related
self.limit_count = limit_count
self.query_offset = offset self.query_offset = offset
self.limit_count = limit_count
self._select_related = select_related
self.filter_clauses = filter_clauses
self.model_cls = model_cls
self.table = self.model_cls.__table__
self.auto_related = [] self.auto_related = []
self.used_aliases = [] self.used_aliases = []
@ -67,16 +70,45 @@ class QuerySet:
self.columns = None self.columns = None
self.order_bys = None self.order_bys = None
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
return self.__class__(model_cls=owner) self.columns = list(self.table.columns)
self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")]
self.select_from = self.table
@property for key in self.model_cls.__model_fields__:
def database(self) -> databases.Database: if (
return self.model_cls.__database__ not self.model_cls.__model_fields__[key].nullable
and isinstance(
self.model_cls.__model_fields__[key], orm.fields.ForeignKey
)
and key not in self._select_related
):
self._select_related = [key] + self._select_related
@property start_params = JoinParameters(
def table(self) -> sqlalchemy.Table: self.model_cls, "", self.table.name, self.model_cls
return self.model_cls.__table__ )
self._extract_auto_required_relations(prev_model=start_params.prev_model)
self._include_auto_related_models()
self._select_related.sort(key=lambda item: (-len(item), item))
for item in self._select_related:
join_parameters = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls
)
for part in item.split("__"):
join_parameters = self._build_join_parameters(part, join_parameters)
expr = sqlalchemy.sql.select(self.columns)
expr = expr.select_from(self.select_from)
expr = self._apply_expression_modifiers(expr)
# print(expr.compile(compile_kwargs={"literal_binds": True}))
self._reset_query_parameters()
return expr, self._select_related
@staticmethod @staticmethod
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
@ -89,6 +121,25 @@ class QuerySet:
def prefixed_table_name(alias: str, name: str) -> text: def prefixed_table_name(alias: str, name: str) -> text:
return text(f"{name} {alias}_{name}") return text(f"{name} {alias}_{name}")
@staticmethod
def _field_is_a_foreign_key_and_no_circular_reference(
field: BaseField, field_name: str, rel_part: str
) -> bool:
return isinstance(field, ForeignKey) and field_name not in rel_part
def _field_qualifies_to_deeper_search(
self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str
) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
partial_match = any(
[x.startswith(prev_part_of_related) for x in self._select_related]
)
already_checked = any([x.startswith(rel_part) for x in self.auto_related])
return (
(field.virtual and parent_virtual)
or (partial_match and not already_checked)
) or not nested
def on_clause( def on_clause(
self, previous_alias: str, alias: str, from_clause: str, to_clause: str, self, previous_alias: str, alias: str, from_clause: str, to_clause: str,
) -> text: ) -> text:
@ -139,25 +190,6 @@ class QuerySet:
prev_model = model_cls prev_model = model_cls
return JoinParameters(prev_model, previous_alias, from_table, model_cls) return JoinParameters(prev_model, previous_alias, from_table, model_cls)
@staticmethod
def _field_is_a_foreign_key_and_no_circular_reference(
field: BaseField, field_name: str, rel_part: str
) -> bool:
return isinstance(field, ForeignKey) and field_name not in rel_part
def _field_qualifies_to_deeper_search(
self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str
) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
partial_match = any(
[x.startswith(prev_part_of_related) for x in self._select_related]
)
already_checked = any([x.startswith(rel_part) for x in self.auto_related])
return (
(field.virtual and parent_virtual)
or (partial_match and not already_checked)
) or not nested
def _extract_auto_required_relations( def _extract_auto_required_relations(
self, self,
prev_model: Type["Model"], prev_model: Type["Model"],
@ -221,130 +253,21 @@ class QuerySet:
self.auto_related = [] self.auto_related = []
self.used_aliases = [] self.used_aliases = []
def build_select_expression(self) -> sqlalchemy.sql.select:
self.columns = list(self.table.columns)
self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")]
self.select_from = self.table
for key in self.model_cls.__model_fields__: class QueryClause:
if ( def __init__(
not self.model_cls.__model_fields__[key].nullable self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
and isinstance( ) -> None:
self.model_cls.__model_fields__[key], orm.fields.ForeignKey
)
and key not in self._select_related
):
self._select_related = [key] + self._select_related
start_params = JoinParameters( self._select_related = select_related
self.model_cls, "", self.table.name, self.model_cls self.filter_clauses = filter_clauses
)
self._extract_auto_required_relations(prev_model=start_params.prev_model)
self._include_auto_related_models()
self._select_related.sort(key=lambda item: (-len(item), item))
for item in self._select_related: self.model_cls = model_cls
join_parameters = JoinParameters( self.table = self.model_cls.__table__
self.model_cls, "", self.table.name, self.model_cls
)
for part in item.split("__"): def filter( # noqa: A003
join_parameters = self._build_join_parameters(part, join_parameters) self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
expr = sqlalchemy.sql.select(self.columns)
expr = expr.select_from(self.select_from)
expr = self._apply_expression_modifiers(expr)
# print(expr.compile(compile_kwargs={"literal_binds": True}))
self._reset_query_parameters()
return expr
def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str]
) -> Tuple[List[str], str, "Model"]:
table_prefix = ""
model_cls = self.model_cls
select_related = [relation for relation in select_related]
# Add any implied select_related
related_str = "__".join(related_parts)
if related_str not in select_related:
select_related.append(related_str)
# Walk the relationships to the actual model class
# against which the comparison is being made.
previous_table = model_cls.__tablename__
for part in related_parts:
current_table = model_cls.__model_fields__[part].to.__tablename__
manager = model_cls._orm_relationship_manager
table_prefix = manager.resolve_relation_join(previous_table, current_table)
model_cls = model_cls.__model_fields__[part].to
previous_table = current_table
return select_related, table_prefix, model_cls
def _compile_clause(
self,
clause: sqlalchemy.sql.expression.BinaryExpression,
column: sqlalchemy.Column,
table: sqlalchemy.Table,
table_prefix: str,
modifiers: Dict,
) -> sqlalchemy.sql.expression.TextClause:
for modifier, modifier_value in modifiers.items():
clause.modifiers[modifier] = modifier_value
clause_text = str(
clause.compile(
dialect=self.model_cls.__database__._backend._dialect,
compile_kwargs={"literal_binds": True},
)
)
alias = f"{table_prefix}_" if table_prefix else ""
aliased_name = f"{alias}{table.name}.{column.name}"
clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name)
clause = text(clause_text)
return clause
def _escape_characters_in_clause(
self, op: str, value: Union[str, "Model"]
) -> Tuple[str, bool]:
has_escaped_character = False
if op in ["contains", "icontains"]:
if isinstance(value, orm.Model):
raise QueryDefinitionError(
"You cannot use contains and icontains with instance of the Model"
)
has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS if c in value)
if has_escaped_character:
# enable escape modifier
for char in self.ESCAPE_CHARACTERS:
value = value.replace(char, f"\\{char}")
value = f"%{value}%"
return value, has_escaped_character
@staticmethod
def _extract_operator_field_and_related(
parts: List[str],
) -> Tuple[str, str, Optional[List]]:
if parts[-1] in FILTER_OPERATORS:
op = parts[-1]
field_name = parts[-2]
related_parts = parts[:-2]
else:
op = "exact"
field_name = parts[-1]
related_parts = parts[:-1]
return op, field_name, related_parts
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
filter_clauses = self.filter_clauses filter_clauses = self.filter_clauses
select_related = list(self._select_related) select_related = list(self._select_related)
@ -395,9 +318,141 @@ class QuerySet:
table_prefix, table_prefix,
modifiers={"escape": "\\" if has_escaped_character else None}, modifiers={"escape": "\\" if has_escaped_character else None},
) )
filter_clauses.append(clause) filter_clauses.append(clause)
return filter_clauses, select_related
def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str]
) -> Tuple[List[str], str, "Model"]:
table_prefix = ""
model_cls = self.model_cls
select_related = [relation for relation in select_related]
# Add any implied select_related
related_str = "__".join(related_parts)
if related_str not in select_related:
select_related.append(related_str)
# Walk the relationships to the actual model class
# against which the comparison is being made.
previous_table = model_cls.__tablename__
for part in related_parts:
current_table = model_cls.__model_fields__[part].to.__tablename__
manager = model_cls._orm_relationship_manager
table_prefix = manager.resolve_relation_join(previous_table, current_table)
model_cls = model_cls.__model_fields__[part].to
previous_table = current_table
return select_related, table_prefix, model_cls
def _compile_clause(
self,
clause: sqlalchemy.sql.expression.BinaryExpression,
column: sqlalchemy.Column,
table: sqlalchemy.Table,
table_prefix: str,
modifiers: Dict,
) -> sqlalchemy.sql.expression.TextClause:
for modifier, modifier_value in modifiers.items():
clause.modifiers[modifier] = modifier_value
clause_text = str(
clause.compile(
dialect=self.model_cls.__database__._backend._dialect,
compile_kwargs={"literal_binds": True},
)
)
alias = f"{table_prefix}_" if table_prefix else ""
aliased_name = f"{alias}{table.name}.{column.name}"
clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name)
clause = text(clause_text)
return clause
@staticmethod
def _escape_characters_in_clause(
op: str, value: Union[str, "Model"]
) -> Tuple[str, bool]:
has_escaped_character = False
if op in ["contains", "icontains"]:
if isinstance(value, orm.Model):
raise QueryDefinitionError(
"You cannot use contains and icontains with instance of the Model"
)
has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value)
if has_escaped_character:
# enable escape modifier
for char in ESCAPE_CHARACTERS:
value = value.replace(char, f"\\{char}")
value = f"%{value}%"
return value, has_escaped_character
@staticmethod
def _extract_operator_field_and_related(
parts: List[str],
) -> Tuple[str, str, Optional[List]]:
if parts[-1] in FILTER_OPERATORS:
op = parts[-1]
field_name = parts[-2]
related_parts = parts[:-2]
else:
op = "exact"
field_name = parts[-1]
related_parts = parts[:-1]
return op, field_name, related_parts
class QuerySet:
def __init__(
self,
model_cls: Type["Model"] = None,
filter_clauses: List = None,
select_related: List = None,
limit_count: int = None,
offset: int = None,
) -> None:
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
self._select_related = [] if select_related is None else select_related
self.limit_count = limit_count
self.query_offset = offset
self.order_bys = None
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
return self.__class__(model_cls=owner)
@property
def database(self) -> databases.Database:
return self.model_cls.__database__
@property
def table(self) -> sqlalchemy.Table:
return self.model_cls.__table__
def build_select_expression(self) -> sqlalchemy.sql.select:
qry = Query(
model_cls=self.model_cls,
select_related=self._select_related,
filter_clauses=self.filter_clauses,
offset=self.query_offset,
limit_count=self.limit_count,
)
exp, self._select_related = qry.build_select_expression()
return exp
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
qryclause = QueryClause(
model_cls=self.model_cls,
select_related=self._select_related,
filter_clauses=self.filter_clauses,
)
filter_clauses, select_related = qryclause.filter(**kwargs)
return self.__class__( return self.__class__(
model_cls=self.model_cls, model_cls=self.model_cls,
filter_clauses=filter_clauses, filter_clauses=filter_clauses,