Merge pull request #6 from collerek/test_also_mysql_and_postgress

Test also mysql and postgress
This commit is contained in:
collerek
2020-09-17 18:19:26 +07:00
committed by GitHub
16 changed files with 508 additions and 412 deletions

BIN
.coverage

Binary file not shown.

View File

@ -5,15 +5,28 @@ dist: xenial
cache: pip cache: pip
python: python:
- "3.6" - "3.6"
- "3.7" - "3.7"
- "3.8" - "3.8"
services:
- postgresql
- mysql
install: install:
- pip install -U -r requirements.txt - pip install -U -r requirements.txt
before_script:
- psql -c 'create database test_database;' -U postgres
- echo 'create database test_database;' | mysql
script: script:
- scripts/test.sh - DATABASE_URL=postgresql://localhost/test_database scripts/test.sh
- DATABASE_URL=mysql://localhost/test_database scripts/test.sh
- DATABASE_URL=sqlite:///test.db scripts/test.sh
after_script: after_script:
- codecov - codecov

View File

@ -2,6 +2,7 @@ import itertools
from typing import Any, List, Tuple, Union from typing import Any, List, Tuple, Union
import sqlalchemy import sqlalchemy
from databases.backends.postgres import Record
import ormar.queryset # noqa I100 import ormar.queryset # noqa I100
from ormar.fields.many_to_many import ManyToManyField from ormar.fields.many_to_many import ManyToManyField
@ -88,14 +89,18 @@ class Model(NewBaseModel):
return item return item
@classmethod @classmethod
def extract_prefixed_table_columns( def extract_prefixed_table_columns( # noqa CCR001
cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str
) -> dict: ) -> dict:
for column in cls.Meta.table.columns: for column in cls.Meta.table.columns:
if column.name not in item: if column.name not in item:
item[column.name] = row[ prefixed_name = (
f'{table_prefix + "_" if table_prefix else ""}{column.name}' f'{table_prefix + "_" if table_prefix else ""}{column.name}'
] )
# databases does not keep aliases in Record for postgres
source = row._row if isinstance(row, Record) else row
item[column.name] = source[prefixed_name]
return item return item
async def save(self) -> "Model": async def save(self) -> "Model":
@ -106,7 +111,8 @@ class Model(NewBaseModel):
expr = self.Meta.table.insert() expr = self.Meta.table.insert()
expr = expr.values(**self_fields) expr = expr.values(**self_fields)
item_id = await self.Meta.database.execute(expr) item_id = await self.Meta.database.execute(expr)
setattr(self, self.Meta.pkname, item_id) if item_id: # postgress does not return id if it's already there
setattr(self, self.Meta.pkname, item_id)
return self return self
async def update(self, **kwargs: Any) -> "Model": async def update(self, **kwargs: Any) -> "Model":

View File

@ -9,7 +9,6 @@ from ormar.queryset import FilterQuery
from ormar.queryset.clause import QueryClause from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query from ormar.queryset.query import Query
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -187,5 +186,6 @@ class QuerySet:
# Execute the insert, and return a new model instance. # Execute the insert, and return a new model instance.
instance = self.model_cls(**kwargs) instance = self.model_cls(**kwargs)
pk = await self.database.execute(expr) pk = await self.database.execute(expr)
setattr(instance, self.model_cls.Meta.pkname, pk) if pk:
setattr(instance, self.model_cls.Meta.pkname, pk)
return instance return instance

View File

@ -8,7 +8,8 @@ from sqlalchemy import text
def get_table_alias() -> str: def get_table_alias() -> str:
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] alias = "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
return alias.lower()
class AliasManager: class AliasManager:

View File

@ -1,7 +1,20 @@
databases[sqlite] databases[sqlite]
databases[postgresql]
databases[mysql]
pydantic pydantic
sqlalchemy sqlalchemy
# Async database drivers
aiomysql
aiosqlite
aiopg
asyncpg
# Sync database drivers for standard tooling around setup/teardown/migrations.
pymysql
psycopg2-binary
mysqlclient
# Testing # Testing
pytest pytest
pytest-cov pytest-cov

View File

@ -1,3 +1,11 @@
import os import os
import databases
assert "DATABASE_URL" in os.environ, "DATABASE_URL is not set."
DATABASE_URL = os.environ['DATABASE_URL']
database_url = databases.DatabaseURL(DATABASE_URL)
if database_url.scheme == "postgresql+aiopg": # pragma no cover
DATABASE_URL = str(database_url.replace(driver=None))
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///test.db") DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///test.db")

View File

@ -1,4 +1,5 @@
import datetime import datetime
import os
import databases import databases
import pytest import pytest

View File

@ -1,7 +1,7 @@
import databases import databases
import sqlalchemy import sqlalchemy
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.testclient import TestClient from starlette.testclient import TestClient
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -38,18 +38,17 @@ async def create_item(item: Item):
return item return item
client = TestClient(app)
def test_read_main(): def test_read_main():
response = client.post( client = TestClient(app)
"/items/", json={"name": "test", "id": 1, "category": {"name": "test cat"}} with client as client:
) response = client.post(
assert response.status_code == 200 "/items/", json={"name": "test", "id": 1, "category": {"name": "test cat"}}
assert response.json() == { )
"category": {"id": None, "name": "test cat"}, assert response.status_code == 200
"id": 1, assert response.json() == {
"name": "test", "category": {"id": None, "name": "test cat"},
} "id": 1,
item = Item(**response.json()) "name": "test",
assert item.id == 1 }
item = Item(**response.json())
assert item.id == 1

