add tests for mysql and postgress, some fixes for those backends

This commit is contained in:
collerek
2020-09-17 13:02:34 +02:00
parent 31096d3f93
commit 1451ec8671
16 changed files with 522 additions and 462 deletions

BIN
.coverage

Binary file not shown.

View File

@ -9,8 +9,6 @@ python:
- "3.7" - "3.7"
- "3.8" - "3.8"
env:
- TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database, sqlite:///test.db"
services: services:
- postgresql - postgresql
@ -26,7 +24,9 @@ before_script:
script: script:
- scripts/test.sh - DATABASE_URL=postgresql://localhost/test_database scripts/test.sh
- DATABASE_URL=mysql://localhost/test_database scripts/test.sh
- DATABASE_URL=sqlite:///test.db scripts/test.sh
after_script: after_script:
- codecov - codecov

View File

@ -1,12 +1,18 @@
import itertools import itertools
from typing import Any, List, Tuple, Union from typing import Any, List, Tuple, Union
from databases.backends.postgres import Record
import sqlalchemy import sqlalchemy
import ormar.queryset # noqa I100 import ormar.queryset # noqa I100
from ormar.fields.many_to_many import ManyToManyField from ormar.fields.many_to_many import ManyToManyField
from ormar.models import NewBaseModel # noqa I100 from ormar.models import NewBaseModel # noqa I100
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
def group_related_list(list_: List) -> dict: def group_related_list(list_: List) -> dict:
test_dict = dict() test_dict = dict()
@ -92,10 +98,14 @@ class Model(NewBaseModel):
cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str
) -> dict: ) -> dict:
for column in cls.Meta.table.columns: for column in cls.Meta.table.columns:
logging.debug('column to extract:' + column.name)
logging.debug(f'{row.keys()}')
if column.name not in item: if column.name not in item:
item[column.name] = row[ prefixed_name = f'{table_prefix + "_" if table_prefix else ""}{column.name}'
f'{table_prefix + "_" if table_prefix else ""}{column.name}' # databases does not keep aliases in Record for postgres
] source = row._row if isinstance(row, Record) else row
item[column.name] = source[prefixed_name]
return item return item
async def save(self) -> "Model": async def save(self) -> "Model":
@ -106,6 +116,7 @@ class Model(NewBaseModel):
expr = self.Meta.table.insert() expr = self.Meta.table.insert()
expr = expr.values(**self_fields) expr = expr.values(**self_fields)
item_id = await self.Meta.database.execute(expr) item_id = await self.Meta.database.execute(expr)
if item_id: # postgress does not return id if it's already there
setattr(self, self.Meta.pkname, item_id) setattr(self, self.Meta.pkname, item_id)
return self return self

View File

@ -9,6 +9,9 @@ from ormar.queryset import FilterQuery
from ormar.queryset.clause import QueryClause from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query from ormar.queryset.query import Query
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -187,5 +190,6 @@ class QuerySet:
# Execute the insert, and return a new model instance. # Execute the insert, and return a new model instance.
instance = self.model_cls(**kwargs) instance = self.model_cls(**kwargs)
pk = await self.database.execute(expr) pk = await self.database.execute(expr)
if pk:
setattr(instance, self.model_cls.Meta.pkname, pk) setattr(instance, self.model_cls.Meta.pkname, pk)
return instance return instance

View File

@ -8,7 +8,8 @@ from sqlalchemy import text
def get_table_alias() -> str: def get_table_alias() -> str:
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] alias = "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
return alias.lower()
class AliasManager: class AliasManager:

View File

@ -7,8 +7,13 @@ sqlalchemy
# Async database drivers # Async database drivers
aiomysql aiomysql
aiosqlite aiosqlite
aiopg
asyncpg asyncpg
# Sync database drivers for standard tooling around setup/teardown/migrations.
pymysql pymysql
psycopg2-binary
mysqlclient
# Testing # Testing
pytest pytest

View File

@ -1,5 +1,11 @@
import os import os
os.environ['TEST_DATABASE_URLS'] = "sqlite:///test.db" import databases
assert "DATABASE_URL" in os.environ, "DATABASE_URL is not set."
DATABASE_URL = os.environ['DATABASE_URL']
database_url = databases.DatabaseURL(DATABASE_URL)
if database_url.scheme == "postgresql+aiopg": # pragma no cover
DATABASE_URL = str(database_url.replace(driver=None))
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///test.db") DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///test.db")

View File

@ -8,10 +8,6 @@ import sqlalchemy
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
assert "TEST_DATABASE_URLS" in os.environ, "TEST_DATABASE_URLS is not set."
DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")]
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
@ -38,32 +34,15 @@ class Example(ormar.Model):
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
def create_test_database(): def create_test_database():
for url in DATABASE_URLS: engine = sqlalchemy.create_engine(DATABASE_URL)
database_url = databases.DatabaseURL(url)
if database_url.scheme == "mysql":
url = str(database_url.replace(driver="pymysql"))
elif database_url.scheme == "postgresql+aiopg":
url = str(database_url.replace(driver=None))
engine = sqlalchemy.create_engine(url)
metadata.create_all(engine) metadata.create_all(engine)
yield yield
for url in DATABASE_URLS:
database_url = databases.DatabaseURL(url)
if database_url.scheme == "mysql":
url = str(database_url.replace(driver="pymysql"))
elif database_url.scheme == "postgresql+aiopg":
url = str(database_url.replace(driver=None))
engine = sqlalchemy.create_engine(url)
metadata.drop_all(engine) metadata.drop_all(engine)
@pytest.mark.parametrize("database_url", DATABASE_URLS)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_crud(database_url): async def test_model_crud():
async with databases.Database(database_url) as database: async with database:
async with database.transaction(force_rollback=True):
Example.Meta.database = database
example = Example() example = Example()
await example.save() await example.save()

View File

@ -1,7 +1,7 @@
import databases import databases
import sqlalchemy import sqlalchemy
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.testclient import TestClient from starlette.testclient import TestClient
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -38,10 +38,9 @@ async def create_item(item: Item):
return item return item
client = TestClient(app)
def test_read_main(): def test_read_main():
client = TestClient(app)
with client as client:
response = client.post( response = client.post(
"/items/", json={"name": "test", "id": 1, "category": {"name": "test cat"}} "/items/", json={"name": "test", "id": 1, "category": {"name": "test cat"}}
) )

View File

