diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 6cc9ec3..8b7ade2 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -1,228 +1,317 @@ -from typing import Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union +from typing import ( + Dict, + List, + Optional, + Sequence, + Set, + TYPE_CHECKING, + Tuple, + Type, + Union, +) import ormar -from ormar.fields import ManyToManyField +from ormar.fields import BaseField, ManyToManyField from ormar.queryset.clause import QueryClause from ormar.queryset.query import Query +from ormar.queryset.utils import translate_list_to_dict if TYPE_CHECKING: # pragma: no cover from ormar import Model class PrefetchQuery: - - def __init__(self, - model_cls: Type["Model"], - fields: Optional[Union[Dict, Set]], - exclude_fields: Optional[Union[Dict, Set]], - prefetch_related: List - ): + def __init__( + self, + model_cls: Type["Model"], + fields: Optional[Union[Dict, Set]], + exclude_fields: Optional[Union[Dict, Set]], + prefetch_related: List, + select_related: List, + ) -> None: self.model = model_cls self.database = self.model.Meta.database self._prefetch_related = prefetch_related + self._select_related = select_related self._exclude_columns = exclude_fields self._columns = fields @staticmethod - def _extract_required_ids(already_extracted: Dict, - parent_model: Type["Model"], - target_model: Type["Model"], - reverse: bool) -> Set: - raw_rows = already_extracted.get(parent_model.get_name(), {}).get('raw', []) + def _extract_required_ids( + already_extracted: Dict, + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool, + ) -> Set: + raw_rows = already_extracted.get(parent_model.get_name(), {}).get("raw", []) + table_prefix = already_extracted.get(parent_model.get_name(), {}).get( + "prefix", "" + ) if reverse: column_name = parent_model.get_column_alias(parent_model.Meta.pkname) else: - column_name = target_model.resolve_relation_field(parent_model, target_model).get_alias() + column_name = target_model.resolve_relation_field( + parent_model, target_model + ).get_alias() list_of_ids = set() + column_name = (f"{table_prefix}_" if table_prefix else "") + column_name for row in raw_rows: if row[column_name]: list_of_ids.add(row[column_name]) return list_of_ids @staticmethod - def _get_filter_for_prefetch(already_extracted: Dict, - parent_model: Type["Model"], - target_model: Type["Model"], - reverse: bool) -> List: - ids = PrefetchQuery._extract_required_ids(already_extracted=already_extracted, - parent_model=parent_model, - target_model=target_model, - reverse=reverse) + def _get_filter_for_prefetch( + already_extracted: Dict, + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool, + ) -> List: + ids = PrefetchQuery._extract_required_ids( + already_extracted=already_extracted, + parent_model=parent_model, + target_model=target_model, + reverse=reverse, + ) if ids: qryclause = QueryClause( - model_cls=target_model, - select_related=[], - filter_clauses=[], + model_cls=target_model, select_related=[], filter_clauses=[], ) if reverse: field = target_model.resolve_relation_field(target_model, parent_model) if issubclass(field, ManyToManyField): - sub_field = target_model.resolve_relation_field(field.through, parent_model) - kwargs = {f'{sub_field.get_alias()}__in': ids} + sub_field = target_model.resolve_relation_field( + field.through, parent_model + ) + kwargs = {f"{sub_field.get_alias()}__in": ids} qryclause = QueryClause( - model_cls=field.through, - select_related=[], - filter_clauses=[], + model_cls=field.through, select_related=[], filter_clauses=[], ) else: - kwargs = {f'{field.get_alias()}__in': ids} + kwargs = {f"{field.get_alias()}__in": ids} else: - target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias() - kwargs = {f'{target_field}__in': ids} + target_field = target_model.Meta.model_fields[ + target_model.Meta.pkname + ].get_alias() + kwargs = {f"{target_field}__in": ids} filter_clauses, _ = qryclause.filter(**kwargs) return filter_clauses return [] @staticmethod - def _populate_nested_related(model: "Model", - already_extracted: Dict) -> "Model": - - for related in model.extract_related_names(): + def _get_model_id_and_field_name( + target_field: Type["BaseField"], model: "Model" + ) -> Tuple[bool, Optional[str], Optional[int]]: + if target_field.virtual: + reverse = True + field_name = model.resolve_relation_name(target_field.to, model) + model_id = model.pk + elif issubclass(target_field, ManyToManyField): + reverse = True + field_name = model.resolve_relation_name(target_field.through, model) + model_id = model.pk + else: reverse = False + related_name = model.resolve_relation_name(model, target_field.to) + related_model = getattr(model, related_name) + if not related_model: + return reverse, None, None + model_id = related_model.pk + field_name = target_field.to.Meta.pkname + return reverse, field_name, model_id + + @staticmethod + def _get_names_to_extract(prefetch_dict: Dict, model: "Model") -> List: + related_to_extract = [] + if prefetch_dict and prefetch_dict is not Ellipsis: + related_to_extract = [ + related + for related in model.extract_related_names() + if related in prefetch_dict + ] + return related_to_extract + + @staticmethod + def _populate_nested_related( + model: "Model", already_extracted: Dict, prefetch_dict: Dict + ) -> "Model": + + related_to_extract = PrefetchQuery._get_names_to_extract( + prefetch_dict=prefetch_dict, model=model + ) + + for related in related_to_extract: target_field = model.Meta.model_fields[related] target_model = target_field.to.get_name() - if target_field.virtual: - reverse = True - field_name = model.resolve_relation_name(target_field.to, model) - model_id = model.pk - elif issubclass(target_field, ManyToManyField): - reverse = True - field_name = model.resolve_relation_name(target_field.through, model) - model_id = model.pk - else: - related_name = model.resolve_relation_name(model, target_field.to) - related_model = getattr(model, related_name) - if not related_model: - continue - model_id = related_model.pk - field_name = target_field.to.Meta.pkname + reverse, field_name, model_id = PrefetchQuery._get_model_id_and_field_name( + target_field=target_field, model=model + ) - if target_model in already_extracted and already_extracted[target_model]['models']: - print('*****POPULATING RELATED:', target_model, field_name, '*****', end='\n') - print(already_extracted[target_model]['models']) - for ind, child_model in enumerate(already_extracted[target_model]['models']): + if ( + target_model in already_extracted + and already_extracted[target_model]["models"] + ): + for key, child_model in already_extracted[target_model][ + "models" + ].items(): if issubclass(target_field, ManyToManyField): - raw_data = already_extracted[target_model]['raw'][ind] - if raw_data[field_name] == model_id: + ind = next( + i + if key == x[target_field.to.get_column_alias(field_name)] + else -1 + for i, x in enumerate( + already_extracted[target_model]["raw"] + ) + ) + raw_data = already_extracted[target_model]["raw"][ind] + if ( + raw_data + and field_name in raw_data + and raw_data[field_name] == model_id + ): setattr(model, related, child_model) elif isinstance(getattr(child_model, field_name), ormar.Model): if getattr(child_model, field_name).pk == model_id: - if reverse: - setattr(model, related, child_model) - else: - setattr(child_model, related, model) + setattr(model, related, child_model) - else: # we have not reverse relation and related_model is a pk value + elif getattr(child_model, field_name) == model_id: setattr(model, related, child_model) return model - async def prefetch_related(self, models: Sequence["Model"], rows: List): + async def prefetch_related( + self, models: Sequence["Model"], rows: List + ) -> Sequence["Model"]: return await self._prefetch_related_models(models=models, rows=rows) - async def _prefetch_related_models(self, - models: Sequence["Model"], - rows: List) -> Sequence["Model"]: - already_extracted = {self.model.get_name(): {'raw': rows, 'models': models}} - for related in self._prefetch_related: - target_model = self.model - fields = self._columns - exclude_fields = self._exclude_columns - for part in related.split('__'): - fields = target_model.get_included(fields, part) - select_related = [] - exclude_fields = target_model.get_excluded(exclude_fields, part) - - target_field = target_model.Meta.model_fields[part] - reverse = False - if target_field.virtual or issubclass(target_field, ManyToManyField): - reverse = True - - if issubclass(target_field, ManyToManyField): - select_related = [target_field.through.get_name()] - - parent_model = target_model - target_model = target_field.to - - if target_model.get_name() not in already_extracted: - filter_clauses = self._get_filter_for_prefetch(already_extracted=already_extracted, - parent_model=parent_model, - target_model=target_model, - reverse=reverse) - if not filter_clauses: # related field is empty - continue - - query_target = target_model - table_prefix = '' - if issubclass(target_field, ManyToManyField): - query_target = target_field.through - select_related = [target_field.to.get_name()] - table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( - from_table=query_target.Meta.tablename, to_table=target_field.to.Meta.tablename) - - qry = Query( - model_cls=query_target, - select_related=select_related, - filter_clauses=filter_clauses, - exclude_clauses=[], - offset=None, - limit_count=None, - fields=fields, - exclude_fields=exclude_fields, - order_bys=None, - ) - expr = qry.build_select_expression() - print(expr.compile(compile_kwargs={"literal_binds": True})) - rows = await self.database.fetch_all(expr) - already_extracted[target_model.get_name()] = {'raw': rows, 'models': []} - if part == related.split('__')[-1]: - for row in rows: - item = target_model.extract_prefixed_table_columns( - item={}, - row=row, - table_prefix=table_prefix, - fields=fields, - exclude_fields=exclude_fields - ) - instance = target_model(**item) - already_extracted[target_model.get_name()]['models'].append(instance) - - target_model = self.model - fields = self._columns - exclude_fields = self._exclude_columns - for part in related.split('__')[:-1]: - fields = target_model.get_included(fields, part) - exclude_fields = target_model.get_excluded(exclude_fields, part) - - target_field = target_model.Meta.model_fields[part] - target_model = target_field.to - table_prefix = '' - - if issubclass(target_field, ManyToManyField): - from_table = target_field.through.Meta.tablename - to_name = target_field.to.Meta.tablename - table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( - from_table=from_table, to_table=to_name) - - for row in already_extracted.get(target_model.get_name(), {}).get('raw', []): - item = target_model.extract_prefixed_table_columns( - item={}, - row=row, - table_prefix=table_prefix, - fields=fields, - exclude_fields=exclude_fields - ) - instance = target_model(**item) - instance = self._populate_nested_related(model=instance, - already_extracted=already_extracted) - - already_extracted[target_model.get_name()]['models'].append(instance) + async def _prefetch_related_models( + self, models: Sequence["Model"], rows: List + ) -> Sequence["Model"]: + already_extracted = { + self.model.get_name(): { + "raw": rows, + "models": {model.pk: model for model in models}, + } + } + select_dict = translate_list_to_dict(self._select_related) + prefetch_dict = translate_list_to_dict(self._prefetch_related) + target_model = self.model + fields = self._columns + exclude_fields = self._exclude_columns + for related in prefetch_dict.keys(): + await self._extract_related_models( + related=related, + target_model=target_model, + prefetch_dict=prefetch_dict.get(related), + select_dict=select_dict.get(related), + already_extracted=already_extracted, + fields=fields, + exclude_fields=exclude_fields, + ) final_models = [] for model in models: - final_models.append(self._populate_nested_related(model=model, - already_extracted=already_extracted)) + final_models.append( + self._populate_nested_related( + model=model, + already_extracted=already_extracted, + prefetch_dict=prefetch_dict, + ) + ) return models + + async def _extract_related_models( # noqa: CFQ002 + self, + related: str, + target_model: Type["Model"], + prefetch_dict: Dict, + select_dict: Dict, + already_extracted: Dict, + fields: Dict, + exclude_fields: Dict, + ) -> None: + + fields = target_model.get_included(fields, related) + exclude_fields = target_model.get_excluded(exclude_fields, related) + select_related = [] + + target_field = target_model.Meta.model_fields[related] + reverse = False + if target_field.virtual or issubclass(target_field, ManyToManyField): + reverse = True + + parent_model = target_model + target_model = target_field.to + + filter_clauses = PrefetchQuery._get_filter_for_prefetch( + already_extracted=already_extracted, + parent_model=parent_model, + target_model=target_model, + reverse=reverse, + ) + if not filter_clauses: # related field is empty + return + + query_target = target_model + table_prefix = "" + if issubclass(target_field, ManyToManyField): + query_target = target_field.through + select_related = [target_field.to.get_name()] + table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( + from_table=query_target.Meta.tablename, + to_table=target_field.to.Meta.tablename, + ) + already_extracted.setdefault(target_model.get_name(), {})[ + "prefix" + ] = table_prefix + + qry = Query( + model_cls=query_target, + select_related=select_related, + filter_clauses=filter_clauses, + exclude_clauses=[], + offset=None, + limit_count=None, + fields=fields, + exclude_fields=exclude_fields, + order_bys=None, + ) + expr = qry.build_select_expression() + # print(expr.compile(compile_kwargs={"literal_binds": True})) + rows = await self.database.fetch_all(expr) + already_extracted.setdefault(target_model.get_name(), {}).update( + {"raw": rows, "models": {}} + ) + + if prefetch_dict and prefetch_dict is not Ellipsis: + for subrelated in prefetch_dict.keys(): + await self._extract_related_models( + related=subrelated, + target_model=target_model, + prefetch_dict=prefetch_dict.get(subrelated), + select_dict=select_dict.get(subrelated) + if (select_dict and subrelated in select_dict) + else {}, + already_extracted=already_extracted, + fields=fields, + exclude_fields=exclude_fields, + ) + + for row in rows: + item = target_model.extract_prefixed_table_columns( + item={}, + row=row, + table_prefix=table_prefix, + fields=fields, + exclude_fields=exclude_fields, + ) + instance = target_model(**item) + instance = self._populate_nested_related( + model=instance, + already_extracted=already_extracted, + prefetch_dict=prefetch_dict, + ) + already_extracted[target_model.get_name()]["models"][instance.pk] = instance diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 6e54129..ce902f7 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -21,17 +21,17 @@ if TYPE_CHECKING: # pragma no cover class QuerySet: def __init__( # noqa CFQ002 - self, - model_cls: Type["Model"] = None, - filter_clauses: List = None, - exclude_clauses: List = None, - select_related: List = None, - limit_count: int = None, - offset: int = None, - columns: Dict = None, - exclude_columns: Dict = None, - order_bys: List = None, - prefetch_related: List = None, + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + exclude_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, + columns: Dict = None, + exclude_columns: Dict = None, + order_bys: List = None, + prefetch_related: List = None, ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses @@ -45,9 +45,9 @@ class QuerySet: self.order_bys = order_bys or [] def __get__( - self, - instance: Optional[Union["QuerySet", "QuerysetProxy"]], - owner: Union[Type["Model"], Type["QuerysetProxy"]], + self, + instance: Optional[Union["QuerySet", "QuerysetProxy"]], + owner: Union[Type["Model"], Type["QuerysetProxy"]], ) -> "QuerySet": if issubclass(owner, ormar.Model): return self.__class__(model_cls=owner) @@ -66,11 +66,16 @@ class QuerySet: raise ValueError("Model class of QuerySet is not initialized") return self.model_cls - async def _prefetch_related_models(self, models: Sequence["Model"], rows: List) -> Sequence["Model"]: - query = PrefetchQuery(model_cls=self.model_cls, - fields=self._columns, - exclude_fields=self._exclude_columns, - prefetch_related=self._prefetch_related) + async def _prefetch_related_models( + self, models: Sequence["Model"], rows: List + ) -> Sequence["Model"]: + query = PrefetchQuery( + model_cls=self.model_cls, + fields=self._columns, + exclude_fields=self._exclude_columns, + prefetch_related=self._prefetch_related, + select_related=self._select_related, + ) return await query.prefetch_related(models=models, rows=rows) def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: @@ -98,7 +103,7 @@ class QuerySet: pkname = self.model_meta.pkname pk = self.model_meta.model_fields[pkname] if new_kwargs.get(pkname, ormar.Undefined) is None and ( - pk.nullable or pk.autoincrement + pk.nullable or pk.autoincrement ): del new_kwargs[pkname] return new_kwargs @@ -197,7 +202,7 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=self.order_bys, - prefetch_related=related + prefetch_related=related, ) def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": @@ -398,9 +403,9 @@ class QuerySet: # refresh server side defaults if any( - field.server_default is not None - for name, field in self.model.Meta.model_fields.items() - if name not in kwargs + field.server_default is not None + for name, field in self.model.Meta.model_fields.items() + if name not in kwargs ): instance = await instance.load() instance.set_save_status(True) @@ -420,7 +425,7 @@ class QuerySet: objt.set_save_status(True) async def bulk_update( # noqa: CCR001 - self, objects: List["Model"], columns: List[str] = None + self, objects: List["Model"], columns: List[str] = None ) -> None: ready_objects = [] pk_name = self.model_meta.pkname diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index 3b098c4..a6b8831 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -21,6 +21,16 @@ class Tonation(ormar.Model): name: str = ormar.String(max_length=100) +class Division(ormar.Model): + class Meta: + tablename = "divisions" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + class Shop(ormar.Model): class Meta: tablename = "shops" @@ -29,6 +39,7 @@ class Shop(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) + division: Optional[Division] = ormar.ForeignKey(Division) class AlbumShops(ormar.Model): @@ -137,8 +148,9 @@ async def test_prefetch_related(): async def test_prefetch_related_with_many_to_many(): async with database: async with database.transaction(force_rollback=True): - shop1 = await Shop.objects.create(name='Shop 1') - shop2 = await Shop.objects.create(name='Shop 2') + div = await Division.objects.create(name='Div 1') + shop1 = await Shop.objects.create(name='Shop 1', division=div) + shop2 = await Shop.objects.create(name='Shop 2', division=div) album = Album(name="Malibu") await album.save() await album.shops.add(shop1) @@ -150,13 +162,15 @@ async def test_prefetch_related_with_many_to_many(): await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') - track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops"]).get( + track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops__division"]).get( title="The Bird") assert track.album.name == "Malibu" assert len(track.album.cover_pictures) == 2 assert track.album.cover_pictures[0].artist == 'Artist 1' assert len(track.album.shops) == 2 + assert track.album.shops[0].name == 'Shop 1' + assert track.album.shops[0].division.name == 'Div 1' @pytest.mark.asyncio