View File

@ -78,6 +78,7 @@ class Member(ormar.Model):
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
def create_test_database(): def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine) metadata.create_all(engine)
yield yield
metadata.drop_all(engine) metadata.drop_all(engine)
@ -85,8 +86,9 @@ def create_test_database():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wrong_query_foreign_key_type(): async def test_wrong_query_foreign_key_type():
with pytest.raises(RelationshipInstanceError): async with database:
Track(title="The Error", album="wrong_pk_type") with pytest.raises(RelationshipInstanceError):
Track(title="The Error", album="wrong_pk_type")
@pytest.mark.asyncio @pytest.mark.asyncio
@ -99,242 +101,252 @@ async def test_setting_explicitly_empty_relation():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_related_name(): async def test_related_name():
async with database: async with database:
album = await Album.objects.create(name="Vanilla") async with database.transaction(force_rollback=True):
await Cover.objects.create(album=album, title="The cover file") album = await Album.objects.create(name="Vanilla")
await Cover.objects.create(album=album, title="The cover file")
assert len(album.cover_pictures) == 1 assert len(album.cover_pictures) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_crud(): async def test_model_crud():
async with database: async with database:
album = Album(name="Malibu") async with database.transaction(force_rollback=True):
await album.save() album = Album(name="Jamaica")
track1 = Track(album=album, title="The Bird", position=1) await album.save()
track2 = Track(album=album, title="Heart don't stand a chance", position=2) track1 = Track(album=album, title="The Bird", position=1)
track3 = Track(album=album, title="The Waters", position=3) track2 = Track(album=album, title="Heart don't stand a chance", position=2)
await track1.save() track3 = Track(album=album, title="The Waters", position=3)
await track2.save() await track1.save()
await track3.save() await track2.save()
await track3.save()
track = await Track.objects.get(title="The Bird") track = await Track.objects.get(title="The Bird")
assert track.album.pk == album.pk assert track.album.pk == album.pk
assert isinstance(track.album, ormar.Model) assert isinstance(track.album, ormar.Model)
assert track.album.name is None assert track.album.name is None
await track.album.load() await track.album.load()
assert track.album.name == "Malibu" assert track.album.name == "Jamaica"
assert len(album.tracks) == 3 assert len(album.tracks) == 3
assert album.tracks[1].title == "Heart don't stand a chance" assert album.tracks[1].title == "Heart don't stand a chance"
album1 = await Album.objects.get(name="Malibu") album1 = await Album.objects.get(name="Jamaica")
assert album1.pk == 1 assert album1.pk == album.pk
assert album1.tracks == [] assert album1.tracks == []
await Track.objects.create( await Track.objects.create(
album={"id": track.album.pk}, title="The Bird2", position=4 album={"id": track.album.pk}, title="The Bird2", position=4
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_select_related(): async def test_select_related():
async with database: async with database:
album = Album(name="Malibu") async with database.transaction(force_rollback=True):
await album.save() album = Album(name="Malibu")
track1 = Track(album=album, title="The Bird", position=1) await album.save()
track2 = Track(album=album, title="Heart don't stand a chance", position=2) track1 = Track(album=album, title="The Bird", position=1)
track3 = Track(album=album, title="The Waters", position=3) track2 = Track(album=album, title="Heart don't stand a chance", position=2)
await track1.save() track3 = Track(album=album, title="The Waters", position=3)
await track2.save() await track1.save()
await track3.save() await track2.save()
await track3.save()
fantasies = Album(name="Fantasies") fantasies = Album(name="Fantasies")
await fantasies.save() await fantasies.save()
track4 = Track(album=fantasies, title="Help I'm Alive", position=1) track4 = Track(album=fantasies, title="Help I'm Alive", position=1)
track5 = Track(album=fantasies, title="Sick Muse", position=2) track5 = Track(album=fantasies, title="Sick Muse", position=2)
track6 = Track(album=fantasies, title="Satellite Mind", position=3) track6 = Track(album=fantasies, title="Satellite Mind", position=3)
await track4.save() await track4.save()
await track5.save() await track5.save()
await track6.save() await track6.save()
track = await Track.objects.select_related("album").get(title="The Bird") track = await Track.objects.select_related("album").get(title="The Bird")
assert track.album.name == "Malibu" assert track.album.name == "Malibu"
tracks = await Track.objects.select_related("album").all() tracks = await Track.objects.select_related("album").all()
assert len(tracks) == 6 assert len(tracks) == 6
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_removal_from_relations(): async def test_model_removal_from_relations():
async with database: async with database:
album = Album(name="Chichi") async with database.transaction(force_rollback=True):
await album.save() album = Album(name="Chichi")
track1 = Track(album=album, title="The Birdman", position=1) await album.save()
track2 = Track(album=album, title="Superman", position=2) track1 = Track(album=album, title="The Birdman", position=1)
track3 = Track(album=album, title="Wonder Woman", position=3) track2 = Track(album=album, title="Superman", position=2)
await track1.save() track3 = Track(album=album, title="Wonder Woman", position=3)
await track2.save() await track1.save()
await track3.save() await track2.save()
await track3.save()
assert len(album.tracks) == 3 assert len(album.tracks) == 3
await album.tracks.remove(track1) await album.tracks.remove(track1)
assert len(album.tracks) == 2 assert len(album.tracks) == 2
assert track1.album is None assert track1.album is None
await track1.update() await track1.update()
track1 = await Track.objects.get(title="The Birdman") track1 = await Track.objects.get(title="The Birdman")
assert track1.album is None assert track1.album is None
await album.tracks.add(track1) await album.tracks.add(track1)
assert len(album.tracks) == 3 assert len(album.tracks) == 3
assert track1.album == album assert track1.album == album
await track1.update() await track1.update()
track1 = await Track.objects.select_related("album__tracks").get( track1 = await Track.objects.select_related("album__tracks").get(
title="The Birdman" title="The Birdman"
) )
album = await Album.objects.select_related("tracks").get(name="Chichi") album = await Album.objects.select_related("tracks").get(name="Chichi")
assert track1.album == album assert track1.album == album
track1.remove(album) track1.remove(album)
assert track1.album is None assert track1.album is None
assert len(album.tracks) == 2 assert len(album.tracks) == 2
track2.remove(album)
assert track2.album is None
assert len(album.tracks) == 1
track2.remove(album)
assert track2.album is None
assert len(album.tracks) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fk_filter(): async def test_fk_filter():
async with database: async with database:
malibu = Album(name="Malibu%") async with database.transaction(force_rollback=True):
await malibu.save() malibu = Album(name="Malibu%")
await Track.objects.create(album=malibu, title="The Bird", position=1) await malibu.save()
await Track.objects.create( await Track.objects.create(album=malibu, title="The Bird", position=1)
album=malibu, title="Heart don't stand a chance", position=2 await Track.objects.create(
) album=malibu, title="Heart don't stand a chance", position=2
await Track.objects.create(album=malibu, title="The Waters", position=3) )
await Track.objects.create(album=malibu, title="The Waters", position=3)
fantasies = await Album.objects.create(name="Fantasies") fantasies = await Album.objects.create(name="Fantasies")
await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1)
await Track.objects.create(album=fantasies, title="Sick Muse", position=2) await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) await Track.objects.create(album=fantasies, title="Satellite Mind", position=3)
tracks = ( tracks = (
await Track.objects.select_related("album") await Track.objects.select_related("album")
.filter(album__name="Fantasies") .filter(album__name="Fantasies")
.all() .all()
) )
assert len(tracks) == 3 assert len(tracks) == 3
for track in tracks: for track in tracks:
assert track.album.name == "Fantasies" assert track.album.name == "Fantasies"
tracks = ( tracks = (
await Track.objects.select_related("album") await Track.objects.select_related("album")
.filter(album__name__icontains="fan") .filter(album__name__icontains="fan")
.all() .all()
) )
assert len(tracks) == 3 assert len(tracks) == 3
for track in tracks: for track in tracks:
assert track.album.name == "Fantasies" assert track.album.name == "Fantasies"
tracks = await Track.objects.filter(album__name__contains="fan").all() tracks = await Track.objects.filter(album__name__contains="Fan").all()
assert len(tracks) == 3 assert len(tracks) == 3
for track in tracks: for track in tracks:
assert track.album.name == "Fantasies" assert track.album.name == "Fantasies"
tracks = await Track.objects.filter(album__name__contains="Malibu%").all() tracks = await Track.objects.filter(album__name__contains="Malibu%").all()
assert len(tracks) == 3 assert len(tracks) == 3
tracks = await Track.objects.filter(album=malibu).select_related("album").all() tracks = await Track.objects.filter(album=malibu).select_related("album").all()
assert len(tracks) == 3 assert len(tracks) == 3
for track in tracks: for track in tracks:
assert track.album.name == "Malibu%" assert track.album.name == "Malibu%"
tracks = await Track.objects.select_related("album").all(album=malibu) tracks = await Track.objects.select_related("album").all(album=malibu)
assert len(tracks) == 3 assert len(tracks) == 3
for track in tracks: for track in tracks:
assert track.album.name == "Malibu%" assert track.album.name == "Malibu%"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_fk(): async def test_multiple_fk():
async with database: async with database:
acme = await Organisation.objects.create(ident="ACME Ltd") async with database.transaction(force_rollback=True):
red_team = await Team.objects.create(org=acme, name="Red Team") acme = await Organisation.objects.create(ident="ACME Ltd")
blue_team = await Team.objects.create(org=acme, name="Blue Team") red_team = await Team.objects.create(org=acme, name="Red Team")
await Member.objects.create(team=red_team, email="a@example.org") blue_team = await Team.objects.create(org=acme, name="Blue Team")
await Member.objects.create(team=red_team, email="b@example.org") await Member.objects.create(team=red_team, email="a@example.org")
await Member.objects.create(team=blue_team, email="c@example.org") await Member.objects.create(team=red_team, email="b@example.org")
await Member.objects.create(team=blue_team, email="d@example.org") await Member.objects.create(team=blue_team, email="c@example.org")
await Member.objects.create(team=blue_team, email="d@example.org")
other = await Organisation.objects.create(ident="Other ltd") other = await Organisation.objects.create(ident="Other ltd")
team = await Team.objects.create(org=other, name="Green Team") team = await Team.objects.create(org=other, name="Green Team")
await Member.objects.create(team=team, email="e@example.org") await Member.objects.create(team=team, email="e@example.org")
members = ( members = (
await Member.objects.select_related("team__org") await Member.objects.select_related("team__org")
.filter(team__org__ident="ACME Ltd") .filter(team__org__ident="ACME Ltd")
.all() .all()
) )
assert len(members) == 4 assert len(members) == 4
for member in members: for member in members:
assert member.team.org.ident == "ACME Ltd" assert member.team.org.ident == "ACME Ltd"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pk_filter(): async def test_pk_filter():
async with database: async with database:
fantasies = await Album.objects.create(name="Test") async with database.transaction(force_rollback=True):
await Track.objects.create(album=fantasies, title="Test1", position=1) fantasies = await Album.objects.create(name="Test")
await Track.objects.create(album=fantasies, title="Test2", position=2) track = await Track.objects.create(album=fantasies, title="Test1", position=1)
await Track.objects.create(album=fantasies, title="Test3", position=3) await Track.objects.create(album=fantasies, title="Test2", position=2)
tracks = await Track.objects.select_related("album").filter(pk=1).all() await Track.objects.create(album=fantasies, title="Test3", position=3)
assert len(tracks) == 1 tracks = await Track.objects.select_related("album").filter(pk=track.pk).all()
assert len(tracks) == 1
tracks = ( tracks = (
await Track.objects.select_related("album") await Track.objects.select_related("album")
.filter(position=2, album__name="Test") .filter(position=2, album__name="Test")
.all() .all()
) )
assert len(tracks) == 1 assert len(tracks) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_limit_and_offset(): async def test_limit_and_offset():
async with database: async with database:
fantasies = await Album.objects.create(name="Limitless") async with database.transaction(force_rollback=True):
await Track.objects.create(id=None, album=fantasies, title="Sample", position=1) fantasies = await Album.objects.create(name="Limitless")
await Track.objects.create(album=fantasies, title="Sample2", position=2) await Track.objects.create(id=None, album=fantasies, title="Sample", position=1)
await Track.objects.create(album=fantasies, title="Sample3", position=3) await Track.objects.create(album=fantasies, title="Sample2", position=2)
await Track.objects.create(album=fantasies, title="Sample3", position=3)
tracks = await Track.objects.limit(1).all() tracks = await Track.objects.limit(1).all()
assert len(tracks) == 1 assert len(tracks) == 1
assert tracks[0].title == "Sample" assert tracks[0].title == "Sample"
tracks = await Track.objects.limit(1).offset(1).all() tracks = await Track.objects.limit(1).offset(1).all()
assert len(tracks) == 1 assert len(tracks) == 1
assert tracks[0].title == "Sample2" assert tracks[0].title == "Sample2"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_exceptions(): async def test_get_exceptions():
async with database: async with database:
fantasies = await Album.objects.create(name="Test") async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Test")
with pytest.raises(NoMatch): with pytest.raises(NoMatch):
await Album.objects.get(name="Test2") await Album.objects.get(name="Test2")
await Track.objects.create(album=fantasies, title="Test1", position=1) await Track.objects.create(album=fantasies, title="Test1", position=1)
await Track.objects.create(album=fantasies, title="Test2", position=2) await Track.objects.create(album=fantasies, title="Test2", position=2)
await Track.objects.create(album=fantasies, title="Test3", position=3) await Track.objects.create(album=fantasies, title="Test3", position=3)
with pytest.raises(MultipleMatches): with pytest.raises(MultipleMatches):
await Track.objects.select_related("album").get(album=fantasies) await Track.objects.select_related("album").get(album=fantasies)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wrong_model_passed_as_fk(): async def test_wrong_model_passed_as_fk():
with pytest.raises(RelationshipInstanceError): async with database:
org = await Organisation.objects.create(ident="ACME Ltd") async with database.transaction(force_rollback=True):
await Track.objects.create(album=org, title="Test1", position=1) with pytest.raises(RelationshipInstanceError):
org = await Organisation.objects.create(ident="ACME Ltd")
await Track.objects.create(album=org, title="Test1", position=1)

