diff --git a/docs/queries.md b/docs/queries.md index c848bc1..17b23ce 100644 --- a/docs/queries.md +++ b/docs/queries.md @@ -212,7 +212,7 @@ You can use special filter suffix to change the filter operands: * exact - like `album__name__exact='Malibu'` (exact match) * iexact - like `album__name__iexact='malibu'` (exact match case insensitive) -* contains - like `album__name__conatins='Mal'` (sql like) +* contains - like `album__name__contains='Mal'` (sql like) * icontains - like `album__name__icontains='mal'` (sql like case insensitive) * in - like `album__name__in=['Malibu', 'Barclay']` (sql in) * gt - like `position__gt=3` (sql >) diff --git a/ormar/exceptions.py b/ormar/exceptions.py index 0800a81..0ec0c8e 100644 --- a/ormar/exceptions.py +++ b/ormar/exceptions.py @@ -6,6 +6,10 @@ class ModelDefinitionError(AsyncOrmException): pass +class ModelError(AsyncOrmException): + pass + + class ModelNotSet(AsyncOrmException): pass diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 8a9d518..05c3112 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -7,6 +7,7 @@ from typing import ( Dict, List, Mapping, + MutableSequence, Optional, Sequence, Set, @@ -22,6 +23,7 @@ import sqlalchemy from pydantic import BaseModel import ormar # noqa I100 +from ormar.exceptions import ModelError from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField from ormar.models.excludable import Excludable @@ -93,16 +95,21 @@ class NewBaseModel( if "pk" in kwargs: kwargs[self.Meta.pkname] = kwargs.pop("pk") # build the models to set them and validate but don't register - new_kwargs = { - k: self._convert_json( - k, - self.Meta.model_fields[k].expand_relationship( - v, self, to_register=False - ), - "dumps", + try: + new_kwargs = { + k: self._convert_json( + k, + self.Meta.model_fields[k].expand_relationship( + v, self, to_register=False + ), + "dumps", + ) + for k, v in kwargs.items() + } + except KeyError as e: + raise ModelError( + f"Unknown field '{e.args[0]}' for model {self.get_name(lower=False)}" ) - for k, v in kwargs.items() - } values, fields_set, validation_error = pydantic.validate_model( self, new_kwargs # type: ignore @@ -249,7 +256,9 @@ class NewBaseModel( @staticmethod def _extract_nested_models_from_list( - models: List, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None], + models: MutableSequence, + include: Union[Set, Dict, None], + exclude: Union[Set, Dict, None], ) -> List: result = [] for model in models: @@ -282,7 +291,7 @@ class NewBaseModel( if self.Meta.model_fields[field].virtual and nested: continue nested_model = getattr(self, field) - if isinstance(nested_model, list): + if isinstance(nested_model, MutableSequence): dict_instance[field] = self._extract_nested_models_from_list( models=nested_model, include=self._skip_ellipsis(include, field), @@ -308,7 +317,7 @@ class NewBaseModel( exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - nested: bool = False + nested: bool = False, ) -> "DictStrAny": # noqa: A003' dict_instance = super().dict( include=include, diff --git a/ormar/protocols/queryset_protocol.py b/ormar/protocols/queryset_protocol.py index 1320c2a..7eb7092 100644 --- a/ormar/protocols/queryset_protocol.py +++ b/ormar/protocols/queryset_protocol.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union try: from typing import Protocol @@ -6,14 +6,21 @@ except ImportError: # pragma: nocover from typing_extensions import Protocol # type: ignore if TYPE_CHECKING: # noqa: C901; #pragma nocover - from ormar import QuerySet, Model + from ormar import Model + from ormar.relations.querysetproxy import QuerysetProxy class QuerySetProtocol(Protocol): # pragma: nocover - def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003, A001 + def filter(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001 ... - def select_related(self, related: Union[List, str]) -> "QuerySet": + def exclude(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001 + ... + + def select_related(self, related: Union[List, str]) -> "QuerysetProxy": + ... + + def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy": ... async def exists(self) -> bool: @@ -25,10 +32,10 @@ class QuerySetProtocol(Protocol): # pragma: nocover async def clear(self) -> int: ... - def limit(self, limit_count: int) -> "QuerySet": + def limit(self, limit_count: int) -> "QuerysetProxy": ... - def offset(self, offset: int) -> "QuerySet": + def offset(self, offset: int) -> "QuerysetProxy": ... async def first(self, **kwargs: Any) -> "Model": @@ -44,3 +51,18 @@ class QuerySetProtocol(Protocol): # pragma: nocover async def create(self, **kwargs: Any) -> "Model": ... + + async def get_or_create(self, **kwargs: Any) -> "Model": + ... + + async def update_or_create(self, **kwargs: Any) -> "Model": + ... + + def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy": + ... + + def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy": + ... + + def order_by(self, columns: Union[List, str]) -> "QuerysetProxy": + ... diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index 8e2d6b4..208b4d1 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -1,4 +1,15 @@ -from typing import Any, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union +from typing import ( + Any, + Dict, + List, + MutableSequence, + Optional, + Sequence, + Set, + TYPE_CHECKING, + TypeVar, + Union, +) import ormar @@ -6,6 +17,7 @@ if TYPE_CHECKING: # pragma no cover from ormar.relations import Relation from ormar.models import Model from ormar.queryset import QuerySet + from ormar import RelationType T = TypeVar("T", bound=Model) @@ -14,9 +26,17 @@ class QuerysetProxy(ormar.QuerySetProtocol): if TYPE_CHECKING: # pragma no cover relation: "Relation" - def __init__(self, relation: "Relation") -> None: + def __init__( + self, relation: "Relation", type_: "RelationType", qryset: "QuerySet" = None + ) -> None: self.relation: Relation = relation - self._queryset: Optional["QuerySet"] = None + self._queryset: Optional["QuerySet"] = qryset + self.type_: "RelationType" = type_ + self._owner: "Model" = self.relation.manager.owner + self.related_field = self._owner.resolve_relation_field( + self.relation.to, self._owner + ) + self.owner_pk_value = self._owner.pk @property def queryset(self) -> "QuerySet": @@ -30,7 +50,7 @@ class QuerysetProxy(ormar.QuerySetProtocol): def _assign_child_to_parent(self, child: Optional["T"]) -> None: if child: - owner = self.relation._owner + owner = self._owner rel_name = owner.resolve_relation_name(owner, child) setattr(owner, rel_name, child) @@ -42,27 +62,26 @@ class QuerysetProxy(ormar.QuerySetProtocol): assert isinstance(child, ormar.Model) self._assign_child_to_parent(child) + def _clean_items_on_load(self) -> None: + if isinstance(self.relation.related_models, MutableSequence): + for item in self.relation.related_models[:]: + self.relation.remove(item) + async def create_through_instance(self, child: "T") -> None: queryset = ormar.QuerySet(model_cls=self.relation.through) - owner_column = self.relation._owner.get_name() + owner_column = self._owner.get_name() child_column = child.get_name() - kwargs = {owner_column: self.relation._owner, child_column: child} + kwargs = {owner_column: self._owner, child_column: child} await queryset.create(**kwargs) async def delete_through_instance(self, child: "T") -> None: queryset = ormar.QuerySet(model_cls=self.relation.through) - owner_column = self.relation._owner.get_name() + owner_column = self._owner.get_name() child_column = child.get_name() - kwargs = {owner_column: self.relation._owner, child_column: child} + kwargs = {owner_column: self._owner, child_column: child} link_instance = await queryset.filter(**kwargs).get() # type: ignore await link_instance.delete() - def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 - return self.queryset.filter(**kwargs) - - def select_related(self, related: Union[List, str]) -> "QuerySet": - return self.queryset.select_related(related) - async def exists(self) -> bool: return await self.queryset.exists() @@ -70,17 +89,16 @@ class QuerysetProxy(ormar.QuerySetProtocol): return await self.queryset.count() async def clear(self) -> int: - queryset = ormar.QuerySet(model_cls=self.relation.through) - owner_column = self.relation._owner.get_name() - kwargs = {owner_column: self.relation._owner} + if self.type_ == ormar.RelationType.MULTIPLE: + queryset = ormar.QuerySet(model_cls=self.relation.through) + owner_column = self._owner.get_name() + else: + queryset = ormar.QuerySet(model_cls=self.relation.to) + owner_column = self.related_field.name + kwargs = {owner_column: self._owner} + self._clean_items_on_load() return await queryset.delete(**kwargs) # type: ignore - def limit(self, limit_count: int) -> "QuerySet": - return self.queryset.limit(limit_count) - - def offset(self, offset: int) -> "QuerySet": - return self.queryset.offset(offset) - async def first(self, **kwargs: Any) -> "Model": first = await self.queryset.first(**kwargs) self._register_related(first) @@ -88,16 +106,72 @@ class QuerysetProxy(ormar.QuerySetProtocol): async def get(self, **kwargs: Any) -> "Model": get = await self.queryset.get(**kwargs) + self._clean_items_on_load() self._register_related(get) return get async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003 all_items = await self.queryset.all(**kwargs) + self._clean_items_on_load() self._register_related(all_items) return all_items async def create(self, **kwargs: Any) -> "Model": - create = await self.queryset.create(**kwargs) - self._register_related(create) - await self.create_through_instance(create) - return create + if self.type_ == ormar.RelationType.REVERSE: + kwargs[self.related_field.name] = self._owner + created = await self.queryset.create(**kwargs) + self._register_related(created) + if self.type_ == ormar.RelationType.MULTIPLE: + await self.create_through_instance(created) + return created + + async def get_or_create(self, **kwargs: Any) -> "Model": + try: + return await self.get(**kwargs) + except ormar.NoMatch: + return await self.create(**kwargs) + + async def update_or_create(self, **kwargs: Any) -> "Model": + pk_name = self.queryset.model_meta.pkname + if "pk" in kwargs: + kwargs[pk_name] = kwargs.pop("pk") + if pk_name not in kwargs or kwargs.get(pk_name) is None: + return await self.create(**kwargs) + model = await self.queryset.get(pk=kwargs[pk_name]) + return await model.update(**kwargs) + + def filter(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001 + queryset = self.queryset.filter(**kwargs) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def exclude(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001 + queryset = self.queryset.exclude(**kwargs) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def select_related(self, related: Union[List, str]) -> "QuerysetProxy": + queryset = self.queryset.select_related(related) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy": + queryset = self.queryset.prefetch_related(related) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def limit(self, limit_count: int) -> "QuerysetProxy": + queryset = self.queryset.limit(limit_count) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def offset(self, offset: int) -> "QuerysetProxy": + queryset = self.queryset.offset(offset) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy": + queryset = self.queryset.fields(columns) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy": + queryset = self.queryset.exclude_fields(columns=columns) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def order_by(self, columns: Union[List, str]) -> "QuerysetProxy": + queryset = self.queryset.order_by(columns) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) diff --git a/ormar/relations/relation.py b/ormar/relations/relation.py index e09f00c..b0183c3 100644 --- a/ormar/relations/relation.py +++ b/ormar/relations/relation.py @@ -34,7 +34,7 @@ class Relation: self.to: Type["T"] = to self.through: Optional[Type["T"]] = through self.related_models: Optional[Union[RelationProxy, "T"]] = ( - RelationProxy(relation=self) + RelationProxy(relation=self, type_=type_) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) else None ) diff --git a/ormar/relations/relation_manager.py b/ormar/relations/relation_manager.py index dfe1ee8..81183f4 100644 --- a/ormar/relations/relation_manager.py +++ b/ormar/relations/relation_manager.py @@ -65,8 +65,6 @@ class RelationsManager: parent_relation = parent._orm._get(child_name) if parent_relation: - # print('missing', child_name) - # parent_relation = register_missing_relation(parent, child, child_name) parent_relation.add(child) # type: ignore child_relation = child._orm._get(to_name) diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index f03252e..25c7b1c 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -5,17 +5,18 @@ from ormar.exceptions import NoMatch, RelationshipInstanceError from ormar.relations.querysetproxy import QuerysetProxy if TYPE_CHECKING: # pragma no cover - from ormar import Model + from ormar import Model, RelationType from ormar.relations import Relation from ormar.queryset import QuerySet class RelationProxy(list): - def __init__(self, relation: "Relation") -> None: - super(RelationProxy, self).__init__() - self.relation: Relation = relation + def __init__(self, relation: "Relation", type_: "RelationType") -> None: + super().__init__() + self.relation: "Relation" = relation + self.type_: "RelationType" = type_ self._owner: "Model" = self.relation.manager.owner - self.queryset_proxy = QuerysetProxy(relation=self.relation) + self.queryset_proxy = QuerysetProxy(relation=self.relation, type_=type_) def __getattribute__(self, item: str) -> Any: if item in ["count", "clear"]: @@ -38,17 +39,19 @@ class RelationProxy(list): ) def _set_queryset(self) -> "QuerySet": - owner_table = self.relation._owner.Meta.tablename - pkname = self.relation._owner.get_column_alias(self.relation._owner.Meta.pkname) - pk_value = self.relation._owner.pk + related_field = self._owner.resolve_relation_field( + self.relation.to, self._owner + ) + pkname = self._owner.get_column_alias(self._owner.Meta.pkname) + pk_value = self._owner.pk if not pk_value: raise RelationshipInstanceError( - "You cannot query many to many relationship on unsaved model." + "You cannot query relationships from unsaved model." ) - kwargs = {f"{owner_table}__{pkname}": pk_value} + kwargs = {f"{related_field.get_alias()}__{pkname}": pk_value} queryset = ( ormar.QuerySet(model_cls=self.relation.to) - .select_related(owner_table) + .select_related(related_field.name) .filter(**kwargs) ) return queryset @@ -67,14 +70,21 @@ class RelationProxy(list): f"{self._owner.get_name()} does not have relation {rel_name}" ) relation.remove(self._owner) - if self.relation._type == ormar.RelationType.MULTIPLE: + self.relation.remove(item) + if self.type_ == ormar.RelationType.MULTIPLE: await self.queryset_proxy.delete_through_instance(item) - - def append(self, item: "Model") -> None: - super().append(item) + else: + setattr(item, rel_name, None) + await item.update() async def add(self, item: "Model") -> None: - if self.relation._type == ormar.RelationType.MULTIPLE: + if self.type_ == ormar.RelationType.MULTIPLE: await self.queryset_proxy.create_through_instance(item) - rel_name = item.resolve_relation_name(item, self._owner) - setattr(item, rel_name, self._owner) + rel_name = item.resolve_relation_name(item, self._owner) + setattr(item, rel_name, self._owner) + else: + related_field = self._owner.resolve_relation_field( + self.relation.to, self._owner + ) + setattr(item, related_field.name, self._owner) + await item.update() diff --git a/tests/test_models.py b/tests/test_models.py index 3c9ffc8..24a81ad 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,7 +9,7 @@ import pytest import sqlalchemy import ormar -from ormar.exceptions import QueryDefinitionError, NoMatch +from ormar.exceptions import QueryDefinitionError, NoMatch, ModelError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -117,6 +117,11 @@ def test_model_class(): assert isinstance(User.Meta.table, sqlalchemy.Table) +def test_wrong_field_name(): + with pytest.raises(ModelError): + User(non_existing_pk=1) + + def test_model_pk(): user = User(pk=1) assert user.pk == 1 diff --git a/tests/test_queryproxy_on_m2m_models.py b/tests/test_queryproxy_on_m2m_models.py new file mode 100644 index 0000000..d33aa5d --- /dev/null +++ b/tests/test_queryproxy_on_m2m_models.py @@ -0,0 +1,182 @@ +import asyncio +from typing import List, Optional, Union + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Subject(ormar.Model): + class Meta: + tablename = "subjects" + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=80) + + +class Author(ormar.Model): + class Meta: + tablename = "authors" + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + first_name: str = ormar.String(max_length=80) + last_name: str = ormar.String(max_length=80) + + +class Category(ormar.Model): + class Meta: + tablename = "categories" + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=40) + sort_order: int = ormar.Integer(nullable=True) + subject: Optional[Subject] = ormar.ForeignKey(Subject) + + +class PostCategory(ormar.Model): + class Meta: + tablename = "posts_categories" + database = database + metadata = metadata + + +class Post(ormar.Model): + class Meta: + tablename = "posts" + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + title: str = ormar.String(max_length=200) + categories: Optional[Union[Category, List[Category]]] = ormar.ManyToMany( + Category, through=PostCategory + ) + author: Optional[Author] = ormar.ForeignKey(Author) + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@pytest.fixture(autouse=True, scope="module") +async def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_queryset_methods(): + async with database: + async with database.transaction(force_rollback=True): + guido = await Author.objects.create( + first_name="Guido", last_name="Van Rossum" + ) + subject = await Subject(name="Random").save() + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create( + name="News", sort_order=1, subject=subject + ) + breaking = await Category.objects.create( + name="Breaking", sort_order=3, subject=subject + ) + + # Add a category to a post. + await post.categories.add(news) + await post.categories.add(breaking) + + category = await post.categories.get_or_create(name="News") + assert category == news + assert len(post.categories) == 1 + + category = await post.categories.get_or_create(name="Breaking News") + assert category != breaking + assert category.pk is not None + assert len(post.categories) == 2 + + await post.categories.update_or_create(pk=category.pk, name="Urgent News") + assert len(post.categories) == 2 + cat = await post.categories.get_or_create(name="Urgent News") + assert cat.pk == category.pk + assert len(post.categories) == 1 + + await post.categories.remove(cat) + await cat.delete() + + assert len(post.categories) == 0 + + category = await post.categories.update_or_create( + name="Weather News", sort_order=2, subject=subject + ) + assert category.pk is not None + assert category.posts[0] == post + + assert len(post.categories) == 1 + + categories = await post.categories.all() + assert len(categories) == 3 == len(post.categories) + + assert await post.categories.exists() + assert 3 == await post.categories.count() + + categories = await post.categories.limit(2).all() + assert len(categories) == 2 == len(post.categories) + + categories2 = await post.categories.limit(2).offset(1).all() + assert len(categories2) == 2 == len(post.categories) + assert categories != categories2 + + categories = await post.categories.order_by("-sort_order").all() + assert len(categories) == 3 == len(post.categories) + assert post.categories[2].name == "News" + assert post.categories[0].name == "Breaking" + + categories = await post.categories.exclude(name__icontains="news").all() + assert len(categories) == 1 == len(post.categories) + assert post.categories[0].name == "Breaking" + + categories = ( + await post.categories.filter(name__icontains="news") + .order_by("-name") + .all() + ) + assert len(categories) == 2 == len(post.categories) + assert post.categories[0].name == "Weather News" + assert post.categories[1].name == "News" + + categories = await post.categories.fields("name").all() + assert len(categories) == 3 == len(post.categories) + for cat in post.categories: + assert cat.sort_order is None + + categories = await post.categories.exclude_fields("sort_order").all() + assert len(categories) == 3 == len(post.categories) + for cat in post.categories: + assert cat.sort_order is None + assert cat.subject.name is None + + categories = await post.categories.select_related("subject").all() + assert len(categories) == 3 == len(post.categories) + for cat in post.categories: + assert cat.subject.name is not None + + categories = await post.categories.prefetch_related("subject").all() + assert len(categories) == 3 == len(post.categories) + for cat in post.categories: + assert cat.subject.name is not None diff --git a/tests/test_reverse_fk_queryset.py b/tests/test_reverse_fk_queryset.py new file mode 100644 index 0000000..0eac5d5 --- /dev/null +++ b/tests/test_reverse_fk_queryset.py @@ -0,0 +1,233 @@ +from typing import Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Album(ormar.Model): + class Meta: + tablename = "albums" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + is_best_seller: bool = ormar.Boolean(default=False) + + +class Writer(ormar.Model): + class Meta: + tablename = "writers" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Track(ormar.Model): + class Meta: + tablename = "tracks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + album: Optional[Album] = ormar.ForeignKey(Album) + title: str = ormar.String(max_length=100) + position: int = ormar.Integer() + play_count: int = ormar.Integer(nullable=True) + written_by: Optional[Writer] = ormar.ForeignKey(Writer) + + +@pytest.fixture(autouse=True) +@pytest.mark.asyncio +async def sample_data(): + album = await Album(name="Malibu").save() + writer1 = await Writer.objects.create(name="John") + writer2 = await Writer.objects.create(name="Sue") + track1 = await Track( + album=album, title="The Bird", position=1, play_count=30, written_by=writer1 + ).save() + track2 = await Track( + album=album, + title="Heart don't stand a chance", + position=2, + play_count=20, + written_by=writer2, + ).save() + tracks3 = await Track( + album=album, title="The Waters", position=3, play_count=10, written_by=writer1 + ).save() + return album, [track1, track2, tracks3] + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_quering_by_reverse_fk(sample_data): + async with database: + async with database.transaction(force_rollback=True): + track1 = sample_data[1][0] + album = await Album.objects.first() + + assert await album.tracks.exists() + assert await album.tracks.count() == 3 + + track = await album.tracks.get_or_create( + title="The Bird", position=1, play_count=30 + ) + assert track == track1 + assert len(album.tracks) == 1 + + track = await album.tracks.get_or_create( + title="The Bird2", position=4, play_count=5 + ) + assert track != track1 + assert track.pk is not None + assert len(album.tracks) == 2 + + await album.tracks.update_or_create(pk=track.pk, play_count=50) + assert len(album.tracks) == 2 + track = await album.tracks.get_or_create(title="The Bird2") + assert track.play_count == 50 + assert len(album.tracks) == 1 + + await album.tracks.remove(track) + assert track.album is None + await track.delete() + + assert len(album.tracks) == 0 + + track6 = await album.tracks.update_or_create( + title="The Bird3", position=4, play_count=5 + ) + assert track6.pk is not None + assert track6.play_count == 5 + + assert len(album.tracks) == 1 + + await album.tracks.remove(track6) + assert track6.album is None + await track6.delete() + + assert len(album.tracks) == 0 + + +@pytest.mark.asyncio +async def test_getting(sample_data): + async with database: + async with database.transaction(force_rollback=True): + album = sample_data[0] + track1 = await album.tracks.fields(["album", "title", "position"]).get( + title="The Bird" + ) + track2 = await album.tracks.exclude_fields("play_count").get( + title="The Bird" + ) + for track in [track1, track2]: + assert track.title == "The Bird" + assert track.album == album + assert track.play_count is None + + assert len(album.tracks) == 1 + + tracks = await album.tracks.all() + assert len(tracks) == 3 + + assert len(album.tracks) == 3 + + tracks = await album.tracks.order_by("play_count").all() + assert len(tracks) == 3 + assert tracks[0].title == "The Waters" + assert tracks[2].title == "The Bird" + + assert len(album.tracks) == 3 + + track = await album.tracks.create( + title="The Bird Fly Away", position=4, play_count=10 + ) + assert track.title == "The Bird Fly Away" + assert track.position == 4 + assert track.album == album + + assert len(album.tracks) == 4 + + tracks = await album.tracks.all() + assert len(tracks) == 4 + + tracks = await album.tracks.limit(2).all() + assert len(tracks) == 2 + + tracks2 = await album.tracks.limit(2).offset(2).all() + assert len(tracks2) == 2 + assert tracks != tracks2 + + tracks3 = await album.tracks.filter(play_count__lt=15).all() + assert len(tracks3) == 2 + + tracks4 = await album.tracks.exclude(play_count__lt=15).all() + assert len(tracks4) == 2 + assert tracks3 != tracks4 + + assert len(album.tracks) == 2 + + await album.tracks.clear() + tracks = await album.tracks.all() + assert len(tracks) == 0 + assert len(album.tracks) == 0 + + +@pytest.mark.asyncio +async def test_loading_related(sample_data): + async with database: + async with database.transaction(force_rollback=True): + album = sample_data[0] + tracks = await album.tracks.select_related("written_by").all() + assert len(tracks) == 3 + assert len(album.tracks) == 3 + for track in tracks: + assert track.written_by is not None + + tracks = await album.tracks.prefetch_related("written_by").all() + assert len(tracks) == 3 + assert len(album.tracks) == 3 + for track in tracks: + assert track.written_by is not None + + +@pytest.mark.asyncio +async def test_adding_removing(sample_data): + async with database: + async with database.transaction(force_rollback=True): + album = sample_data[0] + track_new = await Track(title="Rainbow", position=5, play_count=300).save() + await album.tracks.add(track_new) + assert track_new.album == album + assert len(album.tracks) == 4 + + track_check = await Track.objects.get(title="Rainbow") + assert track_check.album == album + + track_test = await Track.objects.get(title="Rainbow") + assert track_test.album == album + + await album.tracks.remove(track_new) + assert track_new.album is None + assert len(album.tracks) == 3 + + track_test = await Track.objects.get(title="Rainbow") + assert track_test.album is None