refactor required field in model fields into decorator

This commit is contained in:
collerek
2020-08-11 17:18:05 +02:00
parent 8e19a5b127
commit 704e83fed0
3 changed files with 51 additions and 46 deletions

BIN
.coverage

Binary file not shown.

View File

@ -6,7 +6,6 @@ import orm
from orm.exceptions import ModelDefinitionError, RelationshipInstanceError from orm.exceptions import ModelDefinitionError, RelationshipInstanceError
from pydantic import BaseModel, Json from pydantic import BaseModel, Json
from pydantic.fields import ModelField
import sqlalchemy import sqlalchemy
@ -14,6 +13,27 @@ if TYPE_CHECKING: # pragma no cover
from orm.models import Model from orm.models import Model
class RequiredParams:
def __init__(self, *args: str) -> None:
self._required = list(args)
def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]:
old_init = model_field_class.__init__
model_field_class._old_init = old_init
def __init__(instance: "BaseField", *args: Any, **kwargs: Any) -> None:
super(instance.__class__, instance).__init__(*args, **kwargs)
for arg in self._required:
if arg not in kwargs:
raise ModelDefinitionError(
f"{instance.__class__.__name__} field requires parameter: {arg}"
)
setattr(instance, arg, kwargs.pop(arg))
model_field_class.__init__ = __init__
return model_field_class
class BaseField: class BaseField:
__type__ = None __type__ = None
@ -51,7 +71,7 @@ class BaseField:
@property @property
def is_required(self) -> bool: def is_required(self) -> bool:
return ( return (
not self.nullable and not self.has_default and not self.is_auto_primary_key not self.nullable and not self.has_default and not self.is_auto_primary_key
) )
@property @property
@ -95,17 +115,10 @@ class BaseField:
return value return value
@RequiredParams("length")
class String(BaseField): class String(BaseField):
__type__ = str __type__ = str
def __init__(self, *args: Any, **kwargs: Any) -> None:
if "length" not in kwargs:
raise ModelDefinitionError(
"Param length is required for String model field."
)
self.length = kwargs.pop("length")
super().__init__(*args, **kwargs)
def get_column_type(self) -> sqlalchemy.Column: def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.String(self.length) return sqlalchemy.String(self.length)
@ -173,18 +186,10 @@ class BigInteger(BaseField):
return sqlalchemy.BigInteger() return sqlalchemy.BigInteger()
@RequiredParams("length", "precision")
class Decimal(BaseField): class Decimal(BaseField):
__type__ = decimal.Decimal __type__ = decimal.Decimal
def __init__(self, *args: Any, **kwargs: Any) -> None:
if "length" not in kwargs or "precision" not in kwargs:
raise ModelDefinitionError(
"Params length and precision are required for Decimal model field."
)
self.length = kwargs.pop("length")
self.precision = kwargs.pop("precision")
super().__init__(*args, **kwargs)
def get_column_type(self) -> sqlalchemy.Column: def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.DECIMAL(self.length, self.precision) return sqlalchemy.DECIMAL(self.length, self.precision)
@ -204,12 +209,12 @@ def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model":
class ForeignKey(BaseField): class ForeignKey(BaseField):
def __init__( def __init__(
self, self,
to: Type["Model"], to: Type["Model"],
name: str = None, name: str = None,
related_name: str = None, related_name: str = None,
nullable: bool = True, nullable: bool = True,
virtual: bool = False, virtual: bool = False,
) -> None: ) -> None:
super().__init__(nullable=nullable, name=name) super().__init__(nullable=nullable, name=name)
self.virtual = virtual self.virtual = virtual
@ -229,7 +234,7 @@ class ForeignKey(BaseField):
return to_column.get_column_type() return to_column.get_column_type()
def expand_relationship( def expand_relationship(
self, value: Any, child: "Model" self, value: Any, child: "Model"
) -> Union["Model", List["Model"]]: ) -> Union["Model", List["Model"]]:
if isinstance(value, orm.models.Model) and not isinstance(value, self.to): if isinstance(value, orm.models.Model) and not isinstance(value, self.to):

View File

