add skip_reverse parameter, add links to related libs, fix weakref error, fix through error with extra=forbid

This commit is contained in:
collerek
2021-04-11 18:43:23 +02:00
parent e553885221
commit b3b1c156b5
19 changed files with 675 additions and 48 deletions

View File

@ -1,4 +1,5 @@
import datetime
from typing import List
import pytest
import sqlalchemy
@ -59,6 +60,12 @@ async def get_bus(item_id: int):
return bus
@app.get("/buses/", response_model=List[Bus])
async def get_buses():
buses = await Bus.objects.select_related(["owner", "co_owner"]).all()
return buses
@app.post("/trucks/", response_model=Truck)
async def create_truck(truck: Truck):
await truck.save()
@ -84,6 +91,12 @@ async def add_bus_coowner(item_id: int, person: Person):
return bus
@app.get("/buses2/", response_model=List[Bus2])
async def get_buses2():
buses = await Bus2.objects.select_related(["owner", "co_owners"]).all()
return buses
@app.post("/trucks2/", response_model=Truck2)
async def create_truck2(truck: Truck2):
await truck.save()
@ -172,6 +185,10 @@ def test_inheritance_with_relation():
assert unicorn2.co_owner.name == "Joe"
assert unicorn2.max_persons == 50
buses = [Bus(**x) for x in client.get("/buses/").json()]
assert len(buses) == 1
assert buses[0].name == "Unicorn"
def test_inheritance_with_m2m_relation():
client = TestClient(app)
@ -217,3 +234,7 @@ def test_inheritance_with_m2m_relation():
assert shelby.co_owners[0] == alex
assert shelby.co_owners[1] == joe
assert shelby.max_capacity == 2000
buses = [Bus2(**x) for x in client.get("/buses2/").json()]
assert len(buses) == 1
assert buses[0].name == "Unicorn"

View File

