add save status and tests

This commit is contained in:
collerek
2020-11-13 16:21:12 +01:00
parent b0cf916531
commit 1f67da3a5c
3 changed files with 55 additions and 51 deletions

View File

@ -42,13 +42,13 @@ class Model(NewBaseModel):
@classmethod @classmethod
def from_row( # noqa CCR001 def from_row( # noqa CCR001
cls: Type[T], cls: Type[T],
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
select_related: List = None, select_related: List = None,
related_models: Any = None, related_models: Any = None,
previous_table: str = None, previous_table: str = None,
fields: Optional[Union[Dict, Set]] = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: Optional[Union[Dict, Set]] = None, exclude_fields: Optional[Union[Dict, Set]] = None,
) -> Optional[T]: ) -> Optional[T]:
item: Dict[str, Any] = {} item: Dict[str, Any] = {}
@ -58,9 +58,9 @@ class Model(NewBaseModel):
related_models = group_related_list(select_related) related_models = group_related_list(select_related)
if ( if (
previous_table previous_table
and previous_table in cls.Meta.model_fields and previous_table in cls.Meta.model_fields
and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField) and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField)
): ):
previous_table = cls.Meta.model_fields[ previous_table = cls.Meta.model_fields[
previous_table previous_table
@ -90,8 +90,9 @@ class Model(NewBaseModel):
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
) )
instance: Optional[T] = None
if item.get(cls.Meta.pkname, None) is not None: if item.get(cls.Meta.pkname, None) is not None:
instance: Optional[T] = cls(**item) instance = cls(**item)
instance.set_save_status(True) instance.set_save_status(True)
else: else:
instance = None instance = None
@ -100,13 +101,13 @@ class Model(NewBaseModel):
@classmethod @classmethod
def populate_nested_models_from_row( # noqa: CFQ002 def populate_nested_models_from_row( # noqa: CFQ002
cls, cls,
item: dict, item: dict,
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
related_models: Any, related_models: Any,
previous_table: sqlalchemy.Table, previous_table: sqlalchemy.Table,
fields: Optional[Union[Dict, Set]] = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: Optional[Union[Dict, Set]] = None, exclude_fields: Optional[Union[Dict, Set]] = None,
) -> dict: ) -> dict:
for related in related_models: for related in related_models:
if isinstance(related_models, dict) and related_models[related]: if isinstance(related_models, dict) and related_models[related]:
@ -140,12 +141,12 @@ class Model(NewBaseModel):
@classmethod @classmethod
def extract_prefixed_table_columns( # noqa CCR001 def extract_prefixed_table_columns( # noqa CCR001
cls, cls,
item: dict, item: dict,
row: sqlalchemy.engine.result.ResultProxy, row: sqlalchemy.engine.result.ResultProxy,
table_prefix: str, table_prefix: str,
fields: Optional[Union[Dict, Set]] = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: Optional[Union[Dict, Set]] = None, exclude_fields: Optional[Union[Dict, Set]] = None,
) -> dict: ) -> dict:
# databases does not keep aliases in Record for postgres, change to raw row # databases does not keep aliases in Record for postgres, change to raw row

View File

@ -144,12 +144,12 @@ class NewBaseModel(
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:
if item in ( if item in (
"_orm_id", "_orm_id",
"_orm_saved", "_orm_saved",
"_orm", "_orm",
"__fields__", "__fields__",
"_related_names", "_related_names",
"_props", "_props",
): ):
return object.__getattribute__(self, item) return object.__getattribute__(self, item)
if item == "pk": if item == "pk":
@ -163,7 +163,7 @@ class NewBaseModel(
return super().__getattribute__(item) return super().__getattribute__(item)
def _extract_related_model_instead_of_field( def _extract_related_model_instead_of_field(
self, item: str self, item: str
) -> Optional[Union["T", Sequence["T"]]]: ) -> Optional[Union["T", Sequence["T"]]]:
# alias = self.get_column_alias(item) # alias = self.get_column_alias(item)
if item in self._orm: if item in self._orm:
@ -177,9 +177,9 @@ class NewBaseModel(
def __same__(self, other: "NewBaseModel") -> bool: def __same__(self, other: "NewBaseModel") -> bool:
return ( return (
self._orm_id == other._orm_id self._orm_id == other._orm_id
or self.dict() == other.dict() or self.dict() == other.dict()
or (self.pk == other.pk and self.pk is not None) or (self.pk == other.pk and self.pk is not None)
) )
@classmethod @classmethod
@ -209,9 +209,9 @@ class NewBaseModel(
@classmethod @classmethod
def get_properties( def get_properties(
cls, cls,
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
) -> List[str]: ) -> List[str]:
if isinstance(cls._props, list): if isinstance(cls._props, list):
props = cls._props props = cls._props
@ -220,7 +220,7 @@ class NewBaseModel(
prop prop
for prop in dir(cls) for prop in dir(cls)
if isinstance(getattr(cls, prop), property) if isinstance(getattr(cls, prop), property)
and prop not in ("__values__", "__fields__", "fields", "pk_column") and prop not in ("__values__", "__fields__", "fields", "pk_column")
] ]
cls._props = props cls._props = props
if include: if include:
@ -230,16 +230,16 @@ class NewBaseModel(
return props return props
def dict( # noqa A003 def dict( # noqa A003
self, self,
*, *,
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
by_alias: bool = False, by_alias: bool = False,
skip_defaults: bool = None, skip_defaults: bool = None,
exclude_unset: bool = False, exclude_unset: bool = False,
exclude_defaults: bool = False, exclude_defaults: bool = False,
exclude_none: bool = False, exclude_none: bool = False,
nested: bool = False nested: bool = False
) -> "DictStrAny": # noqa: A003' ) -> "DictStrAny": # noqa: A003'
dict_instance = super().dict( dict_instance = super().dict(
include=include, include=include,

View File

@ -358,8 +358,11 @@ class QuerySet:
instance.pk = pk instance.pk = pk
# refresh server side defaults # refresh server side defaults
if any(field.server_default is not None if any(
for name, field in self.model.Meta.model_fields.items() if name not in kwargs): field.server_default is not None
for name, field in self.model.Meta.model_fields.items()
if name not in kwargs
):
instance = await instance.load() instance = await instance.load()
instance.set_save_status(True) instance.set_save_status(True)
return instance return instance
@ -377,7 +380,7 @@ class QuerySet:
for objt in objects: for objt in objects:
objt.set_save_status(True) objt.set_save_status(True)
async def bulk_update( async def bulk_update( # noqa: CCR001
self, objects: List["Model"], columns: List[str] = None self, objects: List["Model"], columns: List[str] = None
) -> None: ) -> None:
ready_objects = [] ready_objects = []