From 1451ec86715a527bd830a19fbea7fdd751ef8de2 Mon Sep 17 00:00:00 2001 From: collerek Date: Thu, 17 Sep 2020 13:02:34 +0200 Subject: [PATCH] add tests for mysql and postgress, some fixes for those backends --- .coverage | Bin 53248 -> 53248 bytes .travis.yml | 6 +- ormar/models/model.py | 47 ++-- ormar/queryset/queryset.py | 6 +- ormar/relations/alias_manager.py | 3 +- requirements.txt | 5 + tests/settings.py | 8 +- tests/test_columns.py | 57 ++--- tests/test_fastapi_usage.py | 29 ++- tests/test_foreign_keys.py | 354 ++++++++++++++-------------- tests/test_many_to_many.py | 177 +++++++------- tests/test_model_definition.py | 5 - tests/test_models.py | 207 ++++++++-------- tests/test_more_reallife_fastapi.py | 2 +- tests/test_more_same_table_joins.py | 9 +- tests/test_same_table_joins.py | 69 +++--- 16 files changed, 522 insertions(+), 462 deletions(-) diff --git a/.coverage b/.coverage index befb8dd377852865d47b94d1d03288bc4a8f417b..458f78cb3fe329f127a13728c3a244d4965409a5 100644 GIT binary patch delta 551 zcmWlUO=uHA7>3`;Zgyw$Gc)<^rVZO9)M$$|u~iyti8+YaOZCuGPU)qmQYv07qB{W( z9tuJ|w4jH2^C$#CJ?Nt1p?J}QNDLY*f#SuYupp({Ss3`9?|r|&kvp2)(VkW-;j=Td z(`P#8Dj%=a1o$7m&wKm>f5xA1$#3#SKF*8W=A0Ysh<#_T*<-fNHrOq;!kTQJ)!8W~ zItlS9?q)}HVoi>Gwy-h^;ERd50!V?Thm}4t?4Nj0g<-TnJ!p2z$|*IS^MAnQw!(t# z3r>l(gnL8`DWNnpu}#~J#-=_`i}9>d^4XUyK?&MA&jcMA|4oNW5Q!_1wC)lsoBsAL zUZ!SBxgzEK%O@4gCcBmTA;C0JN6kIbux|pl;kBKFa_foiz+Lx9UYUVY1LPfentFK= zbc^muUrUNsH3bss1<_3&EaAJl$pTjQm{#9o}28v1f;lt7d$ReITC_*L;jQXK^dNUYuUTIIi7WknPW z8jPirWmClkbM?j=fxMdAahtoEK?-+Mo}h__Kg71~C#XF`%pSS6>! hKTGh{y#B*=$2oxQFkQ)x-dsBTie(#4r+0s>aSZFKdT{^% delta 545 zcmWlSZ)g%>9LKl$=QQW@oPVZnO|4iBBIc+@Yt(MN=}p8|y$FmY5z&jngPBNN*ZHQm zjgT;s;a>GF41|S>UG!qYuw-uvJnLd*7Y#j1xXW(d)8nh(=lhRea@~+zH@un-nNA#^ zo0-kegtnJvjRUX)f5HZ=!aRHdSr~_x;1L*rZfJu;pbP4t0G@*<;1Re7Zh_0-95@Na zfidqOH|;6%v4!gjcSY%aSU~aK9Zqg(o-yN~b!tpGoKi(yZ*h!xGx)EiUbl^6-#4o4 zP)PPe9~K9E3W*qsk$;pphZh?Y-lNZ$oG5Ov0wWvj9&C+!ieFWRi|{NQ=}$(KcX^cI zL|(kM&y5Wa=l)V<+xrO@7Njen`<)ohIP0N#tsGC}V)TVxnHF{_ak&@y%)_7q#|w^n zle;f0v$A$Ox=W*C>;bQz6<=TZ#M}I2e^1sY6{IMwgnBhXsZ#NczE>zQiFk*A{XcMv zB|sm*iW?I4KG|2GWf+5Vc<k9b9#B)o&I4cXsns2K~);C5n@3Xs^)C$Q#XcI#{MX zhpf1M-U=I+2%1U8~8_E}{`Ff27S8+yR!UNdn&`BHD?SQ0{aZW!se| zO|5B*(jRUv(QXuC(uXEm6oWQ1F8a*j20DgPgI*I>)23fL7ZsF6-%FaVs}<&9%>axz T2pGdRGy%0WZa&{womlt}7(naL diff --git a/.travis.yml b/.travis.yml index 95a55f6..d057e1e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,8 +9,6 @@ python: - "3.7" - "3.8" -env: - - TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database, sqlite:///test.db" services: - postgresql @@ -26,7 +24,9 @@ before_script: script: - - scripts/test.sh + - DATABASE_URL=postgresql://localhost/test_database scripts/test.sh + - DATABASE_URL=mysql://localhost/test_database scripts/test.sh + - DATABASE_URL=sqlite:///test.db scripts/test.sh after_script: - codecov \ No newline at end of file diff --git a/ormar/models/model.py b/ormar/models/model.py index 43a6742..9edfe15 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -1,12 +1,18 @@ import itertools from typing import Any, List, Tuple, Union +from databases.backends.postgres import Record import sqlalchemy import ormar.queryset # noqa I100 from ormar.fields.many_to_many import ManyToManyField from ormar.models import NewBaseModel # noqa I100 +import logging +import sys + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) + def group_related_list(list_: List) -> dict: test_dict = dict() @@ -28,11 +34,11 @@ class Model(NewBaseModel): @classmethod def from_row( - cls, - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - related_models: Any = None, - previous_table: str = None, + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + related_models: Any = None, + previous_table: str = None, ) -> Union["Model", Tuple["Model", dict]]: item = {} @@ -43,9 +49,9 @@ class Model(NewBaseModel): # breakpoint() if ( - previous_table - and previous_table in cls.Meta.model_fields - and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField) + previous_table + and previous_table in cls.Meta.model_fields + and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField) ): previous_table = cls.Meta.model_fields[ previous_table @@ -66,11 +72,11 @@ class Model(NewBaseModel): @classmethod def populate_nested_models_from_row( - cls, - item: dict, - row: sqlalchemy.engine.ResultProxy, - related_models: Any, - previous_table: sqlalchemy.Table, + cls, + item: dict, + row: sqlalchemy.engine.ResultProxy, + related_models: Any, + previous_table: sqlalchemy.Table, ) -> dict: for related in related_models: if isinstance(related_models, dict) and related_models[related]: @@ -89,13 +95,17 @@ class Model(NewBaseModel): @classmethod def extract_prefixed_table_columns( - cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str + cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str ) -> dict: for column in cls.Meta.table.columns: + logging.debug('column to extract:' + column.name) + logging.debug(f'{row.keys()}') if column.name not in item: - item[column.name] = row[ - f'{table_prefix + "_" if table_prefix else ""}{column.name}' - ] + prefixed_name = f'{table_prefix + "_" if table_prefix else ""}{column.name}' + # databases does not keep aliases in Record for postgres + source = row._row if isinstance(row, Record) else row + item[column.name] = source[prefixed_name] + return item async def save(self) -> "Model": @@ -106,7 +116,8 @@ class Model(NewBaseModel): expr = self.Meta.table.insert() expr = expr.values(**self_fields) item_id = await self.Meta.database.execute(expr) - setattr(self, self.Meta.pkname, item_id) + if item_id: # postgress does not return id if it's already there + setattr(self, self.Meta.pkname, item_id) return self async def update(self, **kwargs: Any) -> "Model": diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 79f9ff2..6a46736 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -9,6 +9,9 @@ from ormar.queryset import FilterQuery from ormar.queryset.clause import QueryClause from ormar.queryset.query import Query +import logging +import sys +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -187,5 +190,6 @@ class QuerySet: # Execute the insert, and return a new model instance. instance = self.model_cls(**kwargs) pk = await self.database.execute(expr) - setattr(instance, self.model_cls.Meta.pkname, pk) + if pk: + setattr(instance, self.model_cls.Meta.pkname, pk) return instance diff --git a/ormar/relations/alias_manager.py b/ormar/relations/alias_manager.py index 1af8c22..64f7261 100644 --- a/ormar/relations/alias_manager.py +++ b/ormar/relations/alias_manager.py @@ -8,7 +8,8 @@ from sqlalchemy import text def get_table_alias() -> str: - return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] + alias = "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] + return alias.lower() class AliasManager: diff --git a/requirements.txt b/requirements.txt index a9d7c03..4bad8bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,8 +7,13 @@ sqlalchemy # Async database drivers aiomysql aiosqlite +aiopg asyncpg + +# Sync database drivers for standard tooling around setup/teardown/migrations. pymysql +psycopg2-binary +mysqlclient # Testing pytest diff --git a/tests/settings.py b/tests/settings.py index 3b01be1..4cf7350 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -1,5 +1,11 @@ import os -os.environ['TEST_DATABASE_URLS'] = "sqlite:///test.db" +import databases +assert "DATABASE_URL" in os.environ, "DATABASE_URL is not set." + +DATABASE_URL = os.environ['DATABASE_URL'] +database_url = databases.DatabaseURL(DATABASE_URL) +if database_url.scheme == "postgresql+aiopg": # pragma no cover + DATABASE_URL = str(database_url.replace(driver=None)) DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///test.db") diff --git a/tests/test_columns.py b/tests/test_columns.py index 07898cd..b2a7bd0 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -8,10 +8,6 @@ import sqlalchemy import ormar from tests.settings import DATABASE_URL -assert "TEST_DATABASE_URLS" in os.environ, "TEST_DATABASE_URLS is not set." - -DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")] - database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() @@ -38,45 +34,28 @@ class Example(ormar.Model): @pytest.fixture(autouse=True, scope="module") def create_test_database(): - for url in DATABASE_URLS: - database_url = databases.DatabaseURL(url) - if database_url.scheme == "mysql": - url = str(database_url.replace(driver="pymysql")) - elif database_url.scheme == "postgresql+aiopg": - url = str(database_url.replace(driver=None)) - engine = sqlalchemy.create_engine(url) - metadata.create_all(engine) - + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) yield - for url in DATABASE_URLS: - database_url = databases.DatabaseURL(url) - if database_url.scheme == "mysql": - url = str(database_url.replace(driver="pymysql")) - elif database_url.scheme == "postgresql+aiopg": - url = str(database_url.replace(driver=None)) - engine = sqlalchemy.create_engine(url) - metadata.drop_all(engine) + metadata.drop_all(engine) -@pytest.mark.parametrize("database_url", DATABASE_URLS) @pytest.mark.asyncio -async def test_model_crud(database_url): - async with databases.Database(database_url) as database: - async with database.transaction(force_rollback=True): - Example.Meta.database = database - example = Example() - await example.save() +async def test_model_crud(): + async with database: + example = Example() + await example.save() - await example.load() - assert example.created.year == datetime.datetime.now().year - assert example.created_day == datetime.date.today() - assert example.description is None - assert example.value is None - assert example.data == {} + await example.load() + assert example.created.year == datetime.datetime.now().year + assert example.created_day == datetime.date.today() + assert example.description is None + assert example.value is None + assert example.data == {} - await example.update(data={"foo": 123}, value=123.456) - await example.load() - assert example.value == 123.456 - assert example.data == {"foo": 123} + await example.update(data={"foo": 123}, value=123.456) + await example.load() + assert example.value == 123.456 + assert example.data == {"foo": 123} - await example.delete() + await example.delete() diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index f7f2625..fcbf479 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -1,7 +1,7 @@ import databases import sqlalchemy from fastapi import FastAPI -from fastapi.testclient import TestClient +from starlette.testclient import TestClient import ormar from tests.settings import DATABASE_URL @@ -38,18 +38,17 @@ async def create_item(item: Item): return item -client = TestClient(app) - - def test_read_main(): - response = client.post( - "/items/", json={"name": "test", "id": 1, "category": {"name": "test cat"}} - ) - assert response.status_code == 200 - assert response.json() == { - "category": {"id": None, "name": "test cat"}, - "id": 1, - "name": "test", - } - item = Item(**response.json()) - assert item.id == 1 + client = TestClient(app) + with client as client: + response = client.post( + "/items/", json={"name": "test", "id": 1, "category": {"name": "test cat"}} + ) + assert response.status_code == 200 + assert response.json() == { + "category": {"id": None, "name": "test cat"}, + "id": 1, + "name": "test", + } + item = Item(**response.json()) + assert item.id == 1 diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 98a7d5b..75bb1fd 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -78,6 +78,7 @@ class Member(ormar.Model): @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) metadata.create_all(engine) yield metadata.drop_all(engine) @@ -85,8 +86,9 @@ def create_test_database(): @pytest.mark.asyncio async def test_wrong_query_foreign_key_type(): - with pytest.raises(RelationshipInstanceError): - Track(title="The Error", album="wrong_pk_type") + async with database: + with pytest.raises(RelationshipInstanceError): + Track(title="The Error", album="wrong_pk_type") @pytest.mark.asyncio @@ -99,242 +101,252 @@ async def test_setting_explicitly_empty_relation(): @pytest.mark.asyncio async def test_related_name(): async with database: - album = await Album.objects.create(name="Vanilla") - await Cover.objects.create(album=album, title="The cover file") - - assert len(album.cover_pictures) == 1 - + async with database.transaction(force_rollback=True): + album = await Album.objects.create(name="Vanilla") + await Cover.objects.create(album=album, title="The cover file") + assert len(album.cover_pictures) == 1 @pytest.mark.asyncio async def test_model_crud(): async with database: - album = Album(name="Malibu") - await album.save() - track1 = Track(album=album, title="The Bird", position=1) - track2 = Track(album=album, title="Heart don't stand a chance", position=2) - track3 = Track(album=album, title="The Waters", position=3) - await track1.save() - await track2.save() - await track3.save() + async with database.transaction(force_rollback=True): + album = Album(name="Jamaica") + await album.save() + track1 = Track(album=album, title="The Bird", position=1) + track2 = Track(album=album, title="Heart don't stand a chance", position=2) + track3 = Track(album=album, title="The Waters", position=3) + await track1.save() + await track2.save() + await track3.save() - track = await Track.objects.get(title="The Bird") - assert track.album.pk == album.pk - assert isinstance(track.album, ormar.Model) - assert track.album.name is None - await track.album.load() - assert track.album.name == "Malibu" + track = await Track.objects.get(title="The Bird") + assert track.album.pk == album.pk + assert isinstance(track.album, ormar.Model) + assert track.album.name is None + await track.album.load() + assert track.album.name == "Jamaica" - assert len(album.tracks) == 3 - assert album.tracks[1].title == "Heart don't stand a chance" + assert len(album.tracks) == 3 + assert album.tracks[1].title == "Heart don't stand a chance" - album1 = await Album.objects.get(name="Malibu") - assert album1.pk == 1 - assert album1.tracks == [] + album1 = await Album.objects.get(name="Jamaica") + assert album1.pk == album.pk + assert album1.tracks == [] - await Track.objects.create( - album={"id": track.album.pk}, title="The Bird2", position=4 - ) + await Track.objects.create( + album={"id": track.album.pk}, title="The Bird2", position=4 + ) @pytest.mark.asyncio async def test_select_related(): async with database: - album = Album(name="Malibu") - await album.save() - track1 = Track(album=album, title="The Bird", position=1) - track2 = Track(album=album, title="Heart don't stand a chance", position=2) - track3 = Track(album=album, title="The Waters", position=3) - await track1.save() - await track2.save() - await track3.save() + async with database.transaction(force_rollback=True): + album = Album(name="Malibu") + await album.save() + track1 = Track(album=album, title="The Bird", position=1) + track2 = Track(album=album, title="Heart don't stand a chance", position=2) + track3 = Track(album=album, title="The Waters", position=3) + await track1.save() + await track2.save() + await track3.save() - fantasies = Album(name="Fantasies") - await fantasies.save() - track4 = Track(album=fantasies, title="Help I'm Alive", position=1) - track5 = Track(album=fantasies, title="Sick Muse", position=2) - track6 = Track(album=fantasies, title="Satellite Mind", position=3) - await track4.save() - await track5.save() - await track6.save() + fantasies = Album(name="Fantasies") + await fantasies.save() + track4 = Track(album=fantasies, title="Help I'm Alive", position=1) + track5 = Track(album=fantasies, title="Sick Muse", position=2) + track6 = Track(album=fantasies, title="Satellite Mind", position=3) + await track4.save() + await track5.save() + await track6.save() - track = await Track.objects.select_related("album").get(title="The Bird") - assert track.album.name == "Malibu" + track = await Track.objects.select_related("album").get(title="The Bird") + assert track.album.name == "Malibu" - tracks = await Track.objects.select_related("album").all() - assert len(tracks) == 6 + tracks = await Track.objects.select_related("album").all() + assert len(tracks) == 6 @pytest.mark.asyncio async def test_model_removal_from_relations(): async with database: - album = Album(name="Chichi") - await album.save() - track1 = Track(album=album, title="The Birdman", position=1) - track2 = Track(album=album, title="Superman", position=2) - track3 = Track(album=album, title="Wonder Woman", position=3) - await track1.save() - await track2.save() - await track3.save() + async with database.transaction(force_rollback=True): + album = Album(name="Chichi") + await album.save() + track1 = Track(album=album, title="The Birdman", position=1) + track2 = Track(album=album, title="Superman", position=2) + track3 = Track(album=album, title="Wonder Woman", position=3) + await track1.save() + await track2.save() + await track3.save() - assert len(album.tracks) == 3 - await album.tracks.remove(track1) - assert len(album.tracks) == 2 - assert track1.album is None + assert len(album.tracks) == 3 + await album.tracks.remove(track1) + assert len(album.tracks) == 2 + assert track1.album is None - await track1.update() - track1 = await Track.objects.get(title="The Birdman") - assert track1.album is None + await track1.update() + track1 = await Track.objects.get(title="The Birdman") + assert track1.album is None - await album.tracks.add(track1) - assert len(album.tracks) == 3 - assert track1.album == album + await album.tracks.add(track1) + assert len(album.tracks) == 3 + assert track1.album == album - await track1.update() - track1 = await Track.objects.select_related("album__tracks").get( - title="The Birdman" - ) - album = await Album.objects.select_related("tracks").get(name="Chichi") - assert track1.album == album + await track1.update() + track1 = await Track.objects.select_related("album__tracks").get( + title="The Birdman" + ) + album = await Album.objects.select_related("tracks").get(name="Chichi") + assert track1.album == album - track1.remove(album) - assert track1.album is None - assert len(album.tracks) == 2 + track1.remove(album) + assert track1.album is None + assert len(album.tracks) == 2 + + track2.remove(album) + assert track2.album is None + assert len(album.tracks) == 1 - track2.remove(album) - assert track2.album is None - assert len(album.tracks) == 1 @pytest.mark.asyncio async def test_fk_filter(): async with database: - malibu = Album(name="Malibu%") - await malibu.save() - await Track.objects.create(album=malibu, title="The Bird", position=1) - await Track.objects.create( - album=malibu, title="Heart don't stand a chance", position=2 - ) - await Track.objects.create(album=malibu, title="The Waters", position=3) + async with database.transaction(force_rollback=True): + malibu = Album(name="Malibu%") + await malibu.save() + await Track.objects.create(album=malibu, title="The Bird", position=1) + await Track.objects.create( + album=malibu, title="Heart don't stand a chance", position=2 + ) + await Track.objects.create(album=malibu, title="The Waters", position=3) - fantasies = await Album.objects.create(name="Fantasies") - await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) - await Track.objects.create(album=fantasies, title="Sick Muse", position=2) - await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) + fantasies = await Album.objects.create(name="Fantasies") + await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) - tracks = ( - await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() - ) - assert len(tracks) == 3 - for track in tracks: - assert track.album.name == "Fantasies" + tracks = ( + await Track.objects.select_related("album") + .filter(album__name="Fantasies") + .all() + ) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" - tracks = ( - await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() - ) - assert len(tracks) == 3 - for track in tracks: - assert track.album.name == "Fantasies" + tracks = ( + await Track.objects.select_related("album") + .filter(album__name__icontains="fan") + .all() + ) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" - tracks = await Track.objects.filter(album__name__contains="fan").all() - assert len(tracks) == 3 - for track in tracks: - assert track.album.name == "Fantasies" + tracks = await Track.objects.filter(album__name__contains="Fan").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" - tracks = await Track.objects.filter(album__name__contains="Malibu%").all() - assert len(tracks) == 3 + tracks = await Track.objects.filter(album__name__contains="Malibu%").all() + assert len(tracks) == 3 - tracks = await Track.objects.filter(album=malibu).select_related("album").all() - assert len(tracks) == 3 - for track in tracks: - assert track.album.name == "Malibu%" + tracks = await Track.objects.filter(album=malibu).select_related("album").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Malibu%" - tracks = await Track.objects.select_related("album").all(album=malibu) - assert len(tracks) == 3 - for track in tracks: - assert track.album.name == "Malibu%" + tracks = await Track.objects.select_related("album").all(album=malibu) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Malibu%" @pytest.mark.asyncio async def test_multiple_fk(): async with database: - acme = await Organisation.objects.create(ident="ACME Ltd") - red_team = await Team.objects.create(org=acme, name="Red Team") - blue_team = await Team.objects.create(org=acme, name="Blue Team") - await Member.objects.create(team=red_team, email="a@example.org") - await Member.objects.create(team=red_team, email="b@example.org") - await Member.objects.create(team=blue_team, email="c@example.org") - await Member.objects.create(team=blue_team, email="d@example.org") + async with database.transaction(force_rollback=True): + acme = await Organisation.objects.create(ident="ACME Ltd") + red_team = await Team.objects.create(org=acme, name="Red Team") + blue_team = await Team.objects.create(org=acme, name="Blue Team") + await Member.objects.create(team=red_team, email="a@example.org") + await Member.objects.create(team=red_team, email="b@example.org") + await Member.objects.create(team=blue_team, email="c@example.org") + await Member.objects.create(team=blue_team, email="d@example.org") - other = await Organisation.objects.create(ident="Other ltd") - team = await Team.objects.create(org=other, name="Green Team") - await Member.objects.create(team=team, email="e@example.org") + other = await Organisation.objects.create(ident="Other ltd") + team = await Team.objects.create(org=other, name="Green Team") + await Member.objects.create(team=team, email="e@example.org") - members = ( - await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() - ) - assert len(members) == 4 - for member in members: - assert member.team.org.ident == "ACME Ltd" + members = ( + await Member.objects.select_related("team__org") + .filter(team__org__ident="ACME Ltd") + .all() + ) + assert len(members) == 4 + for member in members: + assert member.team.org.ident == "ACME Ltd" @pytest.mark.asyncio async def test_pk_filter(): async with database: - fantasies = await Album.objects.create(name="Test") - await Track.objects.create(album=fantasies, title="Test1", position=1) - await Track.objects.create(album=fantasies, title="Test2", position=2) - await Track.objects.create(album=fantasies, title="Test3", position=3) - tracks = await Track.objects.select_related("album").filter(pk=1).all() - assert len(tracks) == 1 + async with database.transaction(force_rollback=True): + fantasies = await Album.objects.create(name="Test") + track = await Track.objects.create(album=fantasies, title="Test1", position=1) + await Track.objects.create(album=fantasies, title="Test2", position=2) + await Track.objects.create(album=fantasies, title="Test3", position=3) + tracks = await Track.objects.select_related("album").filter(pk=track.pk).all() + assert len(tracks) == 1 - tracks = ( - await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() - ) - assert len(tracks) == 1 + tracks = ( + await Track.objects.select_related("album") + .filter(position=2, album__name="Test") + .all() + ) + assert len(tracks) == 1 @pytest.mark.asyncio async def test_limit_and_offset(): async with database: - fantasies = await Album.objects.create(name="Limitless") - await Track.objects.create(id=None, album=fantasies, title="Sample", position=1) - await Track.objects.create(album=fantasies, title="Sample2", position=2) - await Track.objects.create(album=fantasies, title="Sample3", position=3) + async with database.transaction(force_rollback=True): + fantasies = await Album.objects.create(name="Limitless") + await Track.objects.create(id=None, album=fantasies, title="Sample", position=1) + await Track.objects.create(album=fantasies, title="Sample2", position=2) + await Track.objects.create(album=fantasies, title="Sample3", position=3) - tracks = await Track.objects.limit(1).all() - assert len(tracks) == 1 - assert tracks[0].title == "Sample" + tracks = await Track.objects.limit(1).all() + assert len(tracks) == 1 + assert tracks[0].title == "Sample" - tracks = await Track.objects.limit(1).offset(1).all() - assert len(tracks) == 1 - assert tracks[0].title == "Sample2" + tracks = await Track.objects.limit(1).offset(1).all() + assert len(tracks) == 1 + assert tracks[0].title == "Sample2" @pytest.mark.asyncio async def test_get_exceptions(): async with database: - fantasies = await Album.objects.create(name="Test") + async with database.transaction(force_rollback=True): + fantasies = await Album.objects.create(name="Test") - with pytest.raises(NoMatch): - await Album.objects.get(name="Test2") + with pytest.raises(NoMatch): + await Album.objects.get(name="Test2") - await Track.objects.create(album=fantasies, title="Test1", position=1) - await Track.objects.create(album=fantasies, title="Test2", position=2) - await Track.objects.create(album=fantasies, title="Test3", position=3) - with pytest.raises(MultipleMatches): - await Track.objects.select_related("album").get(album=fantasies) + await Track.objects.create(album=fantasies, title="Test1", position=1) + await Track.objects.create(album=fantasies, title="Test2", position=2) + await Track.objects.create(album=fantasies, title="Test3", position=3) + with pytest.raises(MultipleMatches): + await Track.objects.select_related("album").get(album=fantasies) @pytest.mark.asyncio async def test_wrong_model_passed_as_fk(): - with pytest.raises(RelationshipInstanceError): - org = await Organisation.objects.create(ident="ACME Ltd") - await Track.objects.create(album=org, title="Test1", position=1) + async with database: + async with database.transaction(force_rollback=True): + with pytest.raises(RelationshipInstanceError): + org = await Organisation.objects.create(ident="ACME Ltd") + await Track.objects.create(album=org, title="Test1", position=1) diff --git a/tests/test_many_to_many.py b/tests/test_many_to_many.py index a2963e8..4111458 100644 --- a/tests/test_many_to_many.py +++ b/tests/test_many_to_many.py @@ -1,3 +1,5 @@ +import asyncio + import databases import pytest import sqlalchemy @@ -50,8 +52,15 @@ class Post(ormar.Model): author: ormar.ForeignKey(Author) +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + @pytest.fixture(autouse=True, scope="module") -def create_test_database(): +async def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) metadata.create_all(engine) yield @@ -61,118 +70,124 @@ def create_test_database(): @pytest.fixture(scope="function") async def cleanup(): yield - await PostCategory.objects.delete() - await Post.objects.delete() - await Category.objects.delete() - await Author.objects.delete() + async with database: + await PostCategory.objects.delete() + await Post.objects.delete() + await Category.objects.delete() + await Author.objects.delete() @pytest.mark.asyncio async def test_assigning_related_objects(cleanup): - guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") - post = await Post.objects.create(title="Hello, M2M", author=guido) - news = await Category.objects.create(name="News") + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") - # Add a category to a post. - await post.categories.add(news) - # or from the other end: - await news.posts.add(post) + # Add a category to a post. + await post.categories.add(news) + # or from the other end: + await news.posts.add(post) - # Creating related object from instance: - await post.categories.create(name="Tips") - assert len(post.categories) == 2 + # Creating related object from instance: + await post.categories.create(name="Tips") + assert len(post.categories) == 2 - post_categories = await post.categories.all() - assert len(post_categories) == 2 + post_categories = await post.categories.all() + assert len(post_categories) == 2 @pytest.mark.asyncio async def test_quering_of_the_m2m_models(cleanup): - # orm can do this already. - guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") - post = await Post.objects.create(title="Hello, M2M", author=guido) - news = await Category.objects.create(name="News") - # tl;dr: `post.categories` exposes the QuerySet API. + async with database: + # orm can do this already. + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") + # tl;dr: `post.categories` exposes the QuerySet API. - await post.categories.add(news) + await post.categories.add(news) - post_categories = await post.categories.all() - assert len(post_categories) == 1 + post_categories = await post.categories.all() + assert len(post_categories) == 1 - assert news == await post.categories.get(name="News") + assert news == await post.categories.get(name="News") - num_posts = await news.posts.count() - assert num_posts == 1 + num_posts = await news.posts.count() + assert num_posts == 1 - posts_about_m2m = await news.posts.filter(title__contains="M2M").all() - assert len(posts_about_m2m) == 1 - assert posts_about_m2m[0] == post - posts_about_python = await Post.objects.filter(categories__name="python").all() - assert len(posts_about_python) == 0 + posts_about_m2m = await news.posts.filter(title__contains="M2M").all() + assert len(posts_about_m2m) == 1 + assert posts_about_m2m[0] == post + posts_about_python = await Post.objects.filter(categories__name="python").all() + assert len(posts_about_python) == 0 - # Traversal of relationships: which categories has Guido contributed to? - category = await Category.objects.filter(posts__author=guido).get() - assert category == news - # or: - category2 = await Category.objects.filter(posts__author__first_name="Guido").get() - assert category2 == news + # Traversal of relationships: which categories has Guido contributed to? + category = await Category.objects.filter(posts__author=guido).get() + assert category == news + # or: + category2 = await Category.objects.filter(posts__author__first_name="Guido").get() + assert category2 == news @pytest.mark.asyncio async def test_removal_of_the_relations(cleanup): - guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") - post = await Post.objects.create(title="Hello, M2M", author=guido) - news = await Category.objects.create(name="News") - await post.categories.add(news) - assert len(await post.categories.all()) == 1 - await post.categories.remove(news) - assert len(await post.categories.all()) == 0 - # or: - await news.posts.add(post) - assert len(await news.posts.all()) == 1 - await news.posts.remove(post) - assert len(await news.posts.all()) == 0 + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") + await post.categories.add(news) + assert len(await post.categories.all()) == 1 + await post.categories.remove(news) + assert len(await post.categories.all()) == 0 + # or: + await news.posts.add(post) + assert len(await news.posts.all()) == 1 + await news.posts.remove(post) + assert len(await news.posts.all()) == 0 - # Remove all related objects: - await post.categories.add(news) - await post.categories.clear() - assert len(await post.categories.all()) == 0 + # Remove all related objects: + await post.categories.add(news) + await post.categories.clear() + assert len(await post.categories.all()) == 0 - # post would also lose 'news' category when running: - await post.categories.add(news) - await news.delete() - assert len(await post.categories.all()) == 0 + # post would also lose 'news' category when running: + await post.categories.add(news) + await news.delete() + assert len(await post.categories.all()) == 0 @pytest.mark.asyncio async def test_selecting_related(cleanup): - guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") - post = await Post.objects.create(title="Hello, M2M", author=guido) - news = await Category.objects.create(name="News") - recent = await Category.objects.create(name="Recent") - await post.categories.add(news) - await post.categories.add(recent) - assert len(await post.categories.all()) == 2 - # Loads categories and posts (2 queries) and perform the join in Python. - categories = await Category.objects.select_related("posts").all() - # No extra queries needed => no more `await`s required. - for category in categories: - assert category.posts[0] == post + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") + recent = await Category.objects.create(name="Recent") + await post.categories.add(news) + await post.categories.add(recent) + assert len(await post.categories.all()) == 2 + # Loads categories and posts (2 queries) and perform the join in Python. + categories = await Category.objects.select_related("posts").all() + # No extra queries needed => no more `await`s required. + for category in categories: + assert category.posts[0] == post - news_posts = await news.posts.select_related("author").all() - assert news_posts[0].author == guido + news_posts = await news.posts.select_related("author").all() + assert news_posts[0].author == guido - assert (await post.categories.limit(1).all())[0] == news - assert (await post.categories.offset(1).limit(1).all())[0] == recent + assert (await post.categories.limit(1).all())[0] == news + assert (await post.categories.offset(1).limit(1).all())[0] == recent - assert await post.categories.first() == news + assert await post.categories.first() == news - assert await post.categories.exists() + assert await post.categories.exists() @pytest.mark.asyncio async def test_selecting_related_fail_without_saving(cleanup): - guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") - post = Post(title="Hello, M2M", author=guido) - with pytest.raises(RelationshipInstanceError): - await post.categories.all() + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = Post(title="Hello, M2M", author=guido) + with pytest.raises(RelationshipInstanceError): + await post.categories.all() diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index ab2845c..c374585 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -112,7 +112,6 @@ def test_sqlalchemy_table_is_created(example): def test_no_pk_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): class Meta: tablename = "example3" @@ -123,7 +122,6 @@ def test_no_pk_in_model_definition(): def test_two_pks_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): class Meta: tablename = "example3" @@ -135,7 +133,6 @@ def test_two_pks_in_model_definition(): def test_setting_pk_column_as_pydantic_only_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): class Meta: tablename = "example4" @@ -146,7 +143,6 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition(): def test_decimal_error_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): class Meta: tablename = "example5" @@ -157,7 +153,6 @@ def test_decimal_error_in_model_definition(): def test_string_error_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): class Meta: tablename = "example6" diff --git a/tests/test_models.py b/tests/test_models.py index 758c1bd..79381fe 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,3 +1,5 @@ +import asyncio + import databases import pydantic import pytest @@ -43,9 +45,17 @@ class Product(ormar.Model): in_stock: ormar.Boolean(default=False) +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + @pytest.fixture(autouse=True, scope="module") -def create_test_database(): +async def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) metadata.create_all(engine) yield metadata.drop_all(engine) @@ -69,166 +79,177 @@ def test_model_pk(): @pytest.mark.asyncio async def test_json_column(): async with database: - await JsonSample.objects.create(test_json=dict(aa=12)) - await JsonSample.objects.create(test_json='{"aa": 12}') + async with database.transaction(force_rollback=True): + await JsonSample.objects.create(test_json=dict(aa=12)) + await JsonSample.objects.create(test_json='{"aa": 12}') - items = await JsonSample.objects.all() - assert len(items) == 2 - assert items[0].test_json == dict(aa=12) - assert items[1].test_json == dict(aa=12) + items = await JsonSample.objects.all() + assert len(items) == 2 + assert items[0].test_json == dict(aa=12) + assert items[1].test_json == dict(aa=12) @pytest.mark.asyncio async def test_model_crud(): async with database: - users = await User.objects.all() - assert users == [] + async with database.transaction(force_rollback=True): + users = await User.objects.all() + assert users == [] - user = await User.objects.create(name="Tom") - users = await User.objects.all() - assert user.name == "Tom" - assert user.pk is not None - assert users == [user] + user = await User.objects.create(name="Tom") + users = await User.objects.all() + assert user.name == "Tom" + assert user.pk is not None + assert users == [user] - lookup = await User.objects.get() - assert lookup == user + lookup = await User.objects.get() + assert lookup == user - await user.update(name="Jane") - users = await User.objects.all() - assert user.name == "Jane" - assert user.pk is not None - assert users == [user] + await user.update(name="Jane") + users = await User.objects.all() + assert user.name == "Jane" + assert user.pk is not None + assert users == [user] - await user.delete() - users = await User.objects.all() - assert users == [] + await user.delete() + users = await User.objects.all() + assert users == [] @pytest.mark.asyncio async def test_model_get(): async with database: - with pytest.raises(ormar.NoMatch): - await User.objects.get() + async with database.transaction(force_rollback=True): + with pytest.raises(ormar.NoMatch): + await User.objects.get() - user = await User.objects.create(name="Tom") - lookup = await User.objects.get() - assert lookup == user + user = await User.objects.create(name="Tom") + lookup = await User.objects.get() + assert lookup == user - user = await User.objects.create(name="Jane") - with pytest.raises(ormar.MultipleMatches): - await User.objects.get() + user = await User.objects.create(name="Jane") + with pytest.raises(ormar.MultipleMatches): + await User.objects.get() - same_user = await User.objects.get(pk=user.id) - assert same_user.id == user.id - assert same_user.pk == user.pk + same_user = await User.objects.get(pk=user.id) + assert same_user.id == user.id + assert same_user.pk == user.pk @pytest.mark.asyncio async def test_model_filter(): async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Jane") - await User.objects.create(name="Lucy") + async with database.transaction(force_rollback=True): + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") - user = await User.objects.get(name="Lucy") - assert user.name == "Lucy" + user = await User.objects.get(name="Lucy") + assert user.name == "Lucy" - with pytest.raises(ormar.NoMatch): - await User.objects.get(name="Jim") + with pytest.raises(ormar.NoMatch): + await User.objects.get(name="Jim") - await Product.objects.create(name="T-Shirt", rating=5, in_stock=True) - await Product.objects.create(name="Dress", rating=4) - await Product.objects.create(name="Coat", rating=3, in_stock=True) + await Product.objects.create(name="T-Shirt", rating=5, in_stock=True) + await Product.objects.create(name="Dress", rating=4) + await Product.objects.create(name="Coat", rating=3, in_stock=True) - product = await Product.objects.get(name__iexact="t-shirt", rating=5) - assert product.pk is not None - assert product.name == "T-Shirt" - assert product.rating == 5 + product = await Product.objects.get(name__iexact="t-shirt", rating=5) + assert product.pk is not None + assert product.name == "T-Shirt" + assert product.rating == 5 - products = await Product.objects.all(rating__gte=2, in_stock=True) - assert len(products) == 2 + products = await Product.objects.all(rating__gte=2, in_stock=True) + assert len(products) == 2 - products = await Product.objects.all(name__icontains="T") - assert len(products) == 2 + products = await Product.objects.all(name__icontains="T") + assert len(products) == 2 - # Test escaping % character from icontains, contains, and iexact - await Product.objects.create(name="100%-Cotton", rating=3) - await Product.objects.create(name="Cotton-100%-Egyptian", rating=3) - await Product.objects.create(name="Cotton-100%", rating=3) - products = Product.objects.filter(name__iexact="100%-cotton") - assert await products.count() == 1 + # Test escaping % character from icontains, contains, and iexact + await Product.objects.create(name="100%-Cotton", rating=3) + await Product.objects.create(name="Cotton-100%-Egyptian", rating=3) + await Product.objects.create(name="Cotton-100%", rating=3) + products = Product.objects.filter(name__iexact="100%-cotton") + assert await products.count() == 1 - products = Product.objects.filter(name__contains="%") - assert await products.count() == 3 + products = Product.objects.filter(name__contains="%") + assert await products.count() == 3 - products = Product.objects.filter(name__icontains="%") - assert await products.count() == 3 + products = Product.objects.filter(name__icontains="%") + assert await products.count() == 3 @pytest.mark.asyncio async def test_wrong_query_contains_model(): - with pytest.raises(QueryDefinitionError): - product = Product(name="90%-Cotton", rating=2) - await Product.objects.filter(name__contains=product).count() + async with database: + with pytest.raises(QueryDefinitionError): + product = Product(name="90%-Cotton", rating=2) + await Product.objects.filter(name__contains=product).count() @pytest.mark.asyncio async def test_model_exists(): async with database: - await User.objects.create(name="Tom") - assert await User.objects.filter(name="Tom").exists() is True - assert await User.objects.filter(name="Jane").exists() is False + async with database.transaction(force_rollback=True): + await User.objects.create(name="Tom") + assert await User.objects.filter(name="Tom").exists() is True + assert await User.objects.filter(name="Jane").exists() is False @pytest.mark.asyncio async def test_model_count(): async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Jane") - await User.objects.create(name="Lucy") + async with database.transaction(force_rollback=True): + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") - assert await User.objects.count() == 3 - assert await User.objects.filter(name__icontains="T").count() == 1 + assert await User.objects.count() == 3 + assert await User.objects.filter(name__icontains="T").count() == 1 @pytest.mark.asyncio async def test_model_limit(): async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Jane") - await User.objects.create(name="Lucy") + async with database.transaction(force_rollback=True): + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") - assert len(await User.objects.limit(2).all()) == 2 + assert len(await User.objects.limit(2).all()) == 2 @pytest.mark.asyncio async def test_model_limit_with_filter(): async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Tom") - await User.objects.create(name="Tom") + async with database.transaction(force_rollback=True): + await User.objects.create(name="Tom") + await User.objects.create(name="Tom") + await User.objects.create(name="Tom") - assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2 + assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2 @pytest.mark.asyncio async def test_offset(): async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Jane") + async with database.transaction(force_rollback=True): + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") - users = await User.objects.offset(1).limit(1).all() - assert users[0].name == "Jane" + users = await User.objects.offset(1).limit(1).all() + assert users[0].name == "Jane" @pytest.mark.asyncio async def test_model_first(): async with database: - tom = await User.objects.create(name="Tom") - jane = await User.objects.create(name="Jane") + async with database.transaction(force_rollback=True): + tom = await User.objects.create(name="Tom") + jane = await User.objects.create(name="Jane") - assert await User.objects.first() == tom - assert await User.objects.first(name="Jane") == jane - assert await User.objects.filter(name="Jane").first() == jane - with pytest.raises(NoMatch): - await User.objects.filter(name="Lucy").first() + assert await User.objects.first() == tom + assert await User.objects.first(name="Jane") == jane + assert await User.objects.filter(name="Jane").first() == jane + with pytest.raises(NoMatch): + await User.objects.filter(name="Lucy").first() diff --git a/tests/test_more_reallife_fastapi.py b/tests/test_more_reallife_fastapi.py index e01ce44..f0b9b88 100644 --- a/tests/test_more_reallife_fastapi.py +++ b/tests/test_more_reallife_fastapi.py @@ -112,7 +112,7 @@ def test_all_endpoints(): assert items[0].name == "New name" response = client.delete(f"/items/{item.pk}", json=item.dict()) - assert response.json().get("deleted_rows") == 1 + assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__" response = client.get("/items/") items = response.json() assert len(items) == 0 diff --git a/tests/test_more_same_table_joins.py b/tests/test_more_same_table_joins.py index 3718bc0..5c4ac3d 100644 --- a/tests/test_more_same_table_joins.py +++ b/tests/test_more_same_table_joins.py @@ -78,6 +78,11 @@ async def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) metadata.drop_all(engine) metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +async def create_data(): department = await Department.objects.create(id=1, name="Math Department") department2 = await Department.objects.create(id=2, name="Law Department") class1 = await SchoolClass.objects.create(name="Math") @@ -88,13 +93,11 @@ async def create_test_database(): await Student.objects.create(name="Judy", category=category2, schoolclass=class1) await Student.objects.create(name="Jack", category=category2, schoolclass=class2) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) - yield - metadata.drop_all(engine) - @pytest.mark.asyncio async def test_model_multiple_instances_of_same_table_in_schema(): async with database: + await create_data() classes = await SchoolClass.objects.select_related( ["teachers__category__department", "students"] ).all() diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 5b8ffd0..9eda9d7 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -78,6 +78,11 @@ async def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) metadata.drop_all(engine) metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +async def create_data(): department = await Department.objects.create(id=1, name="Math Department") department2 = await Department.objects.create(id=2, name="Law Department") class1 = await SchoolClass.objects.create(name="Math", department=department) @@ -88,52 +93,56 @@ async def create_test_database(): await Student.objects.create(name="Judy", category=category2, schoolclass=class1) await Student.objects.create(name="Jack", category=category2, schoolclass=class2) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) - yield - metadata.drop_all(engine) @pytest.mark.asyncio async def test_model_multiple_instances_of_same_table_in_schema(): async with database: - classes = await SchoolClass.objects.select_related( - ["teachers__category", "students"] - ).all() - assert classes[0].name == "Math" - assert classes[0].students[0].name == "Jane" + async with database.transaction(force_rollback=True): + await create_data() + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students"] + ).all() + assert classes[0].name == "Math" + assert classes[0].students[0].name == "Jane" - assert len(classes[0].dict().get("students")) == 2 + assert len(classes[0].dict().get("students")) == 2 - # since it's going from schoolclass => teacher => schoolclass (same class) department is already populated - assert classes[0].students[0].schoolclass.name == "Math" - assert classes[0].students[0].schoolclass.department.name is None - await classes[0].students[0].schoolclass.department.load() - assert classes[0].students[0].schoolclass.department.name == "Math Department" + # since it's going from schoolclass => teacher => schoolclass (same class) department is already populated + assert classes[0].students[0].schoolclass.name == "Math" + assert classes[0].students[0].schoolclass.department.name is None + await classes[0].students[0].schoolclass.department.load() + assert classes[0].students[0].schoolclass.department.name == "Math Department" - await classes[1].students[0].schoolclass.department.load() - assert classes[1].students[0].schoolclass.department.name == "Law Department" + await classes[1].students[0].schoolclass.department.load() + assert classes[1].students[0].schoolclass.department.name == "Law Department" @pytest.mark.asyncio async def test_right_tables_join(): async with database: - classes = await SchoolClass.objects.select_related( - ["teachers__category", "students"] - ).all() - assert classes[0].teachers[0].category.name == "Domestic" + async with database.transaction(force_rollback=True): + await create_data() + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students"] + ).all() + assert classes[0].teachers[0].category.name == "Domestic" - assert classes[0].students[0].category.name is None - await classes[0].students[0].category.load() - assert classes[0].students[0].category.name == "Foreign" + assert classes[0].students[0].category.name is None + await classes[0].students[0].category.load() + assert classes[0].students[0].category.name == "Foreign" @pytest.mark.asyncio async def test_multiple_reverse_related_objects(): async with database: - classes = await SchoolClass.objects.select_related( - ["teachers__category", "students__category"] - ).all() - assert classes[0].name == "Math" - assert classes[0].students[1].name == "Judy" - assert classes[0].students[0].category.name == "Foreign" - assert classes[0].students[1].category.name == "Domestic" - assert classes[0].teachers[0].category.name == "Domestic" + async with database.transaction(force_rollback=True): + await create_data() + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students__category"] + ).all() + assert classes[0].name == "Math" + assert classes[0].students[1].name == "Judy" + assert classes[0].students[0].category.name == "Foreign" + assert classes[0].students[1].category.name == "Domestic" + assert classes[0].teachers[0].category.name == "Domestic"