Merge pull request #21 from collerek/fix_json_schema

Fix json schema generation
This commit is contained in:
collerek
2020-10-28 00:28:22 +07:00
committed by GitHub
19 changed files with 310 additions and 79 deletions

BIN
.coverage

Binary file not shown.

1
.gitignore vendored
View File

@ -2,6 +2,7 @@ p38venv
.idea .idea
.pytest_cache .pytest_cache
.mypy_cache .mypy_cache
.coverage
*.pyc *.pyc
*.log *.log
test.db test.db

View File

@ -45,7 +45,7 @@ Those models will be used insted of pydantic ones.
Define your desired endpoints, note how `ormar` models are used both Define your desired endpoints, note how `ormar` models are used both
as `response_model` and as a requests parameters. as `response_model` and as a requests parameters.
```python hl_lines="50-77" ```python hl_lines="50-79"
--8<-- "../docs_src/fastapi/docs001.py" --8<-- "../docs_src/fastapi/docs001.py"
``` ```
@ -57,6 +57,23 @@ as `response_model` and as a requests parameters.
## Test the application ## Test the application
### Run fastapi
If you want to run this script and play with fastapi swagger install uvicorn first
`pip install uvicorn`
And launch the fastapi.
`uvicorn <filename_without_extension>:app --reload`
Now you can navigate to your browser (by default fastapi address is `127.0.0.1:8000/docs`) and play with the api.
!!!info
You can read more about running fastapi in [fastapi][fastapi] docs.
### Test with pytest
Here you have a sample test that will prove that everything works as intended. Here you have a sample test that will prove that everything works as intended.
Be sure to create the tables first. If you are using pytest you can use a fixture. Be sure to create the tables first. If you are using pytest you can use a fixture.
@ -109,9 +126,13 @@ def test_all_endpoints():
assert len(items) == 0 assert len(items) == 0
``` ```
!!!tip
If you want to see more test cases and how to test ormar/fastapi see [tests][tests] directory in the github repo
!!!info !!!info
You can read more on testing fastapi in [fastapi][fastapi] docs. You can read more on testing fastapi in [fastapi][fastapi] docs.
[fastapi]: https://fastapi.tiangolo.com/ [fastapi]: https://fastapi.tiangolo.com/
[models]: ./models.md [models]: ./models.md
[database initialization]: ../models/#database-initialization-migrations [database initialization]: ../models/#database-initialization-migrations
[tests]: https://github.com/collerek/ormar/tree/master/tests

View File

@ -1,3 +1,10 @@
# 0.3.9
* Fix json schema generation as of [#19][#19]
* Fix for not initialized ManyToMany relations in fastapi copies of ormar.Models
* Update docs in regard of fastapi use
* Add tests to verify fastapi/docs proper generation
# 0.3.8 # 0.3.8
* Added possibility to provide alternative database column names with name parameter to all fields. * Added possibility to provide alternative database column names with name parameter to all fields.
@ -43,3 +50,6 @@ Add queryset level methods
# 0.3.0 # 0.3.0
* Added ManyToMany field and support for many to many relations * Added ManyToMany field and support for many to many relations
[#19]: https://github.com/collerek/ormar/issues/19

View File

@ -72,6 +72,8 @@ async def get_item(item_id: int, item: Item):
@app.delete("/items/{item_id}") @app.delete("/items/{item_id}")
async def delete_item(item_id: int, item: Item): async def delete_item(item_id: int, item: Item = None):
if item:
return {"deleted_rows": await item.delete()}
item_db = await Item.objects.get(pk=item_id) item_db = await Item.objects.get(pk=item_id)
return {"deleted_rows": await item_db.delete()} return {"deleted_rows": await item_db.delete()}

View File

@ -12,9 +12,9 @@ nav:
- Release Notes: releases.md - Release Notes: releases.md
repo_name: collerek/ormar repo_name: collerek/ormar
repo_url: https://github.com/collerek/ormar repo_url: https://github.com/collerek/ormar
google_analytics: #google_analytics:
- UA-72514911-3 # - UA-72514911-3
- auto # - auto
theme: theme:
name: material name: material
highlightjs: true highlightjs: true

View File

@ -28,7 +28,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType() Undefined = UndefinedType()
__version__ = "0.3.8" __version__ = "0.3.9"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",

View File

@ -58,6 +58,7 @@ def ForeignKey( # noqa CFQ002
pydantic_only=False, pydantic_only=False,
default=None, default=None,
server_default=None, server_default=None,
__pydantic_model__=to,
) )
return type("ForeignKey", (ForeignKeyField, BaseField), namespace) return type("ForeignKey", (ForeignKeyField, BaseField), namespace)

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Type from typing import Dict, TYPE_CHECKING, Type
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
@ -6,6 +6,8 @@ from ormar.fields.foreign_key import ForeignKeyField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models import Model from ormar.models import Model
REF_PREFIX = "#/components/schemas/"
def ManyToMany( def ManyToMany(
to: Type["Model"], to: Type["Model"],
@ -31,6 +33,9 @@ def ManyToMany(
pydantic_only=False, pydantic_only=False,
default=None, default=None,
server_default=None, server_default=None,
__pydantic_model__=to,
# __origin__=List,
# __args__=[Optional[to]]
) )
return type("ManyToMany", (ManyToManyField, BaseField), namespace) return type("ManyToMany", (ManyToManyField, BaseField), namespace)
@ -38,3 +43,10 @@ def ManyToMany(
class ManyToManyField(ForeignKeyField): class ManyToManyField(ForeignKeyField):
through: Type["Model"] through: Type["Model"]
@classmethod
def __modify_schema__(cls, field_schema: Dict) -> None:
field_schema["type"] = "array"
field_schema["title"] = cls.name.title()
field_schema["definitions"] = {f"{cls.to.__name__}": cls.to.schema()}
field_schema["items"] = {"$ref": f"{REF_PREFIX}{cls.to.__name__}"}

View File

@ -145,7 +145,7 @@ class Model(NewBaseModel):
self_fields = self._extract_model_db_fields() self_fields = self._extract_model_db_fields()
self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname)) self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname))
self_fields = self.objects._translate_columns_to_aliases(self_fields) self_fields = self.translate_columns_to_aliases(self_fields)
expr = self.Meta.table.update().values(**self_fields) expr = self.Meta.table.update().values(**self_fields)
expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname)) expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname))
@ -166,6 +166,6 @@ class Model(NewBaseModel):
"Instance was deleted from database and cannot be refreshed" "Instance was deleted from database and cannot be refreshed"
) )
kwargs = dict(row) kwargs = dict(row)
kwargs = self.objects._translate_aliases_to_columns(kwargs) kwargs = self.translate_aliases_to_columns(kwargs)
self.from_dict(kwargs) self.from_dict(kwargs)
return self return self

View File

@ -5,7 +5,7 @@ import ormar
from ormar.exceptions import RelationshipInstanceError from ormar.exceptions import RelationshipInstanceError
from ormar.fields import BaseField, ManyToManyField from ormar.fields import BaseField, ManyToManyField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta from ormar.models.metaclass import ModelMeta, expand_reverse_relationships
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -112,9 +112,10 @@ class ModelTableProxy:
return self_fields return self_fields
@staticmethod @staticmethod
def resolve_relation_name( def resolve_relation_name( # noqa CCR001
item: Union["NewBaseModel", Type["NewBaseModel"]], item: Union["NewBaseModel", Type["NewBaseModel"]],
related: Union["NewBaseModel", Type["NewBaseModel"]], related: Union["NewBaseModel", Type["NewBaseModel"]],
register_missing: bool = True,
) -> str: ) -> str:
for name, field in item.Meta.model_fields.items(): for name, field in item.Meta.model_fields.items():
if issubclass(field, ForeignKeyField): if issubclass(field, ForeignKeyField):
@ -123,6 +124,13 @@ class ModelTableProxy:
# so we need to compare Meta too as this one is copied as is # so we need to compare Meta too as this one is copied as is
if field.to == related.__class__ or field.to.Meta == related.Meta: if field.to == related.__class__ or field.to.Meta == related.Meta:
return name return name
# fallback for not registered relation
if register_missing:
expand_reverse_relationships(related.__class__) # type: ignore
return ModelTableProxy.resolve_relation_name(
item, related, register_missing=False
)
raise ValueError( raise ValueError(
f"No relation between {item.get_name()} and {related.get_name()}" f"No relation between {item.get_name()} and {related.get_name()}"
) # pragma nocover ) # pragma nocover
@ -140,6 +148,24 @@ class ModelTableProxy:
) )
return to_field return to_field
@classmethod
def translate_columns_to_aliases(cls, new_kwargs: Dict) -> Dict:
for field_name, field in cls.Meta.model_fields.items():
if (
field_name in new_kwargs
and field.name is not None
and field.name != field_name
):
new_kwargs[field.name] = new_kwargs.pop(field_name)
return new_kwargs
@classmethod
def translate_aliases_to_columns(cls, new_kwargs: Dict) -> Dict:
for field_name, field in cls.Meta.model_fields.items():
if field.name in new_kwargs and field.name != field_name:
new_kwargs[field_name] = new_kwargs.pop(field.name)
return new_kwargs
@classmethod @classmethod
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows: List["Model"] = [] merged_rows: List["Model"] = []

View File

@ -100,7 +100,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
) )
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
if name in self.__slots__: if name in ("_orm_id", "_orm_saved", "_orm"):
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
elif name == "pk": elif name == "pk":
object.__setattr__(self, self.Meta.pkname, value) object.__setattr__(self, self.Meta.pkname, value)
@ -137,7 +137,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
alias = self.get_column_alias(item) alias = self.get_column_alias(item)
if alias in self._orm: if alias in self._orm:
return self._orm.get(alias) return self._orm.get(alias)
return None return None # pragma no cover
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if isinstance(other, NewBaseModel): if isinstance(other, NewBaseModel):

View File

@ -74,7 +74,7 @@ class QuerySet:
new_kwargs = self._remove_pk_from_kwargs(new_kwargs) new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs) new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs) new_kwargs = self._populate_default_values(new_kwargs)
new_kwargs = self._translate_columns_to_aliases(new_kwargs) new_kwargs = self.model.translate_columns_to_aliases(new_kwargs)
return new_kwargs return new_kwargs
def _populate_default_values(self, new_kwargs: dict) -> dict: def _populate_default_values(self, new_kwargs: dict) -> dict:
@ -83,22 +83,6 @@ class QuerySet:
new_kwargs[field_name] = field.get_default() new_kwargs[field_name] = field.get_default()
return new_kwargs return new_kwargs
def _translate_columns_to_aliases(self, new_kwargs: dict) -> dict:
for field_name, field in self.model_meta.model_fields.items():
if (
field_name in new_kwargs
and field.name is not None
and field.name != field_name
):
new_kwargs[field.name] = new_kwargs.pop(field_name)
return new_kwargs
def _translate_aliases_to_columns(self, new_kwargs: dict) -> dict:
for field_name, field in self.model_meta.model_fields.items():
if field.name in new_kwargs and field.name != field_name:
new_kwargs[field_name] = new_kwargs.pop(field.name)
return new_kwargs
def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict: def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict:
pkname = self.model_meta.pkname pkname = self.model_meta.pkname
pk = self.model_meta.model_fields[pkname] pk = self.model_meta.model_fields[pkname]
@ -207,7 +191,7 @@ class QuerySet:
async def update(self, each: bool = False, **kwargs: Any) -> int: async def update(self, each: bool = False, **kwargs: Any) -> int:
self_fields = self.model.extract_db_own_fields() self_fields = self.model.extract_db_own_fields()
updates = {k: v for k, v in kwargs.items() if k in self_fields} updates = {k: v for k, v in kwargs.items() if k in self_fields}
updates = self._translate_columns_to_aliases(updates) updates = self.model.translate_columns_to_aliases(updates)
if not each and not self.filter_clauses: if not each and not self.filter_clauses:
raise QueryDefinitionError( raise QueryDefinitionError(
"You cannot update without filtering the queryset first. " "You cannot update without filtering the queryset first. "
@ -353,7 +337,7 @@ class QuerySet:
f"{self.model.__name__} has to have {pk_name} filled." f"{self.model.__name__} has to have {pk_name} filled."
) )
new_kwargs = self.model.substitute_models_with_pks(new_kwargs) new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._translate_columns_to_aliases(new_kwargs) new_kwargs = self.model.translate_columns_to_aliases(new_kwargs)
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns} new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
ready_objects.append(new_kwargs) ready_objects.append(new_kwargs)

View File

@ -40,6 +40,8 @@ class RelationsManager:
to=field.to, to=field.to,
through=getattr(field, "through", None), through=getattr(field, "through", None),
) )
if field.name not in self._related_names:
self._related_names.append(field.name)
def __contains__(self, item: str) -> bool: def __contains__(self, item: str) -> bool:
return item in self._related_names return item in self._related_names

View File

@ -72,4 +72,6 @@ class RelationProxy(list):
if self.relation._type == ormar.RelationType.MULTIPLE: if self.relation._type == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item) await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner) rel_name = item.resolve_relation_name(item, self._owner)
if rel_name not in item._orm:
item._orm._add_relation(item.Meta.model_fields[rel_name])
setattr(item, rel_name, self._owner) setattr(item, rel_name, self._owner)

View File

@ -15,10 +15,10 @@ class Child(ormar.Model):
metadata = metadata metadata = metadata
database = database database = database
id: ormar.Integer(name='child_id', primary_key=True) id: ormar.Integer(name="child_id", primary_key=True)
first_name: ormar.String(name='fname', max_length=100) first_name: ormar.String(name="fname", max_length=100)
last_name: ormar.String(name='lname', max_length=100) last_name: ormar.String(name="lname", max_length=100)
born_year: ormar.Integer(name='year_born', nullable=True) born_year: ormar.Integer(name="year_born", nullable=True)
class ArtistChildren(ormar.Model): class ArtistChildren(ormar.Model):
@ -34,10 +34,10 @@ class Artist(ormar.Model):
metadata = metadata metadata = metadata
database = database database = database
id: ormar.Integer(name='artist_id', primary_key=True) id: ormar.Integer(name="artist_id", primary_key=True)
first_name: ormar.String(name='fname', max_length=100) first_name: ormar.String(name="fname", max_length=100)
last_name: ormar.String(name='lname', max_length=100) last_name: ormar.String(name="lname", max_length=100)
born_year: ormar.Integer(name='year') born_year: ormar.Integer(name="year")
children: ormar.ManyToMany(Child, through=ArtistChildren) children: ormar.ManyToMany(Child, through=ArtistChildren)
@ -47,9 +47,9 @@ class Album(ormar.Model):
metadata = metadata metadata = metadata
database = database database = database
id: ormar.Integer(name='album_id', primary_key=True) id: ormar.Integer(name="album_id", primary_key=True)
name: ormar.String(name='album_name', max_length=100) name: ormar.String(name="album_name", max_length=100)
artist: ormar.ForeignKey(Artist, name='artist_id') artist: ormar.ForeignKey(Artist, name="artist_id")
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
@ -62,70 +62,87 @@ def create_test_database():
def test_table_structure(): def test_table_structure():
assert 'album_id' in [x.name for x in Album.Meta.table.columns] assert "album_id" in [x.name for x in Album.Meta.table.columns]
assert 'album_name' in [x.name for x in Album.Meta.table.columns] assert "album_name" in [x.name for x in Album.Meta.table.columns]
assert 'fname' in [x.name for x in Artist.Meta.table.columns] assert "fname" in [x.name for x in Artist.Meta.table.columns]
assert 'lname' in [x.name for x in Artist.Meta.table.columns] assert "lname" in [x.name for x in Artist.Meta.table.columns]
assert 'year' in [x.name for x in Artist.Meta.table.columns] assert "year" in [x.name for x in Artist.Meta.table.columns]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_working_with_aliases(): async def test_working_with_aliases():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
artist = await Artist.objects.create(first_name='Ted', last_name='Mosbey', born_year=1975) artist = await Artist.objects.create(
first_name="Ted", last_name="Mosbey", born_year=1975
)
await Album.objects.create(name="Aunt Robin", artist=artist) await Album.objects.create(name="Aunt Robin", artist=artist)
await artist.children.create(first_name='Son', last_name='1', born_year=1990) await artist.children.create(
await artist.children.create(first_name='Son', last_name='2', born_year=1995) first_name="Son", last_name="1", born_year=1990
)
await artist.children.create(
first_name="Son", last_name="2", born_year=1995
)
album = await Album.objects.select_related('artist').first() album = await Album.objects.select_related("artist").first()
assert album.artist.last_name == 'Mosbey' assert album.artist.last_name == "Mosbey"
assert album.artist.id is not None assert album.artist.id is not None
assert album.artist.first_name == 'Ted' assert album.artist.first_name == "Ted"
assert album.artist.born_year == 1975 assert album.artist.born_year == 1975
assert album.name == 'Aunt Robin' assert album.name == "Aunt Robin"
artist = await Artist.objects.select_related('children').get() artist = await Artist.objects.select_related("children").get()
assert len(artist.children) == 2 assert len(artist.children) == 2
assert artist.children[0].first_name == 'Son' assert artist.children[0].first_name == "Son"
assert artist.children[1].last_name == '2' assert artist.children[1].last_name == "2"
await artist.update(last_name='Bundy') await artist.update(last_name="Bundy")
await Artist.objects.filter(pk=artist.pk).update(born_year=1974) await Artist.objects.filter(pk=artist.pk).update(born_year=1974)
artist = await Artist.objects.select_related('children').get() artist = await Artist.objects.select_related("children").get()
assert artist.last_name == 'Bundy' assert artist.last_name == "Bundy"
assert artist.born_year == 1974 assert artist.born_year == 1974
artist = await Artist.objects.select_related('children').fields( artist = (
['first_name', 'last_name', 'born_year', 'child__first_name', 'child__last_name']).get() await Artist.objects.select_related("children")
.fields(
[
"first_name",
"last_name",
"born_year",
"child__first_name",
"child__last_name",
]
)
.get()
)
assert artist.children[0].born_year is None assert artist.children[0].born_year is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_operations_and_fields(): async def test_bulk_operations_and_fields():
async with database: async with database:
d1 = Child(first_name='Daughter', last_name='1', born_year=1990) d1 = Child(first_name="Daughter", last_name="1", born_year=1990)
d2 = Child(first_name='Daughter', last_name='2', born_year=1991) d2 = Child(first_name="Daughter", last_name="2", born_year=1991)
await Child.objects.bulk_create([d1, d2]) await Child.objects.bulk_create([d1, d2])
children = await Child.objects.filter(first_name='Daughter').all() children = await Child.objects.filter(first_name="Daughter").all()
assert len(children) == 2 assert len(children) == 2
assert children[0].last_name == '1' assert children[0].last_name == "1"
for child in children: for child in children:
child.born_year = child.born_year - 100 child.born_year = child.born_year - 100
await Child.objects.bulk_update(children) await Child.objects.bulk_update(children)
children = await Child.objects.filter(first_name='Daughter').all() children = await Child.objects.filter(first_name="Daughter").all()
assert len(children) == 2 assert len(children) == 2
assert children[0].born_year == 1890 assert children[0].born_year == 1890
children = await Child.objects.fields(['first_name', 'last_name']).all() children = await Child.objects.fields(["first_name", "last_name"]).all()
assert len(children) == 2 assert len(children) == 2
for child in children: for child in children:
assert child.born_year is None assert child.born_year is None
@ -140,17 +157,21 @@ async def test_bulk_operations_and_fields():
async def test_working_with_aliases_get_or_create(): async def test_working_with_aliases_get_or_create():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
artist = await Artist.objects.get_or_create(first_name='Teddy', last_name='Bear', born_year=2020) artist = await Artist.objects.get_or_create(
first_name="Teddy", last_name="Bear", born_year=2020
)
assert artist.pk is not None assert artist.pk is not None
artist2 = await Artist.objects.get_or_create(first_name='Teddy', last_name='Bear', born_year=2020) artist2 = await Artist.objects.get_or_create(
first_name="Teddy", last_name="Bear", born_year=2020
)
assert artist == artist2 assert artist == artist2
art3 = artist2.dict() art3 = artist2.dict()
art3['born_year'] = 2019 art3["born_year"] = 2019
await Artist.objects.update_or_create(**art3) await Artist.objects.update_or_create(**art3)
artist3 = await Artist.objects.get(last_name='Bear') artist3 = await Artist.objects.get(last_name="Bear")
assert artist3.born_year == 2019 assert artist3.born_year == 2019
artists = await Artist.objects.all() artists = await Artist.objects.all()