@ -78,6 +78,7 @@ class Member(ormar.Model):
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
def create_test_database(): def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine) metadata.create_all(engine)
yield yield
metadata.drop_all(engine) metadata.drop_all(engine)
@ -85,6 +86,7 @@ def create_test_database():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wrong_query_foreign_key_type(): async def test_wrong_query_foreign_key_type():
async with database:
with pytest.raises(RelationshipInstanceError): with pytest.raises(RelationshipInstanceError):
Track(title="The Error", album="wrong_pk_type") Track(title="The Error", album="wrong_pk_type")
@ -99,16 +101,16 @@ async def test_setting_explicitly_empty_relation():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_related_name(): async def test_related_name():
async with database: async with database:
async with database.transaction(force_rollback=True):
album = await Album.objects.create(name="Vanilla") album = await Album.objects.create(name="Vanilla")
await Cover.objects.create(album=album, title="The cover file") await Cover.objects.create(album=album, title="The cover file")
assert len(album.cover_pictures) == 1 assert len(album.cover_pictures) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_crud(): async def test_model_crud():
async with database: async with database:
album = Album(name="Malibu") async with database.transaction(force_rollback=True):
album = Album(name="Jamaica")
await album.save() await album.save()
track1 = Track(album=album, title="The Bird", position=1) track1 = Track(album=album, title="The Bird", position=1)
track2 = Track(album=album, title="Heart don't stand a chance", position=2) track2 = Track(album=album, title="Heart don't stand a chance", position=2)
@ -122,13 +124,13 @@ async def test_model_crud():
assert isinstance(track.album, ormar.Model) assert isinstance(track.album, ormar.Model)
assert track.album.name is None assert track.album.name is None
await track.album.load() await track.album.load()
assert track.album.name == "Malibu" assert track.album.name == "Jamaica"
assert len(album.tracks) == 3 assert len(album.tracks) == 3
assert album.tracks[1].title == "Heart don't stand a chance" assert album.tracks[1].title == "Heart don't stand a chance"
album1 = await Album.objects.get(name="Malibu") album1 = await Album.objects.get(name="Jamaica")
assert album1.pk == 1 assert album1.pk == album.pk
assert album1.tracks == [] assert album1.tracks == []
await Track.objects.create( await Track.objects.create(
@ -139,6 +141,7 @@ async def test_model_crud():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_select_related(): async def test_select_related():
async with database: async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Malibu") album = Album(name="Malibu")
await album.save() await album.save()
track1 = Track(album=album, title="The Bird", position=1) track1 = Track(album=album, title="The Bird", position=1)
@ -167,6 +170,7 @@ async def test_select_related():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_removal_from_relations(): async def test_model_removal_from_relations():
async with database: async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Chichi") album = Album(name="Chichi")
await album.save() await album.save()
track1 = Track(album=album, title="The Birdman", position=1) track1 = Track(album=album, title="The Birdman", position=1)
@ -205,9 +209,11 @@ async def test_model_removal_from_relations():
assert len(album.tracks) == 1 assert len(album.tracks) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fk_filter(): async def test_fk_filter():
async with database: async with database:
async with database.transaction(force_rollback=True):
malibu = Album(name="Malibu%") malibu = Album(name="Malibu%")
await malibu.save() await malibu.save()
await Track.objects.create(album=malibu, title="The Bird", position=1) await Track.objects.create(album=malibu, title="The Bird", position=1)
@ -239,7 +245,7 @@ async def test_fk_filter():
for track in tracks: for track in tracks:
assert track.album.name == "Fantasies" assert track.album.name == "Fantasies"
tracks = await Track.objects.filter(album__name__contains="fan").all() tracks = await Track.objects.filter(album__name__contains="Fan").all()
assert len(tracks) == 3 assert len(tracks) == 3
for track in tracks: for track in tracks:
assert track.album.name == "Fantasies" assert track.album.name == "Fantasies"
@ -261,6 +267,7 @@ async def test_fk_filter():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_fk(): async def test_multiple_fk():
async with database: async with database:
async with database.transaction(force_rollback=True):
acme = await Organisation.objects.create(ident="ACME Ltd") acme = await Organisation.objects.create(ident="ACME Ltd")
red_team = await Team.objects.create(org=acme, name="Red Team") red_team = await Team.objects.create(org=acme, name="Red Team")
blue_team = await Team.objects.create(org=acme, name="Blue Team") blue_team = await Team.objects.create(org=acme, name="Blue Team")
@ -286,11 +293,12 @@ async def test_multiple_fk():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pk_filter(): async def test_pk_filter():
async with database: async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Test") fantasies = await Album.objects.create(name="Test")
await Track.objects.create(album=fantasies, title="Test1", position=1) track = await Track.objects.create(album=fantasies, title="Test1", position=1)
await Track.objects.create(album=fantasies, title="Test2", position=2) await Track.objects.create(album=fantasies, title="Test2", position=2)
await Track.objects.create(album=fantasies, title="Test3", position=3) await Track.objects.create(album=fantasies, title="Test3", position=3)
tracks = await Track.objects.select_related("album").filter(pk=1).all() tracks = await Track.objects.select_related("album").filter(pk=track.pk).all()
assert len(tracks) == 1 assert len(tracks) == 1
tracks = ( tracks = (
@ -304,6 +312,7 @@ async def test_pk_filter():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_limit_and_offset(): async def test_limit_and_offset():
async with database: async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Limitless") fantasies = await Album.objects.create(name="Limitless")
await Track.objects.create(id=None, album=fantasies, title="Sample", position=1) await Track.objects.create(id=None, album=fantasies, title="Sample", position=1)
await Track.objects.create(album=fantasies, title="Sample2", position=2) await Track.objects.create(album=fantasies, title="Sample2", position=2)
@ -321,6 +330,7 @@ async def test_limit_and_offset():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_exceptions(): async def test_get_exceptions():
async with database: async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Test") fantasies = await Album.objects.create(name="Test")
with pytest.raises(NoMatch): with pytest.raises(NoMatch):
@ -335,6 +345,8 @@ async def test_get_exceptions():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wrong_model_passed_as_fk(): async def test_wrong_model_passed_as_fk():
async with database:
async with database.transaction(force_rollback=True):
with pytest.raises(RelationshipInstanceError): with pytest.raises(RelationshipInstanceError):
org = await Organisation.objects.create(ident="ACME Ltd") org = await Organisation.objects.create(ident="ACME Ltd")
await Track.objects.create(album=org, title="Test1", position=1) await Track.objects.create(album=org, title="Test1", position=1)

View File

@ -1,3 +1,5 @@
import asyncio
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
@ -50,8 +52,15 @@ class Post(ormar.Model):
author: ormar.ForeignKey(Author) author: ormar.ForeignKey(Author)
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
def create_test_database(): async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine) metadata.create_all(engine)
yield yield
@ -61,6 +70,7 @@ def create_test_database():
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def cleanup(): async def cleanup():
yield yield
async with database:
await PostCategory.objects.delete() await PostCategory.objects.delete()
await Post.objects.delete() await Post.objects.delete()
await Category.objects.delete() await Category.objects.delete()
@ -69,6 +79,7 @@ async def cleanup():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_assigning_related_objects(cleanup): async def test_assigning_related_objects(cleanup):
async with database:
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido) post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News") news = await Category.objects.create(name="News")
@ -88,6 +99,7 @@ async def test_assigning_related_objects(cleanup):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_quering_of_the_m2m_models(cleanup): async def test_quering_of_the_m2m_models(cleanup):
async with database:
# orm can do this already. # orm can do this already.
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido) post = await Post.objects.create(title="Hello, M2M", author=guido)
@ -120,6 +132,7 @@ async def test_quering_of_the_m2m_models(cleanup):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_removal_of_the_relations(cleanup): async def test_removal_of_the_relations(cleanup):
async with database:
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido) post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News") news = await Category.objects.create(name="News")
@ -146,6 +159,7 @@ async def test_removal_of_the_relations(cleanup):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_selecting_related(cleanup): async def test_selecting_related(cleanup):
async with database:
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido) post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News") news = await Category.objects.create(name="News")
@ -172,6 +186,7 @@ async def test_selecting_related(cleanup):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_selecting_related_fail_without_saving(cleanup): async def test_selecting_related_fail_without_saving(cleanup):
async with database:
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = Post(title="Hello, M2M", author=guido) post = Post(title="Hello, M2M", author=guido)
with pytest.raises(RelationshipInstanceError): with pytest.raises(RelationshipInstanceError):

