modify save_related to be able to save whole tree from dict - including reverse fk and m2m relations - with correct order of saving

This commit is contained in:
collerek
2021-04-12 17:39:42 +02:00
parent 6780c9de8a
commit 854b27947a
7 changed files with 474 additions and 58 deletions

View File

@ -1,5 +1,5 @@
import uuid
from typing import Dict, Optional, Set, TYPE_CHECKING
from typing import Callable, Collection, Dict, Optional, Set, TYPE_CHECKING, cast
import ormar
from ormar.exceptions import ModelPersistenceError
@ -7,6 +7,9 @@ from ormar.models.helpers.validation import validate_choices
from ormar.models.mixins import AliasMixin
from ormar.models.mixins.relation_mixin import RelationMixin
if TYPE_CHECKING: # pragma: no cover
from ormar import ForeignKeyField, Model
class SavePrepareMixin(RelationMixin, AliasMixin):
"""
@ -15,6 +18,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if TYPE_CHECKING: # pragma: nocover
_choices_fields: Optional[Set]
_skip_ellipsis: Callable
@classmethod
def prepare_model_to_save(cls, new_kwargs: dict) -> dict:
@ -170,3 +174,130 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if field_name in new_kwargs and field_name in cls._choices_fields:
validate_choices(field=field, value=new_kwargs.get(field_name))
return new_kwargs
@staticmethod
async def _upsert_model(
instance: "Model",
save_all: bool,
previous_model: Optional["Model"],
relation_field: Optional["ForeignKeyField"],
update_count: int,
) -> int:
"""
Method updates given instance if:
* instance is not saved or
* instance have no pk or
* save_all=True flag is set
and instance is not __pk_only__.
If relation leading to instance is a ManyToMany also the through model is saved
:param instance: current model to upsert
:type instance: Model
:param save_all: flag if all models should be saved or only not saved ones
:type save_all: bool
:param relation_field: field with relation
:type relation_field: Optional[ForeignKeyField]
:param previous_model: previous model from which method came
:type previous_model: Model
:param update_count: no of updated models
:type update_count: int
:return: no of updated models
:rtype: int
"""
if (
save_all or not instance.pk or not instance.saved
) and not instance.__pk_only__:
await instance.upsert()
if relation_field and relation_field.is_multi:
await instance._upsert_through_model(
instance=instance,
relation_field=relation_field,
previous_model=cast("Model", previous_model),
)
update_count += 1
return update_count
@staticmethod
async def _upsert_through_model(
instance: "Model",
previous_model: "Model",
relation_field: Optional["ForeignKeyField"],
) -> None:
"""
Upsert through model for m2m relation.
:param instance: current model to upsert
:type instance: Model
:param relation_field: field with relation
:type relation_field: Optional[ForeignKeyField]
:param previous_model: previous model from which method came
:type previous_model: Model
"""
through_name = previous_model.Meta.model_fields[
relation_field.name
].through.get_name()
through = getattr(instance, through_name)
if through:
through_dict = through.dict(exclude=through.extract_related_names())
else:
through_dict = {}
await getattr(
previous_model, relation_field.name
).queryset_proxy.upsert_through_instance(instance, **through_dict)
async def _update_relation_list(
self,
fields_list: Collection["ForeignKeyField"],
follow: bool,
save_all: bool,
relation_map: Dict,
update_count: int,
) -> int:
"""
Internal method used in save_related to follow deeper from
related models and update numbers of updated related instances.
:type save_all: flag if all models should be saved
:type save_all: bool
:param fields_list: list of ormar fields to follow and save
:type fields_list: Collection["ForeignKeyField"]
:param relation_map: map of relations to follow
:type relation_map: Dict
:param follow: flag to trigger deep save -
by default only directly related models are saved
with follow=True also related models of related models are saved
:type follow: bool
:param update_count: internal parameter for recursive calls -
number of updated instances
:type update_count: int
:return: tuple of update count and visited
:rtype: int
"""
for field in fields_list:
value = getattr(self, field.name) or []
if not isinstance(value, list):
value = [value]
for val in value:
if follow:
update_count = await val.save_related(
follow=follow,
save_all=save_all,
relation_map=self._skip_ellipsis( # type: ignore
relation_map, field.name, default_return={}
),
update_count=update_count,
previous_model=self,
relation_field=field,
)
else:
update_count = await val._upsert_model(
instance=val,
save_all=save_all,
previous_model=self,
relation_field=field,
update_count=update_count,
)
return update_count