refactor required field in model fields into decorator

This commit is contained in:
collerek
2020-08-11 17:18:05 +02:00
parent 8e19a5b127
commit 704e83fed0
3 changed files with 51 additions and 46 deletions

BIN
.coverage

Binary file not shown.

View File

@ -6,7 +6,6 @@ import orm
from orm.exceptions import ModelDefinitionError, RelationshipInstanceError from orm.exceptions import ModelDefinitionError, RelationshipInstanceError
from pydantic import BaseModel, Json from pydantic import BaseModel, Json
from pydantic.fields import ModelField
import sqlalchemy import sqlalchemy
@ -14,6 +13,27 @@ if TYPE_CHECKING: # pragma no cover
from orm.models import Model from orm.models import Model
class RequiredParams:
def __init__(self, *args: str) -> None:
self._required = list(args)
def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]:
old_init = model_field_class.__init__
model_field_class._old_init = old_init
def __init__(instance: "BaseField", *args: Any, **kwargs: Any) -> None:
super(instance.__class__, instance).__init__(*args, **kwargs)
for arg in self._required:
if arg not in kwargs:
raise ModelDefinitionError(
f"{instance.__class__.__name__} field requires parameter: {arg}"
)
setattr(instance, arg, kwargs.pop(arg))
model_field_class.__init__ = __init__
return model_field_class
class BaseField: class BaseField:
__type__ = None __type__ = None
@ -95,17 +115,10 @@ class BaseField:
return value return value
@RequiredParams("length")
class String(BaseField): class String(BaseField):
__type__ = str __type__ = str
def __init__(self, *args: Any, **kwargs: Any) -> None:
if "length" not in kwargs:
raise ModelDefinitionError(
"Param length is required for String model field."
)
self.length = kwargs.pop("length")
super().__init__(*args, **kwargs)
def get_column_type(self) -> sqlalchemy.Column: def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.String(self.length) return sqlalchemy.String(self.length)
@ -173,18 +186,10 @@ class BigInteger(BaseField):
return sqlalchemy.BigInteger() return sqlalchemy.BigInteger()
@RequiredParams("length", "precision")
class Decimal(BaseField): class Decimal(BaseField):
__type__ = decimal.Decimal __type__ = decimal.Decimal
def __init__(self, *args: Any, **kwargs: Any) -> None:
if "length" not in kwargs or "precision" not in kwargs:
raise ModelDefinitionError(
"Params length and precision are required for Decimal model field."
)
self.length = kwargs.pop("length")
self.precision = kwargs.pop("precision")
super().__init__(*args, **kwargs)
def get_column_type(self) -> sqlalchemy.Column: def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.DECIMAL(self.length, self.precision) return sqlalchemy.DECIMAL(self.length, self.precision)

View File

@ -6,7 +6,6 @@ from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
from typing import Callable, Dict, Set from typing import Callable, Dict, Set
import databases import databases
from pydantic.fields import ModelField
import orm.queryset as qry import orm.queryset as qry
from orm.exceptions import ModelDefinitionError from orm.exceptions import ModelDefinitionError
@ -15,6 +14,7 @@ from orm.relations import RelationshipManager
import pydantic import pydantic
from pydantic import BaseConfig, BaseModel, create_model from pydantic import BaseConfig, BaseModel, create_model
from pydantic.fields import ModelField
import sqlalchemy import sqlalchemy
@ -42,10 +42,10 @@ def register_relation_on_build(table_name: str, field: ForeignKey, name: str) ->
) )
def expand_reverse_relationships(model: Type["Model"]): def expand_reverse_relationships(model: Type["Model"]) -> None:
for field_name, model_field in model.__model_fields__.items(): for model_field in model.__model_fields__.values():
if isinstance(model_field, ForeignKey): if isinstance(model_field, ForeignKey):
child_model_name = model_field.related_name or model.__name__.lower() + 's' child_model_name = model_field.related_name or model.__name__.lower() + "s"
parent_model = model_field.to parent_model = model_field.to
child = model child = model
if ( if (