fix many_to_many lazy registration in fastapi cloned models, fixed ForeignKey not treated as subclasses of BaseModels in json schema
This commit is contained in:
@ -58,6 +58,7 @@ def ForeignKey( # noqa CFQ002
|
|||||||
pydantic_only=False,
|
pydantic_only=False,
|
||||||
default=None,
|
default=None,
|
||||||
server_default=None,
|
server_default=None,
|
||||||
|
__pydantic_model__=to,
|
||||||
)
|
)
|
||||||
|
|
||||||
return type("ForeignKey", (ForeignKeyField, BaseField), namespace)
|
return type("ForeignKey", (ForeignKeyField, BaseField), namespace)
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import ormar
|
|||||||
from ormar.exceptions import RelationshipInstanceError
|
from ormar.exceptions import RelationshipInstanceError
|
||||||
from ormar.fields import BaseField, ManyToManyField
|
from ormar.fields import BaseField, ManyToManyField
|
||||||
from ormar.fields.foreign_key import ForeignKeyField
|
from ormar.fields.foreign_key import ForeignKeyField
|
||||||
from ormar.models.metaclass import ModelMeta
|
from ormar.models.metaclass import ModelMeta, expand_reverse_relationships
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
@ -115,6 +115,7 @@ class ModelTableProxy:
|
|||||||
def resolve_relation_name(
|
def resolve_relation_name(
|
||||||
item: Union["NewBaseModel", Type["NewBaseModel"]],
|
item: Union["NewBaseModel", Type["NewBaseModel"]],
|
||||||
related: Union["NewBaseModel", Type["NewBaseModel"]],
|
related: Union["NewBaseModel", Type["NewBaseModel"]],
|
||||||
|
register_missing: bool = True
|
||||||
) -> str:
|
) -> str:
|
||||||
for name, field in item.Meta.model_fields.items():
|
for name, field in item.Meta.model_fields.items():
|
||||||
if issubclass(field, ForeignKeyField):
|
if issubclass(field, ForeignKeyField):
|
||||||
@ -123,6 +124,11 @@ class ModelTableProxy:
|
|||||||
# so we need to compare Meta too as this one is copied as is
|
# so we need to compare Meta too as this one is copied as is
|
||||||
if field.to == related.__class__ or field.to.Meta == related.Meta:
|
if field.to == related.__class__ or field.to.Meta == related.Meta:
|
||||||
return name
|
return name
|
||||||
|
# fallback for not registered relation
|
||||||
|
if register_missing:
|
||||||
|
expand_reverse_relationships(related.__class__)
|
||||||
|
return ModelTableProxy.resolve_relation_name(item, related, register_missing=False)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No relation between {item.get_name()} and {related.get_name()}"
|
f"No relation between {item.get_name()} and {related.get_name()}"
|
||||||
) # pragma nocover
|
) # pragma nocover
|
||||||
@ -204,7 +210,7 @@ class ModelTableProxy:
|
|||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
model_name = f"{model.get_name()}__"
|
model_name = f"{model.get_name()}__"
|
||||||
columns = [
|
columns = [
|
||||||
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
|
name[(name.find(model_name) + len(model_name)):] # noqa: E203
|
||||||
for name in fields
|
for name in fields
|
||||||
if f"{model.get_name()}__" in name
|
if f"{model.get_name()}__" in name
|
||||||
]
|
]
|
||||||
|
|||||||
@ -137,7 +137,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
alias = self.get_column_alias(item)
|
alias = self.get_column_alias(item)
|
||||||
if alias in self._orm:
|
if alias in self._orm:
|
||||||
return self._orm.get(alias)
|
return self._orm.get(alias)
|
||||||
return None
|
return None # pragma no cover
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if isinstance(other, NewBaseModel):
|
if isinstance(other, NewBaseModel):
|
||||||
|
|||||||
@ -40,6 +40,8 @@ class RelationsManager:
|
|||||||
to=field.to,
|
to=field.to,
|
||||||
through=getattr(field, "through", None),
|
through=getattr(field, "through", None),
|
||||||
)
|
)
|
||||||
|
if field.name not in self._related_names:
|
||||||
|
self._related_names.append(field.name)
|
||||||
|
|
||||||
def __contains__(self, item: str) -> bool:
|
def __contains__(self, item: str) -> bool:
|
||||||
return item in self._related_names
|
return item in self._related_names
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from ormar.relations.querysetproxy import QuerysetProxy
|
|||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
from ormar.relations import Relation
|
from ormar.relations import Relation, register_missing_relation
|
||||||
from ormar.queryset import QuerySet
|
from ormar.queryset import QuerySet
|
||||||
|
|
||||||
|
|
||||||
@ -72,4 +72,6 @@ class RelationProxy(list):
|
|||||||
if self.relation._type == ormar.RelationType.MULTIPLE:
|
if self.relation._type == ormar.RelationType.MULTIPLE:
|
||||||
await self.queryset_proxy.create_through_instance(item)
|
await self.queryset_proxy.create_through_instance(item)
|
||||||
rel_name = item.resolve_relation_name(item, self._owner)
|
rel_name = item.resolve_relation_name(item, self._owner)
|
||||||
|
if not rel_name in item._orm:
|
||||||
|
item._orm._add_relation(item.Meta.model_fields[rel_name])
|
||||||
setattr(item, rel_name, self._owner)
|
setattr(item, rel_name, self._owner)
|
||||||
|
|||||||
125
tests/test_fastapi_docs.py
Normal file
125
tests/test_fastapi_docs.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import databases
|
||||||
|
import pytest
|
||||||
|
import sqlalchemy
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
import ormar
|
||||||
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
metadata = sqlalchemy.MetaData()
|
||||||
|
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||||
|
app.state.database = database
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup() -> None:
|
||||||
|
database_ = app.state.database
|
||||||
|
if not database_.is_connected:
|
||||||
|
await database_.connect()
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown() -> None:
|
||||||
|
database_ = app.state.database
|
||||||
|
if database_.is_connected:
|
||||||
|
await database_.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
class LocalMeta:
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
|
||||||
|
class Category(ormar.Model):
|
||||||
|
class Meta(LocalMeta):
|
||||||
|
tablename = "categories"
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
name: ormar.String(max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class ItemsXCategories(ormar.Model):
|
||||||
|
class Meta(LocalMeta):
|
||||||
|
tablename = 'items_x_categories'
|
||||||
|
|
||||||
|
|
||||||
|
class Item(ormar.Model):
|
||||||
|
class Meta(LocalMeta):
|
||||||
|
pass
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
name: ormar.String(max_length=100)
|
||||||
|
categories: ormar.ManyToMany(Category, through=ItemsXCategories)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True, scope="module")
|
||||||
|
def create_test_database():
|
||||||
|
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||||
|
metadata.create_all(engine)
|
||||||
|
yield
|
||||||
|
metadata.drop_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/", response_model=List[Item])
|
||||||
|
async def get_items():
|
||||||
|
items = await Item.objects.select_related("categories").all()
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/items/", response_model=Item)
|
||||||
|
async def create_item(item: Item):
|
||||||
|
await item.save()
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/items/add_category/", response_model=Item)
|
||||||
|
async def create_item(item: Item, category: Category):
|
||||||
|
await item.categories.add(category)
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/categories/", response_model=Category)
|
||||||
|
async def create_category(category: Category):
|
||||||
|
await category.save()
|
||||||
|
return category
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_endpoints():
|
||||||
|
client = TestClient(app)
|
||||||
|
with client as client:
|
||||||
|
response = client.post("/categories/", json={"name": "test cat"})
|
||||||
|
category = response.json()
|
||||||
|
response = client.post("/categories/", json={"name": "test cat2"})
|
||||||
|
category2 = response.json()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/items/", json={"name": "test", "id": 1}
|
||||||
|
)
|
||||||
|
item = Item(**response.json())
|
||||||
|
assert item.pk is not None
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/items/add_category/", json={"item": item.dict(), "category": category}
|
||||||
|
)
|
||||||
|
item = Item(**response.json())
|
||||||
|
assert len(item.categories) == 1
|
||||||
|
assert item.categories[0].name == 'test cat'
|
||||||
|
|
||||||
|
client.post(
|
||||||
|
"/items/add_category/", json={"item": item.dict(), "category": category2}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/items/")
|
||||||
|
items = [Item(**item) for item in response.json()]
|
||||||
|
assert items[0] == item
|
||||||
|
assert len(items[0].categories) == 2
|
||||||
|
assert items[0].categories[0].name == 'test cat'
|
||||||
|
assert items[0].categories[1].name == 'test cat2'
|
||||||
|
|
||||||
|
response = client.get("/docs/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert b'<title>FastAPI - Swagger UI</title>' in response.content
|
||||||
@ -77,13 +77,15 @@ async def create_category(category: Category):
|
|||||||
|
|
||||||
|
|
||||||
@app.put("/items/{item_id}")
|
@app.put("/items/{item_id}")
|
||||||
async def get_item(item_id: int, item: Item):
|
async def update_item(item_id: int, item: Item):
|
||||||
item_db = await Item.objects.get(pk=item_id)
|
item_db = await Item.objects.get(pk=item_id)
|
||||||
return await item_db.update(**item.dict())
|
return await item_db.update(**item.dict())
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/items/{item_id}")
|
@app.delete("/items/{item_id}")
|
||||||
async def delete_item(item_id: int, item: Item):
|
async def delete_item(item_id: int, item: Item = None):
|
||||||
|
if item:
|
||||||
|
return {"deleted_rows": await item.delete()}
|
||||||
item_db = await Item.objects.get(pk=item_id)
|
item_db = await Item.objects.get(pk=item_id)
|
||||||
return {"deleted_rows": await item_db.delete()}
|
return {"deleted_rows": await item_db.delete()}
|
||||||
|
|
||||||
@ -111,8 +113,23 @@ def test_all_endpoints():
|
|||||||
items = [Item(**item) for item in response.json()]
|
items = [Item(**item) for item in response.json()]
|
||||||
assert items[0].name == "New name"
|
assert items[0].name == "New name"
|
||||||
|
|
||||||
response = client.delete(f"/items/{item.pk}", json=item.dict())
|
response = client.delete(f"/items/{item.pk}")
|
||||||
assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__"
|
assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__"
|
||||||
response = client.get("/items/")
|
response = client.get("/items/")
|
||||||
items = response.json()
|
items = response.json()
|
||||||
assert len(items) == 0
|
assert len(items) == 0
|
||||||
|
|
||||||
|
client.post(
|
||||||
|
"/items/", json={"name": "test_2", "id": 2, "category": category}
|
||||||
|
)
|
||||||
|
response = client.get("/items/")
|
||||||
|
items = response.json()
|
||||||
|
assert len(items) == 1
|
||||||
|
|
||||||
|
item = Item(**items[0])
|
||||||
|
response = client.delete(f"/items/{item.pk}", json=item.dict())
|
||||||
|
assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__"
|
||||||
|
|
||||||
|
response = client.get("/docs/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user