fix for sample data in tests

This commit is contained in:
collerek
2020-12-01 08:34:26 +01:00
parent 61da7b4418
commit 4c4e6248b0

View File

@ -46,9 +46,7 @@ class Track(ormar.Model):
written_by: Optional[Writer] = ormar.ForeignKey(Writer) written_by: Optional[Writer] = ormar.ForeignKey(Writer)
@pytest.fixture(autouse=True) async def get_sample_data():
@pytest.mark.asyncio
async def sample_data():
album = await Album(name="Malibu").save() album = await Album(name="Malibu").save()
writer1 = await Writer.objects.create(name="John") writer1 = await Writer.objects.create(name="John")
writer2 = await Writer.objects.create(name="Sue") writer2 = await Writer.objects.create(name="Sue")
@ -78,9 +76,10 @@ def create_test_database():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_quering_by_reverse_fk(sample_data): async def test_quering_by_reverse_fk():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
sample_data = await get_sample_data()
track1 = sample_data[1][0] track1 = sample_data[1][0]
album = await Album.objects.first() album = await Album.objects.first()
@ -128,9 +127,10 @@ async def test_quering_by_reverse_fk(sample_data):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_getting(sample_data): async def test_getting():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
sample_data = await get_sample_data()
album = sample_data[0] album = sample_data[0]
track1 = await album.tracks.fields(["album", "title", "position"]).get( track1 = await album.tracks.fields(["album", "title", "position"]).get(
title="The Bird" title="The Bird"
@ -192,9 +192,10 @@ async def test_getting(sample_data):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_loading_related(sample_data): async def test_loading_related():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
sample_data = await get_sample_data()
album = sample_data[0] album = sample_data[0]
tracks = await album.tracks.select_related("written_by").all() tracks = await album.tracks.select_related("written_by").all()
assert len(tracks) == 3 assert len(tracks) == 3
@ -210,9 +211,10 @@ async def test_loading_related(sample_data):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_adding_removing(sample_data): async def test_adding_removing():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
sample_data = await get_sample_data()
album = sample_data[0] album = sample_data[0]
track_new = await Track(title="Rainbow", position=5, play_count=300).save() track_new = await Track(title="Rainbow", position=5, play_count=300).save()
await album.tracks.add(track_new) await album.tracks.add(track_new)