cleaner version but still dirty

This commit is contained in:
collerek
2020-11-25 13:28:51 +01:00
parent 585bba3ad3
commit f2fe41d38a
3 changed files with 300 additions and 192 deletions

View File

@ -1,228 +1,317 @@
from typing import Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union from typing import (
Dict,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
Tuple,
Type,
Union,
)
import ormar import ormar
from ormar.fields import ManyToManyField from ormar.fields import BaseField, ManyToManyField
from ormar.queryset.clause import QueryClause from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query from ormar.queryset.query import Query
from ormar.queryset.utils import translate_list_to_dict
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from ormar import Model from ormar import Model
class PrefetchQuery: class PrefetchQuery:
def __init__(
def __init__(self, self,
model_cls: Type["Model"], model_cls: Type["Model"],
fields: Optional[Union[Dict, Set]], fields: Optional[Union[Dict, Set]],
exclude_fields: Optional[Union[Dict, Set]], exclude_fields: Optional[Union[Dict, Set]],
prefetch_related: List prefetch_related: List,
): select_related: List,
) -> None:
self.model = model_cls self.model = model_cls
self.database = self.model.Meta.database self.database = self.model.Meta.database
self._prefetch_related = prefetch_related self._prefetch_related = prefetch_related
self._select_related = select_related
self._exclude_columns = exclude_fields self._exclude_columns = exclude_fields
self._columns = fields self._columns = fields
@staticmethod @staticmethod
def _extract_required_ids(already_extracted: Dict, def _extract_required_ids(
parent_model: Type["Model"], already_extracted: Dict,
target_model: Type["Model"], parent_model: Type["Model"],
reverse: bool) -> Set: target_model: Type["Model"],
raw_rows = already_extracted.get(parent_model.get_name(), {}).get('raw', []) reverse: bool,
) -> Set:
raw_rows = already_extracted.get(parent_model.get_name(), {}).get("raw", [])
table_prefix = already_extracted.get(parent_model.get_name(), {}).get(
"prefix", ""
)
if reverse: if reverse:
column_name = parent_model.get_column_alias(parent_model.Meta.pkname) column_name = parent_model.get_column_alias(parent_model.Meta.pkname)
else: else:
column_name = target_model.resolve_relation_field(parent_model, target_model).get_alias() column_name = target_model.resolve_relation_field(
parent_model, target_model
).get_alias()
list_of_ids = set() list_of_ids = set()
column_name = (f"{table_prefix}_" if table_prefix else "") + column_name
for row in raw_rows: for row in raw_rows:
if row[column_name]: if row[column_name]:
list_of_ids.add(row[column_name]) list_of_ids.add(row[column_name])
return list_of_ids return list_of_ids
@staticmethod @staticmethod
def _get_filter_for_prefetch(already_extracted: Dict, def _get_filter_for_prefetch(
parent_model: Type["Model"], already_extracted: Dict,
target_model: Type["Model"], parent_model: Type["Model"],
reverse: bool) -> List: target_model: Type["Model"],
ids = PrefetchQuery._extract_required_ids(already_extracted=already_extracted, reverse: bool,
parent_model=parent_model, ) -> List:
target_model=target_model, ids = PrefetchQuery._extract_required_ids(
reverse=reverse) already_extracted=already_extracted,
parent_model=parent_model,
target_model=target_model,
reverse=reverse,
)
if ids: if ids:
qryclause = QueryClause( qryclause = QueryClause(
model_cls=target_model, model_cls=target_model, select_related=[], filter_clauses=[],
select_related=[],
filter_clauses=[],
) )
if reverse: if reverse:
field = target_model.resolve_relation_field(target_model, parent_model) field = target_model.resolve_relation_field(target_model, parent_model)
if issubclass(field, ManyToManyField): if issubclass(field, ManyToManyField):
sub_field = target_model.resolve_relation_field(field.through, parent_model) sub_field = target_model.resolve_relation_field(
kwargs = {f'{sub_field.get_alias()}__in': ids} field.through, parent_model
)
kwargs = {f"{sub_field.get_alias()}__in": ids}
qryclause = QueryClause( qryclause = QueryClause(
model_cls=field.through, model_cls=field.through, select_related=[], filter_clauses=[],
select_related=[],
filter_clauses=[],
) )
else: else:
kwargs = {f'{field.get_alias()}__in': ids} kwargs = {f"{field.get_alias()}__in": ids}
else: else:
target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias() target_field = target_model.Meta.model_fields[
kwargs = {f'{target_field}__in': ids} target_model.Meta.pkname
].get_alias()
kwargs = {f"{target_field}__in": ids}
filter_clauses, _ = qryclause.filter(**kwargs) filter_clauses, _ = qryclause.filter(**kwargs)
return filter_clauses return filter_clauses
return [] return []
@staticmethod @staticmethod
def _populate_nested_related(model: "Model", def _get_model_id_and_field_name(
already_extracted: Dict) -> "Model": target_field: Type["BaseField"], model: "Model"
) -> Tuple[bool, Optional[str], Optional[int]]:
for related in model.extract_related_names(): if target_field.virtual:
reverse = True
field_name = model.resolve_relation_name(target_field.to, model)
model_id = model.pk
elif issubclass(target_field, ManyToManyField):
reverse = True
field_name = model.resolve_relation_name(target_field.through, model)
model_id = model.pk
else:
reverse = False reverse = False
related_name = model.resolve_relation_name(model, target_field.to)
related_model = getattr(model, related_name)
if not related_model:
return reverse, None, None
model_id = related_model.pk
field_name = target_field.to.Meta.pkname
return reverse, field_name, model_id
@staticmethod
def _get_names_to_extract(prefetch_dict: Dict, model: "Model") -> List:
related_to_extract = []
if prefetch_dict and prefetch_dict is not Ellipsis:
related_to_extract = [
related
for related in model.extract_related_names()
if related in prefetch_dict
]
return related_to_extract
@staticmethod
def _populate_nested_related(
model: "Model", already_extracted: Dict, prefetch_dict: Dict
) -> "Model":
related_to_extract = PrefetchQuery._get_names_to_extract(
prefetch_dict=prefetch_dict, model=model
)
for related in related_to_extract:
target_field = model.Meta.model_fields[related] target_field = model.Meta.model_fields[related]
target_model = target_field.to.get_name() target_model = target_field.to.get_name()
if target_field.virtual: reverse, field_name, model_id = PrefetchQuery._get_model_id_and_field_name(
reverse = True target_field=target_field, model=model
field_name = model.resolve_relation_name(target_field.to, model) )
model_id = model.pk
elif issubclass(target_field, ManyToManyField):
reverse = True
field_name = model.resolve_relation_name(target_field.through, model)
model_id = model.pk
else:
related_name = model.resolve_relation_name(model, target_field.to)
related_model = getattr(model, related_name)
if not related_model:
continue
model_id = related_model.pk
field_name = target_field.to.Meta.pkname
if target_model in already_extracted and already_extracted[target_model]['models']: if (
print('*****POPULATING RELATED:', target_model, field_name, '*****', end='\n') target_model in already_extracted
print(already_extracted[target_model]['models']) and already_extracted[target_model]["models"]
for ind, child_model in enumerate(already_extracted[target_model]['models']): ):
for key, child_model in already_extracted[target_model][
"models"
].items():
if issubclass(target_field, ManyToManyField): if issubclass(target_field, ManyToManyField):
raw_data = already_extracted[target_model]['raw'][ind] ind = next(
if raw_data[field_name] == model_id: i
if key == x[target_field.to.get_column_alias(field_name)]
else -1
for i, x in enumerate(
already_extracted[target_model]["raw"]
)
)
raw_data = already_extracted[target_model]["raw"][ind]
if (
raw_data
and field_name in raw_data
and raw_data[field_name] == model_id
):
setattr(model, related, child_model) setattr(model, related, child_model)
elif isinstance(getattr(child_model, field_name), ormar.Model): elif isinstance(getattr(child_model, field_name), ormar.Model):
if getattr(child_model, field_name).pk == model_id: if getattr(child_model, field_name).pk == model_id:
if reverse: setattr(model, related, child_model)
setattr(model, related, child_model)
else:
setattr(child_model, related, model)
else: # we have not reverse relation and related_model is a pk value elif getattr(child_model, field_name) == model_id:
setattr(model, related, child_model) setattr(model, related, child_model)
return model return model
async def prefetch_related(self, models: Sequence["Model"], rows: List): async def prefetch_related(
self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]:
return await self._prefetch_related_models(models=models, rows=rows) return await self._prefetch_related_models(models=models, rows=rows)
async def _prefetch_related_models(self, async def _prefetch_related_models(
models: Sequence["Model"], self, models: Sequence["Model"], rows: List
rows: List) -> Sequence["Model"]: ) -> Sequence["Model"]:
already_extracted = {self.model.get_name(): {'raw': rows, 'models': models}} already_extracted = {
for related in self._prefetch_related: self.model.get_name(): {
target_model = self.model "raw": rows,
fields = self._columns "models": {model.pk: model for model in models},
exclude_fields = self._exclude_columns }
for part in related.split('__'): }
fields = target_model.get_included(fields, part) select_dict = translate_list_to_dict(self._select_related)
select_related = [] prefetch_dict = translate_list_to_dict(self._prefetch_related)
exclude_fields = target_model.get_excluded(exclude_fields, part) target_model = self.model
fields = self._columns
target_field = target_model.Meta.model_fields[part] exclude_fields = self._exclude_columns
reverse = False for related in prefetch_dict.keys():
if target_field.virtual or issubclass(target_field, ManyToManyField): await self._extract_related_models(
reverse = True related=related,
target_model=target_model,
if issubclass(target_field, ManyToManyField): prefetch_dict=prefetch_dict.get(related),
select_related = [target_field.through.get_name()] select_dict=select_dict.get(related),
already_extracted=already_extracted,
parent_model = target_model fields=fields,
target_model = target_field.to exclude_fields=exclude_fields,
)
if target_model.get_name() not in already_extracted:
filter_clauses = self._get_filter_for_prefetch(already_extracted=already_extracted,
parent_model=parent_model,
target_model=target_model,
reverse=reverse)
if not filter_clauses: # related field is empty
continue
query_target = target_model
table_prefix = ''
if issubclass(target_field, ManyToManyField):
query_target = target_field.through
select_related = [target_field.to.get_name()]
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=query_target.Meta.tablename, to_table=target_field.to.Meta.tablename)
qry = Query(
model_cls=query_target,
select_related=select_related,
filter_clauses=filter_clauses,
exclude_clauses=[],
offset=None,
limit_count=None,
fields=fields,
exclude_fields=exclude_fields,
order_bys=None,
)
expr = qry.build_select_expression()
print(expr.compile(compile_kwargs={"literal_binds": True}))
rows = await self.database.fetch_all(expr)
already_extracted[target_model.get_name()] = {'raw': rows, 'models': []}
if part == related.split('__')[-1]:
for row in rows:
item = target_model.extract_prefixed_table_columns(
item={},
row=row,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields
)
instance = target_model(**item)
already_extracted[target_model.get_name()]['models'].append(instance)
target_model = self.model
fields = self._columns
exclude_fields = self._exclude_columns
for part in related.split('__')[:-1]:
fields = target_model.get_included(fields, part)
exclude_fields = target_model.get_excluded(exclude_fields, part)
target_field = target_model.Meta.model_fields[part]
target_model = target_field.to
table_prefix = ''
if issubclass(target_field, ManyToManyField):
from_table = target_field.through.Meta.tablename
to_name = target_field.to.Meta.tablename
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=from_table, to_table=to_name)
for row in already_extracted.get(target_model.get_name(), {}).get('raw', []):
item = target_model.extract_prefixed_table_columns(
item={},
row=row,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields
)
instance = target_model(**item)
instance = self._populate_nested_related(model=instance,
already_extracted=already_extracted)
already_extracted[target_model.get_name()]['models'].append(instance)
final_models = [] final_models = []
for model in models: for model in models:
final_models.append(self._populate_nested_related(model=model, final_models.append(
already_extracted=already_extracted)) self._populate_nested_related(
model=model,
already_extracted=already_extracted,
prefetch_dict=prefetch_dict,
)
)
return models return models
async def _extract_related_models( # noqa: CFQ002
self,
related: str,
target_model: Type["Model"],
prefetch_dict: Dict,
select_dict: Dict,
already_extracted: Dict,
fields: Dict,
exclude_fields: Dict,
) -> None:
fields = target_model.get_included(fields, related)
exclude_fields = target_model.get_excluded(exclude_fields, related)
select_related = []
target_field = target_model.Meta.model_fields[related]
reverse = False
if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True
parent_model = target_model
target_model = target_field.to
filter_clauses = PrefetchQuery._get_filter_for_prefetch(
already_extracted=already_extracted,
parent_model=parent_model,
target_model=target_model,
reverse=reverse,
)
if not filter_clauses: # related field is empty
return
query_target = target_model
table_prefix = ""
if issubclass(target_field, ManyToManyField):
query_target = target_field.through
select_related = [target_field.to.get_name()]
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=query_target.Meta.tablename,
to_table=target_field.to.Meta.tablename,
)
already_extracted.setdefault(target_model.get_name(), {})[
"prefix"
] = table_prefix
qry = Query(
model_cls=query_target,
select_related=select_related,
filter_clauses=filter_clauses,
exclude_clauses=[],
offset=None,
limit_count=None,
fields=fields,
exclude_fields=exclude_fields,
order_bys=None,
)
expr = qry.build_select_expression()
# print(expr.compile(compile_kwargs={"literal_binds": True}))
rows = await self.database.fetch_all(expr)
already_extracted.setdefault(target_model.get_name(), {}).update(
{"raw": rows, "models": {}}
)
if prefetch_dict and prefetch_dict is not Ellipsis:
for subrelated in prefetch_dict.keys():
await self._extract_related_models(
related=subrelated,
target_model=target_model,
prefetch_dict=prefetch_dict.get(subrelated),
select_dict=select_dict.get(subrelated)
if (select_dict and subrelated in select_dict)
else {},
already_extracted=already_extracted,
fields=fields,
exclude_fields=exclude_fields,
)
for row in rows:
item = target_model.extract_prefixed_table_columns(
item={},
row=row,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields,
)
instance = target_model(**item)
instance = self._populate_nested_related(
model=instance,
already_extracted=already_extracted,
prefetch_dict=prefetch_dict,
)
already_extracted[target_model.get_name()]["models"][instance.pk] = instance

