Files
ormar/ormar/queryset/prefetch_query.py
2020-11-26 06:33:24 +01:00

392 lines
14 KiB
Python

from typing import (
Any,
Dict,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
Tuple,
Type,
Union,
)
import ormar
from ormar.fields import BaseField, ManyToManyField
from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query
from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict
if TYPE_CHECKING: # pragma: no cover
from ormar import Model
class PrefetchQuery:
def __init__(
self,
model_cls: Type["Model"],
fields: Optional[Union[Dict, Set]],
exclude_fields: Optional[Union[Dict, Set]],
prefetch_related: List,
select_related: List,
) -> None:
self.model = model_cls
self.database = self.model.Meta.database
self._prefetch_related = prefetch_related
self._select_related = select_related
self._exclude_columns = exclude_fields
self._columns = fields
self.already_extracted: Dict = dict()
self.models: Dict = {}
self.select_dict = translate_list_to_dict(self._select_related)
async def prefetch_related(
self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]:
self.models = extract_models_to_dict_of_lists(
model_type=self.model, models=models, select_dict=self.select_dict
)
self.models[self.model.get_name()] = models
return await self._prefetch_related_models(models=models, rows=rows)
@staticmethod
def _get_column_name_for_id_extraction(
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool,
use_raw: bool,
) -> str:
if reverse:
column_name = parent_model.Meta.pkname
return (
parent_model.get_column_alias(column_name) if use_raw else column_name
)
else:
column = target_model.resolve_relation_field(parent_model, target_model)
return column.get_alias() if use_raw else column.name
def _extract_ids_from_raw_data(
self, parent_model: Type["Model"], column_name: str
) -> Set:
list_of_ids = set()
current_data = self.already_extracted.get(parent_model.get_name(), {})
table_prefix = current_data.get("prefix", "")
column_name = (f"{table_prefix}_" if table_prefix else "") + column_name
for row in current_data.get("raw", []):
if row[column_name]:
list_of_ids.add(row[column_name])
return list_of_ids
def _extract_ids_from_preloaded_models(
self, parent_model: Type["Model"], column_name: str
) -> Set:
list_of_ids = set()
for model in self.models.get(parent_model.get_name(), []):
child = getattr(model, column_name)
if isinstance(child, ormar.Model):
list_of_ids.add(child.pk)
else:
list_of_ids.add(child)
return list_of_ids
def _extract_required_ids(
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
) -> Set:
use_raw = parent_model.get_name() not in self.models
column_name = self._get_column_name_for_id_extraction(
parent_model=parent_model,
target_model=target_model,
reverse=reverse,
use_raw=use_raw,
)
if use_raw:
return self._extract_ids_from_raw_data(
parent_model=parent_model, column_name=column_name
)
return self._extract_ids_from_preloaded_models(
parent_model=parent_model, column_name=column_name
)
@staticmethod
def _get_clause_target_and_filter_column_name(
parent_model: Type["Model"], target_model: Type["Model"], reverse: bool
) -> Tuple[Type["Model"], str]:
if reverse:
field = target_model.resolve_relation_field(target_model, parent_model)
if issubclass(field, ManyToManyField):
sub_field = target_model.resolve_relation_field(
field.through, parent_model
)
return field.through, sub_field.get_alias()
else:
return target_model, field.get_alias()
target_field = target_model.get_column_alias(target_model.Meta.pkname)
return target_model, target_field
def _get_filter_for_prefetch(
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
) -> List:
ids = self._extract_required_ids(
parent_model=parent_model, target_model=target_model, reverse=reverse,
)
if ids:
(
clause_target,
filter_column,
) = self._get_clause_target_and_filter_column_name(
parent_model=parent_model, target_model=target_model, reverse=reverse
)
qryclause = QueryClause(
model_cls=clause_target, select_related=[], filter_clauses=[],
)
kwargs = {f"{filter_column}__in": ids}
filter_clauses, _ = qryclause.filter(**kwargs)
return filter_clauses
return []
@staticmethod
def _get_model_id(target_field: Type["BaseField"], model: "Model") -> Optional[int]:
if target_field.virtual or issubclass(target_field, ManyToManyField):
return model.pk
related_name = model.resolve_relation_name(model, target_field.to)
related_model = getattr(model, related_name)
return None if not related_model else related_model.pk
@staticmethod
def _get_related_field_name(
target_field: Type["BaseField"], model: Union["Model", Type["Model"]]
) -> str:
if issubclass(target_field, ManyToManyField):
return model.resolve_relation_name(target_field.through, model)
if target_field.virtual:
return model.resolve_relation_name(target_field.to, model)
return target_field.to.Meta.pkname
@staticmethod
def _get_names_to_extract(prefetch_dict: Dict, model: "Model") -> List:
related_to_extract = []
if prefetch_dict and prefetch_dict is not Ellipsis:
related_to_extract = [
related
for related in model.extract_related_names()
if related in prefetch_dict
]
return related_to_extract
def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model":
related_to_extract = self._get_names_to_extract(
prefetch_dict=prefetch_dict, model=model
)
for related in related_to_extract:
target_field = model.Meta.model_fields[related]
target_model = target_field.to.get_name()
model_id = self._get_model_id(target_field=target_field, model=model)
if model_id is None: # pragma: no cover
continue
field_name = self._get_related_field_name(
target_field=target_field, model=model
)
children = self.already_extracted.get(target_model, {}).get(field_name, {})
self._set_children_on_model(
model=model, related=related, children=children, model_id=model_id
)
return model
@staticmethod
def _set_children_on_model(
model: "Model", related: str, children: Dict, model_id: int
) -> None:
for key, child_models in children.items():
if key == model_id:
for child in child_models:
setattr(model, related, child)
async def _prefetch_related_models(
self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]:
self.already_extracted = {self.model.get_name(): {"raw": rows}}
select_dict = translate_list_to_dict(self._select_related)
prefetch_dict = translate_list_to_dict(self._prefetch_related)
target_model = self.model
fields = self._columns
exclude_fields = self._exclude_columns
for related in prefetch_dict.keys():
await self._extract_related_models(
related=related,
target_model=target_model,
prefetch_dict=prefetch_dict.get(related, {}),
select_dict=select_dict.get(related, {}),
fields=fields,
exclude_fields=exclude_fields,
)
final_models = []
for model in models:
final_models.append(
self._populate_nested_related(model=model, prefetch_dict=prefetch_dict,)
)
return models
async def _extract_related_models( # noqa: CFQ002, CCR001
self,
related: str,
target_model: Type["Model"],
prefetch_dict: Dict,
select_dict: Dict,
fields: Union[Set[Any], Dict[Any, Any], None],
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
) -> None:
fields = target_model.get_included(fields, related)
exclude_fields = target_model.get_excluded(exclude_fields, related)
target_field = target_model.Meta.model_fields[related]
reverse = False
if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True
parent_model = target_model
filter_clauses = self._get_filter_for_prefetch(
parent_model=parent_model, target_model=target_field.to, reverse=reverse,
)
if not filter_clauses: # related field is empty
return
already_loaded = select_dict is Ellipsis or related in select_dict
if not already_loaded:
# If not already loaded with select_related
table_prefix, rows = await self._run_prefetch_query(
target_field=target_field,
fields=fields,
exclude_fields=exclude_fields,
filter_clauses=filter_clauses,
)
else:
rows = []
table_prefix = ""
if prefetch_dict and prefetch_dict is not Ellipsis:
for subrelated in prefetch_dict.keys():
await self._extract_related_models(
related=subrelated,
target_model=target_field.to,
prefetch_dict=prefetch_dict.get(subrelated, {}),
select_dict=self._get_select_related_if_apply(
subrelated, select_dict
),
fields=fields,
exclude_fields=exclude_fields,
)
if not already_loaded:
self._populate_rows(
rows=rows,
parent_model=parent_model,
target_field=target_field,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields,
prefetch_dict=prefetch_dict,
)
else:
self._update_already_loaded_rows(
target_field=target_field, prefetch_dict=prefetch_dict,
)
async def _run_prefetch_query(
self,
target_field: Type["BaseField"],
fields: Union[Set[Any], Dict[Any, Any], None],
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
filter_clauses: List,
) -> Tuple[str, List]:
target_model = target_field.to
target_name = target_model.get_name()
select_related = []
query_target = target_model
table_prefix = ""
if issubclass(target_field, ManyToManyField):
query_target = target_field.through
select_related = [target_name]
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=query_target.Meta.tablename,
to_table=target_field.to.Meta.tablename,
)
self.already_extracted.setdefault(target_name, {})["prefix"] = table_prefix
qry = Query(
model_cls=query_target,
select_related=select_related,
filter_clauses=filter_clauses,
exclude_clauses=[],
offset=None,
limit_count=None,
fields=fields,
exclude_fields=exclude_fields,
order_bys=None,
)
expr = qry.build_select_expression()
# print(expr.compile(compile_kwargs={"literal_binds": True}))
rows = await self.database.fetch_all(expr)
self.already_extracted.setdefault(target_name, {}).update({"raw": rows})
return table_prefix, rows
@staticmethod
def _get_select_related_if_apply(related: str, select_dict: Dict) -> Dict:
return (
select_dict.get(related, {})
if (select_dict and select_dict is not Ellipsis and related in select_dict)
else {}
)
def _update_already_loaded_rows( # noqa: CFQ002
self, target_field: Type["BaseField"], prefetch_dict: Dict,
) -> None:
target_model = target_field.to
for instance in self.models.get(target_model.get_name(), []):
self._populate_nested_related(
model=instance, prefetch_dict=prefetch_dict,
)
def _populate_rows( # noqa: CFQ002
self,
rows: List,
target_field: Type["BaseField"],
parent_model: Type["Model"],
table_prefix: str,
fields: Union[Set[Any], Dict[Any, Any], None],
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
prefetch_dict: Dict,
) -> None:
target_model = target_field.to
for row in rows:
field_name = self._get_related_field_name(
target_field=target_field, model=parent_model
)
item = target_model.extract_prefixed_table_columns(
item={},
row=row,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields,
)
instance = target_model(**item)
instance = self._populate_nested_related(
model=instance, prefetch_dict=prefetch_dict,
)
field_db_name = target_model.get_column_alias(field_name)
self.already_extracted[target_model.get_name()].setdefault(
field_name, dict()
).setdefault(row[field_db_name], []).append(instance)