From d7355b8c9bb2c5cfeb3d00d1f4dbd383602c9360 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 3 Aug 2020 17:49:01 +0200 Subject: [PATCH] more checks for table and pydantic model creation --- .coverage | Bin 53248 -> 53248 bytes orm/__init__.py | 17 +++++++++++++++++ orm/fields.py | 9 +++++++++ orm/models.py | 34 ++++++++++++++++++++++------------ tests/test_columns.py | 29 +++++++++++++++++++++++++++-- 5 files changed, 75 insertions(+), 14 deletions(-) diff --git a/.coverage b/.coverage index 8755b3052e06a04c4b76a6414d65aa0ae7b36dcc..78c44140c6601a6f7c8fb9b99c03e62bea51dd59 100644 GIT binary patch delta 132 zcmZozz}&Ead4sV&ySbH#g_WW4W-I+O0enjt_Q@8;jvdEfs}zB&7B+UA=%X0y-MF>^2j krG%Kj{bv+n0jDH{h+RcIjGx>e^*;p7k z71*Noect^0>*;&dn}7DRI{*NOmo4Q0 diff --git a/orm/__init__.py b/orm/__init__.py index e69de29..5270355 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -0,0 +1,17 @@ +from orm.fields import Integer, BigInteger, Boolean, Time, Text, String, JSON, DateTime, Date, Decimal, Float +from orm.models import Model + +__all__ = [ + "Integer", + "BigInteger", + "Boolean", + "Time", + "Text", + "String", + "JSON", + "DateTime", + "Date", + "Decimal", + "Float", + "Model" +] diff --git a/orm/fields.py b/orm/fields.py index 1b64610..3f33a81 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -1,5 +1,6 @@ import datetime import decimal +from typing import Any import pydantic import sqlalchemy @@ -10,6 +11,10 @@ from orm.exceptions import ModelDefinitionError class BaseField: __type__ = None + def __new__(cls, *args, **kwargs): + cls.__annotations__ = {} + return super().__new__(cls) + def __init__(self, *args, **kwargs): name = kwargs.pop('name', None) args = list(args) @@ -32,6 +37,10 @@ class BaseField: self.index = kwargs.pop('index', None) self.unique = kwargs.pop('unique', None) + self.pydantic_only = kwargs.pop('pydantic_only', False) + if self.pydantic_only and self.primary_key: + raise ModelDefinitionError('Primary key column cannot be pydantic only.') + def get_column(self, name=None) -> sqlalchemy.Column: name = self.name or name constraints = self.get_constraints() diff --git a/orm/models.py b/orm/models.py index 50f2eb9..b446885 100644 --- a/orm/models.py +++ b/orm/models.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import sqlalchemy from pydantic import create_model @@ -33,20 +33,29 @@ class ModelMetaclass(type): pkname = None columns = [] - for field_name, field in new_model.__dict__.items(): - if isinstance(field, BaseField): + for field_name, field in attrs.items(): + if isinstance(field, BaseField) and not field.pydantic_only: if field.primary_key: pkname = field_name columns.append(field.get_column(field_name)) - pydantic_fields = parse_pydantic_field_from_model_fields(new_model.__dict__) + # sqlalchemy table creation + attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) + attrs['__columns__'] = columns + attrs['__pkname__'] = pkname - new_model.__table__ = sqlalchemy.Table(tablename, metadata, *columns) - new_model.__columns__ = columns - new_model.__pkname__ = pkname - new_model.__pydantic_fields__ = pydantic_fields - new_model.__pydantic_model__ = create_model(name, **pydantic_fields) - new_model.__fields__ = new_model.__pydantic_model__.__fields__ + # pydantic model creation + pydantic_fields = parse_pydantic_field_from_model_fields(attrs) + pydantic_model = create_model(name, **pydantic_fields) + attrs['__pydantic_fields__'] = pydantic_fields + attrs['__pydantic_model__'] = pydantic_model + attrs['__fields__'] = pydantic_model.__fields__ + attrs['__signature__'] = pydantic_model.__signature__ + attrs['__annotations__'] = pydantic_model.__annotations__ + + new_model = super().__new__( # type: ignore + mcs, name, bases, attrs + ) return new_model @@ -62,9 +71,10 @@ class Model(metaclass=ModelMetaclass): def __setattr__(self, key, value): if key in self.__fields__: setattr(self.values, key, value) - super().__setattr__(key, value) + else: + super().__setattr__(key, value) - def __getattribute__(self, item): + def __getattribute__(self, item) -> Any: if item != '__fields__' and item in self.__fields__: return getattr(self.values, item) return super().__getattribute__(item) diff --git a/tests/test_columns.py b/tests/test_columns.py index 981d9b7..7719a0c 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -1,4 +1,5 @@ import datetime +from typing import ClassVar import pydantic import pytest @@ -27,6 +28,10 @@ class ExampleModel(Model): test_decimal = fields.Decimal(length=10, precision=2) +fields_to_check = ['test', 'test_text', 'test_string', 'test_datetime', 'test_date', 'test_text', 'test_float', + 'test_bigint', 'test_json'] + + class ExampleModel2(Model): __tablename__ = "example2" __metadata__ = metadata @@ -66,10 +71,30 @@ def test_primary_key_access_and_setting(): assert example.test == 2 -def test_wrong_model_definition(): +def test_pydantic_model_is_created(): + example = ExampleModel(pk=1, test_string='test', test_bool=True) + assert issubclass(example.values.__class__, pydantic.BaseModel) + assert all([field in example.values.__fields__ for field in fields_to_check]) + assert example.values.test == 1 + + +def test_sqlalchemy_table_is_created(): + example = ExampleModel(pk=1, test_string='test', test_bool=True) + assert issubclass(example.__table__.__class__, sqlalchemy.Table) + assert all([field in example.__table__.columns for field in fields_to_check]) + + +def test_double_column_name_in_model_definition(): with pytest.raises(ModelDefinitionError): class ExampleModel2(Model): __tablename__ = "example3" __metadata__ = metadata - test = fields.Integer(name='test12', primary_key=True) test_string = fields.String('test_string2', name='test_string2', length=250) + + +def test_setting_pk_column_as_pydantic_only_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + __tablename__ = "example4" + __metadata__ = metadata + test = fields.Integer(name='test12', primary_key=True, pydantic_only=True)