From eb99f28431e0c86fc277b200c6731676777662bf Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 4 Aug 2020 21:37:25 +0200 Subject: [PATCH] added hack to pass as pydantic model in fastapi, tests for fastapi --- .coverage | Bin 53248 -> 53248 bytes orm/__init__.py | 1 + orm/helpers.py | 3 +- orm/models.py | 41 ++++++-- orm/relations.py | 5 +- requirements.txt | 3 +- tests/test_fastapi_usage.py | 42 ++++++++ tests/test_models.py | 200 ++++++++++++++++++++++++++++++++++++ 8 files changed, 282 insertions(+), 13 deletions(-) create mode 100644 tests/test_fastapi_usage.py diff --git a/.coverage b/.coverage index ba7b5f4a6b067d03e366db7864d285fedeb1619e..997d62288b92e3b8a186eb775b7686e374522c00 100644 GIT binary patch delta 667 zcmZozz}&Ead4rKYhmoO`v6+>r@n$RiQv!0lysH`bukbJ7ui_8lSL6H5cbRWFUm>3j zA2aV2-qo801v+_I)p?m2LMQo4P4eSr;bCS-oa8SydA~pRWCK4g7H(#SQUsq>I6fsc zEi*5(Br`uxub|SHi$h7stxH4OaK{ObI{yleQb@crSx%C~}lDPIwvEguW-RiNv-cqc!K)nMh-V@_m* z*<;|x$;xBFT#6)s?k;W<=3KZm$X#5f%=s_@40myAGe<)uf$oa4QIz4~`^3Qilm8X} z1OEN|TlrV;C-Fz|`|@k@3-f&fy6O_2ojeZXO{N%_UuYrKv6fY$qxOp_1v65wgl(5t(C=} z*x5LM96k;%7j{-g&PERQfA98uw*R($_wU=em+yYeJ8twX?|X6mz2Bdozx#G~xlQ`H z^Y84<|NCD1Zuh!+8-u~m{Ofn#yeoU}UjM(ke!o078^~xjulafP+vPb}fgB;$w6rvN z79h)kC2w>6-M8;1pXrxozV)Af@~eI+86N(14E(?Ozwp21f5v~G{|5g>{!{#i`S$=l zzYgf(K7MwPiHtlvo43yQSCHf7-w#s$jsGM68~*3~kNEEbwOr;u$A6Ol2>*VdhPnK# N%)nH|JNd|Y8vw&}><$0` delta 464 zcmZozz}&Ead4rKYhoOa)iLsTD`DQEqQvxzPybl@pukbJ7ui_8lSL6H5cbRWFUlE@r z9~bY#&4L04d0EwYm>EJRJNk)DKIY5M!p+Q(2;?YE{^iTX!o|!`3g&3~i7;|bw(~P) z)#6}ghz5$uY(C*%WuiA%r6Zz{V$&?8#fChCm%PH2p5pn;wqc|fBw%opYK0= zW@&G4&#uG?6m{dAY|t+o!Nvhn$-(8w&dSKy$i!Y(e($IJ@7s6(zTNinUB&if$?EOD zciKNVKmT_5{7-W%?RWpb^Cx@LzVDIugBcS(-CuY2?z_J)-q*i>|DBne4QP}Eo7eoj z`tACYxAq$`zxvNU`BlFZ(79U~_9W4gQP#r}z)^@8RDHbnbNi&0FXD pE6DKh9|0-)#{ZH34gYigNBnnz3NQ1Y<3GuN1gLBY|Kua*Z2 bool: - # expr = self.build_select_expression() - # expr = sqlalchemy.exists(expr).select() - # return await self.database.fetch_val(expr) + async def exists(self) -> bool: + expr = self.build_select_expression() + expr = sqlalchemy.exists(expr).select() + return await self.database.fetch_val(expr) + + async def count(self) -> int: + expr = self.build_select_expression().alias("subquery_for_count") + expr = sqlalchemy.func.count().select().select_from(expr) + return await self.database.fetch_val(expr) def limit(self, limit_count: int): return self.__class__( @@ -196,6 +202,14 @@ class QuerySet: offset=offset ) + async def first(self, **kwargs): + if kwargs: + return await self.filter(**kwargs).first() + + rows = await self.limit(1).all() + if rows: + return rows[0] + async def get(self, **kwargs): if kwargs: return await self.filter(**kwargs).get() @@ -287,7 +301,6 @@ class ModelMetaclass(type): attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__) attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__) attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__) - attrs['__model_fields__'] = model_fields new_model = super().__new__( # type: ignore @@ -297,7 +310,7 @@ class ModelMetaclass(type): return new_model -class Model(metaclass=ModelMetaclass): +class Model(tuple, metaclass=ModelMetaclass): __abstract__ = True objects = QuerySet() @@ -338,9 +351,11 @@ class Model(metaclass=ModelMetaclass): except TypeError: # pragma no cover pass return item - return super().__getattribute__(key) + def __eq__(self, other): + return self.values.dict() == other.values.dict() + def __repr__(self): # pragma no cover return self.values.__repr__() @@ -380,6 +395,18 @@ class Model(metaclass=ModelMetaclass): return cls(**item) + @classmethod + def validate(cls: Type['Model'], value: Any) -> 'Model': # pragma no cover + return cls.__pydantic_model__.validate(cls.__pydantic_model__.__class__, value) + + @classmethod + def __get_validators__(cls): # pragma no cover + yield cls.__pydantic_model__.validate + + @classmethod + def schema(cls, by_alias: bool = True): # pragma no cover + return cls.__pydantic_model__.schame(cls.__pydantic_model__, by_alias=by_alias) + def is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json diff --git a/orm/relations.py b/orm/relations.py index 1bd70cf..8e5f3dd 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -1,6 +1,6 @@ from typing import Dict, Union, List -from sqlalchemy import text +from orm.exceptions import RelationshipNotFound class Relationship: @@ -41,5 +41,4 @@ class RelationshipManager: if rel == name: if relations and relations[0].fk_side == 'parent': return relations[0].child - else: - return [rela.child for rela in relations] + return [rela.child for rela in relations] diff --git a/requirements.txt b/requirements.txt index f6280a7..db5ae51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ sqlalchemy pytest pytest-cov codecov -pytest-asyncio \ No newline at end of file +pytest-asyncio +fastapi \ No newline at end of file diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py new file mode 100644 index 0000000..ae647a0 --- /dev/null +++ b/tests/test_fastapi_usage.py @@ -0,0 +1,42 @@ +import json +from typing import Optional + +import databases +import pydantic +import sqlalchemy +from fastapi import FastAPI +from fastapi.testclient import TestClient + +app = FastAPI() + +import orm +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Item(orm.Model): + __tablename__ = "users" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +@app.post("/items/", response_model=Item) +async def create_item(item: Item): + return item + + +client = TestClient(app) + + +def test_read_main(): + response = client.post("/items/", json={'name': 'test', 'id': 1}) + print(response.json()) + assert response.status_code == 200 + assert response.json() == {'name': 'test', 'id': 1} + item = Item(**response.json()) + assert item.id == 1 diff --git a/tests/test_models.py b/tests/test_models.py index e69de29..cf12c3e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -0,0 +1,200 @@ +import databases +import pytest +import sqlalchemy + +import orm +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class User(orm.Model): + __tablename__ = "users" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Product(orm.Model): + __tablename__ = "product" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + rating = orm.Integer(minimum=1, maximum=5) + in_stock = orm.Boolean(default=False) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_model_class(): + assert list(User.__model_fields__.keys()) == ["id", "name"] + assert isinstance(User.__model_fields__["id"], orm.Integer) + assert User.__model_fields__["id"].primary_key is True + assert isinstance(User.__model_fields__["name"], orm.String) + assert User.__model_fields__["name"].length == 100 + assert isinstance(User.__table__, sqlalchemy.Table) + + +def test_model_pk(): + user = User(pk=1) + assert user.pk == 1 + assert user.id == 1 + + +@pytest.mark.asyncio +async def test_model_crud(): + async with database: + users = await User.objects.all() + assert users == [] + + user = await User.objects.create(name="Tom") + users = await User.objects.all() + assert user.name == "Tom" + assert user.pk is not None + assert users == [user] + + lookup = await User.objects.get() + assert lookup == user + + await user.update(name="Jane") + users = await User.objects.all() + assert user.name == "Jane" + assert user.pk is not None + assert users == [user] + + await user.delete() + users = await User.objects.all() + assert users == [] + + +@pytest.mark.asyncio +async def test_model_get(): + async with database: + with pytest.raises(orm.NoMatch): + await User.objects.get() + + user = await User.objects.create(name="Tom") + lookup = await User.objects.get() + assert lookup == user + + user = await User.objects.create(name="Jane") + with pytest.raises(orm.MultipleMatches): + await User.objects.get() + + same_user = await User.objects.get(pk=user.id) + assert same_user.id == user.id + assert same_user.pk == user.pk + + +@pytest.mark.asyncio +async def test_model_filter(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") + + user = await User.objects.get(name="Lucy") + assert user.name == "Lucy" + + with pytest.raises(orm.NoMatch): + await User.objects.get(name="Jim") + + await Product.objects.create(name="T-Shirt", rating=5, in_stock=True) + await Product.objects.create(name="Dress", rating=4) + await Product.objects.create(name="Coat", rating=3, in_stock=True) + + product = await Product.objects.get(name__iexact="t-shirt", rating=5) + assert product.pk is not None + assert product.name == "T-Shirt" + assert product.rating == 5 + + products = await Product.objects.all(rating__gte=2, in_stock=True) + assert len(products) == 2 + + products = await Product.objects.all(name__icontains="T") + assert len(products) == 2 + + # Test escaping % character from icontains, contains, and iexact + await Product.objects.create(name="100%-Cotton", rating=3) + await Product.objects.create(name="Cotton-100%-Egyptian", rating=3) + await Product.objects.create(name="Cotton-100%", rating=3) + products = Product.objects.filter(name__iexact="100%-cotton") + assert await products.count() == 1 + + products = Product.objects.filter(name__contains="%") + assert await products.count() == 3 + + products = Product.objects.filter(name__icontains="%") + assert await products.count() == 3 + + +@pytest.mark.asyncio +async def test_model_exists(): + async with database: + await User.objects.create(name="Tom") + assert await User.objects.filter(name="Tom").exists() is True + assert await User.objects.filter(name="Jane").exists() is False + + +@pytest.mark.asyncio +async def test_model_count(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") + + assert await User.objects.count() == 3 + assert await User.objects.filter(name__icontains="T").count() == 1 + + +@pytest.mark.asyncio +async def test_model_limit(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") + + assert len(await User.objects.limit(2).all()) == 2 + + +@pytest.mark.asyncio +async def test_model_limit_with_filter(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Tom") + await User.objects.create(name="Tom") + + assert len(await User.objects.limit(2).filter(name__iexact='Tom').all()) == 2 + + +@pytest.mark.asyncio +async def test_offset(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + + users = await User.objects.offset(1).limit(1).all() + assert users[0].name == 'Jane' + + +@pytest.mark.asyncio +async def test_model_first(): + async with database: + tom = await User.objects.create(name="Tom") + jane = await User.objects.create(name="Jane") + + assert await User.objects.first() == tom + assert await User.objects.first(name="Jane") == jane + assert await User.objects.filter(name="Jane").first() == jane + assert await User.objects.filter(name="Lucy").first() is None