View File

@ -1,3 +1,5 @@
import asyncio
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
@ -50,8 +52,15 @@ class Post(ormar.Model):
author: ormar.ForeignKey(Author) author: ormar.ForeignKey(Author)
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
def create_test_database(): async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine) metadata.create_all(engine)
yield yield
@ -61,118 +70,124 @@ def create_test_database():
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def cleanup(): async def cleanup():
yield yield
await PostCategory.objects.delete() async with database:
await Post.objects.delete() await PostCategory.objects.delete()
await Category.objects.delete() await Post.objects.delete()
await Author.objects.delete() await Category.objects.delete()
await Author.objects.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_assigning_related_objects(cleanup): async def test_assigning_related_objects(cleanup):
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") async with database:
post = await Post.objects.create(title="Hello, M2M", author=guido) guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
news = await Category.objects.create(name="News") post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
# Add a category to a post. # Add a category to a post.
await post.categories.add(news) await post.categories.add(news)
# or from the other end: # or from the other end:
await news.posts.add(post) await news.posts.add(post)
# Creating related object from instance: # Creating related object from instance:
await post.categories.create(name="Tips") await post.categories.create(name="Tips")
assert len(post.categories) == 2 assert len(post.categories) == 2
post_categories = await post.categories.all() post_categories = await post.categories.all()
assert len(post_categories) == 2 assert len(post_categories) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_quering_of_the_m2m_models(cleanup): async def test_quering_of_the_m2m_models(cleanup):
# orm can do this already. async with database:
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") # orm can do this already.
post = await Post.objects.create(title="Hello, M2M", author=guido) guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
news = await Category.objects.create(name="News") post = await Post.objects.create(title="Hello, M2M", author=guido)
# tl;dr: `post.categories` exposes the QuerySet API. news = await Category.objects.create(name="News")
# tl;dr: `post.categories` exposes the QuerySet API.
await post.categories.add(news) await post.categories.add(news)
post_categories = await post.categories.all() post_categories = await post.categories.all()
assert len(post_categories) == 1 assert len(post_categories) == 1
assert news == await post.categories.get(name="News") assert news == await post.categories.get(name="News")
num_posts = await news.posts.count() num_posts = await news.posts.count()
assert num_posts == 1 assert num_posts == 1
posts_about_m2m = await news.posts.filter(title__contains="M2M").all() posts_about_m2m = await news.posts.filter(title__contains="M2M").all()
assert len(posts_about_m2m) == 1 assert len(posts_about_m2m) == 1
assert posts_about_m2m[0] == post assert posts_about_m2m[0] == post
posts_about_python = await Post.objects.filter(categories__name="python").all() posts_about_python = await Post.objects.filter(categories__name="python").all()
assert len(posts_about_python) == 0 assert len(posts_about_python) == 0
# Traversal of relationships: which categories has Guido contributed to? # Traversal of relationships: which categories has Guido contributed to?
category = await Category.objects.filter(posts__author=guido).get() category = await Category.objects.filter(posts__author=guido).get()
assert category == news assert category == news
# or: # or:
category2 = await Category.objects.filter(posts__author__first_name="Guido").get() category2 = await Category.objects.filter(posts__author__first_name="Guido").get()
assert category2 == news assert category2 == news
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_removal_of_the_relations(cleanup): async def test_removal_of_the_relations(cleanup):
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") async with database:
post = await Post.objects.create(title="Hello, M2M", author=guido) guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
news = await Category.objects.create(name="News") post = await Post.objects.create(title="Hello, M2M", author=guido)
await post.categories.add(news) news = await Category.objects.create(name="News")
assert len(await post.categories.all()) == 1 await post.categories.add(news)
await post.categories.remove(news) assert len(await post.categories.all()) == 1
assert len(await post.categories.all()) == 0 await post.categories.remove(news)
# or: assert len(await post.categories.all()) == 0
await news.posts.add(post) # or:
assert len(await news.posts.all()) == 1 await news.posts.add(post)
await news.posts.remove(post) assert len(await news.posts.all()) == 1
assert len(await news.posts.all()) == 0 await news.posts.remove(post)
assert len(await news.posts.all()) == 0
# Remove all related objects: # Remove all related objects:
await post.categories.add(news) await post.categories.add(news)
await post.categories.clear() await post.categories.clear()
assert len(await post.categories.all()) == 0 assert len(await post.categories.all()) == 0
# post would also lose 'news' category when running: # post would also lose 'news' category when running:
await post.categories.add(news) await post.categories.add(news)
await news.delete() await news.delete()
assert len(await post.categories.all()) == 0 assert len(await post.categories.all()) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_selecting_related(cleanup): async def test_selecting_related(cleanup):
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") async with database:
post = await Post.objects.create(title="Hello, M2M", author=guido) guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
news = await Category.objects.create(name="News") post = await Post.objects.create(title="Hello, M2M", author=guido)
recent = await Category.objects.create(name="Recent") news = await Category.objects.create(name="News")
await post.categories.add(news) recent = await Category.objects.create(name="Recent")
await post.categories.add(recent) await post.categories.add(news)
assert len(await post.categories.all()) == 2 await post.categories.add(recent)
# Loads categories and posts (2 queries) and perform the join in Python. assert len(await post.categories.all()) == 2
categories = await Category.objects.select_related("posts").all() # Loads categories and posts (2 queries) and perform the join in Python.
# No extra queries needed => no more `await`s required. categories = await Category.objects.select_related("posts").all()
for category in categories: # No extra queries needed => no more `await`s required.
assert category.posts[0] == post for category in categories:
assert category.posts[0] == post
news_posts = await news.posts.select_related("author").all() news_posts = await news.posts.select_related("author").all()
assert news_posts[0].author == guido assert news_posts[0].author == guido
assert (await post.categories.limit(1).all())[0] == news assert (await post.categories.limit(1).all())[0] == news
assert (await post.categories.offset(1).limit(1).all())[0] == recent assert (await post.categories.offset(1).limit(1).all())[0] == recent
assert await post.categories.first() == news assert await post.categories.first() == news
assert await post.categories.exists() assert await post.categories.exists()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_selecting_related_fail_without_saving(cleanup): async def test_selecting_related_fail_without_saving(cleanup):
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") async with database:
post = Post(title="Hello, M2M", author=guido) guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
with pytest.raises(RelationshipInstanceError): post = Post(title="Hello, M2M", author=guido)
await post.categories.all() with pytest.raises(RelationshipInstanceError):
await post.categories.all()

