From 347c056c306e65bc114622f70fecf22786d4295d Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 30 Mar 2021 17:12:56 +0200 Subject: [PATCH] fix dict subtracting with dict and set --- ormar/queryset/utils.py | 22 +++++++++----- tests/test_utils/test_queryset_utils.py | 38 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index 5653f55..3ed7f93 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -141,19 +141,27 @@ def subtract_dict(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR00 :return: combination of both dicts :rtype: Dict """ - if current_dict is Ellipsis: - return dict() for key, value in updating_dict.items(): old_key = current_dict.get(key, {}) new_value: Optional[Union[Dict, Set]] = None if not old_key: continue - if isinstance(value, collections.abc.Mapping): - if isinstance(old_key, set): - old_key = convert_set_to_required_dict(old_key) - new_value = subtract_dict(old_key, value) - elif isinstance(value, set) and isinstance(old_key, set): + if isinstance(value, set) and isinstance(old_key, set): new_value = old_key.difference(value) + elif isinstance(value, (set, collections.abc.Mapping)) and isinstance( + old_key, (set, collections.abc.Mapping) + ): + value = ( + convert_set_to_required_dict(value) + if not isinstance(value, collections.abc.Mapping) + else value + ) + old_key = ( + convert_set_to_required_dict(old_key) + if not isinstance(old_key, collections.abc.Mapping) + else old_key + ) + new_value = subtract_dict(old_key, value) if new_value: current_dict[key] = new_value diff --git a/tests/test_utils/test_queryset_utils.py b/tests/test_utils/test_queryset_utils.py index 8fc9487..fe76092 100644 --- a/tests/test_utils/test_queryset_utils.py +++ b/tests/test_utils/test_queryset_utils.py @@ -136,6 +136,44 @@ def test_subtracting_dict_inc_set_with_dict_inc_set(): assert test == {"cc": {"aa": {"yy"}, "bb": Ellipsis}} +def test_subtracting_with_set_and_dict(): + curr_dict = { + "translation": { + "filters": { + "values": Ellipsis, + "reports": {"report": {"charts": {"chart": Ellipsis}}}, + }, + "translations": {"language": Ellipsis}, + "filtervalues": { + "filter": {"reports": {"report": {"charts": {"chart": Ellipsis}}}} + }, + }, + "chart": { + "reports": { + "report": { + "filters": { + "filter": { + "translation": { + "translations": {"language": Ellipsis}, + "filtervalues": Ellipsis, + }, + "values": { + "translation": {"translations": {"language": Ellipsis}} + }, + } + } + } + } + }, + } + dict_to_update = { + "chart": Ellipsis, + "translation": {"filters", "filtervalues", "chartcolumns"}, + } + test = subtract_dict(curr_dict, dict_to_update) + assert test == {"translation": {"translations": {"language": Ellipsis}}} + + database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData()