From e0223f8a2251df8a92c0559f4934c5a7356cc15e Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 25 Nov 2020 16:56:54 +0100 Subject: [PATCH] cleanup by related field --- ormar/queryset/prefetch_query.py | 82 +++++++++++++++----------------- 1 file changed, 39 insertions(+), 43 deletions(-) diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 8b7ade2..39ef64e 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -10,7 +10,6 @@ from typing import ( Union, ) -import ormar from ormar.fields import BaseField, ManyToManyField from ormar.queryset.clause import QueryClause from ormar.queryset.query import Query @@ -44,10 +43,9 @@ class PrefetchQuery: target_model: Type["Model"], reverse: bool, ) -> Set: - raw_rows = already_extracted.get(parent_model.get_name(), {}).get("raw", []) - table_prefix = already_extracted.get(parent_model.get_name(), {}).get( - "prefix", "" - ) + current_data = already_extracted.get(parent_model.get_name(), {}) + raw_rows = current_data.get("raw", []) + table_prefix = current_data.get("prefix", "") if reverse: column_name = parent_model.get_column_alias(parent_model.Meta.pkname) else: @@ -105,23 +103,33 @@ class PrefetchQuery: target_field: Type["BaseField"], model: "Model" ) -> Tuple[bool, Optional[str], Optional[int]]: if target_field.virtual: - reverse = True + is_multi = False field_name = model.resolve_relation_name(target_field.to, model) model_id = model.pk elif issubclass(target_field, ManyToManyField): - reverse = True + is_multi = True field_name = model.resolve_relation_name(target_field.through, model) model_id = model.pk else: - reverse = False + is_multi = False related_name = model.resolve_relation_name(model, target_field.to) related_model = getattr(model, related_name) if not related_model: - return reverse, None, None + return is_multi, None, None model_id = related_model.pk field_name = target_field.to.Meta.pkname - return reverse, field_name, model_id + return is_multi, field_name, model_id + + @staticmethod + def _get_group_field_name( + target_field: Type["BaseField"], model: Type["Model"] + ) -> str: + if issubclass(target_field, ManyToManyField): + return model.resolve_relation_name(target_field.through, model) + if target_field.virtual: + return model.resolve_relation_name(target_field.to, model) + return target_field.to.Meta.pkname @staticmethod def _get_names_to_extract(prefetch_dict: Dict, model: "Model") -> List: @@ -146,40 +154,17 @@ class PrefetchQuery: for related in related_to_extract: target_field = model.Meta.model_fields[related] target_model = target_field.to.get_name() - reverse, field_name, model_id = PrefetchQuery._get_model_id_and_field_name( + is_multi, field_name, model_id = PrefetchQuery._get_model_id_and_field_name( target_field=target_field, model=model ) + if not field_name: + continue - if ( - target_model in already_extracted - and already_extracted[target_model]["models"] - ): - for key, child_model in already_extracted[target_model][ - "models" - ].items(): - if issubclass(target_field, ManyToManyField): - ind = next( - i - if key == x[target_field.to.get_column_alias(field_name)] - else -1 - for i, x in enumerate( - already_extracted[target_model]["raw"] - ) - ) - raw_data = already_extracted[target_model]["raw"][ind] - if ( - raw_data - and field_name in raw_data - and raw_data[field_name] == model_id - ): - setattr(model, related, child_model) - - elif isinstance(getattr(child_model, field_name), ormar.Model): - if getattr(child_model, field_name).pk == model_id: - setattr(model, related, child_model) - - elif getattr(child_model, field_name) == model_id: - setattr(model, related, child_model) + children = already_extracted.get(target_model, {}).get(field_name, {}) + for key, child_models in children.items(): + if key == model_id: + for child in child_models: + setattr(model, related, child) return model @@ -203,7 +188,7 @@ class PrefetchQuery: fields = self._columns exclude_fields = self._exclude_columns for related in prefetch_dict.keys(): - await self._extract_related_models( + subrelated = await self._extract_related_models( related=related, target_model=target_model, prefetch_dict=prefetch_dict.get(related), @@ -212,6 +197,7 @@ class PrefetchQuery: fields=fields, exclude_fields=exclude_fields, ) + print(related, subrelated) final_models = [] for model in models: final_models.append( @@ -288,7 +274,7 @@ class PrefetchQuery: if prefetch_dict and prefetch_dict is not Ellipsis: for subrelated in prefetch_dict.keys(): - await self._extract_related_models( + submodels = await self._extract_related_models( related=subrelated, target_model=target_model, prefetch_dict=prefetch_dict.get(subrelated), @@ -299,8 +285,13 @@ class PrefetchQuery: fields=fields, exclude_fields=exclude_fields, ) + print(subrelated, submodels) for row in rows: + field_name = PrefetchQuery._get_group_field_name( + target_field=target_field, model=parent_model + ) + print("TEST", field_name, target_model, row[field_name]) item = target_model.extract_prefixed_table_columns( item={}, row=row, @@ -314,4 +305,9 @@ class PrefetchQuery: already_extracted=already_extracted, prefetch_dict=prefetch_dict, ) + already_extracted[target_model.get_name()].setdefault( + field_name, dict() + ).setdefault(row[field_name], []).append(instance) already_extracted[target_model.get_name()]["models"][instance.pk] = instance + + return already_extracted[target_model.get_name()]["models"]