expose querysetproxy on reverse of foreignkey (virtual fk), add additional methods from queryset to querysetproxy

This commit is contained in:
collerek
2020-12-01 08:27:08 +01:00
parent b939a02ce0
commit 61da7b4418
11 changed files with 605 additions and 68 deletions

View File

@ -212,7 +212,7 @@ You can use special filter suffix to change the filter operands:
* exact - like `album__name__exact='Malibu'` (exact match) * exact - like `album__name__exact='Malibu'` (exact match)
* iexact - like `album__name__iexact='malibu'` (exact match case insensitive) * iexact - like `album__name__iexact='malibu'` (exact match case insensitive)
* contains - like `album__name__conatins='Mal'` (sql like) * contains - like `album__name__contains='Mal'` (sql like)
* icontains - like `album__name__icontains='mal'` (sql like case insensitive) * icontains - like `album__name__icontains='mal'` (sql like case insensitive)
* in - like `album__name__in=['Malibu', 'Barclay']` (sql in) * in - like `album__name__in=['Malibu', 'Barclay']` (sql in)
* gt - like `position__gt=3` (sql >) * gt - like `position__gt=3` (sql >)

View File

@ -6,6 +6,10 @@ class ModelDefinitionError(AsyncOrmException):
pass pass
class ModelError(AsyncOrmException):
pass
class ModelNotSet(AsyncOrmException): class ModelNotSet(AsyncOrmException):
pass pass

View File