View File

@ -112,7 +112,6 @@ def test_sqlalchemy_table_is_created(example):
def test_no_pk_in_model_definition(): def test_no_pk_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example3" tablename = "example3"
@ -123,7 +122,6 @@ def test_no_pk_in_model_definition():
def test_two_pks_in_model_definition(): def test_two_pks_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example3" tablename = "example3"
@ -135,7 +133,6 @@ def test_two_pks_in_model_definition():
def test_setting_pk_column_as_pydantic_only_in_model_definition(): def test_setting_pk_column_as_pydantic_only_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example4" tablename = "example4"
@ -146,7 +143,6 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition():
def test_decimal_error_in_model_definition(): def test_decimal_error_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example5" tablename = "example5"
@ -157,7 +153,6 @@ def test_decimal_error_in_model_definition():
def test_string_error_in_model_definition(): def test_string_error_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example6" tablename = "example6"

View File

@ -1,3 +1,5 @@
import asyncio
import databases import databases
import pydantic import pydantic
import pytest import pytest
@ -43,9 +45,17 @@ class Product(ormar.Model):
in_stock: ormar.Boolean(default=False) in_stock: ormar.Boolean(default=False)
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
def create_test_database(): async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine) metadata.create_all(engine)
yield yield
metadata.drop_all(engine) metadata.drop_all(engine)
@ -69,166 +79,177 @@ def test_model_pk():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_json_column(): async def test_json_column():
async with database: async with database:
await JsonSample.objects.create(test_json=dict(aa=12)) async with database.transaction(force_rollback=True):
await JsonSample.objects.create(test_json='{"aa": 12}') await JsonSample.objects.create(test_json=dict(aa=12))
await JsonSample.objects.create(test_json='{"aa": 12}')
items = await JsonSample.objects.all() items = await JsonSample.objects.all()
assert len(items) == 2 assert len(items) == 2
assert items[0].test_json == dict(aa=12) assert items[0].test_json == dict(aa=12)
assert items[1].test_json == dict(aa=12) assert items[1].test_json == dict(aa=12)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_crud(): async def test_model_crud():
async with database: async with database:
users = await User.objects.all() async with database.transaction(force_rollback=True):
assert users == [] users = await User.objects.all()
assert users == []
user = await User.objects.create(name="Tom") user = await User.objects.create(name="Tom")
users = await User.objects.all() users = await User.objects.all()
assert user.name == "Tom" assert user.name == "Tom"
assert user.pk is not None assert user.pk is not None
assert users == [user] assert users == [user]
lookup = await User.objects.get() lookup = await User.objects.get()
assert lookup == user assert lookup == user
await user.update(name="Jane") await user.update(name="Jane")
users = await User.objects.all() users = await User.objects.all()
assert user.name == "Jane" assert user.name == "Jane"
assert user.pk is not None assert user.pk is not None
assert users == [user] assert users == [user]
await user.delete() await user.delete()
users = await User.objects.all() users = await User.objects.all()
assert users == [] assert users == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_get(): async def test_model_get():
async with database: async with database:
with pytest.raises(ormar.NoMatch): async with database.transaction(force_rollback=True):
await User.objects.get() with pytest.raises(ormar.NoMatch):
await User.objects.get()
user = await User.objects.create(name="Tom") user = await User.objects.create(name="Tom")
lookup = await User.objects.get() lookup = await User.objects.get()
assert lookup == user assert lookup == user
user = await User.objects.create(name="Jane") user = await User.objects.create(name="Jane")
with pytest.raises(ormar.MultipleMatches): with pytest.raises(ormar.MultipleMatches):
await User.objects.get() await User.objects.get()
same_user = await User.objects.get(pk=user.id) same_user = await User.objects.get(pk=user.id)
assert same_user.id == user.id assert same_user.id == user.id
assert same_user.pk == user.pk assert same_user.pk == user.pk
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_filter(): async def test_model_filter():
async with database: async with database:
await User.objects.create(name="Tom") async with database.transaction(force_rollback=True):
await User.objects.create(name="Jane") await User.objects.create(name="Tom")
await User.objects.create(name="Lucy") await User.objects.create(name="Jane")
await User.objects.create(name="Lucy")
user = await User.objects.get(name="Lucy") user = await User.objects.get(name="Lucy")
assert user.name == "Lucy" assert user.name == "Lucy"
with pytest.raises(ormar.NoMatch): with pytest.raises(ormar.NoMatch):
await User.objects.get(name="Jim") await User.objects.get(name="Jim")
await Product.objects.create(name="T-Shirt", rating=5, in_stock=True) await Product.objects.create(name="T-Shirt", rating=5, in_stock=True)
await Product.objects.create(name="Dress", rating=4) await Product.objects.create(name="Dress", rating=4)
await Product.objects.create(name="Coat", rating=3, in_stock=True) await Product.objects.create(name="Coat", rating=3, in_stock=True)
product = await Product.objects.get(name__iexact="t-shirt", rating=5) product = await Product.objects.get(name__iexact="t-shirt", rating=5)
assert product.pk is not None assert product.pk is not None
assert product.name == "T-Shirt" assert product.name == "T-Shirt"
assert product.rating == 5 assert product.rating == 5
products = await Product.objects.all(rating__gte=2, in_stock=True) products = await Product.objects.all(rating__gte=2, in_stock=True)
assert len(products) == 2 assert len(products) == 2
products = await Product.objects.all(name__icontains="T") products = await Product.objects.all(name__icontains="T")
assert len(products) == 2 assert len(products) == 2
# Test escaping % character from icontains, contains, and iexact # Test escaping % character from icontains, contains, and iexact
await Product.objects.create(name="100%-Cotton", rating=3) await Product.objects.create(name="100%-Cotton", rating=3)
await Product.objects.create(name="Cotton-100%-Egyptian", rating=3) await Product.objects.create(name="Cotton-100%-Egyptian", rating=3)
await Product.objects.create(name="Cotton-100%", rating=3) await Product.objects.create(name="Cotton-100%", rating=3)
products = Product.objects.filter(name__iexact="100%-cotton") products = Product.objects.filter(name__iexact="100%-cotton")
assert await products.count() == 1 assert await products.count() == 1
products = Product.objects.filter(name__contains="%") products = Product.objects.filter(name__contains="%")
assert await products.count() == 3 assert await products.count() == 3
products = Product.objects.filter(name__icontains="%") products = Product.objects.filter(name__icontains="%")
assert await products.count() == 3 assert await products.count() == 3
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wrong_query_contains_model(): async def test_wrong_query_contains_model():
with pytest.raises(QueryDefinitionError): async with database:
product = Product(name="90%-Cotton", rating=2) with pytest.raises(QueryDefinitionError):
await Product.objects.filter(name__contains=product).count() product = Product(name="90%-Cotton", rating=2)
await Product.objects.filter(name__contains=product).count()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_exists(): async def test_model_exists():
async with database: async with database:
await User.objects.create(name="Tom") async with database.transaction(force_rollback=True):
assert await User.objects.filter(name="Tom").exists() is True await User.objects.create(name="Tom")
assert await User.objects.filter(name="Jane").exists() is False assert await User.objects.filter(name="Tom").exists() is True
assert await User.objects.filter(name="Jane").exists() is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_count(): async def test_model_count():
async with database: async with database:
await User.objects.create(name="Tom") async with database.transaction(force_rollback=True):
await User.objects.create(name="Jane") await User.objects.create(name="Tom")
await User.objects.create(name="Lucy") await User.objects.create(name="Jane")
await User.objects.create(name="Lucy")
assert await User.objects.count() == 3 assert await User.objects.count() == 3
assert await User.objects.filter(name__icontains="T").count() == 1 assert await User.objects.filter(name__icontains="T").count() == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_limit(): async def test_model_limit():
async with database: async with database:
await User.objects.create(name="Tom") async with database.transaction(force_rollback=True):
await User.objects.create(name="Jane") await User.objects.create(name="Tom")
await User.objects.create(name="Lucy") await User.objects.create(name="Jane")
await User.objects.create(name="Lucy")
assert len(await User.objects.limit(2).all()) == 2 assert len(await User.objects.limit(2).all()) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_limit_with_filter(): async def test_model_limit_with_filter():
async with database: async with database:
await User.objects.create(name="Tom") async with database.transaction(force_rollback=True):
await User.objects.create(name="Tom") await User.objects.create(name="Tom")
await User.objects.create(name="Tom") await User.objects.create(name="Tom")
await User.objects.create(name="Tom")
assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2 assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_offset(): async def test_offset():
async with database: async with database:
await User.objects.create(name="Tom") async with database.transaction(force_rollback=True):
await User.objects.create(name="Jane") await User.objects.create(name="Tom")
await User.objects.create(name="Jane")
users = await User.objects.offset(1).limit(1).all() users = await User.objects.offset(1).limit(1).all()
assert users[0].name == "Jane" assert users[0].name == "Jane"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_first(): async def test_model_first():
async with database: async with database:
tom = await User.objects.create(name="Tom") async with database.transaction(force_rollback=True):
jane = await User.objects.create(name="Jane") tom = await User.objects.create(name="Tom")
jane = await User.objects.create(name="Jane")
assert await User.objects.first() == tom assert await User.objects.first() == tom
assert await User.objects.first(name="Jane") == jane assert await User.objects.first(name="Jane") == jane
assert await User.objects.filter(name="Jane").first() == jane assert await User.objects.filter(name="Jane").first() == jane
with pytest.raises(NoMatch): with pytest.raises(NoMatch):
await User.objects.filter(name="Lucy").first() await User.objects.filter(name="Lucy").first()

