Fix for prefetch related (#1275)

* fix prefetch related merging same relations refering to the same children models

* change to List for p3.8

* adapt refactored prefetch query from abandoned composite_key branch and make sure new test passes

* remove unused code, add missing test for prefetch related with self reference models
This commit is contained in:
collerek
2024-03-24 00:00:51 +01:00
committed by GitHub
parent 1ed0d5a87f
commit 52d992d8c7
13 changed files with 875 additions and 831 deletions

View File

@ -364,6 +364,64 @@ class ForeignKeyField(BaseField):
prefix = "to_" if self.self_reference else ""
return self.through_relation_name or f"{prefix}{self.owner.get_name()}"
def get_filter_clause_target(self) -> Type["Model"]:
return self.to
def get_model_relation_fields(self, use_alias: bool = False) -> str:
"""
Extract names of the database columns or model fields that are connected
with given relation based on use_alias switch and which side of the relation
the current field is - reverse or normal.
:param use_alias: use db names aliases or model fields
:type use_alias: bool
:return: name or names of the related columns/ fields
:rtype: Union[str, List[str]]
"""
if use_alias:
return self._get_model_relation_fields_alias()
return self._get_model_relation_fields_name()
def _get_model_relation_fields_name(self) -> str:
if self.virtual:
return self.owner.ormar_config.pkname
return self.name
def _get_model_relation_fields_alias(self) -> str:
if self.virtual:
return self.owner.ormar_config.model_fields[
self.owner.ormar_config.pkname
].get_alias()
return self.get_alias()
def get_related_field_alias(self) -> str:
"""
Extract names of the related database columns or that are connected
with given relation based to use as a target in filter clause.
:return: name or names of the related columns/ fields
:rtype: Union[str, Dict[str, str]]
"""
if self.virtual:
field_name = self.get_related_name()
field = self.to.ormar_config.model_fields[field_name]
return field.get_alias()
target_field = self.to.get_column_alias(self.to.ormar_config.pkname)
return target_field
def get_related_field_name(self) -> Union[str, List[str]]:
"""
Returns name of the relation field that should be used in prefetch query.
This field is later used to register relation in prefetch query,
populate relations dict, and populate nested model in prefetch query.
:return: name(s) of the field
:rtype: Union[str, List[str]]
"""
if self.virtual:
return self.get_related_name()
return self.to.ormar_config.pkname
def _evaluate_forward_ref(
self, globalns: Any, localns: Any, is_through: bool = False
) -> None:

View File

@ -268,6 +268,51 @@ class ManyToManyField( # type: ignore
"""
return self.through
def get_filter_clause_target(self) -> Type["Model"]:
return self.through
def get_model_relation_fields(self, use_alias: bool = False) -> str:
"""
Extract names of the database columns or model fields that are connected
with given relation based on use_alias switch.
:param use_alias: use db names aliases or model fields
:type use_alias: bool
:return: name or names of the related columns/ fields
:rtype: Union[str, List[str]]
"""
pk_field = self.owner.ormar_config.model_fields[self.owner.ormar_config.pkname]
result = pk_field.get_alias() if use_alias else pk_field.name
return result
def get_related_field_alias(self) -> str:
"""
Extract names of the related database columns or that are connected
with given relation based to use as a target in filter clause.
:return: name or names of the related columns/ fields
:rtype: Union[str, Dict[str, str]]
"""
if self.self_reference and self.self_reference_primary == self.name:
field_name = self.default_target_field_name()
else:
field_name = self.default_source_field_name()
sub_field = self.through.ormar_config.model_fields[field_name]
return sub_field.get_alias()
def get_related_field_name(self) -> Union[str, List[str]]:
"""
Returns name of the relation field that should be used in prefetch query.
This field is later used to register relation in prefetch query,
populate relations dict, and populate nested model in prefetch query.
:return: name(s) of the field
:rtype: Union[str, List[str]]
"""
if self.self_reference and self.self_reference_primary == self.name:
return self.default_target_field_name()
return self.default_source_field_name()
def create_default_through_model(self) -> None:
"""
Creates default empty through model if no additional fields are required.

View File

@ -8,14 +8,12 @@ it became quite complicated over time.
from ormar.models.mixins.alias_mixin import AliasMixin
from ormar.models.mixins.excludable_mixin import ExcludableMixin
from ormar.models.mixins.merge_mixin import MergeModelMixin
from ormar.models.mixins.prefetch_mixin import PrefetchQueryMixin
from ormar.models.mixins.pydantic_mixin import PydanticMixin
from ormar.models.mixins.save_mixin import SavePrepareMixin
__all__ = [
"MergeModelMixin",
"AliasMixin",
"PrefetchQueryMixin",
"SavePrepareMixin",
"ExcludableMixin",
"PydanticMixin",

View File

@ -1,123 +0,0 @@
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Type, cast
from ormar.models.mixins.relation_mixin import RelationMixin
if TYPE_CHECKING: # pragma: no cover
from ormar.fields import ForeignKeyField, ManyToManyField
class PrefetchQueryMixin(RelationMixin):
"""
Used in PrefetchQuery to extract ids and names of models to prefetch.
"""
if TYPE_CHECKING: # pragma no cover
from ormar import Model
get_name: Callable # defined in NewBaseModel
@staticmethod
def get_clause_target_and_filter_column_name(
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool,
related: str,
) -> Tuple[Type["Model"], str]:
"""
Returns Model on which query clause should be performed and name of the column.
:param parent_model: related model that the relation lead to
:type parent_model: Type[Model]
:param target_model: model on which query should be performed
:type target_model: Type[Model]
:param reverse: flag if the relation is reverse
:type reverse: bool
:param related: name of the relation field
:type related: str
:return: Model on which query clause should be performed and name of the column
:rtype: Tuple[Type[Model], str]
"""
if reverse:
field_name = parent_model.ormar_config.model_fields[
related
].get_related_name()
field = target_model.ormar_config.model_fields[field_name]
if field.is_multi:
field = cast("ManyToManyField", field)
field_name = field.default_target_field_name()
sub_field = field.through.ormar_config.model_fields[field_name]
return field.through, sub_field.get_alias()
return target_model, field.get_alias()
target_field = target_model.get_column_alias(target_model.ormar_config.pkname)
return target_model, target_field
@staticmethod
def get_column_name_for_id_extraction(
parent_model: Type["Model"], reverse: bool, related: str, use_raw: bool
) -> str:
"""
Returns name of the column that should be used to extract ids from model.
Depending on the relation side it's either primary key column of parent model
or field name specified by related parameter.
:param parent_model: model from which id column should be extracted
:type parent_model: Type[Model]
:param reverse: flag if the relation is reverse
:type reverse: bool
:param related: name of the relation field
:type related: str
:param use_raw: flag if aliases or field names should be used
:type use_raw: bool
:return:
:rtype:
"""
if reverse:
column_name = parent_model.ormar_config.pkname
return (
parent_model.get_column_alias(column_name) if use_raw else column_name
)
column = parent_model.ormar_config.model_fields[related]
return column.get_alias() if use_raw else column.name
@classmethod
def get_related_field_name(cls, target_field: "ForeignKeyField") -> str:
"""
Returns name of the relation field that should be used in prefetch query.
This field is later used to register relation in prefetch query,
populate relations dict, and populate nested model in prefetch query.
:param target_field: relation field that should be used in prefetch
:type target_field: Type[BaseField]
:return: name of the field
:rtype: str
"""
if target_field.is_multi:
return cls.get_name()
if target_field.virtual:
return target_field.get_related_name()
return target_field.to.ormar_config.pkname
@classmethod
def get_filtered_names_to_extract(cls, prefetch_dict: Dict) -> List:
"""
Returns list of related fields names that should be followed to prefetch related
models from.
List of models is translated into dict to assure each model is extracted only
once in one query, that's why this function accepts prefetch_dict not list.
Only relations from current model are returned.
:param prefetch_dict: dictionary of fields to extract
:type prefetch_dict: Dict
:return: list of fields names to extract
:rtype: List
"""
related_to_extract = []
if prefetch_dict and prefetch_dict is not Ellipsis:
related_to_extract = [
related
for related in cls.extract_related_names()
if related in prefetch_dict
]
return related_to_extract

View File

@ -1,14 +1,12 @@
from ormar.models.mixins import (
ExcludableMixin,
MergeModelMixin,
PrefetchQueryMixin,
PydanticMixin,
SavePrepareMixin,
)
class ModelTableProxy(
PrefetchQueryMixin,
MergeModelMixin,
SavePrepareMixin,
ExcludableMixin,

View File

@ -25,7 +25,6 @@ import typing_extensions
import ormar # noqa I100
from ormar.exceptions import ModelError, ModelPersistenceError
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.fields.parsers import decode_bytes, encode_json
from ormar.models.helpers import register_relation_in_alias_manager
@ -1167,18 +1166,3 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
f"model without pk set!"
)
return self_fields
def get_relation_model_id(self, target_field: "BaseField") -> Optional[int]:
"""
Returns an id of the relation side model to use in prefetch query.
:param target_field: field with relation definition
:type target_field: "BaseField"
:return: value of pk if set
:rtype: Optional[int]
"""
if target_field.virtual or target_field.is_multi:
return self.pk
related_name = target_field.name
related_model = getattr(self, related_name)
return None if not related_model else related_model.pk

File diff suppressed because it is too large Load Diff

View File

@ -172,7 +172,7 @@ class QuerySet(Generic[T]):
select_related=self._select_related,
orders_by=self.order_bys,
)
return await query.prefetch_related(models=models, rows=rows) # type: ignore
return await query.prefetch_related(models=models) # type: ignore
async def _process_query_result_rows(self, rows: List) -> List["T"]:
"""

View File

@ -6,7 +6,6 @@ from typing import (
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
@ -42,7 +41,7 @@ def check_node_not_dict_or_not_last_node(
def translate_list_to_dict( # noqa: CCR001
list_to_trans: Union[List, Set], is_order: bool = False
list_to_trans: Union[List, Set], default: Any = ...
) -> Dict:
"""
Splits the list of strings by '__' and converts them to dictionary with nested
@ -53,6 +52,8 @@ def translate_list_to_dict( # noqa: CCR001
:param list_to_trans: input list
:type list_to_trans: Union[List, Set]
:param default: value to use as a default value
:type default: Any
:param is_order: flag if change affects order_by clauses are they require special
default value with sort order.
:type is_order: bool
@ -63,14 +64,7 @@ def translate_list_to_dict( # noqa: CCR001
for path in list_to_trans:
current_level = new_dict
parts = path.split("__")
def_val: Any = ...
if is_order:
if parts[0][0] == "-":
def_val = "desc"
parts[0] = parts[0][1:]
else:
def_val = "asc"
def_val: Any = default
for ind, part in enumerate(parts):
is_last = ind == len(parts) - 1
if check_node_not_dict_or_not_last_node(
@ -189,78 +183,6 @@ def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) ->
return updated_dict
def extract_nested_models( # noqa: CCR001
model: "Model", model_type: Type["Model"], select_dict: Dict, extracted: Dict
) -> None:
"""
Iterates over model relations and extracts all nested models from select_dict and
puts them in corresponding list under relation name in extracted dict.keys
Basically flattens all relation to dictionary of all related models, that can be
used on several models and extract all of their children into dictionary of lists
witch children models.
Goes also into nested relations if needed (specified in select_dict).
:param model: parent Model
:type model: Model
:param model_type: parent model class
:type model_type: Type[Model]
:param select_dict: dictionary of related models from select_related
:type select_dict: Dict
:param extracted: dictionary with already extracted models
:type extracted: Dict
"""
follow = [rel for rel in model_type.extract_related_names() if rel in select_dict]
for related in follow:
child = getattr(model, related)
if not child:
continue
target_model = model_type.ormar_config.model_fields[related].to
if isinstance(child, list):
extracted.setdefault(target_model.get_name(), []).extend(child)
if select_dict[related] is not Ellipsis:
for sub_child in child:
extract_nested_models(
sub_child, target_model, select_dict[related], extracted
)
else:
extracted.setdefault(target_model.get_name(), []).append(child)
if select_dict[related] is not Ellipsis:
extract_nested_models(
child, target_model, select_dict[related], extracted
)
def extract_models_to_dict_of_lists(
model_type: Type["Model"],
models: Sequence["Model"],
select_dict: Dict,
extracted: Optional[Dict] = None,
) -> Dict:
"""
Receives a list of models and extracts all of the children and their children
into dictionary of lists with children models, flattening the structure to one dict
with all children models under their relation keys.
:param model_type: parent model class
:type model_type: Type[Model]
:param models: list of models from which related models should be extracted.
:type models: List[Model]
:param select_dict: dictionary of related models from select_related
:type select_dict: Dict
:param extracted: dictionary with already extracted models
:type extracted: Dict
:return: dictionary of lists f related models
:rtype: Dict
"""
if not extracted:
extracted = dict()
for model in models:
extract_nested_models(model, model_type, select_dict, extracted)
return extracted
def get_relationship_alias_model_and_str(
source_model: Type["Model"], related_parts: List
) -> Tuple[str, Type["Model"], str, bool]: