dirty many to many pass first test

This commit is contained in:
collerek
2020-11-23 17:03:31 +01:00
parent b696156f56
commit 585bba3ad3
2 changed files with 97 additions and 20 deletions

View File

@ -57,6 +57,16 @@ class PrefetchQuery:
)
if reverse:
field = target_model.resolve_relation_field(target_model, parent_model)
if issubclass(field, ManyToManyField):
sub_field = target_model.resolve_relation_field(field.through, parent_model)
kwargs = {f'{sub_field.get_alias()}__in': ids}
qryclause = QueryClause(
model_cls=field.through,
select_related=[],
filter_clauses=[],
)
else:
kwargs = {f'{field.get_alias()}__in': ids}
else:
target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias()
@ -73,13 +83,15 @@ class PrefetchQuery:
reverse = False
target_field = model.Meta.model_fields[related]
if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True
target_model = target_field.to.get_name()
if reverse:
if target_field.virtual:
reverse = True
field_name = model.resolve_relation_name(target_field.to, model)
model_id = model.pk
elif issubclass(target_field, ManyToManyField):
reverse = True
field_name = model.resolve_relation_name(target_field.through, model)
model_id = model.pk
else:
related_name = model.resolve_relation_name(model, target_field.to)
related_model = getattr(model, related_name)
@ -89,17 +101,16 @@ class PrefetchQuery:
field_name = target_field.to.Meta.pkname
if target_model in already_extracted and already_extracted[target_model]['models']:
print('*****POPULATING RELATED:', target_model, field_name)
print('*****POPULATING RELATED:', target_model, field_name, '*****', end='\n')
print(already_extracted[target_model]['models'])
for child_model in already_extracted[target_model]['models']:
related_model = getattr(child_model, field_name)
if isinstance(related_model, list):
for child in related_model:
if child.pk == model_id:
setattr(model, related, child)
for ind, child_model in enumerate(already_extracted[target_model]['models']):
if issubclass(target_field, ManyToManyField):
raw_data = already_extracted[target_model]['raw'][ind]
if raw_data[field_name] == model_id:
setattr(model, related, child_model)
elif isinstance(related_model, ormar.Model):
if related_model.pk == model_id:
elif isinstance(getattr(child_model, field_name), ormar.Model):
if getattr(child_model, field_name).pk == model_id:
if reverse:
setattr(model, related, child_model)
else:
@ -123,6 +134,7 @@ class PrefetchQuery:
exclude_fields = self._exclude_columns
for part in related.split('__'):
fields = target_model.get_included(fields, part)
select_related = []
exclude_fields = target_model.get_excluded(exclude_fields, part)
target_field = target_model.Meta.model_fields[part]
@ -130,6 +142,9 @@ class PrefetchQuery:
if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True
if issubclass(target_field, ManyToManyField):
select_related = [target_field.through.get_name()]
parent_model = target_model
target_model = target_field.to
@ -141,9 +156,17 @@ class PrefetchQuery:
if not filter_clauses: # related field is empty
continue
query_target = target_model
table_prefix = ''
if issubclass(target_field, ManyToManyField):
query_target = target_field.through
select_related = [target_field.to.get_name()]
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=query_target.Meta.tablename, to_table=target_field.to.Meta.tablename)
qry = Query(
model_cls=target_model,
select_related=[],
model_cls=query_target,
select_related=select_related,
filter_clauses=filter_clauses,
exclude_clauses=[],
offset=None,
@ -161,7 +184,7 @@ class PrefetchQuery:
item = target_model.extract_prefixed_table_columns(
item={},
row=row,
table_prefix='',
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields
)
@ -174,12 +197,22 @@ class PrefetchQuery:
for part in related.split('__')[:-1]:
fields = target_model.get_included(fields, part)
exclude_fields = target_model.get_excluded(exclude_fields, part)
target_model = target_model.Meta.model_fields[part].to
target_field = target_model.Meta.model_fields[part]
target_model = target_field.to
table_prefix = ''
if issubclass(target_field, ManyToManyField):
from_table = target_field.through.Meta.tablename
to_name = target_field.to.Meta.tablename
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=from_table, to_table=to_name)
for row in already_extracted.get(target_model.get_name(), {}).get('raw', []):
item = target_model.extract_prefixed_table_columns(
item={},
row=row,
table_prefix='',
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields
)

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional
import databases
import pytest
@ -21,6 +21,23 @@ class Tonation(ormar.Model):
name: str = ormar.String(max_length=100)
class Shop(ormar.Model):
class Meta:
tablename = "shops"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class AlbumShops(ormar.Model):
class Meta:
tablename = "albums_x_shops"
metadata = metadata
database = database
class Album(ormar.Model):
class Meta:
tablename = "albums"
@ -29,6 +46,7 @@ class Album(ormar.Model):
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
shops: List[Shop] = ormar.ManyToMany(to=Shop, through=AlbumShops)
class Track(ormar.Model):
@ -115,6 +133,32 @@ async def test_prefetch_related():
assert len(tracks) == 6
@pytest.mark.asyncio
async def test_prefetch_related_with_many_to_many():
async with database:
async with database.transaction(force_rollback=True):
shop1 = await Shop.objects.create(name='Shop 1')
shop2 = await Shop.objects.create(name='Shop 2')
album = Album(name="Malibu")
await album.save()
await album.shops.add(shop1)
await album.shops.add(shop2)
await Track.objects.create(album=album, title="The Bird", position=1)
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2)
await Track.objects.create(album=album, title="The Waters", position=3)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops"]).get(
title="The Bird")
assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1'
assert len(track.album.shops) == 2
@pytest.mark.asyncio
async def test_prefetch_related_empty():
async with database: