From 9c091afe35621eacc2cae9f6bf26b2e75288995d Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 19 Mar 2021 17:13:59 +0100 Subject: [PATCH] improve relation_proxy types --- ormar/fields/many_to_many.py | 7 +-- ormar/models/helpers/relations.py | 2 +- ormar/relations/querysetproxy.py | 86 +++++++++++++++++-------------- ormar/relations/relation.py | 3 +- ormar/relations/relation_proxy.py | 15 ++++-- tests/test_types.py | 5 -- 6 files changed, 64 insertions(+), 54 deletions(-) diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py index 0ad11a5..a0989e0 100644 --- a/ormar/fields/many_to_many.py +++ b/ormar/fields/many_to_many.py @@ -8,7 +8,8 @@ from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField, validate_not_allowed_fields if TYPE_CHECKING: # pragma no cover - from ormar.models import Model + from ormar.models import Model, T + from ormar.relations.relation_proxy import RelationProxy if sys.version_info < (3, 7): ToType = Type["Model"] @@ -58,14 +59,14 @@ def populate_m2m_params_based_on_to_model( def ManyToMany( - to: "ToType", + to: Type["T"], through: Optional["ToType"] = None, *, name: str = None, unique: bool = False, virtual: bool = False, **kwargs: Any, -): +) -> "RelationProxy[T]": """ Despite a name it's a function that returns constructed ManyToManyField. This function is actually used in model declaration diff --git a/ormar/models/helpers/relations.py b/ormar/models/helpers/relations.py index 2eb01c2..29cebe6 100644 --- a/ormar/models/helpers/relations.py +++ b/ormar/models/helpers/relations.py @@ -101,7 +101,7 @@ def register_reverse_model_fields(model_field: "ForeignKeyField") -> None: """ related_name = model_field.get_related_name() if model_field.is_multi: - model_field.to.Meta.model_fields[related_name] = ManyToMany( + model_field.to.Meta.model_fields[related_name] = ManyToMany( # type: ignore model_field.owner, through=model_field.through, name=related_name, diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index d6543c8..a97c52a 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -2,13 +2,13 @@ from _weakref import CallableProxyType from typing import ( # noqa: I100, I201 Any, Dict, - List, + Generic, List, MutableSequence, Optional, Sequence, Set, TYPE_CHECKING, - Union, + Type, TypeVar, Union, cast, ) @@ -18,12 +18,14 @@ from ormar.exceptions import ModelPersistenceError, QueryDefinitionError if TYPE_CHECKING: # pragma no cover from ormar.relations import Relation - from ormar.models import Model + from ormar.models import Model, T from ormar.queryset import QuerySet from ormar import RelationType +else: + T = TypeVar("T", bound="Model") -class QuerysetProxy: +class QuerysetProxy(Generic[T]): """ Exposes QuerySet methods on relations, but also handles creating and removing of through Models for m2m relations. @@ -33,16 +35,20 @@ class QuerysetProxy: relation: "Relation" def __init__( - self, relation: "Relation", type_: "RelationType", qryset: "QuerySet" = None + self, relation: "Relation", + to: Type["T"], + type_: "RelationType", + qryset: "QuerySet[T]" = None ) -> None: self.relation: Relation = relation - self._queryset: Optional["QuerySet"] = qryset + self._queryset: Optional["QuerySet[T]"] = qryset self.type_: "RelationType" = type_ self._owner: Union[CallableProxyType, "Model"] = self.relation.manager.owner self.related_field_name = self._owner.Meta.model_fields[ self.relation.field_name ].get_related_name() - self.related_field = self.relation.to.Meta.model_fields[self.related_field_name] + self.to: Type[T] = to + self.related_field = to.Meta.model_fields[self.related_field_name] self.owner_pk_value = self._owner.pk self.through_model_name = ( self.related_field.through.get_name() @@ -51,7 +57,7 @@ class QuerysetProxy: ) @property - def queryset(self) -> "QuerySet": + def queryset(self) -> "QuerySet[T]": """ Returns queryset if it's set, AttributeError otherwise. :return: QuerySet @@ -70,7 +76,7 @@ class QuerysetProxy: """ self._queryset = value - def _assign_child_to_parent(self, child: Optional["Model"]) -> None: + def _assign_child_to_parent(self, child: Optional["T"]) -> None: """ Registers child in parents RelationManager. @@ -83,7 +89,7 @@ class QuerysetProxy: setattr(owner, rel_name, child) def _register_related( - self, child: Union["Model", Sequence[Optional["Model"]]] + self, child: Union["T", Sequence[Optional["T"]]] ) -> None: """ Registers child/ children in parents RelationManager. @@ -96,7 +102,7 @@ class QuerysetProxy: self._assign_child_to_parent(subchild) else: assert isinstance(child, ormar.Model) - child = cast("Model", child) + child = cast("T", child) self._assign_child_to_parent(child) def _clean_items_on_load(self) -> None: @@ -107,7 +113,7 @@ class QuerysetProxy: for item in self.relation.related_models[:]: self.relation.remove(item) - async def create_through_instance(self, child: "Model", **kwargs: Any) -> None: + async def create_through_instance(self, child: "T", **kwargs: Any) -> None: """ Crete a through model instance in the database for m2m relations. @@ -129,7 +135,7 @@ class QuerysetProxy: ) await model_cls(**final_kwargs).save() - async def update_through_instance(self, child: "Model", **kwargs: Any) -> None: + async def update_through_instance(self, child: "T", **kwargs: Any) -> None: """ Updates a through model instance in the database for m2m relations. @@ -145,7 +151,7 @@ class QuerysetProxy: through_model = await model_cls.objects.get(**rel_kwargs) await through_model.update(**kwargs) - async def delete_through_instance(self, child: "Model") -> None: + async def delete_through_instance(self, child: "T") -> None: """ Removes through model instance from the database for m2m relations. @@ -254,7 +260,7 @@ class QuerysetProxy: ) return await queryset.delete(**kwargs) # type: ignore - async def first(self, **kwargs: Any) -> "Model": + async def first(self, **kwargs: Any) -> "T": """ Gets the first row from the db ordered by primary key column ascending. @@ -272,7 +278,7 @@ class QuerysetProxy: self._register_related(first) return first - async def get(self, **kwargs: Any) -> "Model": + async def get(self, **kwargs: Any) -> "T": """ Get's the first row from the db meeting the criteria set by kwargs. @@ -296,7 +302,7 @@ class QuerysetProxy: self._register_related(get) return get - async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003 + async def all(self, **kwargs: Any) -> Sequence[Optional["T"]]: # noqa: A003 """ Returns all rows from a database for given model for set filter options. @@ -318,7 +324,7 @@ class QuerysetProxy: self._register_related(all_items) return all_items - async def create(self, **kwargs: Any) -> "Model": + async def create(self, **kwargs: Any) -> "T": """ Creates the model instance, saves it in a database and returns the updates model (with pk populated if not passed and autoincrement is set). @@ -375,7 +381,7 @@ class QuerysetProxy: ) return len(children) - async def get_or_create(self, **kwargs: Any) -> "Model": + async def get_or_create(self, **kwargs: Any) -> "T": """ Combination of create and get methods. @@ -393,7 +399,7 @@ class QuerysetProxy: except ormar.NoMatch: return await self.create(**kwargs) - async def update_or_create(self, **kwargs: Any) -> "Model": + async def update_or_create(self, **kwargs: Any) -> "T": """ Updates the model, or in case there is no match in database creates a new one. @@ -412,7 +418,7 @@ class QuerysetProxy: model = await self.queryset.get(pk=kwargs[pk_name]) return await model.update(**kwargs) - def filter(self, *args: Any, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001 + def filter(self, *args: Any, **kwargs: Any) -> "QuerysetProxy[T]": # noqa: A003, A001 """ Allows you to filter by any `Model` attribute/field as well as to fetch instances, with a filter across an FK relationship. @@ -443,9 +449,9 @@ class QuerysetProxy: :rtype: QuerysetProxy """ queryset = self.queryset.filter(*args, **kwargs) - return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + return self.__class__(relation=self.relation, type_=self.type_, to=self.to, qryset=queryset) - def exclude(self, *args: Any, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001 + def exclude(self, *args: Any, **kwargs: Any) -> "QuerysetProxy[T]": # noqa: A003, A001 """ Works exactly the same as filter and all modifiers (suffixes) are the same, but returns a *not* condition. @@ -467,9 +473,9 @@ class QuerysetProxy: :rtype: QuerysetProxy """ queryset = self.queryset.exclude(*args, **kwargs) - return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset) - def select_related(self, related: Union[List, str]) -> "QuerysetProxy": + def select_related(self, related: Union[List, str]) -> "QuerysetProxy[T]": """ Allows to prefetch related models during the same query. @@ -489,9 +495,9 @@ class QuerysetProxy: :rtype: QuerysetProxy """ queryset = self.queryset.select_related(related) - return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset) - def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy": + def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy[T]": """ Allows to prefetch related models during query - but opposite to `select_related` each subsequent model is fetched in a separate database query. @@ -512,9 +518,9 @@ class QuerysetProxy: :rtype: QuerysetProxy """ queryset = self.queryset.prefetch_related(related) - return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset) - def paginate(self, page: int, page_size: int = 20) -> "QuerysetProxy": + def paginate(self, page: int, page_size: int = 20) -> "QuerysetProxy[T]": """ You can paginate the result which is a combination of offset and limit clauses. Limit is set to page size and offset is set to (page-1) * page_size. @@ -529,9 +535,9 @@ class QuerysetProxy: :rtype: QuerySet """ queryset = self.queryset.paginate(page=page, page_size=page_size) - return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset) - def limit(self, limit_count: int) -> "QuerysetProxy": + def limit(self, limit_count: int) -> "QuerysetProxy[T]": """ You can limit the results to desired number of parent models. @@ -543,9 +549,9 @@ class QuerysetProxy: :rtype: QuerysetProxy """ queryset = self.queryset.limit(limit_count) - return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset) - def offset(self, offset: int) -> "QuerysetProxy": + def offset(self, offset: int) -> "QuerysetProxy[T]": """ You can also offset the results by desired number of main models. @@ -557,9 +563,9 @@ class QuerysetProxy: :rtype: QuerysetProxy """ queryset = self.queryset.offset(offset) - return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset) - def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy": + def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy[T]": """ With `fields()` you can select subset of model columns to limit the data load. @@ -605,9 +611,9 @@ class QuerysetProxy: :rtype: QuerysetProxy """ queryset = self.queryset.fields(columns) - return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset) - def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy": + def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy[T]": """ With `exclude_fields()` you can select subset of model columns that will be excluded to limit the data load. @@ -637,9 +643,9 @@ class QuerysetProxy: :rtype: QuerysetProxy """ queryset = self.queryset.exclude_fields(columns=columns) - return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset) - def order_by(self, columns: Union[List, str]) -> "QuerysetProxy": + def order_by(self, columns: Union[List, str]) -> "QuerysetProxy[T]": """ With `order_by()` you can order the results from database based on your choice of fields. @@ -674,4 +680,4 @@ class QuerysetProxy: :rtype: QuerysetProxy """ queryset = self.queryset.order_by(columns) - return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset) diff --git a/ormar/relations/relation.py b/ormar/relations/relation.py index bb7abd1..3d30423 100644 --- a/ormar/relations/relation.py +++ b/ormar/relations/relation.py @@ -63,7 +63,7 @@ class Relation: self._through = through self.field_name: str = field_name self.related_models: Optional[Union[RelationProxy, "Model"]] = ( - RelationProxy(relation=self, type_=type_, field_name=field_name) + RelationProxy(relation=self, type_=type_, to=to, field_name=field_name) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) else None ) @@ -94,6 +94,7 @@ class Relation: self.related_models = RelationProxy( relation=self, type_=self._type, + to=self.to, field_name=self.field_name, data_=cleaned_data, ) diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index 20932b8..c2b39a1 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Generic, Optional, TYPE_CHECKING, Type, TypeVar import ormar from ormar.exceptions import NoMatch, RelationshipInstanceError @@ -6,11 +6,14 @@ from ormar.relations.querysetproxy import QuerysetProxy if TYPE_CHECKING: # pragma no cover from ormar import Model, RelationType + from ormar.models import T from ormar.relations import Relation from ormar.queryset import QuerySet +else: + T = TypeVar("T", bound="Model") -class RelationProxy(list): +class RelationProxy(Generic[T], list): """ Proxy of the Relation that is a list with special methods. """ @@ -19,6 +22,7 @@ class RelationProxy(list): self, relation: "Relation", type_: "RelationType", + to: Type["T"], field_name: str, data_: Any = None, ) -> None: @@ -28,7 +32,7 @@ class RelationProxy(list): self.field_name = field_name self._owner: "Model" = self.relation.manager.owner self.queryset_proxy: QuerysetProxy = QuerysetProxy( - relation=self.relation, type_=type_ + relation=self.relation, to=to, type_=type_ ) self._related_field_name: Optional[str] = None @@ -48,6 +52,9 @@ class RelationProxy(list): return self._related_field_name + def __getitem__(self, item) -> "T": # type: ignore + return super().__getitem__(item) + def __getattribute__(self, item: str) -> Any: """ Since some QuerySetProxy methods overwrite builtin list methods we @@ -63,7 +70,7 @@ class RelationProxy(list): return getattr(self.queryset_proxy, item) return super().__getattribute__(item) - def __getattr__(self, item: str) -> Any: + def __getattr__(self, item: str) -> "T": """ Delegates calls for non existing attributes to QuerySetProxy. diff --git a/tests/test_types.py b/tests/test_types.py index e06daf1..29adfc1 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -86,11 +86,6 @@ async def test_types() -> None: reveal_type(publishers) # many to many reveal_type(publishers[0]) # item in m2m list # getting relation without __getattribute__ - to_model = Publisher.Meta.model_fields['authors'].to - reveal_type( - publisher2._extract_related_model_instead_of_field("authors") - ) # TODO: wrong - reveal_type(publisher.Meta.model_fields['authors'].to) # TODO: wrong reveal_type(authors) # reverse many to many # TODO: wrong reveal_type(book2) # queryset get reveal_type(books) # queryset all