View File

@ -112,7 +112,7 @@ def test_all_endpoints():
assert items[0].name == "New name" assert items[0].name == "New name"
response = client.delete(f"/items/{item.pk}", json=item.dict()) response = client.delete(f"/items/{item.pk}", json=item.dict())
assert response.json().get("deleted_rows") == 1 assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__"
response = client.get("/items/") response = client.get("/items/")
items = response.json() items = response.json()
assert len(items) == 0 assert len(items) == 0

View File

@ -78,6 +78,11 @@ async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine) metadata.drop_all(engine)
metadata.create_all(engine) metadata.create_all(engine)
yield
metadata.drop_all(engine)
async def create_data():
department = await Department.objects.create(id=1, name="Math Department") department = await Department.objects.create(id=1, name="Math Department")
department2 = await Department.objects.create(id=2, name="Law Department") department2 = await Department.objects.create(id=2, name="Law Department")
class1 = await SchoolClass.objects.create(name="Math") class1 = await SchoolClass.objects.create(name="Math")
@ -88,13 +93,11 @@ async def create_test_database():
await Student.objects.create(name="Judy", category=category2, schoolclass=class1) await Student.objects.create(name="Judy", category=category2, schoolclass=class1)
await Student.objects.create(name="Jack", category=category2, schoolclass=class2) await Student.objects.create(name="Jack", category=category2, schoolclass=class2)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema(): async def test_model_multiple_instances_of_same_table_in_schema():
async with database: async with database:
await create_data()
classes = await SchoolClass.objects.select_related( classes = await SchoolClass.objects.select_related(
["teachers__category__department", "students"] ["teachers__category__department", "students"]
).all() ).all()

View File

@ -78,6 +78,11 @@ async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine) metadata.drop_all(engine)
metadata.create_all(engine) metadata.create_all(engine)
yield
metadata.drop_all(engine)
async def create_data():
department = await Department.objects.create(id=1, name="Math Department") department = await Department.objects.create(id=1, name="Math Department")
department2 = await Department.objects.create(id=2, name="Law Department") department2 = await Department.objects.create(id=2, name="Law Department")
class1 = await SchoolClass.objects.create(name="Math", department=department) class1 = await SchoolClass.objects.create(name="Math", department=department)
@ -88,52 +93,56 @@ async def create_test_database():
await Student.objects.create(name="Judy", category=category2, schoolclass=class1) await Student.objects.create(name="Judy", category=category2, schoolclass=class1)
await Student.objects.create(name="Jack", category=category2, schoolclass=class2) await Student.objects.create(name="Jack", category=category2, schoolclass=class2)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema(): async def test_model_multiple_instances_of_same_table_in_schema():
async with database: async with database:
classes = await SchoolClass.objects.select_related( async with database.transaction(force_rollback=True):
["teachers__category", "students"] await create_data()
).all() classes = await SchoolClass.objects.select_related(
assert classes[0].name == "Math" ["teachers__category", "students"]
assert classes[0].students[0].name == "Jane" ).all()
assert classes[0].name == "Math"
assert classes[0].students[0].name == "Jane"
assert len(classes[0].dict().get("students")) == 2 assert len(classes[0].dict().get("students")) == 2
# since it's going from schoolclass => teacher => schoolclass (same class) department is already populated # since it's going from schoolclass => teacher => schoolclass (same class) department is already populated
assert classes[0].students[0].schoolclass.name == "Math" assert classes[0].students[0].schoolclass.name == "Math"
assert classes[0].students[0].schoolclass.department.name is None assert classes[0].students[0].schoolclass.department.name is None
await classes[0].students[0].schoolclass.department.load() await classes[0].students[0].schoolclass.department.load()
assert classes[0].students[0].schoolclass.department.name == "Math Department" assert classes[0].students[0].schoolclass.department.name == "Math Department"
await classes[1].students[0].schoolclass.department.load() await classes[1].students[0].schoolclass.department.load()
assert classes[1].students[0].schoolclass.department.name == "Law Department" assert classes[1].students[0].schoolclass.department.name == "Law Department"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_right_tables_join(): async def test_right_tables_join():
async with database: async with database:
classes = await SchoolClass.objects.select_related( async with database.transaction(force_rollback=True):
["teachers__category", "students"] await create_data()
).all() classes = await SchoolClass.objects.select_related(
assert classes[0].teachers[0].category.name == "Domestic" ["teachers__category", "students"]
).all()
assert classes[0].teachers[0].category.name == "Domestic"
assert classes[0].students[0].category.name is None assert classes[0].students[0].category.name is None
await classes[0].students[0].category.load() await classes[0].students[0].category.load()
assert classes[0].students[0].category.name == "Foreign" assert classes[0].students[0].category.name == "Foreign"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_reverse_related_objects(): async def test_multiple_reverse_related_objects():
async with database: async with database:
classes = await SchoolClass.objects.select_related( async with database.transaction(force_rollback=True):
["teachers__category", "students__category"] await create_data()
).all() classes = await SchoolClass.objects.select_related(
assert classes[0].name == "Math" ["teachers__category", "students__category"]
assert classes[0].students[1].name == "Judy" ).all()
assert classes[0].students[0].category.name == "Foreign" assert classes[0].name == "Math"
assert classes[0].students[1].category.name == "Domestic" assert classes[0].students[1].name == "Judy"
assert classes[0].teachers[0].category.name == "Domestic" assert classes[0].students[0].category.name == "Foreign"
assert classes[0].students[1].category.name == "Domestic"
assert classes[0].teachers[0].category.name == "Domestic"