refactor fields into classes
This commit is contained in:
@ -23,7 +23,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
|
||||
|
||||
|
||||
def ForeignKey(
|
||||
to: "Model",
|
||||
to: Type["Model"],
|
||||
*,
|
||||
name: str = None,
|
||||
unique: bool = False,
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import datetime
|
||||
import decimal
|
||||
import re
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import pydantic
|
||||
import sqlalchemy
|
||||
from pydantic import Json
|
||||
|
||||
from ormar import ModelDefinitionError # noqa I101
|
||||
from ormar.fields.base import BaseField # noqa I101
|
||||
@ -19,359 +17,253 @@ def is_field_nullable(
|
||||
return nullable
|
||||
|
||||
|
||||
def String(
|
||||
*,
|
||||
name: str = None,
|
||||
primary_key: bool = False,
|
||||
nullable: bool = None,
|
||||
index: bool = False,
|
||||
unique: bool = False,
|
||||
allow_blank: bool = False,
|
||||
strip_whitespace: bool = False,
|
||||
min_length: int = None,
|
||||
max_length: int = None,
|
||||
curtail_length: int = None,
|
||||
regex: str = None,
|
||||
pydantic_only: bool = False,
|
||||
default: Any = None,
|
||||
server_default: Any = None,
|
||||
) -> Type[str]:
|
||||
if max_length is None or max_length <= 0:
|
||||
raise ModelDefinitionError("Parameter max_length is required for field String")
|
||||
class ModelFieldFactory:
|
||||
_bases = None
|
||||
_type = None
|
||||
|
||||
namespace = dict(
|
||||
__type__=str,
|
||||
name=name,
|
||||
primary_key=primary_key,
|
||||
nullable=is_field_nullable(nullable, default, server_default),
|
||||
index=index,
|
||||
unique=unique,
|
||||
allow_blank=allow_blank,
|
||||
strip_whitespace=strip_whitespace,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
curtail_length=curtail_length,
|
||||
regex=regex and re.compile(regex),
|
||||
column_type=sqlalchemy.String(length=max_length),
|
||||
pydantic_only=pydantic_only,
|
||||
default=default,
|
||||
server_default=server_default,
|
||||
autoincrement=False,
|
||||
)
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]:
|
||||
cls.validate(**kwargs)
|
||||
|
||||
return type("String", (pydantic.ConstrainedStr, BaseField), namespace)
|
||||
default = kwargs.pop("default", None)
|
||||
server_default = kwargs.pop("server_default", None)
|
||||
nullable = kwargs.pop("nullable", None)
|
||||
|
||||
|
||||
def Integer(
|
||||
*,
|
||||
name: str = None,
|
||||
primary_key: bool = False,
|
||||
autoincrement: bool = None,
|
||||
nullable: bool = None,
|
||||
index: bool = False,
|
||||
unique: bool = False,
|
||||
minimum: int = None,
|
||||
maximum: int = None,
|
||||
multiple_of: int = None,
|
||||
pydantic_only: bool = False,
|
||||
default: Any = None,
|
||||
server_default: Any = None,
|
||||
) -> Type[int]:
|
||||
namespace = dict(
|
||||
__type__=int,
|
||||
name=name,
|
||||
primary_key=primary_key,
|
||||
nullable=is_field_nullable(nullable, default, server_default),
|
||||
index=index,
|
||||
unique=unique,
|
||||
ge=minimum,
|
||||
le=maximum,
|
||||
multiple_of=multiple_of,
|
||||
column_type=sqlalchemy.Integer(),
|
||||
pydantic_only=pydantic_only,
|
||||
default=default,
|
||||
server_default=server_default,
|
||||
autoincrement=autoincrement if autoincrement is not None else primary_key,
|
||||
)
|
||||
return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace)
|
||||
|
||||
|
||||
def Text(
|
||||
*,
|
||||
name: str = None,
|
||||
primary_key: bool = False,
|
||||
nullable: bool = None,
|
||||
index: bool = False,
|
||||
unique: bool = False,
|
||||
allow_blank: bool = False,
|
||||
strip_whitespace: bool = False,
|
||||
pydantic_only: bool = False,
|
||||
default: Any = None,
|
||||
server_default: Any = None,
|
||||
) -> Type[str]:
|
||||
namespace = dict(
|
||||
__type__=str,
|
||||
name=name,
|
||||
primary_key=primary_key,
|
||||
nullable=is_field_nullable(nullable, default, server_default),
|
||||
index=index,
|
||||
unique=unique,
|
||||
allow_blank=allow_blank,
|
||||
strip_whitespace=strip_whitespace,
|
||||
column_type=sqlalchemy.Text(),
|
||||
pydantic_only=pydantic_only,
|
||||
default=default,
|
||||
server_default=server_default,
|
||||
autoincrement=False,
|
||||
)
|
||||
|
||||
return type("Text", (pydantic.ConstrainedStr, BaseField), namespace)
|
||||
|
||||
|
||||
def Float(
|
||||
*,
|
||||
name: str = None,
|
||||
primary_key: bool = False,
|
||||
nullable: bool = None,
|
||||
index: bool = False,
|
||||
unique: bool = False,
|
||||
minimum: float = None,
|
||||
maximum: float = None,
|
||||
multiple_of: int = None,
|
||||
pydantic_only: bool = False,
|
||||
default: Any = None,
|
||||
server_default: Any = None,
|
||||
) -> Type[int]:
|
||||
namespace = dict(
|
||||
__type__=float,
|
||||
name=name,
|
||||
primary_key=primary_key,
|
||||
nullable=is_field_nullable(nullable, default, server_default),
|
||||
index=index,
|
||||
unique=unique,
|
||||
ge=minimum,
|
||||
le=maximum,
|
||||
multiple_of=multiple_of,
|
||||
column_type=sqlalchemy.Float(),
|
||||
pydantic_only=pydantic_only,
|
||||
default=default,
|
||||
server_default=server_default,
|
||||
autoincrement=False,
|
||||
)
|
||||
return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace)
|
||||
|
||||
|
||||
def Boolean(
|
||||
*,
|
||||
name: str = None,
|
||||
primary_key: bool = False,
|
||||
nullable: bool = None,
|
||||
index: bool = False,
|
||||
unique: bool = False,
|
||||
pydantic_only: bool = False,
|
||||
default: Any = None,
|
||||
server_default: Any = None,
|
||||
) -> Type[bool]:
|
||||
namespace = dict(
|
||||
__type__=bool,
|
||||
name=name,
|
||||
primary_key=primary_key,
|
||||
nullable=is_field_nullable(nullable, default, server_default),
|
||||
index=index,
|
||||
unique=unique,
|
||||
column_type=sqlalchemy.Boolean(),
|
||||
pydantic_only=pydantic_only,
|
||||
default=default,
|
||||
server_default=server_default,
|
||||
autoincrement=False,
|
||||
)
|
||||
return type("Boolean", (int, BaseField), namespace)
|
||||
|
||||
|
||||
def DateTime(
|
||||
*,
|
||||
name: str = None,
|
||||
primary_key: bool = False,
|
||||
nullable: bool = None,
|
||||
index: bool = False,
|
||||
unique: bool = False,
|
||||
pydantic_only: bool = False,
|
||||
default: Any = None,
|
||||
server_default: Any = None,
|
||||
) -> Type[datetime.datetime]:
|
||||
namespace = dict(
|
||||
__type__=datetime.datetime,
|
||||
name=name,
|
||||
primary_key=primary_key,
|
||||
nullable=is_field_nullable(nullable, default, server_default),
|
||||
index=index,
|
||||
unique=unique,
|
||||
column_type=sqlalchemy.DateTime(),
|
||||
pydantic_only=pydantic_only,
|
||||
default=default,
|
||||
server_default=server_default,
|
||||
autoincrement=False,
|
||||
)
|
||||
return type("DateTime", (datetime.datetime, BaseField), namespace)
|
||||
|
||||
|
||||
def Date(
|
||||
*,
|
||||
name: str = None,
|
||||
primary_key: bool = False,
|
||||
nullable: bool = None,
|
||||
index: bool = False,
|
||||
unique: bool = False,
|
||||
pydantic_only: bool = False,
|
||||
default: Any = None,
|
||||
server_default: Any = None,
|
||||
) -> Type[datetime.date]:
|
||||
namespace = dict(
|
||||
__type__=datetime.date,
|
||||
name=name,
|
||||
primary_key=primary_key,
|
||||
nullable=is_field_nullable(nullable, default, server_default),
|
||||
index=index,
|
||||
unique=unique,
|
||||
column_type=sqlalchemy.Date(),
|
||||
pydantic_only=pydantic_only,
|
||||
default=default,
|
||||
server_default=server_default,
|
||||
autoincrement=False,
|
||||
)
|
||||
return type("Date", (datetime.date, BaseField), namespace)
|
||||
|
||||
|
||||
def Time(
|
||||
*,
|
||||
name: str = None,
|
||||
primary_key: bool = False,
|
||||
nullable: bool = None,
|
||||
index: bool = False,
|
||||
unique: bool = False,
|
||||
pydantic_only: bool = False,
|
||||
default: Any = None,
|
||||
server_default: Any = None,
|
||||
) -> Type[datetime.time]:
|
||||
namespace = dict(
|
||||
__type__=datetime.time,
|
||||
name=name,
|
||||
primary_key=primary_key,
|
||||
nullable=is_field_nullable(nullable, default, server_default),
|
||||
index=index,
|
||||
unique=unique,
|
||||
column_type=sqlalchemy.Time(),
|
||||
pydantic_only=pydantic_only,
|
||||
default=default,
|
||||
server_default=server_default,
|
||||
autoincrement=False,
|
||||
)
|
||||
return type("Time", (datetime.time, BaseField), namespace)
|
||||
|
||||
|
||||
def JSON(
|
||||
*,
|
||||
name: str = None,
|
||||
primary_key: bool = False,
|
||||
nullable: bool = None,
|
||||
index: bool = False,
|
||||
unique: bool = False,
|
||||
pydantic_only: bool = False,
|
||||
default: Any = None,
|
||||
server_default: Any = None,
|
||||
) -> Type[Json]:
|
||||
namespace = dict(
|
||||
__type__=pydantic.Json,
|
||||
name=name,
|
||||
primary_key=primary_key,
|
||||
nullable=is_field_nullable(nullable, default, server_default),
|
||||
index=index,
|
||||
unique=unique,
|
||||
column_type=sqlalchemy.JSON(),
|
||||
pydantic_only=pydantic_only,
|
||||
default=default,
|
||||
server_default=server_default,
|
||||
autoincrement=False,
|
||||
)
|
||||
|
||||
return type("JSON", (pydantic.Json, BaseField), namespace)
|
||||
|
||||
|
||||
def BigInteger(
|
||||
*,
|
||||
name: str = None,
|
||||
primary_key: bool = False,
|
||||
autoincrement: bool = None,
|
||||
nullable: bool = None,
|
||||
index: bool = False,
|
||||
unique: bool = False,
|
||||
minimum: int = None,
|
||||
maximum: int = None,
|
||||
multiple_of: int = None,
|
||||
pydantic_only: bool = False,
|
||||
default: Any = None,
|
||||
server_default: Any = None,
|
||||
) -> Type[int]:
|
||||
namespace = dict(
|
||||
__type__=int,
|
||||
name=name,
|
||||
primary_key=primary_key,
|
||||
nullable=is_field_nullable(nullable, default, server_default),
|
||||
index=index,
|
||||
unique=unique,
|
||||
ge=minimum,
|
||||
le=maximum,
|
||||
multiple_of=multiple_of,
|
||||
column_type=sqlalchemy.BigInteger(),
|
||||
pydantic_only=pydantic_only,
|
||||
default=default,
|
||||
server_default=server_default,
|
||||
autoincrement=autoincrement if autoincrement is not None else primary_key,
|
||||
)
|
||||
return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace)
|
||||
|
||||
|
||||
def Decimal(
|
||||
*,
|
||||
name: str = None,
|
||||
primary_key: bool = False,
|
||||
nullable: bool = None,
|
||||
index: bool = False,
|
||||
unique: bool = False,
|
||||
minimum: float = None,
|
||||
maximum: float = None,
|
||||
multiple_of: int = None,
|
||||
precision: int = None,
|
||||
scale: int = None,
|
||||
max_digits: int = None,
|
||||
decimal_places: int = None,
|
||||
pydantic_only: bool = False,
|
||||
default: Any = None,
|
||||
server_default: Any = None,
|
||||
) -> Type[decimal.Decimal]:
|
||||
if precision is None or precision < 0 or scale is None or scale < 0:
|
||||
raise ModelDefinitionError(
|
||||
"Parameters scale and precision are required for field Decimal"
|
||||
namespace = dict(
|
||||
__type__=cls._type,
|
||||
name=kwargs.pop("name", None),
|
||||
primary_key=kwargs.pop("primary_key", False),
|
||||
default=default,
|
||||
server_default=server_default,
|
||||
nullable=is_field_nullable(nullable, default, server_default),
|
||||
index=kwargs.pop("index", False),
|
||||
unique=kwargs.pop("unique", False),
|
||||
pydantic_only=kwargs.pop("pydantic_only", False),
|
||||
autoincrement=kwargs.pop("autoincrement", False),
|
||||
column_type=cls.get_column_type(**kwargs),
|
||||
**kwargs
|
||||
)
|
||||
return type(cls.__name__, cls._bases, namespace)
|
||||
|
||||
namespace = dict(
|
||||
__type__=decimal.Decimal,
|
||||
name=name,
|
||||
primary_key=primary_key,
|
||||
nullable=is_field_nullable(nullable, default, server_default),
|
||||
index=index,
|
||||
unique=unique,
|
||||
ge=minimum,
|
||||
le=maximum,
|
||||
multiple_of=multiple_of,
|
||||
column_type=sqlalchemy.types.DECIMAL(precision=precision, scale=scale),
|
||||
precision=precision,
|
||||
scale=scale,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
pydantic_only=pydantic_only,
|
||||
default=default,
|
||||
server_default=server_default,
|
||||
autoincrement=False,
|
||||
)
|
||||
return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace)
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any: # pragma no cover
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def validate(cls, **kwargs: Any) -> None: # pragma no cover
|
||||
pass
|
||||
|
||||
|
||||
class String(ModelFieldFactory):
|
||||
_bases = (pydantic.ConstrainedStr, BaseField)
|
||||
_type = str
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
*,
|
||||
allow_blank: bool = False,
|
||||
strip_whitespace: bool = False,
|
||||
min_length: int = None,
|
||||
max_length: int = None,
|
||||
curtail_length: int = None,
|
||||
regex: str = None,
|
||||
**kwargs: Any
|
||||
) -> Type[str]:
|
||||
kwargs = {
|
||||
**kwargs,
|
||||
**{
|
||||
k: v
|
||||
for k, v in locals().items()
|
||||
if k not in ["cls", "__class__", "kwargs"]
|
||||
},
|
||||
}
|
||||
return super().__new__(cls, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
return sqlalchemy.String(length=kwargs.get("max_length"))
|
||||
|
||||
@classmethod
|
||||
def validate(cls, **kwargs: Any) -> None:
|
||||
max_length = kwargs.get("max_length", None)
|
||||
if max_length is None or max_length <= 0:
|
||||
raise ModelDefinitionError(
|
||||
"Parameter max_length is required for field String"
|
||||
)
|
||||
|
||||
|
||||
class Integer(ModelFieldFactory):
|
||||
_bases = (pydantic.ConstrainedInt, BaseField)
|
||||
_type = int
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
*,
|
||||
minimum: int = None,
|
||||
maximum: int = None,
|
||||
multiple_of: int = None,
|
||||
**kwargs: Any
|
||||
) -> Type[int]:
|
||||
autoincrement = kwargs.pop("autoincrement", None)
|
||||
autoincrement = (
|
||||
autoincrement
|
||||
if autoincrement is not None
|
||||
else kwargs.get("primary_key", False)
|
||||
)
|
||||
kwargs = {
|
||||
**kwargs,
|
||||
**{
|
||||
k: v
|
||||
for k, v in locals().items()
|
||||
if k not in ["cls", "__class__", "kwargs"]
|
||||
},
|
||||
}
|
||||
return super().__new__(cls, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
return sqlalchemy.Integer()
|
||||
|
||||
|
||||
class Text(ModelFieldFactory):
|
||||
_bases = (pydantic.ConstrainedStr, BaseField)
|
||||
_type = str
|
||||
|
||||
def __new__(
|
||||
cls, *, allow_blank: bool = False, strip_whitespace: bool = False, **kwargs: Any
|
||||
) -> Type[str]:
|
||||
kwargs = {
|
||||
**kwargs,
|
||||
**{
|
||||
k: v
|
||||
for k, v in locals().items()
|
||||
if k not in ["cls", "__class__", "kwargs"]
|
||||
},
|
||||
}
|
||||
return super().__new__(cls, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
return sqlalchemy.Text()
|
||||
|
||||
|
||||
class Float(ModelFieldFactory):
|
||||
_bases = (pydantic.ConstrainedFloat, BaseField)
|
||||
_type = float
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
*,
|
||||
minimum: float = None,
|
||||
maximum: float = None,
|
||||
multiple_of: int = None,
|
||||
**kwargs: Any
|
||||
) -> Type[int]:
|
||||
kwargs = {
|
||||
**kwargs,
|
||||
**{
|
||||
k: v
|
||||
for k, v in locals().items()
|
||||
if k not in ["cls", "__class__", "kwargs"]
|
||||
},
|
||||
}
|
||||
return super().__new__(cls, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
return sqlalchemy.Float()
|
||||
|
||||
|
||||
class Boolean(ModelFieldFactory):
|
||||
_bases = (int, BaseField)
|
||||
_type = bool
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
return sqlalchemy.Boolean()
|
||||
|
||||
|
||||
class DateTime(ModelFieldFactory):
|
||||
_bases = (datetime.datetime, BaseField)
|
||||
_type = datetime.datetime
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
return sqlalchemy.DateTime()
|
||||
|
||||
|
||||
class Date(ModelFieldFactory):
|
||||
_bases = (datetime.date, BaseField)
|
||||
_type = datetime.date
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
return sqlalchemy.Date()
|
||||
|
||||
|
||||
class Time(ModelFieldFactory):
|
||||
_bases = (datetime.time, BaseField)
|
||||
_type = datetime.time
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
return sqlalchemy.Time()
|
||||
|
||||
|
||||
class JSON(ModelFieldFactory):
|
||||
_bases = (pydantic.Json, BaseField)
|
||||
_type = pydantic.Json
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
return sqlalchemy.JSON()
|
||||
|
||||
|
||||
class BigInteger(Integer):
|
||||
_bases = (pydantic.ConstrainedInt, BaseField)
|
||||
_type = int
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
return sqlalchemy.BigInteger()
|
||||
|
||||
|
||||
class Decimal(ModelFieldFactory):
|
||||
_bases = (pydantic.ConstrainedDecimal, BaseField)
|
||||
_type = decimal.Decimal
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
*,
|
||||
minimum: float = None,
|
||||
maximum: float = None,
|
||||
multiple_of: int = None,
|
||||
precision: int = None,
|
||||
scale: int = None,
|
||||
max_digits: int = None,
|
||||
decimal_places: int = None,
|
||||
**kwargs: Any
|
||||
) -> Type[decimal.Decimal]:
|
||||
kwargs = {
|
||||
**kwargs,
|
||||
**{
|
||||
k: v
|
||||
for k, v in locals().items()
|
||||
if k not in ["cls", "__class__", "kwargs"]
|
||||
},
|
||||
}
|
||||
return super().__new__(cls, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
precision = kwargs.get("precision")
|
||||
scale = kwargs.get("scale")
|
||||
return sqlalchemy.DECIMAL(precision=precision, scale=scale)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, **kwargs: Any) -> None:
|
||||
precision = kwargs.get("precision")
|
||||
scale = kwargs.get("scale")
|
||||
if precision is None or precision < 0 or scale is None or scale < 0:
|
||||
raise ModelDefinitionError(
|
||||
"Parameters scale and precision are required for field Decimal"
|
||||
)
|
||||
|
||||
@ -240,10 +240,9 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
|
||||
k: v for k, v in self_fields.items() if k in self.Meta.table.columns
|
||||
}
|
||||
for field in self._extract_related_names():
|
||||
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
|
||||
if getattr(self, field) is not None:
|
||||
self_fields[field] = getattr(
|
||||
getattr(self, field), self.Meta.model_fields[field].to.Meta.pkname
|
||||
)
|
||||
self_fields[field] = getattr(getattr(self, field), target_pk_name)
|
||||
return self_fields
|
||||
|
||||
@classmethod
|
||||
@ -259,17 +258,18 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
|
||||
@classmethod
|
||||
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
|
||||
for field in one.Meta.model_fields.keys():
|
||||
if isinstance(getattr(one, field), list) and not isinstance(
|
||||
getattr(one, field), ormar.Model
|
||||
current_field = getattr(one, field)
|
||||
if isinstance(current_field, list) and not isinstance(
|
||||
current_field, ormar.Model
|
||||
):
|
||||
setattr(other, field, getattr(one, field) + getattr(other, field))
|
||||
elif isinstance(getattr(one, field), ormar.Model):
|
||||
if getattr(one, field).pk == getattr(other, field).pk:
|
||||
setattr(
|
||||
other,
|
||||
field,
|
||||
cls.merge_two_instances(
|
||||
getattr(one, field), getattr(other, field)
|
||||
),
|
||||
)
|
||||
setattr(other, field, current_field + getattr(other, field))
|
||||
elif (
|
||||
isinstance(current_field, ormar.Model)
|
||||
and current_field.pk == getattr(other, field).pk
|
||||
):
|
||||
setattr(
|
||||
other,
|
||||
field,
|
||||
cls.merge_two_instances(current_field, getattr(other, field)),
|
||||
)
|
||||
return other
|
||||
|
||||
Reference in New Issue
Block a user