refactor and cleanup for further optimization

This commit is contained in:
collerek
2020-11-25 20:52:01 +01:00
parent e0223f8a22
commit d6f995d349
4 changed files with 313 additions and 123 deletions

View File

@ -1,4 +1,5 @@
from typing import ( from typing import (
Any,
Dict, Dict,
List, List,
Optional, Optional,
@ -10,10 +11,11 @@ from typing import (
Union, Union,
) )
import ormar
from ormar.fields import BaseField, ManyToManyField from ormar.fields import BaseField, ManyToManyField
from ormar.queryset.clause import QueryClause from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query from ormar.queryset.query import Query
from ormar.queryset.utils import translate_list_to_dict from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from ormar import Model from ormar import Model
@ -35,65 +37,114 @@ class PrefetchQuery:
self._select_related = select_related self._select_related = select_related
self._exclude_columns = exclude_fields self._exclude_columns = exclude_fields
self._columns = fields self._columns = fields
self.already_extracted: Dict = dict()
self.models: Dict = {}
self.select_dict = translate_list_to_dict(self._select_related)
async def prefetch_related(
self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]:
self.models = extract_models_to_dict_of_lists(
model_type=self.model, models=models, select_dict=self.select_dict
)
self.models[self.model.get_name()] = models
return await self._prefetch_related_models(models=models, rows=rows)
@staticmethod @staticmethod
def _extract_required_ids( def _get_column_name_for_id_extraction(
already_extracted: Dict,
parent_model: Type["Model"], parent_model: Type["Model"],
target_model: Type["Model"], target_model: Type["Model"],
reverse: bool, reverse: bool,
) -> Set: use_raw: bool,
current_data = already_extracted.get(parent_model.get_name(), {}) ) -> str:
raw_rows = current_data.get("raw", [])
table_prefix = current_data.get("prefix", "")
if reverse: if reverse:
column_name = parent_model.get_column_alias(parent_model.Meta.pkname) column_name = parent_model.Meta.pkname
return (
parent_model.get_column_alias(column_name) if use_raw else column_name
)
else: else:
column_name = target_model.resolve_relation_field( column = target_model.resolve_relation_field(parent_model, target_model)
parent_model, target_model return column.get_alias() if use_raw else column.name
).get_alias()
def _extract_ids_from_raw_data(
self, parent_model: Type["Model"], column_name: str
) -> Set:
list_of_ids = set() list_of_ids = set()
current_data = self.already_extracted.get(parent_model.get_name(), {})
table_prefix = current_data.get("prefix", "")
column_name = (f"{table_prefix}_" if table_prefix else "") + column_name column_name = (f"{table_prefix}_" if table_prefix else "") + column_name
for row in raw_rows: for row in current_data.get("raw", []):
if row[column_name]: if row[column_name]:
list_of_ids.add(row[column_name]) list_of_ids.add(row[column_name])
return list_of_ids return list_of_ids
@staticmethod def _extract_ids_from_preloaded_models(
def _get_filter_for_prefetch( self, parent_model: Type["Model"], column_name: str
already_extracted: Dict, ) -> Set:
parent_model: Type["Model"], list_of_ids = set()
target_model: Type["Model"], for model in self.models.get(parent_model.get_name(), []):
reverse: bool, child = getattr(model, column_name)
) -> List: if isinstance(child, ormar.Model):
ids = PrefetchQuery._extract_required_ids( list_of_ids.add(child.pk)
already_extracted=already_extracted, else:
list_of_ids.add(child)
return list_of_ids
def _extract_required_ids(
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
) -> Set:
use_raw = parent_model.get_name() not in self.models
column_name = self._get_column_name_for_id_extraction(
parent_model=parent_model, parent_model=parent_model,
target_model=target_model, target_model=target_model,
reverse=reverse, reverse=reverse,
use_raw=use_raw,
)
if use_raw:
return self._extract_ids_from_raw_data(
parent_model=parent_model, column_name=column_name
)
return self._extract_ids_from_preloaded_models(
parent_model=parent_model, column_name=column_name
)
@staticmethod
def _get_clause_target_and_filter_column_name(
parent_model: Type["Model"], target_model: Type["Model"], reverse: bool
) -> Tuple[Type["Model"], str]:
if reverse:
field = target_model.resolve_relation_field(target_model, parent_model)
if issubclass(field, ManyToManyField):
sub_field = target_model.resolve_relation_field(
field.through, parent_model
)
return field.through, sub_field.get_alias()
else:
return target_model, field.get_alias()
target_field = target_model.get_column_alias(target_model.Meta.pkname)
return target_model, target_field
def _get_filter_for_prefetch(
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
) -> List:
ids = self._extract_required_ids(
parent_model=parent_model, target_model=target_model, reverse=reverse,
) )
if ids: if ids:
qryclause = QueryClause( (
model_cls=target_model, select_related=[], filter_clauses=[], clause_target,
filter_column,
) = self._get_clause_target_and_filter_column_name(
parent_model=parent_model, target_model=target_model, reverse=reverse
) )
if reverse: qryclause = QueryClause(
field = target_model.resolve_relation_field(target_model, parent_model) model_cls=clause_target, select_related=[], filter_clauses=[],
if issubclass(field, ManyToManyField): )
sub_field = target_model.resolve_relation_field( kwargs = {f"{filter_column}__in": ids}
field.through, parent_model
)
kwargs = {f"{sub_field.get_alias()}__in": ids}
qryclause = QueryClause(
model_cls=field.through, select_related=[], filter_clauses=[],
)
else:
kwargs = {f"{field.get_alias()}__in": ids}
else:
target_field = target_model.Meta.model_fields[
target_model.Meta.pkname
].get_alias()
kwargs = {f"{target_field}__in": ids}
filter_clauses, _ = qryclause.filter(**kwargs) filter_clauses, _ = qryclause.filter(**kwargs)
return filter_clauses return filter_clauses
return [] return []
@ -123,7 +174,7 @@ class PrefetchQuery:
@staticmethod @staticmethod
def _get_group_field_name( def _get_group_field_name(
target_field: Type["BaseField"], model: Type["Model"] target_field: Type["BaseField"], model: Union["Model", Type["Model"]]
) -> str: ) -> str:
if issubclass(target_field, ManyToManyField): if issubclass(target_field, ManyToManyField):
return model.resolve_relation_name(target_field.through, model) return model.resolve_relation_name(target_field.through, model)
@ -142,117 +193,150 @@ class PrefetchQuery:
] ]
return related_to_extract return related_to_extract
@staticmethod def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model":
def _populate_nested_related(
model: "Model", already_extracted: Dict, prefetch_dict: Dict
) -> "Model":
related_to_extract = PrefetchQuery._get_names_to_extract( related_to_extract = self._get_names_to_extract(
prefetch_dict=prefetch_dict, model=model prefetch_dict=prefetch_dict, model=model
) )
for related in related_to_extract: for related in related_to_extract:
target_field = model.Meta.model_fields[related] target_field = model.Meta.model_fields[related]
target_model = target_field.to.get_name() target_model = target_field.to.get_name()
is_multi, field_name, model_id = PrefetchQuery._get_model_id_and_field_name( is_multi, field_name, model_id = self._get_model_id_and_field_name(
target_field=target_field, model=model target_field=target_field, model=model
) )
if not field_name:
if field_name is None or model_id is None: # pragma: no cover
continue continue
children = already_extracted.get(target_model, {}).get(field_name, {}) children = self.already_extracted.get(target_model, {}).get(field_name, {})
for key, child_models in children.items(): self._set_children_on_model(
if key == model_id: model=model, related=related, children=children, model_id=model_id
for child in child_models: )
setattr(model, related, child)
return model return model
async def prefetch_related( @staticmethod
self, models: Sequence["Model"], rows: List def _set_children_on_model(
) -> Sequence["Model"]: model: "Model", related: str, children: Dict, model_id: int
return await self._prefetch_related_models(models=models, rows=rows) ) -> None:
for key, child_models in children.items():
if key == model_id:
for child in child_models:
setattr(model, related, child)
async def _prefetch_related_models( async def _prefetch_related_models(
self, models: Sequence["Model"], rows: List self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]: ) -> Sequence["Model"]:
already_extracted = { self.already_extracted = {self.model.get_name(): {"raw": rows}}
self.model.get_name(): {
"raw": rows,
"models": {model.pk: model for model in models},
}
}
select_dict = translate_list_to_dict(self._select_related) select_dict = translate_list_to_dict(self._select_related)
prefetch_dict = translate_list_to_dict(self._prefetch_related) prefetch_dict = translate_list_to_dict(self._prefetch_related)
target_model = self.model target_model = self.model
fields = self._columns fields = self._columns
exclude_fields = self._exclude_columns exclude_fields = self._exclude_columns
for related in prefetch_dict.keys(): for related in prefetch_dict.keys():
subrelated = await self._extract_related_models( await self._extract_related_models(
related=related, related=related,
target_model=target_model, target_model=target_model,
prefetch_dict=prefetch_dict.get(related), prefetch_dict=prefetch_dict.get(related, {}),
select_dict=select_dict.get(related), select_dict=select_dict.get(related, {}),
already_extracted=already_extracted,
fields=fields, fields=fields,
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
) )
print(related, subrelated)
final_models = [] final_models = []
for model in models: for model in models:
final_models.append( final_models.append(
self._populate_nested_related( self._populate_nested_related(model=model, prefetch_dict=prefetch_dict,)
model=model,
already_extracted=already_extracted,
prefetch_dict=prefetch_dict,
)
) )
return models return models
async def _extract_related_models( # noqa: CFQ002 async def _extract_related_models( # noqa: CFQ002, CCR001
self, self,
related: str, related: str,
target_model: Type["Model"], target_model: Type["Model"],
prefetch_dict: Dict, prefetch_dict: Dict,
select_dict: Dict, select_dict: Dict,
already_extracted: Dict, fields: Union[Set[Any], Dict[Any, Any], None],
fields: Dict, exclude_fields: Union[Set[Any], Dict[Any, Any], None],
exclude_fields: Dict,
) -> None: ) -> None:
fields = target_model.get_included(fields, related) fields = target_model.get_included(fields, related)
exclude_fields = target_model.get_excluded(exclude_fields, related) exclude_fields = target_model.get_excluded(exclude_fields, related)
select_related = []
target_field = target_model.Meta.model_fields[related] target_field = target_model.Meta.model_fields[related]
reverse = False reverse = False
if target_field.virtual or issubclass(target_field, ManyToManyField): if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True reverse = True
parent_model = target_model parent_model = target_model
target_model = target_field.to
filter_clauses = PrefetchQuery._get_filter_for_prefetch( filter_clauses = self._get_filter_for_prefetch(
already_extracted=already_extracted, parent_model=parent_model, target_model=target_field.to, reverse=reverse,
parent_model=parent_model,
target_model=target_model,
reverse=reverse,
) )
if not filter_clauses: # related field is empty if not filter_clauses: # related field is empty
return return
already_loaded = select_dict is Ellipsis or related in select_dict
if not already_loaded:
# If not already loaded with select_related
table_prefix, rows = await self._run_prefetch_query(
target_field=target_field,
fields=fields,
exclude_fields=exclude_fields,
filter_clauses=filter_clauses,
)
else:
rows = []
table_prefix = ""
if prefetch_dict and prefetch_dict is not Ellipsis:
for subrelated in prefetch_dict.keys():
await self._extract_related_models(
related=subrelated,
target_model=target_field.to,
prefetch_dict=prefetch_dict.get(subrelated, {}),
select_dict=self._get_select_related_if_apply(
subrelated, select_dict
),
fields=fields,
exclude_fields=exclude_fields,
)
if not already_loaded:
self._populate_rows(
rows=rows,
parent_model=parent_model,
target_field=target_field,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields,
prefetch_dict=prefetch_dict,
)
else:
self._update_already_loaded_rows(
target_field=target_field, prefetch_dict=prefetch_dict,
)
async def _run_prefetch_query(
self,
target_field: Type["BaseField"],
fields: Union[Set[Any], Dict[Any, Any], None],
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
filter_clauses: List,
) -> Tuple[str, List]:
target_model = target_field.to
target_name = target_model.get_name()
select_related = []
query_target = target_model query_target = target_model
table_prefix = "" table_prefix = ""
if issubclass(target_field, ManyToManyField): if issubclass(target_field, ManyToManyField):
query_target = target_field.through query_target = target_field.through
select_related = [target_field.to.get_name()] select_related = [target_name]
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=query_target.Meta.tablename, from_table=query_target.Meta.tablename,
to_table=target_field.to.Meta.tablename, to_table=target_field.to.Meta.tablename,
) )
already_extracted.setdefault(target_model.get_name(), {})[ self.already_extracted.setdefault(target_name, {})["prefix"] = table_prefix
"prefix"
] = table_prefix
qry = Query( qry = Query(
model_cls=query_target, model_cls=query_target,
@ -268,30 +352,41 @@ class PrefetchQuery:
expr = qry.build_select_expression() expr = qry.build_select_expression()
# print(expr.compile(compile_kwargs={"literal_binds": True})) # print(expr.compile(compile_kwargs={"literal_binds": True}))
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
already_extracted.setdefault(target_model.get_name(), {}).update( self.already_extracted.setdefault(target_name, {}).update({"raw": rows})
{"raw": rows, "models": {}} return table_prefix, rows
@staticmethod
def _get_select_related_if_apply(related: str, select_dict: Dict) -> Dict:
return (
select_dict.get(related, {})
if (select_dict and select_dict is not Ellipsis and related in select_dict)
else {}
) )
if prefetch_dict and prefetch_dict is not Ellipsis: def _update_already_loaded_rows( # noqa: CFQ002
for subrelated in prefetch_dict.keys(): self, target_field: Type["BaseField"], prefetch_dict: Dict,
submodels = await self._extract_related_models( ) -> None:
related=subrelated, target_model = target_field.to
target_model=target_model, for instance in self.models.get(target_model.get_name(), []):
prefetch_dict=prefetch_dict.get(subrelated), self._populate_nested_related(
select_dict=select_dict.get(subrelated) model=instance, prefetch_dict=prefetch_dict,
if (select_dict and subrelated in select_dict) )
else {},
already_extracted=already_extracted,
fields=fields,
exclude_fields=exclude_fields,
)
print(subrelated, submodels)
def _populate_rows( # noqa: CFQ002
self,
rows: List,
target_field: Type["BaseField"],
parent_model: Type["Model"],
table_prefix: str,
fields: Union[Set[Any], Dict[Any, Any], None],
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
prefetch_dict: Dict,
) -> None:
target_model = target_field.to
for row in rows: for row in rows:
field_name = PrefetchQuery._get_group_field_name( field_name = self._get_group_field_name(
target_field=target_field, model=parent_model target_field=target_field, model=parent_model
) )
print("TEST", field_name, target_model, row[field_name])
item = target_model.extract_prefixed_table_columns( item = target_model.extract_prefixed_table_columns(
item={}, item={},
row=row, row=row,
@ -301,13 +396,8 @@ class PrefetchQuery:
) )
instance = target_model(**item) instance = target_model(**item)
instance = self._populate_nested_related( instance = self._populate_nested_related(
model=instance, model=instance, prefetch_dict=prefetch_dict,
already_extracted=already_extracted,
prefetch_dict=prefetch_dict,
) )
already_extracted[target_model.get_name()].setdefault( self.already_extracted[target_model.get_name()].setdefault(
field_name, dict() field_name, dict()
).setdefault(row[field_name], []).append(instance) ).setdefault(row[field_name], []).append(instance)
already_extracted[target_model.get_name()]["models"][instance.pk] = instance
return already_extracted[target_model.get_name()]["models"]

View File

@ -67,16 +67,16 @@ class QuerySet:
return self.model_cls return self.model_cls
async def _prefetch_related_models( async def _prefetch_related_models(
self, models: Sequence["Model"], rows: List self, models: Sequence[Optional["Model"]], rows: List
) -> Sequence["Model"]: ) -> Sequence[Optional["Model"]]:
query = PrefetchQuery( query = PrefetchQuery(
model_cls=self.model_cls, model_cls=self.model,
fields=self._columns, fields=self._columns,
exclude_fields=self._exclude_columns, exclude_fields=self._exclude_columns,
prefetch_related=self._prefetch_related, prefetch_related=self._prefetch_related,
select_related=self._select_related, select_related=self._select_related,
) )
return await query.prefetch_related(models=models, rows=rows) return await query.prefetch_related(models=models, rows=rows) # type: ignore
def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]:
result_rows = [ result_rows = [
@ -191,7 +191,7 @@ class QuerySet:
if not isinstance(related, list): if not isinstance(related, list):
related = [related] related = [related]
related = list(set(list(self._select_related) + related)) related = list(set(list(self._prefetch_related) + related))
return self.__class__( return self.__class__(
model_cls=self.model, model_cls=self.model,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
@ -352,7 +352,7 @@ class QuerySet:
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
processed_rows = self._process_query_result_rows(rows) processed_rows = self._process_query_result_rows(rows)
if self._prefetch_related: if self._prefetch_related and processed_rows:
processed_rows = await self._prefetch_related_models(processed_rows, rows) processed_rows = await self._prefetch_related_models(processed_rows, rows)
self.check_single_result_rows_count(processed_rows) self.check_single_result_rows_count(processed_rows)
return processed_rows[0] # type: ignore return processed_rows[0] # type: ignore
@ -379,7 +379,7 @@ class QuerySet:
expr = self.build_select_expression() expr = self.build_select_expression()
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
result_rows = self._process_query_result_rows(rows) result_rows = self._process_query_result_rows(rows)
if self._prefetch_related: if self._prefetch_related and result_rows:
result_rows = await self._prefetch_related_models(result_rows, rows) result_rows = await self._prefetch_related_models(result_rows, rows)
return result_rows return result_rows

View File

@ -1,6 +1,9 @@
import collections.abc import collections.abc
import copy import copy
from typing import Any, Dict, List, Set, Union from typing import Any, Dict, List, Sequence, Set, TYPE_CHECKING, Type, Union
if TYPE_CHECKING: # pragma no cover
from ormar import Model
def check_node_not_dict_or_not_last_node( def check_node_not_dict_or_not_last_node(
@ -55,3 +58,39 @@ def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) ->
dict_to_update = translate_list_to_dict(list_to_update) dict_to_update = translate_list_to_dict(list_to_update)
update(updated_dict, dict_to_update) update(updated_dict, dict_to_update)
return updated_dict return updated_dict
def extract_nested_models( # noqa: CCR001
model: "Model", model_type: Type["Model"], select_dict: Dict, extracted: Dict
) -> None:
follow = [rel for rel in model_type.extract_related_names() if rel in select_dict]
for related in follow:
child = getattr(model, related)
if child:
target_model = model_type.Meta.model_fields[related].to
if isinstance(child, list):
extracted.setdefault(target_model.get_name(), []).extend(child)
if select_dict[related] is not Ellipsis:
for sub_child in child:
extract_nested_models(
sub_child, target_model, select_dict[related], extracted,
)
else:
extracted.setdefault(target_model.get_name(), []).append(child)
if select_dict[related] is not Ellipsis:
extract_nested_models(
child, target_model, select_dict[related], extracted,
)
def extract_models_to_dict_of_lists(
model_type: Type["Model"],
models: Sequence["Model"],
select_dict: Dict,
extracted: Dict = None,
) -> Dict:
if not extracted:
extracted = dict()
for model in models:
extract_nested_models(model, model_type, select_dict, extracted)
return extracted

View File

@ -11,6 +11,16 @@ database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
class RandomSet(ormar.Model):
class Meta:
tablename = "randoms"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Tonation(ormar.Model): class Tonation(ormar.Model):
class Meta: class Meta:
tablename = "tonations" tablename = "tonations"
@ -19,6 +29,7 @@ class Tonation(ormar.Model):
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100) name: str = ormar.String(max_length=100)
rand_set: Optional[RandomSet] = ormar.ForeignKey(RandomSet)
class Division(ormar.Model): class Division(ormar.Model):
@ -181,3 +192,53 @@ async def test_prefetch_related_empty():
track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird") track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird")
assert track.title == 'The Bird' assert track.title == 'The Bird'
assert track.album is None assert track.album is None
@pytest.mark.asyncio
async def test_prefetch_related_with_select_related():
async with database:
async with database.transaction(force_rollback=True):
div = await Division.objects.create(name='Div 1')
shop1 = await Shop.objects.create(name='Shop 1', division=div)
shop2 = await Shop.objects.create(name='Shop 2', division=div)
album = Album(name="Malibu")
await album.save()
await album.shops.add(shop1)
await album.shops.add(shop2)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related(
['cover_pictures', 'shops__division']).get()
assert len(album.tracks) == 0
assert len(album.cover_pictures) == 2
assert album.shops[0].division.name == 'Div 1'
rand_set = await RandomSet.objects.create(name='Rand 1')
ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set)
await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1)
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1)
await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1)
album = await Album.objects.select_related('tracks__tonation__rand_set').filter(name='Malibu').prefetch_related(
['cover_pictures', 'shops__division']).get()
assert len(album.tracks) == 3
assert album.tracks[0].tonation == album.tracks[2].tonation == ton1
assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].artist == 'Artist 1'
assert len(album.shops) == 2
assert album.shops[0].name == 'Shop 1'
assert album.shops[0].division.name == 'Div 1'
track = await Track.objects.select_related('album').prefetch_related(
["album__cover_pictures", "album__shops__division"]).get(
title="The Bird")
assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1'
assert len(track.album.shops) == 2
assert track.album.shops[0].name == 'Shop 1'
assert track.album.shops[0].division.name == 'Div 1'