171 lines
4.9 KiB
Python
171 lines
4.9 KiB
Python
import json
|
|
from typing import Any, Dict, Optional, Set, Type, Union, cast
|
|
|
|
import databases
|
|
import pytest
|
|
import sqlalchemy
|
|
from fastapi import FastAPI
|
|
from starlette.testclient import TestClient
|
|
|
|
import ormar
|
|
from ormar.queryset.utils import translate_list_to_dict
|
|
from tests.settings import DATABASE_URL
|
|
|
|
app = FastAPI()
|
|
metadata = sqlalchemy.MetaData()
|
|
database = databases.Database(DATABASE_URL, force_rollback=True)
|
|
app.state.database = database
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup() -> None:
|
|
database_ = app.state.database
|
|
if not database_.is_connected:
|
|
await database_.connect()
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown() -> None:
|
|
database_ = app.state.database
|
|
if database_.is_connected:
|
|
await database_.disconnect()
|
|
|
|
|
|
class Department(ormar.Model):
|
|
class Meta:
|
|
database = database
|
|
metadata = metadata
|
|
|
|
id: int = ormar.Integer(primary_key=True)
|
|
department_name: str = ormar.String(max_length=100)
|
|
|
|
|
|
class Course(ormar.Model):
|
|
class Meta:
|
|
database = database
|
|
metadata = metadata
|
|
|
|
id: int = ormar.Integer(primary_key=True)
|
|
course_name: str = ormar.String(max_length=100)
|
|
completed: bool = ormar.Boolean()
|
|
department: Optional[Department] = ormar.ForeignKey(Department)
|
|
|
|
|
|
class Student(ormar.Model):
|
|
class Meta:
|
|
database = database
|
|
metadata = metadata
|
|
|
|
id: int = ormar.Integer(primary_key=True)
|
|
name: str = ormar.String(max_length=100)
|
|
courses = ormar.ManyToMany(Course)
|
|
|
|
|
|
# create db and tables
|
|
@pytest.fixture(autouse=True, scope="module")
|
|
def create_test_database():
|
|
engine = sqlalchemy.create_engine(DATABASE_URL)
|
|
metadata.create_all(engine)
|
|
yield
|
|
metadata.drop_all(engine)
|
|
|
|
|
|
to_exclude = {
|
|
"id": ...,
|
|
"courses": {
|
|
"__all__": {"id": ..., "students": {"__all__": {"id", "studentcourse"}}}
|
|
},
|
|
}
|
|
|
|
exclude_all = {"id": ..., "courses": {"__all__"}}
|
|
|
|
to_exclude_ormar = {
|
|
"id": ...,
|
|
"courses": {"id": ..., "students": {"id", "studentcourse"}},
|
|
}
|
|
|
|
|
|
def auto_exclude_id_field(to_exclude: Any) -> Union[Dict, Set]:
|
|
if isinstance(to_exclude, dict):
|
|
for key in to_exclude.keys():
|
|
to_exclude[key] = auto_exclude_id_field(to_exclude[key])
|
|
to_exclude["id"] = Ellipsis
|
|
return to_exclude
|
|
else:
|
|
return {"id"}
|
|
|
|
|
|
def generate_exclude_for_ids(model: Type[ormar.Model]) -> Dict:
|
|
to_exclude_base = translate_list_to_dict(model._iterate_related_models())
|
|
return cast(Dict, auto_exclude_id_field(to_exclude=to_exclude_base))
|
|
|
|
|
|
to_exclude_auto = generate_exclude_for_ids(model=Department)
|
|
|
|
|
|
@app.post("/departments/", response_model=Department)
|
|
async def create_department(department: Department):
|
|
await department.save_related(follow=True, save_all=True)
|
|
return department
|
|
|
|
|
|
@app.get("/departments/{department_name}")
|
|
async def get_department(department_name: str):
|
|
department = await Department.objects.select_all(follow=True).get(
|
|
department_name=department_name
|
|
)
|
|
return department.dict(exclude=to_exclude)
|
|
|
|
|
|
@app.get("/departments/{department_name}/second")
|
|
async def get_department_exclude(department_name: str):
|
|
department = await Department.objects.select_all(follow=True).get(
|
|
department_name=department_name
|
|
)
|
|
return department.dict(exclude=to_exclude_ormar)
|
|
|
|
|
|
@app.get("/departments/{department_name}/exclude")
|
|
async def get_department_exclude_all(department_name: str):
|
|
department = await Department.objects.select_all(follow=True).get(
|
|
department_name=department_name
|
|
)
|
|
return department.dict(exclude=exclude_all)
|
|
|
|
|
|
def test_saving_related_in_fastapi():
|
|
client = TestClient(app)
|
|
with client as client:
|
|
payload = {
|
|
"department_name": "Ormar",
|
|
"courses": [
|
|
{
|
|
"course_name": "basic1",
|
|
"completed": True,
|
|
"students": [{"name": "Jack"}, {"name": "Abi"}],
|
|
},
|
|
{
|
|
"course_name": "basic2",
|
|
"completed": True,
|
|
"students": [{"name": "Kate"}, {"name": "Miranda"}],
|
|
},
|
|
],
|
|
}
|
|
response = client.post("/departments/", data=json.dumps(payload))
|
|
department = Department(**response.json())
|
|
|
|
assert department.id is not None
|
|
assert len(department.courses) == 2
|
|
assert department.department_name == "Ormar"
|
|
assert department.courses[0].course_name == "basic1"
|
|
assert department.courses[0].completed
|
|
assert department.courses[1].course_name == "basic2"
|
|
assert department.courses[1].completed
|
|
|
|
response = client.get("/departments/Ormar")
|
|
response2 = client.get("/departments/Ormar/second")
|
|
assert response.json() == response2.json() == payload
|
|
|
|
response3 = client.get("/departments/Ormar/exclude")
|
|
assert response3.json() == {"department_name": "Ormar"}
|