cleaner version but still dirty

This commit is contained in:
collerek
2020-11-25 13:28:51 +01:00
parent 585bba3ad3
commit f2fe41d38a
3 changed files with 300 additions and 192 deletions

View File

@ -1,228 +1,317 @@
from typing import Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union
from typing import (
Dict,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
Tuple,
Type,
Union,
)
import ormar
from ormar.fields import ManyToManyField
from ormar.fields import BaseField, ManyToManyField
from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query
from ormar.queryset.utils import translate_list_to_dict
if TYPE_CHECKING: # pragma: no cover
from ormar import Model
class PrefetchQuery:
def __init__(self,
model_cls: Type["Model"],
fields: Optional[Union[Dict, Set]],
exclude_fields: Optional[Union[Dict, Set]],
prefetch_related: List
):
def __init__(
self,
model_cls: Type["Model"],
fields: Optional[Union[Dict, Set]],
exclude_fields: Optional[Union[Dict, Set]],
prefetch_related: List,
select_related: List,
) -> None:
self.model = model_cls
self.database = self.model.Meta.database
self._prefetch_related = prefetch_related
self._select_related = select_related
self._exclude_columns = exclude_fields
self._columns = fields
@staticmethod
def _extract_required_ids(already_extracted: Dict,
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool) -> Set:
raw_rows = already_extracted.get(parent_model.get_name(), {}).get('raw', [])
def _extract_required_ids(
already_extracted: Dict,
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool,
) -> Set:
raw_rows = already_extracted.get(parent_model.get_name(), {}).get("raw", [])
table_prefix = already_extracted.get(parent_model.get_name(), {}).get(
"prefix", ""
)
if reverse:
column_name = parent_model.get_column_alias(parent_model.Meta.pkname)
else:
column_name = target_model.resolve_relation_field(parent_model, target_model).get_alias()
column_name = target_model.resolve_relation_field(
parent_model, target_model
).get_alias()
list_of_ids = set()
column_name = (f"{table_prefix}_" if table_prefix else "") + column_name
for row in raw_rows:
if row[column_name]:
list_of_ids.add(row[column_name])
return list_of_ids
@staticmethod
def _get_filter_for_prefetch(already_extracted: Dict,
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool) -> List:
ids = PrefetchQuery._extract_required_ids(already_extracted=already_extracted,
parent_model=parent_model,
target_model=target_model,
reverse=reverse)
def _get_filter_for_prefetch(
already_extracted: Dict,
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool,
) -> List:
ids = PrefetchQuery._extract_required_ids(
already_extracted=already_extracted,
parent_model=parent_model,
target_model=target_model,
reverse=reverse,
)
if ids:
qryclause = QueryClause(
model_cls=target_model,
select_related=[],
filter_clauses=[],
model_cls=target_model, select_related=[], filter_clauses=[],
)
if reverse:
field = target_model.resolve_relation_field(target_model, parent_model)
if issubclass(field, ManyToManyField):
sub_field = target_model.resolve_relation_field(field.through, parent_model)
kwargs = {f'{sub_field.get_alias()}__in': ids}
sub_field = target_model.resolve_relation_field(
field.through, parent_model
)
kwargs = {f"{sub_field.get_alias()}__in": ids}
qryclause = QueryClause(
model_cls=field.through,
select_related=[],
filter_clauses=[],
model_cls=field.through, select_related=[], filter_clauses=[],
)
else:
kwargs = {f'{field.get_alias()}__in': ids}
kwargs = {f"{field.get_alias()}__in": ids}
else:
target_field = target_model.Meta.model_fields[target_model.Meta.pkname].get_alias()
kwargs = {f'{target_field}__in': ids}
target_field = target_model.Meta.model_fields[
target_model.Meta.pkname
].get_alias()
kwargs = {f"{target_field}__in": ids}
filter_clauses, _ = qryclause.filter(**kwargs)
return filter_clauses
return []
@staticmethod
def _populate_nested_related(model: "Model",
already_extracted: Dict) -> "Model":
for related in model.extract_related_names():
def _get_model_id_and_field_name(
target_field: Type["BaseField"], model: "Model"
) -> Tuple[bool, Optional[str], Optional[int]]:
if target_field.virtual:
reverse = True
field_name = model.resolve_relation_name(target_field.to, model)
model_id = model.pk
elif issubclass(target_field, ManyToManyField):
reverse = True
field_name = model.resolve_relation_name(target_field.through, model)
model_id = model.pk
else:
reverse = False
related_name = model.resolve_relation_name(model, target_field.to)
related_model = getattr(model, related_name)
if not related_model:
return reverse, None, None
model_id = related_model.pk
field_name = target_field.to.Meta.pkname
return reverse, field_name, model_id
@staticmethod
def _get_names_to_extract(prefetch_dict: Dict, model: "Model") -> List:
related_to_extract = []
if prefetch_dict and prefetch_dict is not Ellipsis:
related_to_extract = [
related
for related in model.extract_related_names()
if related in prefetch_dict
]
return related_to_extract
@staticmethod
def _populate_nested_related(
model: "Model", already_extracted: Dict, prefetch_dict: Dict
) -> "Model":
related_to_extract = PrefetchQuery._get_names_to_extract(
prefetch_dict=prefetch_dict, model=model
)
for related in related_to_extract:
target_field = model.Meta.model_fields[related]
target_model = target_field.to.get_name()
if target_field.virtual:
reverse = True
field_name = model.resolve_relation_name(target_field.to, model)
model_id = model.pk
elif issubclass(target_field, ManyToManyField):
reverse = True
field_name = model.resolve_relation_name(target_field.through, model)
model_id = model.pk
else:
related_name = model.resolve_relation_name(model, target_field.to)
related_model = getattr(model, related_name)
if not related_model:
continue
model_id = related_model.pk
field_name = target_field.to.Meta.pkname
reverse, field_name, model_id = PrefetchQuery._get_model_id_and_field_name(
target_field=target_field, model=model
)
if target_model in already_extracted and already_extracted[target_model]['models']:
print('*****POPULATING RELATED:', target_model, field_name, '*****', end='\n')
print(already_extracted[target_model]['models'])
for ind, child_model in enumerate(already_extracted[target_model]['models']):
if (
target_model in already_extracted
and already_extracted[target_model]["models"]
):
for key, child_model in already_extracted[target_model][
"models"
].items():
if issubclass(target_field, ManyToManyField):
raw_data = already_extracted[target_model]['raw'][ind]
if raw_data[field_name] == model_id:
ind = next(
i
if key == x[target_field.to.get_column_alias(field_name)]
else -1
for i, x in enumerate(
already_extracted[target_model]["raw"]
)
)
raw_data = already_extracted[target_model]["raw"][ind]
if (
raw_data
and field_name in raw_data
and raw_data[field_name] == model_id
):
setattr(model, related, child_model)
elif isinstance(getattr(child_model, field_name), ormar.Model):
if getattr(child_model, field_name).pk == model_id:
if reverse:
setattr(model, related, child_model)
else:
setattr(child_model, related, model)
setattr(model, related, child_model)
else: # we have not reverse relation and related_model is a pk value
elif getattr(child_model, field_name) == model_id:
setattr(model, related, child_model)
return model
async def prefetch_related(self, models: Sequence["Model"], rows: List):
async def prefetch_related(
self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]:
return await self._prefetch_related_models(models=models, rows=rows)
async def _prefetch_related_models(self,
models: Sequence["Model"],
rows: List) -> Sequence["Model"]:
already_extracted = {self.model.get_name(): {'raw': rows, 'models': models}}
for related in self._prefetch_related:
target_model = self.model
fields = self._columns
exclude_fields = self._exclude_columns
for part in related.split('__'):
fields = target_model.get_included(fields, part)
select_related = []
exclude_fields = target_model.get_excluded(exclude_fields, part)
target_field = target_model.Meta.model_fields[part]
reverse = False
if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True
if issubclass(target_field, ManyToManyField):
select_related = [target_field.through.get_name()]
parent_model = target_model
target_model = target_field.to
if target_model.get_name() not in already_extracted:
filter_clauses = self._get_filter_for_prefetch(already_extracted=already_extracted,
parent_model=parent_model,
target_model=target_model,
reverse=reverse)
if not filter_clauses: # related field is empty
continue
query_target = target_model
table_prefix = ''
if issubclass(target_field, ManyToManyField):
query_target = target_field.through
select_related = [target_field.to.get_name()]
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=query_target.Meta.tablename, to_table=target_field.to.Meta.tablename)
qry = Query(
model_cls=query_target,
select_related=select_related,
filter_clauses=filter_clauses,
exclude_clauses=[],
offset=None,
limit_count=None,
fields=fields,
exclude_fields=exclude_fields,
order_bys=None,
)
expr = qry.build_select_expression()
print(expr.compile(compile_kwargs={"literal_binds": True}))
rows = await self.database.fetch_all(expr)
already_extracted[target_model.get_name()] = {'raw': rows, 'models': []}
if part == related.split('__')[-1]:
for row in rows:
item = target_model.extract_prefixed_table_columns(
item={},
row=row,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields
)
instance = target_model(**item)
already_extracted[target_model.get_name()]['models'].append(instance)
target_model = self.model
fields = self._columns
exclude_fields = self._exclude_columns
for part in related.split('__')[:-1]:
fields = target_model.get_included(fields, part)
exclude_fields = target_model.get_excluded(exclude_fields, part)
target_field = target_model.Meta.model_fields[part]
target_model = target_field.to
table_prefix = ''
if issubclass(target_field, ManyToManyField):
from_table = target_field.through.Meta.tablename
to_name = target_field.to.Meta.tablename
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=from_table, to_table=to_name)
for row in already_extracted.get(target_model.get_name(), {}).get('raw', []):
item = target_model.extract_prefixed_table_columns(
item={},
row=row,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields
)
instance = target_model(**item)
instance = self._populate_nested_related(model=instance,
already_extracted=already_extracted)
already_extracted[target_model.get_name()]['models'].append(instance)
async def _prefetch_related_models(
self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]:
already_extracted = {
self.model.get_name(): {
"raw": rows,
"models": {model.pk: model for model in models},
}
}
select_dict = translate_list_to_dict(self._select_related)
prefetch_dict = translate_list_to_dict(self._prefetch_related)
target_model = self.model
fields = self._columns
exclude_fields = self._exclude_columns
for related in prefetch_dict.keys():
await self._extract_related_models(
related=related,
target_model=target_model,
prefetch_dict=prefetch_dict.get(related),
select_dict=select_dict.get(related),
already_extracted=already_extracted,
fields=fields,
exclude_fields=exclude_fields,
)
final_models = []
for model in models:
final_models.append(self._populate_nested_related(model=model,
already_extracted=already_extracted))
final_models.append(
self._populate_nested_related(
model=model,
already_extracted=already_extracted,
prefetch_dict=prefetch_dict,
)
)
return models
async def _extract_related_models( # noqa: CFQ002
self,
related: str,
target_model: Type["Model"],
prefetch_dict: Dict,
select_dict: Dict,
already_extracted: Dict,
fields: Dict,
exclude_fields: Dict,
) -> None:
fields = target_model.get_included(fields, related)
exclude_fields = target_model.get_excluded(exclude_fields, related)
select_related = []
target_field = target_model.Meta.model_fields[related]
reverse = False
if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True
parent_model = target_model
target_model = target_field.to
filter_clauses = PrefetchQuery._get_filter_for_prefetch(
already_extracted=already_extracted,
parent_model=parent_model,
target_model=target_model,
reverse=reverse,
)
if not filter_clauses: # related field is empty
return
query_target = target_model
table_prefix = ""
if issubclass(target_field, ManyToManyField):
query_target = target_field.through
select_related = [target_field.to.get_name()]
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=query_target.Meta.tablename,
to_table=target_field.to.Meta.tablename,
)
already_extracted.setdefault(target_model.get_name(), {})[
"prefix"
] = table_prefix
qry = Query(
model_cls=query_target,
select_related=select_related,
filter_clauses=filter_clauses,
exclude_clauses=[],
offset=None,
limit_count=None,
fields=fields,
exclude_fields=exclude_fields,
order_bys=None,
)
expr = qry.build_select_expression()
# print(expr.compile(compile_kwargs={"literal_binds": True}))
rows = await self.database.fetch_all(expr)
already_extracted.setdefault(target_model.get_name(), {}).update(
{"raw": rows, "models": {}}
)
if prefetch_dict and prefetch_dict is not Ellipsis:
for subrelated in prefetch_dict.keys():
await self._extract_related_models(
related=subrelated,
target_model=target_model,
prefetch_dict=prefetch_dict.get(subrelated),
select_dict=select_dict.get(subrelated)
if (select_dict and subrelated in select_dict)
else {},
already_extracted=already_extracted,
fields=fields,
exclude_fields=exclude_fields,
)
for row in rows:
item = target_model.extract_prefixed_table_columns(
item={},
row=row,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields,
)
instance = target_model(**item)
instance = self._populate_nested_related(
model=instance,
already_extracted=already_extracted,
prefetch_dict=prefetch_dict,
)
already_extracted[target_model.get_name()]["models"][instance.pk] = instance

View File

@ -21,17 +21,17 @@ if TYPE_CHECKING: # pragma no cover
class QuerySet:
def __init__( # noqa CFQ002
self,
model_cls: Type["Model"] = None,
filter_clauses: List = None,
exclude_clauses: List = None,
select_related: List = None,
limit_count: int = None,
offset: int = None,
columns: Dict = None,
exclude_columns: Dict = None,
order_bys: List = None,
prefetch_related: List = None,
self,
model_cls: Type["Model"] = None,
filter_clauses: List = None,
exclude_clauses: List = None,
select_related: List = None,
limit_count: int = None,
offset: int = None,
columns: Dict = None,
exclude_columns: Dict = None,
order_bys: List = None,
prefetch_related: List = None,
) -> None:
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
@ -45,9 +45,9 @@ class QuerySet:
self.order_bys = order_bys or []
def __get__(
self,
instance: Optional[Union["QuerySet", "QuerysetProxy"]],
owner: Union[Type["Model"], Type["QuerysetProxy"]],
self,
instance: Optional[Union["QuerySet", "QuerysetProxy"]],
owner: Union[Type["Model"], Type["QuerysetProxy"]],
) -> "QuerySet":
if issubclass(owner, ormar.Model):
return self.__class__(model_cls=owner)
@ -66,11 +66,16 @@ class QuerySet:
raise ValueError("Model class of QuerySet is not initialized")
return self.model_cls
async def _prefetch_related_models(self, models: Sequence["Model"], rows: List) -> Sequence["Model"]:
query = PrefetchQuery(model_cls=self.model_cls,
fields=self._columns,
exclude_fields=self._exclude_columns,
prefetch_related=self._prefetch_related)
async def _prefetch_related_models(
self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]:
query = PrefetchQuery(
model_cls=self.model_cls,
fields=self._columns,
exclude_fields=self._exclude_columns,
prefetch_related=self._prefetch_related,
select_related=self._select_related,
)
return await query.prefetch_related(models=models, rows=rows)
def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]:
@ -98,7 +103,7 @@ class QuerySet:
pkname = self.model_meta.pkname
pk = self.model_meta.model_fields[pkname]
if new_kwargs.get(pkname, ormar.Undefined) is None and (
pk.nullable or pk.autoincrement
pk.nullable or pk.autoincrement
):
del new_kwargs[pkname]
return new_kwargs
@ -197,7 +202,7 @@ class QuerySet:
columns=self._columns,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys,
prefetch_related=related
prefetch_related=related,
)
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
@ -398,9 +403,9 @@ class QuerySet:
# refresh server side defaults
if any(
field.server_default is not None
for name, field in self.model.Meta.model_fields.items()
if name not in kwargs
field.server_default is not None
for name, field in self.model.Meta.model_fields.items()
if name not in kwargs
):
instance = await instance.load()
instance.set_save_status(True)
@ -420,7 +425,7 @@ class QuerySet:
objt.set_save_status(True)
async def bulk_update( # noqa: CCR001
self, objects: List["Model"], columns: List[str] = None
self, objects: List["Model"], columns: List[str] = None
) -> None:
ready_objects = []
pk_name = self.model_meta.pkname

