from typing import List, Optional from uuid import UUID, uuid4 import databases import pydantic import pytest import sqlalchemy from fastapi import FastAPI from starlette.testclient import TestClient import ormar from tests.settings import DATABASE_URL app = FastAPI() database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() app.state.database = database @app.on_event("startup") async def startup() -> None: database_ = app.state.database if not database_.is_connected: await database_.connect() @app.on_event("shutdown") async def shutdown() -> None: database_ = app.state.database if database_.is_connected: await database_.disconnect() @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) metadata.create_all(engine) yield metadata.drop_all(engine) class BaseMeta(ormar.ModelMeta): database = database metadata = metadata class OtherThing(ormar.Model): class Meta(BaseMeta): tablename = "other_things" id: UUID = ormar.UUID(primary_key=True, default=uuid4) name: str = ormar.Text(default="") ot_contents: str = ormar.Text(default="") class Thing(ormar.Model): class Meta(BaseMeta): tablename = "things" id: UUID = ormar.UUID(primary_key=True, default=uuid4) name: str = ormar.Text(default="") js: pydantic.Json = ormar.JSON(nullable=True) other_thing: Optional[OtherThing] = ormar.ForeignKey(OtherThing, nullable=True) @app.post("/test/1") async def post_test_1(): # don't split initialization and attribute assignment ot = await OtherThing(ot_contents="otc").save() await Thing(other_thing=ot, name="t1").save() await Thing(other_thing=ot, name="t2").save() await Thing(other_thing=ot, name="t3").save() # if you do not care about returned object you can even go with bulk_create # all of them are created in one transaction # things = [Thing(other_thing=ot, name='t1'), # Thing(other_thing=ot, name="t2"), # Thing(other_thing=ot, name="t3")] # await Thing.objects.bulk_create(things) @app.get("/test/2", response_model=List[Thing]) async def get_test_2(): # if you only query for one use get or first ot = await OtherThing.objects.get() ts = await ot.things.all() # specifically null out the relation on things before return for t in ts: t.remove(ot, name="other_thing") return ts @app.get("/test/3", response_model=List[Thing]) async def get_test_3(): ot = await OtherThing.objects.select_related("things").get() # exclude unwanted field while ot is still in scope # in order not to pass it to fastapi return [t.dict(exclude={"other_thing"}) for t in ot.things] @app.get("/test/4", response_model=List[Thing], response_model_exclude={"other_thing"}) async def get_test_4(): ot = await OtherThing.objects.get() # query from the active side return await Thing.objects.all(other_thing=ot) @app.get("/get_ot/", response_model=OtherThing) async def get_ot(): return await OtherThing.objects.get() # more real life (usually) is not getting some random OT and get it's Things # but query for a specific one by some kind of id @app.get( "/test/5/{thing_id}", response_model=List[Thing], response_model_exclude={"other_thing"}, ) async def get_test_5(thing_id: UUID): return await Thing.objects.all(other_thing__id=thing_id) def test_endpoints(): client = TestClient(app) with client: resp = client.post("/test/1") assert resp.status_code == 200 resp2 = client.get("/test/2") assert resp2.status_code == 200 assert len(resp2.json()) == 3 resp3 = client.get("/test/3") assert resp3.status_code == 200 assert len(resp3.json()) == 3 resp4 = client.get("/test/4") assert resp4.status_code == 200 assert len(resp4.json()) == 3 ot = OtherThing(**client.get("/get_ot/").json()) resp5 = client.get(f"/test/5/{ot.id}") assert resp5.status_code == 200 assert len(resp5.json()) == 3