modify save_related to be able to save whole tree from dict - including reverse fk and m2m relations - with correct order of saving

This commit is contained in:
collerek
2021-04-12 17:39:42 +02:00
parent 6780c9de8a
commit 854b27947a
7 changed files with 474 additions and 58 deletions

View File

@ -13,6 +13,17 @@
* even if you `select_related` from reverse side of the model the returned models won't be populated in reversed instance (the join is not prevented so you still can `filter` and `order_by`)
* the relation won't be populated in `dict()` and `json()`
* you cannot pass the nested related objects when populating from `dict()` or `json()` (also through `fastapi`). It will be either ignored or raise error depending on `extra` setting in pydantic `Config`.
* `Model.save_related()` now can save whole data tree in once [#148](https://github.com/collerek/ormar/discussions/148)
meaning:
* it knows if it should save main `Model` or related `Model` first to preserve the relation
* it saves main `Model` if
* it's not `saved`,
* has no `pk` value
* or `save_all=True` flag is set
in those cases you don't have to split save into two calls (`save()` and `save_related()`)
* it supports also `ManyToMany` relations
* it supports also optional `Through` model values for m2m relations
## 🐛 Fixes

View File

@ -4,9 +4,10 @@ from typing import (
Optional,
Set,
TYPE_CHECKING,
cast,
)
from ormar import BaseField
from ormar import BaseField, ForeignKeyField
from ormar.models.traversible import NodeList
@ -39,7 +40,7 @@ class RelationMixin:
return self_fields
@classmethod
def extract_related_fields(cls) -> List:
def extract_related_fields(cls) -> List["ForeignKeyField"]:
"""
Returns List of ormar Fields for all relations declared on a model.
List is cached in cls._related_fields for quicker access.
@ -52,7 +53,7 @@ class RelationMixin:
related_fields = []
for name in cls.extract_related_names().union(cls.extract_through_names()):
related_fields.append(cls.Meta.model_fields[name])
related_fields.append(cast("ForeignKeyField", cls.Meta.model_fields[name]))
cls._related_fields = related_fields
return related_fields

View File

@ -1,5 +1,5 @@
import uuid
from typing import Dict, Optional, Set, TYPE_CHECKING
from typing import Callable, Collection, Dict, Optional, Set, TYPE_CHECKING, cast
import ormar
from ormar.exceptions import ModelPersistenceError
@ -7,6 +7,9 @@ from ormar.models.helpers.validation import validate_choices
from ormar.models.mixins import AliasMixin
from ormar.models.mixins.relation_mixin import RelationMixin
if TYPE_CHECKING: # pragma: no cover
from ormar import ForeignKeyField, Model
class SavePrepareMixin(RelationMixin, AliasMixin):
"""
@ -15,6 +18,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if TYPE_CHECKING: # pragma: nocover
_choices_fields: Optional[Set]
_skip_ellipsis: Callable
@classmethod
def prepare_model_to_save(cls, new_kwargs: dict) -> dict:
@ -170,3 +174,130 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if field_name in new_kwargs and field_name in cls._choices_fields:
validate_choices(field=field, value=new_kwargs.get(field_name))
return new_kwargs
@staticmethod
async def _upsert_model(
instance: "Model",
save_all: bool,
previous_model: Optional["Model"],
relation_field: Optional["ForeignKeyField"],
update_count: int,
) -> int:
"""
Method updates given instance if:
* instance is not saved or
* instance have no pk or
* save_all=True flag is set
and instance is not __pk_only__.
If relation leading to instance is a ManyToMany also the through model is saved
:param instance: current model to upsert
:type instance: Model
:param save_all: flag if all models should be saved or only not saved ones
:type save_all: bool
:param relation_field: field with relation
:type relation_field: Optional[ForeignKeyField]
:param previous_model: previous model from which method came
:type previous_model: Model
:param update_count: no of updated models
:type update_count: int
:return: no of updated models
:rtype: int
"""
if (
save_all or not instance.pk or not instance.saved
) and not instance.__pk_only__:
await instance.upsert()
if relation_field and relation_field.is_multi:
await instance._upsert_through_model(
instance=instance,
relation_field=relation_field,
previous_model=cast("Model", previous_model),
)
update_count += 1
return update_count
@staticmethod
async def _upsert_through_model(
instance: "Model",
previous_model: "Model",
relation_field: Optional["ForeignKeyField"],
) -> None:
"""
Upsert through model for m2m relation.
:param instance: current model to upsert
:type instance: Model
:param relation_field: field with relation
:type relation_field: Optional[ForeignKeyField]
:param previous_model: previous model from which method came
:type previous_model: Model
"""
through_name = previous_model.Meta.model_fields[
relation_field.name
].through.get_name()
through = getattr(instance, through_name)
if through:
through_dict = through.dict(exclude=through.extract_related_names())
else:
through_dict = {}
await getattr(
previous_model, relation_field.name
).queryset_proxy.upsert_through_instance(instance, **through_dict)
async def _update_relation_list(
self,
fields_list: Collection["ForeignKeyField"],
follow: bool,
save_all: bool,
relation_map: Dict,
update_count: int,
) -> int:
"""
Internal method used in save_related to follow deeper from
related models and update numbers of updated related instances.
:type save_all: flag if all models should be saved
:type save_all: bool
:param fields_list: list of ormar fields to follow and save
:type fields_list: Collection["ForeignKeyField"]
:param relation_map: map of relations to follow
:type relation_map: Dict
:param follow: flag to trigger deep save -
by default only directly related models are saved
with follow=True also related models of related models are saved
:type follow: bool
:param update_count: internal parameter for recursive calls -
number of updated instances
:type update_count: int
:return: tuple of update count and visited
:rtype: int
"""
for field in fields_list:
value = getattr(self, field.name) or []
if not isinstance(value, list):
value = [value]
for val in value:
if follow:
update_count = await val.save_related(
follow=follow,
save_all=save_all,
relation_map=self._skip_ellipsis( # type: ignore
relation_map, field.name, default_return={}
),
update_count=update_count,
previous_model=self,
relation_field=field,
)
else:
update_count = await val._upsert_model(
instance=val,
save_all=save_all,
previous_model=self,
relation_field=field,
update_count=update_count,
)
return update_count

View File

@ -2,6 +2,7 @@ from typing import (
Any,
Dict,
List,
Optional,
Set,
TYPE_CHECKING,
TypeVar,
@ -17,6 +18,9 @@ from ormar.queryset.utils import subtract_dict, translate_list_to_dict
T = TypeVar("T", bound="Model")
if TYPE_CHECKING: # pragma: no cover
from ormar import ForeignKeyField
class Model(ModelRow):
__abstract__ = False
@ -110,6 +114,8 @@ class Model(ModelRow):
relation_map: Dict = None,
exclude: Union[Set, Dict] = None,
update_count: int = 0,
previous_model: "Model" = None,
relation_field: Optional["ForeignKeyField"] = None,
) -> int:
"""
Triggers a upsert method on all related models
@ -126,6 +132,10 @@ class Model(ModelRow):
Model A but will never follow into Model C.
Nested relations of those kind need to be persisted manually.
:param relation_field: field with relation leading to this model
:type relation_field: Optional[ForeignKeyField]
:param previous_model: previous model from which method came
:type previous_model: Model
:param exclude: items to exclude during saving of relations
:type exclude: Union[Set, Dict]
:param relation_map: map of relations to follow
@ -151,61 +161,53 @@ class Model(ModelRow):
exclude = translate_list_to_dict(exclude)
relation_map = subtract_dict(relation_map, exclude or {})
for related in self.extract_related_names():
if relation_map and related in relation_map:
value = getattr(self, related)
if value:
update_count = await self._update_and_follow(
value=value,
follow=follow,
save_all=save_all,
relation_map=self._skip_ellipsis( # type: ignore
relation_map, related, default_return={}
),
update_count=update_count,
)
return update_count
if relation_map:
fields_to_visit = {
field
for field in self.extract_related_fields()
if field.name in relation_map
}
pre_save = {
field
for field in fields_to_visit
if not field.virtual and not field.is_multi
}
@staticmethod
async def _update_and_follow(
value: Union["Model", List["Model"]],
follow: bool,
save_all: bool,
relation_map: Dict,
update_count: int,
) -> int:
"""
Internal method used in save_related to follow related models and update numbers
of updated related instances.
update_count = await self._update_relation_list(
fields_list=pre_save,
follow=follow,
save_all=save_all,
relation_map=relation_map,
update_count=update_count,
)
:param value: Model to follow
:type value: Model
:param relation_map: map of relations to follow
:type relation_map: Dict
:param follow: flag to trigger deep save -
by default only directly related models are saved
with follow=True also related models of related models are saved
:type follow: bool
:param update_count: internal parameter for recursive calls -
number of updated instances
:type update_count: int
:return: tuple of update count and visited
:rtype: int
"""
if not isinstance(value, list):
value = [value]
update_count = await self._upsert_model(
instance=self,
save_all=save_all,
previous_model=previous_model,
relation_field=relation_field,
update_count=update_count,
)
post_save = fields_to_visit - pre_save
update_count = await self._update_relation_list(
fields_list=post_save,
follow=follow,
save_all=save_all,
relation_map=relation_map,
update_count=update_count,
)
else:
update_count = await self._upsert_model(
instance=self,
save_all=save_all,
previous_model=previous_model,
relation_field=relation_field,
update_count=update_count,
)
for val in value:
if (not val.saved or save_all) and not val.__pk_only__:
await val.upsert()
update_count += 1
if follow:
update_count = await val.save_related(
follow=follow,
save_all=save_all,
relation_map=relation_map,
update_count=update_count,
)
return update_count
async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T:

View File

@ -16,7 +16,7 @@ from typing import ( # noqa: I100, I201
)
import ormar # noqa: I100, I202
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
from ormar.exceptions import ModelPersistenceError, NoMatch, QueryDefinitionError
if TYPE_CHECKING: # pragma no cover
from ormar.relations import Relation
@ -152,6 +152,21 @@ class QuerysetProxy(Generic[T]):
through_model = await model_cls.objects.get(**rel_kwargs)
await through_model.update(**kwargs)
async def upsert_through_instance(self, child: "T", **kwargs: Any) -> None:
"""
Updates a through model instance in the database for m2m relations if
it already exists, else creates one.
:param kwargs: dict of additional keyword arguments for through instance
:type kwargs: Any
:param child: child model instance
:type child: Model
"""
try:
await self.update_through_instance(child=child, **kwargs)
except NoMatch:
await self.create_through_instance(child=child, **kwargs)
async def delete_through_instance(self, child: "T") -> None:
"""
Removes through model instance from the database for m2m relations.

View File

@ -119,7 +119,7 @@ async def test_saving_many_to_many():
assert count == 0
count = await hq.save_related(save_all=True)
assert count == 2
assert count == 3
hq.nicks[0].name = "Kabucha"
hq.nicks[1].name = "Kabucha2"

View File

@ -0,0 +1,256 @@
from typing import List
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class CringeLevel(ormar.Model):
class Meta:
tablename = "levels"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class NickName(ormar.Model):
class Meta:
tablename = "nicks"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
is_lame: bool = ormar.Boolean(nullable=True)
level: CringeLevel = ormar.ForeignKey(CringeLevel)
class NicksHq(ormar.Model):
class Meta:
tablename = "nicks_x_hq"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
new_field: str = ormar.String(max_length=200, nullable=True)
class HQ(ormar.Model):
class Meta:
tablename = "hqs"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
nicks: List[NickName] = ormar.ManyToMany(NickName, through=NicksHq)
class Company(ormar.Model):
class Meta:
tablename = "companies"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="company_name")
founded: int = ormar.Integer(nullable=True)
hq: HQ = ormar.ForeignKey(HQ, related_name="companies")
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_saving_related_reverse_fk():
async with database:
async with database.transaction(force_rollback=True):
payload = {"companies": [{"name": "Banzai"}], "name": "Main"}
hq = HQ(**payload)
count = await hq.save_related(follow=True, save_all=True)
assert count == 2
hq_check = await HQ.objects.select_related("companies").get()
assert hq_check.pk is not None
assert hq_check.name == "Main"
assert len(hq_check.companies) == 1
assert hq_check.companies[0].name == "Banzai"
assert hq_check.companies[0].pk is not None
@pytest.mark.asyncio
async def test_saving_related_reverse_fk_multiple():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"companies": [{"name": "Banzai"}, {"name": "Yamate"}],
"name": "Main",
}
hq = HQ(**payload)
count = await hq.save_related(follow=True, save_all=True)
assert count == 3
hq_check = await HQ.objects.select_related("companies").get()
assert hq_check.pk is not None
assert hq_check.name == "Main"
assert len(hq_check.companies) == 2
assert hq_check.companies[0].name == "Banzai"
assert hq_check.companies[0].pk is not None
assert hq_check.companies[1].name == "Yamate"
assert hq_check.companies[1].pk is not None
@pytest.mark.asyncio
async def test_saving_related_fk():
async with database:
async with database.transaction(force_rollback=True):
payload = {"hq": {"name": "Main"}, "name": "Banzai"}
comp = Company(**payload)
count = await comp.save_related(follow=True, save_all=True)
assert count == 2
comp_check = await Company.objects.select_related("hq").get()
assert comp_check.pk is not None
assert comp_check.name == "Banzai"
assert comp_check.hq.name == "Main"
assert comp_check.hq.pk is not None
@pytest.mark.asyncio
async def test_saving_many_to_many_wo_through():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"name": "Main",
"nicks": [
{"name": "Bazinga0", "is_lame": False},
{"name": "Bazinga20", "is_lame": True},
],
}
hq = HQ(**payload)
count = await hq.save_related()
assert count == 3
hq_check = await HQ.objects.select_related("nicks").get()
assert hq_check.pk is not None
assert len(hq_check.nicks) == 2
assert hq_check.nicks[0].name == "Bazinga0"
assert hq_check.nicks[1].name == "Bazinga20"
@pytest.mark.asyncio
async def test_saving_many_to_many_with_through():
async with database:
async with database.transaction(force_rollback=True):
async with database.transaction(force_rollback=True):
payload = {
"name": "Main",
"nicks": [
{
"name": "Bazinga0",
"is_lame": False,
"nickshq": {"new_field": "test"},
},
{
"name": "Bazinga20",
"is_lame": True,
"nickshq": {"new_field": "test2"},
},
],
}
hq = HQ(**payload)
count = await hq.save_related()
assert count == 3
hq_check = await HQ.objects.select_related("nicks").get()
assert hq_check.pk is not None
assert len(hq_check.nicks) == 2
assert hq_check.nicks[0].name == "Bazinga0"
assert hq_check.nicks[0].nickshq.new_field == "test"
assert hq_check.nicks[1].name == "Bazinga20"
assert hq_check.nicks[1].nickshq.new_field == "test2"
@pytest.mark.asyncio
async def test_saving_nested_with_m2m_and_rev_fk():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"name": "Main",
"nicks": [
{"name": "Bazinga0", "is_lame": False, "level": {"name": "High"}},
{"name": "Bazinga20", "is_lame": True, "level": {"name": "Low"}},
],
}
hq = HQ(**payload)
count = await hq.save_related(follow=True, save_all=True)
assert count == 5
hq_check = await HQ.objects.select_related("nicks__level").get()
assert hq_check.pk is not None
assert len(hq_check.nicks) == 2
assert hq_check.nicks[0].name == "Bazinga0"
assert hq_check.nicks[0].level.name == "High"
assert hq_check.nicks[1].name == "Bazinga20"
assert hq_check.nicks[1].level.name == "Low"
@pytest.mark.asyncio
async def test_saving_nested_with_m2m_and_rev_fk_and_through():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"hq": {
"name": "Yoko",
"nicks": [
{
"name": "Bazinga0",
"is_lame": False,
"nickshq": {"new_field": "test"},
"level": {"name": "High"},
},
{
"name": "Bazinga20",
"is_lame": True,
"nickshq": {"new_field": "test2"},
"level": {"name": "Low"},
},
],
},
"name": "Main",
}
company = Company(**payload)
count = await company.save_related(follow=True, save_all=True)
assert count == 6
company_check = await Company.objects.select_related(
"hq__nicks__level"
).get()
assert company_check.pk is not None
assert company_check.name == "Main"
assert company_check.hq.name == "Yoko"
assert len(company_check.hq.nicks) == 2
assert company_check.hq.nicks[0].name == "Bazinga0"
assert company_check.hq.nicks[0].nickshq.new_field == "test"
assert company_check.hq.nicks[0].level.name == "High"
assert company_check.hq.nicks[1].name == "Bazinga20"
assert company_check.hq.nicks[1].level.name == "Low"
assert company_check.hq.nicks[1].nickshq.new_field == "test2"