View File

@ -112,7 +112,6 @@ def test_sqlalchemy_table_is_created(example):
def test_no_pk_in_model_definition(): def test_no_pk_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example3" tablename = "example3"
@ -123,7 +122,6 @@ def test_no_pk_in_model_definition():
def test_two_pks_in_model_definition(): def test_two_pks_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example3" tablename = "example3"
@ -135,7 +133,6 @@ def test_two_pks_in_model_definition():
def test_setting_pk_column_as_pydantic_only_in_model_definition(): def test_setting_pk_column_as_pydantic_only_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example4" tablename = "example4"
@ -146,7 +143,6 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition():
def test_decimal_error_in_model_definition(): def test_decimal_error_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example5" tablename = "example5"
@ -157,7 +153,6 @@ def test_decimal_error_in_model_definition():
def test_string_error_in_model_definition(): def test_string_error_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example6" tablename = "example6"

View File

@ -1,3 +1,5 @@
import asyncio
import databases import databases
import pydantic import pydantic
import pytest import pytest
@ -43,9 +45,17 @@ class Product(ormar.Model):
in_stock: ormar.Boolean(default=False) in_stock: ormar.Boolean(default=False)
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
def create_test_database(): async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine) metadata.create_all(engine)
yield yield
metadata.drop_all(engine) metadata.drop_all(engine)
@ -69,6 +79,7 @@ def test_model_pk():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_json_column(): async def test_json_column():
async with database: async with database:
async with database.transaction(force_rollback=True):
await JsonSample.objects.create(test_json=dict(aa=12)) await JsonSample.objects.create(test_json=dict(aa=12))
await JsonSample.objects.create(test_json='{"aa": 12}') await JsonSample.objects.create(test_json='{"aa": 12}')
@ -81,6 +92,7 @@ async def test_json_column():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_crud(): async def test_model_crud():
async with database: async with database:
async with database.transaction(force_rollback=True):
users = await User.objects.all() users = await User.objects.all()
assert users == [] assert users == []
@ -107,6 +119,7 @@ async def test_model_crud():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_get(): async def test_model_get():
async with database: async with database:
async with database.transaction(force_rollback=True):
with pytest.raises(ormar.NoMatch): with pytest.raises(ormar.NoMatch):
await User.objects.get() await User.objects.get()
@ -126,6 +139,7 @@ async def test_model_get():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_filter(): async def test_model_filter():
async with database: async with database:
async with database.transaction(force_rollback=True):
await User.objects.create(name="Tom") await User.objects.create(name="Tom")
await User.objects.create(name="Jane") await User.objects.create(name="Jane")
await User.objects.create(name="Lucy") await User.objects.create(name="Lucy")
@ -167,6 +181,7 @@ async def test_model_filter():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wrong_query_contains_model(): async def test_wrong_query_contains_model():
async with database:
with pytest.raises(QueryDefinitionError): with pytest.raises(QueryDefinitionError):
product = Product(name="90%-Cotton", rating=2) product = Product(name="90%-Cotton", rating=2)
await Product.objects.filter(name__contains=product).count() await Product.objects.filter(name__contains=product).count()
@ -175,6 +190,7 @@ async def test_wrong_query_contains_model():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_exists(): async def test_model_exists():
async with database: async with database:
async with database.transaction(force_rollback=True):
await User.objects.create(name="Tom") await User.objects.create(name="Tom")
assert await User.objects.filter(name="Tom").exists() is True assert await User.objects.filter(name="Tom").exists() is True
assert await User.objects.filter(name="Jane").exists() is False assert await User.objects.filter(name="Jane").exists() is False
@ -183,6 +199,7 @@ async def test_model_exists():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_count(): async def test_model_count():
async with database: async with database:
async with database.transaction(force_rollback=True):
await User.objects.create(name="Tom") await User.objects.create(name="Tom")
await User.objects.create(name="Jane") await User.objects.create(name="Jane")
await User.objects.create(name="Lucy") await User.objects.create(name="Lucy")
@ -194,6 +211,7 @@ async def test_model_count():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_limit(): async def test_model_limit():
async with database: async with database:
async with database.transaction(force_rollback=True):
await User.objects.create(name="Tom") await User.objects.create(name="Tom")
await User.objects.create(name="Jane") await User.objects.create(name="Jane")
await User.objects.create(name="Lucy") await User.objects.create(name="Lucy")
@ -204,6 +222,7 @@ async def test_model_limit():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_limit_with_filter(): async def test_model_limit_with_filter():
async with database: async with database:
async with database.transaction(force_rollback=True):
await User.objects.create(name="Tom") await User.objects.create(name="Tom")
await User.objects.create(name="Tom") await User.objects.create(name="Tom")
await User.objects.create(name="Tom") await User.objects.create(name="Tom")
@ -214,6 +233,7 @@ async def test_model_limit_with_filter():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_offset(): async def test_offset():
async with database: async with database:
async with database.transaction(force_rollback=True):
await User.objects.create(name="Tom") await User.objects.create(name="Tom")
await User.objects.create(name="Jane") await User.objects.create(name="Jane")
@ -224,6 +244,7 @@ async def test_offset():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_first(): async def test_model_first():
async with database: async with database:
async with database.transaction(force_rollback=True):
tom = await User.objects.create(name="Tom") tom = await User.objects.create(name="Tom")
jane = await User.objects.create(name="Jane") jane = await User.objects.create(name="Jane")

View File

@ -112,7 +112,7 @@ def test_all_endpoints():
assert items[0].name == "New name" assert items[0].name == "New name"
response = client.delete(f"/items/{item.pk}", json=item.dict()) response = client.delete(f"/items/{item.pk}", json=item.dict())
assert response.json().get("deleted_rows") == 1 assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__"
response = client.get("/items/") response = client.get("/items/")
items = response.json() items = response.json()
assert len(items) == 0 assert len(items) == 0

View File

@ -78,6 +78,11 @@ async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine) metadata.drop_all(engine)
metadata.create_all(engine) metadata.create_all(engine)
yield
metadata.drop_all(engine)
async def create_data():
department = await Department.objects.create(id=1, name="Math Department") department = await Department.objects.create(id=1, name="Math Department")
department2 = await Department.objects.create(id=2, name="Law Department") department2 = await Department.objects.create(id=2, name="Law Department")
class1 = await SchoolClass.objects.create(name="Math") class1 = await SchoolClass.objects.create(name="Math")
@ -88,13 +93,11 @@ async def create_test_database():
await Student.objects.create(name="Judy", category=category2, schoolclass=class1) await Student.objects.create(name="Judy", category=category2, schoolclass=class1)
await Student.objects.create(name="Jack", category=category2, schoolclass=class2) await Student.objects.create(name="Jack", category=category2, schoolclass=class2)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema(): async def test_model_multiple_instances_of_same_table_in_schema():
async with database: async with database:
await create_data()
classes = await SchoolClass.objects.select_related( classes = await SchoolClass.objects.select_related(
["teachers__category__department", "students"] ["teachers__category__department", "students"]
).all() ).all()

