add exclude to save_related method, switch to same relation_map from iter
This commit is contained in:
@ -2,19 +2,22 @@
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
* `save_related(follow=False)` now accept also second argument `save_related(follow=False, save_all=False)`.
|
* `Model.save_related(follow=False)` now accept also two additional arguments: `Model.save_related(follow=False, save_all=False, exclude=None)`.
|
||||||
By default so with `save_all=False` `ormar` only upserts models that are no saved (so new or updated ones),
|
* `save_all:bool` -> By default (so with `save_all=False`) `ormar` only upserts models that are not saved (so new or updated ones),
|
||||||
with `save_all=True` all related models are saved, regardless of `saved` status, which might be useful if updated
|
with `save_all=True` all related models are saved, regardless of `saved` status, which might be useful if updated
|
||||||
models comes from api call, so are not changed in backend.
|
models comes from api call, so are not changed in the backend.
|
||||||
|
* `exclude: Union[Set, Dict, None]` -> set/dict of relations to exclude from save, those relation won't be saved even with `follow=True` and `save_all=True`.
|
||||||
|
To exclude nested relations pass a nested dictionary like: `exclude={"child":{"sub_child": {"exclude_sub_child_realtion"}}}`. The allowed values follow
|
||||||
|
the `fields/exclude_fields` (from `QuerySet`) methods schema so when in doubt you can refer to docs in queries -> selecting subset of fields -> fields.
|
||||||
* `dict()` method previously included only directly related models or nested models if they were not nullable and not virtual,
|
* `dict()` method previously included only directly related models or nested models if they were not nullable and not virtual,
|
||||||
now all related models not previosuly visited without loops are included in `dict()`. This should be not breaking
|
now all related models not previously visited without loops are included in `dict()`. This should be not breaking
|
||||||
as just more data will be dumped to dict, but it should not be missing.
|
as just more data will be dumped to dict, but it should not be missing.
|
||||||
|
|
||||||
|
|
||||||
## Fixes
|
## Fixes
|
||||||
|
|
||||||
* Fix improper relation field resolution in `QuerysetProxy` if fk column has different database alias.
|
* Fix improper relation field resolution in `QuerysetProxy` if fk column has different database alias.
|
||||||
* Fix hitting recursion error with very complicated models structure with loops.
|
* Fix hitting recursion error with very complicated models structure with loops when calling `dict()`.
|
||||||
|
|
||||||
## Other
|
## Other
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Set,
|
Set,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@ -14,7 +13,7 @@ from ormar.exceptions import ModelPersistenceError, NoMatch
|
|||||||
from ormar.models import NewBaseModel # noqa I100
|
from ormar.models import NewBaseModel # noqa I100
|
||||||
from ormar.models.metaclass import ModelMeta
|
from ormar.models.metaclass import ModelMeta
|
||||||
from ormar.models.model_row import ModelRow
|
from ormar.models.model_row import ModelRow
|
||||||
|
from ormar.queryset.utils import subtract_dict, translate_list_to_dict
|
||||||
|
|
||||||
T = TypeVar("T", bound="Model")
|
T = TypeVar("T", bound="Model")
|
||||||
|
|
||||||
@ -104,9 +103,10 @@ class Model(ModelRow):
|
|||||||
self,
|
self,
|
||||||
follow: bool = False,
|
follow: bool = False,
|
||||||
save_all: bool = False,
|
save_all: bool = False,
|
||||||
visited: Set = None,
|
relation_map: Dict = None,
|
||||||
|
exclude: Union[Set, Dict] = None,
|
||||||
update_count: int = 0,
|
update_count: int = 0,
|
||||||
) -> int: # noqa: CCR001
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Triggers a upsert method on all related models
|
Triggers a upsert method on all related models
|
||||||
if the instances are not already saved.
|
if the instances are not already saved.
|
||||||
@ -122,83 +122,86 @@ class Model(ModelRow):
|
|||||||
Model A but will never follow into Model C.
|
Model A but will never follow into Model C.
|
||||||
Nested relations of those kind need to be persisted manually.
|
Nested relations of those kind need to be persisted manually.
|
||||||
|
|
||||||
|
:param exclude: items to exclude during saving of relations
|
||||||
|
:type exclude: Union[Set, Dict]
|
||||||
|
:param relation_map: map of relations to follow
|
||||||
|
:type relation_map: Dict
|
||||||
|
:param save_all: flag if all models should be saved or only not saved ones
|
||||||
|
:type save_all: bool
|
||||||
:param follow: flag to trigger deep save -
|
:param follow: flag to trigger deep save -
|
||||||
by default only directly related models are saved
|
by default only directly related models are saved
|
||||||
with follow=True also related models of related models are saved
|
with follow=True also related models of related models are saved
|
||||||
:type follow: bool
|
:type follow: bool
|
||||||
:param visited: internal parameter for recursive calls - already visited models
|
|
||||||
:type visited: Set
|
|
||||||
:param update_count: internal parameter for recursive calls -
|
:param update_count: internal parameter for recursive calls -
|
||||||
number of updated instances
|
number of updated instances
|
||||||
:type update_count: int
|
:type update_count: int
|
||||||
:return: number of updated/saved models
|
:return: number of updated/saved models
|
||||||
:rtype: int
|
:rtype: int
|
||||||
"""
|
"""
|
||||||
if not visited:
|
relation_map = (
|
||||||
visited = {self.__class__}
|
relation_map
|
||||||
else:
|
if relation_map is not None
|
||||||
visited = {x for x in visited}
|
else translate_list_to_dict(self._iterate_related_models())
|
||||||
visited.add(self.__class__)
|
)
|
||||||
|
if exclude and isinstance(exclude, Set):
|
||||||
|
exclude = translate_list_to_dict(exclude)
|
||||||
|
relation_map = subtract_dict(relation_map, exclude or {})
|
||||||
|
|
||||||
for related in self.extract_related_names():
|
for related in self.extract_related_names():
|
||||||
if (
|
if relation_map and related in relation_map:
|
||||||
self.Meta.model_fields[related].virtual
|
value = getattr(self, related)
|
||||||
or self.Meta.model_fields[related].is_multi
|
update_count = await self._update_and_follow(
|
||||||
):
|
value=value,
|
||||||
for rel in getattr(self, related):
|
|
||||||
update_count, visited = await self._update_and_follow(
|
|
||||||
rel=rel,
|
|
||||||
follow=follow,
|
|
||||||
save_all=save_all,
|
|
||||||
visited=visited,
|
|
||||||
update_count=update_count,
|
|
||||||
)
|
|
||||||
visited.add(self.Meta.model_fields[related].to)
|
|
||||||
else:
|
|
||||||
rel = getattr(self, related)
|
|
||||||
update_count, visited = await self._update_and_follow(
|
|
||||||
rel=rel,
|
|
||||||
follow=follow,
|
follow=follow,
|
||||||
save_all=save_all,
|
save_all=save_all,
|
||||||
visited=visited,
|
relation_map=self._skip_ellipsis( # type: ignore
|
||||||
|
relation_map, related, default_return={}
|
||||||
|
),
|
||||||
update_count=update_count,
|
update_count=update_count,
|
||||||
)
|
)
|
||||||
visited.add(rel.__class__)
|
|
||||||
return update_count
|
return update_count
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _update_and_follow(
|
async def _update_and_follow(
|
||||||
rel: "Model", follow: bool, save_all: bool, visited: Set, update_count: int
|
value: Union["Model", List["Model"]],
|
||||||
) -> Tuple[int, Set]:
|
follow: bool,
|
||||||
|
save_all: bool,
|
||||||
|
relation_map: Dict,
|
||||||
|
update_count: int,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Internal method used in save_related to follow related models and update numbers
|
Internal method used in save_related to follow related models and update numbers
|
||||||
of updated related instances.
|
of updated related instances.
|
||||||
|
|
||||||
:param rel: Model to follow
|
:param value: Model to follow
|
||||||
:type rel: Model
|
:type value: Model
|
||||||
|
:param relation_map: map of relations to follow
|
||||||
|
:type relation_map: Dict
|
||||||
:param follow: flag to trigger deep save -
|
:param follow: flag to trigger deep save -
|
||||||
by default only directly related models are saved
|
by default only directly related models are saved
|
||||||
with follow=True also related models of related models are saved
|
with follow=True also related models of related models are saved
|
||||||
:type follow: bool
|
:type follow: bool
|
||||||
:param visited: internal parameter for recursive calls - already visited models
|
|
||||||
:type visited: Set
|
|
||||||
:param update_count: internal parameter for recursive calls -
|
:param update_count: internal parameter for recursive calls -
|
||||||
number of updated instances
|
number of updated instances
|
||||||
:type update_count: int
|
:type update_count: int
|
||||||
:return: tuple of update count and visited
|
:return: tuple of update count and visited
|
||||||
:rtype: Tuple[int, Set]
|
:rtype: int
|
||||||
"""
|
"""
|
||||||
if follow and rel.__class__ not in visited:
|
if not isinstance(value, list):
|
||||||
update_count = await rel.save_related(
|
value = [value]
|
||||||
follow=follow,
|
|
||||||
save_all=save_all,
|
for val in value:
|
||||||
visited=visited,
|
if follow:
|
||||||
update_count=update_count,
|
update_count = await val.save_related(
|
||||||
)
|
follow=follow,
|
||||||
if not rel.saved or save_all:
|
save_all=save_all,
|
||||||
await rel.upsert()
|
relation_map=relation_map,
|
||||||
update_count += 1
|
update_count=update_count,
|
||||||
return update_count, visited
|
)
|
||||||
|
if not val.saved or save_all:
|
||||||
|
await val.upsert()
|
||||||
|
update_count += 1
|
||||||
|
return update_count
|
||||||
|
|
||||||
async def update(self: T, **kwargs: Any) -> T:
|
async def update(self: T, **kwargs: Any) -> T:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -127,6 +127,41 @@ def update(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001
|
|||||||
return current_dict
|
return current_dict
|
||||||
|
|
||||||
|
|
||||||
|
def subtract_dict(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001
|
||||||
|
"""
|
||||||
|
Update one dict with another but with regard for nested keys.
|
||||||
|
|
||||||
|
That way nested sets are unionised, dicts updated and
|
||||||
|
only other values are overwritten.
|
||||||
|
|
||||||
|
:param current_dict: dict to update
|
||||||
|
:type current_dict: Dict[str, ellipsis]
|
||||||
|
:param updating_dict: dict with values to update
|
||||||
|
:type updating_dict: Dict
|
||||||
|
:return: combination of both dicts
|
||||||
|
:rtype: Dict
|
||||||
|
"""
|
||||||
|
if current_dict is Ellipsis:
|
||||||
|
return dict()
|
||||||
|
for key, value in updating_dict.items():
|
||||||
|
old_key = current_dict.get(key, {})
|
||||||
|
new_value: Optional[Union[Dict, Set]] = None
|
||||||
|
if not old_key:
|
||||||
|
continue
|
||||||
|
if isinstance(value, collections.abc.Mapping):
|
||||||
|
if isinstance(old_key, set):
|
||||||
|
old_key = convert_set_to_required_dict(old_key)
|
||||||
|
new_value = subtract_dict(old_key, value)
|
||||||
|
elif isinstance(value, set) and isinstance(old_key, set):
|
||||||
|
new_value = old_key.difference(value)
|
||||||
|
|
||||||
|
if new_value:
|
||||||
|
current_dict[key] = new_value
|
||||||
|
else:
|
||||||
|
current_dict.pop(key, None)
|
||||||
|
return current_dict
|
||||||
|
|
||||||
|
|
||||||
def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> Dict:
|
def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Converts the list into dictionary and later performs special update, where
|
Converts the list into dictionary and later performs special update, where
|
||||||
|
|||||||
@ -91,6 +91,14 @@ async def test_saving_related_fk_rel():
|
|||||||
assert count == 1
|
assert count == 1
|
||||||
assert comp.hq.saved
|
assert comp.hq.saved
|
||||||
|
|
||||||
|
comp.hq.name = "Suburbs 2"
|
||||||
|
assert not comp.hq.saved
|
||||||
|
assert comp.saved
|
||||||
|
|
||||||
|
count = await comp.save_related(exclude={"hq"})
|
||||||
|
assert count == 0
|
||||||
|
assert not comp.hq.saved
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_saving_many_to_many():
|
async def test_saving_many_to_many():
|
||||||
@ -123,6 +131,16 @@ async def test_saving_many_to_many():
|
|||||||
assert hq.nicks[0].saved
|
assert hq.nicks[0].saved
|
||||||
assert hq.nicks[1].saved
|
assert hq.nicks[1].saved
|
||||||
|
|
||||||
|
hq.nicks[0].name = "Kabucha a"
|
||||||
|
hq.nicks[1].name = "Kabucha2 a"
|
||||||
|
assert not hq.nicks[0].saved
|
||||||
|
assert not hq.nicks[1].saved
|
||||||
|
|
||||||
|
count = await hq.save_related(exclude={"nicks": ...})
|
||||||
|
assert count == 0
|
||||||
|
assert not hq.nicks[0].saved
|
||||||
|
assert not hq.nicks[1].saved
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_saving_reversed_relation():
|
async def test_saving_reversed_relation():
|
||||||
@ -211,3 +229,16 @@ async def test_saving_nested():
|
|||||||
assert hq.nicks[0].level.saved
|
assert hq.nicks[0].level.saved
|
||||||
assert hq.nicks[1].saved
|
assert hq.nicks[1].saved
|
||||||
assert hq.nicks[1].level.saved
|
assert hq.nicks[1].level.saved
|
||||||
|
|
||||||
|
hq.nicks[0].level.name = "Low 2"
|
||||||
|
hq.nicks[1].level.name = "Medium 2"
|
||||||
|
assert not hq.nicks[0].level.saved
|
||||||
|
assert not hq.nicks[1].level.saved
|
||||||
|
assert hq.nicks[0].saved
|
||||||
|
assert hq.nicks[1].saved
|
||||||
|
count = await hq.save_related(follow=True, exclude={"nicks": {"level"}})
|
||||||
|
assert count == 0
|
||||||
|
assert hq.nicks[0].saved
|
||||||
|
assert not hq.nicks[0].level.saved
|
||||||
|
assert hq.nicks[1].saved
|
||||||
|
assert not hq.nicks[1].level.saved
|
||||||
|
|||||||
@ -4,7 +4,12 @@ import sqlalchemy
|
|||||||
import ormar
|
import ormar
|
||||||
from ormar.models.mixins import ExcludableMixin
|
from ormar.models.mixins import ExcludableMixin
|
||||||
from ormar.queryset.prefetch_query import sort_models
|
from ormar.queryset.prefetch_query import sort_models
|
||||||
from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update
|
from ormar.queryset.utils import (
|
||||||
|
subtract_dict,
|
||||||
|
translate_list_to_dict,
|
||||||
|
update_dict_from_list,
|
||||||
|
update,
|
||||||
|
)
|
||||||
from tests.settings import DATABASE_URL
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
|
|
||||||
@ -79,6 +84,21 @@ def test_updating_dict_inc_set_with_dict():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_subtracting_dict_inc_set_with_dict():
|
||||||
|
curr_dict = {
|
||||||
|
"aa": Ellipsis,
|
||||||
|
"bb": Ellipsis,
|
||||||
|
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
|
||||||
|
}
|
||||||
|
dict_to_update = {
|
||||||
|
"uu": Ellipsis,
|
||||||
|
"bb": {"cc", "dd"},
|
||||||
|
"cc": {"aa": {"xx": {"oo": Ellipsis}}, "bb": Ellipsis},
|
||||||
|
}
|
||||||
|
test = subtract_dict(curr_dict, dict_to_update)
|
||||||
|
assert test == {"aa": Ellipsis, "cc": {"aa": {"yy": Ellipsis}}}
|
||||||
|
|
||||||
|
|
||||||
def test_updating_dict_inc_set_with_dict_inc_set():
|
def test_updating_dict_inc_set_with_dict_inc_set():
|
||||||
curr_dict = {
|
curr_dict = {
|
||||||
"aa": Ellipsis,
|
"aa": Ellipsis,
|
||||||
@ -99,6 +119,23 @@ def test_updating_dict_inc_set_with_dict_inc_set():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_subtracting_dict_inc_set_with_dict_inc_set():
|
||||||
|
curr_dict = {
|
||||||
|
"aa": Ellipsis,
|
||||||
|
"bb": Ellipsis,
|
||||||
|
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
|
||||||
|
"dd": {"aa", "bb"},
|
||||||
|
}
|
||||||
|
dict_to_update = {
|
||||||
|
"aa": Ellipsis,
|
||||||
|
"bb": {"cc", "dd"},
|
||||||
|
"cc": {"aa": {"xx", "oo", "zz", "ii"}},
|
||||||
|
"dd": {"aa", "bb"},
|
||||||
|
}
|
||||||
|
test = subtract_dict(curr_dict, dict_to_update)
|
||||||
|
assert test == {"cc": {"aa": {"yy"}, "bb": Ellipsis}}
|
||||||
|
|
||||||
|
|
||||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||||
metadata = sqlalchemy.MetaData()
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user