WIP - further work and refactoring

This commit is contained in:
collerek
2021-02-17 13:51:38 +01:00
parent 3fd231cf3c
commit 5d40fb6bff
7 changed files with 107 additions and 50 deletions

View File

@ -450,3 +450,24 @@ class ForeignKeyField(BaseField):
value.__class__.__name__, cls._construct_model_from_pk value.__class__.__name__, cls._construct_model_from_pk
)(value, child, to_register) )(value, child, to_register)
return model return model
@classmethod
def get_relation_name(cls) -> str:
"""
Returns name of the relation, which can be a own name or through model
names for m2m models
:return: result of the check
:rtype: bool
"""
return cls.name
@classmethod
def get_source_model(cls) -> Type["Model"]:
"""
Returns model from which the relation comes -> either owner or through model
:return: source model
:rtype: Type["Model"]
"""
return cls.owner

View File

@ -187,3 +187,26 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro
globalns, globalns,
localns or None, localns or None,
) )
@classmethod
def get_relation_name(cls) -> str:
"""
Returns name of the relation, which can be a own name or through model
names for m2m models
:return: result of the check
:rtype: bool
"""
if cls.self_reference and cls.name == cls.self_reference_primary:
return cls.default_source_field_name()
return cls.default_target_field_name()
@classmethod
def get_source_model(cls) -> Type["Model"]:
"""
Returns model from which the relation comes -> either owner or through model
:return: source model
:rtype: Type["Model"]
"""
return cls.through

View File

@ -131,7 +131,9 @@ class ExcludableMixin(RelationMixin):
@staticmethod @staticmethod
def _populate_pk_column( def _populate_pk_column(
model: Type["Model"], columns: List[str], use_alias: bool = False, model: Union[Type["Model"], Type["ModelRow"]],
columns: List[str],
use_alias: bool = False,
) -> List[str]: ) -> List[str]:
""" """
Adds primary key column/alias (depends on use_alias flag) to list of Adds primary key column/alias (depends on use_alias flag) to list of

View File

@ -4,6 +4,7 @@ from typing import (
List, List,
Optional, Optional,
Set, Set,
TYPE_CHECKING,
Type, Type,
TypeVar, TypeVar,
Union, Union,
@ -17,20 +18,22 @@ from ormar.models.helpers.models import group_related_list
T = TypeVar("T", bound="ModelRow") T = TypeVar("T", bound="ModelRow")
if TYPE_CHECKING:
from ormar.fields import ForeignKeyField
class ModelRow(NewBaseModel): class ModelRow(NewBaseModel):
@classmethod @classmethod
def from_row( # noqa CCR001 def from_row(
cls: Type[T], cls: Type[T],
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
source_model: Type[T],
select_related: List = None, select_related: List = None,
related_models: Any = None, related_models: Any = None,
previous_model: Type[T] = None, related_field: Type["ForeignKeyField"] = None,
source_model: Type[T] = None,
related_name: str = None,
fields: Optional[Union[Dict, Set]] = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: Optional[Union[Dict, Set]] = None, exclude_fields: Optional[Union[Dict, Set]] = None,
current_relation_str: str = None, current_relation_str: str = "",
) -> Optional[T]: ) -> Optional[T]:
""" """
Model method to convert raw sql row from database into ormar.Model instance. Model method to convert raw sql row from database into ormar.Model instance.
@ -55,10 +58,8 @@ class ModelRow(NewBaseModel):
:type select_related: List :type select_related: List
:param related_models: list or dict of related models :param related_models: list or dict of related models
:type related_models: Union[List, Dict] :type related_models: Union[List, Dict]
:param previous_model: internal param for nested models to specify table_prefix :param related_field: field with relation declaration
:type previous_model: Model class :type related_field: Type[ForeignKeyField]
:param related_name: internal parameter - name of current nested model
:type related_name: str
:param fields: fields and related model fields to include :param fields: fields and related model fields to include
if provided only those are included if provided only those are included
:type fields: Optional[Union[Dict, Set]] :type fields: Optional[Union[Dict, Set]]
@ -77,35 +78,12 @@ class ModelRow(NewBaseModel):
source_model = cls source_model = cls
related_models = group_related_list(select_related) related_models = group_related_list(select_related)
rel_name2 = related_name if related_field:
table_prefix = cls.Meta.alias_manager.resolve_relation_alias_after_complex(
# TODO: refactor this into field classes? source_model=source_model,
if ( relation_str=current_relation_str,
previous_model relation_field=related_field,
and related_name
and issubclass(
previous_model.Meta.model_fields[related_name], ManyToManyField
) )
):
through_field = previous_model.Meta.model_fields[related_name]
if (
through_field.self_reference
and related_name == through_field.self_reference_primary
):
rel_name2 = through_field.default_source_field_name() # type: ignore
else:
rel_name2 = through_field.default_target_field_name() # type: ignore
previous_model = through_field.through # type: ignore
if previous_model and rel_name2:
if current_relation_str and "__" in current_relation_str and source_model:
table_prefix = cls.Meta.alias_manager.resolve_relation_alias(
from_model=source_model, relation_name=current_relation_str
)
if not table_prefix:
table_prefix = cls.Meta.alias_manager.resolve_relation_alias(
from_model=previous_model, relation_name=rel_name2
)
item = cls.populate_nested_models_from_row( item = cls.populate_nested_models_from_row(
item=item, item=item,
@ -138,11 +116,11 @@ class ModelRow(NewBaseModel):
cls, cls,
item: dict, item: dict,
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
source_model: Type[T],
related_models: Any, related_models: Any,
fields: Optional[Union[Dict, Set]] = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: Optional[Union[Dict, Set]] = None, exclude_fields: Optional[Union[Dict, Set]] = None,
current_relation_str: str = None, current_relation_str: str = None,
source_model: Type[T] = None,
) -> dict: ) -> dict:
""" """
Traverses structure of related models and populates the nested models Traverses structure of related models and populates the nested models
@ -192,8 +170,7 @@ class ModelRow(NewBaseModel):
child = model_cls.from_row( child = model_cls.from_row(
row, row,
related_models=remainder, related_models=remainder,
previous_model=cls, related_field=field,
related_name=related,
fields=fields, fields=fields,
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
current_relation_str=relation_str, current_relation_str=relation_str,

View File

@ -128,6 +128,7 @@ class QuerySet:
select_related=self._select_related, select_related=self._select_related,
fields=self._columns, fields=self._columns,
exclude_fields=self._exclude_columns, exclude_fields=self._exclude_columns,
source_model=self.model,
) )
for row in rows for row in rows
] ]

