Fix for prefetch related (#1275)
* fix prefetch related merging same relations refering to the same children models * change to List for p3.8 * adapt refactored prefetch query from abandoned composite_key branch and make sure new test passes * remove unused code, add missing test for prefetch related with self reference models
This commit is contained in:
@ -364,6 +364,64 @@ class ForeignKeyField(BaseField):
|
||||
prefix = "to_" if self.self_reference else ""
|
||||
return self.through_relation_name or f"{prefix}{self.owner.get_name()}"
|
||||
|
||||
def get_filter_clause_target(self) -> Type["Model"]:
|
||||
return self.to
|
||||
|
||||
def get_model_relation_fields(self, use_alias: bool = False) -> str:
|
||||
"""
|
||||
Extract names of the database columns or model fields that are connected
|
||||
with given relation based on use_alias switch and which side of the relation
|
||||
the current field is - reverse or normal.
|
||||
|
||||
:param use_alias: use db names aliases or model fields
|
||||
:type use_alias: bool
|
||||
:return: name or names of the related columns/ fields
|
||||
:rtype: Union[str, List[str]]
|
||||
"""
|
||||
if use_alias:
|
||||
return self._get_model_relation_fields_alias()
|
||||
return self._get_model_relation_fields_name()
|
||||
|
||||
def _get_model_relation_fields_name(self) -> str:
|
||||
if self.virtual:
|
||||
return self.owner.ormar_config.pkname
|
||||
return self.name
|
||||
|
||||
def _get_model_relation_fields_alias(self) -> str:
|
||||
if self.virtual:
|
||||
return self.owner.ormar_config.model_fields[
|
||||
self.owner.ormar_config.pkname
|
||||
].get_alias()
|
||||
return self.get_alias()
|
||||
|
||||
def get_related_field_alias(self) -> str:
|
||||
"""
|
||||
Extract names of the related database columns or that are connected
|
||||
with given relation based to use as a target in filter clause.
|
||||
|
||||
:return: name or names of the related columns/ fields
|
||||
:rtype: Union[str, Dict[str, str]]
|
||||
"""
|
||||
if self.virtual:
|
||||
field_name = self.get_related_name()
|
||||
field = self.to.ormar_config.model_fields[field_name]
|
||||
return field.get_alias()
|
||||
target_field = self.to.get_column_alias(self.to.ormar_config.pkname)
|
||||
return target_field
|
||||
|
||||
def get_related_field_name(self) -> Union[str, List[str]]:
|
||||
"""
|
||||
Returns name of the relation field that should be used in prefetch query.
|
||||
This field is later used to register relation in prefetch query,
|
||||
populate relations dict, and populate nested model in prefetch query.
|
||||
|
||||
:return: name(s) of the field
|
||||
:rtype: Union[str, List[str]]
|
||||
"""
|
||||
if self.virtual:
|
||||
return self.get_related_name()
|
||||
return self.to.ormar_config.pkname
|
||||
|
||||
def _evaluate_forward_ref(
|
||||
self, globalns: Any, localns: Any, is_through: bool = False
|
||||
) -> None:
|
||||
|
||||
@ -268,6 +268,51 @@ class ManyToManyField( # type: ignore
|
||||
"""
|
||||
return self.through
|
||||
|
||||
def get_filter_clause_target(self) -> Type["Model"]:
|
||||
return self.through
|
||||
|
||||
def get_model_relation_fields(self, use_alias: bool = False) -> str:
|
||||
"""
|
||||
Extract names of the database columns or model fields that are connected
|
||||
with given relation based on use_alias switch.
|
||||
|
||||
:param use_alias: use db names aliases or model fields
|
||||
:type use_alias: bool
|
||||
:return: name or names of the related columns/ fields
|
||||
:rtype: Union[str, List[str]]
|
||||
"""
|
||||
pk_field = self.owner.ormar_config.model_fields[self.owner.ormar_config.pkname]
|
||||
result = pk_field.get_alias() if use_alias else pk_field.name
|
||||
return result
|
||||
|
||||
def get_related_field_alias(self) -> str:
|
||||
"""
|
||||
Extract names of the related database columns or that are connected
|
||||
with given relation based to use as a target in filter clause.
|
||||
|
||||
:return: name or names of the related columns/ fields
|
||||
:rtype: Union[str, Dict[str, str]]
|
||||
"""
|
||||
if self.self_reference and self.self_reference_primary == self.name:
|
||||
field_name = self.default_target_field_name()
|
||||
else:
|
||||
field_name = self.default_source_field_name()
|
||||
sub_field = self.through.ormar_config.model_fields[field_name]
|
||||
return sub_field.get_alias()
|
||||
|
||||
def get_related_field_name(self) -> Union[str, List[str]]:
|
||||
"""
|
||||
Returns name of the relation field that should be used in prefetch query.
|
||||
This field is later used to register relation in prefetch query,
|
||||
populate relations dict, and populate nested model in prefetch query.
|
||||
|
||||
:return: name(s) of the field
|
||||
:rtype: Union[str, List[str]]
|
||||
"""
|
||||
if self.self_reference and self.self_reference_primary == self.name:
|
||||
return self.default_target_field_name()
|
||||
return self.default_source_field_name()
|
||||
|
||||
def create_default_through_model(self) -> None:
|
||||
"""
|
||||
Creates default empty through model if no additional fields are required.
|
||||
|
||||
@ -8,14 +8,12 @@ it became quite complicated over time.
|
||||
from ormar.models.mixins.alias_mixin import AliasMixin
|
||||
from ormar.models.mixins.excludable_mixin import ExcludableMixin
|
||||
from ormar.models.mixins.merge_mixin import MergeModelMixin
|
||||
from ormar.models.mixins.prefetch_mixin import PrefetchQueryMixin
|
||||
from ormar.models.mixins.pydantic_mixin import PydanticMixin
|
||||
from ormar.models.mixins.save_mixin import SavePrepareMixin
|
||||
|
||||
__all__ = [
|
||||
"MergeModelMixin",
|
||||
"AliasMixin",
|
||||
"PrefetchQueryMixin",
|
||||
"SavePrepareMixin",
|
||||
"ExcludableMixin",
|
||||
"PydanticMixin",
|
||||
|
||||
@ -1,123 +0,0 @@
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Type, cast
|
||||
|
||||
from ormar.models.mixins.relation_mixin import RelationMixin
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ormar.fields import ForeignKeyField, ManyToManyField
|
||||
|
||||
|
||||
class PrefetchQueryMixin(RelationMixin):
|
||||
"""
|
||||
Used in PrefetchQuery to extract ids and names of models to prefetch.
|
||||
"""
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
|
||||
get_name: Callable # defined in NewBaseModel
|
||||
|
||||
@staticmethod
|
||||
def get_clause_target_and_filter_column_name(
|
||||
parent_model: Type["Model"],
|
||||
target_model: Type["Model"],
|
||||
reverse: bool,
|
||||
related: str,
|
||||
) -> Tuple[Type["Model"], str]:
|
||||
"""
|
||||
Returns Model on which query clause should be performed and name of the column.
|
||||
|
||||
:param parent_model: related model that the relation lead to
|
||||
:type parent_model: Type[Model]
|
||||
:param target_model: model on which query should be performed
|
||||
:type target_model: Type[Model]
|
||||
:param reverse: flag if the relation is reverse
|
||||
:type reverse: bool
|
||||
:param related: name of the relation field
|
||||
:type related: str
|
||||
:return: Model on which query clause should be performed and name of the column
|
||||
:rtype: Tuple[Type[Model], str]
|
||||
"""
|
||||
if reverse:
|
||||
field_name = parent_model.ormar_config.model_fields[
|
||||
related
|
||||
].get_related_name()
|
||||
field = target_model.ormar_config.model_fields[field_name]
|
||||
if field.is_multi:
|
||||
field = cast("ManyToManyField", field)
|
||||
field_name = field.default_target_field_name()
|
||||
sub_field = field.through.ormar_config.model_fields[field_name]
|
||||
return field.through, sub_field.get_alias()
|
||||
return target_model, field.get_alias()
|
||||
target_field = target_model.get_column_alias(target_model.ormar_config.pkname)
|
||||
return target_model, target_field
|
||||
|
||||
@staticmethod
|
||||
def get_column_name_for_id_extraction(
|
||||
parent_model: Type["Model"], reverse: bool, related: str, use_raw: bool
|
||||
) -> str:
|
||||
"""
|
||||
Returns name of the column that should be used to extract ids from model.
|
||||
Depending on the relation side it's either primary key column of parent model
|
||||
or field name specified by related parameter.
|
||||
|
||||
:param parent_model: model from which id column should be extracted
|
||||
:type parent_model: Type[Model]
|
||||
:param reverse: flag if the relation is reverse
|
||||
:type reverse: bool
|
||||
:param related: name of the relation field
|
||||
:type related: str
|
||||
:param use_raw: flag if aliases or field names should be used
|
||||
:type use_raw: bool
|
||||
:return:
|
||||
:rtype:
|
||||
"""
|
||||
if reverse:
|
||||
column_name = parent_model.ormar_config.pkname
|
||||
return (
|
||||
parent_model.get_column_alias(column_name) if use_raw else column_name
|
||||
)
|
||||
column = parent_model.ormar_config.model_fields[related]
|
||||
return column.get_alias() if use_raw else column.name
|
||||
|
||||
@classmethod
|
||||
def get_related_field_name(cls, target_field: "ForeignKeyField") -> str:
|
||||
"""
|
||||
Returns name of the relation field that should be used in prefetch query.
|
||||
This field is later used to register relation in prefetch query,
|
||||
populate relations dict, and populate nested model in prefetch query.
|
||||
|
||||
:param target_field: relation field that should be used in prefetch
|
||||
:type target_field: Type[BaseField]
|
||||
:return: name of the field
|
||||
:rtype: str
|
||||
"""
|
||||
if target_field.is_multi:
|
||||
return cls.get_name()
|
||||
if target_field.virtual:
|
||||
return target_field.get_related_name()
|
||||
return target_field.to.ormar_config.pkname
|
||||
|
||||
@classmethod
|
||||
def get_filtered_names_to_extract(cls, prefetch_dict: Dict) -> List:
|
||||
"""
|
||||
Returns list of related fields names that should be followed to prefetch related
|
||||
models from.
|
||||
|
||||
List of models is translated into dict to assure each model is extracted only
|
||||
once in one query, that's why this function accepts prefetch_dict not list.
|
||||
|
||||
Only relations from current model are returned.
|
||||
|
||||
:param prefetch_dict: dictionary of fields to extract
|
||||
:type prefetch_dict: Dict
|
||||
:return: list of fields names to extract
|
||||
:rtype: List
|
||||
"""
|
||||
related_to_extract = []
|
||||
if prefetch_dict and prefetch_dict is not Ellipsis:
|
||||
related_to_extract = [
|
||||
related
|
||||
for related in cls.extract_related_names()
|
||||
if related in prefetch_dict
|
||||
]
|
||||
return related_to_extract
|
||||
@ -1,14 +1,12 @@
|
||||
from ormar.models.mixins import (
|
||||
ExcludableMixin,
|
||||
MergeModelMixin,
|
||||
PrefetchQueryMixin,
|
||||
PydanticMixin,
|
||||
SavePrepareMixin,
|
||||
)
|
||||
|
||||
|
||||
class ModelTableProxy(
|
||||
PrefetchQueryMixin,
|
||||
MergeModelMixin,
|
||||
SavePrepareMixin,
|
||||
ExcludableMixin,
|
||||
|
||||
@ -25,7 +25,6 @@ import typing_extensions
|
||||
|
||||
import ormar # noqa I100
|
||||
from ormar.exceptions import ModelError, ModelPersistenceError
|
||||
from ormar.fields import BaseField
|
||||
from ormar.fields.foreign_key import ForeignKeyField
|
||||
from ormar.fields.parsers import decode_bytes, encode_json
|
||||
from ormar.models.helpers import register_relation_in_alias_manager
|
||||
@ -1167,18 +1166,3 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
f"model without pk set!"
|
||||
)
|
||||
return self_fields
|
||||
|
||||
def get_relation_model_id(self, target_field: "BaseField") -> Optional[int]:
|
||||
"""
|
||||
Returns an id of the relation side model to use in prefetch query.
|
||||
|
||||
:param target_field: field with relation definition
|
||||
:type target_field: "BaseField"
|
||||
:return: value of pk if set
|
||||
:rtype: Optional[int]
|
||||
"""
|
||||
if target_field.virtual or target_field.is_multi:
|
||||
return self.pk
|
||||
related_name = target_field.name
|
||||
related_model = getattr(self, related_name)
|
||||
return None if not related_model else related_model.pk
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -172,7 +172,7 @@ class QuerySet(Generic[T]):
|
||||
select_related=self._select_related,
|
||||
orders_by=self.order_bys,
|
||||
)
|
||||
return await query.prefetch_related(models=models, rows=rows) # type: ignore
|
||||
return await query.prefetch_related(models=models) # type: ignore
|
||||
|
||||
async def _process_query_result_rows(self, rows: List) -> List["T"]:
|
||||
"""
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
@ -42,7 +41,7 @@ def check_node_not_dict_or_not_last_node(
|
||||
|
||||
|
||||
def translate_list_to_dict( # noqa: CCR001
|
||||
list_to_trans: Union[List, Set], is_order: bool = False
|
||||
list_to_trans: Union[List, Set], default: Any = ...
|
||||
) -> Dict:
|
||||
"""
|
||||
Splits the list of strings by '__' and converts them to dictionary with nested
|
||||
@ -53,6 +52,8 @@ def translate_list_to_dict( # noqa: CCR001
|
||||
|
||||
:param list_to_trans: input list
|
||||
:type list_to_trans: Union[List, Set]
|
||||
:param default: value to use as a default value
|
||||
:type default: Any
|
||||
:param is_order: flag if change affects order_by clauses are they require special
|
||||
default value with sort order.
|
||||
:type is_order: bool
|
||||
@ -63,14 +64,7 @@ def translate_list_to_dict( # noqa: CCR001
|
||||
for path in list_to_trans:
|
||||
current_level = new_dict
|
||||
parts = path.split("__")
|
||||
def_val: Any = ...
|
||||
if is_order:
|
||||
if parts[0][0] == "-":
|
||||
def_val = "desc"
|
||||
parts[0] = parts[0][1:]
|
||||
else:
|
||||
def_val = "asc"
|
||||
|
||||
def_val: Any = default
|
||||
for ind, part in enumerate(parts):
|
||||
is_last = ind == len(parts) - 1
|
||||
if check_node_not_dict_or_not_last_node(
|
||||
@ -189,78 +183,6 @@ def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) ->
|
||||
return updated_dict
|
||||
|
||||
|
||||
def extract_nested_models( # noqa: CCR001
|
||||
model: "Model", model_type: Type["Model"], select_dict: Dict, extracted: Dict
|
||||
) -> None:
|
||||
"""
|
||||
Iterates over model relations and extracts all nested models from select_dict and
|
||||
puts them in corresponding list under relation name in extracted dict.keys
|
||||
|
||||
Basically flattens all relation to dictionary of all related models, that can be
|
||||
used on several models and extract all of their children into dictionary of lists
|
||||
witch children models.
|
||||
|
||||
Goes also into nested relations if needed (specified in select_dict).
|
||||
|
||||
:param model: parent Model
|
||||
:type model: Model
|
||||
:param model_type: parent model class
|
||||
:type model_type: Type[Model]
|
||||
:param select_dict: dictionary of related models from select_related
|
||||
:type select_dict: Dict
|
||||
:param extracted: dictionary with already extracted models
|
||||
:type extracted: Dict
|
||||
"""
|
||||
follow = [rel for rel in model_type.extract_related_names() if rel in select_dict]
|
||||
for related in follow:
|
||||
child = getattr(model, related)
|
||||
if not child:
|
||||
continue
|
||||
target_model = model_type.ormar_config.model_fields[related].to
|
||||
if isinstance(child, list):
|
||||
extracted.setdefault(target_model.get_name(), []).extend(child)
|
||||
if select_dict[related] is not Ellipsis:
|
||||
for sub_child in child:
|
||||
extract_nested_models(
|
||||
sub_child, target_model, select_dict[related], extracted
|
||||
)
|
||||
else:
|
||||
extracted.setdefault(target_model.get_name(), []).append(child)
|
||||
if select_dict[related] is not Ellipsis:
|
||||
extract_nested_models(
|
||||
child, target_model, select_dict[related], extracted
|
||||
)
|
||||
|
||||
|
||||
def extract_models_to_dict_of_lists(
|
||||
model_type: Type["Model"],
|
||||
models: Sequence["Model"],
|
||||
select_dict: Dict,
|
||||
extracted: Optional[Dict] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Receives a list of models and extracts all of the children and their children
|
||||
into dictionary of lists with children models, flattening the structure to one dict
|
||||
with all children models under their relation keys.
|
||||
|
||||
:param model_type: parent model class
|
||||
:type model_type: Type[Model]
|
||||
:param models: list of models from which related models should be extracted.
|
||||
:type models: List[Model]
|
||||
:param select_dict: dictionary of related models from select_related
|
||||
:type select_dict: Dict
|
||||
:param extracted: dictionary with already extracted models
|
||||
:type extracted: Dict
|
||||
:return: dictionary of lists f related models
|
||||
:rtype: Dict
|
||||
"""
|
||||
if not extracted:
|
||||
extracted = dict()
|
||||
for model in models:
|
||||
extract_nested_models(model, model_type, select_dict, extracted)
|
||||
return extracted
|
||||
|
||||
|
||||
def get_relationship_alias_model_and_str(
|
||||
source_model: Type["Model"], related_parts: List
|
||||
) -> Tuple[str, Type["Model"], str, bool]:
|
||||
|
||||
Reference in New Issue
Block a user