diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py new file mode 100644 index 0000000..0fca239 --- /dev/null +++ b/ormar/queryset/prefetch_query.py @@ -0,0 +1,195 @@ +from typing import Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union + +import ormar +from ormar.fields import ManyToManyField +from ormar.queryset.clause import QueryClause +from ormar.queryset.query import Query + +if TYPE_CHECKING: # pragma: no cover + from ormar import Model + + +class PrefetchQuery: + + def __init__(self, + model_cls: Type["Model"], + fields: Optional[Union[Dict, Set]], + exclude_fields: Optional[Union[Dict, Set]], + prefetch_related: List + ): + + self.model = model_cls + self.database = self.model.Meta.database + self._prefetch_related = prefetch_related + self._exclude_columns = exclude_fields + self._columns = fields + + @staticmethod + def _extract_required_ids(already_extracted: Dict, + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool) -> Set: + raw_rows = already_extracted.get(parent_model.get_name(), {}).get('raw', []) + if reverse: + column_name = parent_model.get_column_alias(parent_model.Meta.pkname) + else: + column_name = target_model.resolve_relation_field(parent_model, target_model).get_alias() + list_of_ids = set() + for row in raw_rows: + if row[column_name]: + list_of_ids.add(row[column_name]) + return list_of_ids + + @staticmethod + def _get_filter_for_prefetch(already_extracted: Dict, + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool) -> List: + ids = PrefetchQuery._extract_required_ids(already_extracted=already_extracted, + parent_model=parent_model, + target_model=target_model, + reverse=reverse) + if ids: + qryclause = QueryClause( + model_cls=target_model, + select_related=[], + filter_clauses=[], + ) + if reverse: + field = target_model.resolve_relation_field(target_model, parent_model) + kwargs = {f'{field.get_alias()}__in': ids} + else: + target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias() + kwargs = {f'{target_field}__in': ids} + filter_clauses, _ = qryclause.filter(**kwargs) + return filter_clauses + return [] + + @staticmethod + def _populate_nested_related(model: "Model", + already_extracted: Dict) -> "Model": + + for related in model.extract_related_names(): + reverse = False + + target_field = model.Meta.model_fields[related] + if target_field.virtual or issubclass(target_field, ManyToManyField): + reverse = True + + target_model = target_field.to.get_name() + if reverse: + field_name = model.resolve_relation_name(target_field.to, model) + model_id = model.pk + else: + related_name = model.resolve_relation_name(model, target_field.to) + related_model = getattr(model, related_name) + if not related_model: + continue + model_id = related_model.pk + field_name = target_field.to.Meta.pkname + + if target_model in already_extracted and already_extracted[target_model]['models']: + print('*****POPULATING RELATED:', target_model, field_name) + print(already_extracted[target_model]['models']) + for child_model in already_extracted[target_model]['models']: + related_model = getattr(child_model, field_name) + if isinstance(related_model, list): + for child in related_model: + if child.pk == model_id: + setattr(model, related, child) + + elif isinstance(related_model, ormar.Model): + if related_model.pk == model_id: + if reverse: + setattr(model, related, child_model) + else: + setattr(child_model, related, model) + + else: # we have not reverse relation and related_model is a pk value + setattr(model, related, child_model) + + return model + + async def prefetch_related(self, models: Sequence["Model"], rows: List): + return await self._prefetch_related_models(models=models, rows=rows) + + async def _prefetch_related_models(self, + models: Sequence["Model"], + rows: List) -> Sequence["Model"]: + already_extracted = {self.model.get_name(): {'raw': rows, 'models': models}} + for related in self._prefetch_related: + target_model = self.model + fields = self._columns + exclude_fields = self._exclude_columns + for part in related.split('__'): + fields = target_model.get_included(fields, part) + exclude_fields = target_model.get_excluded(exclude_fields, part) + + target_field = target_model.Meta.model_fields[part] + reverse = False + if target_field.virtual or issubclass(target_field, ManyToManyField): + reverse = True + + parent_model = target_model + target_model = target_field.to + + if target_model.get_name() not in already_extracted: + filter_clauses = self._get_filter_for_prefetch(already_extracted=already_extracted, + parent_model=parent_model, + target_model=target_model, + reverse=reverse) + if not filter_clauses: # related field is empty + continue + + qry = Query( + model_cls=target_model, + select_related=[], + filter_clauses=filter_clauses, + exclude_clauses=[], + offset=None, + limit_count=None, + fields=fields, + exclude_fields=exclude_fields, + order_bys=None, + ) + expr = qry.build_select_expression() + print(expr.compile(compile_kwargs={"literal_binds": True})) + rows = await self.database.fetch_all(expr) + already_extracted[target_model.get_name()] = {'raw': rows, 'models': []} + if part == related.split('__')[-1]: + for row in rows: + item = target_model.extract_prefixed_table_columns( + item={}, + row=row, + table_prefix='', + fields=fields, + exclude_fields=exclude_fields + ) + instance = target_model(**item) + already_extracted[target_model.get_name()]['models'].append(instance) + + target_model = self.model + fields = self._columns + exclude_fields = self._exclude_columns + for part in related.split('__')[:-1]: + fields = target_model.get_included(fields, part) + exclude_fields = target_model.get_excluded(exclude_fields, part) + target_model = target_model.Meta.model_fields[part].to + for row in already_extracted.get(target_model.get_name(), {}).get('raw', []): + item = target_model.extract_prefixed_table_columns( + item={}, + row=row, + table_prefix='', + fields=fields, + exclude_fields=exclude_fields + ) + instance = target_model(**item) + instance = self._populate_nested_related(model=instance, + already_extracted=already_extracted) + + already_extracted[target_model.get_name()]['models'].append(instance) + final_models = [] + for model in models: + final_models.append(self._populate_nested_related(model=model, + already_extracted=already_extracted)) + return models diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index b8b4aa3..6e54129 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -9,6 +9,7 @@ from ormar import MultipleMatches, NoMatch from ormar.exceptions import QueryDefinitionError from ormar.queryset import FilterQuery from ormar.queryset.clause import QueryClause +from ormar.queryset.prefetch_query import PrefetchQuery from ormar.queryset.query import Query from ormar.queryset.utils import update, update_dict_from_list @@ -20,21 +21,23 @@ if TYPE_CHECKING: # pragma no cover class QuerySet: def __init__( # noqa CFQ002 - self, - model_cls: Type["Model"] = None, - filter_clauses: List = None, - exclude_clauses: List = None, - select_related: List = None, - limit_count: int = None, - offset: int = None, - columns: Dict = None, - exclude_columns: Dict = None, - order_bys: List = None, + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + exclude_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, + columns: Dict = None, + exclude_columns: Dict = None, + order_bys: List = None, + prefetch_related: List = None, ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses self.exclude_clauses = [] if exclude_clauses is None else exclude_clauses self._select_related = [] if select_related is None else select_related + self._prefetch_related = [] if prefetch_related is None else prefetch_related self.limit_count = limit_count self.query_offset = offset self._columns = columns or {} @@ -42,9 +45,9 @@ class QuerySet: self.order_bys = order_bys or [] def __get__( - self, - instance: Optional[Union["QuerySet", "QuerysetProxy"]], - owner: Union[Type["Model"], Type["QuerysetProxy"]], + self, + instance: Optional[Union["QuerySet", "QuerysetProxy"]], + owner: Union[Type["Model"], Type["QuerysetProxy"]], ) -> "QuerySet": if issubclass(owner, ormar.Model): return self.__class__(model_cls=owner) @@ -63,6 +66,13 @@ class QuerySet: raise ValueError("Model class of QuerySet is not initialized") return self.model_cls + async def _prefetch_related_models(self, models: Sequence["Model"], rows: List) -> Sequence["Model"]: + query = PrefetchQuery(model_cls=self.model_cls, + fields=self._columns, + exclude_fields=self._exclude_columns, + prefetch_related=self._prefetch_related) + return await query.prefetch_related(models=models, rows=rows) + def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: result_rows = [ self.model.from_row( @@ -88,7 +98,7 @@ class QuerySet: pkname = self.model_meta.pkname pk = self.model_meta.model_fields[pkname] if new_kwargs.get(pkname, ormar.Undefined) is None and ( - pk.nullable or pk.autoincrement + pk.nullable or pk.autoincrement ): del new_kwargs[pkname] return new_kwargs @@ -148,6 +158,7 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003 @@ -168,6 +179,25 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, + ) + + def prefetch_related(self, related: Union[List, str]) -> "QuerySet": + if not isinstance(related, list): + related = [related] + + related = list(set(list(self._select_related) + related)) + return self.__class__( + model_cls=self.model, + filter_clauses=self.filter_clauses, + exclude_clauses=self.exclude_clauses, + select_related=self._select_related, + limit_count=self.limit_count, + offset=self.query_offset, + columns=self._columns, + exclude_columns=self._exclude_columns, + order_bys=self.order_bys, + prefetch_related=related ) def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": @@ -190,6 +220,7 @@ class QuerySet: columns=self._columns, exclude_columns=current_excluded, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": @@ -212,6 +243,7 @@ class QuerySet: columns=current_included, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) def order_by(self, columns: Union[List, str]) -> "QuerySet": @@ -229,6 +261,7 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=order_bys, + prefetch_related=self._prefetch_related, ) async def exists(self) -> bool: @@ -279,6 +312,7 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) def offset(self, offset: int) -> "QuerySet": @@ -292,6 +326,7 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) async def first(self, **kwargs: Any) -> "Model": @@ -312,6 +347,8 @@ class QuerySet: rows = await self.database.fetch_all(expr) processed_rows = self._process_query_result_rows(rows) + if self._prefetch_related: + processed_rows = await self._prefetch_related_models(processed_rows, rows) self.check_single_result_rows_count(processed_rows) return processed_rows[0] # type: ignore @@ -337,6 +374,8 @@ class QuerySet: expr = self.build_select_expression() rows = await self.database.fetch_all(expr) result_rows = self._process_query_result_rows(rows) + if self._prefetch_related: + result_rows = await self._prefetch_related_models(result_rows, rows) return result_rows @@ -359,9 +398,9 @@ class QuerySet: # refresh server side defaults if any( - field.server_default is not None - for name, field in self.model.Meta.model_fields.items() - if name not in kwargs + field.server_default is not None + for name, field in self.model.Meta.model_fields.items() + if name not in kwargs ): instance = await instance.load() instance.set_save_status(True) @@ -381,7 +420,7 @@ class QuerySet: objt.set_save_status(True) async def bulk_update( # noqa: CCR001 - self, objects: List["Model"], columns: List[str] = None + self, objects: List["Model"], columns: List[str] = None ) -> None: ready_objects = [] pk_name = self.model_meta.pkname diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index e7bf4e5..7015f7a 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -55,10 +55,6 @@ class Organisation(ormar.Model): ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"]) -class Organization(object): - pass - - class Team(ormar.Model): class Meta: tablename = "teams" @@ -239,8 +235,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -248,8 +244,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -294,8 +290,8 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: @@ -327,8 +323,8 @@ async def test_pk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1 diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py new file mode 100644 index 0000000..2582587 --- /dev/null +++ b/tests/test_prefetch_related.py @@ -0,0 +1,125 @@ +from typing import Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Tonation(ormar.Model): + class Meta: + tablename = "tonations" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Album(ormar.Model): + class Meta: + tablename = "albums" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Track(ormar.Model): + class Meta: + tablename = "tracks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + album: Optional[Album] = ormar.ForeignKey(Album) + title: str = ormar.String(max_length=100) + position: int = ormar.Integer() + tonation: Optional[Tonation] = ormar.ForeignKey(Tonation) + + +class Cover(ormar.Model): + class Meta: + tablename = "covers" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures") + title: str = ormar.String(max_length=100) + artist: str = ormar.String(max_length=200, nullable=True) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_prefetch_related(): + async with database: + async with database.transaction(force_rollback=True): + album = Album(name="Malibu") + await album.save() + ton1 = await Tonation.objects.create(name='B-mol') + await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) + await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) + await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) + await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') + await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + + fantasies = Album(name="Fantasies") + await fantasies.save() + await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) + await Cover.objects.create(title='Cover3', album=fantasies, artist='Artist 3') + await Cover.objects.create(title='Cover4', album=fantasies, artist='Artist 4') + + album = await Album.objects.filter(name='Malibu').prefetch_related( + ['tracks__tonation', 'cover_pictures']).get() + assert len(album.tracks) == 3 + assert album.tracks[0].title == 'The Bird' + assert len(album.cover_pictures) == 2 + assert album.cover_pictures[0].title == 'Cover1' + assert album.tracks[0].tonation.name == album.tracks[2].tonation.name == 'B-mol' + + albums = await Album.objects.prefetch_related('tracks').all() + assert len(albums[0].tracks) == 3 + assert len(albums[1].tracks) == 3 + assert albums[0].tracks[0].title == "The Bird" + assert albums[1].tracks[0].title == "Help I'm Alive" + + track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird") + assert track.album.name == "Malibu" + assert len(track.album.cover_pictures) == 2 + assert track.album.cover_pictures[0].artist == 'Artist 1' + + track = await Track.objects.prefetch_related(["album__cover_pictures"]).exclude_fields( + 'album__cover_pictures__artist').get(title="The Bird") + assert track.album.name == "Malibu" + assert len(track.album.cover_pictures) == 2 + assert track.album.cover_pictures[0].artist is None + + tracks = await Track.objects.prefetch_related("album").all() + assert len(tracks) == 6 + + +@pytest.mark.asyncio +async def test_prefetch_related_empty(): + async with database: + async with database.transaction(force_rollback=True): + await Track.objects.create(title="The Bird", position=1) + track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird") + assert track.title == 'The Bird' + assert track.album is None