dirty many to many pass first test
This commit is contained in:
@ -57,6 +57,16 @@ class PrefetchQuery:
|
|||||||
)
|
)
|
||||||
if reverse:
|
if reverse:
|
||||||
field = target_model.resolve_relation_field(target_model, parent_model)
|
field = target_model.resolve_relation_field(target_model, parent_model)
|
||||||
|
if issubclass(field, ManyToManyField):
|
||||||
|
sub_field = target_model.resolve_relation_field(field.through, parent_model)
|
||||||
|
kwargs = {f'{sub_field.get_alias()}__in': ids}
|
||||||
|
qryclause = QueryClause(
|
||||||
|
model_cls=field.through,
|
||||||
|
select_related=[],
|
||||||
|
filter_clauses=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
kwargs = {f'{field.get_alias()}__in': ids}
|
kwargs = {f'{field.get_alias()}__in': ids}
|
||||||
else:
|
else:
|
||||||
target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias()
|
target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias()
|
||||||
@ -73,13 +83,15 @@ class PrefetchQuery:
|
|||||||
reverse = False
|
reverse = False
|
||||||
|
|
||||||
target_field = model.Meta.model_fields[related]
|
target_field = model.Meta.model_fields[related]
|
||||||
if target_field.virtual or issubclass(target_field, ManyToManyField):
|
|
||||||
reverse = True
|
|
||||||
|
|
||||||
target_model = target_field.to.get_name()
|
target_model = target_field.to.get_name()
|
||||||
if reverse:
|
if target_field.virtual:
|
||||||
|
reverse = True
|
||||||
field_name = model.resolve_relation_name(target_field.to, model)
|
field_name = model.resolve_relation_name(target_field.to, model)
|
||||||
model_id = model.pk
|
model_id = model.pk
|
||||||
|
elif issubclass(target_field, ManyToManyField):
|
||||||
|
reverse = True
|
||||||
|
field_name = model.resolve_relation_name(target_field.through, model)
|
||||||
|
model_id = model.pk
|
||||||
else:
|
else:
|
||||||
related_name = model.resolve_relation_name(model, target_field.to)
|
related_name = model.resolve_relation_name(model, target_field.to)
|
||||||
related_model = getattr(model, related_name)
|
related_model = getattr(model, related_name)
|
||||||
@ -89,17 +101,16 @@ class PrefetchQuery:
|
|||||||
field_name = target_field.to.Meta.pkname
|
field_name = target_field.to.Meta.pkname
|
||||||
|
|
||||||
if target_model in already_extracted and already_extracted[target_model]['models']:
|
if target_model in already_extracted and already_extracted[target_model]['models']:
|
||||||
print('*****POPULATING RELATED:', target_model, field_name)
|
print('*****POPULATING RELATED:', target_model, field_name, '*****', end='\n')
|
||||||
print(already_extracted[target_model]['models'])
|
print(already_extracted[target_model]['models'])
|
||||||
for child_model in already_extracted[target_model]['models']:
|
for ind, child_model in enumerate(already_extracted[target_model]['models']):
|
||||||
related_model = getattr(child_model, field_name)
|
if issubclass(target_field, ManyToManyField):
|
||||||
if isinstance(related_model, list):
|
raw_data = already_extracted[target_model]['raw'][ind]
|
||||||
for child in related_model:
|
if raw_data[field_name] == model_id:
|
||||||
if child.pk == model_id:
|
setattr(model, related, child_model)
|
||||||
setattr(model, related, child)
|
|
||||||
|
|
||||||
elif isinstance(related_model, ormar.Model):
|
elif isinstance(getattr(child_model, field_name), ormar.Model):
|
||||||
if related_model.pk == model_id:
|
if getattr(child_model, field_name).pk == model_id:
|
||||||
if reverse:
|
if reverse:
|
||||||
setattr(model, related, child_model)
|
setattr(model, related, child_model)
|
||||||
else:
|
else:
|
||||||
@ -123,6 +134,7 @@ class PrefetchQuery:
|
|||||||
exclude_fields = self._exclude_columns
|
exclude_fields = self._exclude_columns
|
||||||
for part in related.split('__'):
|
for part in related.split('__'):
|
||||||
fields = target_model.get_included(fields, part)
|
fields = target_model.get_included(fields, part)
|
||||||
|
select_related = []
|
||||||
exclude_fields = target_model.get_excluded(exclude_fields, part)
|
exclude_fields = target_model.get_excluded(exclude_fields, part)
|
||||||
|
|
||||||
target_field = target_model.Meta.model_fields[part]
|
target_field = target_model.Meta.model_fields[part]
|
||||||
@ -130,6 +142,9 @@ class PrefetchQuery:
|
|||||||
if target_field.virtual or issubclass(target_field, ManyToManyField):
|
if target_field.virtual or issubclass(target_field, ManyToManyField):
|
||||||
reverse = True
|
reverse = True
|
||||||
|
|
||||||
|
if issubclass(target_field, ManyToManyField):
|
||||||
|
select_related = [target_field.through.get_name()]
|
||||||
|
|
||||||
parent_model = target_model
|
parent_model = target_model
|
||||||
target_model = target_field.to
|
target_model = target_field.to
|
||||||
|
|
||||||
@ -141,9 +156,17 @@ class PrefetchQuery:
|
|||||||
if not filter_clauses: # related field is empty
|
if not filter_clauses: # related field is empty
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
query_target = target_model
|
||||||
|
table_prefix = ''
|
||||||
|
if issubclass(target_field, ManyToManyField):
|
||||||
|
query_target = target_field.through
|
||||||
|
select_related = [target_field.to.get_name()]
|
||||||
|
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
|
||||||
|
from_table=query_target.Meta.tablename, to_table=target_field.to.Meta.tablename)
|
||||||
|
|
||||||
qry = Query(
|
qry = Query(
|
||||||
model_cls=target_model,
|
model_cls=query_target,
|
||||||
select_related=[],
|
select_related=select_related,
|
||||||
filter_clauses=filter_clauses,
|
filter_clauses=filter_clauses,
|
||||||
exclude_clauses=[],
|
exclude_clauses=[],
|
||||||
offset=None,
|
offset=None,
|
||||||
@ -161,7 +184,7 @@ class PrefetchQuery:
|
|||||||
item = target_model.extract_prefixed_table_columns(
|
item = target_model.extract_prefixed_table_columns(
|
||||||
item={},
|
item={},
|
||||||
row=row,
|
row=row,
|
||||||
table_prefix='',
|
table_prefix=table_prefix,
|
||||||
fields=fields,
|
fields=fields,
|
||||||
exclude_fields=exclude_fields
|
exclude_fields=exclude_fields
|
||||||
)
|
)
|
||||||
@ -174,12 +197,22 @@ class PrefetchQuery:
|
|||||||
for part in related.split('__')[:-1]:
|
for part in related.split('__')[:-1]:
|
||||||
fields = target_model.get_included(fields, part)
|
fields = target_model.get_included(fields, part)
|
||||||
exclude_fields = target_model.get_excluded(exclude_fields, part)
|
exclude_fields = target_model.get_excluded(exclude_fields, part)
|
||||||
target_model = target_model.Meta.model_fields[part].to
|
|
||||||
|
target_field = target_model.Meta.model_fields[part]
|
||||||
|
target_model = target_field.to
|
||||||
|
table_prefix = ''
|
||||||
|
|
||||||
|
if issubclass(target_field, ManyToManyField):
|
||||||
|
from_table = target_field.through.Meta.tablename
|
||||||
|
to_name = target_field.to.Meta.tablename
|
||||||
|
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
|
||||||
|
from_table=from_table, to_table=to_name)
|
||||||
|
|
||||||
for row in already_extracted.get(target_model.get_name(), {}).get('raw', []):
|
for row in already_extracted.get(target_model.get_name(), {}).get('raw', []):
|
||||||
item = target_model.extract_prefixed_table_columns(
|
item = target_model.extract_prefixed_table_columns(
|
||||||
item={},
|
item={},
|
||||||
row=row,
|
row=row,
|
||||||
table_prefix='',
|
table_prefix=table_prefix,
|
||||||
fields=fields,
|
fields=fields,
|
||||||
exclude_fields=exclude_fields
|
exclude_fields=exclude_fields
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import databases
|
import databases
|
||||||
import pytest
|
import pytest
|
||||||
@ -21,6 +21,23 @@ class Tonation(ormar.Model):
|
|||||||
name: str = ormar.String(max_length=100)
|
name: str = ormar.String(max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class Shop(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "shops"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: int = ormar.Integer(primary_key=True)
|
||||||
|
name: str = ormar.String(max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class AlbumShops(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "albums_x_shops"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
|
||||||
class Album(ormar.Model):
|
class Album(ormar.Model):
|
||||||
class Meta:
|
class Meta:
|
||||||
tablename = "albums"
|
tablename = "albums"
|
||||||
@ -29,6 +46,7 @@ class Album(ormar.Model):
|
|||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
id: int = ormar.Integer(primary_key=True)
|
||||||
name: str = ormar.String(max_length=100)
|
name: str = ormar.String(max_length=100)
|
||||||
|
shops: List[Shop] = ormar.ManyToMany(to=Shop, through=AlbumShops)
|
||||||
|
|
||||||
|
|
||||||
class Track(ormar.Model):
|
class Track(ormar.Model):
|
||||||
@ -115,6 +133,32 @@ async def test_prefetch_related():
|
|||||||
assert len(tracks) == 6
|
assert len(tracks) == 6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prefetch_related_with_many_to_many():
|
||||||
|
async with database:
|
||||||
|
async with database.transaction(force_rollback=True):
|
||||||
|
shop1 = await Shop.objects.create(name='Shop 1')
|
||||||
|
shop2 = await Shop.objects.create(name='Shop 2')
|
||||||
|
album = Album(name="Malibu")
|
||||||
|
await album.save()
|
||||||
|
await album.shops.add(shop1)
|
||||||
|
await album.shops.add(shop2)
|
||||||
|
|
||||||
|
await Track.objects.create(album=album, title="The Bird", position=1)
|
||||||
|
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2)
|
||||||
|
await Track.objects.create(album=album, title="The Waters", position=3)
|
||||||
|
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
|
||||||
|
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
|
||||||
|
|
||||||
|
track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops"]).get(
|
||||||
|
title="The Bird")
|
||||||
|
assert track.album.name == "Malibu"
|
||||||
|
assert len(track.album.cover_pictures) == 2
|
||||||
|
assert track.album.cover_pictures[0].artist == 'Artist 1'
|
||||||
|
|
||||||
|
assert len(track.album.shops) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_prefetch_related_empty():
|
async def test_prefetch_related_empty():
|
||||||
async with database:
|
async with database:
|
||||||
|
|||||||
Reference in New Issue
Block a user