diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 22db178..3f31c0b 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -1,12 +1,15 @@ import inspect from collections import OrderedDict from typing import ( + Any, + Callable, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, + Tuple, Type, TypeVar, Union, @@ -38,6 +41,8 @@ class ModelTableProxy: Meta: ModelMeta _related_names: Set _related_names_hash: Union[str, bytes] + pk: Any + get_name: Callable def dict(self): # noqa A003 raise NotImplementedError # pragma no cover @@ -47,6 +52,66 @@ class ModelTableProxy: self_fields = {k: v for k, v in self.dict().items() if k not in related_names} return self_fields + @classmethod + def get_related_field_name(cls, target_field: Type["BaseField"]) -> str: + if issubclass(target_field, ormar.fields.ManyToManyField): + return cls.resolve_relation_name(target_field.through, cls) + if target_field.virtual: + return cls.resolve_relation_name(target_field.to, cls) + return target_field.to.Meta.pkname + + @staticmethod + def get_clause_target_and_filter_column_name( + parent_model: Type["Model"], target_model: Type["Model"], reverse: bool + ) -> Tuple[Type["Model"], str]: + if reverse: + field = target_model.resolve_relation_field(target_model, parent_model) + if issubclass(field, ormar.fields.ManyToManyField): + sub_field = target_model.resolve_relation_field( + field.through, parent_model + ) + return field.through, sub_field.get_alias() + else: + return target_model, field.get_alias() + target_field = target_model.get_column_alias(target_model.Meta.pkname) + return target_model, target_field + + @staticmethod + def get_column_name_for_id_extraction( + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool, + use_raw: bool, + ) -> str: + if reverse: + column_name = parent_model.Meta.pkname + return ( + parent_model.get_column_alias(column_name) if use_raw else column_name + ) + else: + column = target_model.resolve_relation_field(parent_model, target_model) + return column.get_alias() if use_raw else column.name + + @classmethod + def get_filtered_names_to_extract(cls, prefetch_dict: Dict) -> List: + related_to_extract = [] + if prefetch_dict and prefetch_dict is not Ellipsis: + related_to_extract = [ + related + for related in cls.extract_related_names() + if related in prefetch_dict + ] + return related_to_extract + + def get_relation_model_id(self, target_field: Type["BaseField"]) -> Optional[int]: + if target_field.virtual or issubclass( + target_field, ormar.fields.ManyToManyField + ): + return self.pk + related_name = self.resolve_relation_name(self, target_field.to) + related_model = getattr(self, related_name) + return None if not related_model else related_model.pk + @classmethod def extract_db_own_fields(cls) -> Set: related_names = cls.extract_related_names() @@ -155,8 +220,18 @@ class ModelTableProxy: @staticmethod def resolve_relation_name( # noqa CCR001 - item: Union["NewBaseModel", Type["NewBaseModel"]], - related: Union["NewBaseModel", Type["NewBaseModel"]], + item: Union[ + "NewBaseModel", + Type["NewBaseModel"], + "ModelTableProxy", + Type["ModelTableProxy"], + ], + related: Union[ + "NewBaseModel", + Type["NewBaseModel"], + "ModelTableProxy", + Type["ModelTableProxy"], + ], ) -> str: for name, field in item.Meta.model_fields.items(): if issubclass(field, ForeignKeyField): diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index ed4de9d..7afef1a 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -21,6 +21,17 @@ if TYPE_CHECKING: # pragma: no cover from ormar import Model +def add_relation_field_to_fields( + fields: Union[Set[Any], Dict[Any, Any], None], related_field_name: str +) -> Union[Set[Any], Dict[Any, Any], None]: + if fields and related_field_name not in fields: + if isinstance(fields, dict): + fields[related_field_name] = ... + elif isinstance(fields, set): + fields.add(related_field_name) + return fields + + class PrefetchQuery: def __init__( self, @@ -50,22 +61,6 @@ class PrefetchQuery: self.models[self.model.get_name()] = models return await self._prefetch_related_models(models=models, rows=rows) - @staticmethod - def _get_column_name_for_id_extraction( - parent_model: Type["Model"], - target_model: Type["Model"], - reverse: bool, - use_raw: bool, - ) -> str: - if reverse: - column_name = parent_model.Meta.pkname - return ( - parent_model.get_column_alias(column_name) if use_raw else column_name - ) - else: - column = target_model.resolve_relation_field(parent_model, target_model) - return column.get_alias() if use_raw else column.name - def _extract_ids_from_raw_data( self, parent_model: Type["Model"], column_name: str ) -> Set: @@ -96,7 +91,7 @@ class PrefetchQuery: use_raw = parent_model.get_name() not in self.models - column_name = self._get_column_name_for_id_extraction( + column_name = parent_model.get_column_name_for_id_extraction( parent_model=parent_model, target_model=target_model, reverse=reverse, @@ -112,22 +107,6 @@ class PrefetchQuery: parent_model=parent_model, column_name=column_name ) - @staticmethod - def _get_clause_target_and_filter_column_name( - parent_model: Type["Model"], target_model: Type["Model"], reverse: bool - ) -> Tuple[Type["Model"], str]: - if reverse: - field = target_model.resolve_relation_field(target_model, parent_model) - if issubclass(field, ManyToManyField): - sub_field = target_model.resolve_relation_field( - field.through, parent_model - ) - return field.through, sub_field.get_alias() - else: - return target_model, field.get_alias() - target_field = target_model.get_column_alias(target_model.Meta.pkname) - return target_model, target_field - def _get_filter_for_prefetch( self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, ) -> List: @@ -138,7 +117,7 @@ class PrefetchQuery: ( clause_target, filter_column, - ) = self._get_clause_target_and_filter_column_name( + ) = parent_model.get_clause_target_and_filter_column_name( parent_model=parent_model, target_model=target_model, reverse=reverse ) qryclause = QueryClause( @@ -149,52 +128,21 @@ class PrefetchQuery: return filter_clauses return [] - @staticmethod - def _get_model_id(target_field: Type["BaseField"], model: "Model") -> Optional[int]: - if target_field.virtual or issubclass(target_field, ManyToManyField): - return model.pk - related_name = model.resolve_relation_name(model, target_field.to) - related_model = getattr(model, related_name) - return None if not related_model else related_model.pk - - @staticmethod - def _get_related_field_name( - target_field: Type["BaseField"], model: Union["Model", Type["Model"]] - ) -> str: - if issubclass(target_field, ManyToManyField): - return model.resolve_relation_name(target_field.through, model) - if target_field.virtual: - return model.resolve_relation_name(target_field.to, model) - return target_field.to.Meta.pkname - - @staticmethod - def _get_names_to_extract(prefetch_dict: Dict, model: "Model") -> List: - related_to_extract = [] - if prefetch_dict and prefetch_dict is not Ellipsis: - related_to_extract = [ - related - for related in model.extract_related_names() - if related in prefetch_dict - ] - return related_to_extract - def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model": - related_to_extract = self._get_names_to_extract( - prefetch_dict=prefetch_dict, model=model + related_to_extract = model.get_filtered_names_to_extract( + prefetch_dict=prefetch_dict ) for related in related_to_extract: target_field = model.Meta.model_fields[related] target_model = target_field.to.get_name() - model_id = self._get_model_id(target_field=target_field, model=model) + model_id = model.get_relation_model_id(target_field=target_field) if model_id is None: # pragma: no cover continue - field_name = self._get_related_field_name( - target_field=target_field, model=model - ) + field_name = model.get_related_field_name(target_field=target_field) children = self.already_extracted.get(target_model, {}).get(field_name, {}) self._set_children_on_model( @@ -266,6 +214,12 @@ class PrefetchQuery: if not already_loaded: # If not already loaded with select_related + related_field_name = parent_model.get_related_field_name( + target_field=target_field + ) + fields = add_relation_field_to_fields( + fields=fields, related_field_name=related_field_name + ) table_prefix, rows = await self._run_prefetch_query( target_field=target_field, fields=fields, @@ -371,9 +325,7 @@ class PrefetchQuery: ) -> None: target_model = target_field.to for row in rows: - field_name = self._get_related_field_name( - target_field=target_field, model=parent_model - ) + field_name = parent_model.get_related_field_name(target_field=target_field) item = target_model.extract_prefixed_table_columns( item={}, row=row, diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index cd0fb2e..825bbd9 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -39,7 +39,7 @@ class Division(ormar.Model): database = database id: int = ormar.Integer(name='division_id', primary_key=True) - name: str = ormar.String(max_length=100) + name: str = ormar.String(max_length=100, nullable=True) class Shop(ormar.Model): @@ -49,7 +49,7 @@ class Shop(ormar.Model): database = database id: int = ormar.Integer(primary_key=True) - name: str = ormar.String(max_length=100) + name: str = ormar.String(max_length=100, nullable=True) division: Optional[Division] = ormar.ForeignKey(Division) @@ -67,7 +67,7 @@ class Album(ormar.Model): database = database id: int = ormar.Integer(primary_key=True) - name: str = ormar.String(max_length=100) + name: str = ormar.String(max_length=100, nullable=True) shops: List[Shop] = ormar.ManyToMany(to=Shop, through=AlbumShops) @@ -243,3 +243,50 @@ async def test_prefetch_related_with_select_related(): assert len(track.album.shops) == 2 assert track.album.shops[0].name == 'Shop 1' assert track.album.shops[0].division.name == 'Div 1' + + +@pytest.mark.asyncio +async def test_prefetch_related_with_select_related_and_fields(): + async with database: + async with database.transaction(force_rollback=True): + div = await Division.objects.create(name='Div 1') + shop1 = await Shop.objects.create(name='Shop 1', division=div) + shop2 = await Shop.objects.create(name='Shop 2', division=div) + album = Album(name="Malibu") + await album.save() + await album.shops.add(shop1) + await album.shops.add(shop2) + await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') + await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + rand_set = await RandomSet.objects.create(name='Rand 1') + ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set) + await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) + await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) + await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) + + album = await Album.objects.select_related('tracks__tonation__rand_set').filter( + name='Malibu').prefetch_related( + ['cover_pictures', 'shops__division']).exclude_fields({'shops': {'division': {'name'}}}).get() + assert len(album.tracks) == 3 + assert album.tracks[0].tonation == album.tracks[2].tonation == ton1 + assert len(album.cover_pictures) == 2 + assert album.cover_pictures[0].artist == 'Artist 1' + + assert len(album.shops) == 2 + assert album.shops[0].name == 'Shop 1' + assert album.shops[0].division.name is None + + album = await Album.objects.select_related('tracks').filter( + name='Malibu').prefetch_related( + ['cover_pictures', 'shops__division']).fields( + {'name': ..., 'shops': {'division'}, 'cover_pictures': {'id': ..., 'title': ...}} + ).exclude_fields({'shops': {'division': {'name'}}}).get() + assert len(album.tracks) == 3 + assert len(album.cover_pictures) == 2 + assert album.cover_pictures[0].artist is None + assert album.cover_pictures[0].title is not None + + assert len(album.shops) == 2 + assert album.shops[0].name is None + assert album.shops[0].division is not None + assert album.shops[0].division.name is None