135
tests/test_fastapi_docs.py Normal file
View File

@ -0,0 +1,135 @@
from typing import List
import databases
import pytest
import sqlalchemy
from fastapi import FastAPI
from starlette.testclient import TestClient
import ormar
from tests.settings import DATABASE_URL
app = FastAPI()
metadata = sqlalchemy.MetaData()
database = databases.Database(DATABASE_URL, force_rollback=True)
app.state.database = database
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
class LocalMeta:
metadata = metadata
database = database
class Category(ormar.Model):
class Meta(LocalMeta):
tablename = "categories"
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
class ItemsXCategories(ormar.Model):
class Meta(LocalMeta):
tablename = "items_x_categories"
class Item(ormar.Model):
class Meta(LocalMeta):
pass
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
categories: ormar.ManyToMany(Category, through=ItemsXCategories)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@app.get("/items/", response_model=List[Item])
async def get_items():
items = await Item.objects.select_related("categories").all()
return items
@app.post("/items/", response_model=Item)
async def create_item(item: Item):
await item.save()
return item
@app.post("/items/add_category/", response_model=Item)
async def create_item(item: Item, category: Category):
await item.categories.add(category)
return item
@app.post("/categories/", response_model=Category)
async def create_category(category: Category):
await category.save()
return category
def test_all_endpoints():
client = TestClient(app)
with client as client:
response = client.post("/categories/", json={"name": "test cat"})
category = response.json()
response = client.post("/categories/", json={"name": "test cat2"})
category2 = response.json()
response = client.post("/items/", json={"name": "test", "id": 1})
item = Item(**response.json())
assert item.pk is not None
response = client.post(
"/items/add_category/", json={"item": item.dict(), "category": category}
)
item = Item(**response.json())
assert len(item.categories) == 1
assert item.categories[0].name == "test cat"
client.post(
"/items/add_category/", json={"item": item.dict(), "category": category2}
)
response = client.get("/items/")
items = [Item(**item) for item in response.json()]
assert items[0] == item
assert len(items[0].categories) == 2
assert items[0].categories[0].name == "test cat"
assert items[0].categories[1].name == "test cat2"
response = client.get("/docs/")
assert response.status_code == 200
assert b"<title>FastAPI - Swagger UI</title>" in response.content
def test_schema_modification():
schema = Item.schema()
assert schema["properties"]["categories"]["type"] == "array"
assert schema["properties"]["categories"]["title"] == "Categories"
def test_schema_gen():
schema = app.openapi()
assert "Category" in schema["components"]["schemas"]
assert "Item" in schema["components"]["schemas"]

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import uuid
from datetime import datetime from datetime import datetime
from typing import List from typing import List
@ -6,7 +7,6 @@ import databases
import pydantic import pydantic
import pytest import pytest
import sqlalchemy import sqlalchemy
import uuid
import ormar import ormar
from ormar.exceptions import QueryDefinitionError, NoMatch from ormar.exceptions import QueryDefinitionError, NoMatch

View File

@ -77,13 +77,15 @@ async def create_category(category: Category):
@app.put("/items/{item_id}") @app.put("/items/{item_id}")
async def get_item(item_id: int, item: Item): async def update_item(item_id: int, item: Item):
item_db = await Item.objects.get(pk=item_id) item_db = await Item.objects.get(pk=item_id)
return await item_db.update(**item.dict()) return await item_db.update(**item.dict())
@app.delete("/items/{item_id}") @app.delete("/items/{item_id}")
async def delete_item(item_id: int, item: Item): async def delete_item(item_id: int, item: Item = None):
if item:
return {"deleted_rows": await item.delete()}
item_db = await Item.objects.get(pk=item_id) item_db = await Item.objects.get(pk=item_id)
return {"deleted_rows": await item_db.delete()} return {"deleted_rows": await item_db.delete()}
@ -111,8 +113,20 @@ def test_all_endpoints():
items = [Item(**item) for item in response.json()] items = [Item(**item) for item in response.json()]
assert items[0].name == "New name" assert items[0].name == "New name"
response = client.delete(f"/items/{item.pk}", json=item.dict()) response = client.delete(f"/items/{item.pk}")
assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__" assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__"
response = client.get("/items/") response = client.get("/items/")
items = response.json() items = response.json()
assert len(items) == 0 assert len(items) == 0
client.post("/items/", json={"name": "test_2", "id": 2, "category": category})
response = client.get("/items/")
items = response.json()
assert len(items) == 1
item = Item(**items[0])
response = client.delete(f"/items/{item.pk}", json=item.dict())
assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__"
response = client.get("/docs/")
assert response.status_code == 200