diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 30602ae..16bb23d 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -450,3 +450,24 @@ class ForeignKeyField(BaseField): value.__class__.__name__, cls._construct_model_from_pk )(value, child, to_register) return model + + @classmethod + def get_relation_name(cls) -> str: + """ + Returns name of the relation, which can be a own name or through model + names for m2m models + + :return: result of the check + :rtype: bool + """ + return cls.name + + @classmethod + def get_source_model(cls) -> Type["Model"]: + """ + Returns model from which the relation comes -> either owner or through model + + :return: source model + :rtype: Type["Model"] + """ + return cls.owner diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py index 2228121..ad7e6d9 100644 --- a/ormar/fields/many_to_many.py +++ b/ormar/fields/many_to_many.py @@ -187,3 +187,26 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro globalns, localns or None, ) + + @classmethod + def get_relation_name(cls) -> str: + """ + Returns name of the relation, which can be a own name or through model + names for m2m models + + :return: result of the check + :rtype: bool + """ + if cls.self_reference and cls.name == cls.self_reference_primary: + return cls.default_source_field_name() + return cls.default_target_field_name() + + @classmethod + def get_source_model(cls) -> Type["Model"]: + """ + Returns model from which the relation comes -> either owner or through model + + :return: source model + :rtype: Type["Model"] + """ + return cls.through diff --git a/ormar/models/mixins/excludable_mixin.py b/ormar/models/mixins/excludable_mixin.py index e474bd4..4b096d9 100644 --- a/ormar/models/mixins/excludable_mixin.py +++ b/ormar/models/mixins/excludable_mixin.py @@ -131,7 +131,9 @@ class ExcludableMixin(RelationMixin): @staticmethod def _populate_pk_column( - model: Type["Model"], columns: List[str], use_alias: bool = False, + model: Union[Type["Model"], Type["ModelRow"]], + columns: List[str], + use_alias: bool = False, ) -> List[str]: """ Adds primary key column/alias (depends on use_alias flag) to list of diff --git a/ormar/models/model_row.py b/ormar/models/model_row.py index 9d33a64..f184bb8 100644 --- a/ormar/models/model_row.py +++ b/ormar/models/model_row.py @@ -4,6 +4,7 @@ from typing import ( List, Optional, Set, + TYPE_CHECKING, Type, TypeVar, Union, @@ -17,20 +18,22 @@ from ormar.models.helpers.models import group_related_list T = TypeVar("T", bound="ModelRow") +if TYPE_CHECKING: + from ormar.fields import ForeignKeyField + class ModelRow(NewBaseModel): @classmethod - def from_row( # noqa CCR001 + def from_row( cls: Type[T], row: sqlalchemy.engine.ResultProxy, + source_model: Type[T], select_related: List = None, related_models: Any = None, - previous_model: Type[T] = None, - source_model: Type[T] = None, - related_name: str = None, + related_field: Type["ForeignKeyField"] = None, fields: Optional[Union[Dict, Set]] = None, exclude_fields: Optional[Union[Dict, Set]] = None, - current_relation_str: str = None, + current_relation_str: str = "", ) -> Optional[T]: """ Model method to convert raw sql row from database into ormar.Model instance. @@ -55,10 +58,8 @@ class ModelRow(NewBaseModel): :type select_related: List :param related_models: list or dict of related models :type related_models: Union[List, Dict] - :param previous_model: internal param for nested models to specify table_prefix - :type previous_model: Model class - :param related_name: internal parameter - name of current nested model - :type related_name: str + :param related_field: field with relation declaration + :type related_field: Type[ForeignKeyField] :param fields: fields and related model fields to include if provided only those are included :type fields: Optional[Union[Dict, Set]] @@ -77,35 +78,12 @@ class ModelRow(NewBaseModel): source_model = cls related_models = group_related_list(select_related) - rel_name2 = related_name - - # TODO: refactor this into field classes? - if ( - previous_model - and related_name - and issubclass( - previous_model.Meta.model_fields[related_name], ManyToManyField + if related_field: + table_prefix = cls.Meta.alias_manager.resolve_relation_alias_after_complex( + source_model=source_model, + relation_str=current_relation_str, + relation_field=related_field, ) - ): - through_field = previous_model.Meta.model_fields[related_name] - if ( - through_field.self_reference - and related_name == through_field.self_reference_primary - ): - rel_name2 = through_field.default_source_field_name() # type: ignore - else: - rel_name2 = through_field.default_target_field_name() # type: ignore - previous_model = through_field.through # type: ignore - - if previous_model and rel_name2: - if current_relation_str and "__" in current_relation_str and source_model: - table_prefix = cls.Meta.alias_manager.resolve_relation_alias( - from_model=source_model, relation_name=current_relation_str - ) - if not table_prefix: - table_prefix = cls.Meta.alias_manager.resolve_relation_alias( - from_model=previous_model, relation_name=rel_name2 - ) item = cls.populate_nested_models_from_row( item=item, @@ -138,11 +116,11 @@ class ModelRow(NewBaseModel): cls, item: dict, row: sqlalchemy.engine.ResultProxy, + source_model: Type[T], related_models: Any, fields: Optional[Union[Dict, Set]] = None, exclude_fields: Optional[Union[Dict, Set]] = None, current_relation_str: str = None, - source_model: Type[T] = None, ) -> dict: """ Traverses structure of related models and populates the nested models @@ -192,8 +170,7 @@ class ModelRow(NewBaseModel): child = model_cls.from_row( row, related_models=remainder, - previous_model=cls, - related_name=related, + related_field=field, fields=fields, exclude_fields=exclude_fields, current_relation_str=relation_str, diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 77ec421..051e695 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -128,6 +128,7 @@ class QuerySet: select_related=self._select_related, fields=self._columns, exclude_fields=self._exclude_columns, + source_model=self.model, ) for row in rows ] diff --git a/ormar/relations/alias_manager.py b/ormar/relations/alias_manager.py index c99dfdb..815a4dc 100644 --- a/ormar/relations/alias_manager.py +++ b/ormar/relations/alias_manager.py @@ -9,6 +9,7 @@ from sqlalchemy import text if TYPE_CHECKING: # pragma: no cover from ormar import Model from ormar.models import ModelRow + from ormar.fields import ForeignKeyField def get_table_alias() -> str: @@ -148,3 +149,35 @@ class AliasManager: """ alias = self._aliases_new.get(f"{from_model.get_name()}_{relation_name}", "") return alias + + def resolve_relation_alias_after_complex( + self, + source_model: Union[Type["Model"], Type["ModelRow"]], + relation_str: str, + relation_field: Type["ForeignKeyField"], + ) -> str: + """ + Given source model and relation string returns the alias for this complex + relation if it exists, otherwise fallback to normal relation from a relation + field definition. + + :param relation_field: field with direct relation definition + :type relation_field: Type["ForeignKeyField"] + :param source_model: model with query starts + :type source_model: source Model + :param relation_str: string with relation joins defined + :type relation_str: str + :return: alias of the relation + :rtype: str + """ + alias = "" + if relation_str and "__" in relation_str: + alias = self.resolve_relation_alias( + from_model=source_model, relation_name=relation_str + ) + if not alias: + alias = self.resolve_relation_alias( + from_model=relation_field.get_source_model(), + relation_name=relation_field.get_relation_name(), + ) + return alias diff --git a/tests/test_m2m_through_fields.py b/tests/test_m2m_through_fields.py index 229eac4..cb123fe 100644 --- a/tests/test_m2m_through_fields.py +++ b/tests/test_m2m_through_fields.py @@ -57,18 +57,18 @@ class PostCategory2(ormar.Model): sort_order: int = ormar.Integer(nullable=True) +class Post2(ormar.Model): + class Meta(BaseMeta): + pass + + id: int = ormar.Integer(primary_key=True) + title: str = ormar.String(max_length=200) + categories = ormar.ManyToMany(Category, through=ForwardRef("PostCategory2")) + + @pytest.mark.asyncio async def test_forward_ref_is_updated(): async with database: - - class Post2(ormar.Model): - class Meta(BaseMeta): - pass - - id: int = ormar.Integer(primary_key=True) - title: str = ormar.String(max_length=200) - categories = ormar.ManyToMany(Category, through=ForwardRef("PostCategory2")) - assert Post2.Meta.requires_ref_update Post2.update_forward_refs()