cleaner version but still dirty

This commit is contained in:
collerek
2020-11-25 13:28:51 +01:00
parent 585bba3ad3
commit f2fe41d38a
3 changed files with 300 additions and 192 deletions

View File

@ -1,89 +1,109 @@
from typing import Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union from typing import (
Dict,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
Tuple,
Type,
Union,
)
import ormar import ormar
from ormar.fields import ManyToManyField from ormar.fields import BaseField, ManyToManyField
from ormar.queryset.clause import QueryClause from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query from ormar.queryset.query import Query
from ormar.queryset.utils import translate_list_to_dict
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from ormar import Model from ormar import Model
class PrefetchQuery: class PrefetchQuery:
def __init__(
def __init__(self, self,
model_cls: Type["Model"], model_cls: Type["Model"],
fields: Optional[Union[Dict, Set]], fields: Optional[Union[Dict, Set]],
exclude_fields: Optional[Union[Dict, Set]], exclude_fields: Optional[Union[Dict, Set]],
prefetch_related: List prefetch_related: List,
): select_related: List,
) -> None:
self.model = model_cls self.model = model_cls
self.database = self.model.Meta.database self.database = self.model.Meta.database
self._prefetch_related = prefetch_related self._prefetch_related = prefetch_related
self._select_related = select_related
self._exclude_columns = exclude_fields self._exclude_columns = exclude_fields
self._columns = fields self._columns = fields
@staticmethod @staticmethod
def _extract_required_ids(already_extracted: Dict, def _extract_required_ids(
already_extracted: Dict,
parent_model: Type["Model"], parent_model: Type["Model"],
target_model: Type["Model"], target_model: Type["Model"],
reverse: bool) -> Set: reverse: bool,
raw_rows = already_extracted.get(parent_model.get_name(), {}).get('raw', []) ) -> Set:
raw_rows = already_extracted.get(parent_model.get_name(), {}).get("raw", [])
table_prefix = already_extracted.get(parent_model.get_name(), {}).get(
"prefix", ""
)
if reverse: if reverse:
column_name = parent_model.get_column_alias(parent_model.Meta.pkname) column_name = parent_model.get_column_alias(parent_model.Meta.pkname)
else: else:
column_name = target_model.resolve_relation_field(parent_model, target_model).get_alias() column_name = target_model.resolve_relation_field(
parent_model, target_model
).get_alias()
list_of_ids = set() list_of_ids = set()
column_name = (f"{table_prefix}_" if table_prefix else "") + column_name
for row in raw_rows: for row in raw_rows:
if row[column_name]: if row[column_name]:
list_of_ids.add(row[column_name]) list_of_ids.add(row[column_name])
return list_of_ids return list_of_ids
@staticmethod @staticmethod
def _get_filter_for_prefetch(already_extracted: Dict, def _get_filter_for_prefetch(
already_extracted: Dict,
parent_model: Type["Model"], parent_model: Type["Model"],
target_model: Type["Model"], target_model: Type["Model"],
reverse: bool) -> List: reverse: bool,
ids = PrefetchQuery._extract_required_ids(already_extracted=already_extracted, ) -> List:
ids = PrefetchQuery._extract_required_ids(
already_extracted=already_extracted,
parent_model=parent_model, parent_model=parent_model,
target_model=target_model, target_model=target_model,
reverse=reverse) reverse=reverse,
)
if ids: if ids:
qryclause = QueryClause( qryclause = QueryClause(
model_cls=target_model, model_cls=target_model, select_related=[], filter_clauses=[],
select_related=[],
filter_clauses=[],
) )
if reverse: if reverse:
field = target_model.resolve_relation_field(target_model, parent_model) field = target_model.resolve_relation_field(target_model, parent_model)
if issubclass(field, ManyToManyField): if issubclass(field, ManyToManyField):
sub_field = target_model.resolve_relation_field(field.through, parent_model) sub_field = target_model.resolve_relation_field(
kwargs = {f'{sub_field.get_alias()}__in': ids} field.through, parent_model
)
kwargs = {f"{sub_field.get_alias()}__in": ids}
qryclause = QueryClause( qryclause = QueryClause(
model_cls=field.through, model_cls=field.through, select_related=[], filter_clauses=[],
select_related=[],
filter_clauses=[],
) )
else: else:
kwargs = {f'{field.get_alias()}__in': ids} kwargs = {f"{field.get_alias()}__in": ids}
else: else:
target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias() target_field = target_model.Meta.model_fields[
kwargs = {f'{target_field}__in': ids} target_model.Meta.pkname
].get_alias()
kwargs = {f"{target_field}__in": ids}
filter_clauses, _ = qryclause.filter(**kwargs) filter_clauses, _ = qryclause.filter(**kwargs)
return filter_clauses return filter_clauses
return [] return []
@staticmethod @staticmethod
def _populate_nested_related(model: "Model", def _get_model_id_and_field_name(
already_extracted: Dict) -> "Model": target_field: Type["BaseField"], model: "Model"
) -> Tuple[bool, Optional[str], Optional[int]]:
for related in model.extract_related_names():
reverse = False
target_field = model.Meta.model_fields[related]
target_model = target_field.to.get_name()
if target_field.virtual: if target_field.virtual:
reverse = True reverse = True
field_name = model.resolve_relation_name(target_field.to, model) field_name = model.resolve_relation_name(target_field.to, model)
@ -93,76 +113,160 @@ class PrefetchQuery:
field_name = model.resolve_relation_name(target_field.through, model) field_name = model.resolve_relation_name(target_field.through, model)
model_id = model.pk model_id = model.pk
else: else:
reverse = False
related_name = model.resolve_relation_name(model, target_field.to) related_name = model.resolve_relation_name(model, target_field.to)
related_model = getattr(model, related_name) related_model = getattr(model, related_name)
if not related_model: if not related_model:
continue return reverse, None, None
model_id = related_model.pk model_id = related_model.pk
field_name = target_field.to.Meta.pkname field_name = target_field.to.Meta.pkname
if target_model in already_extracted and already_extracted[target_model]['models']: return reverse, field_name, model_id
print('*****POPULATING RELATED:', target_model, field_name, '*****', end='\n')
print(already_extracted[target_model]['models']) @staticmethod
for ind, child_model in enumerate(already_extracted[target_model]['models']): def _get_names_to_extract(prefetch_dict: Dict, model: "Model") -> List:
related_to_extract = []
if prefetch_dict and prefetch_dict is not Ellipsis:
related_to_extract = [
related
for related in model.extract_related_names()
if related in prefetch_dict
]
return related_to_extract
@staticmethod
def _populate_nested_related(
model: "Model", already_extracted: Dict, prefetch_dict: Dict
) -> "Model":
related_to_extract = PrefetchQuery._get_names_to_extract(
prefetch_dict=prefetch_dict, model=model
)
for related in related_to_extract:
target_field = model.Meta.model_fields[related]
target_model = target_field.to.get_name()
reverse, field_name, model_id = PrefetchQuery._get_model_id_and_field_name(
target_field=target_field, model=model
)
if (
target_model in already_extracted
and already_extracted[target_model]["models"]
):
for key, child_model in already_extracted[target_model][
"models"
].items():
if issubclass(target_field, ManyToManyField): if issubclass(target_field, ManyToManyField):
raw_data = already_extracted[target_model]['raw'][ind] ind = next(
if raw_data[field_name] == model_id: i
if key == x[target_field.to.get_column_alias(field_name)]
else -1
for i, x in enumerate(
already_extracted[target_model]["raw"]
)
)
raw_data = already_extracted[target_model]["raw"][ind]
if (
raw_data
and field_name in raw_data
and raw_data[field_name] == model_id
):
setattr(model, related, child_model) setattr(model, related, child_model)
elif isinstance(getattr(child_model, field_name), ormar.Model): elif isinstance(getattr(child_model, field_name), ormar.Model):
if getattr(child_model, field_name).pk == model_id: if getattr(child_model, field_name).pk == model_id:
if reverse:
setattr(model, related, child_model) setattr(model, related, child_model)
else:
setattr(child_model, related, model)
else: # we have not reverse relation and related_model is a pk value elif getattr(child_model, field_name) == model_id:
setattr(model, related, child_model) setattr(model, related, child_model)
return model return model
async def prefetch_related(self, models: Sequence["Model"], rows: List): async def prefetch_related(
self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]:
return await self._prefetch_related_models(models=models, rows=rows) return await self._prefetch_related_models(models=models, rows=rows)
async def _prefetch_related_models(self, async def _prefetch_related_models(
models: Sequence["Model"], self, models: Sequence["Model"], rows: List
rows: List) -> Sequence["Model"]: ) -> Sequence["Model"]:
already_extracted = {self.model.get_name(): {'raw': rows, 'models': models}} already_extracted = {
for related in self._prefetch_related: self.model.get_name(): {
"raw": rows,
"models": {model.pk: model for model in models},
}
}
select_dict = translate_list_to_dict(self._select_related)
prefetch_dict = translate_list_to_dict(self._prefetch_related)
target_model = self.model target_model = self.model
fields = self._columns fields = self._columns
exclude_fields = self._exclude_columns exclude_fields = self._exclude_columns
for part in related.split('__'): for related in prefetch_dict.keys():
fields = target_model.get_included(fields, part) await self._extract_related_models(
select_related = [] related=related,
exclude_fields = target_model.get_excluded(exclude_fields, part) target_model=target_model,
prefetch_dict=prefetch_dict.get(related),
select_dict=select_dict.get(related),
already_extracted=already_extracted,
fields=fields,
exclude_fields=exclude_fields,
)
final_models = []
for model in models:
final_models.append(
self._populate_nested_related(
model=model,
already_extracted=already_extracted,
prefetch_dict=prefetch_dict,
)
)
return models
target_field = target_model.Meta.model_fields[part] async def _extract_related_models( # noqa: CFQ002
self,
related: str,
target_model: Type["Model"],
prefetch_dict: Dict,
select_dict: Dict,
already_extracted: Dict,
fields: Dict,
exclude_fields: Dict,
) -> None:
fields = target_model.get_included(fields, related)
exclude_fields = target_model.get_excluded(exclude_fields, related)
select_related = []
target_field = target_model.Meta.model_fields[related]
reverse = False reverse = False
if target_field.virtual or issubclass(target_field, ManyToManyField): if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True reverse = True
if issubclass(target_field, ManyToManyField):
select_related = [target_field.through.get_name()]
parent_model = target_model parent_model = target_model
target_model = target_field.to target_model = target_field.to
if target_model.get_name() not in already_extracted: filter_clauses = PrefetchQuery._get_filter_for_prefetch(
filter_clauses = self._get_filter_for_prefetch(already_extracted=already_extracted, already_extracted=already_extracted,
parent_model=parent_model, parent_model=parent_model,
target_model=target_model, target_model=target_model,
reverse=reverse) reverse=reverse,
)
if not filter_clauses: # related field is empty if not filter_clauses: # related field is empty
continue return
query_target = target_model query_target = target_model
table_prefix = '' table_prefix = ""
if issubclass(target_field, ManyToManyField): if issubclass(target_field, ManyToManyField):
query_target = target_field.through query_target = target_field.through
select_related = [target_field.to.get_name()] select_related = [target_field.to.get_name()]
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=query_target.Meta.tablename, to_table=target_field.to.Meta.tablename) from_table=query_target.Meta.tablename,
to_table=target_field.to.Meta.tablename,
)
already_extracted.setdefault(target_model.get_name(), {})[
"prefix"
] = table_prefix
qry = Query( qry = Query(
model_cls=query_target, model_cls=query_target,
@ -176,53 +280,38 @@ class PrefetchQuery:
order_bys=None, order_bys=None,
) )
expr = qry.build_select_expression() expr = qry.build_select_expression()
print(expr.compile(compile_kwargs={"literal_binds": True})) # print(expr.compile(compile_kwargs={"literal_binds": True}))
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
already_extracted[target_model.get_name()] = {'raw': rows, 'models': []} already_extracted.setdefault(target_model.get_name(), {}).update(
if part == related.split('__')[-1]: {"raw": rows, "models": {}}
)
if prefetch_dict and prefetch_dict is not Ellipsis:
for subrelated in prefetch_dict.keys():
await self._extract_related_models(
related=subrelated,
target_model=target_model,
prefetch_dict=prefetch_dict.get(subrelated),
select_dict=select_dict.get(subrelated)
if (select_dict and subrelated in select_dict)
else {},
already_extracted=already_extracted,
fields=fields,
exclude_fields=exclude_fields,
)
for row in rows: for row in rows:
item = target_model.extract_prefixed_table_columns( item = target_model.extract_prefixed_table_columns(
item={}, item={},
row=row, row=row,
table_prefix=table_prefix, table_prefix=table_prefix,
fields=fields, fields=fields,
exclude_fields=exclude_fields exclude_fields=exclude_fields,
) )
instance = target_model(**item) instance = target_model(**item)
already_extracted[target_model.get_name()]['models'].append(instance) instance = self._populate_nested_related(
model=instance,
target_model = self.model already_extracted=already_extracted,
fields = self._columns prefetch_dict=prefetch_dict,
exclude_fields = self._exclude_columns
for part in related.split('__')[:-1]:
fields = target_model.get_included(fields, part)
exclude_fields = target_model.get_excluded(exclude_fields, part)
target_field = target_model.Meta.model_fields[part]
target_model = target_field.to
table_prefix = ''
if issubclass(target_field, ManyToManyField):
from_table = target_field.through.Meta.tablename
to_name = target_field.to.Meta.tablename
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=from_table, to_table=to_name)
for row in already_extracted.get(target_model.get_name(), {}).get('raw', []):
item = target_model.extract_prefixed_table_columns(
item={},
row=row,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields
) )
instance = target_model(**item) already_extracted[target_model.get_name()]["models"][instance.pk] = instance
instance = self._populate_nested_related(model=instance,
already_extracted=already_extracted)
already_extracted[target_model.get_name()]['models'].append(instance)
final_models = []
for model in models:
final_models.append(self._populate_nested_related(model=model,
already_extracted=already_extracted))
return models

