expose querysetproxy on reverse of foreignkey (virtual fk), add additional methods from queryset to querysetproxy
This commit is contained in:
@ -212,7 +212,7 @@ You can use special filter suffix to change the filter operands:
|
||||
|
||||
* exact - like `album__name__exact='Malibu'` (exact match)
|
||||
* iexact - like `album__name__iexact='malibu'` (exact match case insensitive)
|
||||
* contains - like `album__name__conatins='Mal'` (sql like)
|
||||
* contains - like `album__name__contains='Mal'` (sql like)
|
||||
* icontains - like `album__name__icontains='mal'` (sql like case insensitive)
|
||||
* in - like `album__name__in=['Malibu', 'Barclay']` (sql in)
|
||||
* gt - like `position__gt=3` (sql >)
|
||||
|
||||
@ -6,6 +6,10 @@ class ModelDefinitionError(AsyncOrmException):
|
||||
pass
|
||||
|
||||
|
||||
class ModelError(AsyncOrmException):
|
||||
pass
|
||||
|
||||
|
||||
class ModelNotSet(AsyncOrmException):
|
||||
pass
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
MutableSequence,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
@ -22,6 +23,7 @@ import sqlalchemy
|
||||
from pydantic import BaseModel
|
||||
|
||||
import ormar # noqa I100
|
||||
from ormar.exceptions import ModelError
|
||||
from ormar.fields import BaseField
|
||||
from ormar.fields.foreign_key import ForeignKeyField
|
||||
from ormar.models.excludable import Excludable
|
||||
@ -93,6 +95,7 @@ class NewBaseModel(
|
||||
if "pk" in kwargs:
|
||||
kwargs[self.Meta.pkname] = kwargs.pop("pk")
|
||||
# build the models to set them and validate but don't register
|
||||
try:
|
||||
new_kwargs = {
|
||||
k: self._convert_json(
|
||||
k,
|
||||
@ -103,6 +106,10 @@ class NewBaseModel(
|
||||
)
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
except KeyError as e:
|
||||
raise ModelError(
|
||||
f"Unknown field '{e.args[0]}' for model {self.get_name(lower=False)}"
|
||||
)
|
||||
|
||||
values, fields_set, validation_error = pydantic.validate_model(
|
||||
self, new_kwargs # type: ignore
|
||||
@ -249,7 +256,9 @@ class NewBaseModel(
|
||||
|
||||
@staticmethod
|
||||
def _extract_nested_models_from_list(
|
||||
models: List, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None],
|
||||
models: MutableSequence,
|
||||
include: Union[Set, Dict, None],
|
||||
exclude: Union[Set, Dict, None],
|
||||
) -> List:
|
||||
result = []
|
||||
for model in models:
|
||||
@ -282,7 +291,7 @@ class NewBaseModel(
|
||||
if self.Meta.model_fields[field].virtual and nested:
|
||||
continue
|
||||
nested_model = getattr(self, field)
|
||||
if isinstance(nested_model, list):
|
||||
if isinstance(nested_model, MutableSequence):
|
||||
dict_instance[field] = self._extract_nested_models_from_list(
|
||||
models=nested_model,
|
||||
include=self._skip_ellipsis(include, field),
|
||||
@ -308,7 +317,7 @@ class NewBaseModel(
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
nested: bool = False
|
||||
nested: bool = False,
|
||||
) -> "DictStrAny": # noqa: A003'
|
||||
dict_instance = super().dict(
|
||||
include=include,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union
|
||||
|
||||
try:
|
||||
from typing import Protocol
|
||||
@ -6,14 +6,21 @@ except ImportError: # pragma: nocover
|
||||
from typing_extensions import Protocol # type: ignore
|
||||
|
||||
if TYPE_CHECKING: # noqa: C901; #pragma nocover
|
||||
from ormar import QuerySet, Model
|
||||
from ormar import Model
|
||||
from ormar.relations.querysetproxy import QuerysetProxy
|
||||
|
||||
|
||||
class QuerySetProtocol(Protocol): # pragma: nocover
|
||||
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003, A001
|
||||
def filter(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
|
||||
...
|
||||
|
||||
def select_related(self, related: Union[List, str]) -> "QuerySet":
|
||||
def exclude(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
|
||||
...
|
||||
|
||||
def select_related(self, related: Union[List, str]) -> "QuerysetProxy":
|
||||
...
|
||||
|
||||
def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy":
|
||||
...
|
||||
|
||||
async def exists(self) -> bool:
|
||||
@ -25,10 +32,10 @@ class QuerySetProtocol(Protocol): # pragma: nocover
|
||||
async def clear(self) -> int:
|
||||
...
|
||||
|
||||
def limit(self, limit_count: int) -> "QuerySet":
|
||||
def limit(self, limit_count: int) -> "QuerysetProxy":
|
||||
...
|
||||
|
||||
def offset(self, offset: int) -> "QuerySet":
|
||||
def offset(self, offset: int) -> "QuerysetProxy":
|
||||
...
|
||||
|
||||
async def first(self, **kwargs: Any) -> "Model":
|
||||
@ -44,3 +51,18 @@ class QuerySetProtocol(Protocol): # pragma: nocover
|
||||
|
||||
async def create(self, **kwargs: Any) -> "Model":
|
||||
...
|
||||
|
||||
async def get_or_create(self, **kwargs: Any) -> "Model":
|
||||
...
|
||||
|
||||
async def update_or_create(self, **kwargs: Any) -> "Model":
|
||||
...
|
||||
|
||||
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
|
||||
...
|
||||
|
||||
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
|
||||
...
|
||||
|
||||
def order_by(self, columns: Union[List, str]) -> "QuerysetProxy":
|
||||
...
|
||||
|
||||
@ -1,4 +1,15 @@
|
||||
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
MutableSequence,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import ormar
|
||||
|
||||
@ -6,6 +17,7 @@ if TYPE_CHECKING: # pragma no cover
|
||||
from ormar.relations import Relation
|
||||
from ormar.models import Model
|
||||
from ormar.queryset import QuerySet
|
||||
from ormar import RelationType
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
|
||||
@ -14,9 +26,17 @@ class QuerysetProxy(ormar.QuerySetProtocol):
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
relation: "Relation"
|
||||
|
||||
def __init__(self, relation: "Relation") -> None:
|
||||
def __init__(
|
||||
self, relation: "Relation", type_: "RelationType", qryset: "QuerySet" = None
|
||||
) -> None:
|
||||
self.relation: Relation = relation
|
||||
self._queryset: Optional["QuerySet"] = None
|
||||
self._queryset: Optional["QuerySet"] = qryset
|
||||
self.type_: "RelationType" = type_
|
||||
self._owner: "Model" = self.relation.manager.owner
|
||||
self.related_field = self._owner.resolve_relation_field(
|
||||
self.relation.to, self._owner
|
||||
)
|
||||
self.owner_pk_value = self._owner.pk
|
||||
|
||||
@property
|
||||
def queryset(self) -> "QuerySet":
|
||||
@ -30,7 +50,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
|
||||
|
||||
def _assign_child_to_parent(self, child: Optional["T"]) -> None:
|
||||
if child:
|
||||
owner = self.relation._owner
|
||||
owner = self._owner
|
||||
rel_name = owner.resolve_relation_name(owner, child)
|
||||
setattr(owner, rel_name, child)
|
||||
|
||||
@ -42,27 +62,26 @@ class QuerysetProxy(ormar.QuerySetProtocol):
|
||||
assert isinstance(child, ormar.Model)
|
||||
self._assign_child_to_parent(child)
|
||||
|
||||
def _clean_items_on_load(self) -> None:
|
||||
if isinstance(self.relation.related_models, MutableSequence):
|
||||
for item in self.relation.related_models[:]:
|
||||
self.relation.remove(item)
|
||||
|
||||
async def create_through_instance(self, child: "T") -> None:
|
||||
queryset = ormar.QuerySet(model_cls=self.relation.through)
|
||||
owner_column = self.relation._owner.get_name()
|
||||
owner_column = self._owner.get_name()
|
||||
child_column = child.get_name()
|
||||
kwargs = {owner_column: self.relation._owner, child_column: child}
|
||||
kwargs = {owner_column: self._owner, child_column: child}
|
||||
await queryset.create(**kwargs)
|
||||
|
||||
async def delete_through_instance(self, child: "T") -> None:
|
||||
queryset = ormar.QuerySet(model_cls=self.relation.through)
|
||||
owner_column = self.relation._owner.get_name()
|
||||
owner_column = self._owner.get_name()
|
||||
child_column = child.get_name()
|
||||
kwargs = {owner_column: self.relation._owner, child_column: child}
|
||||
kwargs = {owner_column: self._owner, child_column: child}
|
||||
link_instance = await queryset.filter(**kwargs).get() # type: ignore
|
||||
await link_instance.delete()
|
||||
|
||||
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
|
||||
return self.queryset.filter(**kwargs)
|
||||
|
||||
def select_related(self, related: Union[List, str]) -> "QuerySet":
|
||||
return self.queryset.select_related(related)
|
||||
|
||||
async def exists(self) -> bool:
|
||||
return await self.queryset.exists()
|
||||
|
||||
@ -70,17 +89,16 @@ class QuerysetProxy(ormar.QuerySetProtocol):
|
||||
return await self.queryset.count()
|
||||
|
||||
async def clear(self) -> int:
|
||||
if self.type_ == ormar.RelationType.MULTIPLE:
|
||||
queryset = ormar.QuerySet(model_cls=self.relation.through)
|
||||
owner_column = self.relation._owner.get_name()
|
||||
kwargs = {owner_column: self.relation._owner}
|
||||
owner_column = self._owner.get_name()
|
||||
else:
|
||||
queryset = ormar.QuerySet(model_cls=self.relation.to)
|
||||
owner_column = self.related_field.name
|
||||
kwargs = {owner_column: self._owner}
|
||||
self._clean_items_on_load()
|
||||
return await queryset.delete(**kwargs) # type: ignore
|
||||
|
||||
def limit(self, limit_count: int) -> "QuerySet":
|
||||
return self.queryset.limit(limit_count)
|
||||
|
||||
def offset(self, offset: int) -> "QuerySet":
|
||||
return self.queryset.offset(offset)
|
||||
|
||||
async def first(self, **kwargs: Any) -> "Model":
|
||||
first = await self.queryset.first(**kwargs)
|
||||
self._register_related(first)
|
||||
@ -88,16 +106,72 @@ class QuerysetProxy(ormar.QuerySetProtocol):
|
||||
|
||||
async def get(self, **kwargs: Any) -> "Model":
|
||||
get = await self.queryset.get(**kwargs)
|
||||
self._clean_items_on_load()
|
||||
self._register_related(get)
|
||||
return get
|
||||
|
||||
async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003
|
||||
all_items = await self.queryset.all(**kwargs)
|
||||
self._clean_items_on_load()
|
||||
self._register_related(all_items)
|
||||
return all_items
|
||||
|
||||
async def create(self, **kwargs: Any) -> "Model":
|
||||
create = await self.queryset.create(**kwargs)
|
||||
self._register_related(create)
|
||||
await self.create_through_instance(create)
|
||||
return create
|
||||
if self.type_ == ormar.RelationType.REVERSE:
|
||||
kwargs[self.related_field.name] = self._owner
|
||||
created = await self.queryset.create(**kwargs)
|
||||
self._register_related(created)
|
||||
if self.type_ == ormar.RelationType.MULTIPLE:
|
||||
await self.create_through_instance(created)
|
||||
return created
|
||||
|
||||
async def get_or_create(self, **kwargs: Any) -> "Model":
|
||||
try:
|
||||
return await self.get(**kwargs)
|
||||
except ormar.NoMatch:
|
||||
return await self.create(**kwargs)
|
||||
|
||||
async def update_or_create(self, **kwargs: Any) -> "Model":
|
||||
pk_name = self.queryset.model_meta.pkname
|
||||
if "pk" in kwargs:
|
||||
kwargs[pk_name] = kwargs.pop("pk")
|
||||
if pk_name not in kwargs or kwargs.get(pk_name) is None:
|
||||
return await self.create(**kwargs)
|
||||
model = await self.queryset.get(pk=kwargs[pk_name])
|
||||
return await model.update(**kwargs)
|
||||
|
||||
def filter(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
|
||||
queryset = self.queryset.filter(**kwargs)
|
||||
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
|
||||
|
||||
def exclude(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
|
||||
queryset = self.queryset.exclude(**kwargs)
|
||||
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
|
||||
|
||||
def select_related(self, related: Union[List, str]) -> "QuerysetProxy":
|
||||
queryset = self.queryset.select_related(related)
|
||||
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
|
||||
|
||||
def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy":
|
||||
queryset = self.queryset.prefetch_related(related)
|
||||
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
|
||||
|
||||
def limit(self, limit_count: int) -> "QuerysetProxy":
|
||||
queryset = self.queryset.limit(limit_count)
|
||||
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
|
||||
|
||||
def offset(self, offset: int) -> "QuerysetProxy":
|
||||
queryset = self.queryset.offset(offset)
|
||||
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
|
||||
|
||||
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
|
||||
queryset = self.queryset.fields(columns)
|
||||
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
|
||||
|
||||
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
|
||||
queryset = self.queryset.exclude_fields(columns=columns)
|
||||
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
|
||||
|
||||
def order_by(self, columns: Union[List, str]) -> "QuerysetProxy":
|
||||
queryset = self.queryset.order_by(columns)
|
||||
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
|
||||
|
||||
@ -34,7 +34,7 @@ class Relation:
|
||||
self.to: Type["T"] = to
|
||||
self.through: Optional[Type["T"]] = through
|
||||
self.related_models: Optional[Union[RelationProxy, "T"]] = (
|
||||
RelationProxy(relation=self)
|
||||
RelationProxy(relation=self, type_=type_)
|
||||
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
|
||||
else None
|
||||
)
|
||||
|
||||
@ -65,8 +65,6 @@ class RelationsManager:
|
||||
|
||||
parent_relation = parent._orm._get(child_name)
|
||||
if parent_relation:
|
||||
# print('missing', child_name)
|
||||
# parent_relation = register_missing_relation(parent, child, child_name)
|
||||
parent_relation.add(child) # type: ignore
|
||||
|
||||
child_relation = child._orm._get(to_name)
|
||||
|
||||
@ -5,17 +5,18 @@ from ormar.exceptions import NoMatch, RelationshipInstanceError
|
||||
from ormar.relations.querysetproxy import QuerysetProxy
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
from ormar import Model, RelationType
|
||||
from ormar.relations import Relation
|
||||
from ormar.queryset import QuerySet
|
||||
|
||||
|
||||
class RelationProxy(list):
|
||||
def __init__(self, relation: "Relation") -> None:
|
||||
super(RelationProxy, self).__init__()
|
||||
self.relation: Relation = relation
|
||||
def __init__(self, relation: "Relation", type_: "RelationType") -> None:
|
||||
super().__init__()
|
||||
self.relation: "Relation" = relation
|
||||
self.type_: "RelationType" = type_
|
||||
self._owner: "Model" = self.relation.manager.owner
|
||||
self.queryset_proxy = QuerysetProxy(relation=self.relation)
|
||||
self.queryset_proxy = QuerysetProxy(relation=self.relation, type_=type_)
|
||||
|
||||
def __getattribute__(self, item: str) -> Any:
|
||||
if item in ["count", "clear"]:
|
||||
@ -38,17 +39,19 @@ class RelationProxy(list):
|
||||
)
|
||||
|
||||
def _set_queryset(self) -> "QuerySet":
|
||||
owner_table = self.relation._owner.Meta.tablename
|
||||
pkname = self.relation._owner.get_column_alias(self.relation._owner.Meta.pkname)
|
||||
pk_value = self.relation._owner.pk
|
||||
related_field = self._owner.resolve_relation_field(
|
||||
self.relation.to, self._owner
|
||||
)
|
||||
pkname = self._owner.get_column_alias(self._owner.Meta.pkname)
|
||||
pk_value = self._owner.pk
|
||||
if not pk_value:
|
||||
raise RelationshipInstanceError(
|
||||
"You cannot query many to many relationship on unsaved model."
|
||||
"You cannot query relationships from unsaved model."
|
||||
)
|
||||
kwargs = {f"{owner_table}__{pkname}": pk_value}
|
||||
kwargs = {f"{related_field.get_alias()}__{pkname}": pk_value}
|
||||
queryset = (
|
||||
ormar.QuerySet(model_cls=self.relation.to)
|
||||
.select_related(owner_table)
|
||||
.select_related(related_field.name)
|
||||
.filter(**kwargs)
|
||||
)
|
||||
return queryset
|
||||
@ -67,14 +70,21 @@ class RelationProxy(list):
|
||||
f"{self._owner.get_name()} does not have relation {rel_name}"
|
||||
)
|
||||
relation.remove(self._owner)
|
||||
if self.relation._type == ormar.RelationType.MULTIPLE:
|
||||
self.relation.remove(item)
|
||||
if self.type_ == ormar.RelationType.MULTIPLE:
|
||||
await self.queryset_proxy.delete_through_instance(item)
|
||||
|
||||
def append(self, item: "Model") -> None:
|
||||
super().append(item)
|
||||
else:
|
||||
setattr(item, rel_name, None)
|
||||
await item.update()
|
||||
|
||||
async def add(self, item: "Model") -> None:
|
||||
if self.relation._type == ormar.RelationType.MULTIPLE:
|
||||
if self.type_ == ormar.RelationType.MULTIPLE:
|
||||
await self.queryset_proxy.create_through_instance(item)
|
||||
rel_name = item.resolve_relation_name(item, self._owner)
|
||||
setattr(item, rel_name, self._owner)
|
||||
else:
|
||||
related_field = self._owner.resolve_relation_field(
|
||||
self.relation.to, self._owner
|
||||
)
|
||||
setattr(item, related_field.name, self._owner)
|
||||
await item.update()
|
||||
|
||||
@ -9,7 +9,7 @@ import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import QueryDefinitionError, NoMatch
|
||||
from ormar.exceptions import QueryDefinitionError, NoMatch, ModelError
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
@ -117,6 +117,11 @@ def test_model_class():
|
||||
assert isinstance(User.Meta.table, sqlalchemy.Table)
|
||||
|
||||
|
||||
def test_wrong_field_name():
|
||||
with pytest.raises(ModelError):
|
||||
User(non_existing_pk=1)
|
||||
|
||||
|
||||
def test_model_pk():
|
||||
user = User(pk=1)
|
||||
assert user.pk == 1
|
||||
|
||||
182
tests/test_queryproxy_on_m2m_models.py
Normal file
182
tests/test_queryproxy_on_m2m_models.py
Normal file
@ -0,0 +1,182 @@
|
||||
import asyncio
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import databases
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
class Subject(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "subjects"
|
||||
database = database
|
||||
metadata = metadata
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=80)
|
||||
|
||||
|
||||
class Author(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "authors"
|
||||
database = database
|
||||
metadata = metadata
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
first_name: str = ormar.String(max_length=80)
|
||||
last_name: str = ormar.String(max_length=80)
|
||||
|
||||
|
||||
class Category(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "categories"
|
||||
database = database
|
||||
metadata = metadata
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=40)
|
||||
sort_order: int = ormar.Integer(nullable=True)
|
||||
subject: Optional[Subject] = ormar.ForeignKey(Subject)
|
||||
|
||||
|
||||
class PostCategory(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "posts_categories"
|
||||
database = database
|
||||
metadata = metadata
|
||||
|
||||
|
||||
class Post(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "posts"
|
||||
database = database
|
||||
metadata = metadata
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
title: str = ormar.String(max_length=200)
|
||||
categories: Optional[Union[Category, List[Category]]] = ormar.ManyToMany(
|
||||
Category, through=PostCategory
|
||||
)
|
||||
author: Optional[Author] = ormar.ForeignKey(Author)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
loop = asyncio.get_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
async def create_test_database():
|
||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||
metadata.create_all(engine)
|
||||
yield
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queryset_methods():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
guido = await Author.objects.create(
|
||||
first_name="Guido", last_name="Van Rossum"
|
||||
)
|
||||
subject = await Subject(name="Random").save()
|
||||
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||
news = await Category.objects.create(
|
||||
name="News", sort_order=1, subject=subject
|
||||
)
|
||||
breaking = await Category.objects.create(
|
||||
name="Breaking", sort_order=3, subject=subject
|
||||
)
|
||||
|
||||
# Add a category to a post.
|
||||
await post.categories.add(news)
|
||||
await post.categories.add(breaking)
|
||||
|
||||
category = await post.categories.get_or_create(name="News")
|
||||
assert category == news
|
||||
assert len(post.categories) == 1
|
||||
|
||||
category = await post.categories.get_or_create(name="Breaking News")
|
||||
assert category != breaking
|
||||
assert category.pk is not None
|
||||
assert len(post.categories) == 2
|
||||
|
||||
await post.categories.update_or_create(pk=category.pk, name="Urgent News")
|
||||
assert len(post.categories) == 2
|
||||
cat = await post.categories.get_or_create(name="Urgent News")
|
||||
assert cat.pk == category.pk
|
||||
assert len(post.categories) == 1
|
||||
|
||||
await post.categories.remove(cat)
|
||||
await cat.delete()
|
||||
|
||||
assert len(post.categories) == 0
|
||||
|
||||
category = await post.categories.update_or_create(
|
||||
name="Weather News", sort_order=2, subject=subject
|
||||
)
|
||||
assert category.pk is not None
|
||||
assert category.posts[0] == post
|
||||
|
||||
assert len(post.categories) == 1
|
||||
|
||||
categories = await post.categories.all()
|
||||
assert len(categories) == 3 == len(post.categories)
|
||||
|
||||
assert await post.categories.exists()
|
||||
assert 3 == await post.categories.count()
|
||||
|
||||
categories = await post.categories.limit(2).all()
|
||||
assert len(categories) == 2 == len(post.categories)
|
||||
|
||||
categories2 = await post.categories.limit(2).offset(1).all()
|
||||
assert len(categories2) == 2 == len(post.categories)
|
||||
assert categories != categories2
|
||||
|
||||
categories = await post.categories.order_by("-sort_order").all()
|
||||
assert len(categories) == 3 == len(post.categories)
|
||||
assert post.categories[2].name == "News"
|
||||
assert post.categories[0].name == "Breaking"
|
||||
|
||||
categories = await post.categories.exclude(name__icontains="news").all()
|
||||
assert len(categories) == 1 == len(post.categories)
|
||||
assert post.categories[0].name == "Breaking"
|
||||
|
||||
categories = (
|
||||
await post.categories.filter(name__icontains="news")
|
||||
.order_by("-name")
|
||||
.all()
|
||||
)
|
||||
assert len(categories) == 2 == len(post.categories)
|
||||
assert post.categories[0].name == "Weather News"
|
||||
assert post.categories[1].name == "News"
|
||||
|
||||
categories = await post.categories.fields("name").all()
|
||||
assert len(categories) == 3 == len(post.categories)
|
||||
for cat in post.categories:
|
||||
assert cat.sort_order is None
|
||||
|
||||
categories = await post.categories.exclude_fields("sort_order").all()
|
||||
assert len(categories) == 3 == len(post.categories)
|
||||
for cat in post.categories:
|
||||
assert cat.sort_order is None
|
||||
assert cat.subject.name is None
|
||||
|
||||
categories = await post.categories.select_related("subject").all()
|
||||
assert len(categories) == 3 == len(post.categories)
|
||||
for cat in post.categories:
|
||||
assert cat.subject.name is not None
|
||||
|
||||
categories = await post.categories.prefetch_related("subject").all()
|
||||
assert len(categories) == 3 == len(post.categories)
|
||||
for cat in post.categories:
|
||||
assert cat.subject.name is not None
|
||||
233
tests/test_reverse_fk_queryset.py
Normal file
233
tests/test_reverse_fk_queryset.py
Normal file
@ -0,0 +1,233 @@
|
||||
from typing import Optional
|
||||
|
||||
import databases
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
class Album(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "albums"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100)
|
||||
is_best_seller: bool = ormar.Boolean(default=False)
|
||||
|
||||
|
||||
class Writer(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "writers"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100)
|
||||
|
||||
|
||||
class Track(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "tracks"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
album: Optional[Album] = ormar.ForeignKey(Album)
|
||||
title: str = ormar.String(max_length=100)
|
||||
position: int = ormar.Integer()
|
||||
play_count: int = ormar.Integer(nullable=True)
|
||||
written_by: Optional[Writer] = ormar.ForeignKey(Writer)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@pytest.mark.asyncio
|
||||
async def sample_data():
|
||||
album = await Album(name="Malibu").save()
|
||||
writer1 = await Writer.objects.create(name="John")
|
||||
writer2 = await Writer.objects.create(name="Sue")
|
||||
track1 = await Track(
|
||||
album=album, title="The Bird", position=1, play_count=30, written_by=writer1
|
||||
).save()
|
||||
track2 = await Track(
|
||||
album=album,
|
||||
title="Heart don't stand a chance",
|
||||
position=2,
|
||||
play_count=20,
|
||||
written_by=writer2,
|
||||
).save()
|
||||
tracks3 = await Track(
|
||||
album=album, title="The Waters", position=3, play_count=10, written_by=writer1
|
||||
).save()
|
||||
return album, [track1, track2, tracks3]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def create_test_database():
|
||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||
metadata.drop_all(engine)
|
||||
metadata.create_all(engine)
|
||||
yield
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quering_by_reverse_fk(sample_data):
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
track1 = sample_data[1][0]
|
||||
album = await Album.objects.first()
|
||||
|
||||
assert await album.tracks.exists()
|
||||
assert await album.tracks.count() == 3
|
||||
|
||||
track = await album.tracks.get_or_create(
|
||||
title="The Bird", position=1, play_count=30
|
||||
)
|
||||
assert track == track1
|
||||
assert len(album.tracks) == 1
|
||||
|
||||
track = await album.tracks.get_or_create(
|
||||
title="The Bird2", position=4, play_count=5
|
||||
)
|
||||
assert track != track1
|
||||
assert track.pk is not None
|
||||
assert len(album.tracks) == 2
|
||||
|
||||
await album.tracks.update_or_create(pk=track.pk, play_count=50)
|
||||
assert len(album.tracks) == 2
|
||||
track = await album.tracks.get_or_create(title="The Bird2")
|
||||
assert track.play_count == 50
|
||||
assert len(album.tracks) == 1
|
||||
|
||||
await album.tracks.remove(track)
|
||||
assert track.album is None
|
||||
await track.delete()
|
||||
|
||||
assert len(album.tracks) == 0
|
||||
|
||||
track6 = await album.tracks.update_or_create(
|
||||
title="The Bird3", position=4, play_count=5
|
||||
)
|
||||
assert track6.pk is not None
|
||||
assert track6.play_count == 5
|
||||
|
||||
assert len(album.tracks) == 1
|
||||
|
||||
await album.tracks.remove(track6)
|
||||
assert track6.album is None
|
||||
await track6.delete()
|
||||
|
||||
assert len(album.tracks) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_getting(sample_data):
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
album = sample_data[0]
|
||||
track1 = await album.tracks.fields(["album", "title", "position"]).get(
|
||||
title="The Bird"
|
||||
)
|
||||
track2 = await album.tracks.exclude_fields("play_count").get(
|
||||
title="The Bird"
|
||||
)
|
||||
for track in [track1, track2]:
|
||||
assert track.title == "The Bird"
|
||||
assert track.album == album
|
||||
assert track.play_count is None
|
||||
|
||||
assert len(album.tracks) == 1
|
||||
|
||||
tracks = await album.tracks.all()
|
||||
assert len(tracks) == 3
|
||||
|
||||
assert len(album.tracks) == 3
|
||||
|
||||
tracks = await album.tracks.order_by("play_count").all()
|
||||
assert len(tracks) == 3
|
||||
assert tracks[0].title == "The Waters"
|
||||
assert tracks[2].title == "The Bird"
|
||||
|
||||
assert len(album.tracks) == 3
|
||||
|
||||
track = await album.tracks.create(
|
||||
title="The Bird Fly Away", position=4, play_count=10
|
||||
)
|
||||
assert track.title == "The Bird Fly Away"
|
||||
assert track.position == 4
|
||||
assert track.album == album
|
||||
|
||||
assert len(album.tracks) == 4
|
||||
|
||||
tracks = await album.tracks.all()
|
||||
assert len(tracks) == 4
|
||||
|
||||
tracks = await album.tracks.limit(2).all()
|
||||
assert len(tracks) == 2
|
||||
|
||||
tracks2 = await album.tracks.limit(2).offset(2).all()
|
||||
assert len(tracks2) == 2
|
||||
assert tracks != tracks2
|
||||
|
||||
tracks3 = await album.tracks.filter(play_count__lt=15).all()
|
||||
assert len(tracks3) == 2
|
||||
|
||||
tracks4 = await album.tracks.exclude(play_count__lt=15).all()
|
||||
assert len(tracks4) == 2
|
||||
assert tracks3 != tracks4
|
||||
|
||||
assert len(album.tracks) == 2
|
||||
|
||||
await album.tracks.clear()
|
||||
tracks = await album.tracks.all()
|
||||
assert len(tracks) == 0
|
||||
assert len(album.tracks) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loading_related(sample_data):
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
album = sample_data[0]
|
||||
tracks = await album.tracks.select_related("written_by").all()
|
||||
assert len(tracks) == 3
|
||||
assert len(album.tracks) == 3
|
||||
for track in tracks:
|
||||
assert track.written_by is not None
|
||||
|
||||
tracks = await album.tracks.prefetch_related("written_by").all()
|
||||
assert len(tracks) == 3
|
||||
assert len(album.tracks) == 3
|
||||
for track in tracks:
|
||||
assert track.written_by is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adding_removing(sample_data):
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
album = sample_data[0]
|
||||
track_new = await Track(title="Rainbow", position=5, play_count=300).save()
|
||||
await album.tracks.add(track_new)
|
||||
assert track_new.album == album
|
||||
assert len(album.tracks) == 4
|
||||
|
||||
track_check = await Track.objects.get(title="Rainbow")
|
||||
assert track_check.album == album
|
||||
|
||||
track_test = await Track.objects.get(title="Rainbow")
|
||||
assert track_test.album == album
|
||||
|
||||
await album.tracks.remove(track_new)
|
||||
assert track_new.album is None
|
||||
assert len(album.tracks) == 3
|
||||
|
||||
track_test = await Track.objects.get(title="Rainbow")
|
||||
assert track_test.album is None
|
||||
Reference in New Issue
Block a user