diff --git a/README.md b/README.md index 33e057d..8183bbb 100644 --- a/README.md +++ b/README.md @@ -679,6 +679,7 @@ Signals allow to trigger your function for a given event on a given Model. * `post_relation_add` * `pre_relation_remove` * `post_relation_remove` +* `post_bulk_update` [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ diff --git a/docs/signals.md b/docs/signals.md index bc11238..3a85d52 100644 --- a/docs/signals.md +++ b/docs/signals.md @@ -232,6 +232,11 @@ Send for `Model.relation_name.remove()` method for `ManyToMany` relations and re `sender` - sender class, `instance` - instance to which related model is added, `child` - model being added, `relation_name` - name of the relation to which child is added. +### post_bulk_update + +`post_bulk_update(sender: Type["Model"], instances: List["Model"], **kwargs), +Send for `Model.objects.bulk_update(List[objects])` method. + ## Defining your own signals diff --git a/ormar/__init__.py b/ormar/__init__.py index 21ab87e..fd52f2b 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -30,6 +30,7 @@ from ormar.decorators import ( # noqa: I100 post_relation_remove, post_save, post_update, + post_bulk_update, pre_delete, pre_relation_add, pre_relation_remove, @@ -113,6 +114,7 @@ __all__ = [ "RelationProtocol", "ModelMeta", "property_field", + "post_bulk_update", "post_delete", "post_save", "post_update", diff --git a/ormar/decorators/__init__.py b/ormar/decorators/__init__.py index ec320a8..c6f42d0 100644 --- a/ormar/decorators/__init__.py +++ b/ormar/decorators/__init__.py @@ -9,6 +9,7 @@ Currently only: """ from ormar.decorators.property_field import property_field from ormar.decorators.signals import ( + post_bulk_update, post_delete, post_relation_add, post_relation_remove, @@ -23,6 +24,7 @@ from ormar.decorators.signals import ( __all__ = [ "property_field", + "post_bulk_update", "post_delete", "post_save", "post_update", diff --git a/ormar/decorators/signals.py b/ormar/decorators/signals.py index 468fc0c..cdb9f45 100644 --- a/ormar/decorators/signals.py +++ b/ormar/decorators/signals.py @@ -1,4 +1,4 @@ -from typing import Callable, List, TYPE_CHECKING, Type, Union +from typing import Callable, List, Type, TYPE_CHECKING, Union if TYPE_CHECKING: # pragma: no cover from ormar import Model @@ -171,3 +171,18 @@ def post_relation_remove( :rtype: Callable """ return receiver(signal="post_relation_remove", senders=senders) + + +def post_bulk_update( + senders: Union[Type["Model"], List[Type["Model"]]] +) -> Callable: + """ + Connect given function to all senders for post_bulk_update signal. + + :param senders: one or a list of "Model" classes + that should have the signal receiver registered + :type senders: Union[Type["Model"], List[Type["Model"]]] + :return: returns the original function untouched + :rtype: Callable + """ + return receiver(signal="post_bulk_update", senders=senders) diff --git a/ormar/exceptions.py b/ormar/exceptions.py index 094e272..8fd6bab 100644 --- a/ormar/exceptions.py +++ b/ormar/exceptions.py @@ -81,3 +81,10 @@ class SignalDefinitionError(AsyncOrmException): """ pass + + +class ModelListEmptyError(AsyncOrmException): + """ + Raised for objects is empty when bulk_update + """ + pass diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 2bee725..cc1dede 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -159,6 +159,7 @@ def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001 signals.post_relation_add = Signal() signals.pre_relation_remove = Signal() signals.post_relation_remove = Signal() + signals.post_bulk_update = Signal() new_model.Meta.signals = signals diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index e55f043..60efa37 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -56,6 +56,21 @@ class SavePrepareMixin(RelationMixin, AliasMixin): new_kwargs = cls.translate_columns_to_aliases(new_kwargs) return new_kwargs + @classmethod + def prepare_model_to_update(cls, new_kwargs: dict) -> dict: + """ + Combines all preparation methods before updating. + :param new_kwargs: dictionary of model that is about to be saved + :type new_kwargs: Dict[str, str] + :return: dictionary of model that is about to be updated + :rtype: Dict[str, str] + """ + new_kwargs = cls.parse_non_db_fields(new_kwargs) + new_kwargs = cls.substitute_models_with_pks(new_kwargs) + new_kwargs = cls.reconvert_str_to_bytes(new_kwargs) + new_kwargs = cls.translate_columns_to_aliases(new_kwargs) + return new_kwargs + @classmethod def _remove_not_ormar_fields(cls, new_kwargs: dict) -> dict: """ diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 1782567..95cba43 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -29,7 +29,10 @@ except ImportError: # pragma: no cover import ormar # noqa I100 from ormar import MultipleMatches, NoMatch -from ormar.exceptions import ModelPersistenceError, QueryDefinitionError +from ormar.exceptions import ( + ModelPersistenceError, QueryDefinitionError, + ModelListEmptyError +) from ormar.queryset import FieldAccessor, FilterQuery, SelectAction from ormar.queryset.actions.order_action import OrderAction from ormar.queryset.clause import FilterGroup, QueryClause @@ -1049,7 +1052,7 @@ class QuerySet(Generic[T]): async def bulk_create(self, objects: List["T"]) -> None: """ - Performs a bulk update in one database session to speed up the process. + Performs a bulk create in one database session to speed up the process. Allows you to create multiple objects at once. @@ -1060,17 +1063,15 @@ class QuerySet(Generic[T]): :param objects: list of ormar models already initialized and ready to save. :type objects: List[Model] """ - ready_objects = [] - for objt in objects: - new_kwargs = objt.dict() - new_kwargs = objt.prepare_model_to_save(new_kwargs) - ready_objects.append(new_kwargs) + ready_objects = [ + obj.prepare_model_to_save(obj.dict()) + for obj in objects + ] + expr = self.table.insert().values(ready_objects) + await self.database.execute(expr) - expr = self.table.insert() - await self.database.execute_many(expr, ready_objects) - - for objt in objects: - objt.set_save_status(True) + for obj in objects: + obj.set_save_status(True) async def bulk_update( # noqa: CCR001 self, objects: List["T"], columns: List[str] = None @@ -1078,7 +1079,7 @@ class QuerySet(Generic[T]): """ Performs bulk update in one database session to speed up the process. - Allows to update multiple instance at once. + Allows you to update multiple instance at once. All `Models` passed need to have primary key column populated. @@ -1092,6 +1093,9 @@ class QuerySet(Generic[T]): :param columns: list of columns to update :type columns: List[str] """ + if not objects: + raise ModelListEmptyError("Bulk update objects are empty!") + ready_objects = [] pk_name = self.model_meta.pkname if not columns: @@ -1106,19 +1110,17 @@ class QuerySet(Generic[T]): columns = [self.model.get_column_alias(k) for k in columns] - for objt in objects: - new_kwargs = objt.dict() - if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None: + for obj in objects: + new_kwargs = obj.dict() + if new_kwargs.get(pk_name) is None: raise ModelPersistenceError( "You cannot update unsaved objects. " f"{self.model.__name__} has to have {pk_name} filled." ) - new_kwargs = self.model.parse_non_db_fields(new_kwargs) - new_kwargs = self.model.substitute_models_with_pks(new_kwargs) - new_kwargs = self.model.reconvert_str_to_bytes(new_kwargs) - new_kwargs = self.model.translate_columns_to_aliases(new_kwargs) - new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns} - ready_objects.append(new_kwargs) + new_kwargs = obj.prepare_model_to_update(new_kwargs) + ready_objects.append({ + "new_" + k: v for k, v in new_kwargs.items() if k in columns + }) pk_column = self.model_meta.table.c.get(self.model.get_column_alias(pk_name)) pk_column_name = self.model.get_column_alias(pk_name) @@ -1138,5 +1140,10 @@ class QuerySet(Generic[T]): expr = str(expr) await self.database.execute_many(expr, ready_objects) - for objt in objects: - objt.set_save_status(True) + for obj in objects: + obj.set_save_status(True) + + await cast(Type["Model"], self.model_cls).Meta.signals.post_bulk_update.send( + sender=self.model_cls, instances=objects # type: ignore + ) + diff --git a/ormar/signals/signal.py b/ormar/signals/signal.py index e2c5275..d5b4a5c 100644 --- a/ormar/signals/signal.py +++ b/ormar/signals/signal.py @@ -1,6 +1,6 @@ import asyncio import inspect -from typing import Any, Callable, Dict, List, TYPE_CHECKING, Tuple, Type, Union +from typing import Any, Callable, Dict, Tuple, Type, TYPE_CHECKING, Union from ormar.exceptions import SignalDefinitionError @@ -45,7 +45,7 @@ class Signal: """ def __init__(self) -> None: - self._receivers: List[Tuple[Union[int, Tuple[int, int]], Callable]] = [] + self._receivers: Dict[Union[int, Tuple[int, int]], Callable] = {} def connect(self, receiver: Callable) -> None: """ @@ -63,8 +63,8 @@ class Signal: "Signal receivers must accept **kwargs argument." ) new_receiver_key = make_id(receiver) - if not any(rec_id == new_receiver_key for rec_id, _ in self._receivers): - self._receivers.append((new_receiver_key, receiver)) + if new_receiver_key not in self._receivers: + self._receivers[new_receiver_key] = receiver def disconnect(self, receiver: Callable) -> bool: """ @@ -75,15 +75,10 @@ class Signal: :return: flag if receiver was removed :rtype: bool """ - removed = False new_receiver_key = make_id(receiver) - for ind, rec in enumerate(self._receivers): - rec_id, _ = rec - if rec_id == new_receiver_key: - removed = True - del self._receivers[ind] - break - return removed + receiver_func: Union[Callable, None] = self._receivers.pop( + new_receiver_key, None) + return True if receiver_func is not None else False async def send(self, sender: Type["Model"], **kwargs: Any) -> None: """ @@ -93,28 +88,22 @@ class Signal: :param kwargs: arguments passed to receivers :type kwargs: Any """ - receivers = [] - for receiver in self._receivers: - _, receiver_func = receiver - receivers.append(receiver_func(sender=sender, **kwargs)) + receivers = [ + receiver_func(sender=sender, **kwargs) + for receiver_func in self._receivers.values() + ] await asyncio.gather(*receivers) -class SignalEmitter: +class SignalEmitter(dict): """ Emitter that registers the signals in internal dictionary. If signal with given name does not exist it's auto added on access. """ - - if TYPE_CHECKING: # pragma: no cover - signals: Dict[str, Signal] - - def __init__(self) -> None: - object.__setattr__(self, "signals", dict()) - def __getattr__(self, item: str) -> Signal: - return self.signals.setdefault(item, Signal()) + return self.setdefault(item, Signal()) - def __setattr__(self, key: str, value: Any) -> None: - signals = object.__getattribute__(self, "signals") - signals[key] = value + def __setattr__(self, key: str, value: Signal) -> None: + if not isinstance(value, Signal): + raise SignalDefinitionError(f"{value} is not valid signal") + self[key] = value diff --git a/tests/test_queries/test_queryset_level_methods.py b/tests/test_queries/test_queryset_level_methods.py index c7db39f..e7a9543 100644 --- a/tests/test_queries/test_queryset_level_methods.py +++ b/tests/test_queries/test_queryset_level_methods.py @@ -5,7 +5,10 @@ import pytest import sqlalchemy import ormar -from ormar.exceptions import ModelPersistenceError, QueryDefinitionError +from ormar.exceptions import ( + ModelPersistenceError, QueryDefinitionError, + ModelListEmptyError +) from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -309,3 +312,6 @@ async def test_bulk_update_not_saved_objts(): Note(text="Call Mum.", category=category), ] ) + + with pytest.raises(ModelListEmptyError): + await Note.objects.bulk_update([]) diff --git a/tests/test_signals/test_signals.py b/tests/test_signals/test_signals.py index 4832682..1002a2f 100644 --- a/tests/test_signals/test_signals.py +++ b/tests/test_signals/test_signals.py @@ -6,7 +6,11 @@ import pytest import sqlalchemy import ormar -from ormar import post_delete, post_save, post_update, pre_delete, pre_save, pre_update +from ormar import ( + post_bulk_update, post_delete, post_save, post_update, + pre_delete, pre_save, pre_update +) +from ormar.signals import SignalEmitter from ormar.exceptions import SignalDefinitionError from tests.settings import DATABASE_URL @@ -77,6 +81,12 @@ def test_passing_callable_without_kwargs(): pass +def test_invalid_signal(): + emitter = SignalEmitter() + with pytest.raises(SignalDefinitionError): + emitter.save = 1 + + @pytest.mark.asyncio async def test_signal_functions(cleanup): async with database: @@ -124,6 +134,14 @@ async def test_signal_functions(cleanup): event_log=instance.json(), ).save() + @post_bulk_update(Album) + async def after_bulk_update(sender, instances, **kwargs): + for it in instances: + await AuditLog( + event_type=f"BULK_POST_UPDATE_{sender.get_name()}", + event_log=it.json(), + ).save() + album = await Album.objects.create(name="Venice") audits = await AuditLog.objects.all() @@ -176,6 +194,19 @@ async def test_signal_functions(cleanup): album.signals.pre_save.disconnect(before_save) album.signals.post_save.disconnect(after_save) + albums = await Album.objects.all() + assert len(albums) + + for album in albums: + album.play_count = 1 + + await Album.objects.bulk_update(albums) + + cnt = await AuditLog.objects.filter(event_type__contains="BULK_POST").count() + assert cnt == len(albums) + + album.signals.bulk_post_update.disconnect(after_bulk_update) + @pytest.mark.asyncio async def test_multiple_signals(cleanup):