@ -6,7 +6,6 @@ from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
from typing import Callable, Dict, Set from typing import Callable, Dict, Set
import databases import databases
from pydantic.fields import ModelField
import orm.queryset as qry import orm.queryset as qry
from orm.exceptions import ModelDefinitionError from orm.exceptions import ModelDefinitionError
@ -15,6 +14,7 @@ from orm.relations import RelationshipManager
import pydantic import pydantic
from pydantic import BaseConfig, BaseModel, create_model from pydantic import BaseConfig, BaseModel, create_model
from pydantic.fields import ModelField
import sqlalchemy import sqlalchemy
@ -42,21 +42,21 @@ def register_relation_on_build(table_name: str, field: ForeignKey, name: str) ->
) )
def expand_reverse_relationships(model: Type["Model"]): def expand_reverse_relationships(model: Type["Model"]) -> None:
for field_name, model_field in model.__model_fields__.items(): for model_field in model.__model_fields__.values():
if isinstance(model_field, ForeignKey): if isinstance(model_field, ForeignKey):
child_model_name = model_field.related_name or model.__name__.lower() + 's' child_model_name = model_field.related_name or model.__name__.lower() + "s"
parent_model = model_field.to parent_model = model_field.to
child = model child = model
if ( if (
child_model_name not in parent_model.__fields__ child_model_name not in parent_model.__fields__
and child.get_name() not in parent_model.__fields__ and child.get_name() not in parent_model.__fields__
): ):
register_reverse_model_fields(parent_model, child, child_model_name) register_reverse_model_fields(parent_model, child, child_model_name)
def register_reverse_model_fields( def register_reverse_model_fields(
model: Type["Model"], child: Type["Model"], child_model_name: str model: Type["Model"], child: Type["Model"], child_model_name: str
) -> None: ) -> None:
model.__fields__[child_model_name] = ModelField( model.__fields__[child_model_name] = ModelField(
name=child_model_name, name=child_model_name,
@ -70,7 +70,7 @@ def register_reverse_model_fields(
def sqlalchemy_columns_from_model_fields( def sqlalchemy_columns_from_model_fields(
name: str, object_dict: Dict, table_name: str name: str, object_dict: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
pkname: Optional[str] = None pkname: Optional[str] = None
columns: List[sqlalchemy.Column] = [] columns: List[sqlalchemy.Column] = []
@ -198,9 +198,9 @@ class FakePydantic(list, metaclass=ModelMetaclass):
item = getattr(self.values, key, None) item = getattr(self.values, key, None)
if ( if (
item is not None item is not None
and self._is_conversion_to_json_needed(key) and self._is_conversion_to_json_needed(key)
and isinstance(item, str) and isinstance(item, str)
): ):
try: try:
item = json.loads(item) item = json.loads(item)
@ -216,7 +216,7 @@ class FakePydantic(list, metaclass=ModelMetaclass):
if self.__class__ != other.__class__: # pragma no cover if self.__class__ != other.__class__: # pragma no cover
return False return False
return self._orm_id == other._orm_id or ( return self._orm_id == other._orm_id or (
self.values is not None and other.values is not None and self.pk == other.pk self.values is not None and other.values is not None and self.pk == other.pk
) )
def __repr__(self) -> str: # pragma no cover def __repr__(self) -> str: # pragma no cover
@ -272,7 +272,7 @@ class FakePydantic(list, metaclass=ModelMetaclass):
related_names = set() related_names = set()
for name, field in cls.__fields__.items(): for name, field in cls.__fields__.items():
if inspect.isclass(field.type_) and issubclass( if inspect.isclass(field.type_) and issubclass(
field.type_, pydantic.BaseModel field.type_, pydantic.BaseModel
): ):
related_names.add(name) related_names.add(name)
return related_names return related_names
@ -304,7 +304,7 @@ class FakePydantic(list, metaclass=ModelMetaclass):
for field in one.__model_fields__.keys(): for field in one.__model_fields__.keys():
# print(field, one.dict(), other.dict()) # print(field, one.dict(), other.dict())
if isinstance(getattr(one, field), list) and not isinstance( if isinstance(getattr(one, field), list) and not isinstance(
getattr(one, field), Model getattr(one, field), Model
): ):
setattr(other, field, getattr(one, field) + getattr(other, field)) setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), Model): elif isinstance(getattr(one, field), Model):
@ -326,10 +326,10 @@ class Model(FakePydantic):
@classmethod @classmethod
def from_row( def from_row(
cls, cls,
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
select_related: List = None, select_related: List = None,
previous_table: str = None, previous_table: str = None,
) -> "Model": ) -> "Model":
item = {} item = {}
@ -387,8 +387,8 @@ class Model(FakePydantic):
self_fields.pop(self.__pkname__) self_fields.pop(self.__pkname__)
expr = ( expr = (
self.__table__.update() self.__table__.update()
.values(**self_fields) .values(**self_fields)
.where(self.pk_column == getattr(self, self.__pkname__)) .where(self.pk_column == getattr(self, self.__pkname__))
) )
result = await self.__database__.execute(expr) result = await self.__database__.execute(expr)
return result return result