From a940fcad6eb276fff1fc0c03b797ac195aefe551 Mon Sep 17 00:00:00 2001 From: collerek Date: Sat, 3 Apr 2021 19:50:48 +0200 Subject: [PATCH] fix merging lists of deeply nested reverse relations --- docs/releases.md | 1 + ormar/models/mixins/merge_mixin.py | 98 ++++++++++++--- .../test_nested_reverse_relations.py | 117 ++++++++++++++++++ 3 files changed, 198 insertions(+), 18 deletions(-) create mode 100644 tests/test_queries/test_nested_reverse_relations.py diff --git a/docs/releases.md b/docs/releases.md index ce9daa7..c10e7c3 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -26,6 +26,7 @@ * Fix bug when two non-relation fields were merged (appended) in query result when they were not relation fields (i.e. JSON) * Fix bug when during translation to dict from list the same relation name is used in chain but leads to different models * Fix bug when bulk_create would try to save also `property_field` decorated methods and `pydantic` fields +* Fix wrong merging of deeply nested chain of reversed relations ## Other diff --git a/ormar/models/mixins/merge_mixin.py b/ormar/models/mixins/merge_mixin.py index e7234e4..524d5e5 100644 --- a/ormar/models/mixins/merge_mixin.py +++ b/ormar/models/mixins/merge_mixin.py @@ -1,7 +1,8 @@ from collections import OrderedDict -from typing import List, TYPE_CHECKING +from typing import Dict, List, Optional, TYPE_CHECKING, cast import ormar +from ormar.queryset.utils import translate_list_to_dict if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -46,13 +47,17 @@ class MergeModelMixin: return merged_rows @classmethod - def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": + def merge_two_instances( + cls, one: "Model", other: "Model", relation_map: Dict = None + ) -> "Model": """ Merges current (other) Model and previous one (one) and returns the current Model instance with data merged from previous one. If needed it's calling itself recurrently and merges also children models. + :param relation_map: map of models relations to follow + :type relation_map: Dict :param one: previous model instance :type one: Model :param other: current model instance @@ -60,23 +65,80 @@ class MergeModelMixin: :return: current Model instance with data merged from previous one. :rtype: Model """ - for field_name, field in one.Meta.model_fields.items(): + relation_map = ( + relation_map + if relation_map is not None + else translate_list_to_dict(one._iterate_related_models()) + ) + for field_name in relation_map: current_field = getattr(one, field_name) - if field.is_relation: - if isinstance(current_field, list): - setattr( - other, field_name, current_field + getattr(other, field_name) - ) - elif ( - isinstance(current_field, ormar.Model) - and current_field.pk == getattr(other, field_name).pk - ): - setattr( - other, - field_name, - cls.merge_two_instances( - current_field, getattr(other, field_name) + other_value = getattr(other, field_name, []) + if isinstance(current_field, list): + value_to_set = cls._merge_items_lists( + field_name=field_name, + current_field=current_field, + other_value=other_value, + relation_map=relation_map, + ) + setattr(other, field_name, value_to_set) + elif ( + isinstance(current_field, ormar.Model) + and current_field.pk == other_value.pk + ): + setattr( + other, + field_name, + cls.merge_two_instances( + current_field, + other_value, + relation_map=one._skip_ellipsis( # type: ignore + relation_map, field_name, default_return=dict() ), - ) + ), + ) other.set_save_status(True) return other + + @classmethod + def _merge_items_lists( + cls, + field_name: str, + current_field: List, + other_value: List, + relation_map: Optional[Dict], + ) -> List: + """ + Takes two list of nested models and process them going deeper + according with the map. + + If model from one's list is in other -> they are merged with relations + to follow passed from map. + + If one's model is not in other it's simply appended to the list. + + :param field_name: name of the current relation field + :type field_name: str + :param current_field: list of nested models from one model + :type current_field: List[Model] + :param other_value: list of nested models from other model + :type other_value: List[Model] + :param relation_map: map of relations to follow + :type relation_map: Dict + :return: merged list of models + :rtype: List[Model] + """ + value_to_set = [x for x in other_value] + for cur_field in current_field: + if cur_field in other_value: + old_value = next((x for x in other_value if x == cur_field), None) + new_val = cls.merge_two_instances( + cur_field, + cast("Model", old_value), + relation_map=cur_field._skip_ellipsis( # type: ignore + relation_map, field_name, default_return=dict() + ), + ) + value_to_set = [x for x in value_to_set if x != cur_field] + [new_val] + else: + value_to_set.append(cur_field) + return value_to_set diff --git a/tests/test_queries/test_nested_reverse_relations.py b/tests/test_queries/test_nested_reverse_relations.py new file mode 100644 index 0000000..14b4544 --- /dev/null +++ b/tests/test_queries/test_nested_reverse_relations.py @@ -0,0 +1,117 @@ +from typing import Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL) +metadata = sqlalchemy.MetaData() + + +class BaseMeta(ormar.ModelMeta): + metadata = metadata + database = database + + +class DataSource(ormar.Model): + class Meta(BaseMeta): + tablename = "datasources" + + source_id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=200, unique=True, index=True) + + +class DataSourceTable(ormar.Model): + class Meta(BaseMeta): + tablename = "datasource_tables" + + datasource_table_id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=200, index=True) + data_source: Optional[DataSource] = ormar.ForeignKey( + DataSource, + name="data_source_id", + related_name="datasource_tables", + ondelete="CASCADE", + ) + + +class DataSourceTableColumn(ormar.Model): + class Meta(BaseMeta): + tablename = "datasource_table_columns" + + datasource_table_column_id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=200, index=True) + data_type: str = ormar.String(max_length=200) + datasource_table: Optional[DataSourceTable] = ormar.ForeignKey( + DataSourceTable, + name="datasource_table_id", + related_name="datasource_table_columns", + ondelete="CASCADE", + ) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_double_nested_reverse_relation(): + async with database: + data_source = await DataSource(name="local").save() + test_tables = [ + { + "name": "test1", + "datasource_table_columns": [ + {"name": "col1", "data_type": "test"}, + {"name": "col2", "data_type": "test2"}, + {"name": "col3", "data_type": "test3"}, + ], + }, + { + "name": "test2", + "datasource_table_columns": [ + {"name": "col4", "data_type": "test"}, + {"name": "col5", "data_type": "test2"}, + {"name": "col6", "data_type": "test3"}, + ], + }, + ] + data_source.datasource_tables = test_tables + await data_source.save_related(save_all=True, follow=True) + + tables = await DataSourceTable.objects.all() + assert len(tables) == 2 + + columns = await DataSourceTableColumn.objects.all() + assert len(columns) == 6 + + data_source = ( + await DataSource.objects.select_related( + "datasource_tables__datasource_table_columns" + ) + .filter(datasource_tables__name__in=["test1", "test2"], name="local") + .get() + ) + assert len(data_source.datasource_tables) == 2 + assert len(data_source.datasource_tables[0].datasource_table_columns) == 3 + assert ( + data_source.datasource_tables[0].datasource_table_columns[0].name == "col1" + ) + assert ( + data_source.datasource_tables[0].datasource_table_columns[2].name == "col3" + ) + assert len(data_source.datasource_tables[1].datasource_table_columns) == 3 + assert ( + data_source.datasource_tables[1].datasource_table_columns[0].name == "col4" + ) + assert ( + data_source.datasource_tables[1].datasource_table_columns[2].name == "col6" + )