View File

@ -78,6 +78,11 @@ async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine) metadata.drop_all(engine)
metadata.create_all(engine) metadata.create_all(engine)
yield
metadata.drop_all(engine)
async def create_data():
department = await Department.objects.create(id=1, name="Math Department") department = await Department.objects.create(id=1, name="Math Department")
department2 = await Department.objects.create(id=2, name="Law Department") department2 = await Department.objects.create(id=2, name="Law Department")
class1 = await SchoolClass.objects.create(name="Math", department=department) class1 = await SchoolClass.objects.create(name="Math", department=department)
@ -88,13 +93,13 @@ async def create_test_database():
await Student.objects.create(name="Judy", category=category2, schoolclass=class1) await Student.objects.create(name="Judy", category=category2, schoolclass=class1)
await Student.objects.create(name="Jack", category=category2, schoolclass=class2) await Student.objects.create(name="Jack", category=category2, schoolclass=class2)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema(): async def test_model_multiple_instances_of_same_table_in_schema():
async with database: async with database:
async with database.transaction(force_rollback=True):
await create_data()
classes = await SchoolClass.objects.select_related( classes = await SchoolClass.objects.select_related(
["teachers__category", "students"] ["teachers__category", "students"]
).all() ).all()
@ -116,6 +121,8 @@ async def test_model_multiple_instances_of_same_table_in_schema():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_right_tables_join(): async def test_right_tables_join():
async with database: async with database:
async with database.transaction(force_rollback=True):
await create_data()
classes = await SchoolClass.objects.select_related( classes = await SchoolClass.objects.select_related(
["teachers__category", "students"] ["teachers__category", "students"]
).all() ).all()
@ -129,6 +136,8 @@ async def test_right_tables_join():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_reverse_related_objects(): async def test_multiple_reverse_related_objects():
async with database: async with database:
async with database.transaction(force_rollback=True):
await create_data()
classes = await SchoolClass.objects.select_related( classes = await SchoolClass.objects.select_related(
["teachers__category", "students__category"] ["teachers__category", "students__category"]
).all() ).all()