From 844ecae8f91b5ebeadfdd3341a8e59ef55c9a123 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 29 Mar 2021 17:07:01 +0200 Subject: [PATCH] fix recursion limit for complicated models structures with many loops --- docs/releases.md | 1 + mkdocs.yml | 1 + ormar/models/mixins/relation_mixin.py | 66 +++------- ormar/models/traversible.py | 118 ++++++++++++++++++ pydoc-markdown.yml | 3 + tests/test_deep_relations_select_all.py | 158 ++++++++++++++++++++++++ 6 files changed, 297 insertions(+), 50 deletions(-) create mode 100644 ormar/models/traversible.py create mode 100644 tests/test_deep_relations_select_all.py diff --git a/docs/releases.md b/docs/releases.md index 210bb93..f88a694 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -10,6 +10,7 @@ ## Fixes * Fix improper relation field resolution in `QuerysetProxy` if fk column has different database alias. +* Fix hitting recursion error with very complicated models structure with loops. ## Other diff --git a/mkdocs.yml b/mkdocs.yml index 735c732..5432018 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -59,6 +59,7 @@ nav: - Model Table Proxy: api/models/model-table-proxy.md - Model Metaclass: api/models/model-metaclass.md - Excludable Items: api/models/excludable-items.md + - Traversible: api/models/traversible.md - Fields: - Base Field: api/fields/base-field.md - Model Fields: api/fields/model-fields.md diff --git a/ormar/models/mixins/relation_mixin.py b/ormar/models/mixins/relation_mixin.py index 2a20dcc..ee303e0 100644 --- a/ormar/models/mixins/relation_mixin.py +++ b/ormar/models/mixins/relation_mixin.py @@ -4,11 +4,10 @@ from typing import ( Optional, Set, TYPE_CHECKING, - Type, - Union, ) from ormar import BaseField +from ormar.models.traversible import NodeList class RelationMixin: @@ -17,7 +16,7 @@ class RelationMixin: """ if TYPE_CHECKING: # pragma no cover - from ormar import ModelMeta, Model + from ormar import ModelMeta Meta: ModelMeta _related_names: Optional[Set] @@ -135,61 +134,37 @@ class RelationMixin: @classmethod def _iterate_related_models( # noqa: CCR001 - cls, - visited: Set[str] = None, - source_visited: Set[str] = None, - source_relation: str = None, - source_model: Union[Type["Model"], Type["RelationMixin"]] = None, + cls, node_list: NodeList = None, source_relation: str = None ) -> List[str]: """ Iterates related models recursively to extract relation strings of nested not visited models. - :param visited: set of already visited models - :type visited: Set[str] - :param source_relation: name of the current relation - :type source_relation: str - :param source_model: model from which relation comes in nested relations - :type source_model: Type["Model"] :return: list of relation strings to be passed to select_related :rtype: List[str] """ - source_visited = source_visited or cls._populate_source_model_prefixes() + if not node_list: + node_list = NodeList() + current_node = node_list.add(node_class=cls) + else: + current_node = node_list[-1] relations = cls.extract_related_names() processed_relations = [] for relation in relations: - target_model = cls.Meta.model_fields[relation].to - if cls._is_reverse_side_of_same_relation(source_model, target_model): - continue - if target_model not in source_visited or not source_model: + if not current_node.visited(relation): + target_model = cls.Meta.model_fields[relation].to + node_list.add( + node_class=target_model, + relation_name=relation, + parent_node=current_node, + ) deep_relations = target_model._iterate_related_models( - visited=visited, - source_visited=source_visited, - source_relation=relation, - source_model=cls, + source_relation=relation, node_list=node_list ) processed_relations.extend(deep_relations) - else: - processed_relations.append(relation) return cls._get_final_relations(processed_relations, source_relation) - @staticmethod - def _is_reverse_side_of_same_relation( - source_model: Optional[Union[Type["Model"], Type["RelationMixin"]]], - target_model: Type["Model"], - ) -> bool: - """ - Alias to check if source model is the same as target - :param source_model: source model - relation comes from it - :type source_model: Type["Model"] - :param target_model: target model - relation leads to it - :type target_model: Type["Model"] - :return: result of the check - :rtype: bool - """ - return bool(source_model and target_model == source_model) - @staticmethod def _get_final_relations( processed_relations: List, source_relation: Optional[str] @@ -212,12 +187,3 @@ class RelationMixin: else: final_relations = [source_relation] if source_relation else [] return final_relations - - @classmethod - def _populate_source_model_prefixes(cls) -> Set: - relations = cls.extract_related_names() - visited = {cls} - for relation in relations: - target_model = cls.Meta.model_fields[relation].to - visited.add(target_model) - return visited diff --git a/ormar/models/traversible.py b/ormar/models/traversible.py new file mode 100644 index 0000000..3a90bb9 --- /dev/null +++ b/ormar/models/traversible.py @@ -0,0 +1,118 @@ +from typing import Any, List, Optional, TYPE_CHECKING, Type + +if TYPE_CHECKING: # pragma no cover + from ormar.models.mixins.relation_mixin import RelationMixin + + +class NodeList: + """ + Helper class that helps with iterating nested models + """ + + def __init__(self) -> None: + self.node_list: List["Node"] = [] + + def __getitem__(self, item: Any) -> Any: + return self.node_list.__getitem__(item) + + def add( + self, + node_class: Type["RelationMixin"], + relation_name: str = None, + parent_node: "Node" = None, + ) -> "Node": + """ + Adds new Node or returns the existing one + + :param node_class: Model in current node + :type node_class: ormar.models.metaclass.ModelMetaclass + :param relation_name: name of the current relation + :type relation_name: str + :param parent_node: parent node + :type parent_node: Optional[Node] + :return: returns new or already existing node + :rtype: Node + """ + existing_node = self.find( + relation_name=relation_name, node_class=node_class, parent_node=parent_node + ) + if not existing_node: + current_node = Node( + node_class=node_class, + relation_name=relation_name, + parent_node=parent_node, + ) + self.node_list.append(current_node) + return current_node + return existing_node # pragma: no cover + + def find( + self, + node_class: Type["RelationMixin"], + relation_name: Optional[str] = None, + parent_node: "Node" = None, + ) -> Optional["Node"]: + """ + Searches for existing node with given parameters + + :param node_class: Model in current node + :type node_class: ormar.models.metaclass.ModelMetaclass + :param relation_name: name of the current relation + :type relation_name: str + :param parent_node: parent node + :type parent_node: Optional[Node] + :return: returns already existing node or None + :rtype: Optional[Node] + """ + for node in self.node_list: + if ( + node.node_class == node_class + and node.parent_node == parent_node + and node.relation_name == relation_name + ): + return node # pragma: no cover + return None + + +class Node: + def __init__( + self, + node_class: Type["RelationMixin"], + relation_name: str = None, + parent_node: "Node" = None, + ) -> None: + self.relation_name = relation_name + self.node_class = node_class + self.parent_node = parent_node + self.visited_children: List["Node"] = [] + if self.parent_node: + self.parent_node.visited_children.append(self) + + def __repr__(self) -> str: # pragma: no cover + return ( + f"{self.node_class.get_name(lower=False)}, " + f"relation:{self.relation_name}, " + f"parent: {self.parent_node}" + ) + + def visited(self, relation_name: str) -> bool: + """ + Checks if given relation was already visited. + + Relation was visited if it's name is in current node children. + + Relation was visited if one of the parent node had the same Model class + + :param relation_name: name of relation + :type relation_name: str + :return: result of the check + :rtype: bool + """ + target_model = self.node_class.Meta.model_fields[relation_name].to + if self.parent_node: + node = self + while node.parent_node: + node = node.parent_node + if node.node_class == target_model: + return True + return False diff --git a/pydoc-markdown.yml b/pydoc-markdown.yml index 6f0188f..7243a2a 100644 --- a/pydoc-markdown.yml +++ b/pydoc-markdown.yml @@ -30,6 +30,9 @@ renderer: - title: Excludable Items contents: - models.excludable.* + - title: Traversible + contents: + - models.traversible.* - title: Model Table Proxy contents: - models.modelproxy.* diff --git a/tests/test_deep_relations_select_all.py b/tests/test_deep_relations_select_all.py new file mode 100644 index 0000000..948b81f --- /dev/null +++ b/tests/test_deep_relations_select_all.py @@ -0,0 +1,158 @@ +import databases +import pytest +from sqlalchemy import func + +import ormar +import sqlalchemy +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Chart(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "charts" + database = database + metadata = metadata + + chart_id = ormar.Integer(primary_key=True, autoincrement=True) + name = ormar.String(max_length=200, unique=True, index=True) + query_text = ormar.Text() + datasets = ormar.JSON() + layout = ormar.JSON() + data_config = ormar.JSON() + created_date = ormar.DateTime(server_default=func.now()) + library = ormar.String(max_length=200, default="plotly") + used_filters = ormar.JSON() + + +class Report(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "reports" + database = database + metadata = metadata + + report_id = ormar.Integer(primary_key=True, autoincrement=True) + name = ormar.String(max_length=200, unique=True, index=True) + filters_position = ormar.String(max_length=200) + created_date = ormar.DateTime(server_default=func.now()) + + +class Language(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "languages" + database = database + metadata = metadata + + language_id = ormar.Integer(primary_key=True, autoincrement=True) + code = ormar.String(max_length=5) + name = ormar.String(max_length=200) + + +class TranslationNode(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "translation_nodes" + database = database + metadata = metadata + + node_id = ormar.Integer(primary_key=True, autoincrement=True) + node_type = ormar.String(max_length=200) + + +class Translation(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "translations" + database = database + metadata = metadata + + translation_id = ormar.Integer(primary_key=True, autoincrement=True) + node_id = ormar.ForeignKey(TranslationNode, related_name="translations") + language = ormar.ForeignKey(Language, name="language_id") + value = ormar.String(max_length=500) + + +class Filter(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "filters" + database = database + metadata = metadata + + filter_id = ormar.Integer(primary_key=True, autoincrement=True) + name = ormar.String(max_length=200, unique=True, index=True) + label = ormar.String(max_length=200) + query_text = ormar.Text() + allow_multiselect = ormar.Boolean(default=True) + created_date = ormar.DateTime(server_default=func.now()) + is_dynamic = ormar.Boolean(default=True) + is_date = ormar.Boolean(default=False) + translation = ormar.ForeignKey(TranslationNode, name="translation_node_id") + + +class FilterValue(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "filter_values" + database = database + metadata = metadata + + value_id = ormar.Integer(primary_key=True, autoincrement=True) + value = ormar.String(max_length=300) + label = ormar.String(max_length=300) + filter = ormar.ForeignKey(Filter, name="filter_id", related_name="values") + translation = ormar.ForeignKey(TranslationNode, name="translation_node_id") + + +class FilterXReport(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "filters_x_reports" + database = database + metadata = metadata + + filter_x_report_id = ormar.Integer(primary_key=True) + filter = ormar.ForeignKey(Filter, name="filter_id", related_name="reports") + report = ormar.ForeignKey(Report, name="report_id", related_name="filters") + sort_order = ormar.Integer() + default_value = ormar.Text() + is_visible = ormar.Boolean() + + +class ChartXReport(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "charts_x_reports" + database = database + metadata = metadata + + chart_x_report_id = ormar.Integer(primary_key=True) + chart = ormar.ForeignKey(Chart, name="chart_id", related_name="reports") + report = ormar.ForeignKey(Report, name="report_id", related_name="charts") + sort_order = ormar.Integer() + width = ormar.Integer() + + +class ChartColumn(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "charts_columns" + database = database + metadata = metadata + + column_id = ormar.Integer(primary_key=True, autoincrement=True) + chart = ormar.ForeignKey(Chart, name="chart_id", related_name="columns") + column_name = ormar.String(max_length=200) + column_type = ormar.String(max_length=200) + translation = ormar.ForeignKey(TranslationNode, name="translation_node_id") + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_saving_related_fk_rel(): + async with database: + async with database.transaction(force_rollback=True): + await Report.objects.select_all(follow=True).all()