From 79ad69e355543631e28452a03eaedc866fd31ea8 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 7 Mar 2021 12:50:40 +0100 Subject: [PATCH] check complex prefixes in groups, refactor limit queries, finish docstrings, refactors and cleanup in long methods --- ormar/models/helpers/sqlalchemy.py | 53 +++++++- ormar/models/mixins/relation_mixin.py | 41 ++++++- ormar/models/model_row.py | 155 ++++++++++++++++++------ ormar/models/newbasemodel.py | 2 +- ormar/queryset/actions/filter_action.py | 3 - ormar/queryset/clause.py | 95 +++++++++++---- ormar/queryset/query.py | 48 +++----- ormar/queryset/queryset.py | 43 +++++-- ormar/queryset/utils.py | 58 +++++++-- tests/test_filter_groups.py | 8 -- tests/test_more_same_table_joins.py | 42 +++++++ tests/test_or_filters.py | 56 ++++++++- 12 files changed, 468 insertions(+), 136 deletions(-) diff --git a/ormar/models/helpers/sqlalchemy.py b/ormar/models/helpers/sqlalchemy.py index 3475c66..e40239d 100644 --- a/ormar/models/helpers/sqlalchemy.py +++ b/ormar/models/helpers/sqlalchemy.py @@ -150,19 +150,68 @@ def sqlalchemy_columns_from_model_fields( "Integer primary key named `id` created." ) validate_related_names_in_relations(model_fields, new_model) + return _process_fields(model_fields=model_fields, new_model=new_model) + + +def _process_fields( + model_fields: Dict, new_model: Type["Model"] +) -> Tuple[Optional[str], List[sqlalchemy.Column]]: + """ + Helper method. + + Populates pkname and columns. + Trigger validation of primary_key - only one and required pk can be set, + cannot be pydantic_only. + + Append fields to columns if it's not pydantic_only, + virtual ForeignKey or ManyToMany field. + + Sets `owner` on each model_field as reference to newly created Model. + + :raises ModelDefinitionError: if validation of related_names fail, + or pkname validation fails. + :param model_fields: dictionary of declared ormar model fields + :type model_fields: Dict[str, ormar.Field] + :param new_model: + :type new_model: Model class + :return: pkname, list of sqlalchemy columns + :rtype: Tuple[Optional[str], List[sqlalchemy.Column]] + """ columns = [] pkname = None for field_name, field in model_fields.items(): field.owner = new_model - if field.is_multi and not field.through: + if _is_through_model_not_set(field): field.create_default_through_model() if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) - if not field.pydantic_only and not field.virtual and not field.is_multi: + if _is_db_field(field): columns.append(field.get_column(field.get_alias())) return pkname, columns +def _is_through_model_not_set(field: Type["BaseField"]) -> bool: + """ + Alias to if check that verifies if through model was created. + :param field: field to check + :type field: Type["BaseField"] + :return: result of the check + :rtype: bool + """ + return field.is_multi and not field.through + + +def _is_db_field(field: Type["BaseField"]) -> bool: + """ + Alias to if check that verifies if field should be included in database. + :param field: field to check + :type field: Type["BaseField"] + :return: result of the check + :rtype: bool + """ + return not field.pydantic_only and not field.virtual and not field.is_multi + + def populate_meta_tablename_columns_and_pk( name: str, new_model: Type["Model"] ) -> Type["Model"]: diff --git a/ormar/models/mixins/relation_mixin.py b/ormar/models/mixins/relation_mixin.py index 53fab3c..8955acf 100644 --- a/ormar/models/mixins/relation_mixin.py +++ b/ormar/models/mixins/relation_mixin.py @@ -129,7 +129,7 @@ class RelationMixin: return related_names @classmethod - def _iterate_related_models( + def _iterate_related_models( # noqa: CCR001 cls, visited: Set[str] = None, source_visited: Set[str] = None, @@ -149,14 +149,12 @@ class RelationMixin: :return: list of relation strings to be passed to select_related :rtype: List[str] """ - source_visited = source_visited or set() - if not source_model: - source_visited = cls._populate_source_model_prefixes() + source_visited = source_visited or cls._populate_source_model_prefixes() relations = cls.extract_related_names() processed_relations = [] for relation in relations: target_model = cls.Meta.model_fields[relation].to - if source_model and target_model == source_model: + if cls._is_reverse_side_of_same_relation(source_model, target_model): continue if target_model not in source_visited or not source_model: deep_relations = target_model._iterate_related_models( @@ -168,6 +166,39 @@ class RelationMixin: processed_relations.extend(deep_relations) else: processed_relations.append(relation) + + return cls._get_final_relations(processed_relations, source_relation) + + @staticmethod + def _is_reverse_side_of_same_relation( + source_model: Optional[Union[Type["Model"], Type["RelationMixin"]]], + target_model: Type["Model"], + ) -> bool: + """ + Alias to check if source model is the same as target + :param source_model: source model - relation comes from it + :type source_model: Type["Model"] + :param target_model: target model - relation leads to it + :type target_model: Type["Model"] + :return: result of the check + :rtype: bool + """ + return bool(source_model and target_model == source_model) + + @staticmethod + def _get_final_relations( + processed_relations: List, source_relation: Optional[str] + ) -> List[str]: + """ + Helper method to prefix nested relation strings with current source relation + + :param processed_relations: list of already processed relation str + :type processed_relations: List[str] + :param source_relation: name of the current relation + :type source_relation: str + :return: list of relation strings to be passed to select_related + :rtype: List[str] + """ if processed_relations: final_relations = [ f"{source_relation + '__' if source_relation else ''}{relation}" diff --git a/ormar/models/model_row.py b/ormar/models/model_row.py index d9c674d..2cd488a 100644 --- a/ormar/models/model_row.py +++ b/ormar/models/model_row.py @@ -4,7 +4,9 @@ from typing import ( List, Optional, TYPE_CHECKING, + Tuple, Type, + Union, cast, ) @@ -78,21 +80,12 @@ class ModelRow(NewBaseModel): related_models = group_related_list(select_related) if related_field: - if related_field.is_multi: - previous_model = related_field.through - else: - previous_model = related_field.owner - table_prefix = cls.Meta.alias_manager.resolve_relation_alias( - from_model=previous_model, relation_name=related_field.name + table_prefix = cls._process_table_prefix( + source_model=source_model, + current_relation_str=current_relation_str, + related_field=related_field, + used_prefixes=used_prefixes, ) - if not table_prefix or table_prefix in used_prefixes: - manager = cls.Meta.alias_manager - table_prefix = manager.resolve_relation_alias_after_complex( - source_model=source_model, - relation_str=current_relation_str, - relation_field=related_field, - ) - used_prefixes.append(table_prefix) item = cls._populate_nested_models_from_row( item=item, @@ -118,6 +111,44 @@ class ModelRow(NewBaseModel): instance.set_save_status(True) return instance + @classmethod + def _process_table_prefix( + cls, + source_model: Type["Model"], + current_relation_str: str, + related_field: Type["ForeignKeyField"], + used_prefixes: List[str], + ) -> str: + """ + + :param source_model: model on which relation was defined + :type source_model: Type[Model] + :param current_relation_str: current relation string + :type current_relation_str: str + :param related_field: field with relation declaration + :type related_field: Type["ForeignKeyField"] + :param used_prefixes: list of already extracted prefixes + :type used_prefixes: List[str] + :return: table_prefix to use + :rtype: str + """ + if related_field.is_multi: + previous_model = related_field.through + else: + previous_model = related_field.owner + table_prefix = cls.Meta.alias_manager.resolve_relation_alias( + from_model=previous_model, relation_name=related_field.name + ) + if not table_prefix or table_prefix in used_prefixes: + manager = cls.Meta.alias_manager + table_prefix = manager.resolve_relation_alias_after_complex( + source_model=source_model, + relation_str=current_relation_str, + relation_field=related_field, + ) + used_prefixes.append(table_prefix) + return table_prefix + @classmethod def _populate_nested_models_from_row( # noqa: CFQ002 cls, @@ -170,14 +201,11 @@ class ModelRow(NewBaseModel): if model_excludable.is_excluded(related): return item - relation_str = ( - "__".join([current_relation_str, related]) - if current_relation_str - else related + relation_str, remainder = cls._process_remainder_and_relation_string( + related_models=related_models, + current_relation_str=current_relation_str, + related=related, ) - remainder = None - if isinstance(related_models, dict) and related_models[related]: - remainder = related_models[related] child = model_cls.from_row( row, related_models=remainder, @@ -190,24 +218,84 @@ class ModelRow(NewBaseModel): ) item[model_cls.get_column_name_from_alias(related)] = child if field.is_multi and child: - through_name = cls.Meta.model_fields[related].through.get_name() - through_child = cls.populate_through_instance( + cls._populate_through_instance( row=row, + item=item, related=related, - through_name=through_name, excludable=excludable, + child=child, + proxy_source_model=proxy_source_model, ) - if child.__class__ != proxy_source_model: - setattr(child, through_name, through_child) - else: - item[through_name] = through_child - child.set_save_status(True) - return item + @staticmethod + def _process_remainder_and_relation_string( + related_models: Union[Dict, List], + current_relation_str: Optional[str], + related: str, + ) -> Tuple[str, Optional[Union[Dict, List]]]: + """ + Process remainder models and relation string + + :param related_models: list or dict of related models + :type related_models: Union[Dict, List] + :param current_relation_str: current relation string + :type current_relation_str: Optional[str] + :param related: name of the relation + :type related: str + """ + relation_str = ( + "__".join([current_relation_str, related]) + if current_relation_str + else related + ) + + remainder = None + if isinstance(related_models, dict) and related_models[related]: + remainder = related_models[related] + return relation_str, remainder + @classmethod - def populate_through_instance( + def _populate_through_instance( # noqa: CFQ002 + cls, + row: sqlalchemy.engine.ResultProxy, + item: Dict, + related: str, + excludable: ExcludableItems, + child: "Model", + proxy_source_model: Optional[Type["Model"]], + ) -> None: + """ + Populates the through model on reverse side of current query. + Normally it's child class, unless the query is from queryset. + + :param row: row from db result + :type row: sqlalchemy.engine.ResultProxy + :param item: parent item dict + :type item: Dict + :param related: current relation name + :type related: str + :param excludable: structure of fields to include and exclude + :type excludable: ExcludableItems + :param child: child item of parent + :type child: "Model" + :param proxy_source_model: source model from which querysetproxy is constructed + :type proxy_source_model: Type["Model"] + """ + through_name = cls.Meta.model_fields[related].through.get_name() + through_child = cls._create_through_instance( + row=row, related=related, through_name=through_name, excludable=excludable, + ) + + if child.__class__ != proxy_source_model: + setattr(child, through_name, through_child) + else: + item[through_name] = through_child + child.set_save_status(True) + + @classmethod + def _create_through_instance( cls, row: sqlalchemy.engine.ResultProxy, through_name: str, @@ -288,12 +376,11 @@ class ModelRow(NewBaseModel): model=cls, excludable=excludable, alias=table_prefix, use_alias=False, ) + column_prefix = table_prefix + "_" if table_prefix else "" for column in cls.Meta.table.columns: alias = cls.get_column_name_from_alias(column.name) if alias not in item and alias in selected_columns: - prefixed_name = ( - f'{table_prefix + "_" if table_prefix else ""}{column.name}' - ) + prefixed_name = f"{column_prefix}{column.name}" item[alias] = source[prefixed_name] return item diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 9308d83..8ffafc7 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -227,7 +227,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass super().__setattr__(name, value) self.set_save_status(False) - def __getattribute__(self, item: str) -> Any: + def __getattribute__(self, item: str) -> Any: # noqa: CCR001 """ Because we need to overwrite getting the attribute by ormar instead of pydantic as well as returning related models and not the value stored on the model the diff --git a/ormar/queryset/actions/filter_action.py b/ormar/queryset/actions/filter_action.py index 7b3bb9e..279e0fa 100644 --- a/ormar/queryset/actions/filter_action.py +++ b/ormar/queryset/actions/filter_action.py @@ -43,9 +43,6 @@ class FilterAction(QueryAction): super().__init__(query_str=filter_str, model_cls=model_cls) self.filter_value = value self._escape_characters_in_clause() - self.is_source_model_filter = False - if self.source_model == self.target_model and "__" not in self.related_str: - self.is_source_model_filter = True def has_escaped_characters(self) -> bool: """Check if value is a string that contains characters to escape""" diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index 4a1ef97..504a2eb 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -19,6 +19,11 @@ class FilterType(Enum): class FilterGroup: + """ + Filter groups are used in complex queries condition to group and and or + clauses in where condition + """ + def __init__( self, *args: Any, _filter_type: FilterType = FilterType.AND, **kwargs: Any, ) -> None: @@ -36,6 +41,19 @@ class FilterGroup: select_related: List = None, filter_clauses: List = None, ) -> Tuple[List[FilterAction], List[str]]: + """ + Resolves the FilterGroups actions to use proper target model, replace + complex relation prefixes if needed and nested groups also resolved. + + :param model_cls: model from which the query is run + :type model_cls: Type["Model"] + :param select_related: list of models to join + :type select_related: List[str] + :param filter_clauses: list of filter conditions + :type filter_clauses: List[FilterAction] + :return: list of filter conditions and select_related list + :rtype: Tuple[List[FilterAction], List[str]] + """ select_related = select_related if select_related is not None else [] filter_clauses = filter_clauses if filter_clauses is not None else [] qryclause = QueryClause( @@ -51,42 +69,44 @@ class FilterGroup: self._resolved = True if self._nested_groups: for group in self._nested_groups: - if not group._resolved: - (filter_clauses, select_related) = group.resolve( - model_cls=model_cls, - select_related=select_related, - filter_clauses=filter_clauses, - ) - self._is_self_model_group() + (filter_clauses, select_related) = group.resolve( + model_cls=model_cls, + select_related=select_related, + filter_clauses=filter_clauses, + ) return filter_clauses, select_related def _iter(self) -> Generator: - if not self._nested_groups: - yield from self.actions - return + """ + Iterates all actions in a tree + :return: generator yielding from own actions and nested groups + :rtype: Generator + """ for group in self._nested_groups: yield from group._iter() yield from self.actions - def _is_self_model_group(self) -> None: - if self.actions and self._nested_groups: - if all([action.is_source_model_filter for action in self.actions]) and all( - group.is_source_model_filter for group in self._nested_groups - ): - self.is_source_model_filter = True - elif self.actions: - if all([action.is_source_model_filter for action in self.actions]): - self.is_source_model_filter = True - else: - if all(group.is_source_model_filter for group in self._nested_groups): - self.is_source_model_filter = True - def _get_text_clauses(self) -> List[sqlalchemy.sql.expression.TextClause]: + """ + Helper to return list of text queries from actions and nested groups + :return: list of text queries from actions and nested groups + :rtype: List[sqlalchemy.sql.elements.TextClause] + """ return [x.get_text_clause() for x in self._nested_groups] + [ x.get_text_clause() for x in self.actions ] def get_text_clause(self) -> sqlalchemy.sql.expression.TextClause: + """ + Returns all own actions and nested groups conditions compiled and joined + inside parentheses. + Escapes characters if it's required. + Substitutes values of the models if value is a ormar Model with its pk value. + Compiles the clause. + + :return: complied and escaped clause + :rtype: sqlalchemy.sql.elements.TextClause + """ if self.filter_type == FilterType.AND: clause = sqlalchemy.text( "( " + str(sqlalchemy.sql.and_(*self._get_text_clauses())) + " )" @@ -98,11 +118,31 @@ class FilterGroup: return clause -def or_(*args: Any, **kwargs: Any) -> FilterGroup: +def or_(*args: FilterGroup, **kwargs: Any) -> FilterGroup: + """ + Construct or filter from nested groups and keyword arguments + + :param args: nested filter groups + :type args: Tuple[FilterGroup] + :param kwargs: fields names and proper value types + :type kwargs: Any + :return: FilterGroup ready to be resolved + :rtype: ormar.queryset.clause.FilterGroup + """ return FilterGroup(_filter_type=FilterType.OR, *args, **kwargs) -def and_(*args: Any, **kwargs: Any) -> FilterGroup: +def and_(*args: FilterGroup, **kwargs: Any) -> FilterGroup: + """ + Construct and filter from nested groups and keyword arguments + + :param args: nested filter groups + :type args: Tuple[FilterGroup] + :param kwargs: fields names and proper value types + :type kwargs: Any + :return: FilterGroup ready to be resolved + :rtype: ormar.queryset.clause.FilterGroup + """ return FilterGroup(_filter_type=FilterType.AND, *args, **kwargs) @@ -263,6 +303,11 @@ class QueryClause: return filter_clauses def _verify_prefix_and_switch(self, action: "FilterAction") -> None: + """ + Helper to switch prefix to complex relation one if required + :param action: action to switch prefix in + :type action: ormar.queryset.actions.filter_action.FilterAction + """ manager = self.model_cls.Meta.alias_manager new_alias = manager.resolve_relation_alias(self.model_cls, action.related_str) if "__" in action.related_str and new_alias: diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 0987bac..be46e74 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -108,10 +108,7 @@ class Query: "", self.table, self_related_fields ) self.apply_order_bys_for_primary_model() - if self._pagination_query_required(): - self.select_from = self._build_pagination_subquery() - else: - self.select_from = self.table + self.select_from = self.table related_models = group_related_list(self._select_related) @@ -139,6 +136,9 @@ class Query: self.sorted_orders, ) = sql_join.build_join() + if self._pagination_query_required(): + self._build_pagination_condition() + expr = sqlalchemy.sql.select(self.columns) expr = expr.select_from(self.select_from) @@ -149,7 +149,7 @@ class Query: return expr - def _build_pagination_subquery(self) -> sqlalchemy.sql.select: + def _build_pagination_condition(self) -> None: """ In order to apply limit and offset on main table in join only (otherwise you can get only partially constructed main model @@ -160,32 +160,20 @@ class Query: and query has select_related applied. Otherwise we can limit/offset normally at the end of whole query. - :return: constructed subquery on main table with limit, offset and order applied - :rtype: sqlalchemy.sql.select + The condition is added to filters to filter out desired number of main model + primary key values. Whole query is used to determine the values. """ - expr = sqlalchemy.sql.select(self.model_cls.Meta.table.columns) - expr = LimitQuery(limit_count=self.limit_count).apply(expr) - expr = OffsetQuery(query_offset=self.query_offset).apply(expr) - filters_to_use = [ - filter_clause - for filter_clause in self.filter_clauses - if filter_clause.is_source_model_filter - ] - excludes_to_use = [ - filter_clause - for filter_clause in self.exclude_clauses - if filter_clause.is_source_model_filter - ] - sorts_to_use = { - k: v for k, v in self.sorted_orders.items() if k.is_source_model_order - } - expr = FilterQuery(filter_clauses=filters_to_use).apply(expr) - expr = FilterQuery(filter_clauses=excludes_to_use, exclude=True).apply(expr) - expr = OrderQuery(sorted_orders=sorts_to_use).apply(expr) - expr = expr.alias(f"{self.table}") - self.filter_clauses = list(set(self.filter_clauses) - set(filters_to_use)) - self.exclude_clauses = list(set(self.exclude_clauses) - set(excludes_to_use)) - return expr + pk_alias = self.model_cls.get_column_alias(self.model_cls.Meta.pkname) + qry_text = sqlalchemy.text(f"distinct {self.table.name}.{pk_alias}") + limit_qry = sqlalchemy.sql.select([qry_text]) + limit_qry = limit_qry.select_from(self.select_from) + limit_qry = self._apply_expression_modifiers(limit_qry) + limit_qry = LimitQuery(limit_count=self.limit_count).apply(limit_qry) + limit_qry = OffsetQuery(query_offset=self.query_offset).apply(limit_qry) + limit_action = FilterAction( + filter_str=f"{pk_alias}__in", value=limit_qry, model_cls=self.model_cls + ) + self.filter_clauses.append(limit_action) def _apply_expression_modifiers( self, expr: sqlalchemy.sql.select diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 944550e..e84b67d 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -20,7 +20,7 @@ from ormar import MultipleMatches, NoMatch from ormar.exceptions import ModelError, ModelPersistenceError, QueryDefinitionError from ormar.queryset import FilterQuery from ormar.queryset.actions.order_action import OrderAction -from ormar.queryset.clause import QueryClause +from ormar.queryset.clause import FilterGroup, QueryClause from ormar.queryset.prefetch_query import PrefetchQuery from ormar.queryset.query import Query @@ -192,6 +192,34 @@ class QuerySet: return self.model.merge_instances_list(result_rows) # type: ignore return result_rows + def _resolve_filter_groups(self, groups: Any) -> List[FilterGroup]: + """ + Resolves filter groups to populate FilterAction params in group tree. + + :param groups: tuple of FilterGroups + :type groups: Any + :return: list of resolver groups + :rtype: List[FilterGroup] + """ + filter_groups = [] + if groups: + for group in groups: + if not isinstance(group, FilterGroup): + raise QueryDefinitionError( + "Only ormar.and_ and ormar.or_ " + "can be passed as filter positional" + " arguments," + "other values need to be passed by" + "keyword arguments" + ) + group.resolve( + model_cls=self.model, + select_related=self._select_related, + filter_clauses=self.filter_clauses, + ) + filter_groups.append(group) + return filter_groups + @staticmethod def check_single_result_rows_count(rows: Sequence[Optional["Model"]]) -> None: """ @@ -288,23 +316,14 @@ class QuerySet: :return: filtered QuerySet :rtype: QuerySet """ - filter_groups = [] - if args: - for arg in args: - arg.resolve( - model_cls=self.model, - select_related=self._select_related, - filter_clauses=self.filter_clauses, - ) - filter_groups.append(arg) - + filter_groups = self._resolve_filter_groups(groups=args) qryclause = QueryClause( model_cls=self.model, select_related=self._select_related, filter_clauses=self.filter_clauses, ) filter_clauses, select_related = qryclause.prepare_filter(**kwargs) - filter_clauses = filter_clauses + filter_groups + filter_clauses = filter_clauses + filter_groups # type: ignore if _exclude: exclude_clauses = filter_clauses filter_clauses = self.filter_clauses diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index ca3358d..64e8af2 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -4,6 +4,7 @@ from typing import ( Any, Dict, List, + Optional, Sequence, Set, TYPE_CHECKING, @@ -13,7 +14,7 @@ from typing import ( ) if TYPE_CHECKING: # pragma no cover - from ormar import Model + from ormar import Model, BaseField def check_node_not_dict_or_not_last_node( @@ -238,18 +239,13 @@ def get_relationship_alias_model_and_str( related_field = target_model.Meta.model_fields[relation] if related_field.is_through: - # through is always last - cannot go further - is_through = True - related_parts.remove(relation) - through_field = related_field.owner.Meta.model_fields[ - related_field.related_name or "" - ] - if len(previous_models) > 1 and previous_models[-2] == through_field.to: - previous_model = through_field.to - relation = through_field.related_name - else: - relation = related_field.related_name - + (previous_model, relation, is_through) = _process_through_field( + related_parts=related_parts, + relation=relation, + related_field=related_field, + previous_model=previous_model, + previous_models=previous_models, + ) if related_field.is_multi: previous_model = related_field.through relation = related_field.default_target_field_name() # type: ignore @@ -263,3 +259,39 @@ def get_relationship_alias_model_and_str( relation_str = "__".join(related_parts) return table_prefix, target_model, relation_str, is_through + + +def _process_through_field( + related_parts: List, + relation: Optional[str], + related_field: Type["BaseField"], + previous_model: Type["Model"], + previous_models: List[Type["Model"]], +) -> Tuple[Type["Model"], Optional[str], bool]: + """ + Helper processing through models as they need to be treated differently. + + :param related_parts: split relation string + :type related_parts: List[str] + :param relation: relation name + :type relation: str + :param related_field: field with relation declaration + :type related_field: Type["ForeignKeyField"] + :param previous_model: model from which relation is coming + :type previous_model: Type["Model"] + :param previous_models: list of already visited models in relation chain + :type previous_models: List[Type["Model"]] + :return: previous_model, relation, is_through + :rtype: Tuple[Type["Model"], str, bool] + """ + is_through = True + related_parts.remove(relation) + through_field = related_field.owner.Meta.model_fields[ + related_field.related_name or "" + ] + if len(previous_models) > 1 and previous_models[-2] == through_field.to: + previous_model = through_field.to + relation = through_field.related_name + else: + relation = related_field.related_name + return previous_model, relation, is_through diff --git a/tests/test_filter_groups.py b/tests/test_filter_groups.py index ec6c4a0..21581b1 100644 --- a/tests/test_filter_groups.py +++ b/tests/test_filter_groups.py @@ -44,7 +44,6 @@ def test_or_group(): f"{result.actions[1].table_prefix}" f"_books.title = 'bb' )" ) - assert not result.is_source_model_filter def test_and_group(): @@ -58,7 +57,6 @@ def test_and_group(): f"{result.actions[1].table_prefix}" f"_books.title = 'bb' )" ) - assert not result.is_source_model_filter def test_nested_and(): @@ -77,7 +75,6 @@ def test_nested_and(): f"{book_prefix}" f"_books.title = 'dd' ) )" ) - assert not result.is_source_model_filter def test_nested_group_and_action(): @@ -93,7 +90,6 @@ def test_nested_group_and_action(): f"{book_prefix}" f"_books.title = 'dd' )" ) - assert not result.is_source_model_filter def test_deeply_nested_or(): @@ -120,7 +116,6 @@ def test_deeply_nested_or(): f"( {book_prefix}_books.year > 'xx' OR {book_prefix}_books.title = '22' ) ) )" ) assert result_qry.replace("\n", "") == expected_qry.replace("\n", "") - assert not result.is_source_model_filter def test_one_model_group(): @@ -128,7 +123,6 @@ def test_one_model_group(): result.resolve(model_cls=Book) assert len(result.actions) == 2 assert len(result._nested_groups) == 0 - assert result.is_source_model_filter def test_one_model_nested_group(): @@ -138,7 +132,6 @@ def test_one_model_nested_group(): result.resolve(model_cls=Book) assert len(result.actions) == 0 assert len(result._nested_groups) == 2 - assert result.is_source_model_filter def test_one_model_with_group(): @@ -146,4 +139,3 @@ def test_one_model_with_group(): result.resolve(model_cls=Book) assert len(result.actions) == 1 assert len(result._nested_groups) == 1 - assert result.is_source_model_filter diff --git a/tests/test_more_same_table_joins.py b/tests/test_more_same_table_joins.py index b991d13..13bab15 100644 --- a/tests/test_more_same_table_joins.py +++ b/tests/test_more_same_table_joins.py @@ -122,3 +122,45 @@ async def test_load_all_multiple_instances_of_same_table_in_schema(): assert len(math_class.dict().get("students")) == 2 assert math_class.teachers[0].category.department.name == "Law Department" assert math_class.students[0].category.department.name == "Math Department" + + +@pytest.mark.asyncio +async def test_filter_groups_with_instances_of_same_table_in_schema(): + async with database: + await create_data() + math_class = ( + await SchoolClass.objects.select_related( + ["teachers__category__department", "students__category__department"] + ) + .filter( + ormar.or_( + students__name="Jane", + teachers__category__name="Domestic", + students__category__name="Foreign", + ) + ) + .get(name="Math") + ) + assert math_class.name == "Math" + assert math_class.students[0].name == "Jane" + assert len(math_class.dict().get("students")) == 2 + assert math_class.teachers[0].category.department.name == "Law Department" + assert math_class.students[0].category.department.name == "Math Department" + + classes = ( + await SchoolClass.objects.select_related( + ["students__category__department", "teachers__category__department"] + ) + .filter( + ormar.and_( + ormar.or_( + students__name="Jane", students__category__name="Foreign" + ), + teachers__category__department__name="Law Department", + ) + ) + .all() + ) + assert len(classes) == 1 + assert classes[0].teachers[0].category.department.name == "Law Department" + assert classes[0].students[0].category.department.name == "Math Department" diff --git a/tests/test_or_filters.py b/tests/test_or_filters.py index b3300d1..2a33dc7 100644 --- a/tests/test_or_filters.py +++ b/tests/test_or_filters.py @@ -5,6 +5,7 @@ import pytest import sqlalchemy import ormar +from ormar.exceptions import QueryDefinitionError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL) @@ -108,11 +109,60 @@ async def test_or_filters(): assert len(books) == 3 assert not any([x.title in ["The Silmarillion", "The Witcher"] for x in books]) + books = ( + await Book.objects.select_related("author") + .filter(ormar.or_(year__gt=1980, year__lt=1910)) + .filter(title__startswith="The") + .limit(1) + .all() + ) + assert len(books) == 1 + assert books[0].title == "The Witcher" + + books = ( + await Book.objects.select_related("author") + .filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski")) + .filter(title__startswith="The") + .limit(1) + .all() + ) + assert len(books) == 1 + assert books[0].title == "The Witcher" + + books = ( + await Book.objects.select_related("author") + .filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski")) + .filter(title__startswith="The") + .limit(1) + .offset(1) + .all() + ) + assert len(books) == 1 + assert books[0].title == "The Tower of Fools" + + books = ( + await Book.objects.select_related("author") + .filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski")) + .filter(title__startswith="The") + .limit(1) + .offset(1) + .order_by("-id") + .all() + ) + assert len(books) == 1 + assert books[0].title == "The Witcher" + + with pytest.raises(QueryDefinitionError): + await Book.objects.select_related("author").filter('wrong').all() + + # TODO: Check / modify # process and and or into filter groups (V) # check exclude queries working (V) +# check complex prefixes properly resolved (V) +# fix limit -> change to where subquery to extract number of distinct pk values (V) +# finish docstrings (V) +# fix types for FilterAction and FilterGroup (X) -# when limit and no sql do not allow main model and other models -# check complex prefixes properly resolved -# fix types for FilterAction and FilterGroup (?) +# add docs