diff --git a/.coverage b/.coverage index 52f8587..458f78c 100644 Binary files a/.coverage and b/.coverage differ diff --git a/.travis.yml b/.travis.yml index 416d999..d057e1e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,15 +5,28 @@ dist: xenial cache: pip python: - - "3.6" - - "3.7" - - "3.8" + - "3.6" + - "3.7" + - "3.8" + + +services: + - postgresql + - mysql + install: - - pip install -U -r requirements.txt + - pip install -U -r requirements.txt + +before_script: + - psql -c 'create database test_database;' -U postgres + - echo 'create database test_database;' | mysql + script: - - scripts/test.sh + - DATABASE_URL=postgresql://localhost/test_database scripts/test.sh + - DATABASE_URL=mysql://localhost/test_database scripts/test.sh + - DATABASE_URL=sqlite:///test.db scripts/test.sh after_script: - - codecov \ No newline at end of file + - codecov \ No newline at end of file diff --git a/ormar/models/model.py b/ormar/models/model.py index 43a6742..eeff399 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -2,6 +2,7 @@ import itertools from typing import Any, List, Tuple, Union import sqlalchemy +from databases.backends.postgres import Record import ormar.queryset # noqa I100 from ormar.fields.many_to_many import ManyToManyField @@ -88,14 +89,18 @@ class Model(NewBaseModel): return item @classmethod - def extract_prefixed_table_columns( + def extract_prefixed_table_columns( # noqa CCR001 cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str ) -> dict: for column in cls.Meta.table.columns: if column.name not in item: - item[column.name] = row[ + prefixed_name = ( f'{table_prefix + "_" if table_prefix else ""}{column.name}' - ] + ) + # databases does not keep aliases in Record for postgres + source = row._row if isinstance(row, Record) else row + item[column.name] = source[prefixed_name] + return item async def save(self) -> "Model": @@ -106,7 +111,8 @@ class Model(NewBaseModel): expr = self.Meta.table.insert() expr = expr.values(**self_fields) item_id = await self.Meta.database.execute(expr) - setattr(self, self.Meta.pkname, item_id) + if item_id: # postgress does not return id if it's already there + setattr(self, self.Meta.pkname, item_id) return self async def update(self, **kwargs: Any) -> "Model": diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 79f9ff2..a6c82f8 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -9,7 +9,6 @@ from ormar.queryset import FilterQuery from ormar.queryset.clause import QueryClause from ormar.queryset.query import Query - if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -187,5 +186,6 @@ class QuerySet: # Execute the insert, and return a new model instance. instance = self.model_cls(**kwargs) pk = await self.database.execute(expr) - setattr(instance, self.model_cls.Meta.pkname, pk) + if pk: + setattr(instance, self.model_cls.Meta.pkname, pk) return instance diff --git a/ormar/relations/alias_manager.py b/ormar/relations/alias_manager.py index 1af8c22..64f7261 100644 --- a/ormar/relations/alias_manager.py +++ b/ormar/relations/alias_manager.py @@ -8,7 +8,8 @@ from sqlalchemy import text def get_table_alias() -> str: - return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] + alias = "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] + return alias.lower() class AliasManager: diff --git a/requirements.txt b/requirements.txt index 807e704..4bad8bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,20 @@ databases[sqlite] +databases[postgresql] +databases[mysql] pydantic sqlalchemy +# Async database drivers +aiomysql +aiosqlite +aiopg +asyncpg + +# Sync database drivers for standard tooling around setup/teardown/migrations. +pymysql +psycopg2-binary +mysqlclient + # Testing pytest pytest-cov diff --git a/tests/settings.py b/tests/settings.py index 697acb0..4cf7350 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -1,3 +1,11 @@ import os +import databases + +assert "DATABASE_URL" in os.environ, "DATABASE_URL is not set." + +DATABASE_URL = os.environ['DATABASE_URL'] +database_url = databases.DatabaseURL(DATABASE_URL) +if database_url.scheme == "postgresql+aiopg": # pragma no cover + DATABASE_URL = str(database_url.replace(driver=None)) DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///test.db") diff --git a/tests/test_columns.py b/tests/test_columns.py index 15382b0..b2a7bd0 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -1,4 +1,5 @@ import datetime +import os import databases import pytest diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index f7f2625..fcbf479 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -1,7 +1,7 @@ import databases import sqlalchemy from fastapi import FastAPI -from fastapi.testclient import TestClient +from starlette.testclient import TestClient import ormar from tests.settings import DATABASE_URL @@ -38,18 +38,17 @@ async def create_item(item: Item): return item -client = TestClient(app) - - def test_read_main(): - response = client.post( - "/items/", json={"name": "test", "id": 1, "category": {"name": "test cat"}} - ) - assert response.status_code == 200 - assert response.json() == { - "category": {"id": None, "name": "test cat"}, - "id": 1, - "name": "test", - } - item = Item(**response.json()) - assert item.id == 1 + client = TestClient(app) + with client as client: + response = client.post( + "/items/", json={"name": "test", "id": 1, "category": {"name": "test cat"}} + ) + assert response.status_code == 200 + assert response.json() == { + "category": {"id": None, "name": "test cat"}, + "id": 1, + "name": "test", + } + item = Item(**response.json()) + assert item.id == 1 diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 98a7d5b..75bb1fd 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -78,6 +78,7 @@ class Member(ormar.Model): @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) metadata.create_all(engine) yield metadata.drop_all(engine) @@ -85,8 +86,9 @@ def create_test_database(): @pytest.mark.asyncio async def test_wrong_query_foreign_key_type(): - with pytest.raises(RelationshipInstanceError): - Track(title="The Error", album="wrong_pk_type") + async with database: + with pytest.raises(RelationshipInstanceError): + Track(title="The Error", album="wrong_pk_type") @pytest.mark.asyncio @@ -99,242 +101,252 @@ async def test_setting_explicitly_empty_relation(): @pytest.mark.asyncio async def test_related_name(): async with database: - album = await Album.objects.create(name="Vanilla") - await Cover.objects.create(album=album, title="The cover file") - - assert len(album.cover_pictures) == 1 - + async with database.transaction(force_rollback=True): + album = await Album.objects.create(name="Vanilla") + await Cover.objects.create(album=album, title="The cover file") + assert len(album.cover_pictures) == 1 @pytest.mark.asyncio async def test_model_crud(): async with database: - album = Album(name="Malibu") - await album.save() - track1 = Track(album=album, title="The Bird", position=1) - track2 = Track(album=album, title="Heart don't stand a chance", position=2) - track3 = Track(album=album, title="The Waters", position=3) - await track1.save() - await track2.save() - await track3.save() + async with database.transaction(force_rollback=True): + album = Album(name="Jamaica") + await album.save() + track1 = Track(album=album, title="The Bird", position=1) + track2 = Track(album=album, title="Heart don't stand a chance", position=2) + track3 = Track(album=album, title="The Waters", position=3) + await track1.save() + await track2.save() + await track3.save() - track = await Track.objects.get(title="The Bird") - assert track.album.pk == album.pk - assert isinstance(track.album, ormar.Model) - assert track.album.name is None - await track.album.load() - assert track.album.name == "Malibu" + track = await Track.objects.get(title="The Bird") + assert track.album.pk == album.pk + assert isinstance(track.album, ormar.Model) + assert track.album.name is None + await track.album.load() + assert track.album.name == "Jamaica" - assert len(album.tracks) == 3 - assert album.tracks[1].title == "Heart don't stand a chance" + assert len(album.tracks) == 3 + assert album.tracks[1].title == "Heart don't stand a chance" - album1 = await Album.objects.get(name="Malibu") - assert album1.pk == 1 - assert album1.tracks == [] + album1 = await Album.objects.get(name="Jamaica") + assert album1.pk == album.pk + assert album1.tracks == [] - await Track.objects.create( - album={"id": track.album.pk}, title="The Bird2", position=4 - ) + await Track.objects.create( + album={"id": track.album.pk}, title="The Bird2", position=4 + ) @pytest.mark.asyncio async def test_select_related(): async with database: - album = Album(name="Malibu") - await album.save() - track1 = Track(album=album, title="The Bird", position=1) - track2 = Track(album=album, title="Heart don't stand a chance", position=2) - track3 = Track(album=album, title="The Waters", position=3) - await track1.save() - await track2.save() - await track3.save() + async with database.transaction(force_rollback=True): + album = Album(name="Malibu") + await album.save() + track1 = Track(album=album, title="The Bird", position=1) + track2 = Track(album=album, title="Heart don't stand a chance", position=2) + track3 = Track(album=album, title="The Waters", position=3) + await track1.save() + await track2.save() + await track3.save() - fantasies = Album(name="Fantasies") - await fantasies.save() - track4 = Track(album=fantasies, title="Help I'm Alive", position=1) - track5 = Track(album=fantasies, title="Sick Muse", position=2) - track6 = Track(album=fantasies, title="Satellite Mind", position=3) - await track4.save() - await track5.save() - await track6.save() + fantasies = Album(name="Fantasies") + await fantasies.save() + track4 = Track(album=fantasies, title="Help I'm Alive", position=1) + track5 = Track(album=fantasies, title="Sick Muse", position=2) + track6 = Track(album=fantasies, title="Satellite Mind", position=3) + await track4.save() + await track5.save() + await track6.save() - track = await Track.objects.select_related("album").get(title="The Bird") - assert track.album.name == "Malibu" + track = await Track.objects.select_related("album").get(title="The Bird") + assert track.album.name == "Malibu" - tracks = await Track.objects.select_related("album").all() - assert len(tracks) == 6 + tracks = await Track.objects.select_related("album").all() + assert len(tracks) == 6 @pytest.mark.asyncio async def test_model_removal_from_relations(): async with database: - album = Album(name="Chichi") - await album.save() - track1 = Track(album=album, title="The Birdman", position=1) - track2 = Track(album=album, title="Superman", position=2) - track3 = Track(album=album, title="Wonder Woman", position=3) - await track1.save() - await track2.save() - await track3.save() + async with database.transaction(force_rollback=True): + album = Album(name="Chichi") + await album.save() + track1 = Track(album=album, title="The Birdman", position=1) + track2 = Track(album=album, title="Superman", position=2) + track3 = Track(album=album, title="Wonder Woman", position=3) + await track1.save() + await track2.save() + await track3.save() - assert len(album.tracks) == 3 - await album.tracks.remove(track1) - assert len(album.tracks) == 2 - assert track1.album is None + assert len(album.tracks) == 3 + await album.tracks.remove(track1) + assert len(album.tracks) == 2 + assert track1.album is None - await track1.update() - track1 = await Track.objects.get(title="The Birdman") - assert track1.album is None + await track1.update() + track1 = await Track.objects.get(title="The Birdman") + assert track1.album is None - await album.tracks.add(track1) - assert len(album.tracks) == 3 - assert track1.album == album + await album.tracks.add(track1) + assert len(album.tracks) == 3 + assert track1.album == album - await track1.update() - track1 = await Track.objects.select_related("album__tracks").get( - title="The Birdman" - ) - album = await Album.objects.select_related("tracks").get(name="Chichi") - assert track1.album == album + await track1.update() + track1 = await Track.objects.select_related("album__tracks").get( + title="The Birdman" + ) + album = await Album.objects.select_related("tracks").get(name="Chichi") + assert track1.album == album - track1.remove(album) - assert track1.album is None - assert len(album.tracks) == 2 + track1.remove(album) + assert track1.album is None + assert len(album.tracks) == 2 + + track2.remove(album) + assert track2.album is None + assert len(album.tracks) == 1 - track2.remove(album) - assert track2.album is None - assert len(album.tracks) == 1 @pytest.mark.asyncio async def test_fk_filter(): async with database: - malibu = Album(name="Malibu%") - await malibu.save() - await Track.objects.create(album=malibu, title="The Bird", position=1) - await Track.objects.create( - album=malibu, title="Heart don't stand a chance", position=2 - ) - await Track.objects.create(album=malibu, title="The Waters", position=3) + async with database.transaction(force_rollback=True): + malibu = Album(name="Malibu%") + await malibu.save() + await Track.objects.create(album=malibu, title="The Bird", position=1) + await Track.objects.create( + album=malibu, title="Heart don't stand a chance", position=2 + ) + await Track.objects.create(album=malibu, title="The Waters", position=3) - fantasies = await Album.objects.create(name="Fantasies") - await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) - await Track.objects.create(album=fantasies, title="Sick Muse", position=2) - await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) + fantasies = await Album.objects.create(name="Fantasies") + await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) - tracks = ( - await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() - ) - assert len(tracks) == 3 - for track in tracks: - assert track.album.name == "Fantasies" + tracks = ( + await Track.objects.select_related("album") + .filter(album__name="Fantasies") + .all() + ) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" - tracks = ( - await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() - ) - assert len(tracks) == 3 - for track in tracks: - assert track.album.name == "Fantasies" + tracks = ( + await Track.objects.select_related("album") + .filter(album__name__icontains="fan") + .all() + ) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" - tracks = await Track.objects.filter(album__name__contains="fan").all() - assert len(tracks) == 3 - for track in tracks: - assert track.album.name == "Fantasies" + tracks = await Track.objects.filter(album__name__contains="Fan").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" - tracks = await Track.objects.filter(album__name__contains="Malibu%").all() - assert len(tracks) == 3 + tracks = await Track.objects.filter(album__name__contains="Malibu%").all() + assert len(tracks) == 3 - tracks = await Track.objects.filter(album=malibu).select_related("album").all() - assert len(tracks) == 3 - for track in tracks: - assert track.album.name == "Malibu%" + tracks = await Track.objects.filter(album=malibu).select_related("album").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Malibu%" - tracks = await Track.objects.select_related("album").all(album=malibu) - assert len(tracks) == 3 - for track in tracks: - assert track.album.name == "Malibu%" + tracks = await Track.objects.select_related("album").all(album=malibu) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Malibu%" @pytest.mark.asyncio async def test_multiple_fk(): async with database: - acme = await Organisation.objects.create(ident="ACME Ltd") - red_team = await Team.objects.create(org=acme, name="Red Team") - blue_team = await Team.objects.create(org=acme, name="Blue Team") - await Member.objects.create(team=red_team, email="a@example.org") - await Member.objects.create(team=red_team, email="b@example.org") - await Member.objects.create(team=blue_team, email="c@example.org") - await Member.objects.create(team=blue_team, email="d@example.org") + async with database.transaction(force_rollback=True): + acme = await Organisation.objects.create(ident="ACME Ltd") + red_team = await Team.objects.create(org=acme, name="Red Team") + blue_team = await Team.objects.create(org=acme, name="Blue Team") + await Member.objects.create(team=red_team, email="a@example.org") + await Member.objects.create(team=red_team, email="b@example.org") + await Member.objects.create(team=blue_team, email="c@example.org") + await Member.objects.create(team=blue_team, email="d@example.org") - other = await Organisation.objects.create(ident="Other ltd") - team = await Team.objects.create(org=other, name="Green Team") - await Member.objects.create(team=team, email="e@example.org") + other = await Organisation.objects.create(ident="Other ltd") + team = await Team.objects.create(org=other, name="Green Team") + await Member.objects.create(team=team, email="e@example.org") - members = ( - await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() - ) - assert len(members) == 4 - for member in members: - assert member.team.org.ident == "ACME Ltd" + members = ( + await Member.objects.select_related("team__org") + .filter(team__org__ident="ACME Ltd") + .all() + ) + assert len(members) == 4 + for member in members: + assert member.team.org.ident == "ACME Ltd" @pytest.mark.asyncio async def test_pk_filter(): async with database: - fantasies = await Album.objects.create(name="Test") - await Track.objects.create(album=fantasies, title="Test1", position=1) - await Track.objects.create(album=fantasies, title="Test2", position=2) - await Track.objects.create(album=fantasies, title="Test3", position=3) - tracks = await Track.objects.select_related("album").filter(pk=1).all() - assert len(tracks) == 1 + async with database.transaction(force_rollback=True): + fantasies = await Album.objects.create(name="Test") + track = await Track.objects.create(album=fantasies, title="Test1", position=1) + await Track.objects.create(album=fantasies, title="Test2", position=2) + await Track.objects.create(album=fantasies, title="Test3", position=3) + tracks = await Track.objects.select_related("album").filter(pk=track.pk).all() + assert len(tracks) == 1 - tracks = ( - await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() - ) - assert len(tracks) == 1 + tracks = ( + await Track.objects.select_related("album") + .filter(position=2, album__name="Test") + .all() + ) + assert len(tracks) == 1 @pytest.mark.asyncio async def test_limit_and_offset(): async with database: - fantasies = await Album.objects.create(name="Limitless") - await Track.objects.create(id=None, album=fantasies, title="Sample", position=1) - await Track.objects.create(album=fantasies, title="Sample2", position=2) - await Track.objects.create(album=fantasies, title="Sample3", position=3) + async with database.transaction(force_rollback=True): + fantasies = await Album.objects.create(name="Limitless") + await Track.objects.create(id=None, album=fantasies, title="Sample", position=1) + await Track.objects.create(album=fantasies, title="Sample2", position=2) + await Track.objects.create(album=fantasies, title="Sample3", position=3) - tracks = await Track.objects.limit(1).all() - assert len(tracks) == 1 - assert tracks[0].title == "Sample" + tracks = await Track.objects.limit(1).all() + assert len(tracks) == 1 + assert tracks[0].title == "Sample" - tracks = await Track.objects.limit(1).offset(1).all() - assert len(tracks) == 1 - assert tracks[0].title == "Sample2" + tracks = await Track.objects.limit(1).offset(1).all() + assert len(tracks) == 1 + assert tracks[0].title == "Sample2" @pytest.mark.asyncio async def test_get_exceptions(): async with database: - fantasies = await Album.objects.create(name="Test") + async with database.transaction(force_rollback=True): + fantasies = await Album.objects.create(name="Test") - with pytest.raises(NoMatch): - await Album.objects.get(name="Test2") + with pytest.raises(NoMatch): + await Album.objects.get(name="Test2") - await Track.objects.create(album=fantasies, title="Test1", position=1) - await Track.objects.create(album=fantasies, title="Test2", position=2) - await Track.objects.create(album=fantasies, title="Test3", position=3) - with pytest.raises(MultipleMatches): - await Track.objects.select_related("album").get(album=fantasies) + await Track.objects.create(album=fantasies, title="Test1", position=1) + await Track.objects.create(album=fantasies, title="Test2", position=2) + await Track.objects.create(album=fantasies, title="Test3", position=3) + with pytest.raises(MultipleMatches): + await Track.objects.select_related("album").get(album=fantasies) @pytest.mark.asyncio async def test_wrong_model_passed_as_fk(): - with pytest.raises(RelationshipInstanceError): - org = await Organisation.objects.create(ident="ACME Ltd") - await Track.objects.create(album=org, title="Test1", position=1) + async with database: + async with database.transaction(force_rollback=True): + with pytest.raises(RelationshipInstanceError): + org = await Organisation.objects.create(ident="ACME Ltd") + await Track.objects.create(album=org, title="Test1", position=1) diff --git a/tests/test_many_to_many.py b/tests/test_many_to_many.py index a2963e8..4111458 100644 --- a/tests/test_many_to_many.py +++ b/tests/test_many_to_many.py @@ -1,3 +1,5 @@ +import asyncio + import databases import pytest import sqlalchemy @@ -50,8 +52,15 @@ class Post(ormar.Model): author: ormar.ForeignKey(Author) +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + @pytest.fixture(autouse=True, scope="module") -def create_test_database(): +async def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) metadata.create_all(engine) yield @@ -61,118 +70,124 @@ def create_test_database(): @pytest.fixture(scope="function") async def cleanup(): yield - await PostCategory.objects.delete() - await Post.objects.delete() - await Category.objects.delete() - await Author.objects.delete() + async with database: + await PostCategory.objects.delete() + await Post.objects.delete() + await Category.objects.delete() + await Author.objects.delete() @pytest.mark.asyncio async def test_assigning_related_objects(cleanup): - guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") - post = await Post.objects.create(title="Hello, M2M", author=guido) - news = await Category.objects.create(name="News") + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") - # Add a category to a post. - await post.categories.add(news) - # or from the other end: - await news.posts.add(post) + # Add a category to a post. + await post.categories.add(news) + # or from the other end: + await news.posts.add(post) - # Creating related object from instance: - await post.categories.create(name="Tips") - assert len(post.categories) == 2 + # Creating related object from instance: + await post.categories.create(name="Tips") + assert len(post.categories) == 2 - post_categories = await post.categories.all() - assert len(post_categories) == 2 + post_categories = await post.categories.all() + assert len(post_categories) == 2 @pytest.mark.asyncio async def test_quering_of_the_m2m_models(cleanup): - # orm can do this already. - guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") - post = await Post.objects.create(title="Hello, M2M", author=guido) - news = await Category.objects.create(name="News") - # tl;dr: `post.categories` exposes the QuerySet API. + async with database: + # orm can do this already. + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") + # tl;dr: `post.categories` exposes the QuerySet API. - await post.categories.add(news) + await post.categories.add(news) - post_categories = await post.categories.all() - assert len(post_categories) == 1 + post_categories = await post.categories.all() + assert len(post_categories) == 1 - assert news == await post.categories.get(name="News") + assert news == await post.categories.get(name="News") - num_posts = await news.posts.count() - assert num_posts == 1 + num_posts = await news.posts.count() + assert num_posts == 1 - posts_about_m2m = await news.posts.filter(title__contains="M2M").all() - assert len(posts_about_m2m) == 1 - assert posts_about_m2m[0] == post - posts_about_python = await Post.objects.filter(categories__name="python").all() - assert len(posts_about_python) == 0 + posts_about_m2m = await news.posts.filter(title__contains="M2M").all() + assert len(posts_about_m2m) == 1 + assert posts_about_m2m[0] == post + posts_about_python = await Post.objects.filter(categories__name="python").all() + assert len(posts_about_python) == 0 - # Traversal of relationships: which categories has Guido contributed to? - category = await Category.objects.filter(posts__author=guido).get() - assert category == news - # or: - category2 = await Category.objects.filter(posts__author__first_name="Guido").get() - assert category2 == news + # Traversal of relationships: which categories has Guido contributed to? + category = await Category.objects.filter(posts__author=guido).get() + assert category == news + # or: + category2 = await Category.objects.filter(posts__author__first_name="Guido").get() + assert category2 == news @pytest.mark.asyncio async def test_removal_of_the_relations(cleanup): - guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") - post = await Post.objects.create(title="Hello, M2M", author=guido) - news = await Category.objects.create(name="News") - await post.categories.add(news) - assert len(await post.categories.all()) == 1 - await post.categories.remove(news) - assert len(await post.categories.all()) == 0 - # or: - await news.posts.add(post) - assert len(await news.posts.all()) == 1 - await news.posts.remove(post) - assert len(await news.posts.all()) == 0 + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") + await post.categories.add(news) + assert len(await post.categories.all()) == 1 + await post.categories.remove(news) + assert len(await post.categories.all()) == 0 + # or: + await news.posts.add(post) + assert len(await news.posts.all()) == 1 + await news.posts.remove(post) + assert len(await news.posts.all()) == 0 - # Remove all related objects: - await post.categories.add(news) - await post.categories.clear() - assert len(await post.categories.all()) == 0 + # Remove all related objects: + await post.categories.add(news) + await post.categories.clear() + assert len(await post.categories.all()) == 0 - # post would also lose 'news' category when running: - await post.categories.add(news) - await news.delete() - assert len(await post.categories.all()) == 0 + # post would also lose 'news' category when running: + await post.categories.add(news) + await news.delete() + assert len(await post.categories.all()) == 0 @pytest.mark.asyncio async def test_selecting_related(cleanup): - guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") - post = await Post.objects.create(title="Hello, M2M", author=guido) - news = await Category.objects.create(name="News") - recent = await Category.objects.create(name="Recent") - await post.categories.add(news) - await post.categories.add(recent) - assert len(await post.categories.all()) == 2 - # Loads categories and posts (2 queries) and perform the join in Python. - categories = await Category.objects.select_related("posts").all() - # No extra queries needed => no more `await`s required. - for category in categories: - assert category.posts[0] == post + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") + recent = await Category.objects.create(name="Recent") + await post.categories.add(news) + await post.categories.add(recent) + assert len(await post.categories.all()) == 2 + # Loads categories and posts (2 queries) and perform the join in Python. + categories = await Category.objects.select_related("posts").all() + # No extra queries needed => no more `await`s required. + for category in categories: + assert category.posts[0] == post - news_posts = await news.posts.select_related("author").all() - assert news_posts[0].author == guido + news_posts = await news.posts.select_related("author").all() + assert news_posts[0].author == guido - assert (await post.categories.limit(1).all())[0] == news - assert (await post.categories.offset(1).limit(1).all())[0] == recent + assert (await post.categories.limit(1).all())[0] == news + assert (await post.categories.offset(1).limit(1).all())[0] == recent - assert await post.categories.first() == news + assert await post.categories.first() == news - assert await post.categories.exists() + assert await post.categories.exists() @pytest.mark.asyncio async def test_selecting_related_fail_without_saving(cleanup): - guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") - post = Post(title="Hello, M2M", author=guido) - with pytest.raises(RelationshipInstanceError): - await post.categories.all() + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = Post(title="Hello, M2M", author=guido) + with pytest.raises(RelationshipInstanceError): + await post.categories.all() diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index ab2845c..c374585 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -112,7 +112,6 @@ def test_sqlalchemy_table_is_created(example): def test_no_pk_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): class Meta: tablename = "example3" @@ -123,7 +122,6 @@ def test_no_pk_in_model_definition(): def test_two_pks_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): class Meta: tablename = "example3" @@ -135,7 +133,6 @@ def test_two_pks_in_model_definition(): def test_setting_pk_column_as_pydantic_only_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): class Meta: tablename = "example4" @@ -146,7 +143,6 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition(): def test_decimal_error_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): class Meta: tablename = "example5" @@ -157,7 +153,6 @@ def test_decimal_error_in_model_definition(): def test_string_error_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): class Meta: tablename = "example6" diff --git a/tests/test_models.py b/tests/test_models.py index 758c1bd..79381fe 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,3 +1,5 @@ +import asyncio + import databases import pydantic import pytest @@ -43,9 +45,17 @@ class Product(ormar.Model): in_stock: ormar.Boolean(default=False) +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + @pytest.fixture(autouse=True, scope="module") -def create_test_database(): +async def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) metadata.create_all(engine) yield metadata.drop_all(engine) @@ -69,166 +79,177 @@ def test_model_pk(): @pytest.mark.asyncio async def test_json_column(): async with database: - await JsonSample.objects.create(test_json=dict(aa=12)) - await JsonSample.objects.create(test_json='{"aa": 12}') + async with database.transaction(force_rollback=True): + await JsonSample.objects.create(test_json=dict(aa=12)) + await JsonSample.objects.create(test_json='{"aa": 12}') - items = await JsonSample.objects.all() - assert len(items) == 2 - assert items[0].test_json == dict(aa=12) - assert items[1].test_json == dict(aa=12) + items = await JsonSample.objects.all() + assert len(items) == 2 + assert items[0].test_json == dict(aa=12) + assert items[1].test_json == dict(aa=12) @pytest.mark.asyncio async def test_model_crud(): async with database: - users = await User.objects.all() - assert users == [] + async with database.transaction(force_rollback=True): + users = await User.objects.all() + assert users == [] - user = await User.objects.create(name="Tom") - users = await User.objects.all() - assert user.name == "Tom" - assert user.pk is not None - assert users == [user] + user = await User.objects.create(name="Tom") + users = await User.objects.all() + assert user.name == "Tom" + assert user.pk is not None + assert users == [user] - lookup = await User.objects.get() - assert lookup == user + lookup = await User.objects.get() + assert lookup == user - await user.update(name="Jane") - users = await User.objects.all() - assert user.name == "Jane" - assert user.pk is not None - assert users == [user] + await user.update(name="Jane") + users = await User.objects.all() + assert user.name == "Jane" + assert user.pk is not None + assert users == [user] - await user.delete() - users = await User.objects.all() - assert users == [] + await user.delete() + users = await User.objects.all() + assert users == [] @pytest.mark.asyncio async def test_model_get(): async with database: - with pytest.raises(ormar.NoMatch): - await User.objects.get() + async with database.transaction(force_rollback=True): + with pytest.raises(ormar.NoMatch): + await User.objects.get() - user = await User.objects.create(name="Tom") - lookup = await User.objects.get() - assert lookup == user + user = await User.objects.create(name="Tom") + lookup = await User.objects.get() + assert lookup == user - user = await User.objects.create(name="Jane") - with pytest.raises(ormar.MultipleMatches): - await User.objects.get() + user = await User.objects.create(name="Jane") + with pytest.raises(ormar.MultipleMatches): + await User.objects.get() - same_user = await User.objects.get(pk=user.id) - assert same_user.id == user.id - assert same_user.pk == user.pk + same_user = await User.objects.get(pk=user.id) + assert same_user.id == user.id + assert same_user.pk == user.pk @pytest.mark.asyncio async def test_model_filter(): async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Jane") - await User.objects.create(name="Lucy") + async with database.transaction(force_rollback=True): + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") - user = await User.objects.get(name="Lucy") - assert user.name == "Lucy" + user = await User.objects.get(name="Lucy") + assert user.name == "Lucy" - with pytest.raises(ormar.NoMatch): - await User.objects.get(name="Jim") + with pytest.raises(ormar.NoMatch): + await User.objects.get(name="Jim") - await Product.objects.create(name="T-Shirt", rating=5, in_stock=True) - await Product.objects.create(name="Dress", rating=4) - await Product.objects.create(name="Coat", rating=3, in_stock=True) + await Product.objects.create(name="T-Shirt", rating=5, in_stock=True) + await Product.objects.create(name="Dress", rating=4) + await Product.objects.create(name="Coat", rating=3, in_stock=True) - product = await Product.objects.get(name__iexact="t-shirt", rating=5) - assert product.pk is not None - assert product.name == "T-Shirt" - assert product.rating == 5 + product = await Product.objects.get(name__iexact="t-shirt", rating=5) + assert product.pk is not None + assert product.name == "T-Shirt" + assert product.rating == 5 - products = await Product.objects.all(rating__gte=2, in_stock=True) - assert len(products) == 2 + products = await Product.objects.all(rating__gte=2, in_stock=True) + assert len(products) == 2 - products = await Product.objects.all(name__icontains="T") - assert len(products) == 2 + products = await Product.objects.all(name__icontains="T") + assert len(products) == 2 - # Test escaping % character from icontains, contains, and iexact - await Product.objects.create(name="100%-Cotton", rating=3) - await Product.objects.create(name="Cotton-100%-Egyptian", rating=3) - await Product.objects.create(name="Cotton-100%", rating=3) - products = Product.objects.filter(name__iexact="100%-cotton") - assert await products.count() == 1 + # Test escaping % character from icontains, contains, and iexact + await Product.objects.create(name="100%-Cotton", rating=3) + await Product.objects.create(name="Cotton-100%-Egyptian", rating=3) + await Product.objects.create(name="Cotton-100%", rating=3) + products = Product.objects.filter(name__iexact="100%-cotton") + assert await products.count() == 1 - products = Product.objects.filter(name__contains="%") - assert await products.count() == 3 + products = Product.objects.filter(name__contains="%") + assert await products.count() == 3 - products = Product.objects.filter(name__icontains="%") - assert await products.count() == 3 + products = Product.objects.filter(name__icontains="%") + assert await products.count() == 3 @pytest.mark.asyncio async def test_wrong_query_contains_model(): - with pytest.raises(QueryDefinitionError): - product = Product(name="90%-Cotton", rating=2) - await Product.objects.filter(name__contains=product).count() + async with database: + with pytest.raises(QueryDefinitionError): + product = Product(name="90%-Cotton", rating=2) + await Product.objects.filter(name__contains=product).count() @pytest.mark.asyncio async def test_model_exists(): async with database: - await User.objects.create(name="Tom") - assert await User.objects.filter(name="Tom").exists() is True - assert await User.objects.filter(name="Jane").exists() is False + async with database.transaction(force_rollback=True): + await User.objects.create(name="Tom") + assert await User.objects.filter(name="Tom").exists() is True + assert await User.objects.filter(name="Jane").exists() is False @pytest.mark.asyncio async def test_model_count(): async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Jane") - await User.objects.create(name="Lucy") + async with database.transaction(force_rollback=True): + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") - assert await User.objects.count() == 3 - assert await User.objects.filter(name__icontains="T").count() == 1 + assert await User.objects.count() == 3 + assert await User.objects.filter(name__icontains="T").count() == 1 @pytest.mark.asyncio async def test_model_limit(): async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Jane") - await User.objects.create(name="Lucy") + async with database.transaction(force_rollback=True): + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") - assert len(await User.objects.limit(2).all()) == 2 + assert len(await User.objects.limit(2).all()) == 2 @pytest.mark.asyncio async def test_model_limit_with_filter(): async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Tom") - await User.objects.create(name="Tom") + async with database.transaction(force_rollback=True): + await User.objects.create(name="Tom") + await User.objects.create(name="Tom") + await User.objects.create(name="Tom") - assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2 + assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2 @pytest.mark.asyncio async def test_offset(): async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Jane") + async with database.transaction(force_rollback=True): + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") - users = await User.objects.offset(1).limit(1).all() - assert users[0].name == "Jane" + users = await User.objects.offset(1).limit(1).all() + assert users[0].name == "Jane" @pytest.mark.asyncio async def test_model_first(): async with database: - tom = await User.objects.create(name="Tom") - jane = await User.objects.create(name="Jane") + async with database.transaction(force_rollback=True): + tom = await User.objects.create(name="Tom") + jane = await User.objects.create(name="Jane") - assert await User.objects.first() == tom - assert await User.objects.first(name="Jane") == jane - assert await User.objects.filter(name="Jane").first() == jane - with pytest.raises(NoMatch): - await User.objects.filter(name="Lucy").first() + assert await User.objects.first() == tom + assert await User.objects.first(name="Jane") == jane + assert await User.objects.filter(name="Jane").first() == jane + with pytest.raises(NoMatch): + await User.objects.filter(name="Lucy").first() diff --git a/tests/test_more_reallife_fastapi.py b/tests/test_more_reallife_fastapi.py index e01ce44..f0b9b88 100644 --- a/tests/test_more_reallife_fastapi.py +++ b/tests/test_more_reallife_fastapi.py @@ -112,7 +112,7 @@ def test_all_endpoints(): assert items[0].name == "New name" response = client.delete(f"/items/{item.pk}", json=item.dict()) - assert response.json().get("deleted_rows") == 1 + assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__" response = client.get("/items/") items = response.json() assert len(items) == 0 diff --git a/tests/test_more_same_table_joins.py b/tests/test_more_same_table_joins.py index 3718bc0..5c4ac3d 100644 --- a/tests/test_more_same_table_joins.py +++ b/tests/test_more_same_table_joins.py @@ -78,6 +78,11 @@ async def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) metadata.drop_all(engine) metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +async def create_data(): department = await Department.objects.create(id=1, name="Math Department") department2 = await Department.objects.create(id=2, name="Law Department") class1 = await SchoolClass.objects.create(name="Math") @@ -88,13 +93,11 @@ async def create_test_database(): await Student.objects.create(name="Judy", category=category2, schoolclass=class1) await Student.objects.create(name="Jack", category=category2, schoolclass=class2) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) - yield - metadata.drop_all(engine) - @pytest.mark.asyncio async def test_model_multiple_instances_of_same_table_in_schema(): async with database: + await create_data() classes = await SchoolClass.objects.select_related( ["teachers__category__department", "students"] ).all() diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 5b8ffd0..9eda9d7 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -78,6 +78,11 @@ async def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) metadata.drop_all(engine) metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +async def create_data(): department = await Department.objects.create(id=1, name="Math Department") department2 = await Department.objects.create(id=2, name="Law Department") class1 = await SchoolClass.objects.create(name="Math", department=department) @@ -88,52 +93,56 @@ async def create_test_database(): await Student.objects.create(name="Judy", category=category2, schoolclass=class1) await Student.objects.create(name="Jack", category=category2, schoolclass=class2) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) - yield - metadata.drop_all(engine) @pytest.mark.asyncio async def test_model_multiple_instances_of_same_table_in_schema(): async with database: - classes = await SchoolClass.objects.select_related( - ["teachers__category", "students"] - ).all() - assert classes[0].name == "Math" - assert classes[0].students[0].name == "Jane" + async with database.transaction(force_rollback=True): + await create_data() + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students"] + ).all() + assert classes[0].name == "Math" + assert classes[0].students[0].name == "Jane" - assert len(classes[0].dict().get("students")) == 2 + assert len(classes[0].dict().get("students")) == 2 - # since it's going from schoolclass => teacher => schoolclass (same class) department is already populated - assert classes[0].students[0].schoolclass.name == "Math" - assert classes[0].students[0].schoolclass.department.name is None - await classes[0].students[0].schoolclass.department.load() - assert classes[0].students[0].schoolclass.department.name == "Math Department" + # since it's going from schoolclass => teacher => schoolclass (same class) department is already populated + assert classes[0].students[0].schoolclass.name == "Math" + assert classes[0].students[0].schoolclass.department.name is None + await classes[0].students[0].schoolclass.department.load() + assert classes[0].students[0].schoolclass.department.name == "Math Department" - await classes[1].students[0].schoolclass.department.load() - assert classes[1].students[0].schoolclass.department.name == "Law Department" + await classes[1].students[0].schoolclass.department.load() + assert classes[1].students[0].schoolclass.department.name == "Law Department" @pytest.mark.asyncio async def test_right_tables_join(): async with database: - classes = await SchoolClass.objects.select_related( - ["teachers__category", "students"] - ).all() - assert classes[0].teachers[0].category.name == "Domestic" + async with database.transaction(force_rollback=True): + await create_data() + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students"] + ).all() + assert classes[0].teachers[0].category.name == "Domestic" - assert classes[0].students[0].category.name is None - await classes[0].students[0].category.load() - assert classes[0].students[0].category.name == "Foreign" + assert classes[0].students[0].category.name is None + await classes[0].students[0].category.load() + assert classes[0].students[0].category.name == "Foreign" @pytest.mark.asyncio async def test_multiple_reverse_related_objects(): async with database: - classes = await SchoolClass.objects.select_related( - ["teachers__category", "students__category"] - ).all() - assert classes[0].name == "Math" - assert classes[0].students[1].name == "Judy" - assert classes[0].students[0].category.name == "Foreign" - assert classes[0].students[1].category.name == "Domestic" - assert classes[0].teachers[0].category.name == "Domestic" + async with database.transaction(force_rollback=True): + await create_data() + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students__category"] + ).all() + assert classes[0].name == "Math" + assert classes[0].students[1].name == "Judy" + assert classes[0].students[0].category.name == "Foreign" + assert classes[0].students[1].category.name == "Domestic" + assert classes[0].teachers[0].category.name == "Domestic"