145 lines
3.4 KiB
Python
145 lines
3.4 KiB
Python
import datetime
|
|
import decimal
|
|
|
|
import pydantic
|
|
import sqlalchemy
|
|
|
|
from orm.exceptions import ModelDefinitionError
|
|
|
|
|
|
class BaseField:
|
|
__type__ = None
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
name = kwargs.pop('name', None)
|
|
args = list(args)
|
|
if args:
|
|
if isinstance(args[0], str):
|
|
if name is not None:
|
|
raise ModelDefinitionError(
|
|
'Column name cannot be passed positionally and as a keyword.'
|
|
)
|
|
name = args.pop(0)
|
|
|
|
self.name = name
|
|
self.primary_key = kwargs.pop('primary_key', False)
|
|
self.autoincrement = kwargs.pop('autoincrement', 'auto')
|
|
|
|
self.nullable = kwargs.pop('nullable', not self.primary_key)
|
|
self.default = kwargs.pop('default', None)
|
|
self.server_default = kwargs.pop('server_default', None)
|
|
|
|
self.index = kwargs.pop('index', None)
|
|
self.unique = kwargs.pop('unique', None)
|
|
|
|
def get_column(self, name=None) -> sqlalchemy.Column:
|
|
name = self.name or name
|
|
constraints = self.get_constraints()
|
|
return sqlalchemy.Column(
|
|
name,
|
|
self.get_column_type(),
|
|
*constraints,
|
|
primary_key=self.primary_key,
|
|
autoincrement=self.autoincrement,
|
|
nullable=self.nullable,
|
|
index=self.index,
|
|
unique=self.unique,
|
|
default=self.default,
|
|
server_default=self.server_default
|
|
)
|
|
|
|
def get_column_type(self) -> sqlalchemy.types.TypeEngine:
|
|
raise NotImplementedError() # pragma: no cover
|
|
|
|
def get_constraints(self):
|
|
return []
|
|
|
|
|
|
class String(BaseField):
|
|
__type__ = str
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
assert 'length' in kwargs, 'length is required'
|
|
self.length = kwargs.pop('length')
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def get_column_type(self):
|
|
return sqlalchemy.String(self.length)
|
|
|
|
|
|
class Integer(BaseField):
|
|
__type__ = int
|
|
|
|
def get_column_type(self):
|
|
return sqlalchemy.Integer()
|
|
|
|
|
|
class Text(BaseField):
|
|
__type__ = str
|
|
|
|
def get_column_type(self):
|
|
return sqlalchemy.Text()
|
|
|
|
|
|
class Float(BaseField):
|
|
__type__ = float
|
|
|
|
def get_column_type(self):
|
|
return sqlalchemy.Float()
|
|
|
|
|
|
class Boolean(BaseField):
|
|
__type__ = bool
|
|
|
|
def get_column_type(self):
|
|
return sqlalchemy.Boolean()
|
|
|
|
|
|
class DateTime(BaseField):
|
|
__type__ = datetime.datetime
|
|
|
|
def get_column_type(self):
|
|
return sqlalchemy.DateTime()
|
|
|
|
|
|
class Date(BaseField):
|
|
__type__ = datetime.date
|
|
|
|
def get_column_type(self):
|
|
return sqlalchemy.Date()
|
|
|
|
|
|
class Time(BaseField):
|
|
__type__ = datetime.time
|
|
|
|
def get_column_type(self):
|
|
return sqlalchemy.Time()
|
|
|
|
|
|
class JSON(BaseField):
|
|
__type__ = pydantic.Json
|
|
|
|
def get_column_type(self):
|
|
return sqlalchemy.JSON()
|
|
|
|
|
|
class BigInteger(BaseField):
|
|
__type__ = int
|
|
|
|
def get_column_type(self):
|
|
return sqlalchemy.BigInteger()
|
|
|
|
|
|
class Decimal(BaseField):
|
|
__type__ = decimal.Decimal
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
assert 'precision' in kwargs, 'precision is required'
|
|
assert 'length' in kwargs, 'length is required'
|
|
self.length = kwargs.pop('length')
|
|
self.precision = kwargs.pop('precision')
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def get_column_type(self):
|
|
return sqlalchemy.DECIMAL(self.length, self.precision)
|