refactor query and queryclause into separate classes
This commit is contained in:
375
orm/queryset.py
375
orm/queryset.py
@ -35,6 +35,8 @@ FILTER_OPERATORS = {
|
|||||||
"lte": "__le__",
|
"lte": "__le__",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ESCAPE_CHARACTERS = ["%", "_"]
|
||||||
|
|
||||||
|
|
||||||
class JoinParameters(NamedTuple):
|
class JoinParameters(NamedTuple):
|
||||||
prev_model: Type["Model"]
|
prev_model: Type["Model"]
|
||||||
@ -43,22 +45,23 @@ class JoinParameters(NamedTuple):
|
|||||||
model_cls: Type["Model"]
|
model_cls: Type["Model"]
|
||||||
|
|
||||||
|
|
||||||
class QuerySet:
|
class Query:
|
||||||
ESCAPE_CHARACTERS = ["%", "_"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_cls: Type["Model"] = None,
|
model_cls: Type["Model"],
|
||||||
filter_clauses: List = None,
|
filter_clauses: List,
|
||||||
select_related: List = None,
|
select_related: List,
|
||||||
limit_count: int = None,
|
limit_count: int,
|
||||||
offset: int = None,
|
offset: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_cls = model_cls
|
|
||||||
self.filter_clauses = [] if filter_clauses is None else filter_clauses
|
|
||||||
self._select_related = [] if select_related is None else select_related
|
|
||||||
self.limit_count = limit_count
|
|
||||||
self.query_offset = offset
|
self.query_offset = offset
|
||||||
|
self.limit_count = limit_count
|
||||||
|
self._select_related = select_related
|
||||||
|
self.filter_clauses = filter_clauses
|
||||||
|
|
||||||
|
self.model_cls = model_cls
|
||||||
|
self.table = self.model_cls.__table__
|
||||||
|
|
||||||
self.auto_related = []
|
self.auto_related = []
|
||||||
self.used_aliases = []
|
self.used_aliases = []
|
||||||
@ -67,16 +70,45 @@ class QuerySet:
|
|||||||
self.columns = None
|
self.columns = None
|
||||||
self.order_bys = None
|
self.order_bys = None
|
||||||
|
|
||||||
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
|
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
|
||||||
return self.__class__(model_cls=owner)
|
self.columns = list(self.table.columns)
|
||||||
|
self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")]
|
||||||
|
self.select_from = self.table
|
||||||
|
|
||||||
@property
|
for key in self.model_cls.__model_fields__:
|
||||||
def database(self) -> databases.Database:
|
if (
|
||||||
return self.model_cls.__database__
|
not self.model_cls.__model_fields__[key].nullable
|
||||||
|
and isinstance(
|
||||||
|
self.model_cls.__model_fields__[key], orm.fields.ForeignKey
|
||||||
|
)
|
||||||
|
and key not in self._select_related
|
||||||
|
):
|
||||||
|
self._select_related = [key] + self._select_related
|
||||||
|
|
||||||
@property
|
start_params = JoinParameters(
|
||||||
def table(self) -> sqlalchemy.Table:
|
self.model_cls, "", self.table.name, self.model_cls
|
||||||
return self.model_cls.__table__
|
)
|
||||||
|
self._extract_auto_required_relations(prev_model=start_params.prev_model)
|
||||||
|
self._include_auto_related_models()
|
||||||
|
self._select_related.sort(key=lambda item: (-len(item), item))
|
||||||
|
|
||||||
|
for item in self._select_related:
|
||||||
|
join_parameters = JoinParameters(
|
||||||
|
self.model_cls, "", self.table.name, self.model_cls
|
||||||
|
)
|
||||||
|
|
||||||
|
for part in item.split("__"):
|
||||||
|
join_parameters = self._build_join_parameters(part, join_parameters)
|
||||||
|
|
||||||
|
expr = sqlalchemy.sql.select(self.columns)
|
||||||
|
expr = expr.select_from(self.select_from)
|
||||||
|
|
||||||
|
expr = self._apply_expression_modifiers(expr)
|
||||||
|
|
||||||
|
# print(expr.compile(compile_kwargs={"literal_binds": True}))
|
||||||
|
self._reset_query_parameters()
|
||||||
|
|
||||||
|
return expr, self._select_related
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
|
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
|
||||||
@ -89,6 +121,25 @@ class QuerySet:
|
|||||||
def prefixed_table_name(alias: str, name: str) -> text:
|
def prefixed_table_name(alias: str, name: str) -> text:
|
||||||
return text(f"{name} {alias}_{name}")
|
return text(f"{name} {alias}_{name}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _field_is_a_foreign_key_and_no_circular_reference(
|
||||||
|
field: BaseField, field_name: str, rel_part: str
|
||||||
|
) -> bool:
|
||||||
|
return isinstance(field, ForeignKey) and field_name not in rel_part
|
||||||
|
|
||||||
|
def _field_qualifies_to_deeper_search(
|
||||||
|
self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str
|
||||||
|
) -> bool:
|
||||||
|
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
|
||||||
|
partial_match = any(
|
||||||
|
[x.startswith(prev_part_of_related) for x in self._select_related]
|
||||||
|
)
|
||||||
|
already_checked = any([x.startswith(rel_part) for x in self.auto_related])
|
||||||
|
return (
|
||||||
|
(field.virtual and parent_virtual)
|
||||||
|
or (partial_match and not already_checked)
|
||||||
|
) or not nested
|
||||||
|
|
||||||
def on_clause(
|
def on_clause(
|
||||||
self, previous_alias: str, alias: str, from_clause: str, to_clause: str,
|
self, previous_alias: str, alias: str, from_clause: str, to_clause: str,
|
||||||
) -> text:
|
) -> text:
|
||||||
@ -139,25 +190,6 @@ class QuerySet:
|
|||||||
prev_model = model_cls
|
prev_model = model_cls
|
||||||
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _field_is_a_foreign_key_and_no_circular_reference(
|
|
||||||
field: BaseField, field_name: str, rel_part: str
|
|
||||||
) -> bool:
|
|
||||||
return isinstance(field, ForeignKey) and field_name not in rel_part
|
|
||||||
|
|
||||||
def _field_qualifies_to_deeper_search(
|
|
||||||
self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str
|
|
||||||
) -> bool:
|
|
||||||
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
|
|
||||||
partial_match = any(
|
|
||||||
[x.startswith(prev_part_of_related) for x in self._select_related]
|
|
||||||
)
|
|
||||||
already_checked = any([x.startswith(rel_part) for x in self.auto_related])
|
|
||||||
return (
|
|
||||||
(field.virtual and parent_virtual)
|
|
||||||
or (partial_match and not already_checked)
|
|
||||||
) or not nested
|
|
||||||
|
|
||||||
def _extract_auto_required_relations(
|
def _extract_auto_required_relations(
|
||||||
self,
|
self,
|
||||||
prev_model: Type["Model"],
|
prev_model: Type["Model"],
|
||||||
@ -221,130 +253,21 @@ class QuerySet:
|
|||||||
self.auto_related = []
|
self.auto_related = []
|
||||||
self.used_aliases = []
|
self.used_aliases = []
|
||||||
|
|
||||||
def build_select_expression(self) -> sqlalchemy.sql.select:
|
|
||||||
self.columns = list(self.table.columns)
|
|
||||||
self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")]
|
|
||||||
self.select_from = self.table
|
|
||||||
|
|
||||||
for key in self.model_cls.__model_fields__:
|
class QueryClause:
|
||||||
if (
|
def __init__(
|
||||||
not self.model_cls.__model_fields__[key].nullable
|
self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
|
||||||
and isinstance(
|
) -> None:
|
||||||
self.model_cls.__model_fields__[key], orm.fields.ForeignKey
|
|
||||||
)
|
|
||||||
and key not in self._select_related
|
|
||||||
):
|
|
||||||
self._select_related = [key] + self._select_related
|
|
||||||
|
|
||||||
start_params = JoinParameters(
|
self._select_related = select_related
|
||||||
self.model_cls, "", self.table.name, self.model_cls
|
self.filter_clauses = filter_clauses
|
||||||
)
|
|
||||||
self._extract_auto_required_relations(prev_model=start_params.prev_model)
|
|
||||||
self._include_auto_related_models()
|
|
||||||
self._select_related.sort(key=lambda item: (-len(item), item))
|
|
||||||
|
|
||||||
for item in self._select_related:
|
self.model_cls = model_cls
|
||||||
join_parameters = JoinParameters(
|
self.table = self.model_cls.__table__
|
||||||
self.model_cls, "", self.table.name, self.model_cls
|
|
||||||
)
|
|
||||||
|
|
||||||
for part in item.split("__"):
|
def filter( # noqa: A003
|
||||||
join_parameters = self._build_join_parameters(part, join_parameters)
|
self, **kwargs: Any
|
||||||
|
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
|
||||||
expr = sqlalchemy.sql.select(self.columns)
|
|
||||||
expr = expr.select_from(self.select_from)
|
|
||||||
|
|
||||||
expr = self._apply_expression_modifiers(expr)
|
|
||||||
|
|
||||||
# print(expr.compile(compile_kwargs={"literal_binds": True}))
|
|
||||||
self._reset_query_parameters()
|
|
||||||
|
|
||||||
return expr
|
|
||||||
|
|
||||||
def _determine_filter_target_table(
|
|
||||||
self, related_parts: List[str], select_related: List[str]
|
|
||||||
) -> Tuple[List[str], str, "Model"]:
|
|
||||||
|
|
||||||
table_prefix = ""
|
|
||||||
model_cls = self.model_cls
|
|
||||||
select_related = [relation for relation in select_related]
|
|
||||||
|
|
||||||
# Add any implied select_related
|
|
||||||
related_str = "__".join(related_parts)
|
|
||||||
if related_str not in select_related:
|
|
||||||
select_related.append(related_str)
|
|
||||||
|
|
||||||
# Walk the relationships to the actual model class
|
|
||||||
# against which the comparison is being made.
|
|
||||||
previous_table = model_cls.__tablename__
|
|
||||||
for part in related_parts:
|
|
||||||
current_table = model_cls.__model_fields__[part].to.__tablename__
|
|
||||||
manager = model_cls._orm_relationship_manager
|
|
||||||
table_prefix = manager.resolve_relation_join(previous_table, current_table)
|
|
||||||
model_cls = model_cls.__model_fields__[part].to
|
|
||||||
previous_table = current_table
|
|
||||||
return select_related, table_prefix, model_cls
|
|
||||||
|
|
||||||
def _compile_clause(
|
|
||||||
self,
|
|
||||||
clause: sqlalchemy.sql.expression.BinaryExpression,
|
|
||||||
column: sqlalchemy.Column,
|
|
||||||
table: sqlalchemy.Table,
|
|
||||||
table_prefix: str,
|
|
||||||
modifiers: Dict,
|
|
||||||
) -> sqlalchemy.sql.expression.TextClause:
|
|
||||||
for modifier, modifier_value in modifiers.items():
|
|
||||||
clause.modifiers[modifier] = modifier_value
|
|
||||||
|
|
||||||
clause_text = str(
|
|
||||||
clause.compile(
|
|
||||||
dialect=self.model_cls.__database__._backend._dialect,
|
|
||||||
compile_kwargs={"literal_binds": True},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
alias = f"{table_prefix}_" if table_prefix else ""
|
|
||||||
aliased_name = f"{alias}{table.name}.{column.name}"
|
|
||||||
clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name)
|
|
||||||
clause = text(clause_text)
|
|
||||||
return clause
|
|
||||||
|
|
||||||
def _escape_characters_in_clause(
|
|
||||||
self, op: str, value: Union[str, "Model"]
|
|
||||||
) -> Tuple[str, bool]:
|
|
||||||
has_escaped_character = False
|
|
||||||
|
|
||||||
if op in ["contains", "icontains"]:
|
|
||||||
if isinstance(value, orm.Model):
|
|
||||||
raise QueryDefinitionError(
|
|
||||||
"You cannot use contains and icontains with instance of the Model"
|
|
||||||
)
|
|
||||||
|
|
||||||
has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS if c in value)
|
|
||||||
|
|
||||||
if has_escaped_character:
|
|
||||||
# enable escape modifier
|
|
||||||
for char in self.ESCAPE_CHARACTERS:
|
|
||||||
value = value.replace(char, f"\\{char}")
|
|
||||||
value = f"%{value}%"
|
|
||||||
|
|
||||||
return value, has_escaped_character
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_operator_field_and_related(
|
|
||||||
parts: List[str],
|
|
||||||
) -> Tuple[str, str, Optional[List]]:
|
|
||||||
if parts[-1] in FILTER_OPERATORS:
|
|
||||||
op = parts[-1]
|
|
||||||
field_name = parts[-2]
|
|
||||||
related_parts = parts[:-2]
|
|
||||||
else:
|
|
||||||
op = "exact"
|
|
||||||
field_name = parts[-1]
|
|
||||||
related_parts = parts[:-1]
|
|
||||||
|
|
||||||
return op, field_name, related_parts
|
|
||||||
|
|
||||||
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
|
|
||||||
filter_clauses = self.filter_clauses
|
filter_clauses = self.filter_clauses
|
||||||
select_related = list(self._select_related)
|
select_related = list(self._select_related)
|
||||||
|
|
||||||
@ -395,9 +318,141 @@ class QuerySet:
|
|||||||
table_prefix,
|
table_prefix,
|
||||||
modifiers={"escape": "\\" if has_escaped_character else None},
|
modifiers={"escape": "\\" if has_escaped_character else None},
|
||||||
)
|
)
|
||||||
|
|
||||||
filter_clauses.append(clause)
|
filter_clauses.append(clause)
|
||||||
|
|
||||||
|
return filter_clauses, select_related
|
||||||
|
|
||||||
|
def _determine_filter_target_table(
|
||||||
|
self, related_parts: List[str], select_related: List[str]
|
||||||
|
) -> Tuple[List[str], str, "Model"]:
|
||||||
|
|
||||||
|
table_prefix = ""
|
||||||
|
model_cls = self.model_cls
|
||||||
|
select_related = [relation for relation in select_related]
|
||||||
|
|
||||||
|
# Add any implied select_related
|
||||||
|
related_str = "__".join(related_parts)
|
||||||
|
if related_str not in select_related:
|
||||||
|
select_related.append(related_str)
|
||||||
|
|
||||||
|
# Walk the relationships to the actual model class
|
||||||
|
# against which the comparison is being made.
|
||||||
|
previous_table = model_cls.__tablename__
|
||||||
|
for part in related_parts:
|
||||||
|
current_table = model_cls.__model_fields__[part].to.__tablename__
|
||||||
|
manager = model_cls._orm_relationship_manager
|
||||||
|
table_prefix = manager.resolve_relation_join(previous_table, current_table)
|
||||||
|
model_cls = model_cls.__model_fields__[part].to
|
||||||
|
previous_table = current_table
|
||||||
|
return select_related, table_prefix, model_cls
|
||||||
|
|
||||||
|
def _compile_clause(
|
||||||
|
self,
|
||||||
|
clause: sqlalchemy.sql.expression.BinaryExpression,
|
||||||
|
column: sqlalchemy.Column,
|
||||||
|
table: sqlalchemy.Table,
|
||||||
|
table_prefix: str,
|
||||||
|
modifiers: Dict,
|
||||||
|
) -> sqlalchemy.sql.expression.TextClause:
|
||||||
|
for modifier, modifier_value in modifiers.items():
|
||||||
|
clause.modifiers[modifier] = modifier_value
|
||||||
|
|
||||||
|
clause_text = str(
|
||||||
|
clause.compile(
|
||||||
|
dialect=self.model_cls.__database__._backend._dialect,
|
||||||
|
compile_kwargs={"literal_binds": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
alias = f"{table_prefix}_" if table_prefix else ""
|
||||||
|
aliased_name = f"{alias}{table.name}.{column.name}"
|
||||||
|
clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name)
|
||||||
|
clause = text(clause_text)
|
||||||
|
return clause
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _escape_characters_in_clause(
|
||||||
|
op: str, value: Union[str, "Model"]
|
||||||
|
) -> Tuple[str, bool]:
|
||||||
|
has_escaped_character = False
|
||||||
|
|
||||||
|
if op in ["contains", "icontains"]:
|
||||||
|
if isinstance(value, orm.Model):
|
||||||
|
raise QueryDefinitionError(
|
||||||
|
"You cannot use contains and icontains with instance of the Model"
|
||||||
|
)
|
||||||
|
|
||||||
|
has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value)
|
||||||
|
|
||||||
|
if has_escaped_character:
|
||||||
|
# enable escape modifier
|
||||||
|
for char in ESCAPE_CHARACTERS:
|
||||||
|
value = value.replace(char, f"\\{char}")
|
||||||
|
value = f"%{value}%"
|
||||||
|
|
||||||
|
return value, has_escaped_character
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_operator_field_and_related(
|
||||||
|
parts: List[str],
|
||||||
|
) -> Tuple[str, str, Optional[List]]:
|
||||||
|
if parts[-1] in FILTER_OPERATORS:
|
||||||
|
op = parts[-1]
|
||||||
|
field_name = parts[-2]
|
||||||
|
related_parts = parts[:-2]
|
||||||
|
else:
|
||||||
|
op = "exact"
|
||||||
|
field_name = parts[-1]
|
||||||
|
related_parts = parts[:-1]
|
||||||
|
|
||||||
|
return op, field_name, related_parts
|
||||||
|
|
||||||
|
|
||||||
|
class QuerySet:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_cls: Type["Model"] = None,
|
||||||
|
filter_clauses: List = None,
|
||||||
|
select_related: List = None,
|
||||||
|
limit_count: int = None,
|
||||||
|
offset: int = None,
|
||||||
|
) -> None:
|
||||||
|
self.model_cls = model_cls
|
||||||
|
self.filter_clauses = [] if filter_clauses is None else filter_clauses
|
||||||
|
self._select_related = [] if select_related is None else select_related
|
||||||
|
self.limit_count = limit_count
|
||||||
|
self.query_offset = offset
|
||||||
|
self.order_bys = None
|
||||||
|
|
||||||
|
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
|
||||||
|
return self.__class__(model_cls=owner)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def database(self) -> databases.Database:
|
||||||
|
return self.model_cls.__database__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def table(self) -> sqlalchemy.Table:
|
||||||
|
return self.model_cls.__table__
|
||||||
|
|
||||||
|
def build_select_expression(self) -> sqlalchemy.sql.select:
|
||||||
|
qry = Query(
|
||||||
|
model_cls=self.model_cls,
|
||||||
|
select_related=self._select_related,
|
||||||
|
filter_clauses=self.filter_clauses,
|
||||||
|
offset=self.query_offset,
|
||||||
|
limit_count=self.limit_count,
|
||||||
|
)
|
||||||
|
exp, self._select_related = qry.build_select_expression()
|
||||||
|
return exp
|
||||||
|
|
||||||
|
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
|
||||||
|
qryclause = QueryClause(
|
||||||
|
model_cls=self.model_cls,
|
||||||
|
select_related=self._select_related,
|
||||||
|
filter_clauses=self.filter_clauses,
|
||||||
|
)
|
||||||
|
filter_clauses, select_related = qryclause.filter(**kwargs)
|
||||||
|
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
model_cls=self.model_cls,
|
model_cls=self.model_cls,
|
||||||
filter_clauses=filter_clauses,
|
filter_clauses=filter_clauses,
|
||||||
|
|||||||
Reference in New Issue
Block a user