modify save_related to be able to save whole tree from dict - including reverse fk and m2m relations - with correct order of saving

This commit is contained in:
collerek
2021-04-12 17:39:42 +02:00
parent 6780c9de8a
commit 854b27947a
7 changed files with 474 additions and 58 deletions

View File

@ -13,6 +13,17 @@
* even if you `select_related` from reverse side of the model the returned models won't be populated in reversed instance (the join is not prevented so you still can `filter` and `order_by`) * even if you `select_related` from reverse side of the model the returned models won't be populated in reversed instance (the join is not prevented so you still can `filter` and `order_by`)
* the relation won't be populated in `dict()` and `json()` * the relation won't be populated in `dict()` and `json()`
* you cannot pass the nested related objects when populating from `dict()` or `json()` (also through `fastapi`). It will be either ignored or raise error depending on `extra` setting in pydantic `Config`. * you cannot pass the nested related objects when populating from `dict()` or `json()` (also through `fastapi`). It will be either ignored or raise error depending on `extra` setting in pydantic `Config`.
* `Model.save_related()` now can save whole data tree in once [#148](https://github.com/collerek/ormar/discussions/148)
meaning:
* it knows if it should save main `Model` or related `Model` first to preserve the relation
* it saves main `Model` if
* it's not `saved`,
* has no `pk` value
* or `save_all=True` flag is set
in those cases you don't have to split save into two calls (`save()` and `save_related()`)
* it supports also `ManyToMany` relations
* it supports also optional `Through` model values for m2m relations
## 🐛 Fixes ## 🐛 Fixes

View File

@ -4,9 +4,10 @@ from typing import (
Optional, Optional,
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
cast,
) )
from ormar import BaseField from ormar import BaseField, ForeignKeyField
from ormar.models.traversible import NodeList from ormar.models.traversible import NodeList
@ -39,7 +40,7 @@ class RelationMixin:
return self_fields return self_fields
@classmethod @classmethod
def extract_related_fields(cls) -> List: def extract_related_fields(cls) -> List["ForeignKeyField"]:
""" """
Returns List of ormar Fields for all relations declared on a model. Returns List of ormar Fields for all relations declared on a model.
List is cached in cls._related_fields for quicker access. List is cached in cls._related_fields for quicker access.
@ -52,7 +53,7 @@ class RelationMixin:
related_fields = [] related_fields = []
for name in cls.extract_related_names().union(cls.extract_through_names()): for name in cls.extract_related_names().union(cls.extract_through_names()):
related_fields.append(cls.Meta.model_fields[name]) related_fields.append(cast("ForeignKeyField", cls.Meta.model_fields[name]))
cls._related_fields = related_fields cls._related_fields = related_fields
return related_fields return related_fields

View File

