Files
ormar/tests/test_fastapi/test_nested_saving.py

107 lines
3.2 KiB
Python

import json
from typing import List, Optional
import databases
import pytest
import sqlalchemy
from fastapi import FastAPI
from starlette.testclient import TestClient
import ormar
from tests.settings import DATABASE_URL
app = FastAPI()
metadata = sqlalchemy.MetaData()
database = databases.Database(DATABASE_URL, force_rollback=True)
app.state.database = database
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
class Department(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
department_name: str = ormar.String(max_length=100)
class Course(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
course_name: str = ormar.String(max_length=100)
completed: bool = ormar.Boolean()
department: Optional[Department] = ormar.ForeignKey(Department)
# create db and tables
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@app.post("/DepartmentWithCourses/", response_model=Department)
async def create_department(department: Department):
# there is no save all - you need to split into save and save_related
await department.save()
await department.save_related(follow=True, save_all=True)
return department
@app.get("/DepartmentsAll/", response_model=List[Department])
async def get_Courses():
# if you don't provide default name it related model name + s so courses not course
departmentall = await Department.objects.select_related("courses").all()
return departmentall
def test_saving_related_in_fastapi():
client = TestClient(app)
with client as client:
payload = {
"department_name": "Ormar",
"courses": [
{"course_name": "basic1", "completed": True},
{"course_name": "basic2", "completed": True},
],
}
response = client.post("/DepartmentWithCourses/", data=json.dumps(payload))
department = Department(**response.json())
assert department.id is not None
assert len(department.courses) == 2
assert department.department_name == "Ormar"
assert department.courses[0].course_name == "basic1"
assert department.courses[0].completed
assert department.courses[1].course_name == "basic2"
assert department.courses[1].completed
response = client.get("/DepartmentsAll/")
departments = [Department(**x) for x in response.json()]
assert departments[0].id is not None
assert len(departments[0].courses) == 2
assert departments[0].department_name == "Ormar"
assert departments[0].courses[0].course_name == "basic1"
assert departments[0].courses[0].completed
assert departments[0].courses[1].course_name == "basic2"
assert departments[0].courses[1].completed