improve relation_proxy types

This commit is contained in:
collerek
2021-03-19 17:13:59 +01:00
parent 929e979d37
commit 9c091afe35
6 changed files with 64 additions and 54 deletions

View File

@ -8,7 +8,8 @@ from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField, validate_not_allowed_fields
if TYPE_CHECKING: # pragma no cover
from ormar.models import Model
from ormar.models import Model, T
from ormar.relations.relation_proxy import RelationProxy
if sys.version_info < (3, 7):
ToType = Type["Model"]
@ -58,14 +59,14 @@ def populate_m2m_params_based_on_to_model(
def ManyToMany(
to: "ToType",
to: Type["T"],
through: Optional["ToType"] = None,
*,
name: str = None,
unique: bool = False,
virtual: bool = False,
**kwargs: Any,
):
) -> "RelationProxy[T]":
"""
Despite a name it's a function that returns constructed ManyToManyField.
This function is actually used in model declaration

View File

@ -101,7 +101,7 @@ def register_reverse_model_fields(model_field: "ForeignKeyField") -> None:
"""
related_name = model_field.get_related_name()
if model_field.is_multi:
model_field.to.Meta.model_fields[related_name] = ManyToMany(
model_field.to.Meta.model_fields[related_name] = ManyToMany( # type: ignore
model_field.owner,
through=model_field.through,
name=related_name,

View File

@ -2,13 +2,13 @@ from _weakref import CallableProxyType
from typing import ( # noqa: I100, I201
Any,
Dict,
List,
Generic, List,
MutableSequence,
Optional,
Sequence,
Set,
TYPE_CHECKING,
Union,
Type, TypeVar, Union,
cast,
)
@ -18,12 +18,14 @@ from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
if TYPE_CHECKING: # pragma no cover
from ormar.relations import Relation
from ormar.models import Model
from ormar.models import Model, T
from ormar.queryset import QuerySet
from ormar import RelationType
else:
T = TypeVar("T", bound="Model")
class QuerysetProxy:
class QuerysetProxy(Generic[T]):
"""
Exposes QuerySet methods on relations, but also handles creating and removing
of through Models for m2m relations.
@ -33,16 +35,20 @@ class QuerysetProxy:
relation: "Relation"
def __init__(
self, relation: "Relation", type_: "RelationType", qryset: "QuerySet" = None
self, relation: "Relation",
to: Type["T"],
type_: "RelationType",
qryset: "QuerySet[T]" = None
) -> None:
self.relation: Relation = relation
self._queryset: Optional["QuerySet"] = qryset
self._queryset: Optional["QuerySet[T]"] = qryset
self.type_: "RelationType" = type_
self._owner: Union[CallableProxyType, "Model"] = self.relation.manager.owner
self.related_field_name = self._owner.Meta.model_fields[
self.relation.field_name
].get_related_name()
self.related_field = self.relation.to.Meta.model_fields[self.related_field_name]
self.to: Type[T] = to
self.related_field = to.Meta.model_fields[self.related_field_name]
self.owner_pk_value = self._owner.pk
self.through_model_name = (
self.related_field.through.get_name()
@ -51,7 +57,7 @@ class QuerysetProxy:
)
@property
def queryset(self) -> "QuerySet":
def queryset(self) -> "QuerySet[T]":
"""
Returns queryset if it's set, AttributeError otherwise.
:return: QuerySet
@ -70,7 +76,7 @@ class QuerysetProxy:
"""
self._queryset = value
def _assign_child_to_parent(self, child: Optional["Model"]) -> None:
def _assign_child_to_parent(self, child: Optional["T"]) -> None:
"""
Registers child in parents RelationManager.
@ -83,7 +89,7 @@ class QuerysetProxy:
setattr(owner, rel_name, child)
def _register_related(
self, child: Union["Model", Sequence[Optional["Model"]]]
self, child: Union["T", Sequence[Optional["T"]]]
) -> None:
"""
Registers child/ children in parents RelationManager.
@ -96,7 +102,7 @@ class QuerysetProxy:
self._assign_child_to_parent(subchild)
else:
assert isinstance(child, ormar.Model)
child = cast("Model", child)
child = cast("T", child)
self._assign_child_to_parent(child)
def _clean_items_on_load(self) -> None:
@ -107,7 +113,7 @@ class QuerysetProxy:
for item in self.relation.related_models[:]:
self.relation.remove(item)
async def create_through_instance(self, child: "Model", **kwargs: Any) -> None:
async def create_through_instance(self, child: "T", **kwargs: Any) -> None:
"""
Crete a through model instance in the database for m2m relations.
@ -129,7 +135,7 @@ class QuerysetProxy:
)
await model_cls(**final_kwargs).save()
async def update_through_instance(self, child: "Model", **kwargs: Any) -> None:
async def update_through_instance(self, child: "T", **kwargs: Any) -> None:
"""
Updates a through model instance in the database for m2m relations.
@ -145,7 +151,7 @@ class QuerysetProxy:
through_model = await model_cls.objects.get(**rel_kwargs)
await through_model.update(**kwargs)
async def delete_through_instance(self, child: "Model") -> None:
async def delete_through_instance(self, child: "T") -> None:
"""
Removes through model instance from the database for m2m relations.
@ -254,7 +260,7 @@ class QuerysetProxy:
)
return await queryset.delete(**kwargs) # type: ignore
async def first(self, **kwargs: Any) -> "Model":
async def first(self, **kwargs: Any) -> "T":
"""
Gets the first row from the db ordered by primary key column ascending.
@ -272,7 +278,7 @@ class QuerysetProxy:
self._register_related(first)
return first
async def get(self, **kwargs: Any) -> "Model":
async def get(self, **kwargs: Any) -> "T":
"""
Get's the first row from the db meeting the criteria set by kwargs.
@ -296,7 +302,7 @@ class QuerysetProxy:
self._register_related(get)
return get
async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003
async def all(self, **kwargs: Any) -> Sequence[Optional["T"]]: # noqa: A003
"""
Returns all rows from a database for given model for set filter options.
@ -318,7 +324,7 @@ class QuerysetProxy:
self._register_related(all_items)
return all_items
async def create(self, **kwargs: Any) -> "Model":
async def create(self, **kwargs: Any) -> "T":
"""
Creates the model instance, saves it in a database and returns the updates model
(with pk populated if not passed and autoincrement is set).
@ -375,7 +381,7 @@ class QuerysetProxy:
)
return len(children)
async def get_or_create(self, **kwargs: Any) -> "Model":
async def get_or_create(self, **kwargs: Any) -> "T":
"""
Combination of create and get methods.
@ -393,7 +399,7 @@ class QuerysetProxy:
except ormar.NoMatch:
return await self.create(**kwargs)
async def update_or_create(self, **kwargs: Any) -> "Model":
async def update_or_create(self, **kwargs: Any) -> "T":
"""
Updates the model, or in case there is no match in database creates a new one.
@ -412,7 +418,7 @@ class QuerysetProxy:
model = await self.queryset.get(pk=kwargs[pk_name])
return await model.update(**kwargs)
def filter(self, *args: Any, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
def filter(self, *args: Any, **kwargs: Any) -> "QuerysetProxy[T]": # noqa: A003, A001
"""
Allows you to filter by any `Model` attribute/field
as well as to fetch instances, with a filter across an FK relationship.
@ -443,9 +449,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy
"""
queryset = self.queryset.filter(*args, **kwargs)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
return self.__class__(relation=self.relation, type_=self.type_, to=self.to, qryset=queryset)
def exclude(self, *args: Any, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
def exclude(self, *args: Any, **kwargs: Any) -> "QuerysetProxy[T]": # noqa: A003, A001
"""
Works exactly the same as filter and all modifiers (suffixes) are the same,
but returns a *not* condition.
@ -467,9 +473,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy
"""
queryset = self.queryset.exclude(*args, **kwargs)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def select_related(self, related: Union[List, str]) -> "QuerysetProxy":
def select_related(self, related: Union[List, str]) -> "QuerysetProxy[T]":
"""
Allows to prefetch related models during the same query.
@ -489,9 +495,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy
"""
queryset = self.queryset.select_related(related)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy":
def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy[T]":
"""
Allows to prefetch related models during query - but opposite to
`select_related` each subsequent model is fetched in a separate database query.
@ -512,9 +518,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy
"""
queryset = self.queryset.prefetch_related(related)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def paginate(self, page: int, page_size: int = 20) -> "QuerysetProxy":
def paginate(self, page: int, page_size: int = 20) -> "QuerysetProxy[T]":
"""
You can paginate the result which is a combination of offset and limit clauses.
Limit is set to page size and offset is set to (page-1) * page_size.
@ -529,9 +535,9 @@ class QuerysetProxy:
:rtype: QuerySet
"""
queryset = self.queryset.paginate(page=page, page_size=page_size)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def limit(self, limit_count: int) -> "QuerysetProxy":
def limit(self, limit_count: int) -> "QuerysetProxy[T]":
"""
You can limit the results to desired number of parent models.
@ -543,9 +549,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy
"""
queryset = self.queryset.limit(limit_count)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def offset(self, offset: int) -> "QuerysetProxy":
def offset(self, offset: int) -> "QuerysetProxy[T]":
"""
You can also offset the results by desired number of main models.
@ -557,9 +563,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy
"""
queryset = self.queryset.offset(offset)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy[T]":
"""
With `fields()` you can select subset of model columns to limit the data load.
@ -605,9 +611,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy
"""
queryset = self.queryset.fields(columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy[T]":
"""
With `exclude_fields()` you can select subset of model columns that will
be excluded to limit the data load.
@ -637,9 +643,9 @@ class QuerysetProxy:
:rtype: QuerysetProxy
"""
queryset = self.queryset.exclude_fields(columns=columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)
def order_by(self, columns: Union[List, str]) -> "QuerysetProxy":
def order_by(self, columns: Union[List, str]) -> "QuerysetProxy[T]":
"""
With `order_by()` you can order the results from database based on your
choice of fields.
@ -674,4 +680,4 @@ class QuerysetProxy:
:rtype: QuerysetProxy
"""
queryset = self.queryset.order_by(columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
return self.__class__(relation=self.relation, type_=self.type_,to=self.to, qryset=queryset)

View File

@ -63,7 +63,7 @@ class Relation:
self._through = through
self.field_name: str = field_name
self.related_models: Optional[Union[RelationProxy, "Model"]] = (
RelationProxy(relation=self, type_=type_, field_name=field_name)
RelationProxy(relation=self, type_=type_, to=to, field_name=field_name)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None
)
@ -94,6 +94,7 @@ class Relation:
self.related_models = RelationProxy(
relation=self,
type_=self._type,
to=self.to,
field_name=self.field_name,
data_=cleaned_data,
)

View File

@ -1,4 +1,4 @@
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, Generic, Optional, TYPE_CHECKING, Type, TypeVar
import ormar
from ormar.exceptions import NoMatch, RelationshipInstanceError
@ -6,11 +6,14 @@ from ormar.relations.querysetproxy import QuerysetProxy
if TYPE_CHECKING: # pragma no cover
from ormar import Model, RelationType
from ormar.models import T
from ormar.relations import Relation
from ormar.queryset import QuerySet
else:
T = TypeVar("T", bound="Model")
class RelationProxy(list):
class RelationProxy(Generic[T], list):
"""
Proxy of the Relation that is a list with special methods.
"""
@ -19,6 +22,7 @@ class RelationProxy(list):
self,
relation: "Relation",
type_: "RelationType",
to: Type["T"],
field_name: str,
data_: Any = None,
) -> None:
@ -28,7 +32,7 @@ class RelationProxy(list):
self.field_name = field_name
self._owner: "Model" = self.relation.manager.owner
self.queryset_proxy: QuerysetProxy = QuerysetProxy(
relation=self.relation, type_=type_
relation=self.relation, to=to, type_=type_
)
self._related_field_name: Optional[str] = None
@ -48,6 +52,9 @@ class RelationProxy(list):
return self._related_field_name
def __getitem__(self, item) -> "T": # type: ignore
return super().__getitem__(item)
def __getattribute__(self, item: str) -> Any:
"""
Since some QuerySetProxy methods overwrite builtin list methods we
@ -63,7 +70,7 @@ class RelationProxy(list):
return getattr(self.queryset_proxy, item)
return super().__getattribute__(item)
def __getattr__(self, item: str) -> Any:
def __getattr__(self, item: str) -> "T":
"""
Delegates calls for non existing attributes to QuerySetProxy.

View File

@ -86,11 +86,6 @@ async def test_types() -> None:
reveal_type(publishers) # many to many
reveal_type(publishers[0]) # item in m2m list
# getting relation without __getattribute__
to_model = Publisher.Meta.model_fields['authors'].to
reveal_type(
publisher2._extract_related_model_instead_of_field("authors")
) # TODO: wrong
reveal_type(publisher.Meta.model_fields['authors'].to) # TODO: wrong
reveal_type(authors) # reverse many to many # TODO: wrong
reveal_type(book2) # queryset get
reveal_type(books) # queryset all