diff --git a/docs/releases.md b/docs/releases.md index 69e725b..b9becc1 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -13,6 +13,17 @@ * even if you `select_related` from reverse side of the model the returned models won't be populated in reversed instance (the join is not prevented so you still can `filter` and `order_by`) * the relation won't be populated in `dict()` and `json()` * you cannot pass the nested related objects when populating from `dict()` or `json()` (also through `fastapi`). It will be either ignored or raise error depending on `extra` setting in pydantic `Config`. +* `Model.save_related()` now can save whole data tree in once [#148](https://github.com/collerek/ormar/discussions/148) + meaning: + * it knows if it should save main `Model` or related `Model` first to preserve the relation + * it saves main `Model` if + * it's not `saved`, + * has no `pk` value + * or `save_all=True` flag is set + + in those cases you don't have to split save into two calls (`save()` and `save_related()`) + * it supports also `ManyToMany` relations + * it supports also optional `Through` model values for m2m relations ## 🐛 Fixes diff --git a/ormar/models/mixins/relation_mixin.py b/ormar/models/mixins/relation_mixin.py index 43de0be..6a71382 100644 --- a/ormar/models/mixins/relation_mixin.py +++ b/ormar/models/mixins/relation_mixin.py @@ -4,9 +4,10 @@ from typing import ( Optional, Set, TYPE_CHECKING, + cast, ) -from ormar import BaseField +from ormar import BaseField, ForeignKeyField from ormar.models.traversible import NodeList @@ -39,7 +40,7 @@ class RelationMixin: return self_fields @classmethod - def extract_related_fields(cls) -> List: + def extract_related_fields(cls) -> List["ForeignKeyField"]: """ Returns List of ormar Fields for all relations declared on a model. List is cached in cls._related_fields for quicker access. @@ -52,7 +53,7 @@ class RelationMixin: related_fields = [] for name in cls.extract_related_names().union(cls.extract_through_names()): - related_fields.append(cls.Meta.model_fields[name]) + related_fields.append(cast("ForeignKeyField", cls.Meta.model_fields[name])) cls._related_fields = related_fields return related_fields diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index 5ade49f..52826b7 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -1,5 +1,5 @@ import uuid -from typing import Dict, Optional, Set, TYPE_CHECKING +from typing import Callable, Collection, Dict, Optional, Set, TYPE_CHECKING, cast import ormar from ormar.exceptions import ModelPersistenceError @@ -7,6 +7,9 @@ from ormar.models.helpers.validation import validate_choices from ormar.models.mixins import AliasMixin from ormar.models.mixins.relation_mixin import RelationMixin +if TYPE_CHECKING: # pragma: no cover + from ormar import ForeignKeyField, Model + class SavePrepareMixin(RelationMixin, AliasMixin): """ @@ -15,6 +18,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin): if TYPE_CHECKING: # pragma: nocover _choices_fields: Optional[Set] + _skip_ellipsis: Callable @classmethod def prepare_model_to_save(cls, new_kwargs: dict) -> dict: @@ -170,3 +174,130 @@ class SavePrepareMixin(RelationMixin, AliasMixin): if field_name in new_kwargs and field_name in cls._choices_fields: validate_choices(field=field, value=new_kwargs.get(field_name)) return new_kwargs + + @staticmethod + async def _upsert_model( + instance: "Model", + save_all: bool, + previous_model: Optional["Model"], + relation_field: Optional["ForeignKeyField"], + update_count: int, + ) -> int: + """ + Method updates given instance if: + + * instance is not saved or + * instance have no pk or + * save_all=True flag is set + + and instance is not __pk_only__. + + If relation leading to instance is a ManyToMany also the through model is saved + + :param instance: current model to upsert + :type instance: Model + :param save_all: flag if all models should be saved or only not saved ones + :type save_all: bool + :param relation_field: field with relation + :type relation_field: Optional[ForeignKeyField] + :param previous_model: previous model from which method came + :type previous_model: Model + :param update_count: no of updated models + :type update_count: int + :return: no of updated models + :rtype: int + """ + if ( + save_all or not instance.pk or not instance.saved + ) and not instance.__pk_only__: + await instance.upsert() + if relation_field and relation_field.is_multi: + await instance._upsert_through_model( + instance=instance, + relation_field=relation_field, + previous_model=cast("Model", previous_model), + ) + update_count += 1 + return update_count + + @staticmethod + async def _upsert_through_model( + instance: "Model", + previous_model: "Model", + relation_field: Optional["ForeignKeyField"], + ) -> None: + """ + Upsert through model for m2m relation. + + :param instance: current model to upsert + :type instance: Model + :param relation_field: field with relation + :type relation_field: Optional[ForeignKeyField] + :param previous_model: previous model from which method came + :type previous_model: Model + """ + through_name = previous_model.Meta.model_fields[ + relation_field.name + ].through.get_name() + through = getattr(instance, through_name) + if through: + through_dict = through.dict(exclude=through.extract_related_names()) + else: + through_dict = {} + await getattr( + previous_model, relation_field.name + ).queryset_proxy.upsert_through_instance(instance, **through_dict) + + async def _update_relation_list( + self, + fields_list: Collection["ForeignKeyField"], + follow: bool, + save_all: bool, + relation_map: Dict, + update_count: int, + ) -> int: + """ + Internal method used in save_related to follow deeper from + related models and update numbers of updated related instances. + + :type save_all: flag if all models should be saved + :type save_all: bool + :param fields_list: list of ormar fields to follow and save + :type fields_list: Collection["ForeignKeyField"] + :param relation_map: map of relations to follow + :type relation_map: Dict + :param follow: flag to trigger deep save - + by default only directly related models are saved + with follow=True also related models of related models are saved + :type follow: bool + :param update_count: internal parameter for recursive calls - + number of updated instances + :type update_count: int + :return: tuple of update count and visited + :rtype: int + """ + for field in fields_list: + value = getattr(self, field.name) or [] + if not isinstance(value, list): + value = [value] + for val in value: + if follow: + update_count = await val.save_related( + follow=follow, + save_all=save_all, + relation_map=self._skip_ellipsis( # type: ignore + relation_map, field.name, default_return={} + ), + update_count=update_count, + previous_model=self, + relation_field=field, + ) + else: + update_count = await val._upsert_model( + instance=val, + save_all=save_all, + previous_model=self, + relation_field=field, + update_count=update_count, + ) + return update_count diff --git a/ormar/models/model.py b/ormar/models/model.py index 5e27c69..7c5f834 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -2,6 +2,7 @@ from typing import ( Any, Dict, List, + Optional, Set, TYPE_CHECKING, TypeVar, @@ -17,6 +18,9 @@ from ormar.queryset.utils import subtract_dict, translate_list_to_dict T = TypeVar("T", bound="Model") +if TYPE_CHECKING: # pragma: no cover + from ormar import ForeignKeyField + class Model(ModelRow): __abstract__ = False @@ -110,6 +114,8 @@ class Model(ModelRow): relation_map: Dict = None, exclude: Union[Set, Dict] = None, update_count: int = 0, + previous_model: "Model" = None, + relation_field: Optional["ForeignKeyField"] = None, ) -> int: """ Triggers a upsert method on all related models @@ -126,6 +132,10 @@ class Model(ModelRow): Model A but will never follow into Model C. Nested relations of those kind need to be persisted manually. + :param relation_field: field with relation leading to this model + :type relation_field: Optional[ForeignKeyField] + :param previous_model: previous model from which method came + :type previous_model: Model :param exclude: items to exclude during saving of relations :type exclude: Union[Set, Dict] :param relation_map: map of relations to follow @@ -151,61 +161,53 @@ class Model(ModelRow): exclude = translate_list_to_dict(exclude) relation_map = subtract_dict(relation_map, exclude or {}) - for related in self.extract_related_names(): - if relation_map and related in relation_map: - value = getattr(self, related) - if value: - update_count = await self._update_and_follow( - value=value, - follow=follow, - save_all=save_all, - relation_map=self._skip_ellipsis( # type: ignore - relation_map, related, default_return={} - ), - update_count=update_count, - ) - return update_count + if relation_map: + fields_to_visit = { + field + for field in self.extract_related_fields() + if field.name in relation_map + } + pre_save = { + field + for field in fields_to_visit + if not field.virtual and not field.is_multi + } - @staticmethod - async def _update_and_follow( - value: Union["Model", List["Model"]], - follow: bool, - save_all: bool, - relation_map: Dict, - update_count: int, - ) -> int: - """ - Internal method used in save_related to follow related models and update numbers - of updated related instances. + update_count = await self._update_relation_list( + fields_list=pre_save, + follow=follow, + save_all=save_all, + relation_map=relation_map, + update_count=update_count, + ) - :param value: Model to follow - :type value: Model - :param relation_map: map of relations to follow - :type relation_map: Dict - :param follow: flag to trigger deep save - - by default only directly related models are saved - with follow=True also related models of related models are saved - :type follow: bool - :param update_count: internal parameter for recursive calls - - number of updated instances - :type update_count: int - :return: tuple of update count and visited - :rtype: int - """ - if not isinstance(value, list): - value = [value] + update_count = await self._upsert_model( + instance=self, + save_all=save_all, + previous_model=previous_model, + relation_field=relation_field, + update_count=update_count, + ) + + post_save = fields_to_visit - pre_save + + update_count = await self._update_relation_list( + fields_list=post_save, + follow=follow, + save_all=save_all, + relation_map=relation_map, + update_count=update_count, + ) + + else: + update_count = await self._upsert_model( + instance=self, + save_all=save_all, + previous_model=previous_model, + relation_field=relation_field, + update_count=update_count, + ) - for val in value: - if (not val.saved or save_all) and not val.__pk_only__: - await val.upsert() - update_count += 1 - if follow: - update_count = await val.save_related( - follow=follow, - save_all=save_all, - relation_map=relation_map, - update_count=update_count, - ) return update_count async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T: diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index 4578414..f60f263 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -16,7 +16,7 @@ from typing import ( # noqa: I100, I201 ) import ormar # noqa: I100, I202 -from ormar.exceptions import ModelPersistenceError, QueryDefinitionError +from ormar.exceptions import ModelPersistenceError, NoMatch, QueryDefinitionError if TYPE_CHECKING: # pragma no cover from ormar.relations import Relation @@ -152,6 +152,21 @@ class QuerysetProxy(Generic[T]): through_model = await model_cls.objects.get(**rel_kwargs) await through_model.update(**kwargs) + async def upsert_through_instance(self, child: "T", **kwargs: Any) -> None: + """ + Updates a through model instance in the database for m2m relations if + it already exists, else creates one. + + :param kwargs: dict of additional keyword arguments for through instance + :type kwargs: Any + :param child: child model instance + :type child: Model + """ + try: + await self.update_through_instance(child=child, **kwargs) + except NoMatch: + await self.create_through_instance(child=child, **kwargs) + async def delete_through_instance(self, child: "T") -> None: """ Removes through model instance from the database for m2m relations. diff --git a/tests/test_model_methods/test_save_related.py b/tests/test_model_methods/test_save_related.py index 774d167..207c53d 100644 --- a/tests/test_model_methods/test_save_related.py +++ b/tests/test_model_methods/test_save_related.py @@ -119,7 +119,7 @@ async def test_saving_many_to_many(): assert count == 0 count = await hq.save_related(save_all=True) - assert count == 2 + assert count == 3 hq.nicks[0].name = "Kabucha" hq.nicks[1].name = "Kabucha2" diff --git a/tests/test_model_methods/test_save_related_from_dict.py b/tests/test_model_methods/test_save_related_from_dict.py new file mode 100644 index 0000000..a545092 --- /dev/null +++ b/tests/test_model_methods/test_save_related_from_dict.py @@ -0,0 +1,256 @@ +from typing import List + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class CringeLevel(ormar.Model): + class Meta: + tablename = "levels" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class NickName(ormar.Model): + class Meta: + tablename = "nicks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + is_lame: bool = ormar.Boolean(nullable=True) + level: CringeLevel = ormar.ForeignKey(CringeLevel) + + +class NicksHq(ormar.Model): + class Meta: + tablename = "nicks_x_hq" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + new_field: str = ormar.String(max_length=200, nullable=True) + + +class HQ(ormar.Model): + class Meta: + tablename = "hqs" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + nicks: List[NickName] = ormar.ManyToMany(NickName, through=NicksHq) + + +class Company(ormar.Model): + class Meta: + tablename = "companies" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="company_name") + founded: int = ormar.Integer(nullable=True) + hq: HQ = ormar.ForeignKey(HQ, related_name="companies") + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_saving_related_reverse_fk(): + async with database: + async with database.transaction(force_rollback=True): + payload = {"companies": [{"name": "Banzai"}], "name": "Main"} + hq = HQ(**payload) + count = await hq.save_related(follow=True, save_all=True) + assert count == 2 + + hq_check = await HQ.objects.select_related("companies").get() + assert hq_check.pk is not None + assert hq_check.name == "Main" + assert len(hq_check.companies) == 1 + assert hq_check.companies[0].name == "Banzai" + assert hq_check.companies[0].pk is not None + + +@pytest.mark.asyncio +async def test_saving_related_reverse_fk_multiple(): + async with database: + async with database.transaction(force_rollback=True): + payload = { + "companies": [{"name": "Banzai"}, {"name": "Yamate"}], + "name": "Main", + } + hq = HQ(**payload) + count = await hq.save_related(follow=True, save_all=True) + assert count == 3 + + hq_check = await HQ.objects.select_related("companies").get() + assert hq_check.pk is not None + assert hq_check.name == "Main" + assert len(hq_check.companies) == 2 + assert hq_check.companies[0].name == "Banzai" + assert hq_check.companies[0].pk is not None + assert hq_check.companies[1].name == "Yamate" + assert hq_check.companies[1].pk is not None + + +@pytest.mark.asyncio +async def test_saving_related_fk(): + async with database: + async with database.transaction(force_rollback=True): + payload = {"hq": {"name": "Main"}, "name": "Banzai"} + comp = Company(**payload) + count = await comp.save_related(follow=True, save_all=True) + assert count == 2 + + comp_check = await Company.objects.select_related("hq").get() + assert comp_check.pk is not None + assert comp_check.name == "Banzai" + assert comp_check.hq.name == "Main" + assert comp_check.hq.pk is not None + + +@pytest.mark.asyncio +async def test_saving_many_to_many_wo_through(): + async with database: + async with database.transaction(force_rollback=True): + payload = { + "name": "Main", + "nicks": [ + {"name": "Bazinga0", "is_lame": False}, + {"name": "Bazinga20", "is_lame": True}, + ], + } + + hq = HQ(**payload) + count = await hq.save_related() + assert count == 3 + + hq_check = await HQ.objects.select_related("nicks").get() + assert hq_check.pk is not None + assert len(hq_check.nicks) == 2 + assert hq_check.nicks[0].name == "Bazinga0" + assert hq_check.nicks[1].name == "Bazinga20" + + +@pytest.mark.asyncio +async def test_saving_many_to_many_with_through(): + async with database: + async with database.transaction(force_rollback=True): + async with database.transaction(force_rollback=True): + payload = { + "name": "Main", + "nicks": [ + { + "name": "Bazinga0", + "is_lame": False, + "nickshq": {"new_field": "test"}, + }, + { + "name": "Bazinga20", + "is_lame": True, + "nickshq": {"new_field": "test2"}, + }, + ], + } + + hq = HQ(**payload) + count = await hq.save_related() + assert count == 3 + + hq_check = await HQ.objects.select_related("nicks").get() + assert hq_check.pk is not None + assert len(hq_check.nicks) == 2 + assert hq_check.nicks[0].name == "Bazinga0" + assert hq_check.nicks[0].nickshq.new_field == "test" + assert hq_check.nicks[1].name == "Bazinga20" + assert hq_check.nicks[1].nickshq.new_field == "test2" + + +@pytest.mark.asyncio +async def test_saving_nested_with_m2m_and_rev_fk(): + async with database: + async with database.transaction(force_rollback=True): + payload = { + "name": "Main", + "nicks": [ + {"name": "Bazinga0", "is_lame": False, "level": {"name": "High"}}, + {"name": "Bazinga20", "is_lame": True, "level": {"name": "Low"}}, + ], + } + + hq = HQ(**payload) + count = await hq.save_related(follow=True, save_all=True) + assert count == 5 + + hq_check = await HQ.objects.select_related("nicks__level").get() + assert hq_check.pk is not None + assert len(hq_check.nicks) == 2 + assert hq_check.nicks[0].name == "Bazinga0" + assert hq_check.nicks[0].level.name == "High" + assert hq_check.nicks[1].name == "Bazinga20" + assert hq_check.nicks[1].level.name == "Low" + + +@pytest.mark.asyncio +async def test_saving_nested_with_m2m_and_rev_fk_and_through(): + async with database: + async with database.transaction(force_rollback=True): + payload = { + "hq": { + "name": "Yoko", + "nicks": [ + { + "name": "Bazinga0", + "is_lame": False, + "nickshq": {"new_field": "test"}, + "level": {"name": "High"}, + }, + { + "name": "Bazinga20", + "is_lame": True, + "nickshq": {"new_field": "test2"}, + "level": {"name": "Low"}, + }, + ], + }, + "name": "Main", + } + + company = Company(**payload) + count = await company.save_related(follow=True, save_all=True) + assert count == 6 + + company_check = await Company.objects.select_related( + "hq__nicks__level" + ).get() + assert company_check.pk is not None + assert company_check.name == "Main" + assert company_check.hq.name == "Yoko" + assert len(company_check.hq.nicks) == 2 + assert company_check.hq.nicks[0].name == "Bazinga0" + assert company_check.hq.nicks[0].nickshq.new_field == "test" + assert company_check.hq.nicks[0].level.name == "High" + assert company_check.hq.nicks[1].name == "Bazinga20" + assert company_check.hq.nicks[1].level.name == "Low" + assert company_check.hq.nicks[1].nickshq.new_field == "test2"