View File

@ -21,6 +21,16 @@ class Tonation(ormar.Model):
name: str = ormar.String(max_length=100)
class Division(ormar.Model):
class Meta:
tablename = "divisions"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Shop(ormar.Model):
class Meta:
tablename = "shops"
@ -29,6 +39,7 @@ class Shop(ormar.Model):
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
division: Optional[Division] = ormar.ForeignKey(Division)
class AlbumShops(ormar.Model):
@ -137,8 +148,9 @@ async def test_prefetch_related():
async def test_prefetch_related_with_many_to_many():
async with database:
async with database.transaction(force_rollback=True):
shop1 = await Shop.objects.create(name='Shop 1')
shop2 = await Shop.objects.create(name='Shop 2')
div = await Division.objects.create(name='Div 1')
shop1 = await Shop.objects.create(name='Shop 1', division=div)
shop2 = await Shop.objects.create(name='Shop 2', division=div)
album = Album(name="Malibu")
await album.save()
await album.shops.add(shop1)
@ -150,13 +162,15 @@ async def test_prefetch_related_with_many_to_many():
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops"]).get(
track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops__division"]).get(
title="The Bird")
assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1'
assert len(track.album.shops) == 2
assert track.album.shops[0].name == 'Shop 1'
assert track.album.shops[0].division.name == 'Div 1'
@pytest.mark.asyncio