refactor and cleanup for further optimization
This commit is contained in:
@ -1,4 +1,5 @@
|
|||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
@ -10,10 +11,11 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import ormar
|
||||||
from ormar.fields import BaseField, ManyToManyField
|
from ormar.fields import BaseField, ManyToManyField
|
||||||
from ormar.queryset.clause import QueryClause
|
from ormar.queryset.clause import QueryClause
|
||||||
from ormar.queryset.query import Query
|
from ormar.queryset.query import Query
|
||||||
from ormar.queryset.utils import translate_list_to_dict
|
from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
@ -35,65 +37,114 @@ class PrefetchQuery:
|
|||||||
self._select_related = select_related
|
self._select_related = select_related
|
||||||
self._exclude_columns = exclude_fields
|
self._exclude_columns = exclude_fields
|
||||||
self._columns = fields
|
self._columns = fields
|
||||||
|
self.already_extracted: Dict = dict()
|
||||||
|
self.models: Dict = {}
|
||||||
|
self.select_dict = translate_list_to_dict(self._select_related)
|
||||||
|
|
||||||
|
async def prefetch_related(
|
||||||
|
self, models: Sequence["Model"], rows: List
|
||||||
|
) -> Sequence["Model"]:
|
||||||
|
self.models = extract_models_to_dict_of_lists(
|
||||||
|
model_type=self.model, models=models, select_dict=self.select_dict
|
||||||
|
)
|
||||||
|
self.models[self.model.get_name()] = models
|
||||||
|
return await self._prefetch_related_models(models=models, rows=rows)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_required_ids(
|
def _get_column_name_for_id_extraction(
|
||||||
already_extracted: Dict,
|
|
||||||
parent_model: Type["Model"],
|
parent_model: Type["Model"],
|
||||||
target_model: Type["Model"],
|
target_model: Type["Model"],
|
||||||
reverse: bool,
|
reverse: bool,
|
||||||
) -> Set:
|
use_raw: bool,
|
||||||
current_data = already_extracted.get(parent_model.get_name(), {})
|
) -> str:
|
||||||
raw_rows = current_data.get("raw", [])
|
|
||||||
table_prefix = current_data.get("prefix", "")
|
|
||||||
if reverse:
|
if reverse:
|
||||||
column_name = parent_model.get_column_alias(parent_model.Meta.pkname)
|
column_name = parent_model.Meta.pkname
|
||||||
|
return (
|
||||||
|
parent_model.get_column_alias(column_name) if use_raw else column_name
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
column_name = target_model.resolve_relation_field(
|
column = target_model.resolve_relation_field(parent_model, target_model)
|
||||||
parent_model, target_model
|
return column.get_alias() if use_raw else column.name
|
||||||
).get_alias()
|
|
||||||
|
def _extract_ids_from_raw_data(
|
||||||
|
self, parent_model: Type["Model"], column_name: str
|
||||||
|
) -> Set:
|
||||||
list_of_ids = set()
|
list_of_ids = set()
|
||||||
|
current_data = self.already_extracted.get(parent_model.get_name(), {})
|
||||||
|
table_prefix = current_data.get("prefix", "")
|
||||||
column_name = (f"{table_prefix}_" if table_prefix else "") + column_name
|
column_name = (f"{table_prefix}_" if table_prefix else "") + column_name
|
||||||
for row in raw_rows:
|
for row in current_data.get("raw", []):
|
||||||
if row[column_name]:
|
if row[column_name]:
|
||||||
list_of_ids.add(row[column_name])
|
list_of_ids.add(row[column_name])
|
||||||
return list_of_ids
|
return list_of_ids
|
||||||
|
|
||||||
@staticmethod
|
def _extract_ids_from_preloaded_models(
|
||||||
def _get_filter_for_prefetch(
|
self, parent_model: Type["Model"], column_name: str
|
||||||
already_extracted: Dict,
|
) -> Set:
|
||||||
parent_model: Type["Model"],
|
list_of_ids = set()
|
||||||
target_model: Type["Model"],
|
for model in self.models.get(parent_model.get_name(), []):
|
||||||
reverse: bool,
|
child = getattr(model, column_name)
|
||||||
) -> List:
|
if isinstance(child, ormar.Model):
|
||||||
ids = PrefetchQuery._extract_required_ids(
|
list_of_ids.add(child.pk)
|
||||||
already_extracted=already_extracted,
|
else:
|
||||||
|
list_of_ids.add(child)
|
||||||
|
return list_of_ids
|
||||||
|
|
||||||
|
def _extract_required_ids(
|
||||||
|
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
|
||||||
|
) -> Set:
|
||||||
|
|
||||||
|
use_raw = parent_model.get_name() not in self.models
|
||||||
|
|
||||||
|
column_name = self._get_column_name_for_id_extraction(
|
||||||
parent_model=parent_model,
|
parent_model=parent_model,
|
||||||
target_model=target_model,
|
target_model=target_model,
|
||||||
reverse=reverse,
|
reverse=reverse,
|
||||||
|
use_raw=use_raw,
|
||||||
)
|
)
|
||||||
if ids:
|
|
||||||
qryclause = QueryClause(
|
if use_raw:
|
||||||
model_cls=target_model, select_related=[], filter_clauses=[],
|
return self._extract_ids_from_raw_data(
|
||||||
|
parent_model=parent_model, column_name=column_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return self._extract_ids_from_preloaded_models(
|
||||||
|
parent_model=parent_model, column_name=column_name
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_clause_target_and_filter_column_name(
|
||||||
|
parent_model: Type["Model"], target_model: Type["Model"], reverse: bool
|
||||||
|
) -> Tuple[Type["Model"], str]:
|
||||||
if reverse:
|
if reverse:
|
||||||
field = target_model.resolve_relation_field(target_model, parent_model)
|
field = target_model.resolve_relation_field(target_model, parent_model)
|
||||||
if issubclass(field, ManyToManyField):
|
if issubclass(field, ManyToManyField):
|
||||||
sub_field = target_model.resolve_relation_field(
|
sub_field = target_model.resolve_relation_field(
|
||||||
field.through, parent_model
|
field.through, parent_model
|
||||||
)
|
)
|
||||||
kwargs = {f"{sub_field.get_alias()}__in": ids}
|
return field.through, sub_field.get_alias()
|
||||||
qryclause = QueryClause(
|
else:
|
||||||
model_cls=field.through, select_related=[], filter_clauses=[],
|
return target_model, field.get_alias()
|
||||||
)
|
target_field = target_model.get_column_alias(target_model.Meta.pkname)
|
||||||
|
return target_model, target_field
|
||||||
|
|
||||||
else:
|
def _get_filter_for_prefetch(
|
||||||
kwargs = {f"{field.get_alias()}__in": ids}
|
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
|
||||||
else:
|
) -> List:
|
||||||
target_field = target_model.Meta.model_fields[
|
ids = self._extract_required_ids(
|
||||||
target_model.Meta.pkname
|
parent_model=parent_model, target_model=target_model, reverse=reverse,
|
||||||
].get_alias()
|
)
|
||||||
kwargs = {f"{target_field}__in": ids}
|
if ids:
|
||||||
|
(
|
||||||
|
clause_target,
|
||||||
|
filter_column,
|
||||||
|
) = self._get_clause_target_and_filter_column_name(
|
||||||
|
parent_model=parent_model, target_model=target_model, reverse=reverse
|
||||||
|
)
|
||||||
|
qryclause = QueryClause(
|
||||||
|
model_cls=clause_target, select_related=[], filter_clauses=[],
|
||||||
|
)
|
||||||
|
kwargs = {f"{filter_column}__in": ids}
|
||||||
filter_clauses, _ = qryclause.filter(**kwargs)
|
filter_clauses, _ = qryclause.filter(**kwargs)
|
||||||
return filter_clauses
|
return filter_clauses
|
||||||
return []
|
return []
|
||||||
@ -123,7 +174,7 @@ class PrefetchQuery:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_group_field_name(
|
def _get_group_field_name(
|
||||||
target_field: Type["BaseField"], model: Type["Model"]
|
target_field: Type["BaseField"], model: Union["Model", Type["Model"]]
|
||||||
) -> str:
|
) -> str:
|
||||||
if issubclass(target_field, ManyToManyField):
|
if issubclass(target_field, ManyToManyField):
|
||||||
return model.resolve_relation_name(target_field.through, model)
|
return model.resolve_relation_name(target_field.through, model)
|
||||||
@ -142,117 +193,150 @@ class PrefetchQuery:
|
|||||||
]
|
]
|
||||||
return related_to_extract
|
return related_to_extract
|
||||||
|
|
||||||
@staticmethod
|
def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model":
|
||||||
def _populate_nested_related(
|
|
||||||
model: "Model", already_extracted: Dict, prefetch_dict: Dict
|
|
||||||
) -> "Model":
|
|
||||||
|
|
||||||
related_to_extract = PrefetchQuery._get_names_to_extract(
|
related_to_extract = self._get_names_to_extract(
|
||||||
prefetch_dict=prefetch_dict, model=model
|
prefetch_dict=prefetch_dict, model=model
|
||||||
)
|
)
|
||||||
|
|
||||||
for related in related_to_extract:
|
for related in related_to_extract:
|
||||||
target_field = model.Meta.model_fields[related]
|
target_field = model.Meta.model_fields[related]
|
||||||
target_model = target_field.to.get_name()
|
target_model = target_field.to.get_name()
|
||||||
is_multi, field_name, model_id = PrefetchQuery._get_model_id_and_field_name(
|
is_multi, field_name, model_id = self._get_model_id_and_field_name(
|
||||||
target_field=target_field, model=model
|
target_field=target_field, model=model
|
||||||
)
|
)
|
||||||
if not field_name:
|
|
||||||
|
if field_name is None or model_id is None: # pragma: no cover
|
||||||
continue
|
continue
|
||||||
|
|
||||||
children = already_extracted.get(target_model, {}).get(field_name, {})
|
children = self.already_extracted.get(target_model, {}).get(field_name, {})
|
||||||
|
self._set_children_on_model(
|
||||||
|
model=model, related=related, children=children, model_id=model_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _set_children_on_model(
|
||||||
|
model: "Model", related: str, children: Dict, model_id: int
|
||||||
|
) -> None:
|
||||||
for key, child_models in children.items():
|
for key, child_models in children.items():
|
||||||
if key == model_id:
|
if key == model_id:
|
||||||
for child in child_models:
|
for child in child_models:
|
||||||
setattr(model, related, child)
|
setattr(model, related, child)
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
async def prefetch_related(
|
|
||||||
self, models: Sequence["Model"], rows: List
|
|
||||||
) -> Sequence["Model"]:
|
|
||||||
return await self._prefetch_related_models(models=models, rows=rows)
|
|
||||||
|
|
||||||
async def _prefetch_related_models(
|
async def _prefetch_related_models(
|
||||||
self, models: Sequence["Model"], rows: List
|
self, models: Sequence["Model"], rows: List
|
||||||
) -> Sequence["Model"]:
|
) -> Sequence["Model"]:
|
||||||
already_extracted = {
|
self.already_extracted = {self.model.get_name(): {"raw": rows}}
|
||||||
self.model.get_name(): {
|
|
||||||
"raw": rows,
|
|
||||||
"models": {model.pk: model for model in models},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
select_dict = translate_list_to_dict(self._select_related)
|
select_dict = translate_list_to_dict(self._select_related)
|
||||||
prefetch_dict = translate_list_to_dict(self._prefetch_related)
|
prefetch_dict = translate_list_to_dict(self._prefetch_related)
|
||||||
target_model = self.model
|
target_model = self.model
|
||||||
fields = self._columns
|
fields = self._columns
|
||||||
exclude_fields = self._exclude_columns
|
exclude_fields = self._exclude_columns
|
||||||
for related in prefetch_dict.keys():
|
for related in prefetch_dict.keys():
|
||||||
subrelated = await self._extract_related_models(
|
await self._extract_related_models(
|
||||||
related=related,
|
related=related,
|
||||||
target_model=target_model,
|
target_model=target_model,
|
||||||
prefetch_dict=prefetch_dict.get(related),
|
prefetch_dict=prefetch_dict.get(related, {}),
|
||||||
select_dict=select_dict.get(related),
|
select_dict=select_dict.get(related, {}),
|
||||||
already_extracted=already_extracted,
|
|
||||||
fields=fields,
|
fields=fields,
|
||||||
exclude_fields=exclude_fields,
|
exclude_fields=exclude_fields,
|
||||||
)
|
)
|
||||||
print(related, subrelated)
|
|
||||||
final_models = []
|
final_models = []
|
||||||
for model in models:
|
for model in models:
|
||||||
final_models.append(
|
final_models.append(
|
||||||
self._populate_nested_related(
|
self._populate_nested_related(model=model, prefetch_dict=prefetch_dict,)
|
||||||
model=model,
|
|
||||||
already_extracted=already_extracted,
|
|
||||||
prefetch_dict=prefetch_dict,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return models
|
return models
|
||||||
|
|
||||||
async def _extract_related_models( # noqa: CFQ002
|
async def _extract_related_models( # noqa: CFQ002, CCR001
|
||||||
self,
|
self,
|
||||||
related: str,
|
related: str,
|
||||||
target_model: Type["Model"],
|
target_model: Type["Model"],
|
||||||
prefetch_dict: Dict,
|
prefetch_dict: Dict,
|
||||||
select_dict: Dict,
|
select_dict: Dict,
|
||||||
already_extracted: Dict,
|
fields: Union[Set[Any], Dict[Any, Any], None],
|
||||||
fields: Dict,
|
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
|
||||||
exclude_fields: Dict,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
fields = target_model.get_included(fields, related)
|
fields = target_model.get_included(fields, related)
|
||||||
exclude_fields = target_model.get_excluded(exclude_fields, related)
|
exclude_fields = target_model.get_excluded(exclude_fields, related)
|
||||||
select_related = []
|
|
||||||
|
|
||||||
target_field = target_model.Meta.model_fields[related]
|
target_field = target_model.Meta.model_fields[related]
|
||||||
reverse = False
|
reverse = False
|
||||||
if target_field.virtual or issubclass(target_field, ManyToManyField):
|
if target_field.virtual or issubclass(target_field, ManyToManyField):
|
||||||
reverse = True
|
reverse = True
|
||||||
|
|
||||||
parent_model = target_model
|
parent_model = target_model
|
||||||
target_model = target_field.to
|
|
||||||
|
|
||||||
filter_clauses = PrefetchQuery._get_filter_for_prefetch(
|
filter_clauses = self._get_filter_for_prefetch(
|
||||||
already_extracted=already_extracted,
|
parent_model=parent_model, target_model=target_field.to, reverse=reverse,
|
||||||
parent_model=parent_model,
|
|
||||||
target_model=target_model,
|
|
||||||
reverse=reverse,
|
|
||||||
)
|
)
|
||||||
if not filter_clauses: # related field is empty
|
if not filter_clauses: # related field is empty
|
||||||
return
|
return
|
||||||
|
|
||||||
|
already_loaded = select_dict is Ellipsis or related in select_dict
|
||||||
|
|
||||||
|
if not already_loaded:
|
||||||
|
# If not already loaded with select_related
|
||||||
|
table_prefix, rows = await self._run_prefetch_query(
|
||||||
|
target_field=target_field,
|
||||||
|
fields=fields,
|
||||||
|
exclude_fields=exclude_fields,
|
||||||
|
filter_clauses=filter_clauses,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rows = []
|
||||||
|
table_prefix = ""
|
||||||
|
|
||||||
|
if prefetch_dict and prefetch_dict is not Ellipsis:
|
||||||
|
for subrelated in prefetch_dict.keys():
|
||||||
|
await self._extract_related_models(
|
||||||
|
related=subrelated,
|
||||||
|
target_model=target_field.to,
|
||||||
|
prefetch_dict=prefetch_dict.get(subrelated, {}),
|
||||||
|
select_dict=self._get_select_related_if_apply(
|
||||||
|
subrelated, select_dict
|
||||||
|
),
|
||||||
|
fields=fields,
|
||||||
|
exclude_fields=exclude_fields,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not already_loaded:
|
||||||
|
self._populate_rows(
|
||||||
|
rows=rows,
|
||||||
|
parent_model=parent_model,
|
||||||
|
target_field=target_field,
|
||||||
|
table_prefix=table_prefix,
|
||||||
|
fields=fields,
|
||||||
|
exclude_fields=exclude_fields,
|
||||||
|
prefetch_dict=prefetch_dict,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._update_already_loaded_rows(
|
||||||
|
target_field=target_field, prefetch_dict=prefetch_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run_prefetch_query(
|
||||||
|
self,
|
||||||
|
target_field: Type["BaseField"],
|
||||||
|
fields: Union[Set[Any], Dict[Any, Any], None],
|
||||||
|
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
|
||||||
|
filter_clauses: List,
|
||||||
|
) -> Tuple[str, List]:
|
||||||
|
target_model = target_field.to
|
||||||
|
target_name = target_model.get_name()
|
||||||
|
select_related = []
|
||||||
query_target = target_model
|
query_target = target_model
|
||||||
table_prefix = ""
|
table_prefix = ""
|
||||||
if issubclass(target_field, ManyToManyField):
|
if issubclass(target_field, ManyToManyField):
|
||||||
query_target = target_field.through
|
query_target = target_field.through
|
||||||
select_related = [target_field.to.get_name()]
|
select_related = [target_name]
|
||||||
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
|
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
|
||||||
from_table=query_target.Meta.tablename,
|
from_table=query_target.Meta.tablename,
|
||||||
to_table=target_field.to.Meta.tablename,
|
to_table=target_field.to.Meta.tablename,
|
||||||
)
|
)
|
||||||
already_extracted.setdefault(target_model.get_name(), {})[
|
self.already_extracted.setdefault(target_name, {})["prefix"] = table_prefix
|
||||||
"prefix"
|
|
||||||
] = table_prefix
|
|
||||||
|
|
||||||
qry = Query(
|
qry = Query(
|
||||||
model_cls=query_target,
|
model_cls=query_target,
|
||||||
@ -268,30 +352,41 @@ class PrefetchQuery:
|
|||||||
expr = qry.build_select_expression()
|
expr = qry.build_select_expression()
|
||||||
# print(expr.compile(compile_kwargs={"literal_binds": True}))
|
# print(expr.compile(compile_kwargs={"literal_binds": True}))
|
||||||
rows = await self.database.fetch_all(expr)
|
rows = await self.database.fetch_all(expr)
|
||||||
already_extracted.setdefault(target_model.get_name(), {}).update(
|
self.already_extracted.setdefault(target_name, {}).update({"raw": rows})
|
||||||
{"raw": rows, "models": {}}
|
return table_prefix, rows
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_select_related_if_apply(related: str, select_dict: Dict) -> Dict:
|
||||||
|
return (
|
||||||
|
select_dict.get(related, {})
|
||||||
|
if (select_dict and select_dict is not Ellipsis and related in select_dict)
|
||||||
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefetch_dict and prefetch_dict is not Ellipsis:
|
def _update_already_loaded_rows( # noqa: CFQ002
|
||||||
for subrelated in prefetch_dict.keys():
|
self, target_field: Type["BaseField"], prefetch_dict: Dict,
|
||||||
submodels = await self._extract_related_models(
|
) -> None:
|
||||||
related=subrelated,
|
target_model = target_field.to
|
||||||
target_model=target_model,
|
for instance in self.models.get(target_model.get_name(), []):
|
||||||
prefetch_dict=prefetch_dict.get(subrelated),
|
self._populate_nested_related(
|
||||||
select_dict=select_dict.get(subrelated)
|
model=instance, prefetch_dict=prefetch_dict,
|
||||||
if (select_dict and subrelated in select_dict)
|
|
||||||
else {},
|
|
||||||
already_extracted=already_extracted,
|
|
||||||
fields=fields,
|
|
||||||
exclude_fields=exclude_fields,
|
|
||||||
)
|
)
|
||||||
print(subrelated, submodels)
|
|
||||||
|
|
||||||
|
def _populate_rows( # noqa: CFQ002
|
||||||
|
self,
|
||||||
|
rows: List,
|
||||||
|
target_field: Type["BaseField"],
|
||||||
|
parent_model: Type["Model"],
|
||||||
|
table_prefix: str,
|
||||||
|
fields: Union[Set[Any], Dict[Any, Any], None],
|
||||||
|
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
|
||||||
|
prefetch_dict: Dict,
|
||||||
|
) -> None:
|
||||||
|
target_model = target_field.to
|
||||||
for row in rows:
|
for row in rows:
|
||||||
field_name = PrefetchQuery._get_group_field_name(
|
field_name = self._get_group_field_name(
|
||||||
target_field=target_field, model=parent_model
|
target_field=target_field, model=parent_model
|
||||||
)
|
)
|
||||||
print("TEST", field_name, target_model, row[field_name])
|
|
||||||
item = target_model.extract_prefixed_table_columns(
|
item = target_model.extract_prefixed_table_columns(
|
||||||
item={},
|
item={},
|
||||||
row=row,
|
row=row,
|
||||||
@ -301,13 +396,8 @@ class PrefetchQuery:
|
|||||||
)
|
)
|
||||||
instance = target_model(**item)
|
instance = target_model(**item)
|
||||||
instance = self._populate_nested_related(
|
instance = self._populate_nested_related(
|
||||||
model=instance,
|
model=instance, prefetch_dict=prefetch_dict,
|
||||||
already_extracted=already_extracted,
|
|
||||||
prefetch_dict=prefetch_dict,
|
|
||||||
)
|
)
|
||||||
already_extracted[target_model.get_name()].setdefault(
|
self.already_extracted[target_model.get_name()].setdefault(
|
||||||
field_name, dict()
|
field_name, dict()
|
||||||
).setdefault(row[field_name], []).append(instance)
|
).setdefault(row[field_name], []).append(instance)
|
||||||
already_extracted[target_model.get_name()]["models"][instance.pk] = instance
|
|
||||||
|
|
||||||
return already_extracted[target_model.get_name()]["models"]
|
|
||||||
|
|||||||
@ -67,16 +67,16 @@ class QuerySet:
|
|||||||
return self.model_cls
|
return self.model_cls
|
||||||
|
|
||||||
async def _prefetch_related_models(
|
async def _prefetch_related_models(
|
||||||
self, models: Sequence["Model"], rows: List
|
self, models: Sequence[Optional["Model"]], rows: List
|
||||||
) -> Sequence["Model"]:
|
) -> Sequence[Optional["Model"]]:
|
||||||
query = PrefetchQuery(
|
query = PrefetchQuery(
|
||||||
model_cls=self.model_cls,
|
model_cls=self.model,
|
||||||
fields=self._columns,
|
fields=self._columns,
|
||||||
exclude_fields=self._exclude_columns,
|
exclude_fields=self._exclude_columns,
|
||||||
prefetch_related=self._prefetch_related,
|
prefetch_related=self._prefetch_related,
|
||||||
select_related=self._select_related,
|
select_related=self._select_related,
|
||||||
)
|
)
|
||||||
return await query.prefetch_related(models=models, rows=rows)
|
return await query.prefetch_related(models=models, rows=rows) # type: ignore
|
||||||
|
|
||||||
def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]:
|
def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]:
|
||||||
result_rows = [
|
result_rows = [
|
||||||
@ -191,7 +191,7 @@ class QuerySet:
|
|||||||
if not isinstance(related, list):
|
if not isinstance(related, list):
|
||||||
related = [related]
|
related = [related]
|
||||||
|
|
||||||
related = list(set(list(self._select_related) + related))
|
related = list(set(list(self._prefetch_related) + related))
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
model_cls=self.model,
|
model_cls=self.model,
|
||||||
filter_clauses=self.filter_clauses,
|
filter_clauses=self.filter_clauses,
|
||||||
@ -352,7 +352,7 @@ class QuerySet:
|
|||||||
|
|
||||||
rows = await self.database.fetch_all(expr)
|
rows = await self.database.fetch_all(expr)
|
||||||
processed_rows = self._process_query_result_rows(rows)
|
processed_rows = self._process_query_result_rows(rows)
|
||||||
if self._prefetch_related:
|
if self._prefetch_related and processed_rows:
|
||||||
processed_rows = await self._prefetch_related_models(processed_rows, rows)
|
processed_rows = await self._prefetch_related_models(processed_rows, rows)
|
||||||
self.check_single_result_rows_count(processed_rows)
|
self.check_single_result_rows_count(processed_rows)
|
||||||
return processed_rows[0] # type: ignore
|
return processed_rows[0] # type: ignore
|
||||||
@ -379,7 +379,7 @@ class QuerySet:
|
|||||||
expr = self.build_select_expression()
|
expr = self.build_select_expression()
|
||||||
rows = await self.database.fetch_all(expr)
|
rows = await self.database.fetch_all(expr)
|
||||||
result_rows = self._process_query_result_rows(rows)
|
result_rows = self._process_query_result_rows(rows)
|
||||||
if self._prefetch_related:
|
if self._prefetch_related and result_rows:
|
||||||
result_rows = await self._prefetch_related_models(result_rows, rows)
|
result_rows = await self._prefetch_related_models(result_rows, rows)
|
||||||
|
|
||||||
return result_rows
|
return result_rows
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
import collections.abc
|
import collections.abc
|
||||||
import copy
|
import copy
|
||||||
from typing import Any, Dict, List, Set, Union
|
from typing import Any, Dict, List, Sequence, Set, TYPE_CHECKING, Type, Union
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma no cover
|
||||||
|
from ormar import Model
|
||||||
|
|
||||||
|
|
||||||
def check_node_not_dict_or_not_last_node(
|
def check_node_not_dict_or_not_last_node(
|
||||||
@ -55,3 +58,39 @@ def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) ->
|
|||||||
dict_to_update = translate_list_to_dict(list_to_update)
|
dict_to_update = translate_list_to_dict(list_to_update)
|
||||||
update(updated_dict, dict_to_update)
|
update(updated_dict, dict_to_update)
|
||||||
return updated_dict
|
return updated_dict
|
||||||
|
|
||||||
|
|
||||||
|
def extract_nested_models( # noqa: CCR001
|
||||||
|
model: "Model", model_type: Type["Model"], select_dict: Dict, extracted: Dict
|
||||||
|
) -> None:
|
||||||
|
follow = [rel for rel in model_type.extract_related_names() if rel in select_dict]
|
||||||
|
for related in follow:
|
||||||
|
child = getattr(model, related)
|
||||||
|
if child:
|
||||||
|
target_model = model_type.Meta.model_fields[related].to
|
||||||
|
if isinstance(child, list):
|
||||||
|
extracted.setdefault(target_model.get_name(), []).extend(child)
|
||||||
|
if select_dict[related] is not Ellipsis:
|
||||||
|
for sub_child in child:
|
||||||
|
extract_nested_models(
|
||||||
|
sub_child, target_model, select_dict[related], extracted,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extracted.setdefault(target_model.get_name(), []).append(child)
|
||||||
|
if select_dict[related] is not Ellipsis:
|
||||||
|
extract_nested_models(
|
||||||
|
child, target_model, select_dict[related], extracted,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_models_to_dict_of_lists(
|
||||||
|
model_type: Type["Model"],
|
||||||
|
models: Sequence["Model"],
|
||||||
|
select_dict: Dict,
|
||||||
|
extracted: Dict = None,
|
||||||
|
) -> Dict:
|
||||||
|
if not extracted:
|
||||||
|
extracted = dict()
|
||||||
|
for model in models:
|
||||||
|
extract_nested_models(model, model_type, select_dict, extracted)
|
||||||
|
return extracted
|
||||||
|
|||||||
@ -11,6 +11,16 @@ database = databases.Database(DATABASE_URL, force_rollback=True)
|
|||||||
metadata = sqlalchemy.MetaData()
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
|
|
||||||
|
class RandomSet(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "randoms"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: int = ormar.Integer(primary_key=True)
|
||||||
|
name: str = ormar.String(max_length=100)
|
||||||
|
|
||||||
|
|
||||||
class Tonation(ormar.Model):
|
class Tonation(ormar.Model):
|
||||||
class Meta:
|
class Meta:
|
||||||
tablename = "tonations"
|
tablename = "tonations"
|
||||||
@ -19,6 +29,7 @@ class Tonation(ormar.Model):
|
|||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
id: int = ormar.Integer(primary_key=True)
|
||||||
name: str = ormar.String(max_length=100)
|
name: str = ormar.String(max_length=100)
|
||||||
|
rand_set: Optional[RandomSet] = ormar.ForeignKey(RandomSet)
|
||||||
|
|
||||||
|
|
||||||
class Division(ormar.Model):
|
class Division(ormar.Model):
|
||||||
@ -181,3 +192,53 @@ async def test_prefetch_related_empty():
|
|||||||
track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird")
|
track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird")
|
||||||
assert track.title == 'The Bird'
|
assert track.title == 'The Bird'
|
||||||
assert track.album is None
|
assert track.album is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prefetch_related_with_select_related():
|
||||||
|
async with database:
|
||||||
|
async with database.transaction(force_rollback=True):
|
||||||
|
div = await Division.objects.create(name='Div 1')
|
||||||
|
shop1 = await Shop.objects.create(name='Shop 1', division=div)
|
||||||
|
shop2 = await Shop.objects.create(name='Shop 2', division=div)
|
||||||
|
album = Album(name="Malibu")
|
||||||
|
await album.save()
|
||||||
|
await album.shops.add(shop1)
|
||||||
|
await album.shops.add(shop2)
|
||||||
|
|
||||||
|
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
|
||||||
|
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
|
||||||
|
|
||||||
|
album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related(
|
||||||
|
['cover_pictures', 'shops__division']).get()
|
||||||
|
assert len(album.tracks) == 0
|
||||||
|
assert len(album.cover_pictures) == 2
|
||||||
|
assert album.shops[0].division.name == 'Div 1'
|
||||||
|
|
||||||
|
rand_set = await RandomSet.objects.create(name='Rand 1')
|
||||||
|
ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set)
|
||||||
|
await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1)
|
||||||
|
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1)
|
||||||
|
await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1)
|
||||||
|
|
||||||
|
album = await Album.objects.select_related('tracks__tonation__rand_set').filter(name='Malibu').prefetch_related(
|
||||||
|
['cover_pictures', 'shops__division']).get()
|
||||||
|
assert len(album.tracks) == 3
|
||||||
|
assert album.tracks[0].tonation == album.tracks[2].tonation == ton1
|
||||||
|
assert len(album.cover_pictures) == 2
|
||||||
|
assert album.cover_pictures[0].artist == 'Artist 1'
|
||||||
|
|
||||||
|
assert len(album.shops) == 2
|
||||||
|
assert album.shops[0].name == 'Shop 1'
|
||||||
|
assert album.shops[0].division.name == 'Div 1'
|
||||||
|
|
||||||
|
track = await Track.objects.select_related('album').prefetch_related(
|
||||||
|
["album__cover_pictures", "album__shops__division"]).get(
|
||||||
|
title="The Bird")
|
||||||
|
assert track.album.name == "Malibu"
|
||||||
|
assert len(track.album.cover_pictures) == 2
|
||||||
|
assert track.album.cover_pictures[0].artist == 'Artist 1'
|
||||||
|
|
||||||
|
assert len(track.album.shops) == 2
|
||||||
|
assert track.album.shops[0].name == 'Shop 1'
|
||||||
|
assert track.album.shops[0].division.name == 'Div 1'
|
||||||
|
|||||||
Reference in New Issue
Block a user