added basic save, update, load and delate methods

This commit is contained in:
collerek
2020-08-03 19:59:04 +02:00
parent d7355b8c9b
commit e0bb7e2cda
10 changed files with 316 additions and 106 deletions

BIN
.coverage

Binary file not shown.

1
.gitignore vendored
View File

@ -3,3 +3,4 @@ p38venv
.pytest_cache
*.pyc
*.log
test.db

View File

@ -4,3 +4,15 @@ class AsyncOrmException(Exception):
class ModelDefinitionError(AsyncOrmException):
pass
class ModelNotSet(AsyncOrmException):
pass
class MultipleResults(AsyncOrmException):
pass
class RelationshipNotFound(AsyncOrmException):
pass

View File

@ -1,6 +1,6 @@
import datetime
import decimal
from typing import Any
from typing import Optional, List
import pydantic
import sqlalchemy
@ -11,11 +11,7 @@ from orm.exceptions import ModelDefinitionError
class BaseField:
__type__ = None
def __new__(cls, *args, **kwargs):
cls.__annotations__ = {}
return super().__new__(cls)
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
name = kwargs.pop('name', None)
args = list(args)
if args:
@ -28,7 +24,7 @@ class BaseField:
self.name = name
self.primary_key = kwargs.pop('primary_key', False)
self.autoincrement = kwargs.pop('autoincrement', 'auto')
self.autoincrement = kwargs.pop('autoincrement', self.primary_key)
self.nullable = kwargs.pop('nullable', not self.primary_key)
self.default = kwargs.pop('default', None)
@ -41,7 +37,7 @@ class BaseField:
if self.pydantic_only and self.primary_key:
raise ModelDefinitionError('Primary key column cannot be pydantic only.')
def get_column(self, name=None) -> sqlalchemy.Column:
def get_column(self, name: str = None) -> sqlalchemy.Column:
name = self.name or name
constraints = self.get_constraints()
return sqlalchemy.Column(
@ -60,7 +56,7 @@ class BaseField:
def get_column_type(self) -> sqlalchemy.types.TypeEngine:
raise NotImplementedError() # pragma: no cover
def get_constraints(self):
def get_constraints(self) -> Optional[List]:
return []

27
orm/helpers.py Normal file
View File

@ -0,0 +1,27 @@
from typing import Union, Set, Dict # pragma no cover
class Excludable: # pragma no cover
@staticmethod
def get_excluded(exclude: Union[Set, Dict, None], key: str = None):
# print(f'checking excluded for {key}', exclude)
if isinstance(exclude, dict):
if isinstance(exclude.get(key, {}), dict) and '__all__' in exclude.get(key, {}).keys():
return exclude.get(key).get('__all__')
return exclude.get(key, {})
return exclude
@staticmethod
def is_excluded(exclude: Union[Set, Dict, None], key: str = None):
if exclude is None:
return False
to_exclude = Excludable.get_excluded(exclude, key)
# print(f'to exclude for current key = {key}', to_exclude)
if isinstance(to_exclude, Set):
return key in to_exclude
elif to_exclude is ...:
return True
else:
return False

View File

@ -1,15 +1,21 @@
from __future__ import annotations
import json
from typing import Any
from typing import Set, Dict
import pydantic
import sqlalchemy
from pydantic import create_model
from pydantic import BaseConfig, create_model
from orm.exceptions import ModelDefinitionError
from orm.fields import BaseField
def parse_pydantic_field_from_model_fields(object_dict: dict):
pydantic_fields = {field_name: (
base_field.__type__,
... if (not base_field.nullable and not base_field.default) else (
... if (not base_field.nullable and not base_field.default and not base_field.primary_key) else (
base_field.default() if callable(base_field.default) else base_field.default)
)
for field_name, base_field in object_dict.items()
@ -33,26 +39,37 @@ class ModelMetaclass(type):
pkname = None
columns = []
model_fields = {}
for field_name, field in attrs.items():
if isinstance(field, BaseField) and not field.pydantic_only:
if field.primary_key:
pkname = field_name
columns.append(field.get_column(field_name))
if isinstance(field, BaseField):
model_fields[field_name] = field
if not field.pydantic_only:
if field.primary_key:
pkname = field_name
columns.append(field.get_column(field_name))
# sqlalchemy table creation
attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns)
attrs['__columns__'] = columns
attrs['__pkname__'] = pkname
if not pkname:
raise ModelDefinitionError(
'Table has to have a primary key.'
)
# pydantic model creation
pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
pydantic_model = create_model(name, **pydantic_fields)
config = type('Config', (BaseConfig,), {'orm_mode': True})
pydantic_model = create_model(name, __config__=config, **pydantic_fields)
attrs['__pydantic_fields__'] = pydantic_fields
attrs['__pydantic_model__'] = pydantic_model
attrs['__fields__'] = pydantic_model.__fields__
attrs['__signature__'] = pydantic_model.__signature__
attrs['__annotations__'] = pydantic_model.__annotations__
attrs['__model_fields__'] = model_fields
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs
)
@ -63,21 +80,36 @@ class ModelMetaclass(type):
class Model(metaclass=ModelMetaclass):
__abstract__ = True
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
if "pk" in kwargs:
kwargs[self.__pkname__] = kwargs.pop("pk")
self.values = self.__pydantic_model__(**kwargs)
def __setattr__(self, key, value):
def __setattr__(self, key: str, value: Any) -> None:
if key in self.__fields__:
if self.is_conversion_to_json_needed(key) and not isinstance(value, str):
try:
value = json.dumps(value)
except TypeError: # pragma no cover
pass
setattr(self.values, key, value)
else:
super().__setattr__(key, value)
def __getattribute__(self, item) -> Any:
if item != '__fields__' and item in self.__fields__:
return getattr(self.values, item)
return super().__getattribute__(item)
def __getattribute__(self, key: str) -> Any:
if key != '__fields__' and key in self.__fields__:
item = getattr(self.values, key)
if self.is_conversion_to_json_needed(key) and isinstance(item, str):
try:
item = json.loads(item)
except TypeError: # pragma no cover
pass
return item
return super().__getattribute__(key)
def is_conversion_to_json_needed(self, column_name: str) -> bool:
return self.__model_fields__.get(column_name).__type__ == pydantic.Json
@property
def pk(self):
@ -86,3 +118,69 @@ class Model(metaclass=ModelMetaclass):
@pk.setter
def pk(self, value):
setattr(self.values, self.__pkname__, value)
@property
def pk_column(self) -> sqlalchemy.Column:
return self.__table__.primary_key.columns.values()[0]
def dict(self) -> Dict:
return self.values.dict()
def from_dict(self, value_dict: Dict) -> None:
for key, value in value_dict.items():
setattr(self, key, value)
def extract_own_model_fields(self) -> Dict:
related_names = self.extract_related_names()
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
return self_fields
@classmethod
def extract_related_names(cls) -> Set:
related_names = set()
# for name, field in cls.__fields__.items():
# if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel):
# related_names.add(name)
# elif field.sub_fields and any(
# [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]):
# related_names.add(name)
return related_names
def extract_model_db_fields(self) -> Dict:
self_fields = self.extract_own_model_fields()
self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns}
return self_fields
async def save(self) -> int:
self_fields = self.extract_model_db_fields()
if self.__model_fields__.get(self.__pkname__).autoincrement:
self_fields.pop(self.__pkname__, None)
expr = self.__table__.insert()
expr = expr.values(**self_fields)
item_id = await self.__database__.execute(expr)
setattr(self, 'pk', item_id)
return item_id
async def update(self, **kwargs: Any) -> int:
if kwargs:
new_values = {**self.dict(), **kwargs}
self.from_dict(new_values)
self_fields = self.extract_model_db_fields()
self_fields.pop(self.__pkname__)
expr = self.__table__.update().values(**self_fields).where(
self.pk_column == getattr(self, self.__pkname__))
result = await self.__database__.execute(expr)
return result
async def delete(self) -> int:
expr = self.__table__.delete()
expr = expr.where(self.pk_column == (getattr(self, self.__pkname__)))
result = await self.__database__.execute(expr)
return result
async def load(self) -> Model:
expr = self.__table__.select().where(self.pk_column == self.pk)
row = await self.__database__.fetch_one(expr)
self.from_dict(dict(row))
return self

