From b696156f5603a6274b915fcf72de4d98bef48f3e Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 23 Nov 2020 16:05:05 +0100 Subject: [PATCH 01/13] dirty prefetch_related working for FK and reverse FK --- ormar/queryset/prefetch_query.py | 195 +++++++++++++++++++++++++++++++ ormar/queryset/queryset.py | 75 +++++++++--- tests/test_foreign_keys.py | 20 ++-- tests/test_prefetch_related.py | 125 ++++++++++++++++++++ 4 files changed, 385 insertions(+), 30 deletions(-) create mode 100644 ormar/queryset/prefetch_query.py create mode 100644 tests/test_prefetch_related.py diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py new file mode 100644 index 0000000..0fca239 --- /dev/null +++ b/ormar/queryset/prefetch_query.py @@ -0,0 +1,195 @@ +from typing import Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union + +import ormar +from ormar.fields import ManyToManyField +from ormar.queryset.clause import QueryClause +from ormar.queryset.query import Query + +if TYPE_CHECKING: # pragma: no cover + from ormar import Model + + +class PrefetchQuery: + + def __init__(self, + model_cls: Type["Model"], + fields: Optional[Union[Dict, Set]], + exclude_fields: Optional[Union[Dict, Set]], + prefetch_related: List + ): + + self.model = model_cls + self.database = self.model.Meta.database + self._prefetch_related = prefetch_related + self._exclude_columns = exclude_fields + self._columns = fields + + @staticmethod + def _extract_required_ids(already_extracted: Dict, + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool) -> Set: + raw_rows = already_extracted.get(parent_model.get_name(), {}).get('raw', []) + if reverse: + column_name = parent_model.get_column_alias(parent_model.Meta.pkname) + else: + column_name = target_model.resolve_relation_field(parent_model, target_model).get_alias() + list_of_ids = set() + for row in raw_rows: + if row[column_name]: + list_of_ids.add(row[column_name]) + return list_of_ids + + @staticmethod + def _get_filter_for_prefetch(already_extracted: Dict, + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool) -> List: + ids = PrefetchQuery._extract_required_ids(already_extracted=already_extracted, + parent_model=parent_model, + target_model=target_model, + reverse=reverse) + if ids: + qryclause = QueryClause( + model_cls=target_model, + select_related=[], + filter_clauses=[], + ) + if reverse: + field = target_model.resolve_relation_field(target_model, parent_model) + kwargs = {f'{field.get_alias()}__in': ids} + else: + target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias() + kwargs = {f'{target_field}__in': ids} + filter_clauses, _ = qryclause.filter(**kwargs) + return filter_clauses + return [] + + @staticmethod + def _populate_nested_related(model: "Model", + already_extracted: Dict) -> "Model": + + for related in model.extract_related_names(): + reverse = False + + target_field = model.Meta.model_fields[related] + if target_field.virtual or issubclass(target_field, ManyToManyField): + reverse = True + + target_model = target_field.to.get_name() + if reverse: + field_name = model.resolve_relation_name(target_field.to, model) + model_id = model.pk + else: + related_name = model.resolve_relation_name(model, target_field.to) + related_model = getattr(model, related_name) + if not related_model: + continue + model_id = related_model.pk + field_name = target_field.to.Meta.pkname + + if target_model in already_extracted and already_extracted[target_model]['models']: + print('*****POPULATING RELATED:', target_model, field_name) + print(already_extracted[target_model]['models']) + for child_model in already_extracted[target_model]['models']: + related_model = getattr(child_model, field_name) + if isinstance(related_model, list): + for child in related_model: + if child.pk == model_id: + setattr(model, related, child) + + elif isinstance(related_model, ormar.Model): + if related_model.pk == model_id: + if reverse: + setattr(model, related, child_model) + else: + setattr(child_model, related, model) + + else: # we have not reverse relation and related_model is a pk value + setattr(model, related, child_model) + + return model + + async def prefetch_related(self, models: Sequence["Model"], rows: List): + return await self._prefetch_related_models(models=models, rows=rows) + + async def _prefetch_related_models(self, + models: Sequence["Model"], + rows: List) -> Sequence["Model"]: + already_extracted = {self.model.get_name(): {'raw': rows, 'models': models}} + for related in self._prefetch_related: + target_model = self.model + fields = self._columns + exclude_fields = self._exclude_columns + for part in related.split('__'): + fields = target_model.get_included(fields, part) + exclude_fields = target_model.get_excluded(exclude_fields, part) + + target_field = target_model.Meta.model_fields[part] + reverse = False + if target_field.virtual or issubclass(target_field, ManyToManyField): + reverse = True + + parent_model = target_model + target_model = target_field.to + + if target_model.get_name() not in already_extracted: + filter_clauses = self._get_filter_for_prefetch(already_extracted=already_extracted, + parent_model=parent_model, + target_model=target_model, + reverse=reverse) + if not filter_clauses: # related field is empty + continue + + qry = Query( + model_cls=target_model, + select_related=[], + filter_clauses=filter_clauses, + exclude_clauses=[], + offset=None, + limit_count=None, + fields=fields, + exclude_fields=exclude_fields, + order_bys=None, + ) + expr = qry.build_select_expression() + print(expr.compile(compile_kwargs={"literal_binds": True})) + rows = await self.database.fetch_all(expr) + already_extracted[target_model.get_name()] = {'raw': rows, 'models': []} + if part == related.split('__')[-1]: + for row in rows: + item = target_model.extract_prefixed_table_columns( + item={}, + row=row, + table_prefix='', + fields=fields, + exclude_fields=exclude_fields + ) + instance = target_model(**item) + already_extracted[target_model.get_name()]['models'].append(instance) + + target_model = self.model + fields = self._columns + exclude_fields = self._exclude_columns + for part in related.split('__')[:-1]: + fields = target_model.get_included(fields, part) + exclude_fields = target_model.get_excluded(exclude_fields, part) + target_model = target_model.Meta.model_fields[part].to + for row in already_extracted.get(target_model.get_name(), {}).get('raw', []): + item = target_model.extract_prefixed_table_columns( + item={}, + row=row, + table_prefix='', + fields=fields, + exclude_fields=exclude_fields + ) + instance = target_model(**item) + instance = self._populate_nested_related(model=instance, + already_extracted=already_extracted) + + already_extracted[target_model.get_name()]['models'].append(instance) + final_models = [] + for model in models: + final_models.append(self._populate_nested_related(model=model, + already_extracted=already_extracted)) + return models diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index b8b4aa3..6e54129 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -9,6 +9,7 @@ from ormar import MultipleMatches, NoMatch from ormar.exceptions import QueryDefinitionError from ormar.queryset import FilterQuery from ormar.queryset.clause import QueryClause +from ormar.queryset.prefetch_query import PrefetchQuery from ormar.queryset.query import Query from ormar.queryset.utils import update, update_dict_from_list @@ -20,21 +21,23 @@ if TYPE_CHECKING: # pragma no cover class QuerySet: def __init__( # noqa CFQ002 - self, - model_cls: Type["Model"] = None, - filter_clauses: List = None, - exclude_clauses: List = None, - select_related: List = None, - limit_count: int = None, - offset: int = None, - columns: Dict = None, - exclude_columns: Dict = None, - order_bys: List = None, + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + exclude_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, + columns: Dict = None, + exclude_columns: Dict = None, + order_bys: List = None, + prefetch_related: List = None, ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses self.exclude_clauses = [] if exclude_clauses is None else exclude_clauses self._select_related = [] if select_related is None else select_related + self._prefetch_related = [] if prefetch_related is None else prefetch_related self.limit_count = limit_count self.query_offset = offset self._columns = columns or {} @@ -42,9 +45,9 @@ class QuerySet: self.order_bys = order_bys or [] def __get__( - self, - instance: Optional[Union["QuerySet", "QuerysetProxy"]], - owner: Union[Type["Model"], Type["QuerysetProxy"]], + self, + instance: Optional[Union["QuerySet", "QuerysetProxy"]], + owner: Union[Type["Model"], Type["QuerysetProxy"]], ) -> "QuerySet": if issubclass(owner, ormar.Model): return self.__class__(model_cls=owner) @@ -63,6 +66,13 @@ class QuerySet: raise ValueError("Model class of QuerySet is not initialized") return self.model_cls + async def _prefetch_related_models(self, models: Sequence["Model"], rows: List) -> Sequence["Model"]: + query = PrefetchQuery(model_cls=self.model_cls, + fields=self._columns, + exclude_fields=self._exclude_columns, + prefetch_related=self._prefetch_related) + return await query.prefetch_related(models=models, rows=rows) + def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: result_rows = [ self.model.from_row( @@ -88,7 +98,7 @@ class QuerySet: pkname = self.model_meta.pkname pk = self.model_meta.model_fields[pkname] if new_kwargs.get(pkname, ormar.Undefined) is None and ( - pk.nullable or pk.autoincrement + pk.nullable or pk.autoincrement ): del new_kwargs[pkname] return new_kwargs @@ -148,6 +158,7 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003 @@ -168,6 +179,25 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, + ) + + def prefetch_related(self, related: Union[List, str]) -> "QuerySet": + if not isinstance(related, list): + related = [related] + + related = list(set(list(self._select_related) + related)) + return self.__class__( + model_cls=self.model, + filter_clauses=self.filter_clauses, + exclude_clauses=self.exclude_clauses, + select_related=self._select_related, + limit_count=self.limit_count, + offset=self.query_offset, + columns=self._columns, + exclude_columns=self._exclude_columns, + order_bys=self.order_bys, + prefetch_related=related ) def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": @@ -190,6 +220,7 @@ class QuerySet: columns=self._columns, exclude_columns=current_excluded, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": @@ -212,6 +243,7 @@ class QuerySet: columns=current_included, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) def order_by(self, columns: Union[List, str]) -> "QuerySet": @@ -229,6 +261,7 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=order_bys, + prefetch_related=self._prefetch_related, ) async def exists(self) -> bool: @@ -279,6 +312,7 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) def offset(self, offset: int) -> "QuerySet": @@ -292,6 +326,7 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) async def first(self, **kwargs: Any) -> "Model": @@ -312,6 +347,8 @@ class QuerySet: rows = await self.database.fetch_all(expr) processed_rows = self._process_query_result_rows(rows) + if self._prefetch_related: + processed_rows = await self._prefetch_related_models(processed_rows, rows) self.check_single_result_rows_count(processed_rows) return processed_rows[0] # type: ignore @@ -337,6 +374,8 @@ class QuerySet: expr = self.build_select_expression() rows = await self.database.fetch_all(expr) result_rows = self._process_query_result_rows(rows) + if self._prefetch_related: + result_rows = await self._prefetch_related_models(result_rows, rows) return result_rows @@ -359,9 +398,9 @@ class QuerySet: # refresh server side defaults if any( - field.server_default is not None - for name, field in self.model.Meta.model_fields.items() - if name not in kwargs + field.server_default is not None + for name, field in self.model.Meta.model_fields.items() + if name not in kwargs ): instance = await instance.load() instance.set_save_status(True) @@ -381,7 +420,7 @@ class QuerySet: objt.set_save_status(True) async def bulk_update( # noqa: CCR001 - self, objects: List["Model"], columns: List[str] = None + self, objects: List["Model"], columns: List[str] = None ) -> None: ready_objects = [] pk_name = self.model_meta.pkname diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index e7bf4e5..7015f7a 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -55,10 +55,6 @@ class Organisation(ormar.Model): ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"]) -class Organization(object): - pass - - class Team(ormar.Model): class Meta: tablename = "teams" @@ -239,8 +235,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -248,8 +244,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -294,8 +290,8 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: @@ -327,8 +323,8 @@ async def test_pk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1 diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py new file mode 100644 index 0000000..2582587 --- /dev/null +++ b/tests/test_prefetch_related.py @@ -0,0 +1,125 @@ +from typing import Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Tonation(ormar.Model): + class Meta: + tablename = "tonations" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Album(ormar.Model): + class Meta: + tablename = "albums" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Track(ormar.Model): + class Meta: + tablename = "tracks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + album: Optional[Album] = ormar.ForeignKey(Album) + title: str = ormar.String(max_length=100) + position: int = ormar.Integer() + tonation: Optional[Tonation] = ormar.ForeignKey(Tonation) + + +class Cover(ormar.Model): + class Meta: + tablename = "covers" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures") + title: str = ormar.String(max_length=100) + artist: str = ormar.String(max_length=200, nullable=True) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_prefetch_related(): + async with database: + async with database.transaction(force_rollback=True): + album = Album(name="Malibu") + await album.save() + ton1 = await Tonation.objects.create(name='B-mol') + await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) + await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) + await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) + await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') + await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + + fantasies = Album(name="Fantasies") + await fantasies.save() + await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) + await Cover.objects.create(title='Cover3', album=fantasies, artist='Artist 3') + await Cover.objects.create(title='Cover4', album=fantasies, artist='Artist 4') + + album = await Album.objects.filter(name='Malibu').prefetch_related( + ['tracks__tonation', 'cover_pictures']).get() + assert len(album.tracks) == 3 + assert album.tracks[0].title == 'The Bird' + assert len(album.cover_pictures) == 2 + assert album.cover_pictures[0].title == 'Cover1' + assert album.tracks[0].tonation.name == album.tracks[2].tonation.name == 'B-mol' + + albums = await Album.objects.prefetch_related('tracks').all() + assert len(albums[0].tracks) == 3 + assert len(albums[1].tracks) == 3 + assert albums[0].tracks[0].title == "The Bird" + assert albums[1].tracks[0].title == "Help I'm Alive" + + track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird") + assert track.album.name == "Malibu" + assert len(track.album.cover_pictures) == 2 + assert track.album.cover_pictures[0].artist == 'Artist 1' + + track = await Track.objects.prefetch_related(["album__cover_pictures"]).exclude_fields( + 'album__cover_pictures__artist').get(title="The Bird") + assert track.album.name == "Malibu" + assert len(track.album.cover_pictures) == 2 + assert track.album.cover_pictures[0].artist is None + + tracks = await Track.objects.prefetch_related("album").all() + assert len(tracks) == 6 + + +@pytest.mark.asyncio +async def test_prefetch_related_empty(): + async with database: + async with database.transaction(force_rollback=True): + await Track.objects.create(title="The Bird", position=1) + track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird") + assert track.title == 'The Bird' + assert track.album is None From 585bba3ad375ef3487cc2c6f9791bb810e09823c Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 23 Nov 2020 17:03:31 +0100 Subject: [PATCH 02/13] dirty many to many pass first test --- ormar/queryset/prefetch_query.py | 71 +++++++++++++++++++++++--------- tests/test_prefetch_related.py | 46 ++++++++++++++++++++- 2 files changed, 97 insertions(+), 20 deletions(-) diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 0fca239..6cc9ec3 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -57,7 +57,17 @@ class PrefetchQuery: ) if reverse: field = target_model.resolve_relation_field(target_model, parent_model) - kwargs = {f'{field.get_alias()}__in': ids} + if issubclass(field, ManyToManyField): + sub_field = target_model.resolve_relation_field(field.through, parent_model) + kwargs = {f'{sub_field.get_alias()}__in': ids} + qryclause = QueryClause( + model_cls=field.through, + select_related=[], + filter_clauses=[], + ) + + else: + kwargs = {f'{field.get_alias()}__in': ids} else: target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias() kwargs = {f'{target_field}__in': ids} @@ -73,13 +83,15 @@ class PrefetchQuery: reverse = False target_field = model.Meta.model_fields[related] - if target_field.virtual or issubclass(target_field, ManyToManyField): - reverse = True - target_model = target_field.to.get_name() - if reverse: + if target_field.virtual: + reverse = True field_name = model.resolve_relation_name(target_field.to, model) model_id = model.pk + elif issubclass(target_field, ManyToManyField): + reverse = True + field_name = model.resolve_relation_name(target_field.through, model) + model_id = model.pk else: related_name = model.resolve_relation_name(model, target_field.to) related_model = getattr(model, related_name) @@ -89,17 +101,16 @@ class PrefetchQuery: field_name = target_field.to.Meta.pkname if target_model in already_extracted and already_extracted[target_model]['models']: - print('*****POPULATING RELATED:', target_model, field_name) + print('*****POPULATING RELATED:', target_model, field_name, '*****', end='\n') print(already_extracted[target_model]['models']) - for child_model in already_extracted[target_model]['models']: - related_model = getattr(child_model, field_name) - if isinstance(related_model, list): - for child in related_model: - if child.pk == model_id: - setattr(model, related, child) + for ind, child_model in enumerate(already_extracted[target_model]['models']): + if issubclass(target_field, ManyToManyField): + raw_data = already_extracted[target_model]['raw'][ind] + if raw_data[field_name] == model_id: + setattr(model, related, child_model) - elif isinstance(related_model, ormar.Model): - if related_model.pk == model_id: + elif isinstance(getattr(child_model, field_name), ormar.Model): + if getattr(child_model, field_name).pk == model_id: if reverse: setattr(model, related, child_model) else: @@ -123,6 +134,7 @@ class PrefetchQuery: exclude_fields = self._exclude_columns for part in related.split('__'): fields = target_model.get_included(fields, part) + select_related = [] exclude_fields = target_model.get_excluded(exclude_fields, part) target_field = target_model.Meta.model_fields[part] @@ -130,6 +142,9 @@ class PrefetchQuery: if target_field.virtual or issubclass(target_field, ManyToManyField): reverse = True + if issubclass(target_field, ManyToManyField): + select_related = [target_field.through.get_name()] + parent_model = target_model target_model = target_field.to @@ -141,9 +156,17 @@ class PrefetchQuery: if not filter_clauses: # related field is empty continue + query_target = target_model + table_prefix = '' + if issubclass(target_field, ManyToManyField): + query_target = target_field.through + select_related = [target_field.to.get_name()] + table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( + from_table=query_target.Meta.tablename, to_table=target_field.to.Meta.tablename) + qry = Query( - model_cls=target_model, - select_related=[], + model_cls=query_target, + select_related=select_related, filter_clauses=filter_clauses, exclude_clauses=[], offset=None, @@ -161,7 +184,7 @@ class PrefetchQuery: item = target_model.extract_prefixed_table_columns( item={}, row=row, - table_prefix='', + table_prefix=table_prefix, fields=fields, exclude_fields=exclude_fields ) @@ -174,12 +197,22 @@ class PrefetchQuery: for part in related.split('__')[:-1]: fields = target_model.get_included(fields, part) exclude_fields = target_model.get_excluded(exclude_fields, part) - target_model = target_model.Meta.model_fields[part].to + + target_field = target_model.Meta.model_fields[part] + target_model = target_field.to + table_prefix = '' + + if issubclass(target_field, ManyToManyField): + from_table = target_field.through.Meta.tablename + to_name = target_field.to.Meta.tablename + table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( + from_table=from_table, to_table=to_name) + for row in already_extracted.get(target_model.get_name(), {}).get('raw', []): item = target_model.extract_prefixed_table_columns( item={}, row=row, - table_prefix='', + table_prefix=table_prefix, fields=fields, exclude_fields=exclude_fields ) diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index 2582587..3b098c4 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import List, Optional import databases import pytest @@ -21,6 +21,23 @@ class Tonation(ormar.Model): name: str = ormar.String(max_length=100) +class Shop(ormar.Model): + class Meta: + tablename = "shops" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class AlbumShops(ormar.Model): + class Meta: + tablename = "albums_x_shops" + metadata = metadata + database = database + + class Album(ormar.Model): class Meta: tablename = "albums" @@ -29,6 +46,7 @@ class Album(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) + shops: List[Shop] = ormar.ManyToMany(to=Shop, through=AlbumShops) class Track(ormar.Model): @@ -115,6 +133,32 @@ async def test_prefetch_related(): assert len(tracks) == 6 +@pytest.mark.asyncio +async def test_prefetch_related_with_many_to_many(): + async with database: + async with database.transaction(force_rollback=True): + shop1 = await Shop.objects.create(name='Shop 1') + shop2 = await Shop.objects.create(name='Shop 2') + album = Album(name="Malibu") + await album.save() + await album.shops.add(shop1) + await album.shops.add(shop2) + + await Track.objects.create(album=album, title="The Bird", position=1) + await Track.objects.create(album=album, title="Heart don't stand a chance", position=2) + await Track.objects.create(album=album, title="The Waters", position=3) + await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') + await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + + track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops"]).get( + title="The Bird") + assert track.album.name == "Malibu" + assert len(track.album.cover_pictures) == 2 + assert track.album.cover_pictures[0].artist == 'Artist 1' + + assert len(track.album.shops) == 2 + + @pytest.mark.asyncio async def test_prefetch_related_empty(): async with database: From f2fe41d38abb2740740baca43987efe52a1d40e2 Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 25 Nov 2020 13:28:51 +0100 Subject: [PATCH 03/13] cleaner version but still dirty --- ormar/queryset/prefetch_query.py | 417 +++++++++++++++++++------------ ormar/queryset/queryset.py | 55 ++-- tests/test_prefetch_related.py | 20 +- 3 files changed, 300 insertions(+), 192 deletions(-) diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 6cc9ec3..8b7ade2 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -1,228 +1,317 @@ -from typing import Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union +from typing import ( + Dict, + List, + Optional, + Sequence, + Set, + TYPE_CHECKING, + Tuple, + Type, + Union, +) import ormar -from ormar.fields import ManyToManyField +from ormar.fields import BaseField, ManyToManyField from ormar.queryset.clause import QueryClause from ormar.queryset.query import Query +from ormar.queryset.utils import translate_list_to_dict if TYPE_CHECKING: # pragma: no cover from ormar import Model class PrefetchQuery: - - def __init__(self, - model_cls: Type["Model"], - fields: Optional[Union[Dict, Set]], - exclude_fields: Optional[Union[Dict, Set]], - prefetch_related: List - ): + def __init__( + self, + model_cls: Type["Model"], + fields: Optional[Union[Dict, Set]], + exclude_fields: Optional[Union[Dict, Set]], + prefetch_related: List, + select_related: List, + ) -> None: self.model = model_cls self.database = self.model.Meta.database self._prefetch_related = prefetch_related + self._select_related = select_related self._exclude_columns = exclude_fields self._columns = fields @staticmethod - def _extract_required_ids(already_extracted: Dict, - parent_model: Type["Model"], - target_model: Type["Model"], - reverse: bool) -> Set: - raw_rows = already_extracted.get(parent_model.get_name(), {}).get('raw', []) + def _extract_required_ids( + already_extracted: Dict, + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool, + ) -> Set: + raw_rows = already_extracted.get(parent_model.get_name(), {}).get("raw", []) + table_prefix = already_extracted.get(parent_model.get_name(), {}).get( + "prefix", "" + ) if reverse: column_name = parent_model.get_column_alias(parent_model.Meta.pkname) else: - column_name = target_model.resolve_relation_field(parent_model, target_model).get_alias() + column_name = target_model.resolve_relation_field( + parent_model, target_model + ).get_alias() list_of_ids = set() + column_name = (f"{table_prefix}_" if table_prefix else "") + column_name for row in raw_rows: if row[column_name]: list_of_ids.add(row[column_name]) return list_of_ids @staticmethod - def _get_filter_for_prefetch(already_extracted: Dict, - parent_model: Type["Model"], - target_model: Type["Model"], - reverse: bool) -> List: - ids = PrefetchQuery._extract_required_ids(already_extracted=already_extracted, - parent_model=parent_model, - target_model=target_model, - reverse=reverse) + def _get_filter_for_prefetch( + already_extracted: Dict, + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool, + ) -> List: + ids = PrefetchQuery._extract_required_ids( + already_extracted=already_extracted, + parent_model=parent_model, + target_model=target_model, + reverse=reverse, + ) if ids: qryclause = QueryClause( - model_cls=target_model, - select_related=[], - filter_clauses=[], + model_cls=target_model, select_related=[], filter_clauses=[], ) if reverse: field = target_model.resolve_relation_field(target_model, parent_model) if issubclass(field, ManyToManyField): - sub_field = target_model.resolve_relation_field(field.through, parent_model) - kwargs = {f'{sub_field.get_alias()}__in': ids} + sub_field = target_model.resolve_relation_field( + field.through, parent_model + ) + kwargs = {f"{sub_field.get_alias()}__in": ids} qryclause = QueryClause( - model_cls=field.through, - select_related=[], - filter_clauses=[], + model_cls=field.through, select_related=[], filter_clauses=[], ) else: - kwargs = {f'{field.get_alias()}__in': ids} + kwargs = {f"{field.get_alias()}__in": ids} else: - target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias() - kwargs = {f'{target_field}__in': ids} + target_field = target_model.Meta.model_fields[ + target_model.Meta.pkname + ].get_alias() + kwargs = {f"{target_field}__in": ids} filter_clauses, _ = qryclause.filter(**kwargs) return filter_clauses return [] @staticmethod - def _populate_nested_related(model: "Model", - already_extracted: Dict) -> "Model": - - for related in model.extract_related_names(): + def _get_model_id_and_field_name( + target_field: Type["BaseField"], model: "Model" + ) -> Tuple[bool, Optional[str], Optional[int]]: + if target_field.virtual: + reverse = True + field_name = model.resolve_relation_name(target_field.to, model) + model_id = model.pk + elif issubclass(target_field, ManyToManyField): + reverse = True + field_name = model.resolve_relation_name(target_field.through, model) + model_id = model.pk + else: reverse = False + related_name = model.resolve_relation_name(model, target_field.to) + related_model = getattr(model, related_name) + if not related_model: + return reverse, None, None + model_id = related_model.pk + field_name = target_field.to.Meta.pkname + return reverse, field_name, model_id + + @staticmethod + def _get_names_to_extract(prefetch_dict: Dict, model: "Model") -> List: + related_to_extract = [] + if prefetch_dict and prefetch_dict is not Ellipsis: + related_to_extract = [ + related + for related in model.extract_related_names() + if related in prefetch_dict + ] + return related_to_extract + + @staticmethod + def _populate_nested_related( + model: "Model", already_extracted: Dict, prefetch_dict: Dict + ) -> "Model": + + related_to_extract = PrefetchQuery._get_names_to_extract( + prefetch_dict=prefetch_dict, model=model + ) + + for related in related_to_extract: target_field = model.Meta.model_fields[related] target_model = target_field.to.get_name() - if target_field.virtual: - reverse = True - field_name = model.resolve_relation_name(target_field.to, model) - model_id = model.pk - elif issubclass(target_field, ManyToManyField): - reverse = True - field_name = model.resolve_relation_name(target_field.through, model) - model_id = model.pk - else: - related_name = model.resolve_relation_name(model, target_field.to) - related_model = getattr(model, related_name) - if not related_model: - continue - model_id = related_model.pk - field_name = target_field.to.Meta.pkname + reverse, field_name, model_id = PrefetchQuery._get_model_id_and_field_name( + target_field=target_field, model=model + ) - if target_model in already_extracted and already_extracted[target_model]['models']: - print('*****POPULATING RELATED:', target_model, field_name, '*****', end='\n') - print(already_extracted[target_model]['models']) - for ind, child_model in enumerate(already_extracted[target_model]['models']): + if ( + target_model in already_extracted + and already_extracted[target_model]["models"] + ): + for key, child_model in already_extracted[target_model][ + "models" + ].items(): if issubclass(target_field, ManyToManyField): - raw_data = already_extracted[target_model]['raw'][ind] - if raw_data[field_name] == model_id: + ind = next( + i + if key == x[target_field.to.get_column_alias(field_name)] + else -1 + for i, x in enumerate( + already_extracted[target_model]["raw"] + ) + ) + raw_data = already_extracted[target_model]["raw"][ind] + if ( + raw_data + and field_name in raw_data + and raw_data[field_name] == model_id + ): setattr(model, related, child_model) elif isinstance(getattr(child_model, field_name), ormar.Model): if getattr(child_model, field_name).pk == model_id: - if reverse: - setattr(model, related, child_model) - else: - setattr(child_model, related, model) + setattr(model, related, child_model) - else: # we have not reverse relation and related_model is a pk value + elif getattr(child_model, field_name) == model_id: setattr(model, related, child_model) return model - async def prefetch_related(self, models: Sequence["Model"], rows: List): + async def prefetch_related( + self, models: Sequence["Model"], rows: List + ) -> Sequence["Model"]: return await self._prefetch_related_models(models=models, rows=rows) - async def _prefetch_related_models(self, - models: Sequence["Model"], - rows: List) -> Sequence["Model"]: - already_extracted = {self.model.get_name(): {'raw': rows, 'models': models}} - for related in self._prefetch_related: - target_model = self.model - fields = self._columns - exclude_fields = self._exclude_columns - for part in related.split('__'): - fields = target_model.get_included(fields, part) - select_related = [] - exclude_fields = target_model.get_excluded(exclude_fields, part) - - target_field = target_model.Meta.model_fields[part] - reverse = False - if target_field.virtual or issubclass(target_field, ManyToManyField): - reverse = True - - if issubclass(target_field, ManyToManyField): - select_related = [target_field.through.get_name()] - - parent_model = target_model - target_model = target_field.to - - if target_model.get_name() not in already_extracted: - filter_clauses = self._get_filter_for_prefetch(already_extracted=already_extracted, - parent_model=parent_model, - target_model=target_model, - reverse=reverse) - if not filter_clauses: # related field is empty - continue - - query_target = target_model - table_prefix = '' - if issubclass(target_field, ManyToManyField): - query_target = target_field.through - select_related = [target_field.to.get_name()] - table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( - from_table=query_target.Meta.tablename, to_table=target_field.to.Meta.tablename) - - qry = Query( - model_cls=query_target, - select_related=select_related, - filter_clauses=filter_clauses, - exclude_clauses=[], - offset=None, - limit_count=None, - fields=fields, - exclude_fields=exclude_fields, - order_bys=None, - ) - expr = qry.build_select_expression() - print(expr.compile(compile_kwargs={"literal_binds": True})) - rows = await self.database.fetch_all(expr) - already_extracted[target_model.get_name()] = {'raw': rows, 'models': []} - if part == related.split('__')[-1]: - for row in rows: - item = target_model.extract_prefixed_table_columns( - item={}, - row=row, - table_prefix=table_prefix, - fields=fields, - exclude_fields=exclude_fields - ) - instance = target_model(**item) - already_extracted[target_model.get_name()]['models'].append(instance) - - target_model = self.model - fields = self._columns - exclude_fields = self._exclude_columns - for part in related.split('__')[:-1]: - fields = target_model.get_included(fields, part) - exclude_fields = target_model.get_excluded(exclude_fields, part) - - target_field = target_model.Meta.model_fields[part] - target_model = target_field.to - table_prefix = '' - - if issubclass(target_field, ManyToManyField): - from_table = target_field.through.Meta.tablename - to_name = target_field.to.Meta.tablename - table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( - from_table=from_table, to_table=to_name) - - for row in already_extracted.get(target_model.get_name(), {}).get('raw', []): - item = target_model.extract_prefixed_table_columns( - item={}, - row=row, - table_prefix=table_prefix, - fields=fields, - exclude_fields=exclude_fields - ) - instance = target_model(**item) - instance = self._populate_nested_related(model=instance, - already_extracted=already_extracted) - - already_extracted[target_model.get_name()]['models'].append(instance) + async def _prefetch_related_models( + self, models: Sequence["Model"], rows: List + ) -> Sequence["Model"]: + already_extracted = { + self.model.get_name(): { + "raw": rows, + "models": {model.pk: model for model in models}, + } + } + select_dict = translate_list_to_dict(self._select_related) + prefetch_dict = translate_list_to_dict(self._prefetch_related) + target_model = self.model + fields = self._columns + exclude_fields = self._exclude_columns + for related in prefetch_dict.keys(): + await self._extract_related_models( + related=related, + target_model=target_model, + prefetch_dict=prefetch_dict.get(related), + select_dict=select_dict.get(related), + already_extracted=already_extracted, + fields=fields, + exclude_fields=exclude_fields, + ) final_models = [] for model in models: - final_models.append(self._populate_nested_related(model=model, - already_extracted=already_extracted)) + final_models.append( + self._populate_nested_related( + model=model, + already_extracted=already_extracted, + prefetch_dict=prefetch_dict, + ) + ) return models + + async def _extract_related_models( # noqa: CFQ002 + self, + related: str, + target_model: Type["Model"], + prefetch_dict: Dict, + select_dict: Dict, + already_extracted: Dict, + fields: Dict, + exclude_fields: Dict, + ) -> None: + + fields = target_model.get_included(fields, related) + exclude_fields = target_model.get_excluded(exclude_fields, related) + select_related = [] + + target_field = target_model.Meta.model_fields[related] + reverse = False + if target_field.virtual or issubclass(target_field, ManyToManyField): + reverse = True + + parent_model = target_model + target_model = target_field.to + + filter_clauses = PrefetchQuery._get_filter_for_prefetch( + already_extracted=already_extracted, + parent_model=parent_model, + target_model=target_model, + reverse=reverse, + ) + if not filter_clauses: # related field is empty + return + + query_target = target_model + table_prefix = "" + if issubclass(target_field, ManyToManyField): + query_target = target_field.through + select_related = [target_field.to.get_name()] + table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( + from_table=query_target.Meta.tablename, + to_table=target_field.to.Meta.tablename, + ) + already_extracted.setdefault(target_model.get_name(), {})[ + "prefix" + ] = table_prefix + + qry = Query( + model_cls=query_target, + select_related=select_related, + filter_clauses=filter_clauses, + exclude_clauses=[], + offset=None, + limit_count=None, + fields=fields, + exclude_fields=exclude_fields, + order_bys=None, + ) + expr = qry.build_select_expression() + # print(expr.compile(compile_kwargs={"literal_binds": True})) + rows = await self.database.fetch_all(expr) + already_extracted.setdefault(target_model.get_name(), {}).update( + {"raw": rows, "models": {}} + ) + + if prefetch_dict and prefetch_dict is not Ellipsis: + for subrelated in prefetch_dict.keys(): + await self._extract_related_models( + related=subrelated, + target_model=target_model, + prefetch_dict=prefetch_dict.get(subrelated), + select_dict=select_dict.get(subrelated) + if (select_dict and subrelated in select_dict) + else {}, + already_extracted=already_extracted, + fields=fields, + exclude_fields=exclude_fields, + ) + + for row in rows: + item = target_model.extract_prefixed_table_columns( + item={}, + row=row, + table_prefix=table_prefix, + fields=fields, + exclude_fields=exclude_fields, + ) + instance = target_model(**item) + instance = self._populate_nested_related( + model=instance, + already_extracted=already_extracted, + prefetch_dict=prefetch_dict, + ) + already_extracted[target_model.get_name()]["models"][instance.pk] = instance diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 6e54129..ce902f7 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -21,17 +21,17 @@ if TYPE_CHECKING: # pragma no cover class QuerySet: def __init__( # noqa CFQ002 - self, - model_cls: Type["Model"] = None, - filter_clauses: List = None, - exclude_clauses: List = None, - select_related: List = None, - limit_count: int = None, - offset: int = None, - columns: Dict = None, - exclude_columns: Dict = None, - order_bys: List = None, - prefetch_related: List = None, + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + exclude_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, + columns: Dict = None, + exclude_columns: Dict = None, + order_bys: List = None, + prefetch_related: List = None, ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses @@ -45,9 +45,9 @@ class QuerySet: self.order_bys = order_bys or [] def __get__( - self, - instance: Optional[Union["QuerySet", "QuerysetProxy"]], - owner: Union[Type["Model"], Type["QuerysetProxy"]], + self, + instance: Optional[Union["QuerySet", "QuerysetProxy"]], + owner: Union[Type["Model"], Type["QuerysetProxy"]], ) -> "QuerySet": if issubclass(owner, ormar.Model): return self.__class__(model_cls=owner) @@ -66,11 +66,16 @@ class QuerySet: raise ValueError("Model class of QuerySet is not initialized") return self.model_cls - async def _prefetch_related_models(self, models: Sequence["Model"], rows: List) -> Sequence["Model"]: - query = PrefetchQuery(model_cls=self.model_cls, - fields=self._columns, - exclude_fields=self._exclude_columns, - prefetch_related=self._prefetch_related) + async def _prefetch_related_models( + self, models: Sequence["Model"], rows: List + ) -> Sequence["Model"]: + query = PrefetchQuery( + model_cls=self.model_cls, + fields=self._columns, + exclude_fields=self._exclude_columns, + prefetch_related=self._prefetch_related, + select_related=self._select_related, + ) return await query.prefetch_related(models=models, rows=rows) def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: @@ -98,7 +103,7 @@ class QuerySet: pkname = self.model_meta.pkname pk = self.model_meta.model_fields[pkname] if new_kwargs.get(pkname, ormar.Undefined) is None and ( - pk.nullable or pk.autoincrement + pk.nullable or pk.autoincrement ): del new_kwargs[pkname] return new_kwargs @@ -197,7 +202,7 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=self.order_bys, - prefetch_related=related + prefetch_related=related, ) def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": @@ -398,9 +403,9 @@ class QuerySet: # refresh server side defaults if any( - field.server_default is not None - for name, field in self.model.Meta.model_fields.items() - if name not in kwargs + field.server_default is not None + for name, field in self.model.Meta.model_fields.items() + if name not in kwargs ): instance = await instance.load() instance.set_save_status(True) @@ -420,7 +425,7 @@ class QuerySet: objt.set_save_status(True) async def bulk_update( # noqa: CCR001 - self, objects: List["Model"], columns: List[str] = None + self, objects: List["Model"], columns: List[str] = None ) -> None: ready_objects = [] pk_name = self.model_meta.pkname diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index 3b098c4..a6b8831 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -21,6 +21,16 @@ class Tonation(ormar.Model): name: str = ormar.String(max_length=100) +class Division(ormar.Model): + class Meta: + tablename = "divisions" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + class Shop(ormar.Model): class Meta: tablename = "shops" @@ -29,6 +39,7 @@ class Shop(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) + division: Optional[Division] = ormar.ForeignKey(Division) class AlbumShops(ormar.Model): @@ -137,8 +148,9 @@ async def test_prefetch_related(): async def test_prefetch_related_with_many_to_many(): async with database: async with database.transaction(force_rollback=True): - shop1 = await Shop.objects.create(name='Shop 1') - shop2 = await Shop.objects.create(name='Shop 2') + div = await Division.objects.create(name='Div 1') + shop1 = await Shop.objects.create(name='Shop 1', division=div) + shop2 = await Shop.objects.create(name='Shop 2', division=div) album = Album(name="Malibu") await album.save() await album.shops.add(shop1) @@ -150,13 +162,15 @@ async def test_prefetch_related_with_many_to_many(): await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') - track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops"]).get( + track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops__division"]).get( title="The Bird") assert track.album.name == "Malibu" assert len(track.album.cover_pictures) == 2 assert track.album.cover_pictures[0].artist == 'Artist 1' assert len(track.album.shops) == 2 + assert track.album.shops[0].name == 'Shop 1' + assert track.album.shops[0].division.name == 'Div 1' @pytest.mark.asyncio From e0223f8a2251df8a92c0559f4934c5a7356cc15e Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 25 Nov 2020 16:56:54 +0100 Subject: [PATCH 04/13] cleanup by related field --- ormar/queryset/prefetch_query.py | 82 +++++++++++++++----------------- 1 file changed, 39 insertions(+), 43 deletions(-) diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 8b7ade2..39ef64e 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -10,7 +10,6 @@ from typing import ( Union, ) -import ormar from ormar.fields import BaseField, ManyToManyField from ormar.queryset.clause import QueryClause from ormar.queryset.query import Query @@ -44,10 +43,9 @@ class PrefetchQuery: target_model: Type["Model"], reverse: bool, ) -> Set: - raw_rows = already_extracted.get(parent_model.get_name(), {}).get("raw", []) - table_prefix = already_extracted.get(parent_model.get_name(), {}).get( - "prefix", "" - ) + current_data = already_extracted.get(parent_model.get_name(), {}) + raw_rows = current_data.get("raw", []) + table_prefix = current_data.get("prefix", "") if reverse: column_name = parent_model.get_column_alias(parent_model.Meta.pkname) else: @@ -105,23 +103,33 @@ class PrefetchQuery: target_field: Type["BaseField"], model: "Model" ) -> Tuple[bool, Optional[str], Optional[int]]: if target_field.virtual: - reverse = True + is_multi = False field_name = model.resolve_relation_name(target_field.to, model) model_id = model.pk elif issubclass(target_field, ManyToManyField): - reverse = True + is_multi = True field_name = model.resolve_relation_name(target_field.through, model) model_id = model.pk else: - reverse = False + is_multi = False related_name = model.resolve_relation_name(model, target_field.to) related_model = getattr(model, related_name) if not related_model: - return reverse, None, None + return is_multi, None, None model_id = related_model.pk field_name = target_field.to.Meta.pkname - return reverse, field_name, model_id + return is_multi, field_name, model_id + + @staticmethod + def _get_group_field_name( + target_field: Type["BaseField"], model: Type["Model"] + ) -> str: + if issubclass(target_field, ManyToManyField): + return model.resolve_relation_name(target_field.through, model) + if target_field.virtual: + return model.resolve_relation_name(target_field.to, model) + return target_field.to.Meta.pkname @staticmethod def _get_names_to_extract(prefetch_dict: Dict, model: "Model") -> List: @@ -146,40 +154,17 @@ class PrefetchQuery: for related in related_to_extract: target_field = model.Meta.model_fields[related] target_model = target_field.to.get_name() - reverse, field_name, model_id = PrefetchQuery._get_model_id_and_field_name( + is_multi, field_name, model_id = PrefetchQuery._get_model_id_and_field_name( target_field=target_field, model=model ) + if not field_name: + continue - if ( - target_model in already_extracted - and already_extracted[target_model]["models"] - ): - for key, child_model in already_extracted[target_model][ - "models" - ].items(): - if issubclass(target_field, ManyToManyField): - ind = next( - i - if key == x[target_field.to.get_column_alias(field_name)] - else -1 - for i, x in enumerate( - already_extracted[target_model]["raw"] - ) - ) - raw_data = already_extracted[target_model]["raw"][ind] - if ( - raw_data - and field_name in raw_data - and raw_data[field_name] == model_id - ): - setattr(model, related, child_model) - - elif isinstance(getattr(child_model, field_name), ormar.Model): - if getattr(child_model, field_name).pk == model_id: - setattr(model, related, child_model) - - elif getattr(child_model, field_name) == model_id: - setattr(model, related, child_model) + children = already_extracted.get(target_model, {}).get(field_name, {}) + for key, child_models in children.items(): + if key == model_id: + for child in child_models: + setattr(model, related, child) return model @@ -203,7 +188,7 @@ class PrefetchQuery: fields = self._columns exclude_fields = self._exclude_columns for related in prefetch_dict.keys(): - await self._extract_related_models( + subrelated = await self._extract_related_models( related=related, target_model=target_model, prefetch_dict=prefetch_dict.get(related), @@ -212,6 +197,7 @@ class PrefetchQuery: fields=fields, exclude_fields=exclude_fields, ) + print(related, subrelated) final_models = [] for model in models: final_models.append( @@ -288,7 +274,7 @@ class PrefetchQuery: if prefetch_dict and prefetch_dict is not Ellipsis: for subrelated in prefetch_dict.keys(): - await self._extract_related_models( + submodels = await self._extract_related_models( related=subrelated, target_model=target_model, prefetch_dict=prefetch_dict.get(subrelated), @@ -299,8 +285,13 @@ class PrefetchQuery: fields=fields, exclude_fields=exclude_fields, ) + print(subrelated, submodels) for row in rows: + field_name = PrefetchQuery._get_group_field_name( + target_field=target_field, model=parent_model + ) + print("TEST", field_name, target_model, row[field_name]) item = target_model.extract_prefixed_table_columns( item={}, row=row, @@ -314,4 +305,9 @@ class PrefetchQuery: already_extracted=already_extracted, prefetch_dict=prefetch_dict, ) + already_extracted[target_model.get_name()].setdefault( + field_name, dict() + ).setdefault(row[field_name], []).append(instance) already_extracted[target_model.get_name()]["models"][instance.pk] = instance + + return already_extracted[target_model.get_name()]["models"] From d6f995d3499b4aa5fd52fc5ba83d74629e1584c1 Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 25 Nov 2020 20:52:01 +0100 Subject: [PATCH 05/13] refactor and cleanup for further optimization --- ormar/queryset/prefetch_query.py | 320 ++++++++++++++++++++----------- ormar/queryset/queryset.py | 14 +- ormar/queryset/utils.py | 41 +++- tests/test_prefetch_related.py | 61 ++++++ 4 files changed, 313 insertions(+), 123 deletions(-) diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 39ef64e..243e66e 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -1,4 +1,5 @@ from typing import ( + Any, Dict, List, Optional, @@ -10,10 +11,11 @@ from typing import ( Union, ) +import ormar from ormar.fields import BaseField, ManyToManyField from ormar.queryset.clause import QueryClause from ormar.queryset.query import Query -from ormar.queryset.utils import translate_list_to_dict +from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict if TYPE_CHECKING: # pragma: no cover from ormar import Model @@ -35,65 +37,114 @@ class PrefetchQuery: self._select_related = select_related self._exclude_columns = exclude_fields self._columns = fields + self.already_extracted: Dict = dict() + self.models: Dict = {} + self.select_dict = translate_list_to_dict(self._select_related) + + async def prefetch_related( + self, models: Sequence["Model"], rows: List + ) -> Sequence["Model"]: + self.models = extract_models_to_dict_of_lists( + model_type=self.model, models=models, select_dict=self.select_dict + ) + self.models[self.model.get_name()] = models + return await self._prefetch_related_models(models=models, rows=rows) @staticmethod - def _extract_required_ids( - already_extracted: Dict, + def _get_column_name_for_id_extraction( parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, - ) -> Set: - current_data = already_extracted.get(parent_model.get_name(), {}) - raw_rows = current_data.get("raw", []) - table_prefix = current_data.get("prefix", "") + use_raw: bool, + ) -> str: if reverse: - column_name = parent_model.get_column_alias(parent_model.Meta.pkname) + column_name = parent_model.Meta.pkname + return ( + parent_model.get_column_alias(column_name) if use_raw else column_name + ) else: - column_name = target_model.resolve_relation_field( - parent_model, target_model - ).get_alias() + column = target_model.resolve_relation_field(parent_model, target_model) + return column.get_alias() if use_raw else column.name + + def _extract_ids_from_raw_data( + self, parent_model: Type["Model"], column_name: str + ) -> Set: list_of_ids = set() + current_data = self.already_extracted.get(parent_model.get_name(), {}) + table_prefix = current_data.get("prefix", "") column_name = (f"{table_prefix}_" if table_prefix else "") + column_name - for row in raw_rows: + for row in current_data.get("raw", []): if row[column_name]: list_of_ids.add(row[column_name]) return list_of_ids - @staticmethod - def _get_filter_for_prefetch( - already_extracted: Dict, - parent_model: Type["Model"], - target_model: Type["Model"], - reverse: bool, - ) -> List: - ids = PrefetchQuery._extract_required_ids( - already_extracted=already_extracted, + def _extract_ids_from_preloaded_models( + self, parent_model: Type["Model"], column_name: str + ) -> Set: + list_of_ids = set() + for model in self.models.get(parent_model.get_name(), []): + child = getattr(model, column_name) + if isinstance(child, ormar.Model): + list_of_ids.add(child.pk) + else: + list_of_ids.add(child) + return list_of_ids + + def _extract_required_ids( + self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, + ) -> Set: + + use_raw = parent_model.get_name() not in self.models + + column_name = self._get_column_name_for_id_extraction( parent_model=parent_model, target_model=target_model, reverse=reverse, + use_raw=use_raw, + ) + + if use_raw: + return self._extract_ids_from_raw_data( + parent_model=parent_model, column_name=column_name + ) + + return self._extract_ids_from_preloaded_models( + parent_model=parent_model, column_name=column_name + ) + + @staticmethod + def _get_clause_target_and_filter_column_name( + parent_model: Type["Model"], target_model: Type["Model"], reverse: bool + ) -> Tuple[Type["Model"], str]: + if reverse: + field = target_model.resolve_relation_field(target_model, parent_model) + if issubclass(field, ManyToManyField): + sub_field = target_model.resolve_relation_field( + field.through, parent_model + ) + return field.through, sub_field.get_alias() + else: + return target_model, field.get_alias() + target_field = target_model.get_column_alias(target_model.Meta.pkname) + return target_model, target_field + + def _get_filter_for_prefetch( + self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, + ) -> List: + ids = self._extract_required_ids( + parent_model=parent_model, target_model=target_model, reverse=reverse, ) if ids: - qryclause = QueryClause( - model_cls=target_model, select_related=[], filter_clauses=[], + ( + clause_target, + filter_column, + ) = self._get_clause_target_and_filter_column_name( + parent_model=parent_model, target_model=target_model, reverse=reverse ) - if reverse: - field = target_model.resolve_relation_field(target_model, parent_model) - if issubclass(field, ManyToManyField): - sub_field = target_model.resolve_relation_field( - field.through, parent_model - ) - kwargs = {f"{sub_field.get_alias()}__in": ids} - qryclause = QueryClause( - model_cls=field.through, select_related=[], filter_clauses=[], - ) - - else: - kwargs = {f"{field.get_alias()}__in": ids} - else: - target_field = target_model.Meta.model_fields[ - target_model.Meta.pkname - ].get_alias() - kwargs = {f"{target_field}__in": ids} + qryclause = QueryClause( + model_cls=clause_target, select_related=[], filter_clauses=[], + ) + kwargs = {f"{filter_column}__in": ids} filter_clauses, _ = qryclause.filter(**kwargs) return filter_clauses return [] @@ -123,7 +174,7 @@ class PrefetchQuery: @staticmethod def _get_group_field_name( - target_field: Type["BaseField"], model: Type["Model"] + target_field: Type["BaseField"], model: Union["Model", Type["Model"]] ) -> str: if issubclass(target_field, ManyToManyField): return model.resolve_relation_name(target_field.through, model) @@ -142,117 +193,150 @@ class PrefetchQuery: ] return related_to_extract - @staticmethod - def _populate_nested_related( - model: "Model", already_extracted: Dict, prefetch_dict: Dict - ) -> "Model": + def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model": - related_to_extract = PrefetchQuery._get_names_to_extract( + related_to_extract = self._get_names_to_extract( prefetch_dict=prefetch_dict, model=model ) for related in related_to_extract: target_field = model.Meta.model_fields[related] target_model = target_field.to.get_name() - is_multi, field_name, model_id = PrefetchQuery._get_model_id_and_field_name( + is_multi, field_name, model_id = self._get_model_id_and_field_name( target_field=target_field, model=model ) - if not field_name: + + if field_name is None or model_id is None: # pragma: no cover continue - children = already_extracted.get(target_model, {}).get(field_name, {}) - for key, child_models in children.items(): - if key == model_id: - for child in child_models: - setattr(model, related, child) + children = self.already_extracted.get(target_model, {}).get(field_name, {}) + self._set_children_on_model( + model=model, related=related, children=children, model_id=model_id + ) return model - async def prefetch_related( - self, models: Sequence["Model"], rows: List - ) -> Sequence["Model"]: - return await self._prefetch_related_models(models=models, rows=rows) + @staticmethod + def _set_children_on_model( + model: "Model", related: str, children: Dict, model_id: int + ) -> None: + for key, child_models in children.items(): + if key == model_id: + for child in child_models: + setattr(model, related, child) async def _prefetch_related_models( self, models: Sequence["Model"], rows: List ) -> Sequence["Model"]: - already_extracted = { - self.model.get_name(): { - "raw": rows, - "models": {model.pk: model for model in models}, - } - } + self.already_extracted = {self.model.get_name(): {"raw": rows}} select_dict = translate_list_to_dict(self._select_related) prefetch_dict = translate_list_to_dict(self._prefetch_related) target_model = self.model fields = self._columns exclude_fields = self._exclude_columns for related in prefetch_dict.keys(): - subrelated = await self._extract_related_models( + await self._extract_related_models( related=related, target_model=target_model, - prefetch_dict=prefetch_dict.get(related), - select_dict=select_dict.get(related), - already_extracted=already_extracted, + prefetch_dict=prefetch_dict.get(related, {}), + select_dict=select_dict.get(related, {}), fields=fields, exclude_fields=exclude_fields, ) - print(related, subrelated) final_models = [] for model in models: final_models.append( - self._populate_nested_related( - model=model, - already_extracted=already_extracted, - prefetch_dict=prefetch_dict, - ) + self._populate_nested_related(model=model, prefetch_dict=prefetch_dict,) ) return models - async def _extract_related_models( # noqa: CFQ002 + async def _extract_related_models( # noqa: CFQ002, CCR001 self, related: str, target_model: Type["Model"], prefetch_dict: Dict, select_dict: Dict, - already_extracted: Dict, - fields: Dict, - exclude_fields: Dict, + fields: Union[Set[Any], Dict[Any, Any], None], + exclude_fields: Union[Set[Any], Dict[Any, Any], None], ) -> None: fields = target_model.get_included(fields, related) exclude_fields = target_model.get_excluded(exclude_fields, related) - select_related = [] - target_field = target_model.Meta.model_fields[related] reverse = False if target_field.virtual or issubclass(target_field, ManyToManyField): reverse = True parent_model = target_model - target_model = target_field.to - filter_clauses = PrefetchQuery._get_filter_for_prefetch( - already_extracted=already_extracted, - parent_model=parent_model, - target_model=target_model, - reverse=reverse, + filter_clauses = self._get_filter_for_prefetch( + parent_model=parent_model, target_model=target_field.to, reverse=reverse, ) if not filter_clauses: # related field is empty return + already_loaded = select_dict is Ellipsis or related in select_dict + + if not already_loaded: + # If not already loaded with select_related + table_prefix, rows = await self._run_prefetch_query( + target_field=target_field, + fields=fields, + exclude_fields=exclude_fields, + filter_clauses=filter_clauses, + ) + else: + rows = [] + table_prefix = "" + + if prefetch_dict and prefetch_dict is not Ellipsis: + for subrelated in prefetch_dict.keys(): + await self._extract_related_models( + related=subrelated, + target_model=target_field.to, + prefetch_dict=prefetch_dict.get(subrelated, {}), + select_dict=self._get_select_related_if_apply( + subrelated, select_dict + ), + fields=fields, + exclude_fields=exclude_fields, + ) + + if not already_loaded: + self._populate_rows( + rows=rows, + parent_model=parent_model, + target_field=target_field, + table_prefix=table_prefix, + fields=fields, + exclude_fields=exclude_fields, + prefetch_dict=prefetch_dict, + ) + else: + self._update_already_loaded_rows( + target_field=target_field, prefetch_dict=prefetch_dict, + ) + + async def _run_prefetch_query( + self, + target_field: Type["BaseField"], + fields: Union[Set[Any], Dict[Any, Any], None], + exclude_fields: Union[Set[Any], Dict[Any, Any], None], + filter_clauses: List, + ) -> Tuple[str, List]: + target_model = target_field.to + target_name = target_model.get_name() + select_related = [] query_target = target_model table_prefix = "" if issubclass(target_field, ManyToManyField): query_target = target_field.through - select_related = [target_field.to.get_name()] + select_related = [target_name] table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( from_table=query_target.Meta.tablename, to_table=target_field.to.Meta.tablename, ) - already_extracted.setdefault(target_model.get_name(), {})[ - "prefix" - ] = table_prefix + self.already_extracted.setdefault(target_name, {})["prefix"] = table_prefix qry = Query( model_cls=query_target, @@ -268,30 +352,41 @@ class PrefetchQuery: expr = qry.build_select_expression() # print(expr.compile(compile_kwargs={"literal_binds": True})) rows = await self.database.fetch_all(expr) - already_extracted.setdefault(target_model.get_name(), {}).update( - {"raw": rows, "models": {}} + self.already_extracted.setdefault(target_name, {}).update({"raw": rows}) + return table_prefix, rows + + @staticmethod + def _get_select_related_if_apply(related: str, select_dict: Dict) -> Dict: + return ( + select_dict.get(related, {}) + if (select_dict and select_dict is not Ellipsis and related in select_dict) + else {} ) - if prefetch_dict and prefetch_dict is not Ellipsis: - for subrelated in prefetch_dict.keys(): - submodels = await self._extract_related_models( - related=subrelated, - target_model=target_model, - prefetch_dict=prefetch_dict.get(subrelated), - select_dict=select_dict.get(subrelated) - if (select_dict and subrelated in select_dict) - else {}, - already_extracted=already_extracted, - fields=fields, - exclude_fields=exclude_fields, - ) - print(subrelated, submodels) + def _update_already_loaded_rows( # noqa: CFQ002 + self, target_field: Type["BaseField"], prefetch_dict: Dict, + ) -> None: + target_model = target_field.to + for instance in self.models.get(target_model.get_name(), []): + self._populate_nested_related( + model=instance, prefetch_dict=prefetch_dict, + ) + def _populate_rows( # noqa: CFQ002 + self, + rows: List, + target_field: Type["BaseField"], + parent_model: Type["Model"], + table_prefix: str, + fields: Union[Set[Any], Dict[Any, Any], None], + exclude_fields: Union[Set[Any], Dict[Any, Any], None], + prefetch_dict: Dict, + ) -> None: + target_model = target_field.to for row in rows: - field_name = PrefetchQuery._get_group_field_name( + field_name = self._get_group_field_name( target_field=target_field, model=parent_model ) - print("TEST", field_name, target_model, row[field_name]) item = target_model.extract_prefixed_table_columns( item={}, row=row, @@ -301,13 +396,8 @@ class PrefetchQuery: ) instance = target_model(**item) instance = self._populate_nested_related( - model=instance, - already_extracted=already_extracted, - prefetch_dict=prefetch_dict, + model=instance, prefetch_dict=prefetch_dict, ) - already_extracted[target_model.get_name()].setdefault( + self.already_extracted[target_model.get_name()].setdefault( field_name, dict() ).setdefault(row[field_name], []).append(instance) - already_extracted[target_model.get_name()]["models"][instance.pk] = instance - - return already_extracted[target_model.get_name()]["models"] diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index ce902f7..bc36965 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -67,16 +67,16 @@ class QuerySet: return self.model_cls async def _prefetch_related_models( - self, models: Sequence["Model"], rows: List - ) -> Sequence["Model"]: + self, models: Sequence[Optional["Model"]], rows: List + ) -> Sequence[Optional["Model"]]: query = PrefetchQuery( - model_cls=self.model_cls, + model_cls=self.model, fields=self._columns, exclude_fields=self._exclude_columns, prefetch_related=self._prefetch_related, select_related=self._select_related, ) - return await query.prefetch_related(models=models, rows=rows) + return await query.prefetch_related(models=models, rows=rows) # type: ignore def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: result_rows = [ @@ -191,7 +191,7 @@ class QuerySet: if not isinstance(related, list): related = [related] - related = list(set(list(self._select_related) + related)) + related = list(set(list(self._prefetch_related) + related)) return self.__class__( model_cls=self.model, filter_clauses=self.filter_clauses, @@ -352,7 +352,7 @@ class QuerySet: rows = await self.database.fetch_all(expr) processed_rows = self._process_query_result_rows(rows) - if self._prefetch_related: + if self._prefetch_related and processed_rows: processed_rows = await self._prefetch_related_models(processed_rows, rows) self.check_single_result_rows_count(processed_rows) return processed_rows[0] # type: ignore @@ -379,7 +379,7 @@ class QuerySet: expr = self.build_select_expression() rows = await self.database.fetch_all(expr) result_rows = self._process_query_result_rows(rows) - if self._prefetch_related: + if self._prefetch_related and result_rows: result_rows = await self._prefetch_related_models(result_rows, rows) return result_rows diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index c3c8fa9..ed422dd 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -1,6 +1,9 @@ import collections.abc import copy -from typing import Any, Dict, List, Set, Union +from typing import Any, Dict, List, Sequence, Set, TYPE_CHECKING, Type, Union + +if TYPE_CHECKING: # pragma no cover + from ormar import Model def check_node_not_dict_or_not_last_node( @@ -55,3 +58,39 @@ def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> dict_to_update = translate_list_to_dict(list_to_update) update(updated_dict, dict_to_update) return updated_dict + + +def extract_nested_models( # noqa: CCR001 + model: "Model", model_type: Type["Model"], select_dict: Dict, extracted: Dict +) -> None: + follow = [rel for rel in model_type.extract_related_names() if rel in select_dict] + for related in follow: + child = getattr(model, related) + if child: + target_model = model_type.Meta.model_fields[related].to + if isinstance(child, list): + extracted.setdefault(target_model.get_name(), []).extend(child) + if select_dict[related] is not Ellipsis: + for sub_child in child: + extract_nested_models( + sub_child, target_model, select_dict[related], extracted, + ) + else: + extracted.setdefault(target_model.get_name(), []).append(child) + if select_dict[related] is not Ellipsis: + extract_nested_models( + child, target_model, select_dict[related], extracted, + ) + + +def extract_models_to_dict_of_lists( + model_type: Type["Model"], + models: Sequence["Model"], + select_dict: Dict, + extracted: Dict = None, +) -> Dict: + if not extracted: + extracted = dict() + for model in models: + extract_nested_models(model, model_type, select_dict, extracted) + return extracted diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index a6b8831..378fdb4 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -11,6 +11,16 @@ database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() +class RandomSet(ormar.Model): + class Meta: + tablename = "randoms" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + class Tonation(ormar.Model): class Meta: tablename = "tonations" @@ -19,6 +29,7 @@ class Tonation(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) + rand_set: Optional[RandomSet] = ormar.ForeignKey(RandomSet) class Division(ormar.Model): @@ -181,3 +192,53 @@ async def test_prefetch_related_empty(): track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird") assert track.title == 'The Bird' assert track.album is None + + +@pytest.mark.asyncio +async def test_prefetch_related_with_select_related(): + async with database: + async with database.transaction(force_rollback=True): + div = await Division.objects.create(name='Div 1') + shop1 = await Shop.objects.create(name='Shop 1', division=div) + shop2 = await Shop.objects.create(name='Shop 2', division=div) + album = Album(name="Malibu") + await album.save() + await album.shops.add(shop1) + await album.shops.add(shop2) + + await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') + await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + + album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related( + ['cover_pictures', 'shops__division']).get() + assert len(album.tracks) == 0 + assert len(album.cover_pictures) == 2 + assert album.shops[0].division.name == 'Div 1' + + rand_set = await RandomSet.objects.create(name='Rand 1') + ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set) + await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) + await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) + await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) + + album = await Album.objects.select_related('tracks__tonation__rand_set').filter(name='Malibu').prefetch_related( + ['cover_pictures', 'shops__division']).get() + assert len(album.tracks) == 3 + assert album.tracks[0].tonation == album.tracks[2].tonation == ton1 + assert len(album.cover_pictures) == 2 + assert album.cover_pictures[0].artist == 'Artist 1' + + assert len(album.shops) == 2 + assert album.shops[0].name == 'Shop 1' + assert album.shops[0].division.name == 'Div 1' + + track = await Track.objects.select_related('album').prefetch_related( + ["album__cover_pictures", "album__shops__division"]).get( + title="The Bird") + assert track.album.name == "Malibu" + assert len(track.album.cover_pictures) == 2 + assert track.album.cover_pictures[0].artist == 'Artist 1' + + assert len(track.album.shops) == 2 + assert track.album.shops[0].name == 'Shop 1' + assert track.album.shops[0].division.name == 'Div 1' From 2fd2bd9e3f083235a0359501ecf2accb1c3285c7 Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 25 Nov 2020 21:14:37 +0100 Subject: [PATCH 06/13] further optimization --- ormar/queryset/prefetch_query.py | 41 +++++++++++--------------------- 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 243e66e..9f1930c 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -150,30 +150,15 @@ class PrefetchQuery: return [] @staticmethod - def _get_model_id_and_field_name( - target_field: Type["BaseField"], model: "Model" - ) -> Tuple[bool, Optional[str], Optional[int]]: - if target_field.virtual: - is_multi = False - field_name = model.resolve_relation_name(target_field.to, model) - model_id = model.pk - elif issubclass(target_field, ManyToManyField): - is_multi = True - field_name = model.resolve_relation_name(target_field.through, model) - model_id = model.pk - else: - is_multi = False - related_name = model.resolve_relation_name(model, target_field.to) - related_model = getattr(model, related_name) - if not related_model: - return is_multi, None, None - model_id = related_model.pk - field_name = target_field.to.Meta.pkname - - return is_multi, field_name, model_id + def _get_model_id(target_field: Type["BaseField"], model: "Model") -> Optional[int]: + if target_field.virtual or issubclass(target_field, ManyToManyField): + return model.pk + related_name = model.resolve_relation_name(model, target_field.to) + related_model = getattr(model, related_name) + return None if not related_model else related_model.pk @staticmethod - def _get_group_field_name( + def _get_related_field_name( target_field: Type["BaseField"], model: Union["Model", Type["Model"]] ) -> str: if issubclass(target_field, ManyToManyField): @@ -202,13 +187,15 @@ class PrefetchQuery: for related in related_to_extract: target_field = model.Meta.model_fields[related] target_model = target_field.to.get_name() - is_multi, field_name, model_id = self._get_model_id_and_field_name( + model_id = self._get_model_id(target_field=target_field, model=model) + + if model_id is None: # pragma: no cover + continue + + field_name = self._get_related_field_name( target_field=target_field, model=model ) - if field_name is None or model_id is None: # pragma: no cover - continue - children = self.already_extracted.get(target_model, {}).get(field_name, {}) self._set_children_on_model( model=model, related=related, children=children, model_id=model_id @@ -384,7 +371,7 @@ class PrefetchQuery: ) -> None: target_model = target_field.to for row in rows: - field_name = self._get_group_field_name( + field_name = self._get_related_field_name( target_field=target_field, model=parent_model ) item = target_model.extract_prefixed_table_columns( From 3438928608eddc9cb4c4f9fc4d1a73540ebf221f Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 25 Nov 2020 21:18:27 +0100 Subject: [PATCH 07/13] change tests to run on all branches except docs --- .github/workflows/test-package.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml index 7bd02a9..a953a77 100644 --- a/.github/workflows/test-package.yml +++ b/.github/workflows/test-package.yml @@ -5,7 +5,8 @@ name: build on: push: - branches: [ master ] + branches-ignore: + - 'gh-pages' pull_request: branches: [ master ] From f8dbb7696520193a8fdb6ff609c0ef9d032e7f19 Mon Sep 17 00:00:00 2001 From: collerek Date: Thu, 26 Nov 2020 06:33:24 +0100 Subject: [PATCH 08/13] add aliases to test prefetch_related --- ormar/models/modelproxy.py | 4 +--- ormar/queryset/prefetch_query.py | 3 ++- tests/test_prefetch_related.py | 15 ++++++++------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 1ae257e..22db178 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -236,9 +236,7 @@ class ModelTableProxy: @staticmethod def _populate_pk_column( - model: Type["Model"], - columns: List[str], - use_alias: bool = False, + model: Type["Model"], columns: List[str], use_alias: bool = False, ) -> List[str]: pk_alias = ( model.get_column_alias(model.Meta.pkname) diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 9f1930c..ed4de9d 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -385,6 +385,7 @@ class PrefetchQuery: instance = self._populate_nested_related( model=instance, prefetch_dict=prefetch_dict, ) + field_db_name = target_model.get_column_alias(field_name) self.already_extracted[target_model.get_name()].setdefault( field_name, dict() - ).setdefault(row[field_name], []).append(instance) + ).setdefault(row[field_db_name], []).append(instance) diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index 378fdb4..cd0fb2e 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -17,7 +17,7 @@ class RandomSet(ormar.Model): metadata = metadata database = database - id: int = ormar.Integer(primary_key=True) + id: int = ormar.Integer(name='random_id', primary_key=True) name: str = ormar.String(max_length=100) @@ -28,7 +28,7 @@ class Tonation(ormar.Model): database = database id: int = ormar.Integer(primary_key=True) - name: str = ormar.String(max_length=100) + name: str = ormar.String(name='tonation_name', max_length=100) rand_set: Optional[RandomSet] = ormar.ForeignKey(RandomSet) @@ -38,7 +38,7 @@ class Division(ormar.Model): metadata = metadata database = database - id: int = ormar.Integer(primary_key=True) + id: int = ormar.Integer(name='division_id', primary_key=True) name: str = ormar.String(max_length=100) @@ -77,11 +77,11 @@ class Track(ormar.Model): metadata = metadata database = database - id: int = ormar.Integer(primary_key=True) + id: int = ormar.Integer(name='track_id', primary_key=True) album: Optional[Album] = ormar.ForeignKey(Album) title: str = ormar.String(max_length=100) position: int = ormar.Integer() - tonation: Optional[Tonation] = ormar.ForeignKey(Tonation) + tonation: Optional[Tonation] = ormar.ForeignKey(Tonation, name='tonation_id') class Cover(ormar.Model): @@ -91,7 +91,7 @@ class Cover(ormar.Model): database = database id: int = ormar.Integer(primary_key=True) - album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures") + album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures", name='album_id') title: str = ormar.String(max_length=100) artist: str = ormar.String(max_length=200, nullable=True) @@ -221,7 +221,8 @@ async def test_prefetch_related_with_select_related(): await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) - album = await Album.objects.select_related('tracks__tonation__rand_set').filter(name='Malibu').prefetch_related( + album = await Album.objects.select_related('tracks__tonation__rand_set').filter( + name='Malibu').prefetch_related( ['cover_pictures', 'shops__division']).get() assert len(album.tracks) == 3 assert album.tracks[0].tonation == album.tracks[2].tonation == ton1 From ba360974def203bd29f1ca16d3c350597489ee02 Mon Sep 17 00:00:00 2001 From: collerek Date: Thu, 26 Nov 2020 07:26:16 +0100 Subject: [PATCH 09/13] some refactors and cleanup --- ormar/models/modelproxy.py | 79 +++++++++++++++++++++++++- ormar/queryset/prefetch_query.py | 96 ++++++++------------------------ tests/test_prefetch_related.py | 53 +++++++++++++++++- 3 files changed, 151 insertions(+), 77 deletions(-) diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 22db178..3f31c0b 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -1,12 +1,15 @@ import inspect from collections import OrderedDict from typing import ( + Any, + Callable, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, + Tuple, Type, TypeVar, Union, @@ -38,6 +41,8 @@ class ModelTableProxy: Meta: ModelMeta _related_names: Set _related_names_hash: Union[str, bytes] + pk: Any + get_name: Callable def dict(self): # noqa A003 raise NotImplementedError # pragma no cover @@ -47,6 +52,66 @@ class ModelTableProxy: self_fields = {k: v for k, v in self.dict().items() if k not in related_names} return self_fields + @classmethod + def get_related_field_name(cls, target_field: Type["BaseField"]) -> str: + if issubclass(target_field, ormar.fields.ManyToManyField): + return cls.resolve_relation_name(target_field.through, cls) + if target_field.virtual: + return cls.resolve_relation_name(target_field.to, cls) + return target_field.to.Meta.pkname + + @staticmethod + def get_clause_target_and_filter_column_name( + parent_model: Type["Model"], target_model: Type["Model"], reverse: bool + ) -> Tuple[Type["Model"], str]: + if reverse: + field = target_model.resolve_relation_field(target_model, parent_model) + if issubclass(field, ormar.fields.ManyToManyField): + sub_field = target_model.resolve_relation_field( + field.through, parent_model + ) + return field.through, sub_field.get_alias() + else: + return target_model, field.get_alias() + target_field = target_model.get_column_alias(target_model.Meta.pkname) + return target_model, target_field + + @staticmethod + def get_column_name_for_id_extraction( + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool, + use_raw: bool, + ) -> str: + if reverse: + column_name = parent_model.Meta.pkname + return ( + parent_model.get_column_alias(column_name) if use_raw else column_name + ) + else: + column = target_model.resolve_relation_field(parent_model, target_model) + return column.get_alias() if use_raw else column.name + + @classmethod + def get_filtered_names_to_extract(cls, prefetch_dict: Dict) -> List: + related_to_extract = [] + if prefetch_dict and prefetch_dict is not Ellipsis: + related_to_extract = [ + related + for related in cls.extract_related_names() + if related in prefetch_dict + ] + return related_to_extract + + def get_relation_model_id(self, target_field: Type["BaseField"]) -> Optional[int]: + if target_field.virtual or issubclass( + target_field, ormar.fields.ManyToManyField + ): + return self.pk + related_name = self.resolve_relation_name(self, target_field.to) + related_model = getattr(self, related_name) + return None if not related_model else related_model.pk + @classmethod def extract_db_own_fields(cls) -> Set: related_names = cls.extract_related_names() @@ -155,8 +220,18 @@ class ModelTableProxy: @staticmethod def resolve_relation_name( # noqa CCR001 - item: Union["NewBaseModel", Type["NewBaseModel"]], - related: Union["NewBaseModel", Type["NewBaseModel"]], + item: Union[ + "NewBaseModel", + Type["NewBaseModel"], + "ModelTableProxy", + Type["ModelTableProxy"], + ], + related: Union[ + "NewBaseModel", + Type["NewBaseModel"], + "ModelTableProxy", + Type["ModelTableProxy"], + ], ) -> str: for name, field in item.Meta.model_fields.items(): if issubclass(field, ForeignKeyField): diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index ed4de9d..7afef1a 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -21,6 +21,17 @@ if TYPE_CHECKING: # pragma: no cover from ormar import Model +def add_relation_field_to_fields( + fields: Union[Set[Any], Dict[Any, Any], None], related_field_name: str +) -> Union[Set[Any], Dict[Any, Any], None]: + if fields and related_field_name not in fields: + if isinstance(fields, dict): + fields[related_field_name] = ... + elif isinstance(fields, set): + fields.add(related_field_name) + return fields + + class PrefetchQuery: def __init__( self, @@ -50,22 +61,6 @@ class PrefetchQuery: self.models[self.model.get_name()] = models return await self._prefetch_related_models(models=models, rows=rows) - @staticmethod - def _get_column_name_for_id_extraction( - parent_model: Type["Model"], - target_model: Type["Model"], - reverse: bool, - use_raw: bool, - ) -> str: - if reverse: - column_name = parent_model.Meta.pkname - return ( - parent_model.get_column_alias(column_name) if use_raw else column_name - ) - else: - column = target_model.resolve_relation_field(parent_model, target_model) - return column.get_alias() if use_raw else column.name - def _extract_ids_from_raw_data( self, parent_model: Type["Model"], column_name: str ) -> Set: @@ -96,7 +91,7 @@ class PrefetchQuery: use_raw = parent_model.get_name() not in self.models - column_name = self._get_column_name_for_id_extraction( + column_name = parent_model.get_column_name_for_id_extraction( parent_model=parent_model, target_model=target_model, reverse=reverse, @@ -112,22 +107,6 @@ class PrefetchQuery: parent_model=parent_model, column_name=column_name ) - @staticmethod - def _get_clause_target_and_filter_column_name( - parent_model: Type["Model"], target_model: Type["Model"], reverse: bool - ) -> Tuple[Type["Model"], str]: - if reverse: - field = target_model.resolve_relation_field(target_model, parent_model) - if issubclass(field, ManyToManyField): - sub_field = target_model.resolve_relation_field( - field.through, parent_model - ) - return field.through, sub_field.get_alias() - else: - return target_model, field.get_alias() - target_field = target_model.get_column_alias(target_model.Meta.pkname) - return target_model, target_field - def _get_filter_for_prefetch( self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, ) -> List: @@ -138,7 +117,7 @@ class PrefetchQuery: ( clause_target, filter_column, - ) = self._get_clause_target_and_filter_column_name( + ) = parent_model.get_clause_target_and_filter_column_name( parent_model=parent_model, target_model=target_model, reverse=reverse ) qryclause = QueryClause( @@ -149,52 +128,21 @@ class PrefetchQuery: return filter_clauses return [] - @staticmethod - def _get_model_id(target_field: Type["BaseField"], model: "Model") -> Optional[int]: - if target_field.virtual or issubclass(target_field, ManyToManyField): - return model.pk - related_name = model.resolve_relation_name(model, target_field.to) - related_model = getattr(model, related_name) - return None if not related_model else related_model.pk - - @staticmethod - def _get_related_field_name( - target_field: Type["BaseField"], model: Union["Model", Type["Model"]] - ) -> str: - if issubclass(target_field, ManyToManyField): - return model.resolve_relation_name(target_field.through, model) - if target_field.virtual: - return model.resolve_relation_name(target_field.to, model) - return target_field.to.Meta.pkname - - @staticmethod - def _get_names_to_extract(prefetch_dict: Dict, model: "Model") -> List: - related_to_extract = [] - if prefetch_dict and prefetch_dict is not Ellipsis: - related_to_extract = [ - related - for related in model.extract_related_names() - if related in prefetch_dict - ] - return related_to_extract - def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model": - related_to_extract = self._get_names_to_extract( - prefetch_dict=prefetch_dict, model=model + related_to_extract = model.get_filtered_names_to_extract( + prefetch_dict=prefetch_dict ) for related in related_to_extract: target_field = model.Meta.model_fields[related] target_model = target_field.to.get_name() - model_id = self._get_model_id(target_field=target_field, model=model) + model_id = model.get_relation_model_id(target_field=target_field) if model_id is None: # pragma: no cover continue - field_name = self._get_related_field_name( - target_field=target_field, model=model - ) + field_name = model.get_related_field_name(target_field=target_field) children = self.already_extracted.get(target_model, {}).get(field_name, {}) self._set_children_on_model( @@ -266,6 +214,12 @@ class PrefetchQuery: if not already_loaded: # If not already loaded with select_related + related_field_name = parent_model.get_related_field_name( + target_field=target_field + ) + fields = add_relation_field_to_fields( + fields=fields, related_field_name=related_field_name + ) table_prefix, rows = await self._run_prefetch_query( target_field=target_field, fields=fields, @@ -371,9 +325,7 @@ class PrefetchQuery: ) -> None: target_model = target_field.to for row in rows: - field_name = self._get_related_field_name( - target_field=target_field, model=parent_model - ) + field_name = parent_model.get_related_field_name(target_field=target_field) item = target_model.extract_prefixed_table_columns( item={}, row=row, diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index cd0fb2e..825bbd9 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -39,7 +39,7 @@ class Division(ormar.Model): database = database id: int = ormar.Integer(name='division_id', primary_key=True) - name: str = ormar.String(max_length=100) + name: str = ormar.String(max_length=100, nullable=True) class Shop(ormar.Model): @@ -49,7 +49,7 @@ class Shop(ormar.Model): database = database id: int = ormar.Integer(primary_key=True) - name: str = ormar.String(max_length=100) + name: str = ormar.String(max_length=100, nullable=True) division: Optional[Division] = ormar.ForeignKey(Division) @@ -67,7 +67,7 @@ class Album(ormar.Model): database = database id: int = ormar.Integer(primary_key=True) - name: str = ormar.String(max_length=100) + name: str = ormar.String(max_length=100, nullable=True) shops: List[Shop] = ormar.ManyToMany(to=Shop, through=AlbumShops) @@ -243,3 +243,50 @@ async def test_prefetch_related_with_select_related(): assert len(track.album.shops) == 2 assert track.album.shops[0].name == 'Shop 1' assert track.album.shops[0].division.name == 'Div 1' + + +@pytest.mark.asyncio +async def test_prefetch_related_with_select_related_and_fields(): + async with database: + async with database.transaction(force_rollback=True): + div = await Division.objects.create(name='Div 1') + shop1 = await Shop.objects.create(name='Shop 1', division=div) + shop2 = await Shop.objects.create(name='Shop 2', division=div) + album = Album(name="Malibu") + await album.save() + await album.shops.add(shop1) + await album.shops.add(shop2) + await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') + await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + rand_set = await RandomSet.objects.create(name='Rand 1') + ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set) + await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) + await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) + await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) + + album = await Album.objects.select_related('tracks__tonation__rand_set').filter( + name='Malibu').prefetch_related( + ['cover_pictures', 'shops__division']).exclude_fields({'shops': {'division': {'name'}}}).get() + assert len(album.tracks) == 3 + assert album.tracks[0].tonation == album.tracks[2].tonation == ton1 + assert len(album.cover_pictures) == 2 + assert album.cover_pictures[0].artist == 'Artist 1' + + assert len(album.shops) == 2 + assert album.shops[0].name == 'Shop 1' + assert album.shops[0].division.name is None + + album = await Album.objects.select_related('tracks').filter( + name='Malibu').prefetch_related( + ['cover_pictures', 'shops__division']).fields( + {'name': ..., 'shops': {'division'}, 'cover_pictures': {'id': ..., 'title': ...}} + ).exclude_fields({'shops': {'division': {'name'}}}).get() + assert len(album.tracks) == 3 + assert len(album.cover_pictures) == 2 + assert album.cover_pictures[0].artist is None + assert album.cover_pictures[0].title is not None + + assert len(album.shops) == 2 + assert album.shops[0].name is None + assert album.shops[0].division is not None + assert album.shops[0].division.name is None From 8a75379b44e0ef915489a0f1563f479e816511f7 Mon Sep 17 00:00:00 2001 From: collerek Date: Thu, 26 Nov 2020 09:15:13 +0100 Subject: [PATCH 10/13] update documentation, optimize for memory saving, update docs for select_related, fields and exclude_fields, bump version --- README.md | 7 +- docs/index.md | 6 +- docs/queries.md | 164 ++++++++++++++++++++++++++++++- docs/releases.md | 10 ++ ormar/__init__.py | 2 +- ormar/queryset/prefetch_query.py | 18 +++- tests/test_prefetch_related.py | 17 ++++ 7 files changed, 213 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 25fd50f..b4b1890 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Ormar is built with: * [`SQLAlchemy core`][sqlalchemy-core] for query building. * [`databases`][databases] for cross-database async support. * [`pydantic`][pydantic] for data validation. - * typing_extensions for python 3.6 - 3.7 + * `typing_extensions` for python 3.6 - 3.7 ### Migrations @@ -53,7 +53,8 @@ Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to p database migrations. -**ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.4.0` +**ormar is still under development:** +We recommend pinning any dependencies (with i.e. `ormar~=0.5.2`) ### Quick Start @@ -157,6 +158,7 @@ assert len(tracks) == 1 * `filter(**kwargs) -> QuerySet` * `exclude(**kwargs) -> QuerySet` * `select_related(related: Union[List, str]) -> QuerySet` +* `prefetch_related(related: Union[List, str]) -> QuerySet` * `limit(limit_count: int) -> QuerySet` * `offset(offset: int) -> QuerySet` * `count() -> int` @@ -165,6 +167,7 @@ assert len(tracks) == 1 * `exclude_fields(columns: Union[List, str, set, dict]) -> QuerySet` * `order_by(columns:Union[List, str]) -> QuerySet` + #### Relation types * One to many - with `ForeignKey(to: Model)` diff --git a/docs/index.md b/docs/index.md index f69de18..b4b1890 100644 --- a/docs/index.md +++ b/docs/index.md @@ -45,7 +45,7 @@ Ormar is built with: * [`SQLAlchemy core`][sqlalchemy-core] for query building. * [`databases`][databases] for cross-database async support. * [`pydantic`][pydantic] for data validation. - * typing_extensions for python 3.6 - 3.7 + * `typing_extensions` for python 3.6 - 3.7 ### Migrations @@ -53,7 +53,8 @@ Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to p database migrations. -**ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.4.0` +**ormar is still under development:** +We recommend pinning any dependencies (with i.e. `ormar~=0.5.2`) ### Quick Start @@ -157,6 +158,7 @@ assert len(tracks) == 1 * `filter(**kwargs) -> QuerySet` * `exclude(**kwargs) -> QuerySet` * `select_related(related: Union[List, str]) -> QuerySet` +* `prefetch_related(related: Union[List, str]) -> QuerySet` * `limit(limit_count: int) -> QuerySet` * `offset(offset: int) -> QuerySet` * `count() -> int` diff --git a/docs/queries.md b/docs/queries.md index f7f3d7b..9b1f5fa 100644 --- a/docs/queries.md +++ b/docs/queries.md @@ -253,11 +253,22 @@ notes = await Track.objects.exclude(position_gt=3).all() `select_related(related: Union[List, str]) -> QuerySet` -Allows to prefetch related models. +Allows to prefetch related models during the same query. + +**With `select_related` always only one query is run against the database**, meaning that one +(sometimes complicated) join is generated and later nested models are processed in python. To fetch related model use `ForeignKey` names. -To chain related `Models` relation use double underscore. +To chain related `Models` relation use double underscores between names. + +!!!note + If you are coming from `django` note that `ormar` `select_related` differs -> in `django` you can `select_related` + only singe relation types, while in `ormar` you can select related across `ForeignKey` relation, + reverse side of `ForeignKey` (so virtual auto generated keys) and `ManyToMany` fields (so all relations as of current version). + +!!!note + To control which model fields to select use `fields()` and `exclude_fields()` `QuerySet` methods. ```python album = await Album.objects.select_related("tracks").all() @@ -286,6 +297,147 @@ Exactly the same behavior is for Many2Many fields, where you put the names of Ma Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` +### prefetch_related + +`prefetch_related(related: Union[List, str]) -> QuerySet` + +Allows to prefetch related models during query - but opposite to `select_related` each +subsequent model is fetched in a separate database query. + +**With `prefetch_related` always one query per Model is run against the database**, +meaning that you will have multiple queries executed one after another. + +To fetch related model use `ForeignKey` names. + +To chain related `Models` relation use double underscores between names. + +!!!note + To control which model fields to select use `fields()` and `exclude_fields()` `QuerySet` methods. + +```python +album = await Album.objects.prefetch_related("tracks").all() +# will return album will all columns tracks +``` + +You can provide a string or a list of strings + +```python +classes = await SchoolClass.objects.prefetch_related( +["teachers__category", "students"]).all() +# will return classes with teachers and teachers categories +# as well as classes students +``` + +Exactly the same behavior is for Many2Many fields, where you put the names of Many2Many fields and the final `Models` are fetched for you. + +!!!warning + If you set `ForeignKey` field as not nullable (so required) during + all queries the not nullable `Models` will be auto prefetched, even if you do not include them in select_related. + +!!!note + All methods that do not return the rows explicitly returns a QueySet instance so you can chain them together + + So operations like `filter()`, `select_related()`, `limit()` and `offset()` etc. can be chained. + + Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` + +### select_related vs prefetch_related + +Which should you use -> `select_related` or `prefetch_related`? + +Well, it really depends on your data. The best answer is try yourself and see which one performs faster/better in your system constraints. + +What to keep in mind: + +#### Performance + +**Number of queries**: +`select_related` always executes one query against the database, while `prefetch_related` executes multiple queries. +Usually the query (I/O) operation is the slowest one but it does not have to be. + +**Number of rows**: +Imagine that you have 10 000 object in one table A and each of those objects have 3 children in table B, +and subsequently each object in table B has 2 children in table C. Something like this: + +``` + Model C + / + Model B - Model C + / +Model A - Model B - Model C + \ \ + \ Model C + \ + Model B - Model C + \ + Model C +``` + +That means that `select_related` will always return 60 000 rows (10 000 * 3 * 2) later compacted to 10 000 models. + +How many rows will return `prefetch_related`? + +Well, that depends, if each of models B and C is unique it will return 10 000 rows in first query, 30 000 rows +(each of 3 children of A in table B are unique) in second query and 60 000 rows (each of 2 children of model B +in table C are unique) in 3rd query. + +In this case `select_related` seems like a better choice, not only it will run one query comparing to 3 of +`prefetch_related` but will also return 60 000 rows comparing to 100 000 of `prefetch_related` (10+30+60k). + +But what if each Model A has exactly the same 3 models B and each models C has exactly same models C? `select_related` +will still return 60 000 rows, while `prefetch_related` will return 10 000 for model A, 3 rows for model B and 2 rows for Model C. +So in total 10 006 rows. Now depending on the structure of models (i.e. if it has long Text() fields etc.) `prefetch_related` +might be faster despite it needs to perform three separate queries instead of one. + +#### Memory + +`ormar` is a mini ORM meaning that it does not keep a registry of already loaded models. + +That means that in `select_related` example above you will always have 10 000 Models A, 30 000 Models B +(even if the unique number of rows in db is 3 - processing of `select_related` spawns **new** child models for each parent model). +And 60 000 Models C. + +If the same Model B is shared by rows 1, 10, 100 etc. and you update one of those, the rest of rows +that share the same child will **not** be updated on the spot. +If you persist your changes into the database the change **will be available only after reload +(either each child separately or the whole query again)**. +That means that `select_related` will use more memory as each child is instantiated as a new object - obviously using it's own space. + +!!!note + This might change in future versions if we decide to introduce caching. + +!!!warning + By default all children (or event the same models loaded 2+ times) are completely independent, distinct python objects, despite that they represent the same row in db. + + They will evaluate to True when compared, so in example above: + + ```python + # will return True if child1 of both rows is the same child db row + row1.child1 == row100.child1 + + # same here: + model1 = await Model.get(pk=1) + model2 = await Model.get(pk=1) # same pk = same row in db + # will return `True` + model1 == model2 + ``` + + but + + ```python + # will return False (note that id is a python `builtin` function not ormar one). + id(row1.child1) == (ro100.child1) + + # from above - will also return False + id(model1) == id(model2) + ``` + + +On the contrary - with `prefetch_related` each unique distinct child model is instantiated +only once and the same child models is shared across all parent models. +That means that in `prefetch_related` example above if there are 3 distinct models in table B and 2 in table C, +there will be only 5 children nested models shared between all model A instances. That also means that if you update +any attribute it will be updated on all parents as they share the same child object. ### limit @@ -352,6 +504,10 @@ has_sample = await Book.objects.filter(title='Sample').exists() With `fields()` you can select subset of model columns to limit the data load. +!!!note + Note that `fields()` and `exclude_fields()` works both for main models (on normal queries like `get`, `all` etc.) + as well as `select_related` and `prefetch_related` models (with nested notation). + Given a sample data like following: ```python @@ -433,6 +589,10 @@ It's the opposite of `fields()` method so check documentation above to see what Especially check above how you can pass also nested dictionaries and sets as a mask to exclude fields from whole hierarchy. +!!!note + Note that `fields()` and `exclude_fields()` works both for main models (on normal queries like `get`, `all` etc.) + as well as `select_related` and `prefetch_related` models (with nested notation). + Below you can find few simple examples: ```python hl_lines="47 48 60 61 67" diff --git a/docs/releases.md b/docs/releases.md index e636ceb..cdbd6b0 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,13 @@ +# 0.5.2 + +* Added `prefetch_related` method to load subsequent models in separate queries. +* Update docs + +# 0.5.1 + +* Switched to github actions instead of travis +* Update badges in the docs + # 0.5.0 * Added save status -> you can check if model is saved with `ModelInstance.saved` property diff --git a/ormar/__init__.py b/ormar/__init__.py index 87f1b5b..cd3712a 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.5.1" +__version__ = "0.5.2" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 7afef1a..788fe5d 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -145,20 +145,25 @@ class PrefetchQuery: field_name = model.get_related_field_name(target_field=target_field) children = self.already_extracted.get(target_model, {}).get(field_name, {}) + models = self.already_extracted.get(target_model, {}).get("pk_models", {}) self._set_children_on_model( - model=model, related=related, children=children, model_id=model_id + model=model, + related=related, + children=children, + model_id=model_id, + models=models, ) return model @staticmethod def _set_children_on_model( - model: "Model", related: str, children: Dict, model_id: int + model: "Model", related: str, children: Dict, model_id: int, models: Dict ) -> None: for key, child_models in children.items(): if key == model_id: for child in child_models: - setattr(model, related, child) + setattr(model, related, models.get(child)) async def _prefetch_related_models( self, models: Sequence["Model"], rows: List @@ -338,6 +343,11 @@ class PrefetchQuery: model=instance, prefetch_dict=prefetch_dict, ) field_db_name = target_model.get_column_alias(field_name) + models = self.already_extracted[target_model.get_name()].setdefault( + "pk_models", {} + ) + if instance.pk not in models: + models[instance.pk] = instance self.already_extracted[target_model.get_name()].setdefault( field_name, dict() - ).setdefault(row[field_db_name], []).append(instance) + ).setdefault(row[field_db_name], set()).add(instance.pk) diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index 825bbd9..4d3fba9 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -183,6 +183,23 @@ async def test_prefetch_related_with_many_to_many(): assert track.album.shops[0].name == 'Shop 1' assert track.album.shops[0].division.name == 'Div 1' + album2 = Album(name="Malibu 2") + await album2.save() + await album2.shops.add(shop1) + await album2.shops.add(shop2) + await Track.objects.create(album=album2, title="The Bird 2", position=1) + + tracks = await Track.objects.prefetch_related(["album__shops"]).all() + assert tracks[0].album.name == 'Malibu' + assert tracks[0].album.shops[0].name == "Shop 1" + assert tracks[3].album.name == 'Malibu 2' + assert tracks[3].album.shops[0].name == "Shop 1" + + assert tracks[0].album.shops[0] == tracks[3].album.shops[0] + assert id(tracks[0].album.shops[0]) == id(tracks[3].album.shops[0]) + tracks[0].album.shops[0].name = 'Dummy' + assert tracks[0].album.shops[0].name == tracks[3].album.shops[0].name + @pytest.mark.asyncio async def test_prefetch_related_empty(): From 78d1241807709782eaa11978d78ac53dc662f5db Mon Sep 17 00:00:00 2001 From: collerek Date: Thu, 26 Nov 2020 11:17:33 +0100 Subject: [PATCH 11/13] add order_by support for prefetch_related --- ormar/queryset/prefetch_query.py | 74 ++++++++++++++++++++++++-------- ormar/queryset/queryset.py | 1 + ormar/queryset/utils.py | 14 +++++- tests/test_prefetch_related.py | 8 ++-- tests/test_queryset_utils.py | 46 ++++++++++++++++++++ 5 files changed, 121 insertions(+), 22 deletions(-) diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 788fe5d..198c9b6 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -32,14 +32,48 @@ def add_relation_field_to_fields( return fields +def sort_models(models: List["Model"], orders_by: Dict) -> List["Model"]: + sort_criteria = [ + (key, value) for key, value in orders_by.items() if isinstance(value, str) + ] + sort_criteria = sort_criteria[::-1] + for criteria in sort_criteria: + if criteria[1] == "desc": + models.sort(key=lambda x: getattr(x, criteria[0]), reverse=True) + else: + models.sort(key=lambda x: getattr(x, criteria[0])) + return models + + +def set_children_on_model( # noqa: CCR001 + model: "Model", + related: str, + children: Dict, + model_id: int, + models: Dict, + orders_by: Dict, +) -> None: + for key, child_models in children.items(): + if key == model_id: + models_to_set = [models[child] for child in sorted(child_models)] + if models_to_set: + if orders_by and any(isinstance(x, str) for x in orders_by.values()): + models_to_set = sort_models( + models=models_to_set, orders_by=orders_by + ) + for child in models_to_set: + setattr(model, related, child) + + class PrefetchQuery: - def __init__( + def __init__( # noqa: CFQ002 self, model_cls: Type["Model"], fields: Optional[Union[Dict, Set]], exclude_fields: Optional[Union[Dict, Set]], prefetch_related: List, select_related: List, + orders_by: List, ) -> None: self.model = model_cls @@ -51,6 +85,8 @@ class PrefetchQuery: self.already_extracted: Dict = dict() self.models: Dict = {} self.select_dict = translate_list_to_dict(self._select_related) + self.orders_by = orders_by or [] + self.order_dict = translate_list_to_dict(self.orders_by, is_order=True) async def prefetch_related( self, models: Sequence["Model"], rows: List @@ -128,7 +164,9 @@ class PrefetchQuery: return filter_clauses return [] - def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model": + def _populate_nested_related( + self, model: "Model", prefetch_dict: Dict, orders_by: Dict, + ) -> "Model": related_to_extract = model.get_filtered_names_to_extract( prefetch_dict=prefetch_dict @@ -146,25 +184,17 @@ class PrefetchQuery: children = self.already_extracted.get(target_model, {}).get(field_name, {}) models = self.already_extracted.get(target_model, {}).get("pk_models", {}) - self._set_children_on_model( + set_children_on_model( model=model, related=related, children=children, model_id=model_id, models=models, + orders_by=orders_by.get(related, {}), ) return model - @staticmethod - def _set_children_on_model( - model: "Model", related: str, children: Dict, model_id: int, models: Dict - ) -> None: - for key, child_models in children.items(): - if key == model_id: - for child in child_models: - setattr(model, related, models.get(child)) - async def _prefetch_related_models( self, models: Sequence["Model"], rows: List ) -> Sequence["Model"]: @@ -174,6 +204,7 @@ class PrefetchQuery: target_model = self.model fields = self._columns exclude_fields = self._exclude_columns + orders_by = self.order_dict for related in prefetch_dict.keys(): await self._extract_related_models( related=related, @@ -182,11 +213,14 @@ class PrefetchQuery: select_dict=select_dict.get(related, {}), fields=fields, exclude_fields=exclude_fields, + orders_by=orders_by.get(related, {}), ) final_models = [] for model in models: final_models.append( - self._populate_nested_related(model=model, prefetch_dict=prefetch_dict,) + self._populate_nested_related( + model=model, prefetch_dict=prefetch_dict, orders_by=self.order_dict + ) ) return models @@ -198,6 +232,7 @@ class PrefetchQuery: select_dict: Dict, fields: Union[Set[Any], Dict[Any, Any], None], exclude_fields: Union[Set[Any], Dict[Any, Any], None], + orders_by: Dict, ) -> None: fields = target_model.get_included(fields, related) @@ -246,6 +281,7 @@ class PrefetchQuery: ), fields=fields, exclude_fields=exclude_fields, + orders_by=self._get_select_related_if_apply(subrelated, orders_by), ) if not already_loaded: @@ -257,10 +293,13 @@ class PrefetchQuery: fields=fields, exclude_fields=exclude_fields, prefetch_dict=prefetch_dict, + orders_by=orders_by, ) else: self._update_already_loaded_rows( - target_field=target_field, prefetch_dict=prefetch_dict, + target_field=target_field, + prefetch_dict=prefetch_dict, + orders_by=orders_by, ) async def _run_prefetch_query( @@ -310,12 +349,12 @@ class PrefetchQuery: ) def _update_already_loaded_rows( # noqa: CFQ002 - self, target_field: Type["BaseField"], prefetch_dict: Dict, + self, target_field: Type["BaseField"], prefetch_dict: Dict, orders_by: Dict, ) -> None: target_model = target_field.to for instance in self.models.get(target_model.get_name(), []): self._populate_nested_related( - model=instance, prefetch_dict=prefetch_dict, + model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by ) def _populate_rows( # noqa: CFQ002 @@ -327,6 +366,7 @@ class PrefetchQuery: fields: Union[Set[Any], Dict[Any, Any], None], exclude_fields: Union[Set[Any], Dict[Any, Any], None], prefetch_dict: Dict, + orders_by: Dict, ) -> None: target_model = target_field.to for row in rows: @@ -340,7 +380,7 @@ class PrefetchQuery: ) instance = target_model(**item) instance = self._populate_nested_related( - model=instance, prefetch_dict=prefetch_dict, + model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by ) field_db_name = target_model.get_column_alias(field_name) models = self.already_extracted[target_model.get_name()].setdefault( diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index bc36965..093abf7 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -75,6 +75,7 @@ class QuerySet: exclude_fields=self._exclude_columns, prefetch_related=self._prefetch_related, select_related=self._select_related, + orders_by=self.order_bys, ) return await query.prefetch_related(models=models, rows=rows) # type: ignore diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index ed422dd..bed2e25 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -14,18 +14,28 @@ def check_node_not_dict_or_not_last_node( ) -def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001 +def translate_list_to_dict( # noqa: CCR001 + list_to_trans: Union[List, Set], is_order: bool = False +) -> Dict: new_dict: Dict = dict() for path in list_to_trans: current_level = new_dict parts = path.split("__") + def_val: Any = ... + if is_order: + if parts[0][0] == "-": + def_val = "desc" + parts[0] = parts[0][1:] + else: + def_val = "asc" + for part in parts: if check_node_not_dict_or_not_last_node( part=part, parts=parts, current_level=current_level ): current_level[part] = dict() elif part not in current_level: - current_level[part] = ... + current_level[part] = def_val current_level = current_level[part] return new_dict diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index 4d3fba9..bfc09d7 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -228,6 +228,7 @@ async def test_prefetch_related_with_select_related(): album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related( ['cover_pictures', 'shops__division']).get() + assert len(album.tracks) == 0 assert len(album.cover_pictures) == 2 assert album.shops[0].division.name == 'Div 1' @@ -240,14 +241,15 @@ async def test_prefetch_related_with_select_related(): album = await Album.objects.select_related('tracks__tonation__rand_set').filter( name='Malibu').prefetch_related( - ['cover_pictures', 'shops__division']).get() + ['cover_pictures', 'shops__division']).order_by( + ['-shops__name', '-cover_pictures__artist', 'shops__division__name']).get() assert len(album.tracks) == 3 assert album.tracks[0].tonation == album.tracks[2].tonation == ton1 assert len(album.cover_pictures) == 2 - assert album.cover_pictures[0].artist == 'Artist 1' + assert album.cover_pictures[0].artist == 'Artist 2' assert len(album.shops) == 2 - assert album.shops[0].name == 'Shop 1' + assert album.shops[0].name == 'Shop 2' assert album.shops[0].division.name == 'Div 1' track = await Track.objects.select_related('album').prefetch_related( diff --git a/tests/test_queryset_utils.py b/tests/test_queryset_utils.py index bac0a26..97695b9 100644 --- a/tests/test_queryset_utils.py +++ b/tests/test_queryset_utils.py @@ -1,5 +1,11 @@ +import databases +import sqlalchemy + +import ormar from ormar.models.excludable import Excludable +from ormar.queryset.prefetch_query import sort_models from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update +from tests.settings import DATABASE_URL def test_empty_excludable(): @@ -96,3 +102,43 @@ def test_updating_dict_inc_set_with_dict_inc_set(): "cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis}, "uu": Ellipsis, } + + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class SortModel(ormar.Model): + class Meta: + tablename = "sorts" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + sort_order: int = ormar.Integer() + + +def test_sorting_models(): + models = [ + SortModel(id=1, name='Alice', sort_order=0), + SortModel(id=2, name='Al', sort_order=1), + SortModel(id=3, name='Zake', sort_order=1), + SortModel(id=4, name='Will', sort_order=0), + SortModel(id=5, name='Al', sort_order=2), + SortModel(id=6, name='Alice', sort_order=2) + ] + orders_by = {'name': 'asc', 'none': {}, 'sort_order': 'desc'} + models = sort_models(models, orders_by) + assert models[5].name == 'Zake' + assert models[0].name == 'Al' + assert models[1].name == 'Al' + assert [model.id for model in models] == [5, 2, 6, 1, 4, 3] + + orders_by = {'name': 'asc', 'none': set('aa'), 'id': 'asc'} + models = sort_models(models, orders_by) + assert [model.id for model in models] == [2, 5, 1, 6, 4, 3] + + orders_by = {'sort_order': 'asc', 'none': ..., 'id': 'asc', 'uu': 2, 'aa': None} + models = sort_models(models, orders_by) + assert [model.id for model in models] == [1, 4, 2, 3, 5, 6] From 72b0336b75d6adce6edce85102932363ebf9a0b7 Mon Sep 17 00:00:00 2001 From: collerek Date: Thu, 26 Nov 2020 12:22:12 +0100 Subject: [PATCH 12/13] update docs with tip on order_by --- docs/queries.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/queries.md b/docs/queries.md index 9b1f5fa..c848bc1 100644 --- a/docs/queries.md +++ b/docs/queries.md @@ -267,8 +267,11 @@ To chain related `Models` relation use double underscores between names. only singe relation types, while in `ormar` you can select related across `ForeignKey` relation, reverse side of `ForeignKey` (so virtual auto generated keys) and `ManyToMany` fields (so all relations as of current version). -!!!note +!!!tip To control which model fields to select use `fields()` and `exclude_fields()` `QuerySet` methods. + +!!!tip + To control order of models (both main or nested) use `order_by()` method. ```python album = await Album.objects.select_related("tracks").all() @@ -311,8 +314,11 @@ To fetch related model use `ForeignKey` names. To chain related `Models` relation use double underscores between names. -!!!note +!!!tip To control which model fields to select use `fields()` and `exclude_fields()` `QuerySet` methods. + +!!!tip + To control order of models (both main or nested) use `order_by()` method. ```python album = await Album.objects.prefetch_related("tracks").all() From 164ea17c734e96d4f520cd0e665f404a141a3fa0 Mon Sep 17 00:00:00 2001 From: collerek Date: Thu, 26 Nov 2020 12:31:56 +0100 Subject: [PATCH 13/13] fix minor code smells --- ormar/models/modelproxy.py | 8 +++----- ormar/queryset/prefetch_query.py | 7 ++++--- ormar/queryset/queryset.py | 3 +-- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 3f31c0b..06b5060 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -71,8 +71,7 @@ class ModelTableProxy: field.through, parent_model ) return field.through, sub_field.get_alias() - else: - return target_model, field.get_alias() + return target_model, field.get_alias() target_field = target_model.get_column_alias(target_model.Meta.pkname) return target_model, target_field @@ -88,9 +87,8 @@ class ModelTableProxy: return ( parent_model.get_column_alias(column_name) if use_raw else column_name ) - else: - column = target_model.resolve_relation_field(parent_model, target_model) - return column.get_alias() if use_raw else column.name + column = target_model.resolve_relation_field(parent_model, target_model) + return column.get_alias() if use_raw else column.name @classmethod def get_filtered_names_to_extract(cls, prefetch_dict: Dict) -> List: diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 198c9b6..13ad785 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -38,10 +38,11 @@ def sort_models(models: List["Model"], orders_by: Dict) -> List["Model"]: ] sort_criteria = sort_criteria[::-1] for criteria in sort_criteria: - if criteria[1] == "desc": - models.sort(key=lambda x: getattr(x, criteria[0]), reverse=True) + key, value = criteria + if value == "desc": + models.sort(key=lambda x: getattr(x, key), reverse=True) else: - models.sort(key=lambda x: getattr(x, criteria[0])) + models.sort(key=lambda x: getattr(x, key)) return models diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 093abf7..02a0566 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -51,8 +51,7 @@ class QuerySet: ) -> "QuerySet": if issubclass(owner, ormar.Model): return self.__class__(model_cls=owner) - else: # pragma nocover - return self.__class__() + return self.__class__() # pragma: no cover @property def model_meta(self) -> "ModelMeta":