expose querysetproxy on reverse of foreignkey (virtual fk), add additional methods from queryset to querysetproxy

This commit is contained in:
collerek
2020-12-01 08:27:08 +01:00
parent b939a02ce0
commit 61da7b4418
11 changed files with 605 additions and 68 deletions

View File

@ -6,6 +6,10 @@ class ModelDefinitionError(AsyncOrmException):
pass
class ModelError(AsyncOrmException):
pass
class ModelNotSet(AsyncOrmException):
pass

View File

@ -7,6 +7,7 @@ from typing import (
Dict,
List,
Mapping,
MutableSequence,
Optional,
Sequence,
Set,
@ -22,6 +23,7 @@ import sqlalchemy
from pydantic import BaseModel
import ormar # noqa I100
from ormar.exceptions import ModelError
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.excludable import Excludable
@ -93,16 +95,21 @@ class NewBaseModel(
if "pk" in kwargs:
kwargs[self.Meta.pkname] = kwargs.pop("pk")
# build the models to set them and validate but don't register
new_kwargs = {
k: self._convert_json(
k,
self.Meta.model_fields[k].expand_relationship(
v, self, to_register=False
),
"dumps",
try:
new_kwargs = {
k: self._convert_json(
k,
self.Meta.model_fields[k].expand_relationship(
v, self, to_register=False
),
"dumps",
)
for k, v in kwargs.items()
}
except KeyError as e:
raise ModelError(
f"Unknown field '{e.args[0]}' for model {self.get_name(lower=False)}"
)
for k, v in kwargs.items()
}
values, fields_set, validation_error = pydantic.validate_model(
self, new_kwargs # type: ignore
@ -249,7 +256,9 @@ class NewBaseModel(
@staticmethod
def _extract_nested_models_from_list(
models: List, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None],
models: MutableSequence,
include: Union[Set, Dict, None],
exclude: Union[Set, Dict, None],
) -> List:
result = []
for model in models:
@ -282,7 +291,7 @@ class NewBaseModel(
if self.Meta.model_fields[field].virtual and nested:
continue
nested_model = getattr(self, field)
if isinstance(nested_model, list):
if isinstance(nested_model, MutableSequence):
dict_instance[field] = self._extract_nested_models_from_list(
models=nested_model,
include=self._skip_ellipsis(include, field),
@ -308,7 +317,7 @@ class NewBaseModel(
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
nested: bool = False
nested: bool = False,
) -> "DictStrAny": # noqa: A003'
dict_instance = super().dict(
include=include,

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Union
from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union
try:
from typing import Protocol
@ -6,14 +6,21 @@ except ImportError: # pragma: nocover
from typing_extensions import Protocol # type: ignore
if TYPE_CHECKING: # noqa: C901; #pragma nocover
from ormar import QuerySet, Model
from ormar import Model
from ormar.relations.querysetproxy import QuerysetProxy
class QuerySetProtocol(Protocol): # pragma: nocover
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003, A001
def filter(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
...
def select_related(self, related: Union[List, str]) -> "QuerySet":
def exclude(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
...
def select_related(self, related: Union[List, str]) -> "QuerysetProxy":
...
def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy":
...
async def exists(self) -> bool:
@ -25,10 +32,10 @@ class QuerySetProtocol(Protocol): # pragma: nocover
async def clear(self) -> int:
...
def limit(self, limit_count: int) -> "QuerySet":
def limit(self, limit_count: int) -> "QuerysetProxy":
...
def offset(self, offset: int) -> "QuerySet":
def offset(self, offset: int) -> "QuerysetProxy":
...
async def first(self, **kwargs: Any) -> "Model":
@ -44,3 +51,18 @@ class QuerySetProtocol(Protocol): # pragma: nocover
async def create(self, **kwargs: Any) -> "Model":
...
async def get_or_create(self, **kwargs: Any) -> "Model":
...
async def update_or_create(self, **kwargs: Any) -> "Model":
...
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
...
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
...
def order_by(self, columns: Union[List, str]) -> "QuerysetProxy":
...

View File

@ -1,4 +1,15 @@
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union
from typing import (
Any,
Dict,
List,
MutableSequence,
Optional,
Sequence,
Set,
TYPE_CHECKING,
TypeVar,
Union,
)
import ormar
@ -6,6 +17,7 @@ if TYPE_CHECKING: # pragma no cover
from ormar.relations import Relation
from ormar.models import Model
from ormar.queryset import QuerySet
from ormar import RelationType
T = TypeVar("T", bound=Model)
@ -14,9 +26,17 @@ class QuerysetProxy(ormar.QuerySetProtocol):
if TYPE_CHECKING: # pragma no cover
relation: "Relation"
def __init__(self, relation: "Relation") -> None:
def __init__(
self, relation: "Relation", type_: "RelationType", qryset: "QuerySet" = None
) -> None:
self.relation: Relation = relation
self._queryset: Optional["QuerySet"] = None
self._queryset: Optional["QuerySet"] = qryset
self.type_: "RelationType" = type_
self._owner: "Model" = self.relation.manager.owner
self.related_field = self._owner.resolve_relation_field(
self.relation.to, self._owner
)
self.owner_pk_value = self._owner.pk
@property
def queryset(self) -> "QuerySet":
@ -30,7 +50,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
def _assign_child_to_parent(self, child: Optional["T"]) -> None:
if child:
owner = self.relation._owner
owner = self._owner
rel_name = owner.resolve_relation_name(owner, child)
setattr(owner, rel_name, child)
@ -42,27 +62,26 @@ class QuerysetProxy(ormar.QuerySetProtocol):
assert isinstance(child, ormar.Model)
self._assign_child_to_parent(child)
def _clean_items_on_load(self) -> None:
if isinstance(self.relation.related_models, MutableSequence):
for item in self.relation.related_models[:]:
self.relation.remove(item)
async def create_through_instance(self, child: "T") -> None:
queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name()
owner_column = self._owner.get_name()
child_column = child.get_name()
kwargs = {owner_column: self.relation._owner, child_column: child}
kwargs = {owner_column: self._owner, child_column: child}
await queryset.create(**kwargs)
async def delete_through_instance(self, child: "T") -> None:
queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name()
owner_column = self._owner.get_name()
child_column = child.get_name()
kwargs = {owner_column: self.relation._owner, child_column: child}
kwargs = {owner_column: self._owner, child_column: child}
link_instance = await queryset.filter(**kwargs).get() # type: ignore
await link_instance.delete()
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
return self.queryset.filter(**kwargs)
def select_related(self, related: Union[List, str]) -> "QuerySet":
return self.queryset.select_related(related)
async def exists(self) -> bool:
return await self.queryset.exists()
@ -70,17 +89,16 @@ class QuerysetProxy(ormar.QuerySetProtocol):
return await self.queryset.count()
async def clear(self) -> int:
queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name()
kwargs = {owner_column: self.relation._owner}
if self.type_ == ormar.RelationType.MULTIPLE:
queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self._owner.get_name()
else:
queryset = ormar.QuerySet(model_cls=self.relation.to)
owner_column = self.related_field.name
kwargs = {owner_column: self._owner}
self._clean_items_on_load()
return await queryset.delete(**kwargs) # type: ignore
def limit(self, limit_count: int) -> "QuerySet":
return self.queryset.limit(limit_count)
def offset(self, offset: int) -> "QuerySet":
return self.queryset.offset(offset)
async def first(self, **kwargs: Any) -> "Model":
first = await self.queryset.first(**kwargs)
self._register_related(first)
@ -88,16 +106,72 @@ class QuerysetProxy(ormar.QuerySetProtocol):
async def get(self, **kwargs: Any) -> "Model":
get = await self.queryset.get(**kwargs)
self._clean_items_on_load()
self._register_related(get)
return get
async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003
all_items = await self.queryset.all(**kwargs)
self._clean_items_on_load()
self._register_related(all_items)
return all_items
async def create(self, **kwargs: Any) -> "Model":
create = await self.queryset.create(**kwargs)
self._register_related(create)
await self.create_through_instance(create)
return create
if self.type_ == ormar.RelationType.REVERSE:
kwargs[self.related_field.name] = self._owner
created = await self.queryset.create(**kwargs)
self._register_related(created)
if self.type_ == ormar.RelationType.MULTIPLE:
await self.create_through_instance(created)
return created
async def get_or_create(self, **kwargs: Any) -> "Model":
try:
return await self.get(**kwargs)
except ormar.NoMatch:
return await self.create(**kwargs)
async def update_or_create(self, **kwargs: Any) -> "Model":
pk_name = self.queryset.model_meta.pkname
if "pk" in kwargs:
kwargs[pk_name] = kwargs.pop("pk")
if pk_name not in kwargs or kwargs.get(pk_name) is None:
return await self.create(**kwargs)
model = await self.queryset.get(pk=kwargs[pk_name])
return await model.update(**kwargs)
def filter(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
queryset = self.queryset.filter(**kwargs)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def exclude(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
queryset = self.queryset.exclude(**kwargs)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def select_related(self, related: Union[List, str]) -> "QuerysetProxy":
queryset = self.queryset.select_related(related)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy":
queryset = self.queryset.prefetch_related(related)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def limit(self, limit_count: int) -> "QuerysetProxy":
queryset = self.queryset.limit(limit_count)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def offset(self, offset: int) -> "QuerysetProxy":
queryset = self.queryset.offset(offset)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
queryset = self.queryset.fields(columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
queryset = self.queryset.exclude_fields(columns=columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def order_by(self, columns: Union[List, str]) -> "QuerysetProxy":
queryset = self.queryset.order_by(columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)

View File

@ -34,7 +34,7 @@ class Relation:
self.to: Type["T"] = to
self.through: Optional[Type["T"]] = through
self.related_models: Optional[Union[RelationProxy, "T"]] = (
RelationProxy(relation=self)
RelationProxy(relation=self, type_=type_)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None
)

View File

@ -65,8 +65,6 @@ class RelationsManager:
parent_relation = parent._orm._get(child_name)
if parent_relation:
# print('missing', child_name)
# parent_relation = register_missing_relation(parent, child, child_name)
parent_relation.add(child) # type: ignore
child_relation = child._orm._get(to_name)

View File

@ -5,17 +5,18 @@ from ormar.exceptions import NoMatch, RelationshipInstanceError
from ormar.relations.querysetproxy import QuerysetProxy
if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar import Model, RelationType
from ormar.relations import Relation
from ormar.queryset import QuerySet
class RelationProxy(list):
def __init__(self, relation: "Relation") -> None:
super(RelationProxy, self).__init__()
self.relation: Relation = relation
def __init__(self, relation: "Relation", type_: "RelationType") -> None:
super().__init__()
self.relation: "Relation" = relation
self.type_: "RelationType" = type_
self._owner: "Model" = self.relation.manager.owner
self.queryset_proxy = QuerysetProxy(relation=self.relation)
self.queryset_proxy = QuerysetProxy(relation=self.relation, type_=type_)
def __getattribute__(self, item: str) -> Any:
if item in ["count", "clear"]:
@ -38,17 +39,19 @@ class RelationProxy(list):
)
def _set_queryset(self) -> "QuerySet":
owner_table = self.relation._owner.Meta.tablename
pkname = self.relation._owner.get_column_alias(self.relation._owner.Meta.pkname)
pk_value = self.relation._owner.pk
related_field = self._owner.resolve_relation_field(
self.relation.to, self._owner
)
pkname = self._owner.get_column_alias(self._owner.Meta.pkname)
pk_value = self._owner.pk
if not pk_value:
raise RelationshipInstanceError(
"You cannot query many to many relationship on unsaved model."
"You cannot query relationships from unsaved model."
)
kwargs = {f"{owner_table}__{pkname}": pk_value}
kwargs = {f"{related_field.get_alias()}__{pkname}": pk_value}
queryset = (
ormar.QuerySet(model_cls=self.relation.to)
.select_related(owner_table)
.select_related(related_field.name)
.filter(**kwargs)
)
return queryset
@ -67,14 +70,21 @@ class RelationProxy(list):
f"{self._owner.get_name()} does not have relation {rel_name}"
)
relation.remove(self._owner)
if self.relation._type == ormar.RelationType.MULTIPLE:
self.relation.remove(item)
if self.type_ == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.delete_through_instance(item)
def append(self, item: "Model") -> None:
super().append(item)
else:
setattr(item, rel_name, None)
await item.update()
async def add(self, item: "Model") -> None:
if self.relation._type == ormar.RelationType.MULTIPLE:
if self.type_ == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner)
setattr(item, rel_name, self._owner)
rel_name = item.resolve_relation_name(item, self._owner)
setattr(item, rel_name, self._owner)
else:
related_field = self._owner.resolve_relation_field(
self.relation.to, self._owner
)
setattr(item, related_field.name, self._owner)
await item.update()