View File

@ -21,17 +21,17 @@ if TYPE_CHECKING: # pragma no cover
class QuerySet: class QuerySet:
def __init__( # noqa CFQ002 def __init__( # noqa CFQ002
self, self,
model_cls: Type["Model"] = None, model_cls: Type["Model"] = None,
filter_clauses: List = None, filter_clauses: List = None,
exclude_clauses: List = None, exclude_clauses: List = None,
select_related: List = None, select_related: List = None,
limit_count: int = None, limit_count: int = None,
offset: int = None, offset: int = None,
columns: Dict = None, columns: Dict = None,
exclude_columns: Dict = None, exclude_columns: Dict = None,
order_bys: List = None, order_bys: List = None,
prefetch_related: List = None, prefetch_related: List = None,
) -> None: ) -> None:
self.model_cls = model_cls self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses self.filter_clauses = [] if filter_clauses is None else filter_clauses
@ -45,9 +45,9 @@ class QuerySet:
self.order_bys = order_bys or [] self.order_bys = order_bys or []
def __get__( def __get__(
self, self,
instance: Optional[Union["QuerySet", "QuerysetProxy"]], instance: Optional[Union["QuerySet", "QuerysetProxy"]],
owner: Union[Type["Model"], Type["QuerysetProxy"]], owner: Union[Type["Model"], Type["QuerysetProxy"]],
) -> "QuerySet": ) -> "QuerySet":
if issubclass(owner, ormar.Model): if issubclass(owner, ormar.Model):
return self.__class__(model_cls=owner) return self.__class__(model_cls=owner)
@ -66,11 +66,16 @@ class QuerySet:
raise ValueError("Model class of QuerySet is not initialized") raise ValueError("Model class of QuerySet is not initialized")
return self.model_cls return self.model_cls
async def _prefetch_related_models(self, models: Sequence["Model"], rows: List) -> Sequence["Model"]: async def _prefetch_related_models(
query = PrefetchQuery(model_cls=self.model_cls, self, models: Sequence["Model"], rows: List
fields=self._columns, ) -> Sequence["Model"]:
exclude_fields=self._exclude_columns, query = PrefetchQuery(
prefetch_related=self._prefetch_related) model_cls=self.model_cls,
fields=self._columns,
exclude_fields=self._exclude_columns,
prefetch_related=self._prefetch_related,
select_related=self._select_related,
)
return await query.prefetch_related(models=models, rows=rows) return await query.prefetch_related(models=models, rows=rows)
def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]:
@ -98,7 +103,7 @@ class QuerySet:
pkname = self.model_meta.pkname pkname = self.model_meta.pkname
pk = self.model_meta.model_fields[pkname] pk = self.model_meta.model_fields[pkname]
if new_kwargs.get(pkname, ormar.Undefined) is None and ( if new_kwargs.get(pkname, ormar.Undefined) is None and (
pk.nullable or pk.autoincrement pk.nullable or pk.autoincrement
): ):
del new_kwargs[pkname] del new_kwargs[pkname]
return new_kwargs return new_kwargs
@ -197,7 +202,7 @@ class QuerySet:
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
prefetch_related=related prefetch_related=related,
) )
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
@ -398,9 +403,9 @@ class QuerySet:
# refresh server side defaults # refresh server side defaults
if any( if any(
field.server_default is not None field.server_default is not None
for name, field in self.model.Meta.model_fields.items() for name, field in self.model.Meta.model_fields.items()
if name not in kwargs if name not in kwargs
): ):
instance = await instance.load() instance = await instance.load()
instance.set_save_status(True) instance.set_save_status(True)
@ -420,7 +425,7 @@ class QuerySet:
objt.set_save_status(True) objt.set_save_status(True)
async def bulk_update( # noqa: CCR001 async def bulk_update( # noqa: CCR001
self, objects: List["Model"], columns: List[str] = None self, objects: List["Model"], columns: List[str] = None
) -> None: ) -> None:
ready_objects = [] ready_objects = []
pk_name = self.model_meta.pkname pk_name = self.model_meta.pkname

