add exclude to save_related method, switch to same relation_map from iter

This commit is contained in:
collerek
2021-03-30 16:26:10 +02:00
parent da05e5ba1d
commit 64d3d1b693
6 changed files with 163 additions and 54 deletions

View File

@ -2,19 +2,22 @@
## Features ## Features
* `save_related(follow=False)` now accept also second argument `save_related(follow=False, save_all=False)`. * `Model.save_related(follow=False)` now accept also two additional arguments: `Model.save_related(follow=False, save_all=False, exclude=None)`.
By default so with `save_all=False` `ormar` only upserts models that are no saved (so new or updated ones), * `save_all:bool` -> By default (so with `save_all=False`) `ormar` only upserts models that are not saved (so new or updated ones),
with `save_all=True` all related models are saved, regardless of `saved` status, which might be useful if updated with `save_all=True` all related models are saved, regardless of `saved` status, which might be useful if updated
models comes from api call, so are not changed in backend. models comes from api call, so are not changed in the backend.
* `exclude: Union[Set, Dict, None]` -> set/dict of relations to exclude from save, those relation won't be saved even with `follow=True` and `save_all=True`.
To exclude nested relations pass a nested dictionary like: `exclude={"child":{"sub_child": {"exclude_sub_child_realtion"}}}`. The allowed values follow
the `fields/exclude_fields` (from `QuerySet`) methods schema so when in doubt you can refer to docs in queries -> selecting subset of fields -> fields.
* `dict()` method previously included only directly related models or nested models if they were not nullable and not virtual, * `dict()` method previously included only directly related models or nested models if they were not nullable and not virtual,
now all related models not previosuly visited without loops are included in `dict()`. This should be not breaking now all related models not previously visited without loops are included in `dict()`. This should be not breaking
as just more data will be dumped to dict, but it should not be missing. as just more data will be dumped to dict, but it should not be missing.
## Fixes ## Fixes
* Fix improper relation field resolution in `QuerysetProxy` if fk column has different database alias. * Fix improper relation field resolution in `QuerysetProxy` if fk column has different database alias.
* Fix hitting recursion error with very complicated models structure with loops. * Fix hitting recursion error with very complicated models structure with loops when calling `dict()`.
## Other ## Other

View File

