some cleanup and refactoring

This commit is contained in:
collerek
2020-12-06 08:23:57 +01:00
parent 1d4a074c2c
commit 9838547c4f
2 changed files with 21 additions and 18 deletions

View File

@ -47,11 +47,7 @@ if TYPE_CHECKING: # pragma no cover
class NewBaseModel( class NewBaseModel(
pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass
): ):
__slots__ = ( __slots__ = ("_orm_id", "_orm_saved", "_orm", "_pk_column")
"_orm_id",
"_orm_saved",
"_orm",
)
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, Type[BaseField]] __model_fields__: Dict[str, Type[BaseField]]
@ -75,6 +71,7 @@ class NewBaseModel(
def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore
object.__setattr__(self, "_orm_id", uuid.uuid4().hex) object.__setattr__(self, "_orm_id", uuid.uuid4().hex)
object.__setattr__(self, "_orm_saved", False) object.__setattr__(self, "_orm_saved", False)
object.__setattr__(self, "_pk_column", None)
object.__setattr__( object.__setattr__(
self, self,
"_orm", "_orm",
@ -94,13 +91,8 @@ class NewBaseModel(
if "pk" in kwargs: if "pk" in kwargs:
kwargs[self.Meta.pkname] = kwargs.pop("pk") kwargs[self.Meta.pkname] = kwargs.pop("pk")
# remove property fields values from validation
kwargs = {
k: v
for k, v in kwargs.items()
if k not in object.__getattribute__(self, "Meta").property_fields
}
# build the models to set them and validate but don't register # build the models to set them and validate but don't register
# also remove property fields values from validation
try: try:
new_kwargs: Dict[str, Any] = { new_kwargs: Dict[str, Any] = {
k: self._convert_json( k: self._convert_json(
@ -111,14 +103,15 @@ class NewBaseModel(
"dumps", "dumps",
) )
for k, v in kwargs.items() for k, v in kwargs.items()
if k not in object.__getattribute__(self, "Meta").property_fields
} }
except KeyError as e: except KeyError as e:
raise ModelError( raise ModelError(
f"Unknown field '{e.args[0]}' for model {self.get_name(lower=False)}" f"Unknown field '{e.args[0]}' for model {self.get_name(lower=False)}"
) )
# explicitly set None to excluded fields with default # explicitly set None to excluded fields
# as pydantic populates them with default # as pydantic populates them with default if set
for field_to_nullify in excluded: for field_to_nullify in excluded:
new_kwargs[field_to_nullify] = None new_kwargs[field_to_nullify] = None
@ -195,7 +188,8 @@ class NewBaseModel(
return ( return (
self._orm_id == other._orm_id self._orm_id == other._orm_id
or (self.pk == other.pk and self.pk is not None) or (self.pk == other.pk and self.pk is not None)
or self.dict() == other.dict() or self.dict(exclude=self.extract_related_names())
== other.dict(exclude=other.extract_related_names())
) )
@classmethod @classmethod
@ -207,7 +201,12 @@ class NewBaseModel(
@property @property
def pk_column(self) -> sqlalchemy.Column: def pk_column(self) -> sqlalchemy.Column:
return self.Meta.table.primary_key.columns.values()[0] if object.__getattribute__(self, "_pk_column") is not None:
return object.__getattribute__(self, "_pk_column")
pk_columns = self.Meta.table.primary_key.columns.values()
pk_col = pk_columns[0]
object.__setattr__(self, "_pk_column", pk_col)
return pk_col
@property @property
def saved(self) -> bool: def saved(self) -> bool:

View File

@ -219,7 +219,7 @@ def test_excluding_fields_in_endpoints():
assert isinstance(user_instance.timestamp, datetime.datetime) assert isinstance(user_instance.timestamp, datetime.datetime)
assert user_instance.timestamp == timestamp assert user_instance.timestamp == timestamp
response = client.post("/users4/", json=user) response = client.post("/users4/", json=user3)
assert list(response.json().keys()) == [ assert list(response.json().keys()) == [
"id", "id",
"email", "email",
@ -228,8 +228,12 @@ def test_excluding_fields_in_endpoints():
"category", "category",
"timestamp", "timestamp",
] ]
assert response.json().get("timestamp") != str(timestamp).replace(" ", "T") assert (
assert response.json().get("timestamp") is not None datetime.datetime.strptime(
response.json().get("timestamp"), "%Y-%m-%dT%H:%M:%S.%f"
)
== timestamp
)
def test_adding_fields_in_endpoints(): def test_adding_fields_in_endpoints():