add bulk_create and bulk_update and tests

This commit is contained in:
collerek
2020-09-25 13:26:10 +02:00
parent 41c38a5bd6
commit a5abf2a403
5 changed files with 284 additions and 32 deletions

View File

@ -21,13 +21,13 @@ class ModelTableProxy:
raise NotImplementedError # pragma no cover
def _extract_own_model_fields(self) -> dict:
related_names = self._extract_related_names()
related_names = self.extract_related_names()
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
return self_fields
@classmethod
def extract_db_own_fields(cls) -> set:
related_names = cls._extract_related_names()
def extract_db_own_fields(cls) -> Set:
related_names = cls.extract_related_names()
self_fields = {
name for name in cls.Meta.model_fields.keys() if name not in related_names
}
@ -35,7 +35,7 @@ class ModelTableProxy:
@classmethod
def substitute_models_with_pks(cls, model_dict: dict) -> dict:
for field in cls._extract_related_names():
for field in cls.extract_related_names():
field_value = model_dict.get(field, None)
if field_value is not None:
target_field = cls.Meta.model_fields[field]
@ -47,7 +47,7 @@ class ModelTableProxy:
return model_dict
@classmethod
def _extract_related_names(cls) -> Set:
def extract_related_names(cls) -> Set:
related_names = set()
for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field) and issubclass(field, ForeignKeyField):
@ -69,7 +69,7 @@ class ModelTableProxy:
@classmethod
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
if nested:
return cls._extract_related_names()
return cls.extract_related_names()
related_names = set()
for name, field in cls.Meta.model_fields.items():
if (

View File

@ -92,7 +92,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
object.__setattr__(self, "__fields_set__", fields_set)
# register the related models after initialization
for related in self._extract_related_names():
for related in self.extract_related_names():
self.Meta.model_fields[related].expand_relationship(
kwargs.get(related), self, to_register=True
)
@ -119,7 +119,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
def __getattribute__(self, item: str) -> Any:
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"):
return object.__getattribute__(self, item)
if item != "_extract_related_names" and item in self._extract_related_names():
if item != "extract_related_names" and item in self.extract_related_names():
return self._extract_related_model_instead_of_field(item)
if item == "pk":
return self.__dict__.get(self.Meta.pkname, None)
@ -186,7 +186,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
for field in self._extract_related_names():
for field in self.extract_related_names():
nested_model = getattr(self, field)
if self.Meta.model_fields[field].virtual and nested:

View File

@ -247,3 +247,52 @@ class QuerySet:
if pk and isinstance(pk, self.model_cls.pk_type()):
setattr(instance, self.model_cls.Meta.pkname, pk)
return instance
async def bulk_create(self, objects: List["Model"]) -> None:
ready_objects = []
for objt in objects:
new_kwargs = objt.dict()
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs)
ready_objects.append(new_kwargs)
expr = self.table.insert()
await self.database.execute_many(expr, ready_objects)
async def bulk_update(
self, objects: List["Model"], columns: List[str] = None
) -> None:
ready_expressions = []
pk_name = self.model_cls.Meta.pkname
if not columns:
columns = self.model_cls.extract_db_own_fields().union(
self.model_cls.extract_related_names()
)
if pk_name not in columns:
columns.append(pk_name)
for objt in objects:
new_kwargs = objt.dict()
if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None:
raise QueryDefinitionError(
"You cannot update unsaved objects. "
f"{self.model_cls.__name__} has to have {pk_name} filled."
)
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs)
new_kwargs = {k: v for k, v in new_kwargs.items() if k in columns}
expr = self.table.update().values(
**{k: v for k, v in new_kwargs.items() if k != pk_name}
)
pk_column = self.model_cls.Meta.table.c.get(pk_name)
expr = expr.where(pk_column == new_kwargs.get(pk_name))
ready_expressions.append(expr)
# databases does not bind params for where clause and values separately
# no way to pass one dict with both uses
# so we need to resort to lower connection api
async with self.model_cls.Meta.database.connection() as connection:
for single_query in ready_expressions:
await connection.execute(single_query)