@ -4,7 +4,6 @@ from typing import (
List, List,
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
Tuple,
TypeVar, TypeVar,
Union, Union,
) )
@ -14,7 +13,7 @@ from ormar.exceptions import ModelPersistenceError, NoMatch
from ormar.models import NewBaseModel # noqa I100 from ormar.models import NewBaseModel # noqa I100
from ormar.models.metaclass import ModelMeta from ormar.models.metaclass import ModelMeta
from ormar.models.model_row import ModelRow from ormar.models.model_row import ModelRow
from ormar.queryset.utils import subtract_dict, translate_list_to_dict
T = TypeVar("T", bound="Model") T = TypeVar("T", bound="Model")
@ -104,9 +103,10 @@ class Model(ModelRow):
self, self,
follow: bool = False, follow: bool = False,
save_all: bool = False, save_all: bool = False,
visited: Set = None, relation_map: Dict = None,
exclude: Union[Set, Dict] = None,
update_count: int = 0, update_count: int = 0,
) -> int: # noqa: CCR001 ) -> int:
""" """
Triggers a upsert method on all related models Triggers a upsert method on all related models
if the instances are not already saved. if the instances are not already saved.
@ -122,83 +122,86 @@ class Model(ModelRow):
Model A but will never follow into Model C. Model A but will never follow into Model C.
Nested relations of those kind need to be persisted manually. Nested relations of those kind need to be persisted manually.
:param exclude: items to exclude during saving of relations
:type exclude: Union[Set, Dict]
:param relation_map: map of relations to follow
:type relation_map: Dict
:param save_all: flag if all models should be saved or only not saved ones
:type save_all: bool
:param follow: flag to trigger deep save - :param follow: flag to trigger deep save -
by default only directly related models are saved by default only directly related models are saved
with follow=True also related models of related models are saved with follow=True also related models of related models are saved
:type follow: bool :type follow: bool
:param visited: internal parameter for recursive calls - already visited models
:type visited: Set
:param update_count: internal parameter for recursive calls - :param update_count: internal parameter for recursive calls -
number of updated instances number of updated instances
:type update_count: int :type update_count: int
:return: number of updated/saved models :return: number of updated/saved models
:rtype: int :rtype: int
""" """
if not visited: relation_map = (
visited = {self.__class__} relation_map
else: if relation_map is not None
visited = {x for x in visited} else translate_list_to_dict(self._iterate_related_models())
visited.add(self.__class__) )
if exclude and isinstance(exclude, Set):
exclude = translate_list_to_dict(exclude)
relation_map = subtract_dict(relation_map, exclude or {})
for related in self.extract_related_names(): for related in self.extract_related_names():
if ( if relation_map and related in relation_map:
self.Meta.model_fields[related].virtual value = getattr(self, related)
or self.Meta.model_fields[related].is_multi update_count = await self._update_and_follow(
): value=value,
for rel in getattr(self, related):
update_count, visited = await self._update_and_follow(
rel=rel,
follow=follow,
save_all=save_all,
visited=visited,
update_count=update_count,
)
visited.add(self.Meta.model_fields[related].to)
else:
rel = getattr(self, related)
update_count, visited = await self._update_and_follow(
rel=rel,
follow=follow, follow=follow,
save_all=save_all, save_all=save_all,
visited=visited, relation_map=self._skip_ellipsis( # type: ignore
relation_map, related, default_return={}
),
update_count=update_count, update_count=update_count,
) )
visited.add(rel.__class__)
return update_count return update_count
@staticmethod @staticmethod
async def _update_and_follow( async def _update_and_follow(
rel: "Model", follow: bool, save_all: bool, visited: Set, update_count: int value: Union["Model", List["Model"]],
) -> Tuple[int, Set]: follow: bool,
save_all: bool,
relation_map: Dict,
update_count: int,
) -> int:
""" """
Internal method used in save_related to follow related models and update numbers Internal method used in save_related to follow related models and update numbers
of updated related instances. of updated related instances.
:param rel: Model to follow :param value: Model to follow
:type rel: Model :type value: Model
:param relation_map: map of relations to follow
:type relation_map: Dict
:param follow: flag to trigger deep save - :param follow: flag to trigger deep save -
by default only directly related models are saved by default only directly related models are saved
with follow=True also related models of related models are saved with follow=True also related models of related models are saved
:type follow: bool :type follow: bool
:param visited: internal parameter for recursive calls - already visited models
:type visited: Set
:param update_count: internal parameter for recursive calls - :param update_count: internal parameter for recursive calls -
number of updated instances number of updated instances
:type update_count: int :type update_count: int
:return: tuple of update count and visited :return: tuple of update count and visited
:rtype: Tuple[int, Set] :rtype: int
""" """
if follow and rel.__class__ not in visited: if not isinstance(value, list):
update_count = await rel.save_related( value = [value]
follow=follow,
save_all=save_all, for val in value:
visited=visited, if follow:
update_count=update_count, update_count = await val.save_related(
) follow=follow,
if not rel.saved or save_all: save_all=save_all,
await rel.upsert() relation_map=relation_map,
update_count += 1 update_count=update_count,
return update_count, visited )
if not val.saved or save_all:
await val.upsert()
update_count += 1
return update_count
async def update(self: T, **kwargs: Any) -> T: async def update(self: T, **kwargs: Any) -> T:
""" """

View File

