diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 39ef64e..243e66e 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -1,4 +1,5 @@ from typing import ( + Any, Dict, List, Optional, @@ -10,10 +11,11 @@ from typing import ( Union, ) +import ormar from ormar.fields import BaseField, ManyToManyField from ormar.queryset.clause import QueryClause from ormar.queryset.query import Query -from ormar.queryset.utils import translate_list_to_dict +from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict if TYPE_CHECKING: # pragma: no cover from ormar import Model @@ -35,65 +37,114 @@ class PrefetchQuery: self._select_related = select_related self._exclude_columns = exclude_fields self._columns = fields + self.already_extracted: Dict = dict() + self.models: Dict = {} + self.select_dict = translate_list_to_dict(self._select_related) + + async def prefetch_related( + self, models: Sequence["Model"], rows: List + ) -> Sequence["Model"]: + self.models = extract_models_to_dict_of_lists( + model_type=self.model, models=models, select_dict=self.select_dict + ) + self.models[self.model.get_name()] = models + return await self._prefetch_related_models(models=models, rows=rows) @staticmethod - def _extract_required_ids( - already_extracted: Dict, + def _get_column_name_for_id_extraction( parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, - ) -> Set: - current_data = already_extracted.get(parent_model.get_name(), {}) - raw_rows = current_data.get("raw", []) - table_prefix = current_data.get("prefix", "") + use_raw: bool, + ) -> str: if reverse: - column_name = parent_model.get_column_alias(parent_model.Meta.pkname) + column_name = parent_model.Meta.pkname + return ( + parent_model.get_column_alias(column_name) if use_raw else column_name + ) else: - column_name = target_model.resolve_relation_field( - parent_model, target_model - ).get_alias() + column = target_model.resolve_relation_field(parent_model, target_model) + return column.get_alias() if use_raw else column.name + + def _extract_ids_from_raw_data( + self, parent_model: Type["Model"], column_name: str + ) -> Set: list_of_ids = set() + current_data = self.already_extracted.get(parent_model.get_name(), {}) + table_prefix = current_data.get("prefix", "") column_name = (f"{table_prefix}_" if table_prefix else "") + column_name - for row in raw_rows: + for row in current_data.get("raw", []): if row[column_name]: list_of_ids.add(row[column_name]) return list_of_ids - @staticmethod - def _get_filter_for_prefetch( - already_extracted: Dict, - parent_model: Type["Model"], - target_model: Type["Model"], - reverse: bool, - ) -> List: - ids = PrefetchQuery._extract_required_ids( - already_extracted=already_extracted, + def _extract_ids_from_preloaded_models( + self, parent_model: Type["Model"], column_name: str + ) -> Set: + list_of_ids = set() + for model in self.models.get(parent_model.get_name(), []): + child = getattr(model, column_name) + if isinstance(child, ormar.Model): + list_of_ids.add(child.pk) + else: + list_of_ids.add(child) + return list_of_ids + + def _extract_required_ids( + self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, + ) -> Set: + + use_raw = parent_model.get_name() not in self.models + + column_name = self._get_column_name_for_id_extraction( parent_model=parent_model, target_model=target_model, reverse=reverse, + use_raw=use_raw, + ) + + if use_raw: + return self._extract_ids_from_raw_data( + parent_model=parent_model, column_name=column_name + ) + + return self._extract_ids_from_preloaded_models( + parent_model=parent_model, column_name=column_name + ) + + @staticmethod + def _get_clause_target_and_filter_column_name( + parent_model: Type["Model"], target_model: Type["Model"], reverse: bool + ) -> Tuple[Type["Model"], str]: + if reverse: + field = target_model.resolve_relation_field(target_model, parent_model) + if issubclass(field, ManyToManyField): + sub_field = target_model.resolve_relation_field( + field.through, parent_model + ) + return field.through, sub_field.get_alias() + else: + return target_model, field.get_alias() + target_field = target_model.get_column_alias(target_model.Meta.pkname) + return target_model, target_field + + def _get_filter_for_prefetch( + self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, + ) -> List: + ids = self._extract_required_ids( + parent_model=parent_model, target_model=target_model, reverse=reverse, ) if ids: - qryclause = QueryClause( - model_cls=target_model, select_related=[], filter_clauses=[], + ( + clause_target, + filter_column, + ) = self._get_clause_target_and_filter_column_name( + parent_model=parent_model, target_model=target_model, reverse=reverse ) - if reverse: - field = target_model.resolve_relation_field(target_model, parent_model) - if issubclass(field, ManyToManyField): - sub_field = target_model.resolve_relation_field( - field.through, parent_model - ) - kwargs = {f"{sub_field.get_alias()}__in": ids} - qryclause = QueryClause( - model_cls=field.through, select_related=[], filter_clauses=[], - ) - - else: - kwargs = {f"{field.get_alias()}__in": ids} - else: - target_field = target_model.Meta.model_fields[ - target_model.Meta.pkname - ].get_alias() - kwargs = {f"{target_field}__in": ids} + qryclause = QueryClause( + model_cls=clause_target, select_related=[], filter_clauses=[], + ) + kwargs = {f"{filter_column}__in": ids} filter_clauses, _ = qryclause.filter(**kwargs) return filter_clauses return [] @@ -123,7 +174,7 @@ class PrefetchQuery: @staticmethod def _get_group_field_name( - target_field: Type["BaseField"], model: Type["Model"] + target_field: Type["BaseField"], model: Union["Model", Type["Model"]] ) -> str: if issubclass(target_field, ManyToManyField): return model.resolve_relation_name(target_field.through, model) @@ -142,117 +193,150 @@ class PrefetchQuery: ] return related_to_extract - @staticmethod - def _populate_nested_related( - model: "Model", already_extracted: Dict, prefetch_dict: Dict - ) -> "Model": + def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model": - related_to_extract = PrefetchQuery._get_names_to_extract( + related_to_extract = self._get_names_to_extract( prefetch_dict=prefetch_dict, model=model ) for related in related_to_extract: target_field = model.Meta.model_fields[related] target_model = target_field.to.get_name() - is_multi, field_name, model_id = PrefetchQuery._get_model_id_and_field_name( + is_multi, field_name, model_id = self._get_model_id_and_field_name( target_field=target_field, model=model ) - if not field_name: + + if field_name is None or model_id is None: # pragma: no cover continue - children = already_extracted.get(target_model, {}).get(field_name, {}) - for key, child_models in children.items(): - if key == model_id: - for child in child_models: - setattr(model, related, child) + children = self.already_extracted.get(target_model, {}).get(field_name, {}) + self._set_children_on_model( + model=model, related=related, children=children, model_id=model_id + ) return model - async def prefetch_related( - self, models: Sequence["Model"], rows: List - ) -> Sequence["Model"]: - return await self._prefetch_related_models(models=models, rows=rows) + @staticmethod + def _set_children_on_model( + model: "Model", related: str, children: Dict, model_id: int + ) -> None: + for key, child_models in children.items(): + if key == model_id: + for child in child_models: + setattr(model, related, child) async def _prefetch_related_models( self, models: Sequence["Model"], rows: List ) -> Sequence["Model"]: - already_extracted = { - self.model.get_name(): { - "raw": rows, - "models": {model.pk: model for model in models}, - } - } + self.already_extracted = {self.model.get_name(): {"raw": rows}} select_dict = translate_list_to_dict(self._select_related) prefetch_dict = translate_list_to_dict(self._prefetch_related) target_model = self.model fields = self._columns exclude_fields = self._exclude_columns for related in prefetch_dict.keys(): - subrelated = await self._extract_related_models( + await self._extract_related_models( related=related, target_model=target_model, - prefetch_dict=prefetch_dict.get(related), - select_dict=select_dict.get(related), - already_extracted=already_extracted, + prefetch_dict=prefetch_dict.get(related, {}), + select_dict=select_dict.get(related, {}), fields=fields, exclude_fields=exclude_fields, ) - print(related, subrelated) final_models = [] for model in models: final_models.append( - self._populate_nested_related( - model=model, - already_extracted=already_extracted, - prefetch_dict=prefetch_dict, - ) + self._populate_nested_related(model=model, prefetch_dict=prefetch_dict,) ) return models - async def _extract_related_models( # noqa: CFQ002 + async def _extract_related_models( # noqa: CFQ002, CCR001 self, related: str, target_model: Type["Model"], prefetch_dict: Dict, select_dict: Dict, - already_extracted: Dict, - fields: Dict, - exclude_fields: Dict, + fields: Union[Set[Any], Dict[Any, Any], None], + exclude_fields: Union[Set[Any], Dict[Any, Any], None], ) -> None: fields = target_model.get_included(fields, related) exclude_fields = target_model.get_excluded(exclude_fields, related) - select_related = [] - target_field = target_model.Meta.model_fields[related] reverse = False if target_field.virtual or issubclass(target_field, ManyToManyField): reverse = True parent_model = target_model - target_model = target_field.to - filter_clauses = PrefetchQuery._get_filter_for_prefetch( - already_extracted=already_extracted, - parent_model=parent_model, - target_model=target_model, - reverse=reverse, + filter_clauses = self._get_filter_for_prefetch( + parent_model=parent_model, target_model=target_field.to, reverse=reverse, ) if not filter_clauses: # related field is empty return + already_loaded = select_dict is Ellipsis or related in select_dict + + if not already_loaded: + # If not already loaded with select_related + table_prefix, rows = await self._run_prefetch_query( + target_field=target_field, + fields=fields, + exclude_fields=exclude_fields, + filter_clauses=filter_clauses, + ) + else: + rows = [] + table_prefix = "" + + if prefetch_dict and prefetch_dict is not Ellipsis: + for subrelated in prefetch_dict.keys(): + await self._extract_related_models( + related=subrelated, + target_model=target_field.to, + prefetch_dict=prefetch_dict.get(subrelated, {}), + select_dict=self._get_select_related_if_apply( + subrelated, select_dict + ), + fields=fields, + exclude_fields=exclude_fields, + ) + + if not already_loaded: + self._populate_rows( + rows=rows, + parent_model=parent_model, + target_field=target_field, + table_prefix=table_prefix, + fields=fields, + exclude_fields=exclude_fields, + prefetch_dict=prefetch_dict, + ) + else: + self._update_already_loaded_rows( + target_field=target_field, prefetch_dict=prefetch_dict, + ) + + async def _run_prefetch_query( + self, + target_field: Type["BaseField"], + fields: Union[Set[Any], Dict[Any, Any], None], + exclude_fields: Union[Set[Any], Dict[Any, Any], None], + filter_clauses: List, + ) -> Tuple[str, List]: + target_model = target_field.to + target_name = target_model.get_name() + select_related = [] query_target = target_model table_prefix = "" if issubclass(target_field, ManyToManyField): query_target = target_field.through - select_related = [target_field.to.get_name()] + select_related = [target_name] table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( from_table=query_target.Meta.tablename, to_table=target_field.to.Meta.tablename, ) - already_extracted.setdefault(target_model.get_name(), {})[ - "prefix" - ] = table_prefix + self.already_extracted.setdefault(target_name, {})["prefix"] = table_prefix qry = Query( model_cls=query_target, @@ -268,30 +352,41 @@ class PrefetchQuery: expr = qry.build_select_expression() # print(expr.compile(compile_kwargs={"literal_binds": True})) rows = await self.database.fetch_all(expr) - already_extracted.setdefault(target_model.get_name(), {}).update( - {"raw": rows, "models": {}} + self.already_extracted.setdefault(target_name, {}).update({"raw": rows}) + return table_prefix, rows + + @staticmethod + def _get_select_related_if_apply(related: str, select_dict: Dict) -> Dict: + return ( + select_dict.get(related, {}) + if (select_dict and select_dict is not Ellipsis and related in select_dict) + else {} ) - if prefetch_dict and prefetch_dict is not Ellipsis: - for subrelated in prefetch_dict.keys(): - submodels = await self._extract_related_models( - related=subrelated, - target_model=target_model, - prefetch_dict=prefetch_dict.get(subrelated), - select_dict=select_dict.get(subrelated) - if (select_dict and subrelated in select_dict) - else {}, - already_extracted=already_extracted, - fields=fields, - exclude_fields=exclude_fields, - ) - print(subrelated, submodels) + def _update_already_loaded_rows( # noqa: CFQ002 + self, target_field: Type["BaseField"], prefetch_dict: Dict, + ) -> None: + target_model = target_field.to + for instance in self.models.get(target_model.get_name(), []): + self._populate_nested_related( + model=instance, prefetch_dict=prefetch_dict, + ) + def _populate_rows( # noqa: CFQ002 + self, + rows: List, + target_field: Type["BaseField"], + parent_model: Type["Model"], + table_prefix: str, + fields: Union[Set[Any], Dict[Any, Any], None], + exclude_fields: Union[Set[Any], Dict[Any, Any], None], + prefetch_dict: Dict, + ) -> None: + target_model = target_field.to for row in rows: - field_name = PrefetchQuery._get_group_field_name( + field_name = self._get_group_field_name( target_field=target_field, model=parent_model ) - print("TEST", field_name, target_model, row[field_name]) item = target_model.extract_prefixed_table_columns( item={}, row=row, @@ -301,13 +396,8 @@ class PrefetchQuery: ) instance = target_model(**item) instance = self._populate_nested_related( - model=instance, - already_extracted=already_extracted, - prefetch_dict=prefetch_dict, + model=instance, prefetch_dict=prefetch_dict, ) - already_extracted[target_model.get_name()].setdefault( + self.already_extracted[target_model.get_name()].setdefault( field_name, dict() ).setdefault(row[field_name], []).append(instance) - already_extracted[target_model.get_name()]["models"][instance.pk] = instance - - return already_extracted[target_model.get_name()]["models"] diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index ce902f7..bc36965 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -67,16 +67,16 @@ class QuerySet: return self.model_cls async def _prefetch_related_models( - self, models: Sequence["Model"], rows: List - ) -> Sequence["Model"]: + self, models: Sequence[Optional["Model"]], rows: List + ) -> Sequence[Optional["Model"]]: query = PrefetchQuery( - model_cls=self.model_cls, + model_cls=self.model, fields=self._columns, exclude_fields=self._exclude_columns, prefetch_related=self._prefetch_related, select_related=self._select_related, ) - return await query.prefetch_related(models=models, rows=rows) + return await query.prefetch_related(models=models, rows=rows) # type: ignore def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: result_rows = [ @@ -191,7 +191,7 @@ class QuerySet: if not isinstance(related, list): related = [related] - related = list(set(list(self._select_related) + related)) + related = list(set(list(self._prefetch_related) + related)) return self.__class__( model_cls=self.model, filter_clauses=self.filter_clauses, @@ -352,7 +352,7 @@ class QuerySet: rows = await self.database.fetch_all(expr) processed_rows = self._process_query_result_rows(rows) - if self._prefetch_related: + if self._prefetch_related and processed_rows: processed_rows = await self._prefetch_related_models(processed_rows, rows) self.check_single_result_rows_count(processed_rows) return processed_rows[0] # type: ignore @@ -379,7 +379,7 @@ class QuerySet: expr = self.build_select_expression() rows = await self.database.fetch_all(expr) result_rows = self._process_query_result_rows(rows) - if self._prefetch_related: + if self._prefetch_related and result_rows: result_rows = await self._prefetch_related_models(result_rows, rows) return result_rows diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index c3c8fa9..ed422dd 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -1,6 +1,9 @@ import collections.abc import copy -from typing import Any, Dict, List, Set, Union +from typing import Any, Dict, List, Sequence, Set, TYPE_CHECKING, Type, Union + +if TYPE_CHECKING: # pragma no cover + from ormar import Model def check_node_not_dict_or_not_last_node( @@ -55,3 +58,39 @@ def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> dict_to_update = translate_list_to_dict(list_to_update) update(updated_dict, dict_to_update) return updated_dict + + +def extract_nested_models( # noqa: CCR001 + model: "Model", model_type: Type["Model"], select_dict: Dict, extracted: Dict +) -> None: + follow = [rel for rel in model_type.extract_related_names() if rel in select_dict] + for related in follow: + child = getattr(model, related) + if child: + target_model = model_type.Meta.model_fields[related].to + if isinstance(child, list): + extracted.setdefault(target_model.get_name(), []).extend(child) + if select_dict[related] is not Ellipsis: + for sub_child in child: + extract_nested_models( + sub_child, target_model, select_dict[related], extracted, + ) + else: + extracted.setdefault(target_model.get_name(), []).append(child) + if select_dict[related] is not Ellipsis: + extract_nested_models( + child, target_model, select_dict[related], extracted, + ) + + +def extract_models_to_dict_of_lists( + model_type: Type["Model"], + models: Sequence["Model"], + select_dict: Dict, + extracted: Dict = None, +) -> Dict: + if not extracted: + extracted = dict() + for model in models: + extract_nested_models(model, model_type, select_dict, extracted) + return extracted diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index a6b8831..378fdb4 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -11,6 +11,16 @@ database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() +class RandomSet(ormar.Model): + class Meta: + tablename = "randoms" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + class Tonation(ormar.Model): class Meta: tablename = "tonations" @@ -19,6 +29,7 @@ class Tonation(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) + rand_set: Optional[RandomSet] = ormar.ForeignKey(RandomSet) class Division(ormar.Model): @@ -181,3 +192,53 @@ async def test_prefetch_related_empty(): track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird") assert track.title == 'The Bird' assert track.album is None + + +@pytest.mark.asyncio +async def test_prefetch_related_with_select_related(): + async with database: + async with database.transaction(force_rollback=True): + div = await Division.objects.create(name='Div 1') + shop1 = await Shop.objects.create(name='Shop 1', division=div) + shop2 = await Shop.objects.create(name='Shop 2', division=div) + album = Album(name="Malibu") + await album.save() + await album.shops.add(shop1) + await album.shops.add(shop2) + + await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') + await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + + album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related( + ['cover_pictures', 'shops__division']).get() + assert len(album.tracks) == 0 + assert len(album.cover_pictures) == 2 + assert album.shops[0].division.name == 'Div 1' + + rand_set = await RandomSet.objects.create(name='Rand 1') + ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set) + await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) + await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) + await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) + + album = await Album.objects.select_related('tracks__tonation__rand_set').filter(name='Malibu').prefetch_related( + ['cover_pictures', 'shops__division']).get() + assert len(album.tracks) == 3 + assert album.tracks[0].tonation == album.tracks[2].tonation == ton1 + assert len(album.cover_pictures) == 2 + assert album.cover_pictures[0].artist == 'Artist 1' + + assert len(album.shops) == 2 + assert album.shops[0].name == 'Shop 1' + assert album.shops[0].division.name == 'Div 1' + + track = await Track.objects.select_related('album').prefetch_related( + ["album__cover_pictures", "album__shops__division"]).get( + title="The Bird") + assert track.album.name == "Malibu" + assert len(track.album.cover_pictures) == 2 + assert track.album.cover_pictures[0].artist == 'Artist 1' + + assert len(track.album.shops) == 2 + assert track.album.shops[0].name == 'Shop 1' + assert track.album.shops[0].division.name == 'Div 1'