@ -1,5 +1,5 @@
import uuid import uuid
from typing import Dict, Optional, Set, TYPE_CHECKING from typing import Callable, Collection, Dict, Optional, Set, TYPE_CHECKING, cast
import ormar import ormar
from ormar.exceptions import ModelPersistenceError from ormar.exceptions import ModelPersistenceError
@ -7,6 +7,9 @@ from ormar.models.helpers.validation import validate_choices
from ormar.models.mixins import AliasMixin from ormar.models.mixins import AliasMixin
from ormar.models.mixins.relation_mixin import RelationMixin from ormar.models.mixins.relation_mixin import RelationMixin
if TYPE_CHECKING: # pragma: no cover
from ormar import ForeignKeyField, Model
class SavePrepareMixin(RelationMixin, AliasMixin): class SavePrepareMixin(RelationMixin, AliasMixin):
""" """
@ -15,6 +18,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if TYPE_CHECKING: # pragma: nocover if TYPE_CHECKING: # pragma: nocover
_choices_fields: Optional[Set] _choices_fields: Optional[Set]
_skip_ellipsis: Callable
@classmethod @classmethod
def prepare_model_to_save(cls, new_kwargs: dict) -> dict: def prepare_model_to_save(cls, new_kwargs: dict) -> dict:
@ -170,3 +174,130 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if field_name in new_kwargs and field_name in cls._choices_fields: if field_name in new_kwargs and field_name in cls._choices_fields:
validate_choices(field=field, value=new_kwargs.get(field_name)) validate_choices(field=field, value=new_kwargs.get(field_name))
return new_kwargs return new_kwargs
@staticmethod
async def _upsert_model(
instance: "Model",
save_all: bool,
previous_model: Optional["Model"],
relation_field: Optional["ForeignKeyField"],
update_count: int,
) -> int:
"""
Method updates given instance if:
* instance is not saved or
* instance have no pk or
* save_all=True flag is set
and instance is not __pk_only__.
If relation leading to instance is a ManyToMany also the through model is saved
:param instance: current model to upsert
:type instance: Model
:param save_all: flag if all models should be saved or only not saved ones
:type save_all: bool
:param relation_field: field with relation
:type relation_field: Optional[ForeignKeyField]
:param previous_model: previous model from which method came
:type previous_model: Model
:param update_count: no of updated models
:type update_count: int
:return: no of updated models
:rtype: int
"""
if (
save_all or not instance.pk or not instance.saved
) and not instance.__pk_only__:
await instance.upsert()
if relation_field and relation_field.is_multi:
await instance._upsert_through_model(
instance=instance,
relation_field=relation_field,
previous_model=cast("Model", previous_model),
)
update_count += 1
return update_count
@staticmethod
async def _upsert_through_model(
instance: "Model",
previous_model: "Model",
relation_field: Optional["ForeignKeyField"],
) -> None:
"""
Upsert through model for m2m relation.
:param instance: current model to upsert
:type instance: Model
:param relation_field: field with relation
:type relation_field: Optional[ForeignKeyField]
:param previous_model: previous model from which method came
:type previous_model: Model
"""
through_name = previous_model.Meta.model_fields[
relation_field.name
].through.get_name()
through = getattr(instance, through_name)
if through:
through_dict = through.dict(exclude=through.extract_related_names())
else:
through_dict = {}
await getattr(
previous_model, relation_field.name
).queryset_proxy.upsert_through_instance(instance, **through_dict)
async def _update_relation_list(
self,
fields_list: Collection["ForeignKeyField"],
follow: bool,
save_all: bool,
relation_map: Dict,
update_count: int,
) -> int:
"""
Internal method used in save_related to follow deeper from
related models and update numbers of updated related instances.
:type save_all: flag if all models should be saved
:type save_all: bool
:param fields_list: list of ormar fields to follow and save
:type fields_list: Collection["ForeignKeyField"]
:param relation_map: map of relations to follow
:type relation_map: Dict
:param follow: flag to trigger deep save -
by default only directly related models are saved
with follow=True also related models of related models are saved
:type follow: bool
:param update_count: internal parameter for recursive calls -
number of updated instances
:type update_count: int
:return: tuple of update count and visited
:rtype: int
"""
for field in fields_list:
value = getattr(self, field.name) or []
if not isinstance(value, list):
value = [value]
for val in value:
if follow:
update_count = await val.save_related(
follow=follow,
save_all=save_all,
relation_map=self._skip_ellipsis( # type: ignore
relation_map, field.name, default_return={}
),
update_count=update_count,
previous_model=self,
relation_field=field,
)
else:
update_count = await val._upsert_model(
instance=val,
save_all=save_all,
previous_model=self,
relation_field=field,
update_count=update_count,
)
return update_count

View File

@ -2,6 +2,7 @@ from typing import (
Any, Any,
Dict, Dict,
List, List,
Optional,
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
TypeVar, TypeVar,
@ -17,6 +18,9 @@ from ormar.queryset.utils import subtract_dict, translate_list_to_dict
T = TypeVar("T", bound="Model") T = TypeVar("T", bound="Model")
if TYPE_CHECKING: # pragma: no cover
from ormar import ForeignKeyField
class Model(ModelRow): class Model(ModelRow):
__abstract__ = False __abstract__ = False
@ -110,6 +114,8 @@ class Model(ModelRow):
relation_map: Dict = None, relation_map: Dict = None,
exclude: Union[Set, Dict] = None, exclude: Union[Set, Dict] = None,
update_count: int = 0, update_count: int = 0,
previous_model: "Model" = None,
relation_field: Optional["ForeignKeyField"] = None,
) -> int: ) -> int:
""" """
Triggers a upsert method on all related models Triggers a upsert method on all related models
@ -126,6 +132,10 @@ class Model(ModelRow):
Model A but will never follow into Model C. Model A but will never follow into Model C.
Nested relations of those kind need to be persisted manually. Nested relations of those kind need to be persisted manually.
:param relation_field: field with relation leading to this model
:type relation_field: Optional[ForeignKeyField]
:param previous_model: previous model from which method came
:type previous_model: Model
:param exclude: items to exclude during saving of relations :param exclude: items to exclude during saving of relations
:type exclude: Union[Set, Dict] :type exclude: Union[Set, Dict]
:param relation_map: map of relations to follow :param relation_map: map of relations to follow
@ -151,61 +161,53 @@ class Model(ModelRow):
exclude = translate_list_to_dict(exclude) exclude = translate_list_to_dict(exclude)
relation_map = subtract_dict(relation_map, exclude or {}) relation_map = subtract_dict(relation_map, exclude or {})
for related in self.extract_related_names(): if relation_map:
if relation_map and related in relation_map: fields_to_visit = {
value = getattr(self, related) field
if value: for field in self.extract_related_fields()
update_count = await self._update_and_follow( if field.name in relation_map
value=value, }
follow=follow, pre_save = {
save_all=save_all, field
relation_map=self._skip_ellipsis( # type: ignore for field in fields_to_visit
relation_map, related, default_return={} if not field.virtual and not field.is_multi
), }
update_count=update_count,
)
return update_count
@staticmethod update_count = await self._update_relation_list(
async def _update_and_follow( fields_list=pre_save,
value: Union["Model", List["Model"]],
follow: bool,
save_all: bool,
relation_map: Dict,
update_count: int,
) -> int:
"""
Internal method used in save_related to follow related models and update numbers
of updated related instances.
:param value: Model to follow
:type value: Model
:param relation_map: map of relations to follow
:type relation_map: Dict
:param follow: flag to trigger deep save -
by default only directly related models are saved
with follow=True also related models of related models are saved
:type follow: bool
:param update_count: internal parameter for recursive calls -
number of updated instances
:type update_count: int
:return: tuple of update count and visited
:rtype: int
"""
if not isinstance(value, list):
value = [value]
for val in value:
if (not val.saved or save_all) and not val.__pk_only__:
await val.upsert()
update_count += 1
if follow:
update_count = await val.save_related(
follow=follow, follow=follow,
save_all=save_all, save_all=save_all,
relation_map=relation_map, relation_map=relation_map,
update_count=update_count, update_count=update_count,
) )
update_count = await self._upsert_model(
instance=self,
save_all=save_all,
previous_model=previous_model,
relation_field=relation_field,
update_count=update_count,
)
post_save = fields_to_visit - pre_save
update_count = await self._update_relation_list(
fields_list=post_save,
follow=follow,
save_all=save_all,
relation_map=relation_map,
update_count=update_count,
)
else:
update_count = await self._upsert_model(
instance=self,
save_all=save_all,
previous_model=previous_model,
relation_field=relation_field,
update_count=update_count,
)
return update_count return update_count
async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T: async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T:

