add skip_reverse parameter, add links to related libs, fix weakref error, fix through error with extra=forbid
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
@ -59,6 +60,12 @@ async def get_bus(item_id: int):
|
||||
return bus
|
||||
|
||||
|
||||
@app.get("/buses/", response_model=List[Bus])
|
||||
async def get_buses():
|
||||
buses = await Bus.objects.select_related(["owner", "co_owner"]).all()
|
||||
return buses
|
||||
|
||||
|
||||
@app.post("/trucks/", response_model=Truck)
|
||||
async def create_truck(truck: Truck):
|
||||
await truck.save()
|
||||
@ -84,6 +91,12 @@ async def add_bus_coowner(item_id: int, person: Person):
|
||||
return bus
|
||||
|
||||
|
||||
@app.get("/buses2/", response_model=List[Bus2])
|
||||
async def get_buses2():
|
||||
buses = await Bus2.objects.select_related(["owner", "co_owners"]).all()
|
||||
return buses
|
||||
|
||||
|
||||
@app.post("/trucks2/", response_model=Truck2)
|
||||
async def create_truck2(truck: Truck2):
|
||||
await truck.save()
|
||||
@ -172,6 +185,10 @@ def test_inheritance_with_relation():
|
||||
assert unicorn2.co_owner.name == "Joe"
|
||||
assert unicorn2.max_persons == 50
|
||||
|
||||
buses = [Bus(**x) for x in client.get("/buses/").json()]
|
||||
assert len(buses) == 1
|
||||
assert buses[0].name == "Unicorn"
|
||||
|
||||
|
||||
def test_inheritance_with_m2m_relation():
|
||||
client = TestClient(app)
|
||||
@ -217,3 +234,7 @@ def test_inheritance_with_m2m_relation():
|
||||
assert shelby.co_owners[0] == alex
|
||||
assert shelby.co_owners[1] == joe
|
||||
assert shelby.max_capacity == 2000
|
||||
|
||||
buses = [Bus2(**x) for x in client.get("/buses2/").json()]
|
||||
assert len(buses) == 1
|
||||
assert buses[0].name == "Unicorn"
|
||||
|
||||
106
tests/test_fastapi/test_nested_saving.py
Normal file
106
tests/test_fastapi/test_nested_saving.py
Normal file
@ -0,0 +1,106 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
import databases
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
from fastapi import FastAPI
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
import ormar
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
app = FastAPI()
|
||||
metadata = sqlalchemy.MetaData()
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
app.state.database = database
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup() -> None:
|
||||
database_ = app.state.database
|
||||
if not database_.is_connected:
|
||||
await database_.connect()
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown() -> None:
|
||||
database_ = app.state.database
|
||||
if database_.is_connected:
|
||||
await database_.disconnect()
|
||||
|
||||
|
||||
class Department(ormar.Model):
|
||||
class Meta:
|
||||
database = database
|
||||
metadata = metadata
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
department_name: str = ormar.String(max_length=100)
|
||||
|
||||
|
||||
class Course(ormar.Model):
|
||||
class Meta:
|
||||
database = database
|
||||
metadata = metadata
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
course_name: str = ormar.String(max_length=100)
|
||||
completed: bool = ormar.Boolean()
|
||||
department: Optional[Department] = ormar.ForeignKey(Department)
|
||||
|
||||
|
||||
# create db and tables
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def create_test_database():
|
||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||
metadata.create_all(engine)
|
||||
yield
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
@app.post("/DepartmentWithCourses/", response_model=Department)
|
||||
async def create_department(department: Department):
|
||||
# there is no save all - you need to split into save and save_related
|
||||
await department.save()
|
||||
await department.save_related(follow=True, save_all=True)
|
||||
return department
|
||||
|
||||
|
||||
@app.get("/DepartmentsAll/", response_model=List[Department])
|
||||
async def get_Courses():
|
||||
# if you don't provide default name it related model name + s so courses not course
|
||||
departmentall = await Department.objects.select_related("courses").all()
|
||||
return departmentall
|
||||
|
||||
|
||||
def test_saving_related_in_fastapi():
|
||||
client = TestClient(app)
|
||||
with client as client:
|
||||
payload = {
|
||||
"department_name": "Ormar",
|
||||
"courses": [
|
||||
{"course_name": "basic1", "completed": True},
|
||||
{"course_name": "basic2", "completed": True},
|
||||
],
|
||||
}
|
||||
response = client.post("/DepartmentWithCourses/", data=json.dumps(payload))
|
||||
department = Department(**response.json())
|
||||
|
||||
assert department.id is not None
|
||||
assert len(department.courses) == 2
|
||||
assert department.department_name == "Ormar"
|
||||
assert department.courses[0].course_name == "basic1"
|
||||
assert department.courses[0].completed
|
||||
assert department.courses[1].course_name == "basic2"
|
||||
assert department.courses[1].completed
|
||||
|
||||
response = client.get("/DepartmentsAll/")
|
||||
departments = [Department(**x) for x in response.json()]
|
||||
assert departments[0].id is not None
|
||||
assert len(departments[0].courses) == 2
|
||||
assert departments[0].department_name == "Ormar"
|
||||
assert departments[0].courses[0].course_name == "basic1"
|
||||
assert departments[0].courses[0].completed
|
||||
assert departments[0].courses[1].course_name == "basic2"
|
||||
assert departments[0].courses[1].completed
|
||||
148
tests/test_fastapi/test_skip_reverse_models.py
Normal file
148
tests/test_fastapi/test_skip_reverse_models.py
Normal file
@ -0,0 +1,148 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
import databases
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
from fastapi import FastAPI
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
import ormar
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
app = FastAPI()
|
||||
metadata = sqlalchemy.MetaData()
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
app.state.database = database
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup() -> None:
|
||||
database_ = app.state.database
|
||||
if not database_.is_connected:
|
||||
await database_.connect()
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown() -> None:
|
||||
database_ = app.state.database
|
||||
if database_.is_connected:
|
||||
await database_.disconnect()
|
||||
|
||||
|
||||
class BaseMeta(ormar.ModelMeta):
|
||||
database = database
|
||||
metadata = metadata
|
||||
|
||||
|
||||
class Author(ormar.Model):
|
||||
class Meta(BaseMeta):
|
||||
pass
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
first_name: str = ormar.String(max_length=80)
|
||||
last_name: str = ormar.String(max_length=80)
|
||||
|
||||
|
||||
class Category(ormar.Model):
|
||||
class Meta(BaseMeta):
|
||||
tablename = "categories"
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=40)
|
||||
|
||||
|
||||
class Post(ormar.Model):
|
||||
class Meta(BaseMeta):
|
||||
pass
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
title: str = ormar.String(max_length=200)
|
||||
categories = ormar.ManyToMany(Category, skip_reverse=True)
|
||||
author: Optional[Author] = ormar.ForeignKey(Author, skip_reverse=True)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def create_test_database():
|
||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||
metadata.create_all(engine)
|
||||
yield
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
@app.post("/categories/", response_model=Category)
|
||||
async def create_category(category: Category):
|
||||
await category.save()
|
||||
await category.save_related(follow=True, save_all=True)
|
||||
return category
|
||||
|
||||
|
||||
@app.post("/posts/", response_model=Post)
|
||||
async def create_post(post: Post):
|
||||
if post.author:
|
||||
await post.author.save()
|
||||
await post.save()
|
||||
await post.save_related(follow=True, save_all=True)
|
||||
for category in [cat for cat in post.categories]:
|
||||
await post.categories.add(category)
|
||||
return post
|
||||
|
||||
|
||||
@app.get("/categories/", response_model=List[Category])
|
||||
async def get_categories():
|
||||
return await Category.objects.select_related("posts").all()
|
||||
|
||||
|
||||
@app.get("/posts/", response_model=List[Post])
|
||||
async def get_posts():
|
||||
posts = await Post.objects.select_related(["categories", "author"]).all()
|
||||
return posts
|
||||
|
||||
|
||||
def test_queries():
|
||||
client = TestClient(app)
|
||||
with client as client:
|
||||
right_category = {"name": "Test category"}
|
||||
wrong_category = {"name": "Test category2", "posts": [{"title": "Test Post"}]}
|
||||
|
||||
# cannot add posts if skipped, will be ignored (with extra=ignore by default)
|
||||
response = client.post("/categories/", data=json.dumps(wrong_category))
|
||||
assert response.status_code == 200
|
||||
response = client.get("/categories/")
|
||||
assert response.status_code == 200
|
||||
assert not "posts" in response.json()
|
||||
categories = [Category(**x) for x in response.json()]
|
||||
assert categories[0] is not None
|
||||
assert categories[0].name == "Test category2"
|
||||
|
||||
response = client.post("/categories/", data=json.dumps(right_category))
|
||||
assert response.status_code == 200
|
||||
|
||||
response = client.get("/categories/")
|
||||
assert response.status_code == 200
|
||||
categories = [Category(**x) for x in response.json()]
|
||||
assert categories[1] is not None
|
||||
assert categories[1].name == "Test category"
|
||||
|
||||
right_post = {
|
||||
"title": "ok post",
|
||||
"author": {"first_name": "John", "last_name": "Smith"},
|
||||
"categories": [{"name": "New cat"}],
|
||||
}
|
||||
response = client.post("/posts/", data=json.dumps(right_post))
|
||||
assert response.status_code == 200
|
||||
|
||||
Category.__config__.extra = "allow"
|
||||
response = client.get("/posts/")
|
||||
assert response.status_code == 200
|
||||
posts = [Post(**x) for x in response.json()]
|
||||
assert posts[0].title == "ok post"
|
||||
assert posts[0].author.first_name == "John"
|
||||
assert posts[0].categories[0].name == "New cat"
|
||||
|
||||
wrong_category = {"name": "Test category3", "posts": [{"title": "Test Post"}]}
|
||||
|
||||
# cannot add posts if skipped, will be error with extra forbid
|
||||
Category.__config__.extra = "forbid"
|
||||
response = client.post("/categories/", data=json.dumps(wrong_category))
|
||||
assert response.status_code == 422
|
||||
@ -123,6 +123,16 @@ async def get_test_5(thing_id: UUID):
|
||||
return await Thing.objects.all(other_thing__id=thing_id)
|
||||
|
||||
|
||||
@app.get(
|
||||
"/test/error", response_model=List[Thing], response_model_exclude={"other_thing"}
|
||||
)
|
||||
async def get_weakref():
|
||||
ots = await OtherThing.objects.all()
|
||||
ot = ots[0]
|
||||
ts = await ot.things.all()
|
||||
return ts
|
||||
|
||||
|
||||
def test_endpoints():
|
||||
client = TestClient(app)
|
||||
with client:
|
||||
@ -145,3 +155,7 @@ def test_endpoints():
|
||||
resp5 = client.get(f"/test/5/{ot.id}")
|
||||
assert resp5.status_code == 200
|
||||
assert len(resp5.json()) == 3
|
||||
|
||||
resp6 = client.get("/test/error")
|
||||
assert resp6.status_code == 200
|
||||
assert len(resp6.json()) == 3
|
||||
|
||||
Reference in New Issue
Block a user