refactor fields into classes

This commit is contained in:
collerek
2020-08-23 16:14:04 +02:00
parent 806fe9b63e
commit 348a3d90dc
5 changed files with 262 additions and 369 deletions

BIN
.coverage

Binary file not shown.

View File

@ -23,7 +23,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
def ForeignKey( def ForeignKey(
to: "Model", to: Type["Model"],
*, *,
name: str = None, name: str = None,
unique: bool = False, unique: bool = False,

View File

@ -1,11 +1,9 @@
import datetime import datetime
import decimal import decimal
import re
from typing import Any, Optional, Type from typing import Any, Optional, Type
import pydantic import pydantic
import sqlalchemy import sqlalchemy
from pydantic import Json
from ormar import ModelDefinitionError # noqa I101 from ormar import ModelDefinitionError # noqa I101
from ormar.fields.base import BaseField # noqa I101 from ormar.fields.base import BaseField # noqa I101
@ -19,325 +17,223 @@ def is_field_nullable(
return nullable return nullable
def String( class ModelFieldFactory:
_bases = None
_type = None
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]:
cls.validate(**kwargs)
default = kwargs.pop("default", None)
server_default = kwargs.pop("server_default", None)
nullable = kwargs.pop("nullable", None)
namespace = dict(
__type__=cls._type,
name=kwargs.pop("name", None),
primary_key=kwargs.pop("primary_key", False),
default=default,
server_default=server_default,
nullable=is_field_nullable(nullable, default, server_default),
index=kwargs.pop("index", False),
unique=kwargs.pop("unique", False),
pydantic_only=kwargs.pop("pydantic_only", False),
autoincrement=kwargs.pop("autoincrement", False),
column_type=cls.get_column_type(**kwargs),
**kwargs
)
return type(cls.__name__, cls._bases, namespace)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any: # pragma no cover
return None
@classmethod
def validate(cls, **kwargs: Any) -> None: # pragma no cover
pass
class String(ModelFieldFactory):
_bases = (pydantic.ConstrainedStr, BaseField)
_type = str
def __new__(
cls,
*, *,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
allow_blank: bool = False, allow_blank: bool = False,
strip_whitespace: bool = False, strip_whitespace: bool = False,
min_length: int = None, min_length: int = None,
max_length: int = None, max_length: int = None,
curtail_length: int = None, curtail_length: int = None,
regex: str = None, regex: str = None,
pydantic_only: bool = False, **kwargs: Any
default: Any = None,
server_default: Any = None,
) -> Type[str]: ) -> Type[str]:
if max_length is None or max_length <= 0: kwargs = {
raise ModelDefinitionError("Parameter max_length is required for field String") **kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
namespace = dict( @classmethod
__type__=str, def get_column_type(cls, **kwargs: Any) -> Any:
name=name, return sqlalchemy.String(length=kwargs.get("max_length"))
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default), @classmethod
index=index, def validate(cls, **kwargs: Any) -> None:
unique=unique, max_length = kwargs.get("max_length", None)
allow_blank=allow_blank, if max_length is None or max_length <= 0:
strip_whitespace=strip_whitespace, raise ModelDefinitionError(
min_length=min_length, "Parameter max_length is required for field String"
max_length=max_length,
curtail_length=curtail_length,
regex=regex and re.compile(regex),
column_type=sqlalchemy.String(length=max_length),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
) )
return type("String", (pydantic.ConstrainedStr, BaseField), namespace)
class Integer(ModelFieldFactory):
_bases = (pydantic.ConstrainedInt, BaseField)
_type = int
def Integer( def __new__(
cls,
*, *,
name: str = None,
primary_key: bool = False,
autoincrement: bool = None,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: int = None, minimum: int = None,
maximum: int = None, maximum: int = None,
multiple_of: int = None, multiple_of: int = None,
pydantic_only: bool = False, **kwargs: Any
default: Any = None,
server_default: Any = None,
) -> Type[int]: ) -> Type[int]:
namespace = dict( autoincrement = kwargs.pop("autoincrement", None)
__type__=int, autoincrement = (
name=name, autoincrement
primary_key=primary_key, if autoincrement is not None
nullable=is_field_nullable(nullable, default, server_default), else kwargs.get("primary_key", False)
index=index,
unique=unique,
ge=minimum,
le=maximum,
multiple_of=multiple_of,
column_type=sqlalchemy.Integer(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=autoincrement if autoincrement is not None else primary_key,
) )
return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace) kwargs = {
**kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Integer()
def Text( class Text(ModelFieldFactory):
*, _bases = (pydantic.ConstrainedStr, BaseField)
name: str = None, _type = str
primary_key: bool = False,
nullable: bool = None, def __new__(
index: bool = False, cls, *, allow_blank: bool = False, strip_whitespace: bool = False, **kwargs: Any
unique: bool = False,
allow_blank: bool = False,
strip_whitespace: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[str]: ) -> Type[str]:
namespace = dict( kwargs = {
__type__=str, **kwargs,
name=name, **{
primary_key=primary_key, k: v
nullable=is_field_nullable(nullable, default, server_default), for k, v in locals().items()
index=index, if k not in ["cls", "__class__", "kwargs"]
unique=unique, },
allow_blank=allow_blank, }
strip_whitespace=strip_whitespace, return super().__new__(cls, **kwargs)
column_type=sqlalchemy.Text(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("Text", (pydantic.ConstrainedStr, BaseField), namespace) @classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Text()
def Float( class Float(ModelFieldFactory):
_bases = (pydantic.ConstrainedFloat, BaseField)
_type = float
def __new__(
cls,
*, *,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: float = None, minimum: float = None,
maximum: float = None, maximum: float = None,
multiple_of: int = None, multiple_of: int = None,
pydantic_only: bool = False, **kwargs: Any
default: Any = None,
server_default: Any = None,
) -> Type[int]: ) -> Type[int]:
namespace = dict( kwargs = {
__type__=float, **kwargs,
name=name, **{
primary_key=primary_key, k: v
nullable=is_field_nullable(nullable, default, server_default), for k, v in locals().items()
index=index, if k not in ["cls", "__class__", "kwargs"]
unique=unique, },
ge=minimum, }
le=maximum, return super().__new__(cls, **kwargs)
multiple_of=multiple_of,
column_type=sqlalchemy.Float(), @classmethod
pydantic_only=pydantic_only, def get_column_type(cls, **kwargs: Any) -> Any:
default=default, return sqlalchemy.Float()
server_default=server_default,
autoincrement=False,
)
return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace)
def Boolean( class Boolean(ModelFieldFactory):
_bases = (int, BaseField)
_type = bool
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Boolean()
class DateTime(ModelFieldFactory):
_bases = (datetime.datetime, BaseField)
_type = datetime.datetime
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.DateTime()
class Date(ModelFieldFactory):
_bases = (datetime.date, BaseField)
_type = datetime.date
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Date()
class Time(ModelFieldFactory):
_bases = (datetime.time, BaseField)
_type = datetime.time
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Time()
class JSON(ModelFieldFactory):
_bases = (pydantic.Json, BaseField)
_type = pydantic.Json
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.JSON()
class BigInteger(Integer):
_bases = (pydantic.ConstrainedInt, BaseField)
_type = int
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.BigInteger()
class Decimal(ModelFieldFactory):
_bases = (pydantic.ConstrainedDecimal, BaseField)
_type = decimal.Decimal
def __new__(
cls,
*, *,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[bool]:
namespace = dict(
__type__=bool,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.Boolean(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("Boolean", (int, BaseField), namespace)
def DateTime(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[datetime.datetime]:
namespace = dict(
__type__=datetime.datetime,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.DateTime(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("DateTime", (datetime.datetime, BaseField), namespace)
def Date(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[datetime.date]:
namespace = dict(
__type__=datetime.date,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.Date(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("Date", (datetime.date, BaseField), namespace)
def Time(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[datetime.time]:
namespace = dict(
__type__=datetime.time,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.Time(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("Time", (datetime.time, BaseField), namespace)
def JSON(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[Json]:
namespace = dict(
__type__=pydantic.Json,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.JSON(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("JSON", (pydantic.Json, BaseField), namespace)
def BigInteger(
*,
name: str = None,
primary_key: bool = False,
autoincrement: bool = None,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[int]:
namespace = dict(
__type__=int,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
ge=minimum,
le=maximum,
multiple_of=multiple_of,
column_type=sqlalchemy.BigInteger(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=autoincrement if autoincrement is not None else primary_key,
)
return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace)
def Decimal(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: float = None, minimum: float = None,
maximum: float = None, maximum: float = None,
multiple_of: int = None, multiple_of: int = None,
@ -345,33 +241,29 @@ def Decimal(
scale: int = None, scale: int = None,
max_digits: int = None, max_digits: int = None,
decimal_places: int = None, decimal_places: int = None,
pydantic_only: bool = False, **kwargs: Any
default: Any = None,
server_default: Any = None,
) -> Type[decimal.Decimal]: ) -> Type[decimal.Decimal]:
kwargs = {
**kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
precision = kwargs.get("precision")
scale = kwargs.get("scale")
return sqlalchemy.DECIMAL(precision=precision, scale=scale)
@classmethod
def validate(cls, **kwargs: Any) -> None:
precision = kwargs.get("precision")
scale = kwargs.get("scale")
if precision is None or precision < 0 or scale is None or scale < 0: if precision is None or precision < 0 or scale is None or scale < 0:
raise ModelDefinitionError( raise ModelDefinitionError(
"Parameters scale and precision are required for field Decimal" "Parameters scale and precision are required for field Decimal"
) )
namespace = dict(
__type__=decimal.Decimal,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
ge=minimum,
le=maximum,
multiple_of=multiple_of,
column_type=sqlalchemy.types.DECIMAL(precision=precision, scale=scale),
precision=precision,
scale=scale,
max_digits=max_digits,
decimal_places=decimal_places,
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False,
)
return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace)

View File

@ -240,10 +240,9 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
k: v for k, v in self_fields.items() if k in self.Meta.table.columns k: v for k, v in self_fields.items() if k in self.Meta.table.columns
} }
for field in self._extract_related_names(): for field in self._extract_related_names():
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
if getattr(self, field) is not None: if getattr(self, field) is not None:
self_fields[field] = getattr( self_fields[field] = getattr(getattr(self, field), target_pk_name)
getattr(self, field), self.Meta.model_fields[field].to.Meta.pkname
)
return self_fields return self_fields
@classmethod @classmethod
@ -259,17 +258,18 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
@classmethod @classmethod
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
for field in one.Meta.model_fields.keys(): for field in one.Meta.model_fields.keys():
if isinstance(getattr(one, field), list) and not isinstance( current_field = getattr(one, field)
getattr(one, field), ormar.Model if isinstance(current_field, list) and not isinstance(
current_field, ormar.Model
):
setattr(other, field, current_field + getattr(other, field))
elif (
isinstance(current_field, ormar.Model)
and current_field.pk == getattr(other, field).pk
): ):
setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), ormar.Model):
if getattr(one, field).pk == getattr(other, field).pk:
setattr( setattr(
other, other,
field, field,
cls.merge_two_instances( cls.merge_two_instances(current_field, getattr(other, field)),
getattr(one, field), getattr(other, field)
),
) )
return other return other

View File

@ -76,6 +76,7 @@ def event_loop():
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
async def create_test_database(): async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine) metadata.create_all(engine)
department = await Department.objects.create(id=1, name="Math Department") department = await Department.objects.create(id=1, name="Math Department")
class1 = await SchoolClass.objects.create(name="Math", department=department) class1 = await SchoolClass.objects.create(name="Math", department=department)