View File

@ -16,7 +16,7 @@ from typing import ( # noqa: I100, I201
) )
import ormar # noqa: I100, I202 import ormar # noqa: I100, I202
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError from ormar.exceptions import ModelPersistenceError, NoMatch, QueryDefinitionError
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.relations import Relation from ormar.relations import Relation
@ -152,6 +152,21 @@ class QuerysetProxy(Generic[T]):
through_model = await model_cls.objects.get(**rel_kwargs) through_model = await model_cls.objects.get(**rel_kwargs)
await through_model.update(**kwargs) await through_model.update(**kwargs)
async def upsert_through_instance(self, child: "T", **kwargs: Any) -> None:
"""
Updates a through model instance in the database for m2m relations if
it already exists, else creates one.
:param kwargs: dict of additional keyword arguments for through instance
:type kwargs: Any
:param child: child model instance
:type child: Model
"""
try:
await self.update_through_instance(child=child, **kwargs)
except NoMatch:
await self.create_through_instance(child=child, **kwargs)
async def delete_through_instance(self, child: "T") -> None: async def delete_through_instance(self, child: "T") -> None:
""" """
Removes through model instance from the database for m2m relations. Removes through model instance from the database for m2m relations.

View File

@ -119,7 +119,7 @@ async def test_saving_many_to_many():
assert count == 0 assert count == 0
count = await hq.save_related(save_all=True) count = await hq.save_related(save_all=True)
assert count == 2 assert count == 3
hq.nicks[0].name = "Kabucha" hq.nicks[0].name = "Kabucha"
hq.nicks[1].name = "Kabucha2" hq.nicks[1].name = "Kabucha2"

View File

