From d6e2c85b793ba62c9ae5181ee666ebb261323597 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 17 Jan 2021 12:29:21 +0100 Subject: [PATCH] refactors in join to register complex aliases on duplicate, to do is doing the same in filter clauses --- ormar/models/helpers/models.py | 25 ++++++---- ormar/queryset/clause.py | 63 +++++++++++++----------- ormar/queryset/join.py | 82 +++++++++++++++++++------------- ormar/queryset/query.py | 1 + ormar/queryset/queryset.py | 3 ++ ormar/relations/alias_manager.py | 6 +++ tests/test_docs/__init__.py | 0 tests/test_forward_refs.py | 22 +++++++-- tests/test_models_helpers.py | 9 ++++ 9 files changed, 137 insertions(+), 74 deletions(-) delete mode 100644 tests/test_docs/__init__.py create mode 100644 tests/test_models_helpers.py diff --git a/ormar/models/helpers/models.py b/ormar/models/helpers/models.py index 6cf1f12..75282d6 100644 --- a/ormar/models/helpers/models.py +++ b/ormar/models/helpers/models.py @@ -22,12 +22,12 @@ def is_field_an_forward_ref(field: Type["BaseField"]) -> bool: :rtype: bool """ return issubclass(field, ForeignKeyField) and ( - field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef + field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef ) def populate_default_options_values( - new_model: Type["Model"], model_fields: Dict + new_model: Type["Model"], model_fields: Dict ) -> None: """ Sets all optional Meta values to it's defaults @@ -52,7 +52,8 @@ def populate_default_options_values( new_model.Meta.abstract = False if any( - is_field_an_forward_ref(field) for field in new_model.Meta.model_fields.values() + is_field_an_forward_ref(field) for field in + new_model.Meta.model_fields.values() ): new_model.Meta.requires_ref_update = True else: @@ -77,7 +78,7 @@ def extract_annotations_and_default_vals(attrs: Dict) -> Tuple[Dict, Dict]: # cannot be in relations helpers due to cyclical import def validate_related_names_in_relations( # noqa CCR001 - model_fields: Dict, new_model: Type["Model"] + model_fields: Dict, new_model: Type["Model"] ) -> None: """ Performs a validation of relation_names in relation fields. @@ -122,20 +123,24 @@ def group_related_list(list_: List) -> Dict: will become: {'people': {'houses': [], 'cars': ['models', 'colors']}} + Result dictionary is sorted by length of the values and by key + :param list_: list of related models used in select related :type list_: List[str] :return: list converted to dictionary to avoid repetition and group nested models :rtype: Dict[str, List] """ - test_dict: Dict[str, Any] = dict() + result_dict: Dict[str, Any] = dict() + list_.sort(key=lambda x: x.split("__")[0]) grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0]) for key, group in grouped: group_list = list(group) - new = [ + new = sorted([ "__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1 - ] + ]) if any("__" in x for x in new): - test_dict[key] = group_related_list(new) + result_dict[key] = group_related_list(new) else: - test_dict[key] = new - return test_dict + result_dict.setdefault(key, []).extend(new) + return {k: v for k, v in + sorted(result_dict.items(), key=lambda item: len(item[1]))} diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index 514fdf1..3985b5e 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -34,7 +34,7 @@ class QueryClause: """ def __init__( - self, model_cls: Type["Model"], filter_clauses: List, select_related: List, + self, model_cls: Type["Model"], filter_clauses: List, select_related: List, ) -> None: self._select_related = select_related[:] @@ -44,7 +44,7 @@ class QueryClause: self.table = self.model_cls.Meta.table def filter( # noqa: A003 - self, **kwargs: Any + self, **kwargs: Any ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: """ Main external access point that processes the clauses into sqlalchemy text @@ -65,7 +65,7 @@ class QueryClause: return filter_clauses, select_related def _populate_filter_clauses( - self, **kwargs: Any + self, **kwargs: Any ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: """ Iterates all clauses and extracts used operator and field from related @@ -98,7 +98,9 @@ class QueryClause: table_prefix, model_cls, ) = self._determine_filter_target_table( - related_parts, select_related + related_parts=related_parts, + select_related=select_related, + field_name=field_name ) table = model_cls.Meta.table @@ -116,12 +118,12 @@ class QueryClause: return filter_clauses, select_related def _process_column_clause_for_operator_and_value( - self, - value: Any, - op: str, - column: sqlalchemy.Column, - table: sqlalchemy.Table, - table_prefix: str, + self, + value: Any, + op: str, + column: sqlalchemy.Column, + table: sqlalchemy.Table, + table_prefix: str, ) -> sqlalchemy.sql.expression.TextClause: """ Escapes characters if it's required. @@ -158,7 +160,7 @@ class QueryClause: return clause def _determine_filter_target_table( - self, related_parts: List[str], select_related: List[str] + self, related_parts: List[str], select_related: List[str], field_name: str ) -> Tuple[List[str], str, Type["Model"]]: """ Adds related strings to select_related list otherwise the clause would fail as @@ -187,27 +189,34 @@ class QueryClause: # Walk the relationships to the actual model class # against which the comparison is being made. previous_model = model_cls - for part in related_parts: - part2 = part - if issubclass(model_cls.Meta.model_fields[part], ManyToManyField): - through_field = model_cls.Meta.model_fields[part] - previous_model = through_field.through - part2 = through_field.default_target_field_name() # type: ignore - manager = model_cls.Meta.alias_manager + manager = model_cls.Meta.alias_manager + for relation in related_parts: + related_field = model_cls.Meta.model_fields[relation] + if issubclass(related_field, ManyToManyField): + previous_model = related_field.through + relation = related_field.default_target_field_name() # type: ignore table_prefix = manager.resolve_relation_alias( - from_model=previous_model, relation_name=part2 + from_model=previous_model, relation_name=relation ) - model_cls = model_cls.Meta.model_fields[part].to + model_cls = related_field.to previous_model = model_cls + # handle duplicated aliases in nested relations + # TODO: check later and remove nocover + complex_prefix = manager.resolve_relation_alias( + from_model=self.model_cls, + relation_name='__'.join([related_str, field_name]) + ) + if complex_prefix: # pragma: nocover + table_prefix = complex_prefix return select_related, table_prefix, model_cls def _compile_clause( - self, - clause: sqlalchemy.sql.expression.BinaryExpression, - column: sqlalchemy.Column, - table: sqlalchemy.Table, - table_prefix: str, - modifiers: Dict, + self, + clause: sqlalchemy.sql.expression.BinaryExpression, + column: sqlalchemy.Column, + table: sqlalchemy.Table, + table_prefix: str, + modifiers: Dict, ) -> sqlalchemy.sql.expression.TextClause: """ Compiles the clause to str using appropriate database dialect, replace columns @@ -287,7 +296,7 @@ class QueryClause: @staticmethod def _extract_operator_field_and_related( - parts: List[str], + parts: List[str], ) -> Tuple[str, str, Optional[List]]: """ Splits filter query key and extracts required parts. diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index e911d44..4013efe 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -24,18 +24,20 @@ if TYPE_CHECKING: # pragma no cover class SqlJoin: def __init__( # noqa: CFQ002 - self, - used_aliases: List, - select_from: sqlalchemy.sql.select, - columns: List[sqlalchemy.Column], - fields: Optional[Union[Set, Dict]], - exclude_fields: Optional[Union[Set, Dict]], - order_columns: Optional[List], - sorted_orders: OrderedDict, - main_model: Type["Model"], - relation_name: str, - related_models: Any = None, - own_alias: str = "", + self, + used_aliases: List, + select_from: sqlalchemy.sql.select, + columns: List[sqlalchemy.Column], + fields: Optional[Union[Set, Dict]], + exclude_fields: Optional[Union[Set, Dict]], + order_columns: Optional[List], + sorted_orders: OrderedDict, + main_model: Type["Model"], + relation_name: str, + relation_str: str, + related_models: Any = None, + own_alias: str = "", + source_model: Type["Model"] = None, ) -> None: self.relation_name = relation_name self.related_models = related_models or [] @@ -53,6 +55,9 @@ class SqlJoin: self._next_model: Optional[Type["Model"]] = None self._next_alias: Optional[str] = None + self.relation_str = relation_str + self.source_model = source_model + @property def next_model(self) -> Type["Model"]: if not self._next_model: # pragma: nocover @@ -85,7 +90,8 @@ class SqlJoin: """ return self.main_model.Meta.alias_manager - def on_clause(self, previous_alias: str, from_clause: str, to_clause: str,) -> text: + def on_clause(self, previous_alias: str, from_clause: str, + to_clause: str, ) -> text: """ Receives aliases and names of both ends of the join and combines them into one text clause used in joins. @@ -117,11 +123,7 @@ class SqlJoin: self.process_m2m_through_table() self.next_model = self.target_field.to - self.next_alias = self.alias_manager.resolve_relation_alias( - from_model=self.target_field.owner, relation_name=self.relation_name - ) - if self.next_alias not in self.used_aliases: - self._process_join() + self._forward_join() self._process_following_joins() @@ -132,6 +134,23 @@ class SqlJoin: self.sorted_orders, ) + def _forward_join(self): + self.next_alias = self.alias_manager.resolve_relation_alias( + from_model=self.target_field.owner, relation_name=self.relation_name + ) + if self.next_alias not in self.used_aliases: + self._process_join() + else: + if '__' in self.relation_str: + relation_key = f'{self.source_model.get_name()}_{self.relation_str}' + if relation_key not in self.alias_manager: + print(f'registering {relation_key}') + self.next_alias = self.alias_manager.add_alias( + alias_key=relation_key) + else: + self.next_alias = self.alias_manager[relation_key] + self._process_join() + def _process_following_joins(self) -> None: """ Iterates through nested models to create subsequent joins. @@ -139,8 +158,8 @@ class SqlJoin: for related_name in self.related_models: remainder = None if ( - isinstance(self.related_models, dict) - and self.related_models[related_name] + isinstance(self.related_models, dict) + and self.related_models[related_name] ): remainder = self.related_models[related_name] self._process_deeper_join(related_name=related_name, remainder=remainder) @@ -175,7 +194,9 @@ class SqlJoin: main_model=self.next_model, relation_name=related_name, related_models=remainder, + relation_str='__'.join([self.relation_str, related_name]), own_alias=self.next_alias, + source_model=self.source_model or self.main_model ) ( self.used_aliases, @@ -203,11 +224,8 @@ class SqlJoin: self._replace_many_to_many_order_by_columns(self.relation_name, new_part) self.next_model = self.target_field.through - self.next_alias = self.alias_manager.resolve_relation_alias( - from_model=self.target_field.owner, relation_name=self.relation_name - ) - if self.next_alias not in self.used_aliases: - self._process_join() + self._forward_join() + self.relation_name = new_part self.own_alias = self.next_alias self.target_field = self.next_model.Meta.model_fields[self.relation_name] @@ -226,18 +244,18 @@ class SqlJoin: """ target_field = self.target_field is_primary_self_ref = ( - target_field.self_reference - and self.relation_name == target_field.self_reference_primary + target_field.self_reference + and self.relation_name == target_field.self_reference_primary ) if (is_primary_self_ref and not reverse) or ( - not is_primary_self_ref and reverse + not is_primary_self_ref and reverse ): new_part = target_field.default_source_field_name() # type: ignore else: new_part = target_field.default_target_field_name() # type: ignore return new_part - def _process_join(self,) -> None: # noqa: CFQ002 + def _process_join(self, ) -> None: # noqa: CFQ002 """ Resolves to and from column names and table names. @@ -317,10 +335,10 @@ class SqlJoin: :rtype: bool """ return len(condition) >= 2 and ( - condition[-2] == part or condition[-2][1:] == part + condition[-2] == part or condition[-2][1:] == part ) - def set_aliased_order_by(self, condition: List[str], to_table: str,) -> None: + def set_aliased_order_by(self, condition: List[str], to_table: str, ) -> None: """ Substitute hyphens ('-') with descending order. Construct actual sqlalchemy text clause using aliased table and column name. @@ -335,7 +353,7 @@ class SqlJoin: order = text(f"{self.next_alias}_{to_table}.{column_alias} {direction}") self.sorted_orders["__".join(condition)] = order - def get_order_bys(self, to_table: str, pkname_alias: str,) -> None: # noqa: CCR001 + def get_order_bys(self, to_table: str, pkname_alias: str, ) -> None: # noqa: CCR001 """ Triggers construction of order bys if they are given. Otherwise by default each table is sorted by a primary key column asc. diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index df2ed6d..60f3b5e 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -159,6 +159,7 @@ class Query: sorted_orders=self.sorted_orders, main_model=self.model_cls, relation_name=related, + relation_str=related, related_models=remainder, ) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 4940265..08c4be5 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -228,6 +228,9 @@ class QuerySet: :return: filtered QuerySet :rtype: QuerySet """ + # TODO: delay processing of filter clauses or switch to group one + # that keeps all aliases even if duplicated - now initialized too late + # in the join qryclause = QueryClause( model_cls=self.model, select_related=self._select_related, diff --git a/ormar/relations/alias_manager.py b/ormar/relations/alias_manager.py index 3c85d1c..803abc5 100644 --- a/ormar/relations/alias_manager.py +++ b/ormar/relations/alias_manager.py @@ -33,6 +33,12 @@ class AliasManager: def __init__(self) -> None: self._aliases_new: Dict[str, str] = dict() + def __contains__(self, item): + return self._aliases_new.__contains__(item) + + def __getitem__(self, key): + return self._aliases_new.__getitem__(key) + @staticmethod def prefixed_columns( alias: str, table: sqlalchemy.Table, fields: List = None diff --git a/tests/test_docs/__init__.py b/tests/test_docs/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_forward_refs.py b/tests/test_forward_refs.py index f002ef1..9132c41 100644 --- a/tests/test_forward_refs.py +++ b/tests/test_forward_refs.py @@ -169,7 +169,7 @@ async def test_other_forwardref_relation(cleanup): async def test_m2m_self_forwardref_relation(cleanup): async with db: async with db.transaction(force_rollback=True): - checkers = await Game.objects.create(name="checkers") + checkers = await Game.objects.create(name="Checkers") uno = await Game(name="Uno").save() jenga = await Game(name="Jenga").save() @@ -186,15 +186,17 @@ async def test_m2m_self_forwardref_relation(cleanup): await billy.friends.add(kate) await billy.friends.add(steve) - # await steve.friends.add(kate) - # await steve.friends.add(billy) - billy_check = await Child.objects.select_related( - ["friends", "favourite_game", "least_favourite_game"] + ["friends", "favourite_game", "least_favourite_game", + "friends__favourite_game", "friends__least_favourite_game"] ).get(name="Billy") assert len(billy_check.friends) == 2 assert billy_check.friends[0].name == "Kate" + assert billy_check.friends[0].favourite_game.name == 'Checkers' + assert billy_check.friends[0].least_favourite_game.name == 'Uno' assert billy_check.friends[1].name == "Steve" + assert billy_check.friends[1].favourite_game.name == 'Jenga' + assert billy_check.friends[1].least_favourite_game.name == 'Uno' assert billy_check.favourite_game.name == "Uno" kate_check = await Child.objects.select_related(["also_friends",]).get( @@ -203,3 +205,13 @@ async def test_m2m_self_forwardref_relation(cleanup): assert len(kate_check.also_friends) == 1 assert kate_check.also_friends[0].name == "Billy" + + # TODO: Fix filters with complex prefixes + # billy_check = await Child.objects.select_related( + # ["friends", "favourite_game", "least_favourite_game", + # "friends__favourite_game", "friends__least_favourite_game"] + # ).filter(friends__favourite_game__name="Checkers").get(name="Billy") + # assert len(billy_check.friends) == 1 + # assert billy_check.friends[0].name == "Kate" + # assert billy_check.friends[0].favourite_game.name == 'Checkers' + # assert billy_check.friends[0].least_favourite_game.name == 'Uno' diff --git a/tests/test_models_helpers.py b/tests/test_models_helpers.py new file mode 100644 index 0000000..56a77c0 --- /dev/null +++ b/tests/test_models_helpers.py @@ -0,0 +1,9 @@ +from ormar.models.helpers.models import group_related_list + + +def test_group_related_list(): + given = ['friends__least_favourite_game', 'least_favourite_game', 'friends', + 'favourite_game', 'friends__favourite_game'] + expected = {'least_favourite_game': [], 'favourite_game': [], + 'friends': ['favourite_game', 'least_favourite_game']} + assert group_related_list(given) == expected