From 585bba3ad375ef3487cc2c6f9791bb810e09823c Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 23 Nov 2020 17:03:31 +0100 Subject: [PATCH] dirty many to many pass first test --- ormar/queryset/prefetch_query.py | 71 +++++++++++++++++++++++--------- tests/test_prefetch_related.py | 46 ++++++++++++++++++++- 2 files changed, 97 insertions(+), 20 deletions(-) diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 0fca239..6cc9ec3 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -57,7 +57,17 @@ class PrefetchQuery: ) if reverse: field = target_model.resolve_relation_field(target_model, parent_model) - kwargs = {f'{field.get_alias()}__in': ids} + if issubclass(field, ManyToManyField): + sub_field = target_model.resolve_relation_field(field.through, parent_model) + kwargs = {f'{sub_field.get_alias()}__in': ids} + qryclause = QueryClause( + model_cls=field.through, + select_related=[], + filter_clauses=[], + ) + + else: + kwargs = {f'{field.get_alias()}__in': ids} else: target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias() kwargs = {f'{target_field}__in': ids} @@ -73,13 +83,15 @@ class PrefetchQuery: reverse = False target_field = model.Meta.model_fields[related] - if target_field.virtual or issubclass(target_field, ManyToManyField): - reverse = True - target_model = target_field.to.get_name() - if reverse: + if target_field.virtual: + reverse = True field_name = model.resolve_relation_name(target_field.to, model) model_id = model.pk + elif issubclass(target_field, ManyToManyField): + reverse = True + field_name = model.resolve_relation_name(target_field.through, model) + model_id = model.pk else: related_name = model.resolve_relation_name(model, target_field.to) related_model = getattr(model, related_name) @@ -89,17 +101,16 @@ class PrefetchQuery: field_name = target_field.to.Meta.pkname if target_model in already_extracted and already_extracted[target_model]['models']: - print('*****POPULATING RELATED:', target_model, field_name) + print('*****POPULATING RELATED:', target_model, field_name, '*****', end='\n') print(already_extracted[target_model]['models']) - for child_model in already_extracted[target_model]['models']: - related_model = getattr(child_model, field_name) - if isinstance(related_model, list): - for child in related_model: - if child.pk == model_id: - setattr(model, related, child) + for ind, child_model in enumerate(already_extracted[target_model]['models']): + if issubclass(target_field, ManyToManyField): + raw_data = already_extracted[target_model]['raw'][ind] + if raw_data[field_name] == model_id: + setattr(model, related, child_model) - elif isinstance(related_model, ormar.Model): - if related_model.pk == model_id: + elif isinstance(getattr(child_model, field_name), ormar.Model): + if getattr(child_model, field_name).pk == model_id: if reverse: setattr(model, related, child_model) else: @@ -123,6 +134,7 @@ class PrefetchQuery: exclude_fields = self._exclude_columns for part in related.split('__'): fields = target_model.get_included(fields, part) + select_related = [] exclude_fields = target_model.get_excluded(exclude_fields, part) target_field = target_model.Meta.model_fields[part] @@ -130,6 +142,9 @@ class PrefetchQuery: if target_field.virtual or issubclass(target_field, ManyToManyField): reverse = True + if issubclass(target_field, ManyToManyField): + select_related = [target_field.through.get_name()] + parent_model = target_model target_model = target_field.to @@ -141,9 +156,17 @@ class PrefetchQuery: if not filter_clauses: # related field is empty continue + query_target = target_model + table_prefix = '' + if issubclass(target_field, ManyToManyField): + query_target = target_field.through + select_related = [target_field.to.get_name()] + table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( + from_table=query_target.Meta.tablename, to_table=target_field.to.Meta.tablename) + qry = Query( - model_cls=target_model, - select_related=[], + model_cls=query_target, + select_related=select_related, filter_clauses=filter_clauses, exclude_clauses=[], offset=None, @@ -161,7 +184,7 @@ class PrefetchQuery: item = target_model.extract_prefixed_table_columns( item={}, row=row, - table_prefix='', + table_prefix=table_prefix, fields=fields, exclude_fields=exclude_fields ) @@ -174,12 +197,22 @@ class PrefetchQuery: for part in related.split('__')[:-1]: fields = target_model.get_included(fields, part) exclude_fields = target_model.get_excluded(exclude_fields, part) - target_model = target_model.Meta.model_fields[part].to + + target_field = target_model.Meta.model_fields[part] + target_model = target_field.to + table_prefix = '' + + if issubclass(target_field, ManyToManyField): + from_table = target_field.through.Meta.tablename + to_name = target_field.to.Meta.tablename + table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( + from_table=from_table, to_table=to_name) + for row in already_extracted.get(target_model.get_name(), {}).get('raw', []): item = target_model.extract_prefixed_table_columns( item={}, row=row, - table_prefix='', + table_prefix=table_prefix, fields=fields, exclude_fields=exclude_fields ) diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index 2582587..3b098c4 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import List, Optional import databases import pytest @@ -21,6 +21,23 @@ class Tonation(ormar.Model): name: str = ormar.String(max_length=100) +class Shop(ormar.Model): + class Meta: + tablename = "shops" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class AlbumShops(ormar.Model): + class Meta: + tablename = "albums_x_shops" + metadata = metadata + database = database + + class Album(ormar.Model): class Meta: tablename = "albums" @@ -29,6 +46,7 @@ class Album(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) + shops: List[Shop] = ormar.ManyToMany(to=Shop, through=AlbumShops) class Track(ormar.Model): @@ -115,6 +133,32 @@ async def test_prefetch_related(): assert len(tracks) == 6 +@pytest.mark.asyncio +async def test_prefetch_related_with_many_to_many(): + async with database: + async with database.transaction(force_rollback=True): + shop1 = await Shop.objects.create(name='Shop 1') + shop2 = await Shop.objects.create(name='Shop 2') + album = Album(name="Malibu") + await album.save() + await album.shops.add(shop1) + await album.shops.add(shop2) + + await Track.objects.create(album=album, title="The Bird", position=1) + await Track.objects.create(album=album, title="Heart don't stand a chance", position=2) + await Track.objects.create(album=album, title="The Waters", position=3) + await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') + await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + + track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops"]).get( + title="The Bird") + assert track.album.name == "Malibu" + assert len(track.album.cover_pictures) == 2 + assert track.album.cover_pictures[0].artist == 'Artist 1' + + assert len(track.album.shops) == 2 + + @pytest.mark.asyncio async def test_prefetch_related_empty(): async with database: