From b643c884ac7896efe5389154e5925246f360d9ae Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 24 May 2021 18:04:41 +0200 Subject: [PATCH] add fastapi tests for get_pydantic --- ormar/models/mixins/pydantic_mixin.py | 23 +++- .../test_excludes_with_get_pydantic.py | 117 ++++++++++++++++++ .../test_geting_the_pydantic_models.py | 47 ++++++- 3 files changed, 181 insertions(+), 6 deletions(-) create mode 100644 tests/test_fastapi/test_excludes_with_get_pydantic.py diff --git a/ormar/models/mixins/pydantic_mixin.py b/ormar/models/mixins/pydantic_mixin.py index 76af068..d9d5d7c 100644 --- a/ormar/models/mixins/pydantic_mixin.py +++ b/ormar/models/mixins/pydantic_mixin.py @@ -1,4 +1,15 @@ -from typing import Any, Callable, Dict, List, Set, TYPE_CHECKING, Type, Union, cast +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + TYPE_CHECKING, + Type, + Union, + cast, +) import pydantic from pydantic.fields import ModelField @@ -78,6 +89,7 @@ class PydanticMixin(RelationMixin): relation_map: Dict[str, Any], ) -> Any: field = cls.Meta.model_fields[name] + target: Any = None if field.is_relation and name in relation_map: # type: ignore target = field.to._convert_ormar_to_pydantic( include=cls._skip_ellipsis(include, name), @@ -87,9 +99,10 @@ class PydanticMixin(RelationMixin): ), ) if field.is_multi or field.virtual: - return List[target] # type: ignore - return target + target = List[target] # type: ignore elif not field.is_relation: defaults[name] = cls.__fields__[name].field_info - return field.__type__ - return None + target = field.__type__ + if target is not None and field.nullable: + target = Optional[target] + return target diff --git a/tests/test_fastapi/test_excludes_with_get_pydantic.py b/tests/test_fastapi/test_excludes_with_get_pydantic.py new file mode 100644 index 0000000..88e0c23 --- /dev/null +++ b/tests/test_fastapi/test_excludes_with_get_pydantic.py @@ -0,0 +1,117 @@ +from typing import Type, TypeVar + +import pytest +import sqlalchemy +from fastapi import FastAPI +from pydantic import BaseModel +from starlette.testclient import TestClient + +from tests.settings import DATABASE_URL +from tests.test_inheritance_and_pydantic_generation.test_geting_the_pydantic_models import ( # type: ignore + Category, + Item, + SelfRef, + MutualA, + MutualB, + database, + metadata, +) + +app = FastAPI() +app.state.database = database + + +@app.on_event("startup") +async def startup() -> None: + database_ = app.state.database + if not database_.is_connected: + await database_.connect() + + +@app.on_event("shutdown") +async def shutdown() -> None: + database_ = app.state.database + if database_.is_connected: + await database_.disconnect() + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@app.post("/categories/", response_model=Category) +async def create_category(category: Category.get_pydantic(exclude={"id"})): # type: ignore + return await Category(**category.dict()).save() + + +@app.post( + "/selfrefs/", + response_model=SelfRef.get_pydantic(exclude={"parent", "children__name"}), +) +async def create_selfref(selfref: SelfRef.get_pydantic(exclude={"children__name"})): # type: ignore + selfr = SelfRef(**selfref.dict()) + await selfr.save() + if selfr.children: + for child in selfr.children: + await child.upsert() + return selfr + + +@app.get("/selfrefs/{ref_id}/") +async def get_selfref(ref_id: int): + selfr = await SelfRef.objects.select_related("children").get(id=ref_id) + return selfr + + +def test_read_main(): + client = TestClient(app) + with client as client: + test_category = dict(name="Foo", id=12) + response = client.post("/categories/", json=test_category) + assert response.status_code == 200 + cat = Category(**response.json()) + assert cat.name == "Foo" + assert cat.id == 1 + assert cat.items == [] + + test_selfref = dict(id=1, name="test") + test_selfref2 = dict(id=2, name="test2", parent={"id": 1}) + test_selfref3 = dict(id=3, name="test3", children=[{"id": 1}]) + + response = client.post("/selfrefs/", json=test_selfref) + assert response.status_code == 200 + self_ref = SelfRef(**response.json()) + assert self_ref.id == 1 + assert self_ref.name == "test" + assert self_ref.parent is None + assert self_ref.children == [] + + response = client.post("/selfrefs/", json=test_selfref2) + assert response.status_code == 200 + self_ref = SelfRef(**response.json()) + assert self_ref.id == 2 + assert self_ref.name == "test2" + assert self_ref.parent is None + assert self_ref.children == [] + + response = client.post("/selfrefs/", json=test_selfref3) + assert response.status_code == 200 + self_ref = SelfRef(**response.json()) + assert self_ref.id == 3 + assert self_ref.name == "test3" + assert self_ref.parent is None + assert self_ref.children[0].dict() == {"id": 1} + + response = client.get("/selfrefs/3/") + assert response.status_code == 200 + check_children = SelfRef(**response.json()) + assert check_children.children[0].dict() == { + "children": [], + "id": 1, + "name": "test", + "parent": {"id": 3, "name": "test3"}, + } diff --git a/tests/test_inheritance_and_pydantic_generation/test_geting_the_pydantic_models.py b/tests/test_inheritance_and_pydantic_generation/test_geting_the_pydantic_models.py index 11623a0..06401b3 100644 --- a/tests/test_inheritance_and_pydantic_generation/test_geting_the_pydantic_models.py +++ b/tests/test_inheritance_and_pydantic_generation/test_geting_the_pydantic_models.py @@ -23,7 +23,7 @@ class SelfRef(ormar.Model): tablename = "self_refs" id: int = ormar.Integer(primary_key=True) - name: str = ormar.String(max_length=100) + name: str = ormar.String(max_length=100, default="selfref") parent = ormar.ForeignKey(ForwardRef("SelfRef"), related_name="children") @@ -47,6 +47,25 @@ class Item(ormar.Model): category: Optional[Category] = ormar.ForeignKey(Category, nullable=True) +class MutualA(ormar.Model): + class Meta(BaseMeta): + tablename = "mutual_a" + + id: int = ormar.Integer(primary_key=True) + mutual_b = ormar.ForeignKey(ForwardRef("MutualB"), related_name="mutuals_a") + + +class MutualB(ormar.Model): + class Meta(BaseMeta): + tablename = "mutual_b" + + id: int = ormar.Integer(primary_key=True) + mutual_a = ormar.ForeignKey(MutualA, related_name="mutuals_b") + + +MutualA.update_forward_refs() + + def test_getting_pydantic_model(): PydanticCategory = Category.get_pydantic() assert issubclass(PydanticCategory, pydantic.BaseModel) @@ -79,6 +98,10 @@ def test_initializing_pydantic_model(): cat = PydanticCategory(**data) assert cat.dict() == data + data = {"id": 1, "name": "test"} + cat = PydanticCategory(**data) + assert cat.dict() == {**data, "items": None} + def test_getting_pydantic_model_include(): PydanticCategory = Category.get_pydantic(include={"id", "name"}) @@ -155,3 +178,25 @@ def test_getting_pydantic_model_self_ref(): InnerSelf = PydanticSelfRef.__fields__["parent"].type_ assert len(InnerSelf.__fields__) == 2 assert set(InnerSelf.__fields__.keys()) == {"id", "name"} + + InnerSelf2 = PydanticSelfRef.__fields__["children"].type_ + assert len(InnerSelf2.__fields__) == 2 + assert set(InnerSelf2.__fields__.keys()) == {"id", "name"} + + assert InnerSelf2 != InnerSelf + + +def test_getting_pydantic_model_mutual_rels(): + MutualAPydantic = MutualA.get_pydantic() + assert len(MutualAPydantic.__fields__) == 3 + assert set(MutualAPydantic.__fields__.keys()) == {"id", "mutual_b", "mutuals_b"} + + MutualB1 = MutualAPydantic.__fields__["mutual_b"].type_ + MutualB2 = MutualAPydantic.__fields__["mutuals_b"].type_ + assert MutualB1 != MutualB2 + + assert len(MutualB1.__fields__) == 1 + assert "id" in MutualB1.__fields__ + + assert len(MutualB2.__fields__) == 1 + assert "id" in MutualB2.__fields__