@ -7,6 +7,7 @@ from typing import (
Dict, Dict,
List, List,
Mapping, Mapping,
MutableSequence,
Optional, Optional,
Sequence, Sequence,
Set, Set,
@ -22,6 +23,7 @@ import sqlalchemy
from pydantic import BaseModel from pydantic import BaseModel
import ormar # noqa I100 import ormar # noqa I100
from ormar.exceptions import ModelError
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.excludable import Excludable from ormar.models.excludable import Excludable
@ -93,16 +95,21 @@ class NewBaseModel(
if "pk" in kwargs: if "pk" in kwargs:
kwargs[self.Meta.pkname] = kwargs.pop("pk") kwargs[self.Meta.pkname] = kwargs.pop("pk")
# build the models to set them and validate but don't register # build the models to set them and validate but don't register
new_kwargs = { try:
k: self._convert_json( new_kwargs = {
k, k: self._convert_json(
self.Meta.model_fields[k].expand_relationship( k,
v, self, to_register=False self.Meta.model_fields[k].expand_relationship(
), v, self, to_register=False
"dumps", ),
"dumps",
)
for k, v in kwargs.items()
}
except KeyError as e:
raise ModelError(
f"Unknown field '{e.args[0]}' for model {self.get_name(lower=False)}"
) )
for k, v in kwargs.items()
}
values, fields_set, validation_error = pydantic.validate_model( values, fields_set, validation_error = pydantic.validate_model(
self, new_kwargs # type: ignore self, new_kwargs # type: ignore
@ -249,7 +256,9 @@ class NewBaseModel(
@staticmethod @staticmethod
def _extract_nested_models_from_list( def _extract_nested_models_from_list(
models: List, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None], models: MutableSequence,
include: Union[Set, Dict, None],
exclude: Union[Set, Dict, None],
) -> List: ) -> List:
result = [] result = []
for model in models: for model in models:
@ -282,7 +291,7 @@ class NewBaseModel(
if self.Meta.model_fields[field].virtual and nested: if self.Meta.model_fields[field].virtual and nested:
continue continue
nested_model = getattr(self, field) nested_model = getattr(self, field)
if isinstance(nested_model, list): if isinstance(nested_model, MutableSequence):
dict_instance[field] = self._extract_nested_models_from_list( dict_instance[field] = self._extract_nested_models_from_list(
models=nested_model, models=nested_model,
include=self._skip_ellipsis(include, field), include=self._skip_ellipsis(include, field),
@ -308,7 +317,7 @@ class NewBaseModel(
exclude_unset: bool = False, exclude_unset: bool = False,
exclude_defaults: bool = False, exclude_defaults: bool = False,
exclude_none: bool = False, exclude_none: bool = False,
nested: bool = False nested: bool = False,
) -> "DictStrAny": # noqa: A003' ) -> "DictStrAny": # noqa: A003'
dict_instance = super().dict( dict_instance = super().dict(
include=include, include=include,

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Union from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union
try: try:
from typing import Protocol from typing import Protocol
@ -6,14 +6,21 @@ except ImportError: # pragma: nocover
from typing_extensions import Protocol # type: ignore from typing_extensions import Protocol # type: ignore
if TYPE_CHECKING: # noqa: C901; #pragma nocover if TYPE_CHECKING: # noqa: C901; #pragma nocover
from ormar import QuerySet, Model from ormar import Model
from ormar.relations.querysetproxy import QuerysetProxy
class QuerySetProtocol(Protocol): # pragma: nocover class QuerySetProtocol(Protocol): # pragma: nocover
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003, A001 def filter(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
... ...
def select_related(self, related: Union[List, str]) -> "QuerySet": def exclude(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
...
def select_related(self, related: Union[List, str]) -> "QuerysetProxy":
...
def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy":
... ...
async def exists(self) -> bool: async def exists(self) -> bool:
@ -25,10 +32,10 @@ class QuerySetProtocol(Protocol): # pragma: nocover
async def clear(self) -> int: async def clear(self) -> int:
... ...
def limit(self, limit_count: int) -> "QuerySet": def limit(self, limit_count: int) -> "QuerysetProxy":
... ...
def offset(self, offset: int) -> "QuerySet": def offset(self, offset: int) -> "QuerysetProxy":
... ...
async def first(self, **kwargs: Any) -> "Model": async def first(self, **kwargs: Any) -> "Model":
@ -44,3 +51,18 @@ class QuerySetProtocol(Protocol): # pragma: nocover
async def create(self, **kwargs: Any) -> "Model": async def create(self, **kwargs: Any) -> "Model":
... ...
async def get_or_create(self, **kwargs: Any) -> "Model":
...
async def update_or_create(self, **kwargs: Any) -> "Model":
...
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
...
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
...
def order_by(self, columns: Union[List, str]) -> "QuerysetProxy":
...

View File

@ -1,4 +1,15 @@
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union from typing import (
Any,
Dict,
List,
MutableSequence,
Optional,
Sequence,
Set,
TYPE_CHECKING,
TypeVar,
Union,
)
import ormar import ormar
@ -6,6 +17,7 @@ if TYPE_CHECKING: # pragma no cover
from ormar.relations import Relation from ormar.relations import Relation
from ormar.models import Model from ormar.models import Model
from ormar.queryset import QuerySet from ormar.queryset import QuerySet
from ormar import RelationType
T = TypeVar("T", bound=Model) T = TypeVar("T", bound=Model)
@ -14,9 +26,17 @@ class QuerysetProxy(ormar.QuerySetProtocol):
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
relation: "Relation" relation: "Relation"
def __init__(self, relation: "Relation") -> None: def __init__(
self, relation: "Relation", type_: "RelationType", qryset: "QuerySet" = None
) -> None:
self.relation: Relation = relation self.relation: Relation = relation
self._queryset: Optional["QuerySet"] = None self._queryset: Optional["QuerySet"] = qryset
self.type_: "RelationType" = type_
self._owner: "Model" = self.relation.manager.owner
self.related_field = self._owner.resolve_relation_field(
self.relation.to, self._owner
)
self.owner_pk_value = self._owner.pk
@property @property
def queryset(self) -> "QuerySet": def queryset(self) -> "QuerySet":
@ -30,7 +50,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
def _assign_child_to_parent(self, child: Optional["T"]) -> None: def _assign_child_to_parent(self, child: Optional["T"]) -> None:
if child: if child:
owner = self.relation._owner owner = self._owner
rel_name = owner.resolve_relation_name(owner, child) rel_name = owner.resolve_relation_name(owner, child)
setattr(owner, rel_name, child) setattr(owner, rel_name, child)
@ -42,27 +62,26 @@ class QuerysetProxy(ormar.QuerySetProtocol):
assert isinstance(child, ormar.Model) assert isinstance(child, ormar.Model)
self._assign_child_to_parent(child) self._assign_child_to_parent(child)
def _clean_items_on_load(self) -> None:
if isinstance(self.relation.related_models, MutableSequence):
for item in self.relation.related_models[:]:
self.relation.remove(item)
async def create_through_instance(self, child: "T") -> None: async def create_through_instance(self, child: "T") -> None:
queryset = ormar.QuerySet(model_cls=self.relation.through) queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name() owner_column = self._owner.get_name()
child_column = child.get_name() child_column = child.get_name()
kwargs = {owner_column: self.relation._owner, child_column: child} kwargs = {owner_column: self._owner, child_column: child}
await queryset.create(**kwargs) await queryset.create(**kwargs)
async def delete_through_instance(self, child: "T") -> None: async def delete_through_instance(self, child: "T") -> None:
queryset = ormar.QuerySet(model_cls=self.relation.through) queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name() owner_column = self._owner.get_name()
child_column = child.get_name() child_column = child.get_name()
kwargs = {owner_column: self.relation._owner, child_column: child} kwargs = {owner_column: self._owner, child_column: child}
link_instance = await queryset.filter(**kwargs).get() # type: ignore link_instance = await queryset.filter(**kwargs).get() # type: ignore
await link_instance.delete() await link_instance.delete()
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
return self.queryset.filter(**kwargs)
def select_related(self, related: Union[List, str]) -> "QuerySet":
return self.queryset.select_related(related)
async def exists(self) -> bool: async def exists(self) -> bool:
return await self.queryset.exists() return await self.queryset.exists()
@ -70,17 +89,16 @@ class QuerysetProxy(ormar.QuerySetProtocol):
return await self.queryset.count() return await self.queryset.count()
async def clear(self) -> int: async def clear(self) -> int:
queryset = ormar.QuerySet(model_cls=self.relation.through) if self.type_ == ormar.RelationType.MULTIPLE:
owner_column = self.relation._owner.get_name() queryset = ormar.QuerySet(model_cls=self.relation.through)
kwargs = {owner_column: self.relation._owner} owner_column = self._owner.get_name()
else:
queryset = ormar.QuerySet(model_cls=self.relation.to)
owner_column = self.related_field.name
kwargs = {owner_column: self._owner}
self._clean_items_on_load()
return await queryset.delete(**kwargs) # type: ignore return await queryset.delete(**kwargs) # type: ignore
def limit(self, limit_count: int) -> "QuerySet":
return self.queryset.limit(limit_count)
def offset(self, offset: int) -> "QuerySet":
return self.queryset.offset(offset)
async def first(self, **kwargs: Any) -> "Model": async def first(self, **kwargs: Any) -> "Model":
first = await self.queryset.first(**kwargs) first = await self.queryset.first(**kwargs)
self._register_related(first) self._register_related(first)
@ -88,16 +106,72 @@ class QuerysetProxy(ormar.QuerySetProtocol):
async def get(self, **kwargs: Any) -> "Model": async def get(self, **kwargs: Any) -> "Model":
get = await self.queryset.get(**kwargs) get = await self.queryset.get(**kwargs)
self._clean_items_on_load()
self._register_related(get) self._register_related(get)
return get return get
async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003 async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003
all_items = await self.queryset.all(**kwargs) all_items = await self.queryset.all(**kwargs)
self._clean_items_on_load()
self._register_related(all_items) self._register_related(all_items)
return all_items return all_items
async def create(self, **kwargs: Any) -> "Model": async def create(self, **kwargs: Any) -> "Model":
create = await self.queryset.create(**kwargs) if self.type_ == ormar.RelationType.REVERSE:
self._register_related(create) kwargs[self.related_field.name] = self._owner
await self.create_through_instance(create) created = await self.queryset.create(**kwargs)
return create self._register_related(created)
if self.type_ == ormar.RelationType.MULTIPLE:
await self.create_through_instance(created)
return created
async def get_or_create(self, **kwargs: Any) -> "Model":
try:
return await self.get(**kwargs)
except ormar.NoMatch:
return await self.create(**kwargs)
async def update_or_create(self, **kwargs: Any) -> "Model":
pk_name = self.queryset.model_meta.pkname
if "pk" in kwargs:
kwargs[pk_name] = kwargs.pop("pk")
if pk_name not in kwargs or kwargs.get(pk_name) is None:
return await self.create(**kwargs)
model = await self.queryset.get(pk=kwargs[pk_name])
return await model.update(**kwargs)
def filter(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
queryset = self.queryset.filter(**kwargs)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def exclude(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
queryset = self.queryset.exclude(**kwargs)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def select_related(self, related: Union[List, str]) -> "QuerysetProxy":
queryset = self.queryset.select_related(related)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy":
queryset = self.queryset.prefetch_related(related)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def limit(self, limit_count: int) -> "QuerysetProxy":
queryset = self.queryset.limit(limit_count)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def offset(self, offset: int) -> "QuerysetProxy":
queryset = self.queryset.offset(offset)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
queryset = self.queryset.fields(columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
queryset = self.queryset.exclude_fields(columns=columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def order_by(self, columns: Union[List, str]) -> "QuerysetProxy":
queryset = self.queryset.order_by(columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)

View File

@ -34,7 +34,7 @@ class Relation:
self.to: Type["T"] = to self.to: Type["T"] = to
self.through: Optional[Type["T"]] = through self.through: Optional[Type["T"]] = through
self.related_models: Optional[Union[RelationProxy, "T"]] = ( self.related_models: Optional[Union[RelationProxy, "T"]] = (
RelationProxy(relation=self) RelationProxy(relation=self, type_=type_)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None else None
) )

View File

@ -65,8 +65,6 @@ class RelationsManager:
parent_relation = parent._orm._get(child_name) parent_relation = parent._orm._get(child_name)
if parent_relation: if parent_relation:
# print('missing', child_name)
# parent_relation = register_missing_relation(parent, child, child_name)
parent_relation.add(child) # type: ignore parent_relation.add(child) # type: ignore
child_relation = child._orm._get(to_name) child_relation = child._orm._get(to_name)

View File

@ -5,17 +5,18 @@ from ormar.exceptions import NoMatch, RelationshipInstanceError
from ormar.relations.querysetproxy import QuerysetProxy from ormar.relations.querysetproxy import QuerysetProxy
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model, RelationType
from ormar.relations import Relation from ormar.relations import Relation
from ormar.queryset import QuerySet from ormar.queryset import QuerySet
class RelationProxy(list): class RelationProxy(list):
def __init__(self, relation: "Relation") -> None: def __init__(self, relation: "Relation", type_: "RelationType") -> None:
super(RelationProxy, self).__init__() super().__init__()
self.relation: Relation = relation self.relation: "Relation" = relation
self.type_: "RelationType" = type_
self._owner: "Model" = self.relation.manager.owner self._owner: "Model" = self.relation.manager.owner
self.queryset_proxy = QuerysetProxy(relation=self.relation) self.queryset_proxy = QuerysetProxy(relation=self.relation, type_=type_)
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:
if item in ["count", "clear"]: if item in ["count", "clear"]:
@ -38,17 +39,19 @@ class RelationProxy(list):
) )
def _set_queryset(self) -> "QuerySet": def _set_queryset(self) -> "QuerySet":
owner_table = self.relation._owner.Meta.tablename related_field = self._owner.resolve_relation_field(
pkname = self.relation._owner.get_column_alias(self.relation._owner.Meta.pkname) self.relation.to, self._owner
pk_value = self.relation._owner.pk )
pkname = self._owner.get_column_alias(self._owner.Meta.pkname)
pk_value = self._owner.pk
if not pk_value: if not pk_value:
raise RelationshipInstanceError( raise RelationshipInstanceError(
"You cannot query many to many relationship on unsaved model." "You cannot query relationships from unsaved model."
) )
kwargs = {f"{owner_table}__{pkname}": pk_value} kwargs = {f"{related_field.get_alias()}__{pkname}": pk_value}
queryset = ( queryset = (
ormar.QuerySet(model_cls=self.relation.to) ormar.QuerySet(model_cls=self.relation.to)
.select_related(owner_table) .select_related(related_field.name)
.filter(**kwargs) .filter(**kwargs)
) )
return queryset return queryset
@ -67,14 +70,21 @@ class RelationProxy(list):
f"{self._owner.get_name()} does not have relation {rel_name}" f"{self._owner.get_name()} does not have relation {rel_name}"
) )
relation.remove(self._owner) relation.remove(self._owner)
if self.relation._type == ormar.RelationType.MULTIPLE: self.relation.remove(item)
if self.type_ == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.delete_through_instance(item) await self.queryset_proxy.delete_through_instance(item)
else:
def append(self, item: "Model") -> None: setattr(item, rel_name, None)
super().append(item) await item.update()
async def add(self, item: "Model") -> None: async def add(self, item: "Model") -> None:
if self.relation._type == ormar.RelationType.MULTIPLE: if self.type_ == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item) await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner) rel_name = item.resolve_relation_name(item, self._owner)
setattr(item, rel_name, self._owner) setattr(item, rel_name, self._owner)
else:
related_field = self._owner.resolve_relation_field(
self.relation.to, self._owner
)
setattr(item, related_field.name, self._owner)
await item.update()

View File

@ -9,7 +9,7 @@ import pytest
import sqlalchemy import sqlalchemy
import ormar import ormar
from ormar.exceptions import QueryDefinitionError, NoMatch from ormar.exceptions import QueryDefinitionError, NoMatch, ModelError
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
@ -117,6 +117,11 @@ def test_model_class():
assert isinstance(User.Meta.table, sqlalchemy.Table) assert isinstance(User.Meta.table, sqlalchemy.Table)
def test_wrong_field_name():
with pytest.raises(ModelError):
User(non_existing_pk=1)
def test_model_pk(): def test_model_pk():
user = User(pk=1) user = User(pk=1)
assert user.pk == 1 assert user.pk == 1

View File

@ -0,0 +1,182 @@
import asyncio
from typing import List, Optional, Union
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Subject(ormar.Model):
class Meta:
tablename = "subjects"
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=80)
class Author(ormar.Model):
class Meta:
tablename = "authors"
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
first_name: str = ormar.String(max_length=80)
last_name: str = ormar.String(max_length=80)
class Category(ormar.Model):
class Meta:
tablename = "categories"
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=40)
sort_order: int = ormar.Integer(nullable=True)
subject: Optional[Subject] = ormar.ForeignKey(Subject)
class PostCategory(ormar.Model):
class Meta:
tablename = "posts_categories"
database = database
metadata = metadata
class Post(ormar.Model):
class Meta:
tablename = "posts"
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=200)
categories: Optional[Union[Category, List[Category]]] = ormar.ManyToMany(
Category, through=PostCategory
)
author: Optional[Author] = ormar.ForeignKey(Author)
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module")
async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_queryset_methods():
async with database:
async with database.transaction(force_rollback=True):
guido = await Author.objects.create(
first_name="Guido", last_name="Van Rossum"
)
subject = await Subject(name="Random").save()
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(
name="News", sort_order=1, subject=subject
)
breaking = await Category.objects.create(
name="Breaking", sort_order=3, subject=subject
)
# Add a category to a post.
await post.categories.add(news)
await post.categories.add(breaking)
category = await post.categories.get_or_create(name="News")
assert category == news
assert len(post.categories) == 1
category = await post.categories.get_or_create(name="Breaking News")
assert category != breaking
assert category.pk is not None
assert len(post.categories) == 2
await post.categories.update_or_create(pk=category.pk, name="Urgent News")
assert len(post.categories) == 2
cat = await post.categories.get_or_create(name="Urgent News")
assert cat.pk == category.pk
assert len(post.categories) == 1
await post.categories.remove(cat)
await cat.delete()
assert len(post.categories) == 0
category = await post.categories.update_or_create(
name="Weather News", sort_order=2, subject=subject
)
assert category.pk is not None
assert category.posts[0] == post
assert len(post.categories) == 1
categories = await post.categories.all()
assert len(categories) == 3 == len(post.categories)
assert await post.categories.exists()
assert 3 == await post.categories.count()
categories = await post.categories.limit(2).all()
assert len(categories) == 2 == len(post.categories)
categories2 = await post.categories.limit(2).offset(1).all()
assert len(categories2) == 2 == len(post.categories)
assert categories != categories2
categories = await post.categories.order_by("-sort_order").all()
assert len(categories) == 3 == len(post.categories)
assert post.categories[2].name == "News"
assert post.categories[0].name == "Breaking"
categories = await post.categories.exclude(name__icontains="news").all()
assert len(categories) == 1 == len(post.categories)
assert post.categories[0].name == "Breaking"
categories = (
await post.categories.filter(name__icontains="news")
.order_by("-name")
.all()
)
assert len(categories) == 2 == len(post.categories)
assert post.categories[0].name == "Weather News"
assert post.categories[1].name == "News"
categories = await post.categories.fields("name").all()
assert len(categories) == 3 == len(post.categories)
for cat in post.categories:
assert cat.sort_order is None
categories = await post.categories.exclude_fields("sort_order").all()
assert len(categories) == 3 == len(post.categories)
for cat in post.categories:
assert cat.sort_order is None
assert cat.subject.name is None
categories = await post.categories.select_related("subject").all()
assert len(categories) == 3 == len(post.categories)
for cat in post.categories:
assert cat.subject.name is not None
categories = await post.categories.prefetch_related("subject").all()
assert len(categories) == 3 == len(post.categories)
for cat in post.categories:
assert cat.subject.name is not None

View File

@ -0,0 +1,233 @@
from typing import Optional
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
is_best_seller: bool = ormar.Boolean(default=False)
class Writer(ormar.Model):
class Meta:
tablename = "writers"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Track(ormar.Model):
class Meta:
tablename = "tracks"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album)
title: str = ormar.String(max_length=100)
position: int = ormar.Integer()
play_count: int = ormar.Integer(nullable=True)
written_by: Optional[Writer] = ormar.ForeignKey(Writer)
@pytest.fixture(autouse=True)
@pytest.mark.asyncio
async def sample_data():
album = await Album(name="Malibu").save()
writer1 = await Writer.objects.create(name="John")
writer2 = await Writer.objects.create(name="Sue")
track1 = await Track(
album=album, title="The Bird", position=1, play_count=30, written_by=writer1
).save()
track2 = await Track(
album=album,
title="Heart don't stand a chance",
position=2,
play_count=20,
written_by=writer2,
).save()
tracks3 = await Track(
album=album, title="The Waters", position=3, play_count=10, written_by=writer1
).save()
return album, [track1, track2, tracks3]
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_quering_by_reverse_fk(sample_data):
async with database:
async with database.transaction(force_rollback=True):
track1 = sample_data[1][0]
album = await Album.objects.first()
assert await album.tracks.exists()
assert await album.tracks.count() == 3
track = await album.tracks.get_or_create(
title="The Bird", position=1, play_count=30
)
assert track == track1
assert len(album.tracks) == 1
track = await album.tracks.get_or_create(
title="The Bird2", position=4, play_count=5
)
assert track != track1
assert track.pk is not None
assert len(album.tracks) == 2
await album.tracks.update_or_create(pk=track.pk, play_count=50)
assert len(album.tracks) == 2
track = await album.tracks.get_or_create(title="The Bird2")
assert track.play_count == 50
assert len(album.tracks) == 1
await album.tracks.remove(track)
assert track.album is None
await track.delete()
assert len(album.tracks) == 0
track6 = await album.tracks.update_or_create(
title="The Bird3", position=4, play_count=5
)
assert track6.pk is not None
assert track6.play_count == 5
assert len(album.tracks) == 1
await album.tracks.remove(track6)
assert track6.album is None
await track6.delete()
assert len(album.tracks) == 0
@pytest.mark.asyncio
async def test_getting(sample_data):
async with database:
async with database.transaction(force_rollback=True):
album = sample_data[0]
track1 = await album.tracks.fields(["album", "title", "position"]).get(
title="The Bird"
)
track2 = await album.tracks.exclude_fields("play_count").get(
title="The Bird"
)
for track in [track1, track2]:
assert track.title == "The Bird"
assert track.album == album
assert track.play_count is None
assert len(album.tracks) == 1
tracks = await album.tracks.all()
assert len(tracks) == 3
assert len(album.tracks) == 3
tracks = await album.tracks.order_by("play_count").all()
assert len(tracks) == 3
assert tracks[0].title == "The Waters"
assert tracks[2].title == "The Bird"
assert len(album.tracks) == 3
track = await album.tracks.create(
title="The Bird Fly Away", position=4, play_count=10
)
assert track.title == "The Bird Fly Away"
assert track.position == 4
assert track.album == album
assert len(album.tracks) == 4
tracks = await album.tracks.all()
assert len(tracks) == 4
tracks = await album.tracks.limit(2).all()
assert len(tracks) == 2
tracks2 = await album.tracks.limit(2).offset(2).all()
assert len(tracks2) == 2
assert tracks != tracks2
tracks3 = await album.tracks.filter(play_count__lt=15).all()
assert len(tracks3) == 2
tracks4 = await album.tracks.exclude(play_count__lt=15).all()
assert len(tracks4) == 2
assert tracks3 != tracks4
assert len(album.tracks) == 2
await album.tracks.clear()
tracks = await album.tracks.all()
assert len(tracks) == 0
assert len(album.tracks) == 0
@pytest.mark.asyncio
async def test_loading_related(sample_data):
async with database:
async with database.transaction(force_rollback=True):
album = sample_data[0]
tracks = await album.tracks.select_related("written_by").all()
assert len(tracks) == 3
assert len(album.tracks) == 3
for track in tracks:
assert track.written_by is not None
tracks = await album.tracks.prefetch_related("written_by").all()
assert len(tracks) == 3
assert len(album.tracks) == 3
for track in tracks:
assert track.written_by is not None
@pytest.mark.asyncio
async def test_adding_removing(sample_data):
async with database:
async with database.transaction(force_rollback=True):
album = sample_data[0]
track_new = await Track(title="Rainbow", position=5, play_count=300).save()
await album.tracks.add(track_new)
assert track_new.album == album
assert len(album.tracks) == 4
track_check = await Track.objects.get(title="Rainbow")
assert track_check.album == album
track_test = await Track.objects.get(title="Rainbow")
assert track_test.album == album
await album.tracks.remove(track_new)
assert track_new.album is None
assert len(album.tracks) == 3
track_test = await Track.objects.get(title="Rainbow")
assert track_test.album is None