import json from typing import Any, Dict, Optional, Set, Type, Union, cast import databases import pytest import sqlalchemy from fastapi import FastAPI from starlette.testclient import TestClient import ormar from ormar.queryset.utils import translate_list_to_dict from tests.settings import DATABASE_URL app = FastAPI() metadata = sqlalchemy.MetaData() database = databases.Database(DATABASE_URL, force_rollback=True) app.state.database = database @app.on_event("startup") async def startup() -> None: database_ = app.state.database if not database_.is_connected: await database_.connect() @app.on_event("shutdown") async def shutdown() -> None: database_ = app.state.database if database_.is_connected: await database_.disconnect() class Department(ormar.Model): class Meta: database = database metadata = metadata id: int = ormar.Integer(primary_key=True) department_name: str = ormar.String(max_length=100) class Course(ormar.Model): class Meta: database = database metadata = metadata id: int = ormar.Integer(primary_key=True) course_name: str = ormar.String(max_length=100) completed: bool = ormar.Boolean() department: Optional[Department] = ormar.ForeignKey(Department) class Student(ormar.Model): class Meta: database = database metadata = metadata id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) courses = ormar.ManyToMany(Course) # create db and tables @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) metadata.create_all(engine) yield metadata.drop_all(engine) to_exclude = { "id": ..., "courses": { "__all__": {"id": ..., "students": {"__all__": {"id", "studentcourse"}}} }, } exclude_all = {"id": ..., "courses": {"__all__"}} to_exclude_ormar = { "id": ..., "courses": {"id": ..., "students": {"id", "studentcourse"}}, } def auto_exclude_id_field(to_exclude: Any) -> Union[Dict, Set]: if isinstance(to_exclude, dict): for key in to_exclude.keys(): to_exclude[key] = auto_exclude_id_field(to_exclude[key]) to_exclude["id"] = Ellipsis return to_exclude else: return {"id"} def generate_exclude_for_ids(model: Type[ormar.Model]) -> Dict: to_exclude_base = translate_list_to_dict(model._iterate_related_models()) return cast(Dict, auto_exclude_id_field(to_exclude=to_exclude_base)) to_exclude_auto = generate_exclude_for_ids(model=Department) @app.post("/departments/", response_model=Department) async def create_department(department: Department): await department.save_related(follow=True, save_all=True) return department @app.get("/departments/{department_name}") async def get_department(department_name: str): department = await Department.objects.select_all(follow=True).get( department_name=department_name ) return department.dict(exclude=to_exclude) @app.get("/departments/{department_name}/second") async def get_department_exclude(department_name: str): department = await Department.objects.select_all(follow=True).get( department_name=department_name ) return department.dict(exclude=to_exclude_ormar) @app.get("/departments/{department_name}/exclude") async def get_department_exclude_all(department_name: str): department = await Department.objects.select_all(follow=True).get( department_name=department_name ) return department.dict(exclude=exclude_all) def test_saving_related_in_fastapi(): client = TestClient(app) with client as client: payload = { "department_name": "Ormar", "courses": [ { "course_name": "basic1", "completed": True, "students": [{"name": "Jack"}, {"name": "Abi"}], }, { "course_name": "basic2", "completed": True, "students": [{"name": "Kate"}, {"name": "Miranda"}], }, ], } response = client.post("/departments/", data=json.dumps(payload)) department = Department(**response.json()) assert department.id is not None assert len(department.courses) == 2 assert department.department_name == "Ormar" assert department.courses[0].course_name == "basic1" assert department.courses[0].completed assert department.courses[1].course_name == "basic2" assert department.courses[1].completed response = client.get("/departments/Ormar") response2 = client.get("/departments/Ormar/second") assert response.json() == response2.json() == payload response3 = client.get("/departments/Ormar/exclude") assert response3.json() == {"department_name": "Ormar"}