from typing import List, Optional, Union import ormar import pytest from ormar.exceptions import QueryDefinitionError from tests.lifespan import init_tests from tests.settings import create_config base_ormar_config = create_config() class Subject(ormar.Model): ormar_config = base_ormar_config.copy(tablename="subjects") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=80) class Author(ormar.Model): ormar_config = base_ormar_config.copy(tablename="authors") id: int = ormar.Integer(primary_key=True) first_name: str = ormar.String(max_length=80) last_name: str = ormar.String(max_length=80) class Category(ormar.Model): ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=40) sort_order: int = ormar.Integer(nullable=True) subject: Optional[Subject] = ormar.ForeignKey(Subject) class PostCategory(ormar.Model): ormar_config = base_ormar_config.copy(tablename="posts_categories") class Post(ormar.Model): ormar_config = base_ormar_config.copy(tablename="posts") id: int = ormar.Integer(primary_key=True) title: str = ormar.String(max_length=200) categories: Optional[Union[Category, List[Category]]] = ormar.ManyToMany( Category, through=PostCategory ) author: Optional[Author] = ormar.ForeignKey(Author) create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_queryset_methods(): async with base_ormar_config.database: async with base_ormar_config.database.transaction(force_rollback=True): guido = await Author.objects.create( first_name="Guido", last_name="Van Rossum" ) subject = await Subject(name="Random").save() post = await Post.objects.create(title="Hello, M2M", author=guido) news = await Category.objects.create( name="News", sort_order=1, subject=subject ) breaking = await Category.objects.create( name="Breaking", sort_order=3, subject=subject ) # Add a category to a post. await post.categories.add(news) await post.categories.add(breaking) category, created = await post.categories.get_or_create(name="News") assert category == news assert len(post.categories) == 1 assert created is False category, created = await post.categories.get_or_create( name="Breaking News" ) assert category != breaking assert category.pk is not None assert len(post.categories) == 2 assert created is True await post.categories.update_or_create(pk=category.pk, name="Urgent News") assert len(post.categories) == 2 cat, created = await post.categories.get_or_create(name="Urgent News") assert cat.pk == category.pk assert len(post.categories) == 1 assert created is False await post.categories.remove(cat) await cat.delete() assert len(post.categories) == 0 category = await post.categories.update_or_create( name="Weather News", sort_order=2, subject=subject ) assert category.pk is not None assert category.posts[0] == post assert len(post.categories) == 1 categories = await post.categories.all() assert len(categories) == 3 == len(post.categories) assert await post.categories.exists() assert 3 == await post.categories.count() categories = await post.categories.limit(2).all() assert len(categories) == 2 == len(post.categories) categories2 = await post.categories.limit(2).offset(1).all() assert len(categories2) == 2 == len(post.categories) assert categories != categories2 categories = await post.categories.order_by("-sort_order").all() assert len(categories) == 3 == len(post.categories) assert post.categories[2].name == "News" assert post.categories[0].name == "Breaking" categories = await post.categories.exclude(name__icontains="news").all() assert len(categories) == 1 == len(post.categories) assert post.categories[0].name == "Breaking" categories = ( await post.categories.filter(name__icontains="news") .order_by("-name") .all() ) assert len(categories) == 2 == len(post.categories) assert post.categories[0].name == "Weather News" assert post.categories[1].name == "News" categories = await post.categories.fields("name").all() assert len(categories) == 3 == len(post.categories) for cat in post.categories: assert cat.sort_order is None categories = await post.categories.exclude_fields("sort_order").all() assert len(categories) == 3 == len(post.categories) for cat in post.categories: assert cat.sort_order is None assert cat.subject.name is None categories = await post.categories.select_related("subject").all() assert len(categories) == 3 == len(post.categories) for cat in post.categories: assert cat.subject.name is not None categories = await post.categories.prefetch_related("subject").all() assert len(categories) == 3 == len(post.categories) for cat in post.categories: assert cat.subject.name is not None @pytest.mark.asyncio async def test_queryset_update(): async with base_ormar_config.database: async with base_ormar_config.database.transaction(force_rollback=True): guido = await Author.objects.create( first_name="Guido", last_name="Van Rossum" ) subject = await Subject(name="Random").save() post = await Post.objects.create(title="Hello, M2M", author=guido) await post.categories.create(name="News", sort_order=1, subject=subject) await post.categories.create(name="Breaking", sort_order=3, subject=subject) await post.categories.order_by("sort_order").all() assert len(post.categories) == 2 assert post.categories[0].sort_order == 1 assert post.categories[0].name == "News" assert post.categories[1].sort_order == 3 assert post.categories[1].name == "Breaking" updated = await post.categories.update(each=True, name="Test") assert updated == 2 await post.categories.order_by("sort_order").all() assert len(post.categories) == 2 assert post.categories[0].name == "Test" assert post.categories[1].name == "Test" updated = await post.categories.filter(sort_order=3).update(name="Test 2") assert updated == 1 await post.categories.order_by("sort_order").all() assert len(post.categories) == 2 assert post.categories[0].name == "Test" assert post.categories[1].name == "Test 2" with pytest.raises(QueryDefinitionError): await post.categories.update(name="Test WRONG")