diff --git a/ormar/__init__.py b/ormar/__init__.py index e5e9262..f3cdedb 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.4.1" +__version__ = "0.4.2" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 26334b4..d2ffb44 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -1,6 +1,7 @@ from typing import Any, List, Optional, TYPE_CHECKING, Type, Union import sqlalchemy +from pydantic import BaseModel, create_model from sqlalchemy import UniqueConstraint import ormar # noqa I101 @@ -9,6 +10,7 @@ from ormar.fields.base import BaseField if TYPE_CHECKING: # pragma no cover from ormar.models import Model, NewBaseModel + from ormar.fields import ManyToManyField def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": @@ -23,6 +25,15 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": return fk(**init_dict) +def create_dummy_model( + base_model: Type["Model"], + pk_field: Type[Union[BaseField, "ForeignKeyField", "ManyToManyField"]], +) -> Type["BaseModel"]: + fields = {f"{pk_field.name}": (pk_field.__type__, None)} + dummy_model = create_model(f"PkOnly{base_model.get_name(lower=False)}", **fields) # type: ignore + return dummy_model + + class UniqueColumns(UniqueConstraint): pass @@ -40,10 +51,11 @@ def ForeignKey( # noqa CFQ002 ) -> Any: fk_string = to.Meta.tablename + "." + to.get_column_alias(to.Meta.pkname) to_field = to.Meta.model_fields[to.Meta.pkname] + pk_only_model = create_dummy_model(to, to_field) __type__ = ( - Union[to_field.__type__, to] + Union[to_field.__type__, to, pk_only_model] if not nullable - else Optional[Union[to_field.__type__, to]] + else Optional[Union[to_field.__type__, to, pk_only_model]] ) namespace = dict( __type__=__type__, diff --git a/tests/test_more_reallife_fastapi.py b/tests/test_more_reallife_fastapi.py index fc3b1f7..45dab4a 100644 --- a/tests/test_more_reallife_fastapi.py +++ b/tests/test_more_reallife_fastapi.py @@ -1,3 +1,4 @@ +import asyncio from typing import List, Optional import databases @@ -64,6 +65,12 @@ async def get_items(): return items +@app.get("/items/raw/", response_model=List[Item]) +async def get_raw_items(): + items = await Item.objects.all() + return items + + @app.post("/items/", response_model=Item) async def create_item(item: Item): await item.save() @@ -76,6 +83,12 @@ async def create_category(category: Category): return category +@app.get("/items/{item_id}") +async def get_item(item_id: int): + item = await Item.objects.get(pk=item_id) + return item + + @app.put("/items/{item_id}") async def update_item(item_id: int, item: Item): item_db = await Item.objects.get(pk=item_id) @@ -113,6 +126,19 @@ def test_all_endpoints(): items = [Item(**item) for item in response.json()] assert items[0].name == "New name" + response = client.get("/items/raw/") + items = [Item(**item) for item in response.json()] + assert items[0].name == "New name" + assert items[0].category.name is None + + loop = asyncio.get_event_loop() + loop.run_until_complete(items[0].category.load()) + assert items[0].category.name is not None + + response = client.get(f"/items/{item.pk}") + new_item = Item(**response.json()) + assert new_item == item + response = client.delete(f"/items/{item.pk}") assert response.json().get("deleted_rows", "__UNDEFINED__") != "__UNDEFINED__" response = client.get("/items/")