fix many_to_many lazy registration in fastapi cloned models, fixed ForeignKey not treated as subclasses of BaseModels in json schema

This commit is contained in:
collerek
2020-10-27 13:49:07 +01:00
parent 36300f9056
commit d3091c404f
8 changed files with 232 additions and 79 deletions

View File

@ -58,6 +58,7 @@ def ForeignKey( # noqa CFQ002
pydantic_only=False,
default=None,
server_default=None,
__pydantic_model__=to,
)
return type("ForeignKey", (ForeignKeyField, BaseField), namespace)

View File

@ -5,7 +5,7 @@ import ormar
from ormar.exceptions import RelationshipInstanceError
from ormar.fields import BaseField, ManyToManyField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta
from ormar.models.metaclass import ModelMeta, expand_reverse_relationships
if TYPE_CHECKING: # pragma no cover
from ormar import Model
@ -115,6 +115,7 @@ class ModelTableProxy:
def resolve_relation_name(
item: Union["NewBaseModel", Type["NewBaseModel"]],
related: Union["NewBaseModel", Type["NewBaseModel"]],
register_missing: bool = True
) -> str:
for name, field in item.Meta.model_fields.items():
if issubclass(field, ForeignKeyField):
@ -123,6 +124,11 @@ class ModelTableProxy:
# so we need to compare Meta too as this one is copied as is
if field.to == related.__class__ or field.to.Meta == related.Meta:
return name
# fallback for not registered relation
if register_missing:
expand_reverse_relationships(related.__class__)
return ModelTableProxy.resolve_relation_name(item, related, register_missing=False)
raise ValueError(
f"No relation between {item.get_name()} and {related.get_name()}"
) # pragma nocover

View File

@ -137,7 +137,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
alias = self.get_column_alias(item)
if alias in self._orm:
return self._orm.get(alias)
return None
return None # pragma no cover
def __eq__(self, other: object) -> bool:
if isinstance(other, NewBaseModel):

View File

@ -40,6 +40,8 @@ class RelationsManager:
to=field.to,
through=getattr(field, "through", None),
)
if field.name not in self._related_names:
self._related_names.append(field.name)
def __contains__(self, item: str) -> bool:
return item in self._related_names

View File

@ -6,7 +6,7 @@ from ormar.relations.querysetproxy import QuerysetProxy
if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.relations import Relation
from ormar.relations import Relation, register_missing_relation
from ormar.queryset import QuerySet
@ -72,4 +72,6 @@ class RelationProxy(list):
if self.relation._type == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner)
if not rel_name in item._orm:
item._orm._add_relation(item.Meta.model_fields[rel_name])
setattr(item, rel_name, self._owner)

125
tests/test_fastapi_docs.py Normal file
View File

@ -0,0 +1,125 @@
from typing import List
import databases
import pytest
import sqlalchemy
from fastapi import FastAPI
from starlette.testclient import TestClient
import ormar
from tests.settings import DATABASE_URL
app = FastAPI()
metadata = sqlalchemy.MetaData()
database = databases.Database(DATABASE_URL, force_rollback=True)
app.state.database = database
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
class LocalMeta:
metadata = metadata
database = database
class Category(ormar.Model):
class Meta(LocalMeta):
tablename = "categories"
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
class ItemsXCategories(ormar.Model):
class Meta(LocalMeta):
tablename = 'items_x_categories'
class Item(ormar.Model):
class Meta(LocalMeta):
pass
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
categories: ormar.ManyToMany(Category, through=ItemsXCategories)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@app.get("/items/", response_model=List[Item])
async def get_items():
items = await Item.objects.select_related("categories").all()
return items
@app.post("/items/", response_model=Item)
async def create_item(item: Item):
await item.save()
return item
@app.post("/items/add_category/", response_model=Item)
async def create_item(item: Item, category: Category):
await item.categories.add(category)
return item
@app.post("/categories/", response_model=Category)
async def create_category(category: Category):
await category.save()
return category
def test_all_endpoints():
client = TestClient(app)
with client as client:
response = client.post("/categories/", json={"name": "test cat"})
category = response.json()
response = client.post("/categories/", json={"name": "test cat2"})
category2 = response.json()
response = client.post(
"/items/", json={"name": "test", "id": 1}
)
item = Item(**response.json())
assert item.pk is not None
response = client.post(
"/items/add_category/", json={"item": item.dict(), "category": category}
)
item = Item(**response.json())
assert len(item.categories) == 1
assert item.categories[0].name == 'test cat'
client.post(
"/items/add_category/", json={"item": item.dict(), "category": category2}
)
response = client.get("/items/")
items = [Item(**item) for item in response.json()]
assert items[0] == item
assert len(items[0].categories) == 2
assert items[0].categories[0].name == 'test cat'
assert items[0].categories[1].name == 'test cat2'
response = client.get("/docs/")
assert response.status_code == 200
assert b'<title>FastAPI - Swagger UI</title>' in response.content

View File

@ -77,13 +77,15 @@ async def create_category(category: Category):
@app.put("/items/{item_id}")
async def get_item(item_id: int, item: Item):
async def update_item(item_id: int, item: Item):
item_db = await Item.objects.get(pk=item_id)
return await item_db.update(**item.dict())
@app.delete("/items/{item_id}")
async def delete_item(item_id: int, item: Item):
async def delete_item(item_id: int, item: Item = None):
if item:
return {"deleted_rows": await item.delete()}
item_db = await Item.objects.get(pk=item_id)
return {"deleted_rows": await item_db.delete()}
@ -111,8 +113,23 @@ def test_all_endpoints():
items = [Item(**item) for item in response.json()]
assert items[0].name == "New name"
response = client.delete(f"/items/{item.pk}", json=item.dict())
response = client.delete(f"/items/{item.pk}")
assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__"
response = client.get("/items/")
items = response.json()
assert len(items) == 0
client.post(
"/items/", json={"name": "test_2", "id": 2, "category": category}
)
response = client.get("/items/")
items = response.json()
assert len(items) == 1
item = Item(**items[0])
response = client.delete(f"/items/{item.pk}", json=item.dict())
assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__"
response = client.get("/docs/")
assert response.status_code == 200