View File

@ -21,6 +21,16 @@ class Tonation(ormar.Model):
name: str = ormar.String(max_length=100) name: str = ormar.String(max_length=100)
class Division(ormar.Model):
class Meta:
tablename = "divisions"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Shop(ormar.Model): class Shop(ormar.Model):
class Meta: class Meta:
tablename = "shops" tablename = "shops"
@ -29,6 +39,7 @@ class Shop(ormar.Model):
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100) name: str = ormar.String(max_length=100)
division: Optional[Division] = ormar.ForeignKey(Division)
class AlbumShops(ormar.Model): class AlbumShops(ormar.Model):
@ -137,8 +148,9 @@ async def test_prefetch_related():
async def test_prefetch_related_with_many_to_many(): async def test_prefetch_related_with_many_to_many():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
shop1 = await Shop.objects.create(name='Shop 1') div = await Division.objects.create(name='Div 1')
shop2 = await Shop.objects.create(name='Shop 2') shop1 = await Shop.objects.create(name='Shop 1', division=div)
shop2 = await Shop.objects.create(name='Shop 2', division=div)
album = Album(name="Malibu") album = Album(name="Malibu")
await album.save() await album.save()
await album.shops.add(shop1) await album.shops.add(shop1)
@ -150,13 +162,15 @@ async def test_prefetch_related_with_many_to_many():
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops"]).get( track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops__division"]).get(
title="The Bird") title="The Bird")
assert track.album.name == "Malibu" assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2 assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1' assert track.album.cover_pictures[0].artist == 'Artist 1'
assert len(track.album.shops) == 2 assert len(track.album.shops) == 2
assert track.album.shops[0].name == 'Shop 1'
assert track.album.shops[0].division.name == 'Div 1'
@pytest.mark.asyncio @pytest.mark.asyncio