dirty prefetch_related working for FK and reverse FK

This commit is contained in:
collerek
2020-11-23 16:05:05 +01:00
parent 08779f4689
commit b696156f56
4 changed files with 385 additions and 30 deletions

View File

@ -0,0 +1,195 @@
from typing import Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union
import ormar
from ormar.fields import ManyToManyField
from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query
if TYPE_CHECKING: # pragma: no cover
from ormar import Model
class PrefetchQuery:
def __init__(self,
model_cls: Type["Model"],
fields: Optional[Union[Dict, Set]],
exclude_fields: Optional[Union[Dict, Set]],
prefetch_related: List
):
self.model = model_cls
self.database = self.model.Meta.database
self._prefetch_related = prefetch_related
self._exclude_columns = exclude_fields
self._columns = fields
@staticmethod
def _extract_required_ids(already_extracted: Dict,
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool) -> Set:
raw_rows = already_extracted.get(parent_model.get_name(), {}).get('raw', [])
if reverse:
column_name = parent_model.get_column_alias(parent_model.Meta.pkname)
else:
column_name = target_model.resolve_relation_field(parent_model, target_model).get_alias()
list_of_ids = set()
for row in raw_rows:
if row[column_name]:
list_of_ids.add(row[column_name])
return list_of_ids
@staticmethod
def _get_filter_for_prefetch(already_extracted: Dict,
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool) -> List:
ids = PrefetchQuery._extract_required_ids(already_extracted=already_extracted,
parent_model=parent_model,
target_model=target_model,
reverse=reverse)
if ids:
qryclause = QueryClause(
model_cls=target_model,
select_related=[],
filter_clauses=[],
)
if reverse:
field = target_model.resolve_relation_field(target_model, parent_model)
kwargs = {f'{field.get_alias()}__in': ids}
else:
target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias()
kwargs = {f'{target_field}__in': ids}
filter_clauses, _ = qryclause.filter(**kwargs)
return filter_clauses
return []
@staticmethod
def _populate_nested_related(model: "Model",
already_extracted: Dict) -> "Model":
for related in model.extract_related_names():
reverse = False
target_field = model.Meta.model_fields[related]
if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True
target_model = target_field.to.get_name()
if reverse:
field_name = model.resolve_relation_name(target_field.to, model)
model_id = model.pk
else:
related_name = model.resolve_relation_name(model, target_field.to)
related_model = getattr(model, related_name)
if not related_model:
continue
model_id = related_model.pk
field_name = target_field.to.Meta.pkname
if target_model in already_extracted and already_extracted[target_model]['models']:
print('*****POPULATING RELATED:', target_model, field_name)
print(already_extracted[target_model]['models'])
for child_model in already_extracted[target_model]['models']:
related_model = getattr(child_model, field_name)
if isinstance(related_model, list):
for child in related_model:
if child.pk == model_id:
setattr(model, related, child)
elif isinstance(related_model, ormar.Model):
if related_model.pk == model_id:
if reverse:
setattr(model, related, child_model)
else:
setattr(child_model, related, model)
else: # we have not reverse relation and related_model is a pk value
setattr(model, related, child_model)
return model
async def prefetch_related(self, models: Sequence["Model"], rows: List):
return await self._prefetch_related_models(models=models, rows=rows)
async def _prefetch_related_models(self,
models: Sequence["Model"],
rows: List) -> Sequence["Model"]:
already_extracted = {self.model.get_name(): {'raw': rows, 'models': models}}
for related in self._prefetch_related:
target_model = self.model
fields = self._columns
exclude_fields = self._exclude_columns
for part in related.split('__'):
fields = target_model.get_included(fields, part)
exclude_fields = target_model.get_excluded(exclude_fields, part)
target_field = target_model.Meta.model_fields[part]
reverse = False
if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True
parent_model = target_model
target_model = target_field.to
if target_model.get_name() not in already_extracted:
filter_clauses = self._get_filter_for_prefetch(already_extracted=already_extracted,
parent_model=parent_model,
target_model=target_model,
reverse=reverse)
if not filter_clauses: # related field is empty
continue
qry = Query(
model_cls=target_model,
select_related=[],
filter_clauses=filter_clauses,
exclude_clauses=[],
offset=None,
limit_count=None,
fields=fields,
exclude_fields=exclude_fields,
order_bys=None,
)
expr = qry.build_select_expression()
print(expr.compile(compile_kwargs={"literal_binds": True}))
rows = await self.database.fetch_all(expr)
already_extracted[target_model.get_name()] = {'raw': rows, 'models': []}
if part == related.split('__')[-1]:
for row in rows:
item = target_model.extract_prefixed_table_columns(
item={},
row=row,
table_prefix='',
fields=fields,
exclude_fields=exclude_fields
)
instance = target_model(**item)
already_extracted[target_model.get_name()]['models'].append(instance)
target_model = self.model
fields = self._columns
exclude_fields = self._exclude_columns
for part in related.split('__')[:-1]:
fields = target_model.get_included(fields, part)
exclude_fields = target_model.get_excluded(exclude_fields, part)
target_model = target_model.Meta.model_fields[part].to
for row in already_extracted.get(target_model.get_name(), {}).get('raw', []):
item = target_model.extract_prefixed_table_columns(
item={},
row=row,
table_prefix='',
fields=fields,
exclude_fields=exclude_fields
)
instance = target_model(**item)
instance = self._populate_nested_related(model=instance,
already_extracted=already_extracted)
already_extracted[target_model.get_name()]['models'].append(instance)
final_models = []
for model in models:
final_models.append(self._populate_nested_related(model=model,
already_extracted=already_extracted))
return models

