improve relation_proxy types

This commit is contained in:
collerek
2021-03-19 17:13:59 +01:00
parent 929e979d37
commit 9c091afe35
6 changed files with 64 additions and 54 deletions

View File

@ -8,7 +8,8 @@ from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField, validate_not_allowed_fields from ormar.fields.foreign_key import ForeignKeyField, validate_not_allowed_fields
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models import Model from ormar.models import Model, T
from ormar.relations.relation_proxy import RelationProxy
if sys.version_info < (3, 7): if sys.version_info < (3, 7):
ToType = Type["Model"] ToType = Type["Model"]
@ -58,14 +59,14 @@ def populate_m2m_params_based_on_to_model(
def ManyToMany( def ManyToMany(
to: "ToType", to: Type["T"],
through: Optional["ToType"] = None, through: Optional["ToType"] = None,
*, *,
name: str = None, name: str = None,
unique: bool = False, unique: bool = False,
virtual: bool = False, virtual: bool = False,
**kwargs: Any, **kwargs: Any,
): ) -> "RelationProxy[T]":
""" """
Despite a name it's a function that returns constructed ManyToManyField. Despite a name it's a function that returns constructed ManyToManyField.
This function is actually used in model declaration This function is actually used in model declaration

View File

@ -101,7 +101,7 @@ def register_reverse_model_fields(model_field: "ForeignKeyField") -> None:
""" """
related_name = model_field.get_related_name() related_name = model_field.get_related_name()
if model_field.is_multi: if model_field.is_multi:
model_field.to.Meta.model_fields[related_name] = ManyToMany( model_field.to.Meta.model_fields[related_name] = ManyToMany( # type: ignore
model_field.owner, model_field.owner,
through=model_field.through, through=model_field.through,
name=related_name, name=related_name,

View File

@ -2,13 +2,13 @@ from _weakref import CallableProxyType
from typing import ( # noqa: I100, I201 from typing import ( # noqa: I100, I201
Any, Any,
Dict, Dict,
List, Generic, List,
MutableSequence, MutableSequence,
Optional, Optional,
Sequence, Sequence,
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
Union, Type, TypeVar, Union,
cast, cast,
) )
@ -18,12 +18,14 @@ from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.relations import Relation from ormar.relations import Relation
from ormar.models import Model from ormar.models import Model, T
from ormar.queryset import QuerySet from ormar.queryset import QuerySet
from ormar import RelationType from ormar import RelationType
else:
T = TypeVar("T", bound="Model")
class QuerysetProxy: class QuerysetProxy(Generic[T]):
""" """
Exposes QuerySet methods on relations, but also handles creating and removing Exposes QuerySet methods on relations, but also handles creating and removing
of through Models for m2m relations. of through Models for m2m relations.
@ -33,16 +35,20 @@ class QuerysetProxy:
relation: "Relation" relation: "Relation"
def __init__( def __init__(
self, relation: "Relation", type_: "RelationType", qryset: "QuerySet" = None self, relation: "Relation",
to: Type["T"],
type_: "RelationType",
qryset: "QuerySet[T]" = None
) -> None: ) -> None:
self.relation: Relation = relation self.relation: Relation = relation
self._queryset: Optional["QuerySet"] = qryset self._queryset: Optional["QuerySet[T]"] = qryset
self.type_: "RelationType" = type_ self.type_: "RelationType" = type_
self._owner: Union[CallableProxyType, "Model"] = self.relation.manager.owner self._owner: Union[CallableProxyType, "Model"] = self.relation.manager.owner
self.related_field_name = self._owner.Meta.model_fields[ self.related_field_name = self._owner.Meta.model_fields[
self.relation.field_name self.relation.field_name
].get_related_name() ].get_related_name()
self.related_field = self.relation.to.Meta.model_fields[self.related_field_name] self.to: Type[T] = to
self.related_field = to.Meta.model_fields[self.related_field_name]
self.owner_pk_value = self._owner.pk self.owner_pk_value = self._owner.pk
self.through_model_name = ( self.through_model_name = (
self.related_field.through.get_name() self.related_field.through.get_name()
@ -51,7 +57,7 @@ class QuerysetProxy:
) )
@property @property
def queryset(self) -> "QuerySet": def queryset(self) -> "QuerySet[T]":
""" """
Returns queryset if it's set, AttributeError otherwise. Returns queryset if it's set, AttributeError otherwise.
:return: QuerySet :return: QuerySet
@ -70,7 +76,7 @@ class QuerysetProxy:
""" """
self._queryset = value self._queryset = value
def _assign_child_to_parent(self, child: Optional["Model"]) -> None: def _assign_child_to_parent(self, child: Optional["T"]) -> None:
""" """
Registers child in parents RelationManager. Registers child in parents RelationManager.
@ -83,7 +89,7 @@ class QuerysetProxy:
setattr(owner, rel_name, child) setattr(owner, rel_name, child)
def _register_related( def _register_related(
self, child: Union["Model", Sequence[Optional["Model"]]] self, child: Union["T", Sequence[Optional["T"]]]
) -> None: ) -> None:
""" """
Registers child/ children in parents RelationManager. Registers child/ children in parents RelationManager.
@ -96,7 +102,7 @@ class QuerysetProxy:
self._assign_child_to_parent(subchild) self._assign_child_to_parent(subchild)
else: else:
assert isinstance(child, ormar.Model) assert isinstance(child, ormar.Model)
child = cast("Model", child) child = cast("T", child)
self._assign_child_to_parent(child) self._assign_child_to_parent(child)
def _clean_items_on_load(self) -> None: def _clean_items_on_load(self) -> None:
@ -107,7 +113,7 @@ class QuerysetProxy:
for item in self.relation.related_models[:]: for item in self.relation.related_models[:]:
self.relation.remove(item) self.relation.remove(item)
async def create_through_instance(self, child: "Model", **kwargs: Any) -> None: async def create_through_instance(self, child: "T", **kwargs: Any) -> None:
""" """
Crete a through model instance in the database for m2m relations. Crete a through model instance in the database for m2m relations.
@ -129,7 +135,7 @@ class QuerysetProxy:
) )
await model_cls(**final_kwargs).save() await model_cls(**final_kwargs).save()
async def update_through_instance(self, child: "Model", **kwargs: Any) -> None: async def update_through_instance(self, child: "T", **kwargs: Any) -> None:
""" """
Updates a through model instance in the database for m2m relations. Updates a through model instance in the database for m2m relations.
@ -145,7 +151,7 @@ class QuerysetProxy:
through_model = await model_cls.objects.get(**rel_kwargs) through_model = await model_cls.objects.get(**rel_kwargs)
await through_model.update(**kwargs) await through_model.update(**kwargs)
async def delete_through_instance(self, child: "Model") -> None: async def delete_through_instance(self, child: "T") -> None:
""" """
Removes through model instance from the database for m2m relations. Removes through model instance from the database for m2m relations.
@ -254,7 +260,7 @@ class QuerysetProxy:
) )
return await queryset.delete(**kwargs) # type: ignore return await queryset.delete(**kwargs) # type: ignore
async def first(self, **kwargs: Any) -> "Model": async def first(self, **kwargs: Any) -> "T":
""" """
Gets the first row from the db ordered by primary key column ascending. Gets the first row from the db ordered by primary key column ascending.
@ -272,7 +278,7 @@ class QuerysetProxy:
self._register_related(first) self._register_related(first)
return first return first
async def get(self, **kwargs: Any) -> "Model": async def get(self, **kwargs: Any) -> "T":
""" """
Get's the first row from the db meeting the criteria set by kwargs. Get's the first row from the db meeting the criteria set by kwargs.
@ -296,7 +302,7 @@ class QuerysetProxy:
self._register_related(get) self._register_related(get)
return get return get
async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003 async def all(self, **kwargs: Any) -> Sequence[Optional["T"]]: # noqa: A003
""" """
Returns all rows from a database for given model for set filter options. Returns all rows from a database for given model for set filter options.
@ -318,7 +324,7 @@ class QuerysetProxy:
self._register_related(all_items) self._register_related(all_items)
return all_items return all_items
async def create(self, **kwargs: Any) -> "Model": async def create(self, **kwargs: Any) -> "T":
""" """
Creates the model instance, saves it in a database and returns the updates model Creates the model instance, saves it in a database and returns the updates model
(with pk populated if not passed and autoincrement is set). (with pk populated if not passed and autoincrement is set).
@ -375,7 +381,7 @@ class QuerysetProxy:
) )
return len(children) return len(children)
async def get_or_create(self, **kwargs: Any) -> "Model": async def get_or_create(self, **kwargs: Any) -> "T":
""" """
Combination of create and get methods. Combination of create and get methods.
@ -393,7 +399,7 @@ class QuerysetProxy:
except ormar.NoMatch: except ormar.NoMatch:
return await self.create(**kwargs) return await self.create(**kwargs)
async def update_or_create(self, **kwargs: Any) -> "Model": async def update_or_create(self, **kwargs: Any) -> "T":
""" """
Updates the model, or in case there is no match in database creates a new one. Updates the model, or in case there is no match in database creates a new one.
@ -412,7 +418,7 @@ class QuerysetProxy:
model = await self.queryset.get(pk=kwargs[pk_name]) model = await self.queryset.get(pk=kwargs[pk_name])
return await model.update(**kwargs) return await model.update(**kwargs)
def filter(self, *args: Any, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001 def filter(self, *args: Any, **kwargs: Any) -> "QuerysetProxy[T]": # noqa: A003, A001
""" """
Allows you to filter by any `Model` attribute/field Allows you to filter by any `Model` attribute/field
as well as to fetch instances, with a filter across an FK relationship. as well as to fetch instances, with a filter across an FK relationship.
@ -443,9 +449,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.filter(*args, **kwargs) queryset = self.queryset.filter(*args, **kwargs)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(relation=self.relation, type_=self.type_, to=self.to, qryset=queryset)
def exclude(self, *args: Any, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001 def exclude(self, *args: Any, **kwargs: Any) -> "QuerysetProxy[T]": # noqa: A003, A001
""" """
Works exactly the same as filter and all modifiers (suffixes) are the same, Works exactly the same as filter and all modifiers (suffixes) are the same,
but returns a *not* condition. but returns a *not* condition.
@ -467,9 +473,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.exclude(*args, **kwargs) queryset = self.queryset.exclude(*args, **kwargs)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def select_related(self, related: Union[List, str]) -> "QuerysetProxy": def select_related(self, related: Union[List, str]) -> "QuerysetProxy[T]":
""" """
Allows to prefetch related models during the same query. Allows to prefetch related models during the same query.
@ -489,9 +495,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.select_related(related) queryset = self.queryset.select_related(related)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy": def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy[T]":
""" """
Allows to prefetch related models during query - but opposite to Allows to prefetch related models during query - but opposite to
`select_related` each subsequent model is fetched in a separate database query. `select_related` each subsequent model is fetched in a separate database query.
@ -512,9 +518,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.prefetch_related(related) queryset = self.queryset.prefetch_related(related)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def paginate(self, page: int, page_size: int = 20) -> "QuerysetProxy": def paginate(self, page: int, page_size: int = 20) -> "QuerysetProxy[T]":
""" """
You can paginate the result which is a combination of offset and limit clauses. You can paginate the result which is a combination of offset and limit clauses.
Limit is set to page size and offset is set to (page-1) * page_size. Limit is set to page size and offset is set to (page-1) * page_size.
@ -529,9 +535,9 @@ class QuerysetProxy:
:rtype: QuerySet :rtype: QuerySet
""" """
queryset = self.queryset.paginate(page=page, page_size=page_size) queryset = self.queryset.paginate(page=page, page_size=page_size)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def limit(self, limit_count: int) -> "QuerysetProxy": def limit(self, limit_count: int) -> "QuerysetProxy[T]":
""" """
You can limit the results to desired number of parent models. You can limit the results to desired number of parent models.
@ -543,9 +549,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.limit(limit_count) queryset = self.queryset.limit(limit_count)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def offset(self, offset: int) -> "QuerysetProxy": def offset(self, offset: int) -> "QuerysetProxy[T]":
""" """
You can also offset the results by desired number of main models. You can also offset the results by desired number of main models.
@ -557,9 +563,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.offset(offset) queryset = self.queryset.offset(offset)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy": def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy[T]":
""" """
With `fields()` you can select subset of model columns to limit the data load. With `fields()` you can select subset of model columns to limit the data load.
@ -605,9 +611,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.fields(columns) queryset = self.queryset.fields(columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy": def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy[T]":
""" """
With `exclude_fields()` you can select subset of model columns that will With `exclude_fields()` you can select subset of model columns that will
be excluded to limit the data load. be excluded to limit the data load.
@ -637,9 +643,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.exclude_fields(columns=columns) queryset = self.queryset.exclude_fields(columns=columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def order_by(self, columns: Union[List, str]) -> "QuerysetProxy": def order_by(self, columns: Union[List, str]) -> "QuerysetProxy[T]":
""" """
With `order_by()` you can order the results from database based on your With `order_by()` you can order the results from database based on your
choice of fields. choice of fields.
@ -674,4 +680,4 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.order_by(columns) queryset = self.queryset.order_by(columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)

View File

@ -63,7 +63,7 @@ class Relation:
self._through = through self._through = through
self.field_name: str = field_name self.field_name: str = field_name
self.related_models: Optional[Union[RelationProxy, "Model"]] = ( self.related_models: Optional[Union[RelationProxy, "Model"]] = (
RelationProxy(relation=self, type_=type_, field_name=field_name) RelationProxy(relation=self, type_=type_, to=to, field_name=field_name)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None else None
) )
@ -94,6 +94,7 @@ class Relation:
self.related_models = RelationProxy( self.related_models = RelationProxy(
relation=self, relation=self,
type_=self._type, type_=self._type,
to=self.to,
field_name=self.field_name, field_name=self.field_name,
data_=cleaned_data, data_=cleaned_data,
) )

View File

@ -1,4 +1,4 @@
from typing import Any, Optional, TYPE_CHECKING from typing import Any, Generic, Optional, TYPE_CHECKING, Type, TypeVar
import ormar import ormar
from ormar.exceptions import NoMatch, RelationshipInstanceError from ormar.exceptions import NoMatch, RelationshipInstanceError
@ -6,11 +6,14 @@ from ormar.relations.querysetproxy import QuerysetProxy
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model, RelationType from ormar import Model, RelationType
from ormar.models import T
from ormar.relations import Relation from ormar.relations import Relation
from ormar.queryset import QuerySet from ormar.queryset import QuerySet
else:
T = TypeVar("T", bound="Model")
class RelationProxy(list): class RelationProxy(Generic[T], list):
""" """
Proxy of the Relation that is a list with special methods. Proxy of the Relation that is a list with special methods.
""" """
@ -19,6 +22,7 @@ class RelationProxy(list):
self, self,
relation: "Relation", relation: "Relation",
type_: "RelationType", type_: "RelationType",
to: Type["T"],
field_name: str, field_name: str,
data_: Any = None, data_: Any = None,
) -> None: ) -> None:
@ -28,7 +32,7 @@ class RelationProxy(list):
self.field_name = field_name self.field_name = field_name
self._owner: "Model" = self.relation.manager.owner self._owner: "Model" = self.relation.manager.owner
self.queryset_proxy: QuerysetProxy = QuerysetProxy( self.queryset_proxy: QuerysetProxy = QuerysetProxy(
relation=self.relation, type_=type_ relation=self.relation, to=to, type_=type_
) )
self._related_field_name: Optional[str] = None self._related_field_name: Optional[str] = None
@ -48,6 +52,9 @@ class RelationProxy(list):
return self._related_field_name return self._related_field_name
def __getitem__(self, item) -> "T": # type: ignore
return super().__getitem__(item)
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:
""" """
Since some QuerySetProxy methods overwrite builtin list methods we Since some QuerySetProxy methods overwrite builtin list methods we
@ -63,7 +70,7 @@ class RelationProxy(list):
return getattr(self.queryset_proxy, item) return getattr(self.queryset_proxy, item)
return super().__getattribute__(item) return super().__getattribute__(item)
def __getattr__(self, item: str) -> Any: def __getattr__(self, item: str) -> "T":
""" """
Delegates calls for non existing attributes to QuerySetProxy. Delegates calls for non existing attributes to QuerySetProxy.

View File

@ -86,11 +86,6 @@ async def test_types() -> None:
reveal_type(publishers) # many to many reveal_type(publishers) # many to many
reveal_type(publishers[0]) # item in m2m list reveal_type(publishers[0]) # item in m2m list
# getting relation without __getattribute__ # getting relation without __getattribute__
to_model = Publisher.Meta.model_fields['authors'].to
reveal_type(
publisher2._extract_related_model_instead_of_field("authors")
) # TODO: wrong
reveal_type(publisher.Meta.model_fields['authors'].to) # TODO: wrong
reveal_type(authors) # reverse many to many # TODO: wrong reveal_type(authors) # reverse many to many # TODO: wrong
reveal_type(book2) # queryset get reveal_type(book2) # queryset get
reveal_type(books) # queryset all reveal_type(books) # queryset all