diff --git a/docs/releases.md b/docs/releases.md index 98f1234..d874436 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -2,19 +2,22 @@ ## Features -* `save_related(follow=False)` now accept also second argument `save_related(follow=False, save_all=False)`. - By default so with `save_all=False` `ormar` only upserts models that are no saved (so new or updated ones), +* `Model.save_related(follow=False)` now accept also two additional arguments: `Model.save_related(follow=False, save_all=False, exclude=None)`. + * `save_all:bool` -> By default (so with `save_all=False`) `ormar` only upserts models that are not saved (so new or updated ones), with `save_all=True` all related models are saved, regardless of `saved` status, which might be useful if updated - models comes from api call, so are not changed in backend. + models comes from api call, so are not changed in the backend. + * `exclude: Union[Set, Dict, None]` -> set/dict of relations to exclude from save, those relation won't be saved even with `follow=True` and `save_all=True`. + To exclude nested relations pass a nested dictionary like: `exclude={"child":{"sub_child": {"exclude_sub_child_realtion"}}}`. The allowed values follow + the `fields/exclude_fields` (from `QuerySet`) methods schema so when in doubt you can refer to docs in queries -> selecting subset of fields -> fields. * `dict()` method previously included only directly related models or nested models if they were not nullable and not virtual, - now all related models not previosuly visited without loops are included in `dict()`. This should be not breaking + now all related models not previously visited without loops are included in `dict()`. This should be not breaking as just more data will be dumped to dict, but it should not be missing. ## Fixes * Fix improper relation field resolution in `QuerysetProxy` if fk column has different database alias. -* Fix hitting recursion error with very complicated models structure with loops. +* Fix hitting recursion error with very complicated models structure with loops when calling `dict()`. ## Other diff --git a/ormar/models/model.py b/ormar/models/model.py index cf4b653..641aaa5 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -4,7 +4,6 @@ from typing import ( List, Set, TYPE_CHECKING, - Tuple, TypeVar, Union, ) @@ -14,7 +13,7 @@ from ormar.exceptions import ModelPersistenceError, NoMatch from ormar.models import NewBaseModel # noqa I100 from ormar.models.metaclass import ModelMeta from ormar.models.model_row import ModelRow - +from ormar.queryset.utils import subtract_dict, translate_list_to_dict T = TypeVar("T", bound="Model") @@ -104,9 +103,10 @@ class Model(ModelRow): self, follow: bool = False, save_all: bool = False, - visited: Set = None, + relation_map: Dict = None, + exclude: Union[Set, Dict] = None, update_count: int = 0, - ) -> int: # noqa: CCR001 + ) -> int: """ Triggers a upsert method on all related models if the instances are not already saved. @@ -122,83 +122,86 @@ class Model(ModelRow): Model A but will never follow into Model C. Nested relations of those kind need to be persisted manually. + :param exclude: items to exclude during saving of relations + :type exclude: Union[Set, Dict] + :param relation_map: map of relations to follow + :type relation_map: Dict + :param save_all: flag if all models should be saved or only not saved ones + :type save_all: bool :param follow: flag to trigger deep save - by default only directly related models are saved with follow=True also related models of related models are saved :type follow: bool - :param visited: internal parameter for recursive calls - already visited models - :type visited: Set :param update_count: internal parameter for recursive calls - number of updated instances :type update_count: int :return: number of updated/saved models :rtype: int """ - if not visited: - visited = {self.__class__} - else: - visited = {x for x in visited} - visited.add(self.__class__) + relation_map = ( + relation_map + if relation_map is not None + else translate_list_to_dict(self._iterate_related_models()) + ) + if exclude and isinstance(exclude, Set): + exclude = translate_list_to_dict(exclude) + relation_map = subtract_dict(relation_map, exclude or {}) for related in self.extract_related_names(): - if ( - self.Meta.model_fields[related].virtual - or self.Meta.model_fields[related].is_multi - ): - for rel in getattr(self, related): - update_count, visited = await self._update_and_follow( - rel=rel, - follow=follow, - save_all=save_all, - visited=visited, - update_count=update_count, - ) - visited.add(self.Meta.model_fields[related].to) - else: - rel = getattr(self, related) - update_count, visited = await self._update_and_follow( - rel=rel, + if relation_map and related in relation_map: + value = getattr(self, related) + update_count = await self._update_and_follow( + value=value, follow=follow, save_all=save_all, - visited=visited, + relation_map=self._skip_ellipsis( # type: ignore + relation_map, related, default_return={} + ), update_count=update_count, ) - visited.add(rel.__class__) return update_count @staticmethod async def _update_and_follow( - rel: "Model", follow: bool, save_all: bool, visited: Set, update_count: int - ) -> Tuple[int, Set]: + value: Union["Model", List["Model"]], + follow: bool, + save_all: bool, + relation_map: Dict, + update_count: int, + ) -> int: """ Internal method used in save_related to follow related models and update numbers of updated related instances. - :param rel: Model to follow - :type rel: Model + :param value: Model to follow + :type value: Model + :param relation_map: map of relations to follow + :type relation_map: Dict :param follow: flag to trigger deep save - by default only directly related models are saved with follow=True also related models of related models are saved :type follow: bool - :param visited: internal parameter for recursive calls - already visited models - :type visited: Set :param update_count: internal parameter for recursive calls - number of updated instances :type update_count: int :return: tuple of update count and visited - :rtype: Tuple[int, Set] + :rtype: int """ - if follow and rel.__class__ not in visited: - update_count = await rel.save_related( - follow=follow, - save_all=save_all, - visited=visited, - update_count=update_count, - ) - if not rel.saved or save_all: - await rel.upsert() - update_count += 1 - return update_count, visited + if not isinstance(value, list): + value = [value] + + for val in value: + if follow: + update_count = await val.save_related( + follow=follow, + save_all=save_all, + relation_map=relation_map, + update_count=update_count, + ) + if not val.saved or save_all: + await val.upsert() + update_count += 1 + return update_count async def update(self: T, **kwargs: Any) -> T: """ diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index f5f00ac..5653f55 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -127,6 +127,41 @@ def update(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001 return current_dict +def subtract_dict(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001 + """ + Update one dict with another but with regard for nested keys. + + That way nested sets are unionised, dicts updated and + only other values are overwritten. + + :param current_dict: dict to update + :type current_dict: Dict[str, ellipsis] + :param updating_dict: dict with values to update + :type updating_dict: Dict + :return: combination of both dicts + :rtype: Dict + """ + if current_dict is Ellipsis: + return dict() + for key, value in updating_dict.items(): + old_key = current_dict.get(key, {}) + new_value: Optional[Union[Dict, Set]] = None + if not old_key: + continue + if isinstance(value, collections.abc.Mapping): + if isinstance(old_key, set): + old_key = convert_set_to_required_dict(old_key) + new_value = subtract_dict(old_key, value) + elif isinstance(value, set) and isinstance(old_key, set): + new_value = old_key.difference(value) + + if new_value: + current_dict[key] = new_value + else: + current_dict.pop(key, None) + return current_dict + + def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> Dict: """ Converts the list into dictionary and later performs special update, where diff --git a/tests/test_model_methods/test_save_related.py b/tests/test_model_methods/test_save_related.py index 8609353..774d167 100644 --- a/tests/test_model_methods/test_save_related.py +++ b/tests/test_model_methods/test_save_related.py @@ -91,6 +91,14 @@ async def test_saving_related_fk_rel(): assert count == 1 assert comp.hq.saved + comp.hq.name = "Suburbs 2" + assert not comp.hq.saved + assert comp.saved + + count = await comp.save_related(exclude={"hq"}) + assert count == 0 + assert not comp.hq.saved + @pytest.mark.asyncio async def test_saving_many_to_many(): @@ -123,6 +131,16 @@ async def test_saving_many_to_many(): assert hq.nicks[0].saved assert hq.nicks[1].saved + hq.nicks[0].name = "Kabucha a" + hq.nicks[1].name = "Kabucha2 a" + assert not hq.nicks[0].saved + assert not hq.nicks[1].saved + + count = await hq.save_related(exclude={"nicks": ...}) + assert count == 0 + assert not hq.nicks[0].saved + assert not hq.nicks[1].saved + @pytest.mark.asyncio async def test_saving_reversed_relation(): @@ -211,3 +229,16 @@ async def test_saving_nested(): assert hq.nicks[0].level.saved assert hq.nicks[1].saved assert hq.nicks[1].level.saved + + hq.nicks[0].level.name = "Low 2" + hq.nicks[1].level.name = "Medium 2" + assert not hq.nicks[0].level.saved + assert not hq.nicks[1].level.saved + assert hq.nicks[0].saved + assert hq.nicks[1].saved + count = await hq.save_related(follow=True, exclude={"nicks": {"level"}}) + assert count == 0 + assert hq.nicks[0].saved + assert not hq.nicks[0].level.saved + assert hq.nicks[1].saved + assert not hq.nicks[1].level.saved diff --git a/tests/test_model_methods/test_saving_related.py b/tests/test_relations/test_saving_related.py similarity index 100% rename from tests/test_model_methods/test_saving_related.py rename to tests/test_relations/test_saving_related.py diff --git a/tests/test_utils/test_queryset_utils.py b/tests/test_utils/test_queryset_utils.py index cd96dc8..8fc9487 100644 --- a/tests/test_utils/test_queryset_utils.py +++ b/tests/test_utils/test_queryset_utils.py @@ -4,7 +4,12 @@ import sqlalchemy import ormar from ormar.models.mixins import ExcludableMixin from ormar.queryset.prefetch_query import sort_models -from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update +from ormar.queryset.utils import ( + subtract_dict, + translate_list_to_dict, + update_dict_from_list, + update, +) from tests.settings import DATABASE_URL @@ -79,6 +84,21 @@ def test_updating_dict_inc_set_with_dict(): } +def test_subtracting_dict_inc_set_with_dict(): + curr_dict = { + "aa": Ellipsis, + "bb": Ellipsis, + "cc": {"aa": {"xx", "yy"}, "bb": Ellipsis}, + } + dict_to_update = { + "uu": Ellipsis, + "bb": {"cc", "dd"}, + "cc": {"aa": {"xx": {"oo": Ellipsis}}, "bb": Ellipsis}, + } + test = subtract_dict(curr_dict, dict_to_update) + assert test == {"aa": Ellipsis, "cc": {"aa": {"yy": Ellipsis}}} + + def test_updating_dict_inc_set_with_dict_inc_set(): curr_dict = { "aa": Ellipsis, @@ -99,6 +119,23 @@ def test_updating_dict_inc_set_with_dict_inc_set(): } +def test_subtracting_dict_inc_set_with_dict_inc_set(): + curr_dict = { + "aa": Ellipsis, + "bb": Ellipsis, + "cc": {"aa": {"xx", "yy"}, "bb": Ellipsis}, + "dd": {"aa", "bb"}, + } + dict_to_update = { + "aa": Ellipsis, + "bb": {"cc", "dd"}, + "cc": {"aa": {"xx", "oo", "zz", "ii"}}, + "dd": {"aa", "bb"}, + } + test = subtract_dict(curr_dict, dict_to_update) + assert test == {"cc": {"aa": {"yy"}, "bb": Ellipsis}} + + database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData()