Bump supported fastapi versions (#1110)

* Bump supported fastapi version to <=0.97, change all fastapi tests from starlette client to httpx.AsyncClient

* Add lifecycle manager to fastapi tests

* Fix coverage

* Add python 3.11 to test suite, bump version
This commit is contained in:
collerek
2023-06-18 18:52:06 +02:00
committed by GitHub
parent e72e40dd6c
commit b1ab0de4d4
27 changed files with 733 additions and 587 deletions

View File

@ -17,7 +17,7 @@ jobs:
if: github.event_name == 'push' || github.event.pull_request.head.repo.full_name != 'collerek/ormar' if: github.event_name == 'push' || github.event.pull_request.head.repo.full_name != 'collerek/ormar'
strategy: strategy:
matrix: matrix:
python-version: [3.7, 3.8, 3.9, "3.10"] python-version: [3.7, 3.8, 3.9, "3.10", 3.11]
fail-fast: false fail-fast: false
services: services:
mysql: mysql:
@ -49,7 +49,7 @@ jobs:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install poetry==1.4.1 python -m pip install poetry==1.4.2
poetry install --extras "all" poetry install --extras "all"
env: env:
POETRY_VIRTUALENVS_CREATE: false POETRY_VIRTUALENVS_CREATE: false

View File

@ -21,7 +21,7 @@ async def test_initializing_models(aio_benchmark, num_models: int):
] ]
assert len(authors) == num_models assert len(authors) == num_models
initialize_models(num_models) await initialize_models(num_models)
@pytest.mark.parametrize("num_models", [10, 20, 40]) @pytest.mark.parametrize("num_models", [10, 20, 40])
@ -30,7 +30,7 @@ async def test_initializing_models_with_related_models(aio_benchmark, num_models
async def initialize_models_with_related_models( async def initialize_models_with_related_models(
author: Author, publisher: Publisher, num_models: int author: Author, publisher: Publisher, num_models: int
): ):
books = [ _ = [
Book( Book(
author=author, author=author,
publisher=publisher, publisher=publisher,
@ -43,6 +43,6 @@ async def test_initializing_models_with_related_models(aio_benchmark, num_models
author = await Author(name="Author", score=10).save() author = await Author(name="Author", score=10).save()
publisher = await Publisher(name="Publisher", prestige=random.randint(0, 10)).save() publisher = await Publisher(name="Publisher", prestige=random.randint(0, 10)).save()
ids = initialize_models_with_related_models( _ = initialize_models_with_related_models(
author=author, publisher=publisher, num_models=num_models author=author, publisher=publisher, num_models=num_models
) )

View File

@ -18,7 +18,7 @@ async def test_updating_models_individually(
@aio_benchmark @aio_benchmark
async def update(authors: List[Author]): async def update(authors: List[Author]):
for author in authors: for author in authors:
a = await author.update( _ = await author.update(
name="".join(random.sample(string.ascii_letters, 5)) name="".join(random.sample(string.ascii_letters, 5))
) )

View File

@ -18,7 +18,6 @@ from typing import (
Union, Union,
cast, cast,
) )
import functools
import databases import databases
import pydantic import pydantic
@ -242,6 +241,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
:param new_hash: The hash to update to :param new_hash: The hash to update to
:type new_hash: int :type new_hash: int
""" """
def _update_cache(relations: List[Relation], recurse: bool = True) -> None: def _update_cache(relations: List[Relation], recurse: bool = True) -> None:
for relation in relations: for relation in relations:
relation_proxy = relation.get() relation_proxy = relation.get()
@ -249,7 +249,10 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if hasattr(relation_proxy, "update_cache"): if hasattr(relation_proxy, "update_cache"):
relation_proxy.update_cache(prev_hash, new_hash) # type: ignore relation_proxy.update_cache(prev_hash, new_hash) # type: ignore
elif recurse and hasattr(relation_proxy, "_orm"): elif recurse and hasattr(relation_proxy, "_orm"):
_update_cache(relation_proxy._orm._relations.values(), recurse=False) # type: ignore _update_cache(
relation_proxy._orm._relations.values(), # type: ignore
recurse=False,
)
_update_cache(list(self._orm._relations.values())) _update_cache(list(self._orm._relations.values()))

View File

@ -133,8 +133,8 @@ class Relation(Generic[T]):
return None return None
else: else:
# We need to clear the weakrefs that don't point to anything anymore # We need to clear the weakrefs that don't point to anything anymore
# There's an assumption here that if some of the related models went out of scope, # There's an assumption here that if some of the related models
# then they all did, so we can just check the first one # went out of scope, then they all did, so we can just check the first one
try: try:
self.related_models[0].__repr__.__self__ self.related_models[0].__repr__.__self__
return self.related_models.index(child) return self.related_models.index(child)

View File

@ -145,7 +145,8 @@ class RelationProxy(Generic[T], List[T]):
""" """
item = self[index] item = self[index]
# Try to delete it, but do it the long way if weakly-referenced thing doesn't exist # Try to delete it, but do it a long way
# if weakly-referenced thing doesn't exist
try: try:
self._relation_cache.pop(item.__hash__()) self._relation_cache.pop(item.__hash__())
except ReferenceError: except ReferenceError:

812
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -3,8 +3,8 @@ name = "ormar"
[tool.poetry] [tool.poetry]
name = "ormar" name = "ormar"
version = "0.12.1" version = "0.12.2"
description = "A simple async ORM with fastapi in mind and pydantic validation." description = "An async ORM with fastapi in mind and pydantic validation."
authors = ["Radosław Drążkiewicz <collerek@gmail.com>"] authors = ["Radosław Drążkiewicz <collerek@gmail.com>"]
license = "MIT" license = "MIT"
readme = "README.md" readme = "README.md"
@ -56,6 +56,7 @@ psycopg2-binary = { version = "^2.9.1", optional = true }
mysqlclient = { version = "^2.1.0", optional = true } mysqlclient = { version = "^2.1.0", optional = true }
PyMySQL = { version = ">=0.9", optional = true } PyMySQL = { version = ">=0.9", optional = true }
[tool.poetry.dependencies.orjson] [tool.poetry.dependencies.orjson]
version = ">=3.6.4" version = ">=3.6.4"
optional = true optional = true
@ -75,7 +76,7 @@ pytest = "^7.3.1"
pytest-cov = "^4.0.0" pytest-cov = "^4.0.0"
codecov = "^2.1.13" codecov = "^2.1.13"
pytest-asyncio = "^0.21.0" pytest-asyncio = "^0.21.0"
fastapi = ">=0.70.1,<0.86" fastapi = ">=0.70.1,<=0.97"
flake8 = "^3.9.2" flake8 = "^3.9.2"
flake8-black = "^0.3.6" flake8-black = "^0.3.6"
flake8-bugbear = "^23.3.12" flake8-bugbear = "^23.3.12"
@ -137,6 +138,10 @@ all = [
"cryptography", "cryptography",
] ]
[tool.poetry.group.dev.dependencies]
httpx = "^0.24.1"
asgi-lifespan = "^2.1.0"
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@ -6,8 +6,9 @@ import databases
import pydantic import pydantic
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
from ormar import post_save, property_field from ormar import post_save, property_field
@ -153,34 +154,35 @@ async def create_user7(user: RandomModel):
return await user.save() return await user.save()
def test_excluding_fields_in_endpoints(): @pytest.mark.asyncio
client = TestClient(app) async def test_excluding_fields_in_endpoints():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
async with client as client, LifespanManager(app):
user = { user = {
"email": "test@domain.com", "email": "test@domain.com",
"password": "^*^%A*DA*IAAA", "password": "^*^%A*DA*IAAA",
"first_name": "John", "first_name": "John",
"last_name": "Doe", "last_name": "Doe",
} }
response = client.post("/users/", json=user) response = await client.post("/users/", json=user)
created_user = User(**response.json()) created_user = User(**response.json())
assert created_user.pk is not None assert created_user.pk is not None
assert created_user.password is None assert created_user.password is None
user2 = {"email": "test@domain.com", "first_name": "John", "last_name": "Doe"} user2 = {"email": "test@domain.com", "first_name": "John", "last_name": "Doe"}
response = client.post("/users/", json=user2) response = await client.post("/users/", json=user2)
created_user = User(**response.json()) created_user = User(**response.json())
assert created_user.pk is not None assert created_user.pk is not None
assert created_user.password is None assert created_user.password is None
response = client.post("/users2/", json=user) response = await client.post("/users2/", json=user)
created_user2 = User(**response.json()) created_user2 = User(**response.json())
assert created_user2.pk is not None assert created_user2.pk is not None
assert created_user2.password is None assert created_user2.password is None
# response has only 3 fields from UserBase # response has only 3 fields from UserBase
response = client.post("/users3/", json=user) response = await client.post("/users3/", json=user)
assert list(response.json().keys()) == ["email", "first_name", "last_name"] assert list(response.json().keys()) == ["email", "first_name", "last_name"]
timestamp = datetime.datetime.now() timestamp = datetime.datetime.now()
@ -192,7 +194,7 @@ def test_excluding_fields_in_endpoints():
"last_name": "Doe", "last_name": "Doe",
"timestamp": str(timestamp), "timestamp": str(timestamp),
} }
response = client.post("/users4/", json=user3) response = await client.post("/users4/", json=user3)
assert list(response.json().keys()) == [ assert list(response.json().keys()) == [
"id", "id",
"email", "email",
@ -209,7 +211,7 @@ def test_excluding_fields_in_endpoints():
assert isinstance(user_instance.timestamp, datetime.datetime) assert isinstance(user_instance.timestamp, datetime.datetime)
assert user_instance.timestamp == timestamp assert user_instance.timestamp == timestamp
response = client.post("/users4/", json=user3) response = await client.post("/users4/", json=user3)
assert list(response.json().keys()) == [ assert list(response.json().keys()) == [
"id", "id",
"email", "email",
@ -226,11 +228,12 @@ def test_excluding_fields_in_endpoints():
) )
def test_adding_fields_in_endpoints(): @pytest.mark.asyncio
client = TestClient(app) async def test_adding_fields_in_endpoints():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
async with client as client, LifespanManager(app):
user3 = {"last_name": "Test", "full_name": "deleted"} user3 = {"last_name": "Test", "full_name": "deleted"}
response = client.post("/random/", json=user3) response = await client.post("/random/", json=user3)
assert list(response.json().keys()) == [ assert list(response.json().keys()) == [
"id", "id",
"password", "password",
@ -242,7 +245,7 @@ def test_adding_fields_in_endpoints():
assert response.json().get("full_name") == "John Test" assert response.json().get("full_name") == "John Test"
user3 = {"last_name": "Test"} user3 = {"last_name": "Test"}
response = client.post("/random/", json=user3) response = await client.post("/random/", json=user3)
assert list(response.json().keys()) == [ assert list(response.json().keys()) == [
"id", "id",
"password", "password",
@ -254,11 +257,12 @@ def test_adding_fields_in_endpoints():
assert response.json().get("full_name") == "John Test" assert response.json().get("full_name") == "John Test"
def test_adding_fields_in_endpoints2(): @pytest.mark.asyncio
client = TestClient(app) async def test_adding_fields_in_endpoints2():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
async with client as client, LifespanManager(app):
user3 = {"last_name": "Test"} user3 = {"last_name": "Test"}
response = client.post("/random2/", json=user3) response = await client.post("/random2/", json=user3)
assert list(response.json().keys()) == [ assert list(response.json().keys()) == [
"id", "id",
"password", "password",
@ -270,7 +274,8 @@ def test_adding_fields_in_endpoints2():
assert response.json().get("full_name") == "John Test" assert response.json().get("full_name") == "John Test"
def test_excluding_property_field_in_endpoints2(): @pytest.mark.asyncio
async def test_excluding_property_field_in_endpoints2():
dummy_registry = {} dummy_registry = {}
@ -278,10 +283,10 @@ def test_excluding_property_field_in_endpoints2():
async def after_save(sender, instance, **kwargs): async def after_save(sender, instance, **kwargs):
dummy_registry[instance.pk] = instance.dict() dummy_registry[instance.pk] = instance.dict()
client = TestClient(app) client = AsyncClient(app=app, base_url="http://testserver")
with client as client: async with client as client, LifespanManager(app):
user3 = {"last_name": "Test"} user3 = {"last_name": "Test"}
response = client.post("/random3/", json=user3) response = await client.post("/random3/", json=user3)
assert list(response.json().keys()) == [ assert list(response.json().keys()) == [
"id", "id",
"password", "password",

View File

@ -6,8 +6,9 @@ from typing import List
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -78,16 +79,17 @@ def create_test_database():
metadata.drop_all(engine) metadata.drop_all(engine)
def test_read_main(): @pytest.mark.asyncio
client = TestClient(app) async def test_read_main():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
response = client.post( async with client as client, LifespanManager(app):
response = await client.post(
"/things", "/things",
data=json.dumps({"bt": base64.b64encode(blob3).decode()}), json={"bt": base64.b64encode(blob3).decode()},
headers=headers, headers=headers,
) )
assert response.status_code == 200 assert response.status_code == 200
response = client.get("/things") response = await client.get("/things")
assert response.json()[0]["bt"] == base64.b64encode(blob3).decode() assert response.json()[0]["bt"] == base64.b64encode(blob3).decode()
thing = BinaryThing(**response.json()[0]) thing = BinaryThing(**response.json()[0])
assert thing.__dict__["bt"] == blob3 assert thing.__dict__["bt"] == blob3

View File

@ -7,8 +7,9 @@ import databases
import pydantic import pydantic
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -93,16 +94,17 @@ async def create_item(item: Organisation):
return item return item
def test_all_endpoints(): @pytest.mark.asyncio
client = TestClient(app) async def test_all_endpoints():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
response = client.post( async with client as client, LifespanManager(app):
response = await client.post(
"/items/", "/items/",
json={"id": 1, "ident": "", "priority": 4, "expire_date": "2022-05-01"}, json={"id": 1, "ident": "", "priority": 4, "expire_date": "2022-05-01"},
) )
assert response.status_code == 422 assert response.status_code == 422
response = client.post( response = await client.post(
"/items/", "/items/",
json={ json={
"id": 1, "id": 1,
@ -124,7 +126,7 @@ def test_all_endpoints():
assert response.status_code == 200 assert response.status_code == 200
item = Organisation(**response.json()) item = Organisation(**response.json())
assert item.pk is not None assert item.pk is not None
response = client.get("/docs/") response = await client.get("/docs")
assert response.status_code == 200 assert response.status_code == 200
assert b"<title>FastAPI - Swagger UI</title>" in response.content assert b"<title>FastAPI - Swagger UI</title>" in response.content

View File

@ -2,9 +2,11 @@ from typing import Optional
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import databases import databases
import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
@ -60,10 +62,11 @@ async def get_cb2(): # pragma: no cover
return None return None
def test_all_endpoints(): @pytest.mark.asyncio
client = TestClient(app) async def test_all_endpoints():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
response = client.get("/openapi.json") async with client as client, LifespanManager(app):
response = await client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
schema = response.json() schema = response.json()
components = schema["components"]["schemas"] components = schema["components"]["schemas"]

View File

@ -1,7 +1,8 @@
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
from tests.test_inheritance_and_pydantic_generation.test_geting_pydantic_models import ( from tests.test_inheritance_and_pydantic_generation.test_geting_pydantic_models import (
@ -68,11 +69,12 @@ async def get_selfref(ref_id: int):
return selfr return selfr
def test_read_main(): @pytest.mark.asyncio
client = TestClient(app) async def test_read_main():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
async with client as client, LifespanManager(app):
test_category = dict(name="Foo", id=12) test_category = dict(name="Foo", id=12)
response = client.post("/categories/", json=test_category) response = await client.post("/categories/", json=test_category)
assert response.status_code == 200 assert response.status_code == 200
cat = Category(**response.json()) cat = Category(**response.json())
assert cat.name == "Foo" assert cat.name == "Foo"
@ -83,7 +85,7 @@ def test_read_main():
test_selfref2 = dict(name="test2", parent={"id": 1}) test_selfref2 = dict(name="test2", parent={"id": 1})
test_selfref3 = dict(name="test3", children=[{"name": "aaa"}]) test_selfref3 = dict(name="test3", children=[{"name": "aaa"}])
response = client.post("/selfrefs/", json=test_selfref) response = await client.post("/selfrefs/", json=test_selfref)
assert response.status_code == 200 assert response.status_code == 200
self_ref = SelfRef(**response.json()) self_ref = SelfRef(**response.json())
assert self_ref.id == 1 assert self_ref.id == 1
@ -91,7 +93,7 @@ def test_read_main():
assert self_ref.parent is None assert self_ref.parent is None
assert self_ref.children == [] assert self_ref.children == []
response = client.post("/selfrefs/", json=test_selfref2) response = await client.post("/selfrefs/", json=test_selfref2)
assert response.status_code == 200 assert response.status_code == 200
self_ref = SelfRef(**response.json()) self_ref = SelfRef(**response.json())
assert self_ref.id == 2 assert self_ref.id == 2
@ -99,7 +101,7 @@ def test_read_main():
assert self_ref.parent is None assert self_ref.parent is None
assert self_ref.children == [] assert self_ref.children == []
response = client.post("/selfrefs/", json=test_selfref3) response = await client.post("/selfrefs/", json=test_selfref3)
assert response.status_code == 200 assert response.status_code == 200
self_ref = SelfRef(**response.json()) self_ref = SelfRef(**response.json())
assert self_ref.id == 3 assert self_ref.id == 3
@ -107,7 +109,7 @@ def test_read_main():
assert self_ref.parent is None assert self_ref.parent is None
assert self_ref.children[0].dict() == {"id": 4} assert self_ref.children[0].dict() == {"id": 4}
response = client.get("/selfrefs/3/") response = await client.get("/selfrefs/3/")
assert response.status_code == 200 assert response.status_code == 200
check_children = SelfRef(**response.json()) check_children = SelfRef(**response.json())
assert check_children.children[0].dict() == { assert check_children.children[0].dict() == {
@ -117,7 +119,7 @@ def test_read_main():
"parent": {"id": 3, "name": "test3"}, "parent": {"id": 3, "name": "test3"},
} }
response = client.get("/selfrefs/2/") response = await client.get("/selfrefs/2/")
assert response.status_code == 200 assert response.status_code == 200
check_children = SelfRef(**response.json()) check_children = SelfRef(**response.json())
assert check_children.dict() == { assert check_children.dict() == {

View File

@ -3,8 +3,9 @@ from typing import List
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -103,26 +104,27 @@ async def get_item_excl(item_id: int):
return item return item
def test_all_endpoints(): @pytest.mark.asyncio
client = TestClient(app) async def test_all_endpoints():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
async with client as client, LifespanManager(app):
item = { item = {
"name": "test", "name": "test",
"categories": [{"name": "test cat"}, {"name": "test cat2"}], "categories": [{"name": "test cat"}, {"name": "test cat2"}],
} }
response = client.post("/items/", json=item) response = await client.post("/items/", json=item)
item_check = Item(**response.json()) item_check = Item(**response.json())
assert item_check.id is not None assert item_check.id is not None
assert item_check.categories[0].id is not None assert item_check.categories[0].id is not None
no_pk_item = client.get(f"/items/{item_check.id}", json=item).json() no_pk_item = (await client.get(f"/items/{item_check.id}")).json()
assert no_pk_item == item assert no_pk_item == item
no_pk_item2 = client.get(f"/items/fex/{item_check.id}", json=item).json() no_pk_item2 = (await client.get(f"/items/fex/{item_check.id}")).json()
assert no_pk_item2 == item assert no_pk_item2 == item
no_pk_category = client.get( no_pk_category = (
f"/categories/{item_check.categories[0].id}", json=item await client.get(f"/categories/{item_check.categories[0].id}")
).json() ).json()
assert no_pk_category == { assert no_pk_category == {
"items": [ "items": [
@ -134,8 +136,8 @@ def test_all_endpoints():
"name": "test cat", "name": "test cat",
} }
no_through_category = client.get( no_through_category = (
f"/categories/nt/{item_check.categories[0].id}", json=item await client.get(f"/categories/nt/{item_check.categories[0].id}")
).json() ).json()
assert no_through_category == { assert no_through_category == {
"id": 1, "id": 1,
@ -143,7 +145,7 @@ def test_all_endpoints():
"name": "test cat", "name": "test cat",
} }
no_through_category = client.get( no_through_category = (
f"/categories/ntp/{item_check.categories[0].id}", json=item await client.get(f"/categories/ntp/{item_check.categories[0].id}")
).json() ).json()
assert no_through_category == {"items": [{"name": "test"}], "name": "test cat"} assert no_through_category == {"items": [{"name": "test"}], "name": "test cat"}

View File

@ -3,8 +3,9 @@ import json
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
from ormar import Extra from ormar import Extra
@ -53,11 +54,12 @@ async def create_item(item: Item):
return await item.save() return await item.save()
def test_extra_parameters_in_request(): @pytest.mark.asyncio
client = TestClient(app) async def test_extra_parameters_in_request():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
async with client as client, LifespanManager(app):
data = {"name": "Name", "extraname": "to ignore"} data = {"name": "Name", "extraname": "to ignore"}
resp = client.post("item/", data=json.dumps(data)) resp = await client.post("item/", json=data)
assert resp.status_code == 200 assert resp.status_code == 200
assert "name" in resp.json() assert "name" in resp.json()
assert resp.json().get("name") == "Name" assert resp.json().get("name") == "Name"

View File

@ -5,8 +5,9 @@ import databases
import pydantic import pydantic
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -98,37 +99,38 @@ async def create_category(category: Category):
return category return category
def test_all_endpoints(): @pytest.mark.asyncio
client = TestClient(app) async def test_all_endpoints():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
response = client.post("/categories/", json={"name": "test cat"}) async with client as client, LifespanManager(app):
response = await client.post("/categories/", json={"name": "test cat"})
category = response.json() category = response.json()
response = client.post("/categories/", json={"name": "test cat2"}) response = await client.post("/categories/", json={"name": "test cat2"})
category2 = response.json() category2 = response.json()
response = client.post("/items/", json={"name": "test", "id": 1}) response = await client.post("/items/", json={"name": "test", "id": 1})
item = Item(**response.json()) item = Item(**response.json())
assert item.pk is not None assert item.pk is not None
response = client.post( response = await client.post(
"/items/add_category/", json={"item": item.dict(), "category": category} "/items/add_category/", json={"item": item.dict(), "category": category}
) )
item = Item(**response.json()) item = Item(**response.json())
assert len(item.categories) == 1 assert len(item.categories) == 1
assert item.categories[0].name == "test cat" assert item.categories[0].name == "test cat"
client.post( await client.post(
"/items/add_category/", json={"item": item.dict(), "category": category2} "/items/add_category/", json={"item": item.dict(), "category": category2}
) )
response = client.get("/items/") response = await client.get("/items/")
items = [Item(**item) for item in response.json()] items = [Item(**item) for item in response.json()]
assert items[0] == item assert items[0] == item
assert len(items[0].categories) == 2 assert len(items[0].categories) == 2
assert items[0].categories[0].name == "test cat" assert items[0].categories[0].name == "test cat"
assert items[0].categories[1].name == "test cat2" assert items[0].categories[1].name == "test cat2"
response = client.get("/docs/") response = await client.get("/docs")
assert response.status_code == 200 assert response.status_code == 200
assert b"<title>FastAPI - Swagger UI</title>" in response.content assert b"<title>FastAPI - Swagger UI</title>" in response.content

View File

@ -1,9 +1,11 @@
from typing import Optional from typing import Optional
import databases import databases
import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -40,10 +42,11 @@ async def create_item(item: Item):
return item return item
def test_read_main(): @pytest.mark.asyncio
client = TestClient(app) async def test_read_main():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
response = client.post( async with client as client, LifespanManager(app):
response = await client.post(
"/items/", json={"name": "test", "id": 1, "category": {"name": "test cat"}} "/items/", json={"name": "test", "id": 1, "category": {"name": "test cat"}}
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@ -3,8 +3,9 @@ from typing import List
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
from tests.test_inheritance_and_pydantic_generation.test_inheritance_concrete import ( # type: ignore from tests.test_inheritance_and_pydantic_generation.test_inheritance_concrete import ( # type: ignore
@ -118,13 +119,14 @@ def create_test_database():
metadata.drop_all(engine) metadata.drop_all(engine)
def test_read_main(): @pytest.mark.asyncio
client = TestClient(app) async def test_read_main():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
async with client as client, LifespanManager(app):
test_category = dict(name="Foo", code=123, created_by="Sam", updated_by="Max") test_category = dict(name="Foo", code=123, created_by="Sam", updated_by="Max")
test_subject = dict(name="Bar") test_subject = dict(name="Bar")
response = client.post("/categories/", json=test_category) response = await client.post("/categories/", json=test_category)
assert response.status_code == 200 assert response.status_code == 200
cat = Category(**response.json()) cat = Category(**response.json())
assert cat.name == "Foo" assert cat.name == "Foo"
@ -140,7 +142,7 @@ def test_read_main():
"%Y-%m-%d %H:%M:%S.%f" "%Y-%m-%d %H:%M:%S.%f"
) )
test_subject["category"] = cat_dict test_subject["category"] = cat_dict
response = client.post("/subjects/", json=test_subject) response = await client.post("/subjects/", json=test_subject)
assert response.status_code == 200 assert response.status_code == 200
sub = Subject(**response.json()) sub = Subject(**response.json())
assert sub.name == "Bar" assert sub.name == "Bar"
@ -148,11 +150,12 @@ def test_read_main():
assert isinstance(sub.updated_date, datetime.datetime) assert isinstance(sub.updated_date, datetime.datetime)
def test_inheritance_with_relation(): @pytest.mark.asyncio
client = TestClient(app) async def test_inheritance_with_relation():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
sam = Person(**client.post("/persons/", json={"name": "Sam"}).json()) async with client as client, LifespanManager(app):
joe = Person(**client.post("/persons/", json={"name": "Joe"}).json()) sam = Person(**(await client.post("/persons/", json={"name": "Sam"})).json())
joe = Person(**(await client.post("/persons/", json={"name": "Joe"})).json())
truck_dict = dict( truck_dict = dict(
name="Shelby wanna be", name="Shelby wanna be",
@ -163,8 +166,8 @@ def test_inheritance_with_relation():
bus_dict = dict( bus_dict = dict(
name="Unicorn", max_persons=50, owner=sam.dict(), co_owner=joe.dict() name="Unicorn", max_persons=50, owner=sam.dict(), co_owner=joe.dict()
) )
unicorn = Bus(**client.post("/buses/", json=bus_dict).json()) unicorn = Bus(**(await client.post("/buses/", json=bus_dict)).json())
shelby = Truck(**client.post("/trucks/", json=truck_dict).json()) shelby = Truck(**(await client.post("/trucks/", json=truck_dict)).json())
assert shelby.name == "Shelby wanna be" assert shelby.name == "Shelby wanna be"
assert shelby.owner.name == "Sam" assert shelby.owner.name == "Sam"
@ -178,36 +181,43 @@ def test_inheritance_with_relation():
assert unicorn.co_owner.name == "Joe" assert unicorn.co_owner.name == "Joe"
assert unicorn.max_persons == 50 assert unicorn.max_persons == 50
unicorn2 = Bus(**client.get(f"/buses/{unicorn.pk}").json()) unicorn2 = Bus(**(await client.get(f"/buses/{unicorn.pk}")).json())
assert unicorn2.name == "Unicorn" assert unicorn2.name == "Unicorn"
assert unicorn2.owner == sam assert unicorn2.owner == sam
assert unicorn2.owner.name == "Sam" assert unicorn2.owner.name == "Sam"
assert unicorn2.co_owner.name == "Joe" assert unicorn2.co_owner.name == "Joe"
assert unicorn2.max_persons == 50 assert unicorn2.max_persons == 50
buses = [Bus(**x) for x in client.get("/buses/").json()] buses = [Bus(**x) for x in (await client.get("/buses/")).json()]
assert len(buses) == 1 assert len(buses) == 1
assert buses[0].name == "Unicorn" assert buses[0].name == "Unicorn"
def test_inheritance_with_m2m_relation(): @pytest.mark.asyncio
client = TestClient(app) async def test_inheritance_with_m2m_relation():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
sam = Person(**client.post("/persons/", json={"name": "Sam"}).json()) async with client as client, LifespanManager(app):
joe = Person(**client.post("/persons/", json={"name": "Joe"}).json()) sam = Person(**(await client.post("/persons/", json={"name": "Sam"})).json())
alex = Person(**client.post("/persons/", json={"name": "Alex"}).json()) joe = Person(**(await client.post("/persons/", json={"name": "Joe"})).json())
alex = Person(**(await client.post("/persons/", json={"name": "Alex"})).json())
truck_dict = dict(name="Shelby wanna be", max_capacity=2000, owner=sam.dict()) truck_dict = dict(name="Shelby wanna be", max_capacity=2000, owner=sam.dict())
bus_dict = dict(name="Unicorn", max_persons=80, owner=sam.dict()) bus_dict = dict(name="Unicorn", max_persons=80, owner=sam.dict())
unicorn = Bus2(**client.post("/buses2/", json=bus_dict).json()) unicorn = Bus2(**(await client.post("/buses2/", json=bus_dict)).json())
shelby = Truck2(**client.post("/trucks2/", json=truck_dict).json()) shelby = Truck2(**(await client.post("/trucks2/", json=truck_dict)).json())
unicorn = Bus2( unicorn = Bus2(
**client.post(f"/buses2/{unicorn.pk}/add_coowner/", json=joe.dict()).json() **(
await client.post(f"/buses2/{unicorn.pk}/add_coowner/", json=joe.dict())
).json()
) )
unicorn = Bus2( unicorn = Bus2(
**client.post(f"/buses2/{unicorn.pk}/add_coowner/", json=alex.dict()).json() **(
await client.post(
f"/buses2/{unicorn.pk}/add_coowner/", json=alex.dict()
)
).json()
) )
assert shelby.name == "Shelby wanna be" assert shelby.name == "Shelby wanna be"
@ -222,10 +232,12 @@ def test_inheritance_with_m2m_relation():
assert unicorn.co_owners[1] == alex assert unicorn.co_owners[1] == alex
assert unicorn.max_persons == 80 assert unicorn.max_persons == 80
client.post(f"/trucks2/{shelby.pk}/add_coowner/", json=alex.dict()) await client.post(f"/trucks2/{shelby.pk}/add_coowner/", json=alex.dict())
shelby = Truck2( shelby = Truck2(
**client.post(f"/trucks2/{shelby.pk}/add_coowner/", json=joe.dict()).json() **(
await client.post(f"/trucks2/{shelby.pk}/add_coowner/", json=joe.dict())
).json()
) )
assert shelby.name == "Shelby wanna be" assert shelby.name == "Shelby wanna be"
@ -235,6 +247,6 @@ def test_inheritance_with_m2m_relation():
assert shelby.co_owners[1] == joe assert shelby.co_owners[1] == joe
assert shelby.max_capacity == 2000 assert shelby.max_capacity == 2000
buses = [Bus2(**x) for x in client.get("/buses2/").json()] buses = [Bus2(**x) for x in (await client.get("/buses2/")).json()]
assert len(buses) == 1 assert len(buses) == 1
assert buses[0].name == "Unicorn" assert buses[0].name == "Unicorn"

View File

@ -2,8 +2,9 @@ import datetime
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
from tests.test_inheritance_and_pydantic_generation.test_inheritance_mixins import Category, Subject, metadata, db as database # type: ignore from tests.test_inheritance_and_pydantic_generation.test_inheritance_mixins import Category, Subject, metadata, db as database # type: ignore
@ -45,13 +46,14 @@ def create_test_database():
metadata.drop_all(engine) metadata.drop_all(engine)
def test_read_main(): @pytest.mark.asyncio
client = TestClient(app) async def test_read_main():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
async with client as client, LifespanManager(app):
test_category = dict(name="Foo", code=123, created_by="Sam", updated_by="Max") test_category = dict(name="Foo", code=123, created_by="Sam", updated_by="Max")
test_subject = dict(name="Bar") test_subject = dict(name="Bar")
response = client.post("/categories/", json=test_category) response = await client.post("/categories/", json=test_category)
assert response.status_code == 200 assert response.status_code == 200
cat = Category(**response.json()) cat = Category(**response.json())
assert cat.name == "Foo" assert cat.name == "Foo"
@ -67,7 +69,7 @@ def test_read_main():
"%Y-%m-%d %H:%M:%S.%f" "%Y-%m-%d %H:%M:%S.%f"
) )
test_subject["category"] = cat_dict test_subject["category"] = cat_dict
response = client.post("/subjects/", json=test_subject) response = await client.post("/subjects/", json=test_subject)
assert response.status_code == 200 assert response.status_code == 200
sub = Subject(**response.json()) sub = Subject(**response.json())
assert sub.name == "Bar" assert sub.name == "Bar"

View File

@ -6,8 +6,9 @@ import databases
import pydantic import pydantic
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -127,10 +128,11 @@ async def test_setting_values_after_init():
assert '["thing1"]' in (await Thing.objects.get(id=t1.id)).json() assert '["thing1"]' in (await Thing.objects.get(id=t1.id)).json()
def test_read_main(): @pytest.mark.asyncio
client = TestClient(app) async def test_read_main():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
response = client.get("/things_with_sample") async with client as client, LifespanManager(app):
response = await client.get("/things_with_sample")
assert response.status_code == 200 assert response.status_code == 200
# check if raw response not double encoded # check if raw response not double encoded
@ -142,11 +144,13 @@ def test_read_main():
assert resp[1].get("js") == ["asdf", "asdf", "bobby", "nigel"] assert resp[1].get("js") == ["asdf", "asdf", "bobby", "nigel"]
# create a new one # create a new one
response = client.post("/things", json={"js": ["test", "test2"], "name": "c"}) response = await client.post(
"/things", json={"js": ["test", "test2"], "name": "c"}
)
assert response.json().get("js") == ["test", "test2"] assert response.json().get("js") == ["test", "test2"]
# get all with new one # get all with new one
response = client.get("/things") response = await client.get("/things")
assert response.status_code == 200 assert response.status_code == 200
assert '["test","test2"]' in response.text assert '["test","test2"]' in response.text
resp = response.json() resp = response.json()
@ -154,26 +158,26 @@ def test_read_main():
assert resp[1].get("js") == ["asdf", "asdf", "bobby", "nigel"] assert resp[1].get("js") == ["asdf", "asdf", "bobby", "nigel"]
assert resp[2].get("js") == ["test", "test2"] assert resp[2].get("js") == ["test", "test2"]
response = client.get("/things_with_sample_after_init") response = await client.get("/things_with_sample_after_init")
assert response.status_code == 200 assert response.status_code == 200
resp = response.json() resp = response.json()
assert resp.get("js") == ["js", "set", "after", "constructor"] assert resp.get("js") == ["js", "set", "after", "constructor"]
# test new with after constructor # test new with after constructor
response = client.get("/things") response = await client.get("/things")
resp = response.json() resp = response.json()
assert resp[0].get("js") == ["lemon", "raspberry", "lime", "pumice"] assert resp[0].get("js") == ["lemon", "raspberry", "lime", "pumice"]
assert resp[1].get("js") == ["asdf", "asdf", "bobby", "nigel"] assert resp[1].get("js") == ["asdf", "asdf", "bobby", "nigel"]
assert resp[2].get("js") == ["test", "test2"] assert resp[2].get("js") == ["test", "test2"]
assert resp[3].get("js") == ["js", "set", "after", "constructor"] assert resp[3].get("js") == ["js", "set", "after", "constructor"]
response = client.put("/update_thing", json=resp[3]) response = await client.put("/update_thing", json=resp[3])
assert response.status_code == 200 assert response.status_code == 200
resp = response.json() resp = response.json()
assert resp.get("js") == ["js", "set", "after", "update"] assert resp.get("js") == ["js", "set", "after", "update"]
# test new with after constructor # test new with after constructor
response = client.get("/things_untyped") response = await client.get("/things_untyped")
resp = response.json() resp = response.json()
assert resp[0].get("js") == ["lemon", "raspberry", "lime", "pumice"] assert resp[0].get("js") == ["lemon", "raspberry", "lime", "pumice"]
assert resp[1].get("js") == ["asdf", "asdf", "bobby", "nigel"] assert resp[1].get("js") == ["asdf", "asdf", "bobby", "nigel"]

View File

@ -3,10 +3,11 @@ from typing import List, Optional
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from pydantic.schema import ForwardRef from pydantic.schema import ForwardRef
from starlette import status from starlette import status
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
@ -90,9 +91,10 @@ async def create_country(country: Country): # if this is ormar
return result return result
def test_payload(): @pytest.mark.asyncio
client = TestClient(app) async def test_payload():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
async with client as client, LifespanManager(app):
payload = { payload = {
"name": "Thailand", "name": "Thailand",
"iso2": "TH", "iso2": "TH",
@ -101,7 +103,9 @@ def test_payload():
"demonym": "Thai", "demonym": "Thai",
"native_name": "Thailand", "native_name": "Thailand",
} }
resp = client.post("/", json=payload, headers={"application-type": "json"}) resp = await client.post(
"/", json=payload, headers={"application-type": "json"}
)
# print(resp.content) # print(resp.content)
assert resp.status_code == 201 assert resp.status_code == 201

View File

@ -1,11 +1,11 @@
import asyncio
from typing import List, Optional from typing import List, Optional
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -59,25 +59,25 @@ def create_test_database():
metadata.drop_all(engine) metadata.drop_all(engine)
@app.get("/items/", response_model=List[Item]) @app.get("/items", response_model=List[Item])
async def get_items(): async def get_items():
items = await Item.objects.select_related("category").all() items = await Item.objects.select_related("category").all()
return items return items
@app.get("/items/raw/", response_model=List[Item]) @app.get("/items/raw", response_model=List[Item])
async def get_raw_items(): async def get_raw_items():
items = await Item.objects.all() items = await Item.objects.all()
return items return items
@app.post("/items/", response_model=Item) @app.post("/items", response_model=Item)
async def create_item(item: Item): async def create_item(item: Item):
await item.save() await item.save()
return item return item
@app.post("/categories/", response_model=Category) @app.post("/categories", response_model=Category)
async def create_category(category: Category): async def create_category(category: Category):
await category.save() await category.save()
return category return category
@ -96,59 +96,60 @@ async def update_item(item_id: int, item: Item):
@app.delete("/items/{item_id}") @app.delete("/items/{item_id}")
async def delete_item(item_id: int, item: Item = None): async def delete_item(item_id: int):
if item:
return {"deleted_rows": await item.delete()}
item_db = await Item.objects.get(pk=item_id) item_db = await Item.objects.get(pk=item_id)
return {"deleted_rows": await item_db.delete()} return {"deleted_rows": await item_db.delete()}
def test_all_endpoints(): @pytest.mark.asyncio
client = TestClient(app) async def test_all_endpoints():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
response = client.post("/categories/", json={"name": "test cat"}) async with client as client, LifespanManager(app):
response = await client.post("/categories", json={"name": "test cat"})
category = response.json() category = response.json()
response = client.post( response = await client.post(
"/items/", json={"name": "test", "id": 1, "category": category} "/items", json={"name": "test", "id": 1, "category": category}
) )
item = Item(**response.json()) item = Item(**response.json())
assert item.pk is not None assert item.pk is not None
response = client.get("/items/") response = await client.get("/items")
items = [Item(**item) for item in response.json()] items = [Item(**item) for item in response.json()]
assert items[0] == item assert items[0] == item
item.name = "New name" item.name = "New name"
response = client.put(f"/items/{item.pk}", json=item.dict()) response = await client.put(f"/items/{item.pk}", json=item.dict())
assert response.json() == item.dict() assert response.json() == item.dict()
response = client.get("/items/") response = await client.get("/items")
items = [Item(**item) for item in response.json()] items = [Item(**item) for item in response.json()]
assert items[0].name == "New name" assert items[0].name == "New name"
response = client.get("/items/raw/") response = await client.get("/items/raw")
items = [Item(**item) for item in response.json()] items = [Item(**item) for item in response.json()]
assert items[0].name == "New name" assert items[0].name == "New name"
assert items[0].category.name is None assert items[0].category.name is None
response = client.get(f"/items/{item.pk}") response = await client.get(f"/items/{item.pk}")
new_item = Item(**response.json()) new_item = Item(**response.json())
assert new_item == item assert new_item == item
response = client.delete(f"/items/{item.pk}") response = await client.delete(f"/items/{item.pk}")
assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__" assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__"
response = client.get("/items/") response = await client.get("/items")
items = response.json() items = response.json()
assert len(items) == 0 assert len(items) == 0
client.post("/items/", json={"name": "test_2", "id": 2, "category": category}) await client.post(
response = client.get("/items/") "/items", json={"name": "test_2", "id": 2, "category": category}
)
response = await client.get("/items")
items = response.json() items = response.json()
assert len(items) == 1 assert len(items) == 1
item = Item(**items[0]) item = Item(**items[0])
response = client.delete(f"/items/{item.pk}", json=item.dict()) response = await client.delete(f"/items/{item.pk}")
assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__" assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__"
response = client.get("/docs/") response = await client.get("/docs")
assert response.status_code == 200 assert response.status_code == 200

View File

@ -4,8 +4,9 @@ from typing import Any, Dict, Optional, Set, Type, Union, cast
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
from ormar.queryset.utils import translate_list_to_dict from ormar.queryset.utils import translate_list_to_dict
@ -135,9 +136,10 @@ async def get_department_exclude_all(department_name: str):
return department.dict(exclude=exclude_all) return department.dict(exclude=exclude_all)
def test_saving_related_in_fastapi(): @pytest.mark.asyncio
client = TestClient(app) async def test_saving_related_in_fastapi():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
async with client as client, LifespanManager(app):
payload = { payload = {
"department_name": "Ormar", "department_name": "Ormar",
"courses": [ "courses": [
@ -153,9 +155,7 @@ def test_saving_related_in_fastapi():
}, },
], ],
} }
response = client.post( response = await client.post("/departments/", json=payload, headers=headers)
"/departments/", data=json.dumps(payload), headers=headers
)
department = Department(**response.json()) department = Department(**response.json())
assert department.id is not None assert department.id is not None
@ -166,9 +166,9 @@ def test_saving_related_in_fastapi():
assert department.courses[1].course_name == "basic2" assert department.courses[1].course_name == "basic2"
assert department.courses[1].completed assert department.courses[1].completed
response = client.get("/departments/Ormar") response = await client.get("/departments/Ormar")
response2 = client.get("/departments/Ormar/second") response2 = await client.get("/departments/Ormar/second")
assert response.json() == response2.json() == payload assert response.json() == response2.json() == payload
response3 = client.get("/departments/Ormar/exclude") response3 = await client.get("/departments/Ormar/exclude")
assert response3.json() == {"department_name": "Ormar"} assert response3.json() == {"department_name": "Ormar"}

View File

@ -1,14 +1,15 @@
import json import json
from datetime import datetime
import uuid import uuid
from datetime import datetime
from typing import List from typing import List
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI
from httpx import AsyncClient
from pydantic import BaseModel, Json from pydantic import BaseModel, Json
from starlette.testclient import TestClient
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -121,9 +122,10 @@ async def create_quiz_lol(
return await quiz.save() return await quiz.save()
def test_quiz_creation(): @pytest.mark.asyncio
client = TestClient(app=router) async def test_quiz_creation():
with client as client: client = AsyncClient(app=router, base_url="http://testserver")
async with client as client, LifespanManager(router):
payload = { payload = {
"title": "Some test question", "title": "Some test question",
"description": "A description", "description": "A description",
@ -145,5 +147,5 @@ def test_quiz_creation():
}, },
], ],
} }
response = client.post("/create", data=json.dumps(payload)) response = await client.post("/create", json=payload)
assert response.status_code == 200 assert response.status_code == 200

View File

@ -4,8 +4,9 @@ import databases
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -94,10 +95,11 @@ async def get_book_with_author_by_id(book_id: int):
return book return book
def test_related_with_defaults(sample_data): @pytest.mark.asyncio
client = TestClient(app) async def test_related_with_defaults(sample_data):
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
response = client.get("/books/1") async with client as client, LifespanManager(app):
response = await client.get("/books/1")
assert response.json() == { assert response.json() == {
"author": {"id": 1}, "author": {"id": 1},
"id": 1, "id": 1,
@ -105,7 +107,7 @@ def test_related_with_defaults(sample_data):
"year": 2021, "year": 2021,
} }
response = client.get("/books_with_author/1") response = await client.get("/books_with_author/1")
assert response.json() == { assert response.json() == {
"author": { "author": {
"books": [ "books": [

View File

@ -1,11 +1,11 @@
import json
from typing import List, Optional from typing import List, Optional
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -101,30 +101,31 @@ async def get_posts():
return posts return posts
def test_queries(): @pytest.mark.asyncio
client = TestClient(app) async def test_queries():
with client as client: client = AsyncClient(app=app, base_url="http://testserver")
async with client as client, LifespanManager(app):
right_category = {"name": "Test category"} right_category = {"name": "Test category"}
wrong_category = {"name": "Test category2", "posts": [{"title": "Test Post"}]} wrong_category = {"name": "Test category2", "posts": [{"title": "Test Post"}]}
# cannot add posts if skipped, will be ignored (with extra=ignore by default) # cannot add posts if skipped, will be ignored (with extra=ignore by default)
response = client.post( response = await client.post(
"/categories/", data=json.dumps(wrong_category), headers=headers "/categories/", json=wrong_category, headers=headers
) )
assert response.status_code == 200 assert response.status_code == 200
response = client.get("/categories/") response = await client.get("/categories/")
assert response.status_code == 200 assert response.status_code == 200
assert not "posts" in response.json() assert not "posts" in response.json()
categories = [Category(**x) for x in response.json()] categories = [Category(**x) for x in response.json()]
assert categories[0] is not None assert categories[0] is not None
assert categories[0].name == "Test category2" assert categories[0].name == "Test category2"
response = client.post( response = await client.post(
"/categories/", data=json.dumps(right_category), headers=headers "/categories/", json=right_category, headers=headers
) )
assert response.status_code == 200 assert response.status_code == 200
response = client.get("/categories/") response = await client.get("/categories/")
assert response.status_code == 200 assert response.status_code == 200
categories = [Category(**x) for x in response.json()] categories = [Category(**x) for x in response.json()]
assert categories[1] is not None assert categories[1] is not None
@ -135,11 +136,11 @@ def test_queries():
"author": {"first_name": "John", "last_name": "Smith"}, "author": {"first_name": "John", "last_name": "Smith"},
"categories": [{"name": "New cat"}], "categories": [{"name": "New cat"}],
} }
response = client.post("/posts/", data=json.dumps(right_post), headers=headers) response = await client.post("/posts/", json=right_post, headers=headers)
assert response.status_code == 200 assert response.status_code == 200
Category.__config__.extra = "allow" Category.__config__.extra = "allow"
response = client.get("/posts/") response = await client.get("/posts/")
assert response.status_code == 200 assert response.status_code == 200
posts = [Post(**x) for x in response.json()] posts = [Post(**x) for x in response.json()]
assert posts[0].title == "ok post" assert posts[0].title == "ok post"
@ -150,5 +151,5 @@ def test_queries():
# cannot add posts if skipped, will be error with extra forbid # cannot add posts if skipped, will be error with extra forbid
Category.__config__.extra = "forbid" Category.__config__.extra = "forbid"
response = client.post("/categories/", data=json.dumps(wrong_category)) response = await client.post("/categories/", json=wrong_category)
assert response.status_code == 422 assert response.status_code == 422

View File

@ -5,8 +5,9 @@ import databases
import pydantic import pydantic
import pytest import pytest
import sqlalchemy import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI from fastapi import FastAPI
from starlette.testclient import TestClient from httpx import AsyncClient
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -133,29 +134,30 @@ async def get_weakref():
return ts return ts
def test_endpoints(): @pytest.mark.asyncio
client = TestClient(app) async def test_endpoints():
with client: client = AsyncClient(app=app, base_url="http://testserver")
resp = client.post("/test/1") async with client, LifespanManager(app):
resp = await client.post("/test/1")
assert resp.status_code == 200 assert resp.status_code == 200
resp2 = client.get("/test/2") resp2 = await client.get("/test/2")
assert resp2.status_code == 200 assert resp2.status_code == 200
assert len(resp2.json()) == 3 assert len(resp2.json()) == 3
resp3 = client.get("/test/3") resp3 = await client.get("/test/3")
assert resp3.status_code == 200 assert resp3.status_code == 200
assert len(resp3.json()) == 3 assert len(resp3.json()) == 3
resp4 = client.get("/test/4") resp4 = await client.get("/test/4")
assert resp4.status_code == 200 assert resp4.status_code == 200
assert len(resp4.json()) == 3 assert len(resp4.json()) == 3
ot = OtherThing(**client.get("/get_ot/").json()) ot = OtherThing(**(await client.get("/get_ot/")).json())
resp5 = client.get(f"/test/5/{ot.id}") resp5 = await client.get(f"/test/5/{ot.id}")
assert resp5.status_code == 200 assert resp5.status_code == 200
assert len(resp5.json()) == 3 assert len(resp5.json()) == 3
resp6 = client.get("/test/error") resp6 = await client.get("/test/error")
assert resp6.status_code == 200 assert resp6.status_code == 200
assert len(resp6.json()) == 3 assert len(resp6.json()) == 3