@ -0,0 +1,106 @@
import json
from typing import List, Optional
import databases
import pytest
import sqlalchemy
from fastapi import FastAPI
from starlette.testclient import TestClient
import ormar
from tests.settings import DATABASE_URL
app = FastAPI()
metadata = sqlalchemy.MetaData()
database = databases.Database(DATABASE_URL, force_rollback=True)
app.state.database = database
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
class Department(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
department_name: str = ormar.String(max_length=100)
class Course(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
course_name: str = ormar.String(max_length=100)
completed: bool = ormar.Boolean()
department: Optional[Department] = ormar.ForeignKey(Department)
# create db and tables
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@app.post("/DepartmentWithCourses/", response_model=Department)
async def create_department(department: Department):
# there is no save all - you need to split into save and save_related
await department.save()
await department.save_related(follow=True, save_all=True)
return department
@app.get("/DepartmentsAll/", response_model=List[Department])
async def get_Courses():
# if you don't provide default name it related model name + s so courses not course
departmentall = await Department.objects.select_related("courses").all()
return departmentall
def test_saving_related_in_fastapi():
client = TestClient(app)
with client as client:
payload = {
"department_name": "Ormar",
"courses": [
{"course_name": "basic1", "completed": True},
{"course_name": "basic2", "completed": True},
],
}
response = client.post("/DepartmentWithCourses/", data=json.dumps(payload))
department = Department(**response.json())
assert department.id is not None
assert len(department.courses) == 2
assert department.department_name == "Ormar"
assert department.courses[0].course_name == "basic1"
assert department.courses[0].completed
assert department.courses[1].course_name == "basic2"
assert department.courses[1].completed
response = client.get("/DepartmentsAll/")
departments = [Department(**x) for x in response.json()]
assert departments[0].id is not None
assert len(departments[0].courses) == 2
assert departments[0].department_name == "Ormar"
assert departments[0].courses[0].course_name == "basic1"
assert departments[0].courses[0].completed
assert departments[0].courses[1].course_name == "basic2"
assert departments[0].courses[1].completed

View File

@ -0,0 +1,148 @@
import json
from typing import List, Optional
import databases
import pytest
import sqlalchemy
from fastapi import FastAPI
from starlette.testclient import TestClient
import ormar
from tests.settings import DATABASE_URL
app = FastAPI()
metadata = sqlalchemy.MetaData()
database = databases.Database(DATABASE_URL, force_rollback=True)
app.state.database = database
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
class BaseMeta(ormar.ModelMeta):
database = database
metadata = metadata
class Author(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
first_name: str = ormar.String(max_length=80)
last_name: str = ormar.String(max_length=80)
class Category(ormar.Model):
class Meta(BaseMeta):
tablename = "categories"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=40)
class Post(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=200)
categories = ormar.ManyToMany(Category, skip_reverse=True)
author: Optional[Author] = ormar.ForeignKey(Author, skip_reverse=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@app.post("/categories/", response_model=Category)
async def create_category(category: Category):
await category.save()
await category.save_related(follow=True, save_all=True)
return category
@app.post("/posts/", response_model=Post)
async def create_post(post: Post):
if post.author:
await post.author.save()
await post.save()
await post.save_related(follow=True, save_all=True)
for category in [cat for cat in post.categories]:
await post.categories.add(category)
return post
@app.get("/categories/", response_model=List[Category])
async def get_categories():
return await Category.objects.select_related("posts").all()
@app.get("/posts/", response_model=List[Post])
async def get_posts():
posts = await Post.objects.select_related(["categories", "author"]).all()
return posts
def test_queries():
client = TestClient(app)
with client as client:
right_category = {"name": "Test category"}
wrong_category = {"name": "Test category2", "posts": [{"title": "Test Post"}]}
# cannot add posts if skipped, will be ignored (with extra=ignore by default)
response = client.post("/categories/", data=json.dumps(wrong_category))
assert response.status_code == 200
response = client.get("/categories/")
assert response.status_code == 200
assert not "posts" in response.json()
categories = [Category(**x) for x in response.json()]
assert categories[0] is not None
assert categories[0].name == "Test category2"
response = client.post("/categories/", data=json.dumps(right_category))
assert response.status_code == 200
response = client.get("/categories/")
assert response.status_code == 200
categories = [Category(**x) for x in response.json()]
assert categories[1] is not None
assert categories[1].name == "Test category"
right_post = {
"title": "ok post",
"author": {"first_name": "John", "last_name": "Smith"},
"categories": [{"name": "New cat"}],
}
response = client.post("/posts/", data=json.dumps(right_post))
assert response.status_code == 200
Category.__config__.extra = "allow"
response = client.get("/posts/")
assert response.status_code == 200
posts = [Post(**x) for x in response.json()]
assert posts[0].title == "ok post"
assert posts[0].author.first_name == "John"
assert posts[0].categories[0].name == "New cat"
wrong_category = {"name": "Test category3", "posts": [{"title": "Test Post"}]}
# cannot add posts if skipped, will be error with extra forbid
Category.__config__.extra = "forbid"
response = client.post("/categories/", data=json.dumps(wrong_category))
assert response.status_code == 422

View File

@ -123,6 +123,16 @@ async def get_test_5(thing_id: UUID):
return await Thing.objects.all(other_thing__id=thing_id)
@app.get(
"/test/error", response_model=List[Thing], response_model_exclude={"other_thing"}
)
async def get_weakref():
ots = await OtherThing.objects.all()
ot = ots[0]
ts = await ot.things.all()
return ts
def test_endpoints():
client = TestClient(app)
with client:
@ -145,3 +155,7 @@ def test_endpoints():
resp5 = client.get(f"/test/5/{ot.id}")
assert resp5.status_code == 200
assert len(resp5.json()) == 3
resp6 = client.get("/test/error")
assert resp6.status_code == 200
assert len(resp6.json()) == 3

View File

@ -0,0 +1,223 @@
from typing import List, Optional
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
class BaseMeta(ormar.ModelMeta):
database = database
metadata = metadata
class Author(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
first_name: str = ormar.String(max_length=80)
last_name: str = ormar.String(max_length=80)
class Category(ormar.Model):
class Meta(BaseMeta):
tablename = "categories"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=40)
class Post(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=200)
categories: Optional[List[Category]] = ormar.ManyToMany(Category, skip_reverse=True)
author: Optional[Author] = ormar.ForeignKey(Author, skip_reverse=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.fixture(scope="function")
async def cleanup():
yield
async with database:
PostCategory = Post.Meta.model_fields["categories"].through
await PostCategory.objects.delete(each=True)
await Post.objects.delete(each=True)
await Category.objects.delete(each=True)
await Author.objects.delete(each=True)
def test_model_definition():
category = Category(name="Test")
author = Author(first_name="Test", last_name="Author")
post = Post(title="Test Post", author=author)
post.categories = category
assert post.categories[0] == category
assert post.author == author
with pytest.raises(AttributeError):
assert author.posts
with pytest.raises(AttributeError):
assert category.posts
assert "posts" not in category._orm
@pytest.mark.asyncio
async def test_assigning_related_objects(cleanup):
async with database:
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
# Add a category to a post.
await post.categories.add(news)
# other way is disabled
with pytest.raises(AttributeError):
await news.posts.add(post)
assert await post.categories.get_or_none(name="no exist") is None
assert await post.categories.get_or_none(name="News") == news
# Creating columns object from instance:
await post.categories.create(name="Tips")
assert len(post.categories) == 2
post_categories = await post.categories.all()
assert len(post_categories) == 2
category = await Category.objects.select_related("posts").get(name="News")
with pytest.raises(AttributeError):
assert category.posts
@pytest.mark.asyncio
async def test_quering_of_related_model_works_but_no_result(cleanup):
async with database:
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
await post.categories.add(news)
post_categories = await post.categories.all()
assert len(post_categories) == 1
assert "posts" not in post.dict().get("categories", [])[0]
assert news == await post.categories.get(name="News")
posts_about_python = await Post.objects.filter(categories__name="python").all()
assert len(posts_about_python) == 0
# relation not in dict
category = (
await Category.objects.select_related("posts")
.filter(posts__author=guido)
.get()
)
assert category == news
assert "posts" not in category.dict()
# relation not in json
category2 = (
await Category.objects.select_related("posts")
.filter(posts__author__first_name="Guido")
.get()
)
assert category2 == news
assert "posts" not in category2.json()
assert "posts" not in Category.schema().get("properties")
@pytest.mark.asyncio
async def test_removal_of_the_relations(cleanup):
async with database:
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
await post.categories.add(news)
assert len(await post.categories.all()) == 1
await post.categories.remove(news)
assert len(await post.categories.all()) == 0
with pytest.raises(AttributeError):
await news.posts.add(post)
with pytest.raises(AttributeError):
await news.posts.remove(post)
await post.categories.add(news)
await post.categories.clear()
assert len(await post.categories.all()) == 0
await post.categories.add(news)
await news.delete()
assert len(await post.categories.all()) == 0
@pytest.mark.asyncio
async def test_selecting_related(cleanup):
async with database:
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
guido2 = await Author.objects.create(
first_name="Guido2", last_name="Van Rossum"
)
post = await Post.objects.create(title="Hello, M2M", author=guido)
post2 = await Post.objects.create(title="Bye, M2M", author=guido2)
news = await Category.objects.create(name="News")
recent = await Category.objects.create(name="Recent")
await post.categories.add(news)
await post.categories.add(recent)
await post2.categories.add(recent)
assert len(await post.categories.all()) == 2
assert (await post.categories.limit(1).all())[0] == news
assert (await post.categories.offset(1).limit(1).all())[0] == recent
assert await post.categories.first() == news
assert await post.categories.exists()
# still can order
categories = (
await Category.objects.select_related("posts")
.order_by("posts__title")
.all()
)
assert categories[0].name == "Recent"
assert categories[1].name == "News"
# still can filter
categories = await Category.objects.filter(posts__title="Bye, M2M").all()
assert categories[0].name == "Recent"
assert len(categories) == 1
# same for reverse fk
authors = (
await Author.objects.select_related("posts").order_by("posts__title").all()
)
assert authors[0].first_name == "Guido2"
assert authors[1].first_name == "Guido"
authors = await Author.objects.filter(posts__title="Bye, M2M").all()
assert authors[0].first_name == "Guido2"
assert len(authors) == 1