refactor fields into classes

This commit is contained in:
collerek
2020-08-23 16:14:04 +02:00
parent 806fe9b63e
commit 348a3d90dc
5 changed files with 262 additions and 369 deletions

BIN
.coverage

Binary file not shown.

View File

@ -23,7 +23,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
def ForeignKey(
to: "Model",
to: Type["Model"],
*,
name: str = None,
unique: bool = False,

View File

@ -1,11 +1,9 @@
import datetime
import decimal
import re
from typing import Any, Optional, Type
import pydantic
import sqlalchemy
from pydantic import Json
from ormar import ModelDefinitionError # noqa I101
from ormar.fields.base import BaseField # noqa I101
@ -19,359 +17,253 @@ def is_field_nullable(
return nullable
def String(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
allow_blank: bool = False,
strip_whitespace: bool = False,
min_length: int = None,
max_length: int = None,
curtail_length: int = None,
regex: str = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[str]:
if max_length is None or max_length <= 0:
raise ModelDefinitionError("Parameter max_length is required for field String")
class ModelFieldFactory:
_bases = None
_type = None
namespace = dict(
__type__=str,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
allow_blank=allow_blank,
strip_whitespace=strip_whitespace,
min_length=min_length,
max_length=max_length,
curtail_length=curtail_length,
regex=regex and re.compile(regex),
column_type=sqlalchemy.String(length=max_length),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]:
cls.validate(**kwargs)
return type("String", (pydantic.ConstrainedStr, BaseField), namespace)
default = kwargs.pop("default", None)
server_default = kwargs.pop("server_default", None)
nullable = kwargs.pop("nullable", None)
def Integer(
*,
name: str = None,
primary_key: bool = False,
autoincrement: bool = None,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[int]:
namespace = dict(
__type__=int,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
ge=minimum,
le=maximum,
multiple_of=multiple_of,
column_type=sqlalchemy.Integer(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=autoincrement if autoincrement is not None else primary_key,
)
return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace)
def Text(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
allow_blank: bool = False,
strip_whitespace: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[str]:
namespace = dict(
__type__=str,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
allow_blank=allow_blank,
strip_whitespace=strip_whitespace,
column_type=sqlalchemy.Text(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("Text", (pydantic.ConstrainedStr, BaseField), namespace)
def Float(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[int]:
namespace = dict(
__type__=float,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
ge=minimum,
le=maximum,
multiple_of=multiple_of,
column_type=sqlalchemy.Float(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace)
def Boolean(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[bool]:
namespace = dict(
__type__=bool,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.Boolean(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("Boolean", (int, BaseField), namespace)
def DateTime(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[datetime.datetime]:
namespace = dict(
__type__=datetime.datetime,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.DateTime(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("DateTime", (datetime.datetime, BaseField), namespace)
def Date(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[datetime.date]:
namespace = dict(
__type__=datetime.date,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.Date(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("Date", (datetime.date, BaseField), namespace)
def Time(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[datetime.time]:
namespace = dict(
__type__=datetime.time,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.Time(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("Time", (datetime.time, BaseField), namespace)
def JSON(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[Json]:
namespace = dict(
__type__=pydantic.Json,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.JSON(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("JSON", (pydantic.Json, BaseField), namespace)
def BigInteger(
*,
name: str = None,
primary_key: bool = False,
autoincrement: bool = None,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[int]:
namespace = dict(
__type__=int,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
ge=minimum,
le=maximum,
multiple_of=multiple_of,
column_type=sqlalchemy.BigInteger(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=autoincrement if autoincrement is not None else primary_key,
)
return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace)
def Decimal(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
precision: int = None,
scale: int = None,
max_digits: int = None,
decimal_places: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[decimal.Decimal]:
if precision is None or precision < 0 or scale is None or scale < 0:
raise ModelDefinitionError(
"Parameters scale and precision are required for field Decimal"
namespace = dict(
__type__=cls._type,
name=kwargs.pop("name", None),
primary_key=kwargs.pop("primary_key", False),
default=default,
server_default=server_default,
nullable=is_field_nullable(nullable, default, server_default),
index=kwargs.pop("index", False),
unique=kwargs.pop("unique", False),
pydantic_only=kwargs.pop("pydantic_only", False),
autoincrement=kwargs.pop("autoincrement", False),
column_type=cls.get_column_type(**kwargs),
**kwargs
)
return type(cls.__name__, cls._bases, namespace)
namespace = dict(
__type__=decimal.Decimal,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
ge=minimum,
le=maximum,
multiple_of=multiple_of,
column_type=sqlalchemy.types.DECIMAL(precision=precision, scale=scale),
precision=precision,
scale=scale,
max_digits=max_digits,
decimal_places=decimal_places,
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any: # pragma no cover
return None
@classmethod
def validate(cls, **kwargs: Any) -> None: # pragma no cover
pass
class String(ModelFieldFactory):
_bases = (pydantic.ConstrainedStr, BaseField)
_type = str
def __new__(
cls,
*,
allow_blank: bool = False,
strip_whitespace: bool = False,
min_length: int = None,
max_length: int = None,
curtail_length: int = None,
regex: str = None,
**kwargs: Any
) -> Type[str]:
kwargs = {
**kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.String(length=kwargs.get("max_length"))
@classmethod
def validate(cls, **kwargs: Any) -> None:
max_length = kwargs.get("max_length", None)
if max_length is None or max_length <= 0:
raise ModelDefinitionError(
"Parameter max_length is required for field String"
)
class Integer(ModelFieldFactory):
_bases = (pydantic.ConstrainedInt, BaseField)
_type = int
def __new__(
cls,
*,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
**kwargs: Any
) -> Type[int]:
autoincrement = kwargs.pop("autoincrement", None)
autoincrement = (
autoincrement
if autoincrement is not None
else kwargs.get("primary_key", False)
)
kwargs = {
**kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Integer()
class Text(ModelFieldFactory):
_bases = (pydantic.ConstrainedStr, BaseField)
_type = str
def __new__(
cls, *, allow_blank: bool = False, strip_whitespace: bool = False, **kwargs: Any
) -> Type[str]:
kwargs = {
**kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Text()
class Float(ModelFieldFactory):
_bases = (pydantic.ConstrainedFloat, BaseField)
_type = float
def __new__(
cls,
*,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
**kwargs: Any
) -> Type[int]:
kwargs = {
**kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Float()
class Boolean(ModelFieldFactory):
_bases = (int, BaseField)
_type = bool
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Boolean()
class DateTime(ModelFieldFactory):
_bases = (datetime.datetime, BaseField)
_type = datetime.datetime
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.DateTime()
class Date(ModelFieldFactory):
_bases = (datetime.date, BaseField)
_type = datetime.date
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Date()
class Time(ModelFieldFactory):
_bases = (datetime.time, BaseField)
_type = datetime.time
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Time()
class JSON(ModelFieldFactory):
_bases = (pydantic.Json, BaseField)
_type = pydantic.Json
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.JSON()
class BigInteger(Integer):
_bases = (pydantic.ConstrainedInt, BaseField)
_type = int
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.BigInteger()
class Decimal(ModelFieldFactory):
_bases = (pydantic.ConstrainedDecimal, BaseField)
_type = decimal.Decimal
def __new__(
cls,
*,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
precision: int = None,
scale: int = None,
max_digits: int = None,
decimal_places: int = None,
**kwargs: Any
) -> Type[decimal.Decimal]:
kwargs = {
**kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
precision = kwargs.get("precision")
scale = kwargs.get("scale")
return sqlalchemy.DECIMAL(precision=precision, scale=scale)
@classmethod
def validate(cls, **kwargs: Any) -> None:
precision = kwargs.get("precision")
scale = kwargs.get("scale")
if precision is None or precision < 0 or scale is None or scale < 0:
raise ModelDefinitionError(
"Parameters scale and precision are required for field Decimal"
)

View File

@ -240,10 +240,9 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
k: v for k, v in self_fields.items() if k in self.Meta.table.columns
}
for field in self._extract_related_names():
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
if getattr(self, field) is not None:
self_fields[field] = getattr(
getattr(self, field), self.Meta.model_fields[field].to.Meta.pkname
)
self_fields[field] = getattr(getattr(self, field), target_pk_name)
return self_fields
@classmethod
@ -259,17 +258,18 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
@classmethod
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
for field in one.Meta.model_fields.keys():
if isinstance(getattr(one, field), list) and not isinstance(
getattr(one, field), ormar.Model
current_field = getattr(one, field)
if isinstance(current_field, list) and not isinstance(
current_field, ormar.Model
):
setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), ormar.Model):
if getattr(one, field).pk == getattr(other, field).pk:
setattr(
other,
field,
cls.merge_two_instances(
getattr(one, field), getattr(other, field)
),
)
setattr(other, field, current_field + getattr(other, field))
elif (
isinstance(current_field, ormar.Model)
and current_field.pk == getattr(other, field).pk
):
setattr(
other,
field,
cls.merge_two_instances(current_field, getattr(other, field)),
)
return other

View File

@ -76,6 +76,7 @@ def event_loop():
@pytest.fixture(autouse=True, scope="module")
async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
department = await Department.objects.create(id=1, name="Math Department")
class1 = await SchoolClass.objects.create(name="Math", department=department)