View File

@ -6,3 +6,4 @@ sqlalchemy
pytest
pytest-cov
codecov
pytest-asyncio

3
tests/settings.py Normal file
View File

@ -0,0 +1,3 @@
import os
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///test.db")

View File

@ -1,100 +1,58 @@
import datetime
from typing import ClassVar
import pydantic
import databases
import pytest
import sqlalchemy
import orm.fields as fields
from orm.exceptions import ModelDefinitionError
from orm.models import Model
import orm
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class ExampleModel(Model):
def time():
return datetime.datetime.now().time()
class Example(orm.Model):
__tablename__ = "example"
__metadata__ = metadata
test = fields.Integer(primary_key=True)
test_string = fields.String(length=250)
test_text = fields.Text(default='')
test_bool = fields.Boolean(nullable=False)
test_float = fields.Float()
test_datetime = fields.DateTime(default=datetime.datetime.now)
test_date = fields.Date(default=datetime.date.today)
test_time = fields.Time(default=datetime.time)
test_json = fields.JSON(default={})
test_bigint = fields.BigInteger(default=0)
test_decimal = fields.Decimal(length=10, precision=2)
__database__ = database
id = orm.Integer(primary_key=True)
created = orm.DateTime(default=datetime.datetime.now)
created_day = orm.Date(default=datetime.date.today)
created_time = orm.Time(default=time)
description = orm.Text(nullable=True)
value = orm.Float(nullable=True)
data = orm.JSON(default={})
fields_to_check = ['test', 'test_text', 'test_string', 'test_datetime', 'test_date', 'test_text', 'test_float',
'test_bigint', 'test_json']
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
class ExampleModel2(Model):
__tablename__ = "example2"
__metadata__ = metadata
test = fields.Integer(name='test12', primary_key=True)
test_string = fields.String('test_string2', length=250)
@pytest.mark.asyncio
async def test_model_crud():
async with database:
example = Example()
await example.save()
await example.load()
assert example.created.year == datetime.datetime.now().year
assert example.created_day == datetime.date.today()
assert example.description is None
assert example.value is None
assert example.data == {}
def test_not_nullable_field_is_required():
with pytest.raises(pydantic.error_wrappers.ValidationError):
ExampleModel(test=1, test_string='test')
await example.update(data={"foo": 123}, value=123.456)
await example.load()
assert example.value == 123.456
assert example.data == {"foo": 123}
def test_model_attribute_access():
example = ExampleModel(test=1, test_string='test', test_bool=True)
assert example.test == 1
assert example.test_string == 'test'
assert example.test_datetime.year == datetime.datetime.now().year
assert example.test_date == datetime.date.today()
assert example.test_text == ''
assert example.test_float is None
assert example.test_bigint == 0
assert example.test_json == {}
example.test = 12
assert example.test == 12
example.new_attr = 12
assert 'new_attr' in example.__dict__
def test_primary_key_access_and_setting():
example = ExampleModel(pk=1, test_string='test', test_bool=True)
assert example.pk == 1
example.pk = 2
assert example.pk == 2
assert example.test == 2
def test_pydantic_model_is_created():
example = ExampleModel(pk=1, test_string='test', test_bool=True)
assert issubclass(example.values.__class__, pydantic.BaseModel)
assert all([field in example.values.__fields__ for field in fields_to_check])
assert example.values.test == 1
def test_sqlalchemy_table_is_created():
example = ExampleModel(pk=1, test_string='test', test_bool=True)
assert issubclass(example.__table__.__class__, sqlalchemy.Table)
assert all([field in example.__table__.columns for field in fields_to_check])
def test_double_column_name_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
__tablename__ = "example3"
__metadata__ = metadata
test_string = fields.String('test_string2', name='test_string2', length=250)
def test_setting_pk_column_as_pydantic_only_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
__tablename__ = "example4"
__metadata__ = metadata
test = fields.Integer(name='test12', primary_key=True, pydantic_only=True)
await example.delete()

View File

@ -0,0 +1,114 @@
import datetime
from typing import ClassVar
import pydantic
import pytest
import sqlalchemy
import orm.fields as fields
from orm.exceptions import ModelDefinitionError
from orm.models import Model
metadata = sqlalchemy.MetaData()
class ExampleModel(Model):
__tablename__ = "example"
__metadata__ = metadata
test = fields.Integer(primary_key=True)
test_string = fields.String(length=250)
test_text = fields.Text(default='')
test_bool = fields.Boolean(nullable=False)
test_float = fields.Float()
test_datetime = fields.DateTime(default=datetime.datetime.now)
test_date = fields.Date(default=datetime.date.today)
test_time = fields.Time(default=datetime.time)
test_json = fields.JSON(default={})
test_bigint = fields.BigInteger(default=0)
test_decimal = fields.Decimal(length=10, precision=2)
fields_to_check = ['test', 'test_text', 'test_string', 'test_datetime', 'test_date', 'test_text', 'test_float',
'test_bigint', 'test_json']
class ExampleModel2(Model):
__tablename__ = "example2"
__metadata__ = metadata
test = fields.Integer(name='test12', primary_key=True)
test_string = fields.String('test_string2', length=250)
@pytest.fixture()
def example():
return ExampleModel(pk=1, test_string='test', test_bool=True)
def test_not_nullable_field_is_required():
with pytest.raises(pydantic.error_wrappers.ValidationError):
ExampleModel(test=1, test_string='test')
def test_model_attribute_access(example):
assert example.test == 1
assert example.test_string == 'test'
assert example.test_datetime.year == datetime.datetime.now().year
assert example.test_date == datetime.date.today()
assert example.test_text == ''
assert example.test_float is None
assert example.test_bigint == 0
assert example.test_json == {}
example.test = 12
assert example.test == 12
example.new_attr = 12
assert 'new_attr' in example.__dict__
def test_primary_key_access_and_setting(example):
assert example.pk == 1
example.pk = 2
assert example.pk == 2
assert example.test == 2
def test_pydantic_model_is_created(example):
assert issubclass(example.values.__class__, pydantic.BaseModel)
assert all([field in example.values.__fields__ for field in fields_to_check])
assert example.values.test == 1
def test_sqlalchemy_table_is_created(example):
assert issubclass(example.__table__.__class__, sqlalchemy.Table)
assert all([field in example.__table__.columns for field in fields_to_check])
def test_double_column_name_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
__tablename__ = "example3"
__metadata__ = metadata
test_string = fields.String('test_string2', name='test_string2', length=250)
def test_no_pk_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
__tablename__ = "example3"
__metadata__ = metadata
test_string = fields.String(name='test_string2', length=250)
def test_setting_pk_column_as_pydantic_only_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
__tablename__ = "example4"
__metadata__ = metadata
test = fields.Integer(name='test12', primary_key=True, pydantic_only=True)
def test_json_conversion_in_model():
with pytest.raises(pydantic.ValidationError):
ExampleModel(test_json=datetime.datetime.now(), test=1, test_string='test', test_bool=True)