View File

@ -66,11 +66,16 @@ class QuerySet:
raise ValueError("Model class of QuerySet is not initialized") raise ValueError("Model class of QuerySet is not initialized")
return self.model_cls return self.model_cls
async def _prefetch_related_models(self, models: Sequence["Model"], rows: List) -> Sequence["Model"]: async def _prefetch_related_models(
query = PrefetchQuery(model_cls=self.model_cls, self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]:
query = PrefetchQuery(
model_cls=self.model_cls,
fields=self._columns, fields=self._columns,
exclude_fields=self._exclude_columns, exclude_fields=self._exclude_columns,
prefetch_related=self._prefetch_related) prefetch_related=self._prefetch_related,
select_related=self._select_related,
)
return await query.prefetch_related(models=models, rows=rows) return await query.prefetch_related(models=models, rows=rows)
def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]:
@ -197,7 +202,7 @@ class QuerySet:
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
prefetch_related=related prefetch_related=related,
) )
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":

View File

@ -21,6 +21,16 @@ class Tonation(ormar.Model):
name: str = ormar.String(max_length=100) name: str = ormar.String(max_length=100)
class Division(ormar.Model):
class Meta:
tablename = "divisions"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Shop(ormar.Model): class Shop(ormar.Model):
class Meta: class Meta:
tablename = "shops" tablename = "shops"
@ -29,6 +39,7 @@ class Shop(ormar.Model):
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100) name: str = ormar.String(max_length=100)
division: Optional[Division] = ormar.ForeignKey(Division)
class AlbumShops(ormar.Model): class AlbumShops(ormar.Model):
@ -137,8 +148,9 @@ async def test_prefetch_related():
async def test_prefetch_related_with_many_to_many(): async def test_prefetch_related_with_many_to_many():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
shop1 = await Shop.objects.create(name='Shop 1') div = await Division.objects.create(name='Div 1')
shop2 = await Shop.objects.create(name='Shop 2') shop1 = await Shop.objects.create(name='Shop 1', division=div)
shop2 = await Shop.objects.create(name='Shop 2', division=div)
album = Album(name="Malibu") album = Album(name="Malibu")
await album.save() await album.save()
await album.shops.add(shop1) await album.shops.add(shop1)
@ -150,13 +162,15 @@ async def test_prefetch_related_with_many_to_many():
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops"]).get( track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops__division"]).get(
title="The Bird") title="The Bird")
assert track.album.name == "Malibu" assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2 assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1' assert track.album.cover_pictures[0].artist == 'Artist 1'
assert len(track.album.shops) == 2 assert len(track.album.shops) == 2
assert track.album.shops[0].name == 'Shop 1'
assert track.album.shops[0].division.name == 'Div 1'
@pytest.mark.asyncio @pytest.mark.asyncio