diff --git a/.coverage b/.coverage index ba7b5f4..997d622 100644 Binary files a/.coverage and b/.coverage differ diff --git a/orm/__init__.py b/orm/__init__.py index 9c652ae..a0acd43 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -1,6 +1,7 @@ from orm.fields import Integer, BigInteger, Boolean, Time, Text, String, JSON, DateTime, Date, Decimal, Float, \ ForeignKey from orm.models import Model +from orm.exceptions import ModelDefinitionError, MultipleMatches, NoMatch, ModelNotSet __all__ = [ "Integer", diff --git a/orm/helpers.py b/orm/helpers.py index 6e3d254..f8d7cfb 100644 --- a/orm/helpers.py +++ b/orm/helpers.py @@ -23,5 +23,4 @@ class Excludable: # pragma no cover return key in to_exclude elif to_exclude is ...: return True - else: - return False + return False diff --git a/orm/models.py b/orm/models.py index e19474c..7eca3b9 100644 --- a/orm/models.py +++ b/orm/models.py @@ -2,6 +2,7 @@ import copy import inspect import json import uuid +from abc import ABCMeta from typing import Any, List, Type from typing import Set, Dict @@ -173,10 +174,15 @@ class QuerySet: offset=self.query_offset ) - # async def exists(self) -> bool: - # expr = self.build_select_expression() - # expr = sqlalchemy.exists(expr).select() - # return await self.database.fetch_val(expr) + async def exists(self) -> bool: + expr = self.build_select_expression() + expr = sqlalchemy.exists(expr).select() + return await self.database.fetch_val(expr) + + async def count(self) -> int: + expr = self.build_select_expression().alias("subquery_for_count") + expr = sqlalchemy.func.count().select().select_from(expr) + return await self.database.fetch_val(expr) def limit(self, limit_count: int): return self.__class__( @@ -196,6 +202,14 @@ class QuerySet: offset=offset ) + async def first(self, **kwargs): + if kwargs: + return await self.filter(**kwargs).first() + + rows = await self.limit(1).all() + if rows: + return rows[0] + async def get(self, **kwargs): if kwargs: return await self.filter(**kwargs).get() @@ -287,7 +301,6 @@ class ModelMetaclass(type): attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__) attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__) attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__) - attrs['__model_fields__'] = model_fields new_model = super().__new__( # type: ignore @@ -297,7 +310,7 @@ class ModelMetaclass(type): return new_model -class Model(metaclass=ModelMetaclass): +class Model(tuple, metaclass=ModelMetaclass): __abstract__ = True objects = QuerySet() @@ -338,9 +351,11 @@ class Model(metaclass=ModelMetaclass): except TypeError: # pragma no cover pass return item - return super().__getattribute__(key) + def __eq__(self, other): + return self.values.dict() == other.values.dict() + def __repr__(self): # pragma no cover return self.values.__repr__() @@ -380,6 +395,18 @@ class Model(metaclass=ModelMetaclass): return cls(**item) + @classmethod + def validate(cls: Type['Model'], value: Any) -> 'Model': # pragma no cover + return cls.__pydantic_model__.validate(cls.__pydantic_model__.__class__, value) + + @classmethod + def __get_validators__(cls): # pragma no cover + yield cls.__pydantic_model__.validate + + @classmethod + def schema(cls, by_alias: bool = True): # pragma no cover + return cls.__pydantic_model__.schame(cls.__pydantic_model__, by_alias=by_alias) + def is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json diff --git a/orm/relations.py b/orm/relations.py index 1bd70cf..8e5f3dd 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -1,6 +1,6 @@ from typing import Dict, Union, List -from sqlalchemy import text +from orm.exceptions import RelationshipNotFound class Relationship: @@ -41,5 +41,4 @@ class RelationshipManager: if rel == name: if relations and relations[0].fk_side == 'parent': return relations[0].child - else: - return [rela.child for rela in relations] + return [rela.child for rela in relations] diff --git a/requirements.txt b/requirements.txt index f6280a7..db5ae51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ sqlalchemy pytest pytest-cov codecov -pytest-asyncio \ No newline at end of file +pytest-asyncio +fastapi \ No newline at end of file diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py new file mode 100644 index 0000000..ae647a0 --- /dev/null +++ b/tests/test_fastapi_usage.py @@ -0,0 +1,42 @@ +import json +from typing import Optional + +import databases +import pydantic +import sqlalchemy +from fastapi import FastAPI +from fastapi.testclient import TestClient + +app = FastAPI() + +import orm +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Item(orm.Model): + __tablename__ = "users" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +@app.post("/items/", response_model=Item) +async def create_item(item: Item): + return item + + +client = TestClient(app) + + +def test_read_main(): + response = client.post("/items/", json={'name': 'test', 'id': 1}) + print(response.json()) + assert response.status_code == 200 + assert response.json() == {'name': 'test', 'id': 1} + item = Item(**response.json()) + assert item.id == 1 diff --git a/tests/test_models.py b/tests/test_models.py index e69de29..cf12c3e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -0,0 +1,200 @@ +import databases +import pytest +import sqlalchemy + +import orm +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class User(orm.Model): + __tablename__ = "users" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Product(orm.Model): + __tablename__ = "product" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + rating = orm.Integer(minimum=1, maximum=5) + in_stock = orm.Boolean(default=False) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_model_class(): + assert list(User.__model_fields__.keys()) == ["id", "name"] + assert isinstance(User.__model_fields__["id"], orm.Integer) + assert User.__model_fields__["id"].primary_key is True + assert isinstance(User.__model_fields__["name"], orm.String) + assert User.__model_fields__["name"].length == 100 + assert isinstance(User.__table__, sqlalchemy.Table) + + +def test_model_pk(): + user = User(pk=1) + assert user.pk == 1 + assert user.id == 1 + + +@pytest.mark.asyncio +async def test_model_crud(): + async with database: + users = await User.objects.all() + assert users == [] + + user = await User.objects.create(name="Tom") + users = await User.objects.all() + assert user.name == "Tom" + assert user.pk is not None + assert users == [user] + + lookup = await User.objects.get() + assert lookup == user + + await user.update(name="Jane") + users = await User.objects.all() + assert user.name == "Jane" + assert user.pk is not None + assert users == [user] + + await user.delete() + users = await User.objects.all() + assert users == [] + + +@pytest.mark.asyncio +async def test_model_get(): + async with database: + with pytest.raises(orm.NoMatch): + await User.objects.get() + + user = await User.objects.create(name="Tom") + lookup = await User.objects.get() + assert lookup == user + + user = await User.objects.create(name="Jane") + with pytest.raises(orm.MultipleMatches): + await User.objects.get() + + same_user = await User.objects.get(pk=user.id) + assert same_user.id == user.id + assert same_user.pk == user.pk + + +@pytest.mark.asyncio +async def test_model_filter(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") + + user = await User.objects.get(name="Lucy") + assert user.name == "Lucy" + + with pytest.raises(orm.NoMatch): + await User.objects.get(name="Jim") + + await Product.objects.create(name="T-Shirt", rating=5, in_stock=True) + await Product.objects.create(name="Dress", rating=4) + await Product.objects.create(name="Coat", rating=3, in_stock=True) + + product = await Product.objects.get(name__iexact="t-shirt", rating=5) + assert product.pk is not None + assert product.name == "T-Shirt" + assert product.rating == 5 + + products = await Product.objects.all(rating__gte=2, in_stock=True) + assert len(products) == 2 + + products = await Product.objects.all(name__icontains="T") + assert len(products) == 2 + + # Test escaping % character from icontains, contains, and iexact + await Product.objects.create(name="100%-Cotton", rating=3) + await Product.objects.create(name="Cotton-100%-Egyptian", rating=3) + await Product.objects.create(name="Cotton-100%", rating=3) + products = Product.objects.filter(name__iexact="100%-cotton") + assert await products.count() == 1 + + products = Product.objects.filter(name__contains="%") + assert await products.count() == 3 + + products = Product.objects.filter(name__icontains="%") + assert await products.count() == 3 + + +@pytest.mark.asyncio +async def test_model_exists(): + async with database: + await User.objects.create(name="Tom") + assert await User.objects.filter(name="Tom").exists() is True + assert await User.objects.filter(name="Jane").exists() is False + + +@pytest.mark.asyncio +async def test_model_count(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") + + assert await User.objects.count() == 3 + assert await User.objects.filter(name__icontains="T").count() == 1 + + +@pytest.mark.asyncio +async def test_model_limit(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") + + assert len(await User.objects.limit(2).all()) == 2 + + +@pytest.mark.asyncio +async def test_model_limit_with_filter(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Tom") + await User.objects.create(name="Tom") + + assert len(await User.objects.limit(2).filter(name__iexact='Tom').all()) == 2 + + +@pytest.mark.asyncio +async def test_offset(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + + users = await User.objects.offset(1).limit(1).all() + assert users[0].name == 'Jane' + + +@pytest.mark.asyncio +async def test_model_first(): + async with database: + tom = await User.objects.create(name="Tom") + jane = await User.objects.create(name="Jane") + + assert await User.objects.first() == tom + assert await User.objects.first(name="Jane") == jane + assert await User.objects.filter(name="Jane").first() == jane + assert await User.objects.filter(name="Lucy").first() is None