View File

@ -9,6 +9,7 @@ from sqlalchemy import text
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from ormar import Model from ormar import Model
from ormar.models import ModelRow from ormar.models import ModelRow
from ormar.fields import ForeignKeyField
def get_table_alias() -> str: def get_table_alias() -> str:
@ -148,3 +149,35 @@ class AliasManager:
""" """
alias = self._aliases_new.get(f"{from_model.get_name()}_{relation_name}", "") alias = self._aliases_new.get(f"{from_model.get_name()}_{relation_name}", "")
return alias return alias
def resolve_relation_alias_after_complex(
self,
source_model: Union[Type["Model"], Type["ModelRow"]],
relation_str: str,
relation_field: Type["ForeignKeyField"],
) -> str:
"""
Given source model and relation string returns the alias for this complex
relation if it exists, otherwise fallback to normal relation from a relation
field definition.
:param relation_field: field with direct relation definition
:type relation_field: Type["ForeignKeyField"]
:param source_model: model with query starts
:type source_model: source Model
:param relation_str: string with relation joins defined
:type relation_str: str
:return: alias of the relation
:rtype: str
"""
alias = ""
if relation_str and "__" in relation_str:
alias = self.resolve_relation_alias(
from_model=source_model, relation_name=relation_str
)
if not alias:
alias = self.resolve_relation_alias(
from_model=relation_field.get_source_model(),
relation_name=relation_field.get_relation_name(),
)
return alias

View File

@ -57,18 +57,18 @@ class PostCategory2(ormar.Model):
sort_order: int = ormar.Integer(nullable=True) sort_order: int = ormar.Integer(nullable=True)
class Post2(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=200)
categories = ormar.ManyToMany(Category, through=ForwardRef("PostCategory2"))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_forward_ref_is_updated(): async def test_forward_ref_is_updated():
async with database: async with database:
class Post2(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=200)
categories = ormar.ManyToMany(Category, through=ForwardRef("PostCategory2"))
assert Post2.Meta.requires_ref_update assert Post2.Meta.requires_ref_update
Post2.update_forward_refs() Post2.update_forward_refs()