refactor fields into classes
This commit is contained in:
@ -23,7 +23,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
|
|||||||
|
|
||||||
|
|
||||||
def ForeignKey(
|
def ForeignKey(
|
||||||
to: "Model",
|
to: Type["Model"],
|
||||||
*,
|
*,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
unique: bool = False,
|
unique: bool = False,
|
||||||
|
|||||||
@ -1,11 +1,9 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import decimal
|
import decimal
|
||||||
import re
|
|
||||||
from typing import Any, Optional, Type
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from pydantic import Json
|
|
||||||
|
|
||||||
from ormar import ModelDefinitionError # noqa I101
|
from ormar import ModelDefinitionError # noqa I101
|
||||||
from ormar.fields.base import BaseField # noqa I101
|
from ormar.fields.base import BaseField # noqa I101
|
||||||
@ -19,325 +17,223 @@ def is_field_nullable(
|
|||||||
return nullable
|
return nullable
|
||||||
|
|
||||||
|
|
||||||
def String(
|
class ModelFieldFactory:
|
||||||
|
_bases = None
|
||||||
|
_type = None
|
||||||
|
|
||||||
|
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]:
|
||||||
|
cls.validate(**kwargs)
|
||||||
|
|
||||||
|
default = kwargs.pop("default", None)
|
||||||
|
server_default = kwargs.pop("server_default", None)
|
||||||
|
nullable = kwargs.pop("nullable", None)
|
||||||
|
|
||||||
|
namespace = dict(
|
||||||
|
__type__=cls._type,
|
||||||
|
name=kwargs.pop("name", None),
|
||||||
|
primary_key=kwargs.pop("primary_key", False),
|
||||||
|
default=default,
|
||||||
|
server_default=server_default,
|
||||||
|
nullable=is_field_nullable(nullable, default, server_default),
|
||||||
|
index=kwargs.pop("index", False),
|
||||||
|
unique=kwargs.pop("unique", False),
|
||||||
|
pydantic_only=kwargs.pop("pydantic_only", False),
|
||||||
|
autoincrement=kwargs.pop("autoincrement", False),
|
||||||
|
column_type=cls.get_column_type(**kwargs),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return type(cls.__name__, cls._bases, namespace)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_column_type(cls, **kwargs: Any) -> Any: # pragma no cover
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, **kwargs: Any) -> None: # pragma no cover
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class String(ModelFieldFactory):
|
||||||
|
_bases = (pydantic.ConstrainedStr, BaseField)
|
||||||
|
_type = str
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
*,
|
*,
|
||||||
name: str = None,
|
|
||||||
primary_key: bool = False,
|
|
||||||
nullable: bool = None,
|
|
||||||
index: bool = False,
|
|
||||||
unique: bool = False,
|
|
||||||
allow_blank: bool = False,
|
allow_blank: bool = False,
|
||||||
strip_whitespace: bool = False,
|
strip_whitespace: bool = False,
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
curtail_length: int = None,
|
curtail_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
pydantic_only: bool = False,
|
**kwargs: Any
|
||||||
default: Any = None,
|
) -> Type[str]:
|
||||||
server_default: Any = None,
|
kwargs = {
|
||||||
) -> Type[str]:
|
**kwargs,
|
||||||
if max_length is None or max_length <= 0:
|
**{
|
||||||
raise ModelDefinitionError("Parameter max_length is required for field String")
|
k: v
|
||||||
|
for k, v in locals().items()
|
||||||
|
if k not in ["cls", "__class__", "kwargs"]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return super().__new__(cls, **kwargs)
|
||||||
|
|
||||||
namespace = dict(
|
@classmethod
|
||||||
__type__=str,
|
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||||
name=name,
|
return sqlalchemy.String(length=kwargs.get("max_length"))
|
||||||
primary_key=primary_key,
|
|
||||||
nullable=is_field_nullable(nullable, default, server_default),
|
@classmethod
|
||||||
index=index,
|
def validate(cls, **kwargs: Any) -> None:
|
||||||
unique=unique,
|
max_length = kwargs.get("max_length", None)
|
||||||
allow_blank=allow_blank,
|
if max_length is None or max_length <= 0:
|
||||||
strip_whitespace=strip_whitespace,
|
raise ModelDefinitionError(
|
||||||
min_length=min_length,
|
"Parameter max_length is required for field String"
|
||||||
max_length=max_length,
|
|
||||||
curtail_length=curtail_length,
|
|
||||||
regex=regex and re.compile(regex),
|
|
||||||
column_type=sqlalchemy.String(length=max_length),
|
|
||||||
pydantic_only=pydantic_only,
|
|
||||||
default=default,
|
|
||||||
server_default=server_default,
|
|
||||||
autoincrement=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return type("String", (pydantic.ConstrainedStr, BaseField), namespace)
|
|
||||||
|
|
||||||
|
class Integer(ModelFieldFactory):
|
||||||
|
_bases = (pydantic.ConstrainedInt, BaseField)
|
||||||
|
_type = int
|
||||||
|
|
||||||
def Integer(
|
def __new__(
|
||||||
|
cls,
|
||||||
*,
|
*,
|
||||||
name: str = None,
|
|
||||||
primary_key: bool = False,
|
|
||||||
autoincrement: bool = None,
|
|
||||||
nullable: bool = None,
|
|
||||||
index: bool = False,
|
|
||||||
unique: bool = False,
|
|
||||||
minimum: int = None,
|
minimum: int = None,
|
||||||
maximum: int = None,
|
maximum: int = None,
|
||||||
multiple_of: int = None,
|
multiple_of: int = None,
|
||||||
pydantic_only: bool = False,
|
**kwargs: Any
|
||||||
default: Any = None,
|
) -> Type[int]:
|
||||||
server_default: Any = None,
|
autoincrement = kwargs.pop("autoincrement", None)
|
||||||
) -> Type[int]:
|
autoincrement = (
|
||||||
namespace = dict(
|
autoincrement
|
||||||
__type__=int,
|
if autoincrement is not None
|
||||||
name=name,
|
else kwargs.get("primary_key", False)
|
||||||
primary_key=primary_key,
|
|
||||||
nullable=is_field_nullable(nullable, default, server_default),
|
|
||||||
index=index,
|
|
||||||
unique=unique,
|
|
||||||
ge=minimum,
|
|
||||||
le=maximum,
|
|
||||||
multiple_of=multiple_of,
|
|
||||||
column_type=sqlalchemy.Integer(),
|
|
||||||
pydantic_only=pydantic_only,
|
|
||||||
default=default,
|
|
||||||
server_default=server_default,
|
|
||||||
autoincrement=autoincrement if autoincrement is not None else primary_key,
|
|
||||||
)
|
)
|
||||||
return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace)
|
kwargs = {
|
||||||
|
**kwargs,
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in locals().items()
|
||||||
|
if k not in ["cls", "__class__", "kwargs"]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return super().__new__(cls, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||||
|
return sqlalchemy.Integer()
|
||||||
|
|
||||||
|
|
||||||
def Text(
|
class Text(ModelFieldFactory):
|
||||||
|
_bases = (pydantic.ConstrainedStr, BaseField)
|
||||||
|
_type = str
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls, *, allow_blank: bool = False, strip_whitespace: bool = False, **kwargs: Any
|
||||||
|
) -> Type[str]:
|
||||||
|
kwargs = {
|
||||||
|
**kwargs,
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in locals().items()
|
||||||
|
if k not in ["cls", "__class__", "kwargs"]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return super().__new__(cls, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||||
|
return sqlalchemy.Text()
|
||||||
|
|
||||||
|
|
||||||
|
class Float(ModelFieldFactory):
|
||||||
|
_bases = (pydantic.ConstrainedFloat, BaseField)
|
||||||
|
_type = float
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
*,
|
*,
|
||||||
name: str = None,
|
|
||||||
primary_key: bool = False,
|
|
||||||
nullable: bool = None,
|
|
||||||
index: bool = False,
|
|
||||||
unique: bool = False,
|
|
||||||
allow_blank: bool = False,
|
|
||||||
strip_whitespace: bool = False,
|
|
||||||
pydantic_only: bool = False,
|
|
||||||
default: Any = None,
|
|
||||||
server_default: Any = None,
|
|
||||||
) -> Type[str]:
|
|
||||||
namespace = dict(
|
|
||||||
__type__=str,
|
|
||||||
name=name,
|
|
||||||
primary_key=primary_key,
|
|
||||||
nullable=is_field_nullable(nullable, default, server_default),
|
|
||||||
index=index,
|
|
||||||
unique=unique,
|
|
||||||
allow_blank=allow_blank,
|
|
||||||
strip_whitespace=strip_whitespace,
|
|
||||||
column_type=sqlalchemy.Text(),
|
|
||||||
pydantic_only=pydantic_only,
|
|
||||||
default=default,
|
|
||||||
server_default=server_default,
|
|
||||||
autoincrement=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return type("Text", (pydantic.ConstrainedStr, BaseField), namespace)
|
|
||||||
|
|
||||||
|
|
||||||
def Float(
|
|
||||||
*,
|
|
||||||
name: str = None,
|
|
||||||
primary_key: bool = False,
|
|
||||||
nullable: bool = None,
|
|
||||||
index: bool = False,
|
|
||||||
unique: bool = False,
|
|
||||||
minimum: float = None,
|
minimum: float = None,
|
||||||
maximum: float = None,
|
maximum: float = None,
|
||||||
multiple_of: int = None,
|
multiple_of: int = None,
|
||||||
pydantic_only: bool = False,
|
**kwargs: Any
|
||||||
default: Any = None,
|
) -> Type[int]:
|
||||||
server_default: Any = None,
|
kwargs = {
|
||||||
) -> Type[int]:
|
**kwargs,
|
||||||
namespace = dict(
|
**{
|
||||||
__type__=float,
|
k: v
|
||||||
name=name,
|
for k, v in locals().items()
|
||||||
primary_key=primary_key,
|
if k not in ["cls", "__class__", "kwargs"]
|
||||||
nullable=is_field_nullable(nullable, default, server_default),
|
},
|
||||||
index=index,
|
}
|
||||||
unique=unique,
|
return super().__new__(cls, **kwargs)
|
||||||
ge=minimum,
|
|
||||||
le=maximum,
|
@classmethod
|
||||||
multiple_of=multiple_of,
|
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||||
column_type=sqlalchemy.Float(),
|
return sqlalchemy.Float()
|
||||||
pydantic_only=pydantic_only,
|
|
||||||
default=default,
|
|
||||||
server_default=server_default,
|
|
||||||
autoincrement=False,
|
|
||||||
)
|
|
||||||
return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace)
|
|
||||||
|
|
||||||
|
|
||||||
def Boolean(
|
class Boolean(ModelFieldFactory):
|
||||||
|
_bases = (int, BaseField)
|
||||||
|
_type = bool
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||||
|
return sqlalchemy.Boolean()
|
||||||
|
|
||||||
|
|
||||||
|
class DateTime(ModelFieldFactory):
|
||||||
|
_bases = (datetime.datetime, BaseField)
|
||||||
|
_type = datetime.datetime
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||||
|
return sqlalchemy.DateTime()
|
||||||
|
|
||||||
|
|
||||||
|
class Date(ModelFieldFactory):
|
||||||
|
_bases = (datetime.date, BaseField)
|
||||||
|
_type = datetime.date
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||||
|
return sqlalchemy.Date()
|
||||||
|
|
||||||
|
|
||||||
|
class Time(ModelFieldFactory):
|
||||||
|
_bases = (datetime.time, BaseField)
|
||||||
|
_type = datetime.time
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||||
|
return sqlalchemy.Time()
|
||||||
|
|
||||||
|
|
||||||
|
class JSON(ModelFieldFactory):
|
||||||
|
_bases = (pydantic.Json, BaseField)
|
||||||
|
_type = pydantic.Json
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||||
|
return sqlalchemy.JSON()
|
||||||
|
|
||||||
|
|
||||||
|
class BigInteger(Integer):
|
||||||
|
_bases = (pydantic.ConstrainedInt, BaseField)
|
||||||
|
_type = int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||||
|
return sqlalchemy.BigInteger()
|
||||||
|
|
||||||
|
|
||||||
|
class Decimal(ModelFieldFactory):
|
||||||
|
_bases = (pydantic.ConstrainedDecimal, BaseField)
|
||||||
|
_type = decimal.Decimal
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
*,
|
*,
|
||||||
name: str = None,
|
|
||||||
primary_key: bool = False,
|
|
||||||
nullable: bool = None,
|
|
||||||
index: bool = False,
|
|
||||||
unique: bool = False,
|
|
||||||
pydantic_only: bool = False,
|
|
||||||
default: Any = None,
|
|
||||||
server_default: Any = None,
|
|
||||||
) -> Type[bool]:
|
|
||||||
namespace = dict(
|
|
||||||
__type__=bool,
|
|
||||||
name=name,
|
|
||||||
primary_key=primary_key,
|
|
||||||
nullable=is_field_nullable(nullable, default, server_default),
|
|
||||||
index=index,
|
|
||||||
unique=unique,
|
|
||||||
column_type=sqlalchemy.Boolean(),
|
|
||||||
pydantic_only=pydantic_only,
|
|
||||||
default=default,
|
|
||||||
server_default=server_default,
|
|
||||||
autoincrement=False,
|
|
||||||
)
|
|
||||||
return type("Boolean", (int, BaseField), namespace)
|
|
||||||
|
|
||||||
|
|
||||||
def DateTime(
|
|
||||||
*,
|
|
||||||
name: str = None,
|
|
||||||
primary_key: bool = False,
|
|
||||||
nullable: bool = None,
|
|
||||||
index: bool = False,
|
|
||||||
unique: bool = False,
|
|
||||||
pydantic_only: bool = False,
|
|
||||||
default: Any = None,
|
|
||||||
server_default: Any = None,
|
|
||||||
) -> Type[datetime.datetime]:
|
|
||||||
namespace = dict(
|
|
||||||
__type__=datetime.datetime,
|
|
||||||
name=name,
|
|
||||||
primary_key=primary_key,
|
|
||||||
nullable=is_field_nullable(nullable, default, server_default),
|
|
||||||
index=index,
|
|
||||||
unique=unique,
|
|
||||||
column_type=sqlalchemy.DateTime(),
|
|
||||||
pydantic_only=pydantic_only,
|
|
||||||
default=default,
|
|
||||||
server_default=server_default,
|
|
||||||
autoincrement=False,
|
|
||||||
)
|
|
||||||
return type("DateTime", (datetime.datetime, BaseField), namespace)
|
|
||||||
|
|
||||||
|
|
||||||
def Date(
|
|
||||||
*,
|
|
||||||
name: str = None,
|
|
||||||
primary_key: bool = False,
|
|
||||||
nullable: bool = None,
|
|
||||||
index: bool = False,
|
|
||||||
unique: bool = False,
|
|
||||||
pydantic_only: bool = False,
|
|
||||||
default: Any = None,
|
|
||||||
server_default: Any = None,
|
|
||||||
) -> Type[datetime.date]:
|
|
||||||
namespace = dict(
|
|
||||||
__type__=datetime.date,
|
|
||||||
name=name,
|
|
||||||
primary_key=primary_key,
|
|
||||||
nullable=is_field_nullable(nullable, default, server_default),
|
|
||||||
index=index,
|
|
||||||
unique=unique,
|
|
||||||
column_type=sqlalchemy.Date(),
|
|
||||||
pydantic_only=pydantic_only,
|
|
||||||
default=default,
|
|
||||||
server_default=server_default,
|
|
||||||
autoincrement=False,
|
|
||||||
)
|
|
||||||
return type("Date", (datetime.date, BaseField), namespace)
|
|
||||||
|
|
||||||
|
|
||||||
def Time(
|
|
||||||
*,
|
|
||||||
name: str = None,
|
|
||||||
primary_key: bool = False,
|
|
||||||
nullable: bool = None,
|
|
||||||
index: bool = False,
|
|
||||||
unique: bool = False,
|
|
||||||
pydantic_only: bool = False,
|
|
||||||
default: Any = None,
|
|
||||||
server_default: Any = None,
|
|
||||||
) -> Type[datetime.time]:
|
|
||||||
namespace = dict(
|
|
||||||
__type__=datetime.time,
|
|
||||||
name=name,
|
|
||||||
primary_key=primary_key,
|
|
||||||
nullable=is_field_nullable(nullable, default, server_default),
|
|
||||||
index=index,
|
|
||||||
unique=unique,
|
|
||||||
column_type=sqlalchemy.Time(),
|
|
||||||
pydantic_only=pydantic_only,
|
|
||||||
default=default,
|
|
||||||
server_default=server_default,
|
|
||||||
autoincrement=False,
|
|
||||||
)
|
|
||||||
return type("Time", (datetime.time, BaseField), namespace)
|
|
||||||
|
|
||||||
|
|
||||||
def JSON(
|
|
||||||
*,
|
|
||||||
name: str = None,
|
|
||||||
primary_key: bool = False,
|
|
||||||
nullable: bool = None,
|
|
||||||
index: bool = False,
|
|
||||||
unique: bool = False,
|
|
||||||
pydantic_only: bool = False,
|
|
||||||
default: Any = None,
|
|
||||||
server_default: Any = None,
|
|
||||||
) -> Type[Json]:
|
|
||||||
namespace = dict(
|
|
||||||
__type__=pydantic.Json,
|
|
||||||
name=name,
|
|
||||||
primary_key=primary_key,
|
|
||||||
nullable=is_field_nullable(nullable, default, server_default),
|
|
||||||
index=index,
|
|
||||||
unique=unique,
|
|
||||||
column_type=sqlalchemy.JSON(),
|
|
||||||
pydantic_only=pydantic_only,
|
|
||||||
default=default,
|
|
||||||
server_default=server_default,
|
|
||||||
autoincrement=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return type("JSON", (pydantic.Json, BaseField), namespace)
|
|
||||||
|
|
||||||
|
|
||||||
def BigInteger(
|
|
||||||
*,
|
|
||||||
name: str = None,
|
|
||||||
primary_key: bool = False,
|
|
||||||
autoincrement: bool = None,
|
|
||||||
nullable: bool = None,
|
|
||||||
index: bool = False,
|
|
||||||
unique: bool = False,
|
|
||||||
minimum: int = None,
|
|
||||||
maximum: int = None,
|
|
||||||
multiple_of: int = None,
|
|
||||||
pydantic_only: bool = False,
|
|
||||||
default: Any = None,
|
|
||||||
server_default: Any = None,
|
|
||||||
) -> Type[int]:
|
|
||||||
namespace = dict(
|
|
||||||
__type__=int,
|
|
||||||
name=name,
|
|
||||||
primary_key=primary_key,
|
|
||||||
nullable=is_field_nullable(nullable, default, server_default),
|
|
||||||
index=index,
|
|
||||||
unique=unique,
|
|
||||||
ge=minimum,
|
|
||||||
le=maximum,
|
|
||||||
multiple_of=multiple_of,
|
|
||||||
column_type=sqlalchemy.BigInteger(),
|
|
||||||
pydantic_only=pydantic_only,
|
|
||||||
default=default,
|
|
||||||
server_default=server_default,
|
|
||||||
autoincrement=autoincrement if autoincrement is not None else primary_key,
|
|
||||||
)
|
|
||||||
return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace)
|
|
||||||
|
|
||||||
|
|
||||||
def Decimal(
|
|
||||||
*,
|
|
||||||
name: str = None,
|
|
||||||
primary_key: bool = False,
|
|
||||||
nullable: bool = None,
|
|
||||||
index: bool = False,
|
|
||||||
unique: bool = False,
|
|
||||||
minimum: float = None,
|
minimum: float = None,
|
||||||
maximum: float = None,
|
maximum: float = None,
|
||||||
multiple_of: int = None,
|
multiple_of: int = None,
|
||||||
@ -345,33 +241,29 @@ def Decimal(
|
|||||||
scale: int = None,
|
scale: int = None,
|
||||||
max_digits: int = None,
|
max_digits: int = None,
|
||||||
decimal_places: int = None,
|
decimal_places: int = None,
|
||||||
pydantic_only: bool = False,
|
**kwargs: Any
|
||||||
default: Any = None,
|
) -> Type[decimal.Decimal]:
|
||||||
server_default: Any = None,
|
kwargs = {
|
||||||
) -> Type[decimal.Decimal]:
|
**kwargs,
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in locals().items()
|
||||||
|
if k not in ["cls", "__class__", "kwargs"]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return super().__new__(cls, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||||
|
precision = kwargs.get("precision")
|
||||||
|
scale = kwargs.get("scale")
|
||||||
|
return sqlalchemy.DECIMAL(precision=precision, scale=scale)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, **kwargs: Any) -> None:
|
||||||
|
precision = kwargs.get("precision")
|
||||||
|
scale = kwargs.get("scale")
|
||||||
if precision is None or precision < 0 or scale is None or scale < 0:
|
if precision is None or precision < 0 or scale is None or scale < 0:
|
||||||
raise ModelDefinitionError(
|
raise ModelDefinitionError(
|
||||||
"Parameters scale and precision are required for field Decimal"
|
"Parameters scale and precision are required for field Decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
namespace = dict(
|
|
||||||
__type__=decimal.Decimal,
|
|
||||||
name=name,
|
|
||||||
primary_key=primary_key,
|
|
||||||
nullable=is_field_nullable(nullable, default, server_default),
|
|
||||||
index=index,
|
|
||||||
unique=unique,
|
|
||||||
ge=minimum,
|
|
||||||
le=maximum,
|
|
||||||
multiple_of=multiple_of,
|
|
||||||
column_type=sqlalchemy.types.DECIMAL(precision=precision, scale=scale),
|
|
||||||
precision=precision,
|
|
||||||
scale=scale,
|
|
||||||
max_digits=max_digits,
|
|
||||||
decimal_places=decimal_places,
|
|
||||||
pydantic_only=pydantic_only,
|
|
||||||
default=default,
|
|
||||||
server_default=server_default,
|
|
||||||
autoincrement=False,
|
|
||||||
)
|
|
||||||
return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace)
|
|
||||||
|
|||||||
@ -240,10 +240,9 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
|
|||||||
k: v for k, v in self_fields.items() if k in self.Meta.table.columns
|
k: v for k, v in self_fields.items() if k in self.Meta.table.columns
|
||||||
}
|
}
|
||||||
for field in self._extract_related_names():
|
for field in self._extract_related_names():
|
||||||
|
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
|
||||||
if getattr(self, field) is not None:
|
if getattr(self, field) is not None:
|
||||||
self_fields[field] = getattr(
|
self_fields[field] = getattr(getattr(self, field), target_pk_name)
|
||||||
getattr(self, field), self.Meta.model_fields[field].to.Meta.pkname
|
|
||||||
)
|
|
||||||
return self_fields
|
return self_fields
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -259,17 +258,18 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
|
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
|
||||||
for field in one.Meta.model_fields.keys():
|
for field in one.Meta.model_fields.keys():
|
||||||
if isinstance(getattr(one, field), list) and not isinstance(
|
current_field = getattr(one, field)
|
||||||
getattr(one, field), ormar.Model
|
if isinstance(current_field, list) and not isinstance(
|
||||||
|
current_field, ormar.Model
|
||||||
|
):
|
||||||
|
setattr(other, field, current_field + getattr(other, field))
|
||||||
|
elif (
|
||||||
|
isinstance(current_field, ormar.Model)
|
||||||
|
and current_field.pk == getattr(other, field).pk
|
||||||
):
|
):
|
||||||
setattr(other, field, getattr(one, field) + getattr(other, field))
|
|
||||||
elif isinstance(getattr(one, field), ormar.Model):
|
|
||||||
if getattr(one, field).pk == getattr(other, field).pk:
|
|
||||||
setattr(
|
setattr(
|
||||||
other,
|
other,
|
||||||
field,
|
field,
|
||||||
cls.merge_two_instances(
|
cls.merge_two_instances(current_field, getattr(other, field)),
|
||||||
getattr(one, field), getattr(other, field)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
return other
|
return other
|
||||||
|
|||||||
@ -76,6 +76,7 @@ def event_loop():
|
|||||||
@pytest.fixture(autouse=True, scope="module")
|
@pytest.fixture(autouse=True, scope="module")
|
||||||
async def create_test_database():
|
async def create_test_database():
|
||||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||||
|
metadata.drop_all(engine)
|
||||||
metadata.create_all(engine)
|
metadata.create_all(engine)
|
||||||
department = await Department.objects.create(id=1, name="Math Department")
|
department = await Department.objects.create(id=1, name="Math Department")
|
||||||
class1 = await SchoolClass.objects.create(name="Math", department=department)
|
class1 = await SchoolClass.objects.create(name="Math", department=department)
|
||||||
|
|||||||
Reference in New Issue
Block a user