From 704e83fed0e6ca75cb5b4c859b55648900b76550 Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 11 Aug 2020 17:18:05 +0200 Subject: [PATCH] refactor required field in model fields into decorator --- .coverage | Bin 53248 -> 53248 bytes orm/fields.py | 57 +++++++++++++++++++++++++++----------------------- orm/models.py | 40 +++++++++++++++++------------------ 3 files changed, 51 insertions(+), 46 deletions(-) diff --git a/.coverage b/.coverage index 04d6ff895dc3fdaa51adcd2b47292b0028b57483..0c8e0091151cb63f7feb13a3d4d9a61ed1db424a 100644 GIT binary patch delta 116 zcmV-)0E_>CpaX!Q1F$MD2RAw~GdeLfvoSB#P#$*x5BU%358e;c52+7t4<8Q*4*(AG z4$2O&4vr3Yvk?$e4x@OE0Rg>}d5JPKik1ar3JTQp> delta 115 zcmV-(0F3{DpaX!Q1F$MD2Q@k}HaajcvoSB#P#$;y5BU%358e;c52_Dv4 None: + self._required = list(args) + + def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]: + old_init = model_field_class.__init__ + model_field_class._old_init = old_init + + def __init__(instance: "BaseField", *args: Any, **kwargs: Any) -> None: + super(instance.__class__, instance).__init__(*args, **kwargs) + for arg in self._required: + if arg not in kwargs: + raise ModelDefinitionError( + f"{instance.__class__.__name__} field requires parameter: {arg}" + ) + setattr(instance, arg, kwargs.pop(arg)) + + model_field_class.__init__ = __init__ + return model_field_class + + class BaseField: __type__ = None @@ -51,7 +71,7 @@ class BaseField: @property def is_required(self) -> bool: return ( - not self.nullable and not self.has_default and not self.is_auto_primary_key + not self.nullable and not self.has_default and not self.is_auto_primary_key ) @property @@ -95,17 +115,10 @@ class BaseField: return value +@RequiredParams("length") class String(BaseField): __type__ = str - def __init__(self, *args: Any, **kwargs: Any) -> None: - if "length" not in kwargs: - raise ModelDefinitionError( - "Param length is required for String model field." - ) - self.length = kwargs.pop("length") - super().__init__(*args, **kwargs) - def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.String(self.length) @@ -173,18 +186,10 @@ class BigInteger(BaseField): return sqlalchemy.BigInteger() +@RequiredParams("length", "precision") class Decimal(BaseField): __type__ = decimal.Decimal - def __init__(self, *args: Any, **kwargs: Any) -> None: - if "length" not in kwargs or "precision" not in kwargs: - raise ModelDefinitionError( - "Params length and precision are required for Decimal model field." - ) - self.length = kwargs.pop("length") - self.precision = kwargs.pop("precision") - super().__init__(*args, **kwargs) - def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.DECIMAL(self.length, self.precision) @@ -204,12 +209,12 @@ def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": class ForeignKey(BaseField): def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual @@ -229,7 +234,7 @@ class ForeignKey(BaseField): return to_column.get_column_type() def expand_relationship( - self, value: Any, child: "Model" + self, value: Any, child: "Model" ) -> Union["Model", List["Model"]]: if isinstance(value, orm.models.Model) and not isinstance(value, self.to): diff --git a/orm/models.py b/orm/models.py index f7a6e6a..94047d2 100644 --- a/orm/models.py +++ b/orm/models.py @@ -6,7 +6,6 @@ from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar from typing import Callable, Dict, Set import databases -from pydantic.fields import ModelField import orm.queryset as qry from orm.exceptions import ModelDefinitionError @@ -15,6 +14,7 @@ from orm.relations import RelationshipManager import pydantic from pydantic import BaseConfig, BaseModel, create_model +from pydantic.fields import ModelField import sqlalchemy @@ -42,21 +42,21 @@ def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> ) -def expand_reverse_relationships(model: Type["Model"]): - for field_name, model_field in model.__model_fields__.items(): +def expand_reverse_relationships(model: Type["Model"]) -> None: + for model_field in model.__model_fields__.values(): if isinstance(model_field, ForeignKey): - child_model_name = model_field.related_name or model.__name__.lower() + 's' + child_model_name = model_field.related_name or model.__name__.lower() + "s" parent_model = model_field.to child = model if ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ): register_reverse_model_fields(parent_model, child, child_model_name) def register_reverse_model_fields( - model: Type["Model"], child: Type["Model"], child_model_name: str + model: Type["Model"], child: Type["Model"], child_model_name: str ) -> None: model.__fields__[child_model_name] = ModelField( name=child_model_name, @@ -70,7 +70,7 @@ def register_reverse_model_fields( def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str + name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: pkname: Optional[str] = None columns: List[sqlalchemy.Column] = [] @@ -198,9 +198,9 @@ class FakePydantic(list, metaclass=ModelMetaclass): item = getattr(self.values, key, None) if ( - item is not None - and self._is_conversion_to_json_needed(key) - and isinstance(item, str) + item is not None + and self._is_conversion_to_json_needed(key) + and isinstance(item, str) ): try: item = json.loads(item) @@ -216,7 +216,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): if self.__class__ != other.__class__: # pragma no cover return False return self._orm_id == other._orm_id or ( - self.values is not None and other.values is not None and self.pk == other.pk + self.values is not None and other.values is not None and self.pk == other.pk ) def __repr__(self) -> str: # pragma no cover @@ -272,7 +272,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): related_names = set() for name, field in cls.__fields__.items(): if inspect.isclass(field.type_) and issubclass( - field.type_, pydantic.BaseModel + field.type_, pydantic.BaseModel ): related_names.add(name) return related_names @@ -304,7 +304,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): for field in one.__model_fields__.keys(): # print(field, one.dict(), other.dict()) if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), Model + getattr(one, field), Model ): setattr(other, field, getattr(one, field) + getattr(other, field)) elif isinstance(getattr(one, field), Model): @@ -326,10 +326,10 @@ class Model(FakePydantic): @classmethod def from_row( - cls, - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - previous_table: str = None, + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + previous_table: str = None, ) -> "Model": item = {} @@ -387,8 +387,8 @@ class Model(FakePydantic): self_fields.pop(self.__pkname__) expr = ( self.__table__.update() - .values(**self_fields) - .where(self.pk_column == getattr(self, self.__pkname__)) + .values(**self_fields) + .where(self.pk_column == getattr(self, self.__pkname__)) ) result = await self.__database__.execute(expr) return result