dirty many to many pass first test

This commit is contained in:
collerek
2020-11-23 17:03:31 +01:00
parent b696156f56
commit 585bba3ad3
2 changed files with 97 additions and 20 deletions

View File

@ -57,7 +57,17 @@ class PrefetchQuery:
) )
if reverse: if reverse:
field = target_model.resolve_relation_field(target_model, parent_model) field = target_model.resolve_relation_field(target_model, parent_model)
kwargs = {f'{field.get_alias()}__in': ids} if issubclass(field, ManyToManyField):
sub_field = target_model.resolve_relation_field(field.through, parent_model)
kwargs = {f'{sub_field.get_alias()}__in': ids}
qryclause = QueryClause(
model_cls=field.through,
select_related=[],
filter_clauses=[],
)
else:
kwargs = {f'{field.get_alias()}__in': ids}
else: else:
target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias() target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias()
kwargs = {f'{target_field}__in': ids} kwargs = {f'{target_field}__in': ids}
@ -73,13 +83,15 @@ class PrefetchQuery:
reverse = False reverse = False
target_field = model.Meta.model_fields[related] target_field = model.Meta.model_fields[related]
if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True
target_model = target_field.to.get_name() target_model = target_field.to.get_name()
if reverse: if target_field.virtual:
reverse = True
field_name = model.resolve_relation_name(target_field.to, model) field_name = model.resolve_relation_name(target_field.to, model)
model_id = model.pk model_id = model.pk
elif issubclass(target_field, ManyToManyField):
reverse = True
field_name = model.resolve_relation_name(target_field.through, model)
model_id = model.pk
else: else:
related_name = model.resolve_relation_name(model, target_field.to) related_name = model.resolve_relation_name(model, target_field.to)
related_model = getattr(model, related_name) related_model = getattr(model, related_name)
@ -89,17 +101,16 @@ class PrefetchQuery:
field_name = target_field.to.Meta.pkname field_name = target_field.to.Meta.pkname
if target_model in already_extracted and already_extracted[target_model]['models']: if target_model in already_extracted and already_extracted[target_model]['models']:
print('*****POPULATING RELATED:', target_model, field_name) print('*****POPULATING RELATED:', target_model, field_name, '*****', end='\n')
print(already_extracted[target_model]['models']) print(already_extracted[target_model]['models'])
for child_model in already_extracted[target_model]['models']: for ind, child_model in enumerate(already_extracted[target_model]['models']):
related_model = getattr(child_model, field_name) if issubclass(target_field, ManyToManyField):
if isinstance(related_model, list): raw_data = already_extracted[target_model]['raw'][ind]
for child in related_model: if raw_data[field_name] == model_id:
if child.pk == model_id: setattr(model, related, child_model)
setattr(model, related, child)
elif isinstance(related_model, ormar.Model): elif isinstance(getattr(child_model, field_name), ormar.Model):
if related_model.pk == model_id: if getattr(child_model, field_name).pk == model_id:
if reverse: if reverse:
setattr(model, related, child_model) setattr(model, related, child_model)
else: else:
@ -123,6 +134,7 @@ class PrefetchQuery:
exclude_fields = self._exclude_columns exclude_fields = self._exclude_columns
for part in related.split('__'): for part in related.split('__'):
fields = target_model.get_included(fields, part) fields = target_model.get_included(fields, part)
select_related = []
exclude_fields = target_model.get_excluded(exclude_fields, part) exclude_fields = target_model.get_excluded(exclude_fields, part)
target_field = target_model.Meta.model_fields[part] target_field = target_model.Meta.model_fields[part]
@ -130,6 +142,9 @@ class PrefetchQuery:
if target_field.virtual or issubclass(target_field, ManyToManyField): if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True reverse = True
if issubclass(target_field, ManyToManyField):
select_related = [target_field.through.get_name()]
parent_model = target_model parent_model = target_model
target_model = target_field.to target_model = target_field.to
@ -141,9 +156,17 @@ class PrefetchQuery:
if not filter_clauses: # related field is empty if not filter_clauses: # related field is empty
continue continue
query_target = target_model
table_prefix = ''
if issubclass(target_field, ManyToManyField):
query_target = target_field.through
select_related = [target_field.to.get_name()]
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=query_target.Meta.tablename, to_table=target_field.to.Meta.tablename)
qry = Query( qry = Query(
model_cls=target_model, model_cls=query_target,
select_related=[], select_related=select_related,
filter_clauses=filter_clauses, filter_clauses=filter_clauses,
exclude_clauses=[], exclude_clauses=[],
offset=None, offset=None,
@ -161,7 +184,7 @@ class PrefetchQuery:
item = target_model.extract_prefixed_table_columns( item = target_model.extract_prefixed_table_columns(
item={}, item={},
row=row, row=row,
table_prefix='', table_prefix=table_prefix,
fields=fields, fields=fields,
exclude_fields=exclude_fields exclude_fields=exclude_fields
) )
@ -174,12 +197,22 @@ class PrefetchQuery:
for part in related.split('__')[:-1]: for part in related.split('__')[:-1]:
fields = target_model.get_included(fields, part) fields = target_model.get_included(fields, part)
exclude_fields = target_model.get_excluded(exclude_fields, part) exclude_fields = target_model.get_excluded(exclude_fields, part)
target_model = target_model.Meta.model_fields[part].to
target_field = target_model.Meta.model_fields[part]
target_model = target_field.to
table_prefix = ''
if issubclass(target_field, ManyToManyField):
from_table = target_field.through.Meta.tablename
to_name = target_field.to.Meta.tablename
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=from_table, to_table=to_name)
for row in already_extracted.get(target_model.get_name(), {}).get('raw', []): for row in already_extracted.get(target_model.get_name(), {}).get('raw', []):
item = target_model.extract_prefixed_table_columns( item = target_model.extract_prefixed_table_columns(
item={}, item={},
row=row, row=row,
table_prefix='', table_prefix=table_prefix,
fields=fields, fields=fields,
exclude_fields=exclude_fields exclude_fields=exclude_fields
) )

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import List, Optional
import databases import databases
import pytest import pytest
@ -21,6 +21,23 @@ class Tonation(ormar.Model):
name: str = ormar.String(max_length=100) name: str = ormar.String(max_length=100)
class Shop(ormar.Model):
class Meta:
tablename = "shops"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class AlbumShops(ormar.Model):
class Meta:
tablename = "albums_x_shops"
metadata = metadata
database = database
class Album(ormar.Model): class Album(ormar.Model):
class Meta: class Meta:
tablename = "albums" tablename = "albums"
@ -29,6 +46,7 @@ class Album(ormar.Model):
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100) name: str = ormar.String(max_length=100)
shops: List[Shop] = ormar.ManyToMany(to=Shop, through=AlbumShops)
class Track(ormar.Model): class Track(ormar.Model):
@ -115,6 +133,32 @@ async def test_prefetch_related():
assert len(tracks) == 6 assert len(tracks) == 6
@pytest.mark.asyncio
async def test_prefetch_related_with_many_to_many():
async with database:
async with database.transaction(force_rollback=True):
shop1 = await Shop.objects.create(name='Shop 1')
shop2 = await Shop.objects.create(name='Shop 2')
album = Album(name="Malibu")
await album.save()
await album.shops.add(shop1)
await album.shops.add(shop2)
await Track.objects.create(album=album, title="The Bird", position=1)
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2)
await Track.objects.create(album=album, title="The Waters", position=3)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops"]).get(
title="The Bird")
assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1'
assert len(track.album.shops) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_prefetch_related_empty(): async def test_prefetch_related_empty():
async with database: async with database: