From e0bb7e2cda0f50e9e2b7763dd95b10dc4fb32d1f Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 3 Aug 2020 19:59:04 +0200 Subject: [PATCH] added basic save, update, load and delate methods --- .coverage | Bin 53248 -> 53248 bytes .gitignore | 3 +- orm/exceptions.py | 12 ++++ orm/fields.py | 14 ++-- orm/helpers.py | 27 +++++++ orm/models.py | 124 +++++++++++++++++++++++++++++---- requirements.txt | 3 +- tests/settings.py | 3 + tests/test_columns.py | 122 +++++++++++--------------------- tests/test_model_definition.py | 114 ++++++++++++++++++++++++++++++ 10 files changed, 316 insertions(+), 106 deletions(-) create mode 100644 orm/helpers.py create mode 100644 tests/settings.py create mode 100644 tests/test_model_definition.py diff --git a/.coverage b/.coverage index 78c44140c6601a6f7c8fb9b99c03e62bea51dd59..5710bb4dace601b758eec3329f156f463a9c90a0 100644 GIT binary patch delta 526 zcmZozz}&Ead4sV&yQP(>g_W_zW-I+u0%Dwe1q}RG_?Pfk@dxp%@%`hw#5a?#V6&h= z7$2)TCo@Cn`1#H~KsU;Nb^ zSTvYJ8PTlcv|ug;iQ%-4Lz_7orhttH$vXDU|KnH{ML78D8Tfzlzv6$ue~f<*|6Km5 z{PjTF()sH}I9M1twK&SA|DXSJ&gc8jo>|)4+q1K=14a4Rxn$XZv<=&?+3(N1`B!{= zzgqSDUn}q4eJA_k@Ata<>>R8>X(86MG&TtqAj^m)@8sWbPX5av)!+C(^1tDK&i{!2E>Op1p!%Ksy3D}X f;$;R#4=*z?3V4~>Km;p@U;z=#Ac6@*Faikxx)Pta delta 273 zcmZozz}&Ead4sV&ySbH#g_WW4W-I+u0s?G&{}}kM@Gs%7;t%3iiofc|M{=-F9k}y=G*)~j#W{BjejWv|4;s>{I~fJ@b3UiP2|rPWMg6E zRAGzS_j&Vg%lgTCU*Bb9V+9KFv8G9~0BIAJyqkYt=Y9V_`R44iX`656n9V+0$IQVD zloDe8_McIR3CL1oDx3a){?9r8bMMv6n=jAG$im6V$Hc|Jz`(|TgMt4y{}=wZ{LlFB z^WOlv=LrAi8}oG(fbMw&lKRH~k^c? None: name = kwargs.pop('name', None) args = list(args) if args: @@ -28,7 +24,7 @@ class BaseField: self.name = name self.primary_key = kwargs.pop('primary_key', False) - self.autoincrement = kwargs.pop('autoincrement', 'auto') + self.autoincrement = kwargs.pop('autoincrement', self.primary_key) self.nullable = kwargs.pop('nullable', not self.primary_key) self.default = kwargs.pop('default', None) @@ -41,7 +37,7 @@ class BaseField: if self.pydantic_only and self.primary_key: raise ModelDefinitionError('Primary key column cannot be pydantic only.') - def get_column(self, name=None) -> sqlalchemy.Column: + def get_column(self, name: str = None) -> sqlalchemy.Column: name = self.name or name constraints = self.get_constraints() return sqlalchemy.Column( @@ -60,7 +56,7 @@ class BaseField: def get_column_type(self) -> sqlalchemy.types.TypeEngine: raise NotImplementedError() # pragma: no cover - def get_constraints(self): + def get_constraints(self) -> Optional[List]: return [] diff --git a/orm/helpers.py b/orm/helpers.py new file mode 100644 index 0000000..6e3d254 --- /dev/null +++ b/orm/helpers.py @@ -0,0 +1,27 @@ +from typing import Union, Set, Dict # pragma no cover + + +class Excludable: # pragma no cover + + @staticmethod + def get_excluded(exclude: Union[Set, Dict, None], key: str = None): + # print(f'checking excluded for {key}', exclude) + if isinstance(exclude, dict): + if isinstance(exclude.get(key, {}), dict) and '__all__' in exclude.get(key, {}).keys(): + return exclude.get(key).get('__all__') + return exclude.get(key, {}) + return exclude + + @staticmethod + def is_excluded(exclude: Union[Set, Dict, None], key: str = None): + if exclude is None: + return False + to_exclude = Excludable.get_excluded(exclude, key) + # print(f'to exclude for current key = {key}', to_exclude) + + if isinstance(to_exclude, Set): + return key in to_exclude + elif to_exclude is ...: + return True + else: + return False diff --git a/orm/models.py b/orm/models.py index b446885..19224aa 100644 --- a/orm/models.py +++ b/orm/models.py @@ -1,15 +1,21 @@ +from __future__ import annotations + +import json from typing import Any +from typing import Set, Dict +import pydantic import sqlalchemy -from pydantic import create_model +from pydantic import BaseConfig, create_model +from orm.exceptions import ModelDefinitionError from orm.fields import BaseField def parse_pydantic_field_from_model_fields(object_dict: dict): pydantic_fields = {field_name: ( base_field.__type__, - ... if (not base_field.nullable and not base_field.default) else ( + ... if (not base_field.nullable and not base_field.default and not base_field.primary_key) else ( base_field.default() if callable(base_field.default) else base_field.default) ) for field_name, base_field in object_dict.items() @@ -33,26 +39,37 @@ class ModelMetaclass(type): pkname = None columns = [] + model_fields = {} for field_name, field in attrs.items(): - if isinstance(field, BaseField) and not field.pydantic_only: - if field.primary_key: - pkname = field_name - columns.append(field.get_column(field_name)) + if isinstance(field, BaseField): + model_fields[field_name] = field + if not field.pydantic_only: + if field.primary_key: + pkname = field_name + columns.append(field.get_column(field_name)) # sqlalchemy table creation attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) attrs['__columns__'] = columns attrs['__pkname__'] = pkname + if not pkname: + raise ModelDefinitionError( + 'Table has to have a primary key.' + ) + # pydantic model creation pydantic_fields = parse_pydantic_field_from_model_fields(attrs) - pydantic_model = create_model(name, **pydantic_fields) + config = type('Config', (BaseConfig,), {'orm_mode': True}) + pydantic_model = create_model(name, __config__=config, **pydantic_fields) attrs['__pydantic_fields__'] = pydantic_fields attrs['__pydantic_model__'] = pydantic_model attrs['__fields__'] = pydantic_model.__fields__ attrs['__signature__'] = pydantic_model.__signature__ attrs['__annotations__'] = pydantic_model.__annotations__ + attrs['__model_fields__'] = model_fields + new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) @@ -63,21 +80,36 @@ class ModelMetaclass(type): class Model(metaclass=ModelMetaclass): __abstract__ = True - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: if "pk" in kwargs: kwargs[self.__pkname__] = kwargs.pop("pk") self.values = self.__pydantic_model__(**kwargs) - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: Any) -> None: if key in self.__fields__: + if self.is_conversion_to_json_needed(key) and not isinstance(value, str): + try: + value = json.dumps(value) + except TypeError: # pragma no cover + pass setattr(self.values, key, value) else: super().__setattr__(key, value) - def __getattribute__(self, item) -> Any: - if item != '__fields__' and item in self.__fields__: - return getattr(self.values, item) - return super().__getattribute__(item) + def __getattribute__(self, key: str) -> Any: + if key != '__fields__' and key in self.__fields__: + item = getattr(self.values, key) + if self.is_conversion_to_json_needed(key) and isinstance(item, str): + try: + item = json.loads(item) + except TypeError: # pragma no cover + pass + return item + + return super().__getattribute__(key) + + def is_conversion_to_json_needed(self, column_name: str) -> bool: + return self.__model_fields__.get(column_name).__type__ == pydantic.Json @property def pk(self): @@ -86,3 +118,69 @@ class Model(metaclass=ModelMetaclass): @pk.setter def pk(self, value): setattr(self.values, self.__pkname__, value) + + @property + def pk_column(self) -> sqlalchemy.Column: + return self.__table__.primary_key.columns.values()[0] + + def dict(self) -> Dict: + return self.values.dict() + + def from_dict(self, value_dict: Dict) -> None: + for key, value in value_dict.items(): + setattr(self, key, value) + + def extract_own_model_fields(self) -> Dict: + related_names = self.extract_related_names() + self_fields = {k: v for k, v in self.dict().items() if k not in related_names} + return self_fields + + @classmethod + def extract_related_names(cls) -> Set: + related_names = set() + # for name, field in cls.__fields__.items(): + # if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): + # related_names.add(name) + # elif field.sub_fields and any( + # [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]): + # related_names.add(name) + return related_names + + def extract_model_db_fields(self) -> Dict: + self_fields = self.extract_own_model_fields() + self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns} + return self_fields + + async def save(self) -> int: + self_fields = self.extract_model_db_fields() + if self.__model_fields__.get(self.__pkname__).autoincrement: + self_fields.pop(self.__pkname__, None) + expr = self.__table__.insert() + expr = expr.values(**self_fields) + item_id = await self.__database__.execute(expr) + setattr(self, 'pk', item_id) + return item_id + + async def update(self, **kwargs: Any) -> int: + if kwargs: + new_values = {**self.dict(), **kwargs} + self.from_dict(new_values) + + self_fields = self.extract_model_db_fields() + self_fields.pop(self.__pkname__) + expr = self.__table__.update().values(**self_fields).where( + self.pk_column == getattr(self, self.__pkname__)) + result = await self.__database__.execute(expr) + return result + + async def delete(self) -> int: + expr = self.__table__.delete() + expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) + result = await self.__database__.execute(expr) + return result + + async def load(self) -> Model: + expr = self.__table__.select().where(self.pk_column == self.pk) + row = await self.__database__.fetch_one(expr) + self.from_dict(dict(row)) + return self diff --git a/requirements.txt b/requirements.txt index 590566c..f6280a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ sqlalchemy # Testing pytest pytest-cov -codecov \ No newline at end of file +codecov +pytest-asyncio \ No newline at end of file diff --git a/tests/settings.py b/tests/settings.py new file mode 100644 index 0000000..697acb0 --- /dev/null +++ b/tests/settings.py @@ -0,0 +1,3 @@ +import os + +DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///test.db") diff --git a/tests/test_columns.py b/tests/test_columns.py index 7719a0c..edee116 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -1,100 +1,58 @@ import datetime -from typing import ClassVar -import pydantic +import databases import pytest import sqlalchemy -import orm.fields as fields -from orm.exceptions import ModelDefinitionError -from orm.models import Model +import orm +from tests.settings import DATABASE_URL +database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() -class ExampleModel(Model): +def time(): + return datetime.datetime.now().time() + + +class Example(orm.Model): __tablename__ = "example" __metadata__ = metadata - test = fields.Integer(primary_key=True) - test_string = fields.String(length=250) - test_text = fields.Text(default='') - test_bool = fields.Boolean(nullable=False) - test_float = fields.Float() - test_datetime = fields.DateTime(default=datetime.datetime.now) - test_date = fields.Date(default=datetime.date.today) - test_time = fields.Time(default=datetime.time) - test_json = fields.JSON(default={}) - test_bigint = fields.BigInteger(default=0) - test_decimal = fields.Decimal(length=10, precision=2) + __database__ = database + + id = orm.Integer(primary_key=True) + created = orm.DateTime(default=datetime.datetime.now) + created_day = orm.Date(default=datetime.date.today) + created_time = orm.Time(default=time) + description = orm.Text(nullable=True) + value = orm.Float(nullable=True) + data = orm.JSON(default={}) -fields_to_check = ['test', 'test_text', 'test_string', 'test_datetime', 'test_date', 'test_text', 'test_float', - 'test_bigint', 'test_json'] +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) -class ExampleModel2(Model): - __tablename__ = "example2" - __metadata__ = metadata - test = fields.Integer(name='test12', primary_key=True) - test_string = fields.String('test_string2', length=250) +@pytest.mark.asyncio +async def test_model_crud(): + async with database: + example = Example() + await example.save() + await example.load() + assert example.created.year == datetime.datetime.now().year + assert example.created_day == datetime.date.today() + assert example.description is None + assert example.value is None + assert example.data == {} -def test_not_nullable_field_is_required(): - with pytest.raises(pydantic.error_wrappers.ValidationError): - ExampleModel(test=1, test_string='test') + await example.update(data={"foo": 123}, value=123.456) + await example.load() + assert example.value == 123.456 + assert example.data == {"foo": 123} - -def test_model_attribute_access(): - example = ExampleModel(test=1, test_string='test', test_bool=True) - assert example.test == 1 - assert example.test_string == 'test' - assert example.test_datetime.year == datetime.datetime.now().year - assert example.test_date == datetime.date.today() - assert example.test_text == '' - assert example.test_float is None - assert example.test_bigint == 0 - assert example.test_json == {} - - example.test = 12 - assert example.test == 12 - - example.new_attr = 12 - assert 'new_attr' in example.__dict__ - - -def test_primary_key_access_and_setting(): - example = ExampleModel(pk=1, test_string='test', test_bool=True) - assert example.pk == 1 - example.pk = 2 - - assert example.pk == 2 - assert example.test == 2 - - -def test_pydantic_model_is_created(): - example = ExampleModel(pk=1, test_string='test', test_bool=True) - assert issubclass(example.values.__class__, pydantic.BaseModel) - assert all([field in example.values.__fields__ for field in fields_to_check]) - assert example.values.test == 1 - - -def test_sqlalchemy_table_is_created(): - example = ExampleModel(pk=1, test_string='test', test_bool=True) - assert issubclass(example.__table__.__class__, sqlalchemy.Table) - assert all([field in example.__table__.columns for field in fields_to_check]) - - -def test_double_column_name_in_model_definition(): - with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example3" - __metadata__ = metadata - test_string = fields.String('test_string2', name='test_string2', length=250) - - -def test_setting_pk_column_as_pydantic_only_in_model_definition(): - with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example4" - __metadata__ = metadata - test = fields.Integer(name='test12', primary_key=True, pydantic_only=True) + await example.delete() diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py new file mode 100644 index 0000000..f06f141 --- /dev/null +++ b/tests/test_model_definition.py @@ -0,0 +1,114 @@ +import datetime +from typing import ClassVar + +import pydantic +import pytest +import sqlalchemy + +import orm.fields as fields +from orm.exceptions import ModelDefinitionError +from orm.models import Model + +metadata = sqlalchemy.MetaData() + + +class ExampleModel(Model): + __tablename__ = "example" + __metadata__ = metadata + test = fields.Integer(primary_key=True) + test_string = fields.String(length=250) + test_text = fields.Text(default='') + test_bool = fields.Boolean(nullable=False) + test_float = fields.Float() + test_datetime = fields.DateTime(default=datetime.datetime.now) + test_date = fields.Date(default=datetime.date.today) + test_time = fields.Time(default=datetime.time) + test_json = fields.JSON(default={}) + test_bigint = fields.BigInteger(default=0) + test_decimal = fields.Decimal(length=10, precision=2) + + +fields_to_check = ['test', 'test_text', 'test_string', 'test_datetime', 'test_date', 'test_text', 'test_float', + 'test_bigint', 'test_json'] + + +class ExampleModel2(Model): + __tablename__ = "example2" + __metadata__ = metadata + test = fields.Integer(name='test12', primary_key=True) + test_string = fields.String('test_string2', length=250) + + +@pytest.fixture() +def example(): + return ExampleModel(pk=1, test_string='test', test_bool=True) + + +def test_not_nullable_field_is_required(): + with pytest.raises(pydantic.error_wrappers.ValidationError): + ExampleModel(test=1, test_string='test') + + +def test_model_attribute_access(example): + assert example.test == 1 + assert example.test_string == 'test' + assert example.test_datetime.year == datetime.datetime.now().year + assert example.test_date == datetime.date.today() + assert example.test_text == '' + assert example.test_float is None + assert example.test_bigint == 0 + assert example.test_json == {} + + example.test = 12 + assert example.test == 12 + + example.new_attr = 12 + assert 'new_attr' in example.__dict__ + + +def test_primary_key_access_and_setting(example): + assert example.pk == 1 + example.pk = 2 + + assert example.pk == 2 + assert example.test == 2 + + +def test_pydantic_model_is_created(example): + assert issubclass(example.values.__class__, pydantic.BaseModel) + assert all([field in example.values.__fields__ for field in fields_to_check]) + assert example.values.test == 1 + + +def test_sqlalchemy_table_is_created(example): + assert issubclass(example.__table__.__class__, sqlalchemy.Table) + assert all([field in example.__table__.columns for field in fields_to_check]) + + +def test_double_column_name_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + __tablename__ = "example3" + __metadata__ = metadata + test_string = fields.String('test_string2', name='test_string2', length=250) + + +def test_no_pk_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + __tablename__ = "example3" + __metadata__ = metadata + test_string = fields.String(name='test_string2', length=250) + + +def test_setting_pk_column_as_pydantic_only_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + __tablename__ = "example4" + __metadata__ = metadata + test = fields.Integer(name='test12', primary_key=True, pydantic_only=True) + + +def test_json_conversion_in_model(): + with pytest.raises(pydantic.ValidationError): + ExampleModel(test_json=datetime.datetime.now(), test=1, test_string='test', test_bool=True)