View File

@ -9,6 +9,7 @@ from ormar import MultipleMatches, NoMatch
from ormar.exceptions import QueryDefinitionError from ormar.exceptions import QueryDefinitionError
from ormar.queryset import FilterQuery from ormar.queryset import FilterQuery
from ormar.queryset.clause import QueryClause from ormar.queryset.clause import QueryClause
from ormar.queryset.prefetch_query import PrefetchQuery
from ormar.queryset.query import Query from ormar.queryset.query import Query
from ormar.queryset.utils import update, update_dict_from_list from ormar.queryset.utils import update, update_dict_from_list
@ -30,11 +31,13 @@ class QuerySet:
columns: Dict = None, columns: Dict = None,
exclude_columns: Dict = None, exclude_columns: Dict = None,
order_bys: List = None, order_bys: List = None,
prefetch_related: List = None,
) -> None: ) -> None:
self.model_cls = model_cls self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses self.filter_clauses = [] if filter_clauses is None else filter_clauses
self.exclude_clauses = [] if exclude_clauses is None else exclude_clauses self.exclude_clauses = [] if exclude_clauses is None else exclude_clauses
self._select_related = [] if select_related is None else select_related self._select_related = [] if select_related is None else select_related
self._prefetch_related = [] if prefetch_related is None else prefetch_related
self.limit_count = limit_count self.limit_count = limit_count
self.query_offset = offset self.query_offset = offset
self._columns = columns or {} self._columns = columns or {}
@ -63,6 +66,13 @@ class QuerySet:
raise ValueError("Model class of QuerySet is not initialized") raise ValueError("Model class of QuerySet is not initialized")
return self.model_cls return self.model_cls
async def _prefetch_related_models(self, models: Sequence["Model"], rows: List) -> Sequence["Model"]:
query = PrefetchQuery(model_cls=self.model_cls,
fields=self._columns,
exclude_fields=self._exclude_columns,
prefetch_related=self._prefetch_related)
return await query.prefetch_related(models=models, rows=rows)
def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]:
result_rows = [ result_rows = [
self.model.from_row( self.model.from_row(
@ -148,6 +158,7 @@ class QuerySet:
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
) )
def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003 def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003
@ -168,6 +179,25 @@ class QuerySet:
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
)
def prefetch_related(self, related: Union[List, str]) -> "QuerySet":
if not isinstance(related, list):
related = [related]
related = list(set(list(self._select_related) + related))
return self.__class__(
model_cls=self.model,
filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
select_related=self._select_related,
limit_count=self.limit_count,
offset=self.query_offset,
columns=self._columns,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys,
prefetch_related=related
) )
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
@ -190,6 +220,7 @@ class QuerySet:
columns=self._columns, columns=self._columns,
exclude_columns=current_excluded, exclude_columns=current_excluded,
order_bys=self.order_bys, order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
) )
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
@ -212,6 +243,7 @@ class QuerySet:
columns=current_included, columns=current_included,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
) )
def order_by(self, columns: Union[List, str]) -> "QuerySet": def order_by(self, columns: Union[List, str]) -> "QuerySet":
@ -229,6 +261,7 @@ class QuerySet:
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=order_bys, order_bys=order_bys,
prefetch_related=self._prefetch_related,
) )
async def exists(self) -> bool: async def exists(self) -> bool:
@ -279,6 +312,7 @@ class QuerySet:
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
) )
def offset(self, offset: int) -> "QuerySet": def offset(self, offset: int) -> "QuerySet":
@ -292,6 +326,7 @@ class QuerySet:
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
) )
async def first(self, **kwargs: Any) -> "Model": async def first(self, **kwargs: Any) -> "Model":
@ -312,6 +347,8 @@ class QuerySet:
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
processed_rows = self._process_query_result_rows(rows) processed_rows = self._process_query_result_rows(rows)
if self._prefetch_related:
processed_rows = await self._prefetch_related_models(processed_rows, rows)
self.check_single_result_rows_count(processed_rows) self.check_single_result_rows_count(processed_rows)
return processed_rows[0] # type: ignore return processed_rows[0] # type: ignore
@ -337,6 +374,8 @@ class QuerySet:
expr = self.build_select_expression() expr = self.build_select_expression()
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
result_rows = self._process_query_result_rows(rows) result_rows = self._process_query_result_rows(rows)
if self._prefetch_related:
result_rows = await self._prefetch_related_models(result_rows, rows)
return result_rows return result_rows

