From e697235172b95ff8fa3f7426dcfb94b32bbb13cb Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 21 Feb 2021 17:46:06 +0100 Subject: [PATCH] intorduce relation flags on basefield and simplify imports --- ormar/fields/base.py | 8 ++- ormar/fields/foreign_key.py | 3 +- ormar/fields/many_to_many.py | 2 + ormar/fields/through_field.py | 11 ++-- ormar/models/__init__.py | 4 +- ormar/models/helpers/models.py | 2 +- ormar/models/helpers/pydantic.py | 7 ++- .../helpers/related_names_validation.py | 2 +- ormar/models/helpers/relations.py | 24 ++++---- ormar/models/helpers/sqlalchemy.py | 6 +- ormar/models/metaclass.py | 7 ++- ormar/models/mixins/prefetch_mixin.py | 12 ++-- ormar/models/mixins/relation_mixin.py | 12 +--- ormar/models/model.py | 6 +- ormar/models/model_row.py | 19 +++--- ormar/models/newbasemodel.py | 18 +++--- ormar/queryset/__init__.py | 4 +- ormar/queryset/join.py | 7 +-- ormar/queryset/prefetch_query.py | 11 ++-- ormar/queryset/queryset.py | 55 +++++++++++------- ormar/queryset/utils.py | 3 +- ormar/relations/querysetproxy.py | 25 ++++---- ormar/relations/relation.py | 18 +++--- ormar/relations/relation_manager.py | 25 ++++---- ormar/relations/relation_proxy.py | 4 +- test.db-journal | Bin 0 -> 4616 bytes tests/test_m2m_through_fields.py | 17 ++++-- 27 files changed, 163 insertions(+), 149 deletions(-) create mode 100644 test.db-journal diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 08308d8..95645c0 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -36,9 +36,13 @@ class BaseField(FieldInfo): index: bool unique: bool pydantic_only: bool - virtual: bool = False choices: typing.Sequence + virtual: bool = False # ManyToManyFields and reverse ForeignKeyFields + is_multi: bool = False # ManyToManyField + is_relation: bool = False # ForeignKeyField + subclasses + is_through: bool = False # ThroughFields + owner: Type["Model"] to: Type["Model"] through: Type["Model"] @@ -62,7 +66,7 @@ class BaseField(FieldInfo): :return: result of the check :rtype: bool """ - return not issubclass(cls, ormar.fields.ManyToManyField) and not cls.virtual + return not cls.is_multi and not cls.virtual @classmethod def get_alias(cls) -> str: diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index f6bccfa..8dda14f 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -48,7 +48,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": **{ k: create_dummy_instance(v.to) for k, v in fk.Meta.model_fields.items() - if isinstance(v, ForeignKeyField) and not v.nullable and not v.virtual + if v.is_relation and not v.nullable and not v.virtual }, } return fk(**init_dict) @@ -217,6 +217,7 @@ def ForeignKey( # noqa CFQ002 ondelete=ondelete, owner=owner, self_reference=self_reference, + is_relation=True, ) return type("ForeignKey", (ForeignKeyField, BaseField), namespace) diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py index ad7e6d9..2b2b300 100644 --- a/ormar/fields/many_to_many.py +++ b/ormar/fields/many_to_many.py @@ -103,6 +103,8 @@ def ManyToMany( server_default=None, owner=owner, self_reference=self_reference, + is_relation=True, + is_multi=True, ) return type("ManyToMany", (ManyToManyField, BaseField), namespace) diff --git a/ormar/fields/through_field.py b/ormar/fields/through_field.py index 99361c7..e5e4a24 100644 --- a/ormar/fields/through_field.py +++ b/ormar/fields/through_field.py @@ -15,12 +15,7 @@ if TYPE_CHECKING: # pragma no cover def Through( # noqa CFQ002 - to: "ToType", - *, - name: str = None, - related_name: str = None, - virtual: bool = True, - **kwargs: Any, + to: "ToType", *, name: str = None, related_name: str = None, **kwargs: Any, ) -> Any: # TODO: clean docstring """ @@ -52,7 +47,7 @@ def Through( # noqa CFQ002 alias=name, name=kwargs.pop("real_name", None), related_name=related_name, - virtual=virtual, + virtual=True, owner=owner, nullable=False, unique=False, @@ -62,6 +57,8 @@ def Through( # noqa CFQ002 pydantic_only=False, default=None, server_default=None, + is_relation=True, + is_through=True, ) return type("Through", (ThroughField, BaseField), namespace) diff --git a/ormar/models/__init__.py b/ormar/models/__init__.py index 9990c4a..58372b7 100644 --- a/ormar/models/__init__.py +++ b/ormar/models/__init__.py @@ -6,6 +6,6 @@ ass well as vast number of helper functions for pydantic, sqlalchemy and relatio from ormar.models.newbasemodel import NewBaseModel # noqa I100 from ormar.models.model_row import ModelRow # noqa I100 -from ormar.models.model import Model # noqa I100 +from ormar.models.model import Model, T # noqa I100 -__all__ = ["NewBaseModel", "Model", "ModelRow"] +__all__ = ["T", "NewBaseModel", "Model", "ModelRow"] diff --git a/ormar/models/helpers/models.py b/ormar/models/helpers/models.py index 449a920..6d67e91 100644 --- a/ormar/models/helpers/models.py +++ b/ormar/models/helpers/models.py @@ -21,7 +21,7 @@ def is_field_an_forward_ref(field: Type["BaseField"]) -> bool: :return: result of the check :rtype: bool """ - return issubclass(field, ormar.ForeignKeyField) and ( + return field.is_relation and ( field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef ) diff --git a/ormar/models/helpers/pydantic.py b/ormar/models/helpers/pydantic.py index 6a8f8a5..5797598 100644 --- a/ormar/models/helpers/pydantic.py +++ b/ormar/models/helpers/pydantic.py @@ -6,14 +6,15 @@ from pydantic.fields import ModelField from pydantic.utils import lenient_issubclass import ormar # noqa: I100, I202 -from ormar.fields import BaseField, ManyToManyField +from ormar.fields import BaseField if TYPE_CHECKING: # pragma no cover from ormar import Model + from ormar.fields import ManyToManyField def create_pydantic_field( - field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] + field_name: str, model: Type["Model"], model_field: Type["ManyToManyField"] ) -> None: """ Registers pydantic field on through model that leads to passed model @@ -59,7 +60,7 @@ def get_pydantic_field(field_name: str, model: Type["Model"]) -> "ModelField": def populate_default_pydantic_field_value( - ormar_field: Type[BaseField], field_name: str, attrs: dict + ormar_field: Type["BaseField"], field_name: str, attrs: dict ) -> dict: """ Grabs current value of the ormar Field in class namespace diff --git a/ormar/models/helpers/related_names_validation.py b/ormar/models/helpers/related_names_validation.py index 8bc32c1..56497b2 100644 --- a/ormar/models/helpers/related_names_validation.py +++ b/ormar/models/helpers/related_names_validation.py @@ -25,7 +25,7 @@ def validate_related_names_in_relations( # noqa CCR001 """ already_registered: Dict[str, List[Optional[str]]] = dict() for field in model_fields.values(): - if issubclass(field, ormar.ForeignKeyField): + if field.is_relation: to_name = ( field.to.get_name() if not field.to.__class__ == ForwardRef diff --git a/ormar/models/helpers/relations.py b/ormar/models/helpers/relations.py index af9ee61..48b35be 100644 --- a/ormar/models/helpers/relations.py +++ b/ormar/models/helpers/relations.py @@ -1,14 +1,14 @@ -from typing import TYPE_CHECKING, Type +from typing import TYPE_CHECKING, Type, cast import ormar from ormar import ForeignKey, ManyToMany -from ormar.fields import ManyToManyField, Through, ThroughField -from ormar.fields.foreign_key import ForeignKeyField +from ormar.fields import Through from ormar.models.helpers.sqlalchemy import adjust_through_many_to_many_model from ormar.relations import AliasManager if TYPE_CHECKING: # pragma no cover from ormar import Model + from ormar.fields import ManyToManyField, ForeignKeyField alias_manager = AliasManager() @@ -32,7 +32,7 @@ def register_relation_on_build(field: Type["ForeignKeyField"]) -> None: ) -def register_many_to_many_relation_on_build(field: Type[ManyToManyField]) -> None: +def register_many_to_many_relation_on_build(field: Type["ManyToManyField"]) -> None: """ Registers connection between through model and both sides of the m2m relation. Registration include also reverse relation side to be able to join both sides. @@ -83,10 +83,8 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: """ model_fields = list(model.Meta.model_fields.values()) for model_field in model_fields: - if ( - issubclass(model_field, ForeignKeyField) - and not model_field.has_unresolved_forward_refs() - ): + if model_field.is_relation and not model_field.has_unresolved_forward_refs(): + model_field = cast(Type["ForeignKeyField"], model_field) expand_reverse_relationship(model_field=model_field) @@ -102,7 +100,7 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None: :type model_field: relation Field """ related_name = model_field.get_related_name() - if issubclass(model_field, ManyToManyField): + if model_field.is_multi: model_field.to.Meta.model_fields[related_name] = ManyToMany( model_field.owner, through=model_field.through, @@ -114,6 +112,7 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None: self_reference_primary=model_field.self_reference_primary, ) # register foreign keys on through model + model_field = cast(Type["ManyToManyField"], model_field) register_through_shortcut_fields(model_field=model_field) adjust_through_many_to_many_model(model_field=model_field) else: @@ -155,7 +154,7 @@ def register_through_shortcut_fields(model_field: Type["ManyToManyField"]) -> No ) -def register_relation_in_alias_manager(field: Type[ForeignKeyField]) -> None: +def register_relation_in_alias_manager(field: Type["ForeignKeyField"]) -> None: """ Registers the relation (and reverse relation) in alias manager. The m2m relations require registration of through model between @@ -168,11 +167,12 @@ def register_relation_in_alias_manager(field: Type[ForeignKeyField]) -> None: :param field: relation field :type field: ForeignKey or ManyToManyField class """ - if issubclass(field, ManyToManyField): + if field.is_multi: if field.has_unresolved_forward_refs(): return + field = cast(Type["ManyToManyField"], field) register_many_to_many_relation_on_build(field=field) - elif issubclass(field, ForeignKeyField) and not issubclass(field, ThroughField): + elif field.is_relation and not field.is_through: if field.has_unresolved_forward_refs(): return register_relation_on_build(field=field) diff --git a/ormar/models/helpers/sqlalchemy.py b/ormar/models/helpers/sqlalchemy.py index a17e786..b641969 100644 --- a/ormar/models/helpers/sqlalchemy.py +++ b/ormar/models/helpers/sqlalchemy.py @@ -156,11 +156,7 @@ def sqlalchemy_columns_from_model_fields( field.owner = new_model if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) - if ( - not field.pydantic_only - and not field.virtual - and not issubclass(field, ormar.ManyToManyField) - ): + if not field.pydantic_only and not field.virtual and not field.is_multi: columns.append(field.get_column(field.get_alias())) return pkname, columns diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 4134fd5..304fd19 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -262,7 +262,7 @@ def copy_and_replace_m2m_through_model( new_meta.model_fields = { name: field for name, field in new_meta.model_fields.items() - if not issubclass(field, ForeignKeyField) + if not field.is_relation } _, columns = sqlalchemy_columns_from_model_fields( new_meta.model_fields, copy_through @@ -329,7 +329,8 @@ def copy_data_from_parent_model( # noqa: CCR001 else attrs.get("__name__", "").lower() + "s" ) for field_name, field in base_class.Meta.model_fields.items(): - if issubclass(field, ManyToManyField): + if field.is_multi: + field = cast(Type["ManyToManyField"], field) copy_and_replace_m2m_through_model( field=field, field_name=field_name, @@ -339,7 +340,7 @@ def copy_data_from_parent_model( # noqa: CCR001 meta=meta, ) - elif issubclass(field, ForeignKeyField) and field.related_name: + elif field.is_relation and field.related_name: copy_field = type( # type: ignore field.__name__, (ForeignKeyField, BaseField), dict(field.__dict__) ) diff --git a/ormar/models/mixins/prefetch_mixin.py b/ormar/models/mixins/prefetch_mixin.py index 273dd01..d8ee350 100644 --- a/ormar/models/mixins/prefetch_mixin.py +++ b/ormar/models/mixins/prefetch_mixin.py @@ -1,9 +1,10 @@ -from typing import Callable, Dict, List, TYPE_CHECKING, Tuple, Type +from typing import Callable, Dict, List, TYPE_CHECKING, Tuple, Type, cast -import ormar -from ormar.fields.foreign_key import ForeignKeyField from ormar.models.mixins.relation_mixin import RelationMixin +if TYPE_CHECKING: + from ormar.fields import ForeignKeyField, ManyToManyField + class PrefetchQueryMixin(RelationMixin): """ @@ -39,7 +40,8 @@ class PrefetchQueryMixin(RelationMixin): if reverse: field_name = parent_model.Meta.model_fields[related].get_related_name() field = target_model.Meta.model_fields[field_name] - if issubclass(field, ormar.fields.ManyToManyField): + if field.is_multi: + field = cast(Type["ManyToManyField"], field) field_name = field.default_target_field_name() sub_field = field.through.Meta.model_fields[field_name] return field.through, sub_field.get_alias() @@ -87,7 +89,7 @@ class PrefetchQueryMixin(RelationMixin): :return: name of the field :rtype: str """ - if issubclass(target_field, ormar.fields.ManyToManyField): + if target_field.is_multi: return cls.get_name() if target_field.virtual: return target_field.get_related_name() diff --git a/ormar/models/mixins/relation_mixin.py b/ormar/models/mixins/relation_mixin.py index 435fcc8..aebaa20 100644 --- a/ormar/models/mixins/relation_mixin.py +++ b/ormar/models/mixins/relation_mixin.py @@ -1,10 +1,6 @@ import inspect from typing import List, Optional, Set, TYPE_CHECKING -from ormar import ManyToManyField -from ormar.fields import ThroughField -from ormar.fields.foreign_key import ForeignKeyField - class RelationMixin: """ @@ -62,7 +58,7 @@ class RelationMixin: related_fields = set() for name in cls.extract_related_names(): field = cls.Meta.model_fields[name] - if issubclass(field, ManyToManyField): + if field.is_multi: related_fields.add(field.through.get_name(lower=True)) return related_fields @@ -80,11 +76,7 @@ class RelationMixin: related_names = set() for name, field in cls.Meta.model_fields.items(): - if ( - inspect.isclass(field) - and issubclass(field, ForeignKeyField) - and not issubclass(field, ThroughField) - ): + if inspect.isclass(field) and field.is_relation and not field.is_through: related_names.add(name) cls._related_names = related_names diff --git a/ormar/models/model.py b/ormar/models/model.py index 9286da9..c20368d 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -8,7 +8,6 @@ from typing import ( import ormar.queryset # noqa I100 from ormar.exceptions import ModelPersistenceError, NoMatch -from ormar.fields.many_to_many import ManyToManyField from ormar.models import NewBaseModel # noqa I100 from ormar.models.metaclass import ModelMeta from ormar.models.model_row import ModelRow @@ -139,8 +138,9 @@ class Model(ModelRow): visited.add(self.__class__) for related in self.extract_related_names(): - if self.Meta.model_fields[related].virtual or issubclass( - self.Meta.model_fields[related], ManyToManyField + if ( + self.Meta.model_fields[related].virtual + or self.Meta.model_fields[related].is_multi ): for rel in getattr(self, related): update_count, visited = await self._update_and_follow( diff --git a/ormar/models/model_row.py b/ormar/models/model_row.py index f184bb8..f0a4a31 100644 --- a/ormar/models/model_row.py +++ b/ormar/models/model_row.py @@ -8,24 +8,26 @@ from typing import ( Type, TypeVar, Union, + cast, ) import sqlalchemy -from ormar import ManyToManyField # noqa: I202 -from ormar.models import NewBaseModel +from ormar.models import NewBaseModel # noqa: I202 from ormar.models.helpers.models import group_related_list -T = TypeVar("T", bound="ModelRow") if TYPE_CHECKING: from ormar.fields import ForeignKeyField + from ormar.models import T +else: + T = TypeVar("T", bound="ModelRow") class ModelRow(NewBaseModel): @classmethod - def from_row( - cls: Type[T], + def from_row( # noqa: CFQ002 + cls: Type["ModelRow"], row: sqlalchemy.engine.ResultProxy, source_model: Type[T], select_related: List = None, @@ -75,7 +77,7 @@ class ModelRow(NewBaseModel): table_prefix = "" if select_related: - source_model = cls + source_model = cast(Type[T], cls) related_models = group_related_list(select_related) if related_field: @@ -107,7 +109,7 @@ class ModelRow(NewBaseModel): item["__excluded__"] = cls.get_names_to_exclude( fields=fields, exclude_fields=exclude_fields ) - instance = cls(**item) + instance = cast(T, cls(**item)) instance.set_save_status(True) return instance @@ -160,6 +162,7 @@ class ModelRow(NewBaseModel): else related ) field = cls.Meta.model_fields[related] + field = cast(Type["ForeignKeyField"], field) fields = cls.get_included(fields, related) exclude_fields = cls.get_excluded(exclude_fields, related) model_cls = field.to @@ -177,7 +180,7 @@ class ModelRow(NewBaseModel): source_model=source_model, ) item[model_cls.get_column_name_from_alias(related)] = child - if issubclass(field, ManyToManyField) and child: + if field.is_multi and child: # TODO: way to figure out which side should be populated? through_name = cls.Meta.model_fields[related].through.get_name() # for now it's nested dict, should be instance? diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 6975753..cc20807 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -46,15 +46,15 @@ from ormar.relations.alias_manager import AliasManager from ormar.relations.relation_manager import RelationsManager if TYPE_CHECKING: # pragma no cover - from ormar import Model + from ormar.models import Model, T from ormar.signals import SignalEmitter - T = TypeVar("T", bound=Model) - IntStr = Union[int, str] DictStrAny = Dict[str, Any] AbstractSetIntStr = AbstractSet[IntStr] MappingIntStrAny = Mapping[IntStr, Any] +else: + T = TypeVar("T") class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass): @@ -89,7 +89,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass Meta: ModelMeta # noinspection PyMissingConstructor - def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore + def __init__(self: T, *args: Any, **kwargs: Any) -> None: # type: ignore """ Initializer that creates a new ormar Model that is also pydantic Model at the same time. @@ -129,7 +129,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass object.__setattr__( self, "_orm", - RelationsManager(related_fields=self.extract_related_fields(), owner=self,), + RelationsManager( + related_fields=self.extract_related_fields(), owner=cast(T, self), + ), ) pk_only = kwargs.pop("__pk_only__", False) @@ -298,7 +300,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass def _extract_related_model_instead_of_field( self, item: str - ) -> Optional[Union["T", Sequence["T"]]]: + ) -> Optional[Union["Model", Sequence["Model"]]]: """ Retrieves the related model/models from RelationshipManager. @@ -755,9 +757,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass :return: value of pk if set :rtype: Optional[int] """ - if target_field.virtual or issubclass( - target_field, ormar.fields.ManyToManyField - ): + if target_field.virtual or target_field.is_multi: return self.pk related_name = target_field.name related_model = getattr(self, related_name) diff --git a/ormar/queryset/__init__.py b/ormar/queryset/__init__.py index 8528b05..ebfab7b 100644 --- a/ormar/queryset/__init__.py +++ b/ormar/queryset/__init__.py @@ -5,6 +5,6 @@ from ormar.queryset.filter_query import FilterQuery from ormar.queryset.limit_query import LimitQuery from ormar.queryset.offset_query import OffsetQuery from ormar.queryset.order_query import OrderQuery -from ormar.queryset.queryset import QuerySet +from ormar.queryset.queryset import QuerySet, T -__all__ = ["QuerySet", "FilterQuery", "LimitQuery", "OffsetQuery", "OrderQuery"] +__all__ = ["T", "QuerySet", "FilterQuery", "LimitQuery", "OffsetQuery", "OrderQuery"] diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index 9ce306d..e90e49d 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -15,7 +15,6 @@ import sqlalchemy from sqlalchemy import text from ormar.exceptions import RelationshipInstanceError # noqa I100 -from ormar.fields import BaseField, ManyToManyField # noqa I100 from ormar.relations import AliasManager if TYPE_CHECKING: # pragma no cover @@ -118,7 +117,7 @@ class SqlJoin: :return: list of used aliases, select from, list of aliased columns, sort orders :rtype: Tuple[List[str], Join, List[TextClause], collections.OrderedDict] """ - if issubclass(self.target_field, ManyToManyField): + if self.target_field.is_multi: self.process_m2m_through_table() self.next_model = self.target_field.to @@ -287,7 +286,7 @@ class SqlJoin: ) pkname_alias = self.next_model.get_column_alias(self.next_model.Meta.pkname) - if not issubclass(self.target_field, ManyToManyField): + if not self.target_field.is_multi: self.get_order_bys( to_table=to_table, pkname_alias=pkname_alias, ) @@ -415,7 +414,7 @@ class SqlJoin: :return: to key and from key :rtype: Tuple[str, str] """ - if issubclass(self.target_field, ManyToManyField): + if self.target_field.is_multi: to_key = self.process_m2m_related_name_change(reverse=True) from_key = self.main_model.get_column_alias(self.main_model.Meta.pkname) diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 4c8c6d7..7abf4c6 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -13,14 +13,13 @@ from typing import ( ) import ormar -from ormar.fields import BaseField, ManyToManyField -from ormar.fields.foreign_key import ForeignKeyField from ormar.queryset.clause import QueryClause from ormar.queryset.query import Query from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict if TYPE_CHECKING: # pragma: no cover from ormar import Model + from ormar.fields import ForeignKeyField, BaseField def add_relation_field_to_fields( @@ -316,7 +315,7 @@ class PrefetchQuery: for related in related_to_extract: target_field = model.Meta.model_fields[related] - target_field = cast(Type[ForeignKeyField], target_field) + target_field = cast(Type["ForeignKeyField"], target_field) target_model = target_field.to.get_name() model_id = model.get_relation_model_id(target_field=target_field) @@ -424,9 +423,9 @@ class PrefetchQuery: fields = target_model.get_included(fields, related) exclude_fields = target_model.get_excluded(exclude_fields, related) target_field = target_model.Meta.model_fields[related] - target_field = cast(Type[ForeignKeyField], target_field) + target_field = cast(Type["ForeignKeyField"], target_field) reverse = False - if target_field.virtual or issubclass(target_field, ManyToManyField): + if target_field.virtual or target_field.is_multi: reverse = True parent_model = target_model @@ -522,7 +521,7 @@ class PrefetchQuery: select_related = [] query_target = target_model table_prefix = "" - if issubclass(target_field, ManyToManyField): + if target_field.is_multi: query_target = target_field.through select_related = [target_name] table_prefix = target_field.to.Meta.alias_manager.resolve_relation_alias( diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 051e695..ba55586 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1,4 +1,17 @@ -from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union +from typing import ( + Any, + Dict, + Generic, + List, + Optional, + Sequence, + Set, + TYPE_CHECKING, + Type, + TypeVar, + Union, + cast, +) import databases import sqlalchemy @@ -14,19 +27,21 @@ from ormar.queryset.query import Query from ormar.queryset.utils import update, update_dict_from_list if TYPE_CHECKING: # pragma no cover - from ormar import Model + from ormar.models import T from ormar.models.metaclass import ModelMeta from ormar.relations.querysetproxy import QuerysetProxy +else: + T = TypeVar("T") -class QuerySet: +class QuerySet(Generic[T]): """ Main class to perform database queries, exposed on each model as objects attribute. """ def __init__( # noqa CFQ002 self, - model_cls: Type["Model"] = None, + model_cls: Optional[Type[T]] = None, filter_clauses: List = None, exclude_clauses: List = None, select_related: List = None, @@ -53,7 +68,7 @@ class QuerySet: def __get__( self, instance: Optional[Union["QuerySet", "QuerysetProxy"]], - owner: Union[Type["Model"], Type["QuerysetProxy"]], + owner: Union[Type[T], Type["QuerysetProxy"]], ) -> "QuerySet": if issubclass(owner, ormar.Model): if owner.Meta.requires_ref_update: @@ -62,7 +77,7 @@ class QuerySet: f"ForwardRefs. \nBefore using the model you " f"need to call update_forward_refs()." ) - if issubclass(owner, ormar.Model): + owner = cast(Type[T], owner) return self.__class__(model_cls=owner) return self.__class__() # pragma: no cover @@ -79,7 +94,7 @@ class QuerySet: return self.model_cls.Meta @property - def model(self) -> Type["Model"]: + def model(self) -> Type[T]: """ Shortcut to model class set on QuerySet. @@ -91,8 +106,8 @@ class QuerySet: return self.model_cls async def _prefetch_related_models( - self, models: Sequence[Optional["Model"]], rows: List - ) -> Sequence[Optional["Model"]]: + self, models: Sequence[Optional["T"]], rows: List + ) -> Sequence[Optional["T"]]: """ Performs prefetch query for selected models names. @@ -113,7 +128,7 @@ class QuerySet: ) return await query.prefetch_related(models=models, rows=rows) # type: ignore - def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: + def _process_query_result_rows(self, rows: List) -> Sequence[Optional[T]]: """ Process database rows and initialize ormar Model from each of the rows. @@ -137,7 +152,7 @@ class QuerySet: return result_rows @staticmethod - def check_single_result_rows_count(rows: Sequence[Optional["Model"]]) -> None: + def check_single_result_rows_count(rows: Sequence[Optional[T]]) -> None: """ Verifies if the result has one and only one row. @@ -198,7 +213,7 @@ class QuerySet: limit_raw_sql=self.limit_sql_raw, ) exp = qry.build_select_expression() - print("\n", exp.compile(compile_kwargs={"literal_binds": True})) + # print("\n", exp.compile(compile_kwargs={"literal_binds": True})) return exp def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003 @@ -683,7 +698,7 @@ class QuerySet: limit_raw_sql=limit_raw_sql, ) - async def first(self, **kwargs: Any) -> "Model": + async def first(self, **kwargs: Any) -> T: """ Gets the first row from the db ordered by primary key column ascending. @@ -707,7 +722,7 @@ class QuerySet: self.check_single_result_rows_count(processed_rows) return processed_rows[0] # type: ignore - async def get(self, **kwargs: Any) -> "Model": + async def get(self, **kwargs: Any) -> T: """ Get's the first row from the db meeting the criteria set by kwargs. @@ -739,7 +754,7 @@ class QuerySet: self.check_single_result_rows_count(processed_rows) return processed_rows[0] # type: ignore - async def get_or_create(self, **kwargs: Any) -> "Model": + async def get_or_create(self, **kwargs: Any) -> T: """ Combination of create and get methods. @@ -757,7 +772,7 @@ class QuerySet: except NoMatch: return await self.create(**kwargs) - async def update_or_create(self, **kwargs: Any) -> "Model": + async def update_or_create(self, **kwargs: Any) -> T: """ Updates the model, or in case there is no match in database creates a new one. @@ -774,7 +789,7 @@ class QuerySet: model = await self.get(pk=kwargs[pk_name]) return await model.update(**kwargs) - async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003 + async def all(self, **kwargs: Any) -> Sequence[Optional[T]]: # noqa: A003 """ Returns all rows from a database for given model for set filter options. @@ -798,7 +813,7 @@ class QuerySet: return result_rows - async def create(self, **kwargs: Any) -> "Model": + async def create(self, **kwargs: Any) -> T: """ Creates the model instance, saves it in a database and returns the updates model (with pk populated if not passed and autoincrement is set). @@ -841,7 +856,7 @@ class QuerySet: ) return instance - async def bulk_create(self, objects: List["Model"]) -> None: + async def bulk_create(self, objects: List[T]) -> None: """ Performs a bulk update in one database session to speed up the process. @@ -867,7 +882,7 @@ class QuerySet: objt.set_save_status(True) async def bulk_update( # noqa: CCR001 - self, objects: List["Model"], columns: List[str] = None + self, objects: List[T], columns: List[str] = None ) -> None: """ Performs bulk update in one database session to speed up the process. diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index e2cf33a..f1cbf43 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -12,7 +12,6 @@ from typing import ( Union, ) -from ormar.fields import ManyToManyField if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -236,7 +235,7 @@ def get_relationship_alias_model_and_str( manager = model_cls.Meta.alias_manager for relation in related_parts: related_field = model_cls.Meta.model_fields[relation] - if issubclass(related_field, ManyToManyField): + if related_field.is_multi: previous_model = related_field.through relation = related_field.default_target_field_name() # type: ignore table_prefix = manager.resolve_relation_alias( diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index c23dcb7..031684b 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -1,6 +1,7 @@ from typing import ( Any, Dict, + Generic, List, MutableSequence, Optional, @@ -9,6 +10,7 @@ from typing import ( TYPE_CHECKING, TypeVar, Union, + cast, ) import ormar @@ -16,14 +18,14 @@ from ormar.exceptions import ModelPersistenceError if TYPE_CHECKING: # pragma no cover from ormar.relations import Relation - from ormar.models import Model + from ormar.models import Model, T from ormar.queryset import QuerySet from ormar import RelationType - - T = TypeVar("T", bound=Model) +else: + T = TypeVar("T") -class QuerysetProxy(ormar.QuerySetProtocol): +class QuerysetProxy(Generic[T]): """ Exposes QuerySet methods on relations, but also handles creating and removing of through Models for m2m relations. @@ -47,7 +49,7 @@ class QuerysetProxy(ormar.QuerySetProtocol): self.through_model_name = ( self.related_field.through.get_name() if self.type_ == ormar.RelationType.MULTIPLE - else None + else "" ) @property @@ -94,6 +96,7 @@ class QuerysetProxy(ormar.QuerySetProtocol): self._assign_child_to_parent(subchild) else: assert isinstance(child, ormar.Model) + child = cast(T, child) self._assign_child_to_parent(child) def _clean_items_on_load(self) -> None: @@ -198,7 +201,7 @@ class QuerysetProxy(ormar.QuerySetProtocol): ) return await queryset.delete(**kwargs) # type: ignore - async def first(self, **kwargs: Any) -> "Model": + async def first(self, **kwargs: Any) -> T: """ Gets the first row from the db ordered by primary key column ascending. @@ -216,7 +219,7 @@ class QuerysetProxy(ormar.QuerySetProtocol): self._register_related(first) return first - async def get(self, **kwargs: Any) -> "Model": + async def get(self, **kwargs: Any) -> "T": """ Get's the first row from the db meeting the criteria set by kwargs. @@ -240,7 +243,7 @@ class QuerysetProxy(ormar.QuerySetProtocol): self._register_related(get) return get - async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003 + async def all(self, **kwargs: Any) -> Sequence[Optional["T"]]: # noqa: A003 """ Returns all rows from a database for given model for set filter options. @@ -262,7 +265,7 @@ class QuerysetProxy(ormar.QuerySetProtocol): self._register_related(all_items) return all_items - async def create(self, **kwargs: Any) -> "Model": + async def create(self, **kwargs: Any) -> "T": """ Creates the model instance, saves it in a database and returns the updates model (with pk populated if not passed and autoincrement is set). @@ -287,7 +290,7 @@ class QuerysetProxy(ormar.QuerySetProtocol): await self.create_through_instance(created, **through_kwargs) return created - async def get_or_create(self, **kwargs: Any) -> "Model": + async def get_or_create(self, **kwargs: Any) -> "T": """ Combination of create and get methods. @@ -305,7 +308,7 @@ class QuerysetProxy(ormar.QuerySetProtocol): except ormar.NoMatch: return await self.create(**kwargs) - async def update_or_create(self, **kwargs: Any) -> "Model": + async def update_or_create(self, **kwargs: Any) -> "T": """ Updates the model, or in case there is no match in database creates a new one. diff --git a/ormar/relations/relation.py b/ormar/relations/relation.py index 0ae2f59..6d4da36 100644 --- a/ormar/relations/relation.py +++ b/ormar/relations/relation.py @@ -1,17 +1,13 @@ from enum import Enum -from typing import List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union +from typing import List, Optional, Set, TYPE_CHECKING, Type, Union import ormar # noqa I100 from ormar.exceptions import RelationshipInstanceError # noqa I100 -from ormar.fields.foreign_key import ForeignKeyField # noqa I100 from ormar.relations.relation_proxy import RelationProxy if TYPE_CHECKING: # pragma no cover - from ormar import Model from ormar.relations import RelationsManager - from ormar.models import NewBaseModel - - T = TypeVar("T", bound=Model) + from ormar.models import Model, NewBaseModel, T class RelationType(Enum): @@ -39,7 +35,7 @@ class Relation: manager: "RelationsManager", type_: RelationType, field_name: str, - to: Type["T"], + to: Type["Model"], through: Type["T"] = None, ) -> None: """ @@ -63,10 +59,10 @@ class Relation: self._owner: "Model" = manager.owner self._type: RelationType = type_ self._to_remove: Set = set() - self.to: Type["T"] = to - self._through: Optional[Type["T"]] = through + self.to: Type["Model"] = to + self._through = through self.field_name: str = field_name - self.related_models: Optional[Union[RelationProxy, "T"]] = ( + self.related_models: Optional[Union[RelationProxy, "Model"]] = ( RelationProxy(relation=self, type_=type_, field_name=field_name) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) else None @@ -161,7 +157,7 @@ class Relation: self.related_models.pop(position) # type: ignore del self._owner.__dict__[relation_name][position] - def get(self) -> Optional[Union[List["T"], "T"]]: + def get(self) -> Optional[Union[List["Model"], "Model"]]: """ Return the related model or models from RelationProxy. diff --git a/ormar/relations/relation_manager.py b/ormar/relations/relation_manager.py index addfcf1..a718b09 100644 --- a/ormar/relations/relation_manager.py +++ b/ormar/relations/relation_manager.py @@ -1,17 +1,12 @@ -from typing import Dict, List, Optional, Sequence, TYPE_CHECKING, Type, TypeVar, Union +from typing import Dict, List, Optional, Sequence, TYPE_CHECKING, Type, Union from weakref import proxy -from ormar.fields import BaseField, ThroughField -from ormar.fields.foreign_key import ForeignKeyField -from ormar.fields.many_to_many import ManyToManyField from ormar.relations.relation import Relation, RelationType from ormar.relations.utils import get_relations_sides_and_names if TYPE_CHECKING: # pragma no cover - from ormar import Model - from ormar.models import NewBaseModel - - T = TypeVar("T", bound=Model) + from ormar.models import NewBaseModel, T, Model + from ormar.fields import ForeignKeyField, BaseField class RelationsManager: @@ -21,8 +16,8 @@ class RelationsManager: def __init__( self, - related_fields: List[Type[ForeignKeyField]] = None, - owner: "NewBaseModel" = None, + related_fields: List[Type["ForeignKeyField"]] = None, + owner: Optional["T"] = None, ) -> None: self.owner = proxy(owner) self._related_fields = related_fields or [] @@ -31,7 +26,7 @@ class RelationsManager: for field in self._related_fields: self._add_relation(field) - def _get_relation_type(self, field: Type[BaseField]) -> RelationType: + def _get_relation_type(self, field: Type["BaseField"]) -> RelationType: """ Returns type of the relation declared on a field. @@ -40,13 +35,13 @@ class RelationsManager: :return: type of the relation defined on field :rtype: RelationType """ - if issubclass(field, ManyToManyField): + if field.is_multi: return RelationType.MULTIPLE - if issubclass(field, ThroughField): + if field.is_through: return RelationType.THROUGH return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE - def _add_relation(self, field: Type[BaseField]) -> None: + def _add_relation(self, field: Type["BaseField"]) -> None: """ Registers relation in the manager. Adds Relation instance under field.name. @@ -73,7 +68,7 @@ class RelationsManager: """ return item in self._related_names - def get(self, name: str) -> Optional[Union["T", Sequence["T"]]]: + def get(self, name: str) -> Optional[Union["Model", Sequence["Model"]]]: """ Returns the related model/models if relation is set. Actual call is delegated to Relation instance registered under relation name. diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index 87bb05f..58d6e9e 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -27,7 +27,9 @@ class RelationProxy(list): self.type_: "RelationType" = type_ self.field_name = field_name self._owner: "Model" = self.relation.manager.owner - self.queryset_proxy = QuerysetProxy(relation=self.relation, type_=type_) + self.queryset_proxy: QuerysetProxy = QuerysetProxy( + relation=self.relation, type_=type_ + ) self._related_field_name: Optional[str] = None @property diff --git a/test.db-journal b/test.db-journal new file mode 100644 index 0000000000000000000000000000000000000000..f5538646e9b3593b93b04ca5c2a9d379bdd826c2 GIT binary patch literal 4616 zcmZQzK!AcPQ+$9tRt5$pASHm17-bFT5Mbs7`EE42hjSc{`fW4>MnhmU1V%$(Gz3ON aU^E0qLtr!nMnhmU1V%$(Gz1tx<30d0eguF3 literal 0 HcmV?d00001 diff --git a/tests/test_m2m_through_fields.py b/tests/test_m2m_through_fields.py index cb123fe..279d1a8 100644 --- a/tests/test_m2m_through_fields.py +++ b/tests/test_m2m_through_fields.py @@ -1,3 +1,5 @@ +from typing import Any + import databases import pytest import sqlalchemy @@ -19,8 +21,8 @@ class Category(ormar.Model): class Meta(BaseMeta): tablename = "categories" - id: int = ormar.Integer(primary_key=True) - name: str = ormar.String(max_length=40) + id = ormar.Integer(primary_key=True) + name = ormar.String(max_length=40) class PostCategory(ormar.Model): @@ -107,8 +109,12 @@ async def test_setting_additional_fields_on_through_model_in_create(): assert postcat.sort_order == 2 +def process_post(post: Post): + pass + + @pytest.mark.asyncio -async def test_getting_additional_fields_from_queryset(): +async def test_getting_additional_fields_from_queryset() -> Any: async with database: post = await Post(title="Test post").save() await post.categories.create( @@ -122,10 +128,11 @@ async def test_getting_additional_fields_from_queryset(): assert post.categories[0].postcategory.sort_order == 1 assert post.categories[1].postcategory.sort_order == 2 - post = await Post.objects.select_related("categories").get( + post2 = await Post.objects.select_related("categories").get( categories__name="Test category2" ) - assert post.categories[0].postcategory.sort_order == 2 + assert post2.categories[0].postcategory.sort_order == 2 + process_post(post2) # TODO: check/ modify following