Files
ormar/tests/test_fastapi/test_nested_saving.py
collerek b1ab0de4d4 Bump supported fastapi versions (#1110)
* Bump supported fastapi version to <=0.97, change all fastapi tests from starlette client to httpx.AsyncClient

* Add lifecycle manager to fastapi tests

* Fix coverage

* Add python 3.11 to test suite, bump version
2023-06-18 18:52:06 +02:00

175 lines
5.1 KiB
Python

import json
from typing import Any, Dict, Optional, Set, Type, Union, cast
import databases
import pytest
import sqlalchemy
from asgi_lifespan import LifespanManager
from fastapi import FastAPI
from httpx import AsyncClient
import ormar
from ormar.queryset.utils import translate_list_to_dict
from tests.settings import DATABASE_URL
app = FastAPI()
metadata = sqlalchemy.MetaData()
database = databases.Database(DATABASE_URL, force_rollback=True)
app.state.database = database
headers = {"content-type": "application/json"}
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
class Department(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
department_name: str = ormar.String(max_length=100)
class Course(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
course_name: str = ormar.String(max_length=100)
completed: bool = ormar.Boolean()
department: Optional[Department] = ormar.ForeignKey(Department)
class Student(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
courses = ormar.ManyToMany(Course)
# create db and tables
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
to_exclude = {
"id": ...,
"courses": {
"__all__": {"id": ..., "students": {"__all__": {"id", "studentcourse"}}}
},
}
exclude_all = {"id": ..., "courses": {"__all__"}}
to_exclude_ormar = {
"id": ...,
"courses": {"id": ..., "students": {"id", "studentcourse"}},
}
def auto_exclude_id_field(to_exclude: Any) -> Union[Dict, Set]:
if isinstance(to_exclude, dict):
for key in to_exclude.keys():
to_exclude[key] = auto_exclude_id_field(to_exclude[key])
to_exclude["id"] = Ellipsis
return to_exclude
else:
return {"id"}
def generate_exclude_for_ids(model: Type[ormar.Model]) -> Dict:
to_exclude_base = translate_list_to_dict(model._iterate_related_models())
return cast(Dict, auto_exclude_id_field(to_exclude=to_exclude_base))
to_exclude_auto = generate_exclude_for_ids(model=Department)
@app.post("/departments/", response_model=Department)
async def create_department(department: Department):
await department.save_related(follow=True, save_all=True)
return department
@app.get("/departments/{department_name}")
async def get_department(department_name: str):
department = await Department.objects.select_all(follow=True).get(
department_name=department_name
)
return department.dict(exclude=to_exclude)
@app.get("/departments/{department_name}/second")
async def get_department_exclude(department_name: str):
department = await Department.objects.select_all(follow=True).get(
department_name=department_name
)
return department.dict(exclude=to_exclude_ormar)
@app.get("/departments/{department_name}/exclude")
async def get_department_exclude_all(department_name: str):
department = await Department.objects.select_all(follow=True).get(
department_name=department_name
)
return department.dict(exclude=exclude_all)
@pytest.mark.asyncio
async def test_saving_related_in_fastapi():
client = AsyncClient(app=app, base_url="http://testserver")
async with client as client, LifespanManager(app):
payload = {
"department_name": "Ormar",
"courses": [
{
"course_name": "basic1",
"completed": True,
"students": [{"name": "Jack"}, {"name": "Abi"}],
},
{
"course_name": "basic2",
"completed": True,
"students": [{"name": "Kate"}, {"name": "Miranda"}],
},
],
}
response = await client.post("/departments/", json=payload, headers=headers)
department = Department(**response.json())
assert department.id is not None
assert len(department.courses) == 2
assert department.department_name == "Ormar"
assert department.courses[0].course_name == "basic1"
assert department.courses[0].completed
assert department.courses[1].course_name == "basic2"
assert department.courses[1].completed
response = await client.get("/departments/Ormar")
response2 = await client.get("/departments/Ormar/second")
assert response.json() == response2.json() == payload
response3 = await client.get("/departments/Ormar/exclude")
assert response3.json() == {"department_name": "Ormar"}