View File

@ -55,10 +55,6 @@ class Organisation(ormar.Model):
ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"]) ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"])
class Organization(object):
pass
class Team(ormar.Model): class Team(ormar.Model):
class Meta: class Meta:
tablename = "teams" tablename = "teams"

View File

@ -0,0 +1,125 @@
from typing import Optional
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Tonation(ormar.Model):
class Meta:
tablename = "tonations"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Track(ormar.Model):
class Meta:
tablename = "tracks"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album)
title: str = ormar.String(max_length=100)
position: int = ormar.Integer()
tonation: Optional[Tonation] = ormar.ForeignKey(Tonation)
class Cover(ormar.Model):
class Meta:
tablename = "covers"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures")
title: str = ormar.String(max_length=100)
artist: str = ormar.String(max_length=200, nullable=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_prefetch_related():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Malibu")
await album.save()
ton1 = await Tonation.objects.create(name='B-mol')
await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1)
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1)
await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
fantasies = Album(name="Fantasies")
await fantasies.save()
await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1)
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
await Track.objects.create(album=fantasies, title="Satellite Mind", position=3)
await Cover.objects.create(title='Cover3', album=fantasies, artist='Artist 3')
await Cover.objects.create(title='Cover4', album=fantasies, artist='Artist 4')
album = await Album.objects.filter(name='Malibu').prefetch_related(
['tracks__tonation', 'cover_pictures']).get()
assert len(album.tracks) == 3
assert album.tracks[0].title == 'The Bird'
assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].title == 'Cover1'
assert album.tracks[0].tonation.name == album.tracks[2].tonation.name == 'B-mol'
albums = await Album.objects.prefetch_related('tracks').all()
assert len(albums[0].tracks) == 3
assert len(albums[1].tracks) == 3
assert albums[0].tracks[0].title == "The Bird"
assert albums[1].tracks[0].title == "Help I'm Alive"
track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird")
assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1'
track = await Track.objects.prefetch_related(["album__cover_pictures"]).exclude_fields(
'album__cover_pictures__artist').get(title="The Bird")
assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist is None
tracks = await Track.objects.prefetch_related("album").all()
assert len(tracks) == 6
@pytest.mark.asyncio
async def test_prefetch_related_empty():
async with database:
async with database.transaction(force_rollback=True):
await Track.objects.create(title="The Bird", position=1)
track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird")
assert track.title == 'The Bird'
assert track.album is None