@ -127,6 +127,41 @@ def update(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001
return current_dict return current_dict
def subtract_dict(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001
"""
Update one dict with another but with regard for nested keys.
That way nested sets are unionised, dicts updated and
only other values are overwritten.
:param current_dict: dict to update
:type current_dict: Dict[str, ellipsis]
:param updating_dict: dict with values to update
:type updating_dict: Dict
:return: combination of both dicts
:rtype: Dict
"""
if current_dict is Ellipsis:
return dict()
for key, value in updating_dict.items():
old_key = current_dict.get(key, {})
new_value: Optional[Union[Dict, Set]] = None
if not old_key:
continue
if isinstance(value, collections.abc.Mapping):
if isinstance(old_key, set):
old_key = convert_set_to_required_dict(old_key)
new_value = subtract_dict(old_key, value)
elif isinstance(value, set) and isinstance(old_key, set):
new_value = old_key.difference(value)
if new_value:
current_dict[key] = new_value
else:
current_dict.pop(key, None)
return current_dict
def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> Dict: def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> Dict:
""" """
Converts the list into dictionary and later performs special update, where Converts the list into dictionary and later performs special update, where

View File

@ -91,6 +91,14 @@ async def test_saving_related_fk_rel():
assert count == 1 assert count == 1
assert comp.hq.saved assert comp.hq.saved
comp.hq.name = "Suburbs 2"
assert not comp.hq.saved
assert comp.saved
count = await comp.save_related(exclude={"hq"})
assert count == 0
assert not comp.hq.saved
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_saving_many_to_many(): async def test_saving_many_to_many():
@ -123,6 +131,16 @@ async def test_saving_many_to_many():
assert hq.nicks[0].saved assert hq.nicks[0].saved
assert hq.nicks[1].saved assert hq.nicks[1].saved
hq.nicks[0].name = "Kabucha a"
hq.nicks[1].name = "Kabucha2 a"
assert not hq.nicks[0].saved
assert not hq.nicks[1].saved
count = await hq.save_related(exclude={"nicks": ...})
assert count == 0
assert not hq.nicks[0].saved
assert not hq.nicks[1].saved
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_saving_reversed_relation(): async def test_saving_reversed_relation():
@ -211,3 +229,16 @@ async def test_saving_nested():
assert hq.nicks[0].level.saved assert hq.nicks[0].level.saved
assert hq.nicks[1].saved assert hq.nicks[1].saved
assert hq.nicks[1].level.saved assert hq.nicks[1].level.saved
hq.nicks[0].level.name = "Low 2"
hq.nicks[1].level.name = "Medium 2"
assert not hq.nicks[0].level.saved
assert not hq.nicks[1].level.saved
assert hq.nicks[0].saved
assert hq.nicks[1].saved
count = await hq.save_related(follow=True, exclude={"nicks": {"level"}})
assert count == 0
assert hq.nicks[0].saved
assert not hq.nicks[0].level.saved
assert hq.nicks[1].saved
assert not hq.nicks[1].level.saved

View File

@ -4,7 +4,12 @@ import sqlalchemy
import ormar import ormar
from ormar.models.mixins import ExcludableMixin from ormar.models.mixins import ExcludableMixin
from ormar.queryset.prefetch_query import sort_models from ormar.queryset.prefetch_query import sort_models
from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update from ormar.queryset.utils import (
subtract_dict,
translate_list_to_dict,
update_dict_from_list,
update,
)
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -79,6 +84,21 @@ def test_updating_dict_inc_set_with_dict():
} }
def test_subtracting_dict_inc_set_with_dict():
curr_dict = {
"aa": Ellipsis,
"bb": Ellipsis,
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
}
dict_to_update = {
"uu": Ellipsis,
"bb": {"cc", "dd"},
"cc": {"aa": {"xx": {"oo": Ellipsis}}, "bb": Ellipsis},
}
test = subtract_dict(curr_dict, dict_to_update)
assert test == {"aa": Ellipsis, "cc": {"aa": {"yy": Ellipsis}}}
def test_updating_dict_inc_set_with_dict_inc_set(): def test_updating_dict_inc_set_with_dict_inc_set():
curr_dict = { curr_dict = {
"aa": Ellipsis, "aa": Ellipsis,
@ -99,6 +119,23 @@ def test_updating_dict_inc_set_with_dict_inc_set():
} }
def test_subtracting_dict_inc_set_with_dict_inc_set():
curr_dict = {
"aa": Ellipsis,
"bb": Ellipsis,
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
"dd": {"aa", "bb"},
}
dict_to_update = {
"aa": Ellipsis,
"bb": {"cc", "dd"},
"cc": {"aa": {"xx", "oo", "zz", "ii"}},
"dd": {"aa", "bb"},
}
test = subtract_dict(curr_dict, dict_to_update)
assert test == {"cc": {"aa": {"yy"}, "bb": Ellipsis}}
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()