From becb914e557c197893c0089ab8bd9dcd7c77a250 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 9 Aug 2020 13:27:53 +0200 Subject: [PATCH] refactor query and queryclause into separate classes --- .coverage | Bin 53248 -> 53248 bytes orm/queryset.py | 375 +++++++++++++++++++++++++++--------------------- 2 files changed, 215 insertions(+), 160 deletions(-) diff --git a/.coverage b/.coverage index 09e8343a34172aa3ca53982806eb1a61d15636a3..5d7b859a5b4294ba4300067c998729bd4a6b59ad 100644 GIT binary patch delta 148 zcmV;F0Biq%paX!Q1F$MD2QxY{HaammvoSB#P#$;y5BU%358e;c52p`s404+swW z4$=;|4wepovk?$m4wHC}s7Xf$1px_x4hI6DxBQRK`oG=oKKJ|fSNrbX_xt+y_v!Dq v?SAj8-Puxij(`1ozW3Di@49>6KXdJ$e(ukEfBxLvzWY@9 None: - self.model_cls = model_cls - self.filter_clauses = [] if filter_clauses is None else filter_clauses - self._select_related = [] if select_related is None else select_related - self.limit_count = limit_count + self.query_offset = offset + self.limit_count = limit_count + self._select_related = select_related + self.filter_clauses = filter_clauses + + self.model_cls = model_cls + self.table = self.model_cls.__table__ self.auto_related = [] self.used_aliases = [] @@ -67,16 +70,45 @@ class QuerySet: self.columns = None self.order_bys = None - def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": - return self.__class__(model_cls=owner) + def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: + self.columns = list(self.table.columns) + self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] + self.select_from = self.table - @property - def database(self) -> databases.Database: - return self.model_cls.__database__ + for key in self.model_cls.__model_fields__: + if ( + not self.model_cls.__model_fields__[key].nullable + and isinstance( + self.model_cls.__model_fields__[key], orm.fields.ForeignKey + ) + and key not in self._select_related + ): + self._select_related = [key] + self._select_related - @property - def table(self) -> sqlalchemy.Table: - return self.model_cls.__table__ + start_params = JoinParameters( + self.model_cls, "", self.table.name, self.model_cls + ) + self._extract_auto_required_relations(prev_model=start_params.prev_model) + self._include_auto_related_models() + self._select_related.sort(key=lambda item: (-len(item), item)) + + for item in self._select_related: + join_parameters = JoinParameters( + self.model_cls, "", self.table.name, self.model_cls + ) + + for part in item.split("__"): + join_parameters = self._build_join_parameters(part, join_parameters) + + expr = sqlalchemy.sql.select(self.columns) + expr = expr.select_from(self.select_from) + + expr = self._apply_expression_modifiers(expr) + + # print(expr.compile(compile_kwargs={"literal_binds": True})) + self._reset_query_parameters() + + return expr, self._select_related @staticmethod def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: @@ -89,6 +121,25 @@ class QuerySet: def prefixed_table_name(alias: str, name: str) -> text: return text(f"{name} {alias}_{name}") + @staticmethod + def _field_is_a_foreign_key_and_no_circular_reference( + field: BaseField, field_name: str, rel_part: str + ) -> bool: + return isinstance(field, ForeignKey) and field_name not in rel_part + + def _field_qualifies_to_deeper_search( + self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str + ) -> bool: + prev_part_of_related = "__".join(rel_part.split("__")[:-1]) + partial_match = any( + [x.startswith(prev_part_of_related) for x in self._select_related] + ) + already_checked = any([x.startswith(rel_part) for x in self.auto_related]) + return ( + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested + def on_clause( self, previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: @@ -139,25 +190,6 @@ class QuerySet: prev_model = model_cls return JoinParameters(prev_model, previous_alias, from_table, model_cls) - @staticmethod - def _field_is_a_foreign_key_and_no_circular_reference( - field: BaseField, field_name: str, rel_part: str - ) -> bool: - return isinstance(field, ForeignKey) and field_name not in rel_part - - def _field_qualifies_to_deeper_search( - self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str - ) -> bool: - prev_part_of_related = "__".join(rel_part.split("__")[:-1]) - partial_match = any( - [x.startswith(prev_part_of_related) for x in self._select_related] - ) - already_checked = any([x.startswith(rel_part) for x in self.auto_related]) - return ( - (field.virtual and parent_virtual) - or (partial_match and not already_checked) - ) or not nested - def _extract_auto_required_relations( self, prev_model: Type["Model"], @@ -221,130 +253,21 @@ class QuerySet: self.auto_related = [] self.used_aliases = [] - def build_select_expression(self) -> sqlalchemy.sql.select: - self.columns = list(self.table.columns) - self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] - self.select_from = self.table - for key in self.model_cls.__model_fields__: - if ( - not self.model_cls.__model_fields__[key].nullable - and isinstance( - self.model_cls.__model_fields__[key], orm.fields.ForeignKey - ) - and key not in self._select_related - ): - self._select_related = [key] + self._select_related +class QueryClause: + def __init__( + self, model_cls: Type["Model"], filter_clauses: List, select_related: List, + ) -> None: - start_params = JoinParameters( - self.model_cls, "", self.table.name, self.model_cls - ) - self._extract_auto_required_relations(prev_model=start_params.prev_model) - self._include_auto_related_models() - self._select_related.sort(key=lambda item: (-len(item), item)) + self._select_related = select_related + self.filter_clauses = filter_clauses - for item in self._select_related: - join_parameters = JoinParameters( - self.model_cls, "", self.table.name, self.model_cls - ) + self.model_cls = model_cls + self.table = self.model_cls.__table__ - for part in item.split("__"): - join_parameters = self._build_join_parameters(part, join_parameters) - - expr = sqlalchemy.sql.select(self.columns) - expr = expr.select_from(self.select_from) - - expr = self._apply_expression_modifiers(expr) - - # print(expr.compile(compile_kwargs={"literal_binds": True})) - self._reset_query_parameters() - - return expr - - def _determine_filter_target_table( - self, related_parts: List[str], select_related: List[str] - ) -> Tuple[List[str], str, "Model"]: - - table_prefix = "" - model_cls = self.model_cls - select_related = [relation for relation in select_related] - - # Add any implied select_related - related_str = "__".join(related_parts) - if related_str not in select_related: - select_related.append(related_str) - - # Walk the relationships to the actual model class - # against which the comparison is being made. - previous_table = model_cls.__tablename__ - for part in related_parts: - current_table = model_cls.__model_fields__[part].to.__tablename__ - manager = model_cls._orm_relationship_manager - table_prefix = manager.resolve_relation_join(previous_table, current_table) - model_cls = model_cls.__model_fields__[part].to - previous_table = current_table - return select_related, table_prefix, model_cls - - def _compile_clause( - self, - clause: sqlalchemy.sql.expression.BinaryExpression, - column: sqlalchemy.Column, - table: sqlalchemy.Table, - table_prefix: str, - modifiers: Dict, - ) -> sqlalchemy.sql.expression.TextClause: - for modifier, modifier_value in modifiers.items(): - clause.modifiers[modifier] = modifier_value - - clause_text = str( - clause.compile( - dialect=self.model_cls.__database__._backend._dialect, - compile_kwargs={"literal_binds": True}, - ) - ) - alias = f"{table_prefix}_" if table_prefix else "" - aliased_name = f"{alias}{table.name}.{column.name}" - clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name) - clause = text(clause_text) - return clause - - def _escape_characters_in_clause( - self, op: str, value: Union[str, "Model"] - ) -> Tuple[str, bool]: - has_escaped_character = False - - if op in ["contains", "icontains"]: - if isinstance(value, orm.Model): - raise QueryDefinitionError( - "You cannot use contains and icontains with instance of the Model" - ) - - has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS if c in value) - - if has_escaped_character: - # enable escape modifier - for char in self.ESCAPE_CHARACTERS: - value = value.replace(char, f"\\{char}") - value = f"%{value}%" - - return value, has_escaped_character - - @staticmethod - def _extract_operator_field_and_related( - parts: List[str], - ) -> Tuple[str, str, Optional[List]]: - if parts[-1] in FILTER_OPERATORS: - op = parts[-1] - field_name = parts[-2] - related_parts = parts[:-2] - else: - op = "exact" - field_name = parts[-1] - related_parts = parts[:-1] - - return op, field_name, related_parts - - def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 + def filter( # noqa: A003 + self, **kwargs: Any + ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: filter_clauses = self.filter_clauses select_related = list(self._select_related) @@ -395,9 +318,141 @@ class QuerySet: table_prefix, modifiers={"escape": "\\" if has_escaped_character else None}, ) - filter_clauses.append(clause) + return filter_clauses, select_related + + def _determine_filter_target_table( + self, related_parts: List[str], select_related: List[str] + ) -> Tuple[List[str], str, "Model"]: + + table_prefix = "" + model_cls = self.model_cls + select_related = [relation for relation in select_related] + + # Add any implied select_related + related_str = "__".join(related_parts) + if related_str not in select_related: + select_related.append(related_str) + + # Walk the relationships to the actual model class + # against which the comparison is being made. + previous_table = model_cls.__tablename__ + for part in related_parts: + current_table = model_cls.__model_fields__[part].to.__tablename__ + manager = model_cls._orm_relationship_manager + table_prefix = manager.resolve_relation_join(previous_table, current_table) + model_cls = model_cls.__model_fields__[part].to + previous_table = current_table + return select_related, table_prefix, model_cls + + def _compile_clause( + self, + clause: sqlalchemy.sql.expression.BinaryExpression, + column: sqlalchemy.Column, + table: sqlalchemy.Table, + table_prefix: str, + modifiers: Dict, + ) -> sqlalchemy.sql.expression.TextClause: + for modifier, modifier_value in modifiers.items(): + clause.modifiers[modifier] = modifier_value + + clause_text = str( + clause.compile( + dialect=self.model_cls.__database__._backend._dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + alias = f"{table_prefix}_" if table_prefix else "" + aliased_name = f"{alias}{table.name}.{column.name}" + clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name) + clause = text(clause_text) + return clause + + @staticmethod + def _escape_characters_in_clause( + op: str, value: Union[str, "Model"] + ) -> Tuple[str, bool]: + has_escaped_character = False + + if op in ["contains", "icontains"]: + if isinstance(value, orm.Model): + raise QueryDefinitionError( + "You cannot use contains and icontains with instance of the Model" + ) + + has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value) + + if has_escaped_character: + # enable escape modifier + for char in ESCAPE_CHARACTERS: + value = value.replace(char, f"\\{char}") + value = f"%{value}%" + + return value, has_escaped_character + + @staticmethod + def _extract_operator_field_and_related( + parts: List[str], + ) -> Tuple[str, str, Optional[List]]: + if parts[-1] in FILTER_OPERATORS: + op = parts[-1] + field_name = parts[-2] + related_parts = parts[:-2] + else: + op = "exact" + field_name = parts[-1] + related_parts = parts[:-1] + + return op, field_name, related_parts + + +class QuerySet: + def __init__( + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, + ) -> None: + self.model_cls = model_cls + self.filter_clauses = [] if filter_clauses is None else filter_clauses + self._select_related = [] if select_related is None else select_related + self.limit_count = limit_count + self.query_offset = offset + self.order_bys = None + + def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": + return self.__class__(model_cls=owner) + + @property + def database(self) -> databases.Database: + return self.model_cls.__database__ + + @property + def table(self) -> sqlalchemy.Table: + return self.model_cls.__table__ + + def build_select_expression(self) -> sqlalchemy.sql.select: + qry = Query( + model_cls=self.model_cls, + select_related=self._select_related, + filter_clauses=self.filter_clauses, + offset=self.query_offset, + limit_count=self.limit_count, + ) + exp, self._select_related = qry.build_select_expression() + return exp + + def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 + qryclause = QueryClause( + model_cls=self.model_cls, + select_related=self._select_related, + filter_clauses=self.filter_clauses, + ) + filter_clauses, select_related = qryclause.filter(**kwargs) + return self.__class__( model_cls=self.model_cls, filter_clauses=filter_clauses,