@ -0,0 +1,256 @@
from typing import List
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class CringeLevel(ormar.Model):
class Meta:
tablename = "levels"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class NickName(ormar.Model):
class Meta:
tablename = "nicks"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
is_lame: bool = ormar.Boolean(nullable=True)
level: CringeLevel = ormar.ForeignKey(CringeLevel)
class NicksHq(ormar.Model):
class Meta:
tablename = "nicks_x_hq"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
new_field: str = ormar.String(max_length=200, nullable=True)
class HQ(ormar.Model):
class Meta:
tablename = "hqs"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
nicks: List[NickName] = ormar.ManyToMany(NickName, through=NicksHq)
class Company(ormar.Model):
class Meta:
tablename = "companies"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="company_name")
founded: int = ormar.Integer(nullable=True)
hq: HQ = ormar.ForeignKey(HQ, related_name="companies")
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_saving_related_reverse_fk():
async with database:
async with database.transaction(force_rollback=True):
payload = {"companies": [{"name": "Banzai"}], "name": "Main"}
hq = HQ(**payload)
count = await hq.save_related(follow=True, save_all=True)
assert count == 2
hq_check = await HQ.objects.select_related("companies").get()
assert hq_check.pk is not None
assert hq_check.name == "Main"
assert len(hq_check.companies) == 1
assert hq_check.companies[0].name == "Banzai"
assert hq_check.companies[0].pk is not None
@pytest.mark.asyncio
async def test_saving_related_reverse_fk_multiple():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"companies": [{"name": "Banzai"}, {"name": "Yamate"}],
"name": "Main",
}
hq = HQ(**payload)
count = await hq.save_related(follow=True, save_all=True)
assert count == 3
hq_check = await HQ.objects.select_related("companies").get()
assert hq_check.pk is not None
assert hq_check.name == "Main"
assert len(hq_check.companies) == 2
assert hq_check.companies[0].name == "Banzai"
assert hq_check.companies[0].pk is not None
assert hq_check.companies[1].name == "Yamate"
assert hq_check.companies[1].pk is not None
@pytest.mark.asyncio
async def test_saving_related_fk():
async with database:
async with database.transaction(force_rollback=True):
payload = {"hq": {"name": "Main"}, "name": "Banzai"}
comp = Company(**payload)
count = await comp.save_related(follow=True, save_all=True)
assert count == 2
comp_check = await Company.objects.select_related("hq").get()
assert comp_check.pk is not None
assert comp_check.name == "Banzai"
assert comp_check.hq.name == "Main"
assert comp_check.hq.pk is not None
@pytest.mark.asyncio
async def test_saving_many_to_many_wo_through():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"name": "Main",
"nicks": [
{"name": "Bazinga0", "is_lame": False},
{"name": "Bazinga20", "is_lame": True},
],
}
hq = HQ(**payload)
count = await hq.save_related()
assert count == 3
hq_check = await HQ.objects.select_related("nicks").get()
assert hq_check.pk is not None
assert len(hq_check.nicks) == 2
assert hq_check.nicks[0].name == "Bazinga0"
assert hq_check.nicks[1].name == "Bazinga20"
@pytest.mark.asyncio
async def test_saving_many_to_many_with_through():
async with database:
async with database.transaction(force_rollback=True):
async with database.transaction(force_rollback=True):
payload = {
"name": "Main",
"nicks": [
{
"name": "Bazinga0",
"is_lame": False,
"nickshq": {"new_field": "test"},
},
{
"name": "Bazinga20",
"is_lame": True,
"nickshq": {"new_field": "test2"},
},
],
}
hq = HQ(**payload)
count = await hq.save_related()
assert count == 3
hq_check = await HQ.objects.select_related("nicks").get()
assert hq_check.pk is not None
assert len(hq_check.nicks) == 2
assert hq_check.nicks[0].name == "Bazinga0"
assert hq_check.nicks[0].nickshq.new_field == "test"
assert hq_check.nicks[1].name == "Bazinga20"
assert hq_check.nicks[1].nickshq.new_field == "test2"
@pytest.mark.asyncio
async def test_saving_nested_with_m2m_and_rev_fk():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"name": "Main",
"nicks": [
{"name": "Bazinga0", "is_lame": False, "level": {"name": "High"}},
{"name": "Bazinga20", "is_lame": True, "level": {"name": "Low"}},
],
}
hq = HQ(**payload)
count = await hq.save_related(follow=True, save_all=True)
assert count == 5
hq_check = await HQ.objects.select_related("nicks__level").get()
assert hq_check.pk is not None
assert len(hq_check.nicks) == 2
assert hq_check.nicks[0].name == "Bazinga0"
assert hq_check.nicks[0].level.name == "High"
assert hq_check.nicks[1].name == "Bazinga20"
assert hq_check.nicks[1].level.name == "Low"
@pytest.mark.asyncio
async def test_saving_nested_with_m2m_and_rev_fk_and_through():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"hq": {
"name": "Yoko",
"nicks": [
{
"name": "Bazinga0",
"is_lame": False,
"nickshq": {"new_field": "test"},
"level": {"name": "High"},
},
{
"name": "Bazinga20",
"is_lame": True,
"nickshq": {"new_field": "test2"},
"level": {"name": "Low"},
},
],
},
"name": "Main",
}
company = Company(**payload)
count = await company.save_related(follow=True, save_all=True)
assert count == 6
company_check = await Company.objects.select_related(
"hq__nicks__level"
).get()
assert company_check.pk is not None
assert company_check.name == "Main"
assert company_check.hq.name == "Yoko"
assert len(company_check.hq.nicks) == 2
assert company_check.hq.nicks[0].name == "Bazinga0"
assert company_check.hq.nicks[0].nickshq.new_field == "test"
assert company_check.hq.nicks[0].level.name == "High"
assert company_check.hq.nicks[1].name == "Bazinga20"
assert company_check.hq.nicks[1].level.name == "Low"
assert company_check.hq.nicks[1].nickshq.new_field == "test2"