From bca6b6eef053e2611330ab38471bbcad0e4aa8f3 Mon Sep 17 00:00:00 2001 From: huangsong Date: Fri, 14 Jan 2022 18:26:11 +0800 Subject: [PATCH] add bulk_post_update: signals --- README.md | 1 + docs/signals.md | 5 ++ ormar/__init__.py | 2 + ormar/decorators/__init__.py | 2 + ormar/decorators/signals.py | 30 +++++++-- ormar/exceptions.py | 7 ++ ormar/models/metaclass.py | 1 + ormar/queryset/queryset.py | 56 +++++----------- .../test_queryset_level_methods.py | 66 ++----------------- tests/test_signals/test_signals.py | 26 +++++++- 10 files changed, 92 insertions(+), 104 deletions(-) diff --git a/README.md b/README.md index 33e057d..8183bbb 100644 --- a/README.md +++ b/README.md @@ -679,6 +679,7 @@ Signals allow to trigger your function for a given event on a given Model. * `post_relation_add` * `pre_relation_remove` * `post_relation_remove` +* `post_bulk_update` [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ diff --git a/docs/signals.md b/docs/signals.md index bc11238..3a85d52 100644 --- a/docs/signals.md +++ b/docs/signals.md @@ -232,6 +232,11 @@ Send for `Model.relation_name.remove()` method for `ManyToMany` relations and re `sender` - sender class, `instance` - instance to which related model is added, `child` - model being added, `relation_name` - name of the relation to which child is added. +### post_bulk_update + +`post_bulk_update(sender: Type["Model"], instances: List["Model"], **kwargs), +Send for `Model.objects.bulk_update(List[objects])` method. + ## Defining your own signals diff --git a/ormar/__init__.py b/ormar/__init__.py index 21ab87e..fd52f2b 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -30,6 +30,7 @@ from ormar.decorators import ( # noqa: I100 post_relation_remove, post_save, post_update, + post_bulk_update, pre_delete, pre_relation_add, pre_relation_remove, @@ -113,6 +114,7 @@ __all__ = [ "RelationProtocol", "ModelMeta", "property_field", + "post_bulk_update", "post_delete", "post_save", "post_update", diff --git a/ormar/decorators/__init__.py b/ormar/decorators/__init__.py index ec320a8..c6f42d0 100644 --- a/ormar/decorators/__init__.py +++ b/ormar/decorators/__init__.py @@ -9,6 +9,7 @@ Currently only: """ from ormar.decorators.property_field import property_field from ormar.decorators.signals import ( + post_bulk_update, post_delete, post_relation_add, post_relation_remove, @@ -23,6 +24,7 @@ from ormar.decorators.signals import ( __all__ = [ "property_field", + "post_bulk_update", "post_delete", "post_save", "post_update", diff --git a/ormar/decorators/signals.py b/ormar/decorators/signals.py index 54e97e5..64bbf33 100644 --- a/ormar/decorators/signals.py +++ b/ormar/decorators/signals.py @@ -54,7 +54,8 @@ def post_save(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: return receiver(signal="post_save", senders=senders) -def post_update(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: +def post_update(senders: Union[Type["Model"], List[Type["Model"]]] +) -> Callable: """ Connect given function to all senders for post_update signal. @@ -67,7 +68,8 @@ def post_update(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: return receiver(signal="post_update", senders=senders) -def post_delete(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: +def post_delete(senders: Union[Type["Model"], List[Type["Model"]]] +) -> Callable: """ Connect given function to all senders for post_delete signal. @@ -119,7 +121,8 @@ def pre_delete(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: return receiver(signal="pre_delete", senders=senders) -def pre_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: +def pre_relation_add(senders: Union[Type["Model"], List[Type["Model"]]] +) -> Callable: """ Connect given function to all senders for pre_relation_add signal. @@ -132,7 +135,8 @@ def pre_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]) -> Call return receiver(signal="pre_relation_add", senders=senders) -def post_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: +def post_relation_add(senders: Union[Type["Model"], List[Type["Model"]]] +) -> Callable: """ Connect given function to all senders for post_relation_add signal. @@ -145,7 +149,8 @@ def post_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]) -> Cal return receiver(signal="post_relation_add", senders=senders) -def pre_relation_remove(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: +def pre_relation_remove(senders: Union[Type["Model"], List[Type["Model"]]] +) -> Callable: """ Connect given function to all senders for pre_relation_remove signal. @@ -171,3 +176,18 @@ def post_relation_remove( :rtype: Callable """ return receiver(signal="post_relation_remove", senders=senders) + + +def post_bulk_update( + senders: Union[Type["Model"], List[Type["Model"]]] +) -> Callable: + """ + Connect given function to all senders for post_bulk_update signal. + + :param senders: one or a list of "Model" classes + that should have the signal receiver registered + :type senders: Union[Type["Model"], List[Type["Model"]]] + :return: returns the original function untouched + :rtype: Callable + """ + return receiver(signal="post_bulk_update", senders=senders) diff --git a/ormar/exceptions.py b/ormar/exceptions.py index 094e272..8fd6bab 100644 --- a/ormar/exceptions.py +++ b/ormar/exceptions.py @@ -81,3 +81,10 @@ class SignalDefinitionError(AsyncOrmException): """ pass + + +class ModelListEmptyError(AsyncOrmException): + """ + Raised for objects is empty when bulk_update + """ + pass diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 2bee725..cc1dede 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -159,6 +159,7 @@ def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001 signals.post_relation_add = Signal() signals.pre_relation_remove = Signal() signals.post_relation_remove = Signal() + signals.post_bulk_update = Signal() new_model.Meta.signals = signals diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 7073f63..4237a8b 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -21,7 +21,10 @@ from sqlalchemy import bindparam import ormar # noqa I100 from ormar import MultipleMatches, NoMatch -from ormar.exceptions import ModelPersistenceError, QueryDefinitionError +from ormar.exceptions import ( + ModelPersistenceError, QueryDefinitionError, + ModelListEmptyError +) from ormar.queryset import FieldAccessor, FilterQuery, SelectAction from ormar.queryset.actions.order_action import OrderAction from ormar.queryset.clause import FilterGroup, QueryClause @@ -1037,9 +1040,7 @@ class QuerySet(Generic[T]): instance = await instance.save() return instance - async def bulk_create( - self, objects: List["T"], send_signals: bool = False - ) -> None: + async def bulk_create(self, objects: List["T"]) -> None: """ Performs a bulk create in one database session to speed up the process. @@ -1051,26 +1052,7 @@ class QuerySet(Generic[T]): :param objects: list of ormar models already initialized and ready to save. :type objects: List[Model] - :param send_signals: confirm send the pre/post create signals - :type send_signals: bool """ - - async def before_create(entity: "T") -> None: - await entity.signals.pre_save.send( - sender=entity.__class__, instance=entity) - - async def after_create(entity: "T") -> None: - entity.set_save_status(True) - # FIXME do not have the `id` value, because not reload from db - if send_signals: - await entity.signals.post_save.send( - sender=entity.__class__, instance=entity) - - if send_signals: - await asyncio.gather( - *[before_create(entity) for entity in objects] - ) - ready_objects = [ obj.prepare_model_to_save(obj.dict()) for obj in objects @@ -1080,13 +1062,11 @@ class QuerySet(Generic[T]): # shouldn't use the execute_many, it's `queries.foreach(execute)` await self.database.execute(expr) - await asyncio.gather( - *[after_create(entity) for entity in objects] - ) + for obj in objects: + obj.set_save_status(True) async def bulk_update( # noqa: CCR001 - self, objects: List["T"], columns: List[str] = None, - send_signals: bool = False + self, objects: List["T"], columns: List[str] = None ) -> None: """ Performs bulk update in one database session to speed up the process. @@ -1104,9 +1084,10 @@ class QuerySet(Generic[T]): :type objects: List[Model] :param columns: list of columns to update :type columns: List[str] - :param send_signals: confirm send the pre/post create signals - :type send_signals: bool """ + if not objects: + raise ModelListEmptyError("Bulk update objects are empty!") + ready_objects = [] pk_name = self.model_meta.pkname if not columns: @@ -1151,14 +1132,11 @@ class QuerySet(Generic[T]): expr = str(expr) await self.database.execute_many(expr, ready_objects) - # FIXME: add pre-update-signals - async def after_update(entity: "T") -> None: - entity.set_save_status(True) - if send_signals: - await entity.signals.post_update.send( - sender=entity.__class__, instance=entity - ) + entity = list(objects)[0] - await asyncio.gather( - *[after_update(entity) for entity in objects] + await entity.signals.post_bulk_update.send( + sender=entity.__class__, instances=objects ) + + for obj in objects: + obj.set_save_status(True) diff --git a/tests/test_queries/test_queryset_level_methods.py b/tests/test_queries/test_queryset_level_methods.py index 80febf5..e7a9543 100644 --- a/tests/test_queries/test_queryset_level_methods.py +++ b/tests/test_queries/test_queryset_level_methods.py @@ -5,8 +5,10 @@ import pytest import sqlalchemy import ormar -from ormar import post_save, post_update, pre_save -from ormar.exceptions import ModelPersistenceError, QueryDefinitionError +from ormar.exceptions import ( + ModelPersistenceError, QueryDefinitionError, + ModelListEmptyError +) from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -194,35 +196,6 @@ async def test_bulk_create(): assert len(completed) == 2 -@pytest.mark.asyncio -async def test_bulk_create_send_signals(): - async with database: - - @pre_save(ToDo) - async def before_save(sender, instance, **kwargs): - instance.completed = False - - @post_save(ToDo) - async def after_save(sender, instance, **kwargs): - assert not instance.completed - - await ToDo.objects.bulk_create( - [ - ToDo(text="Buy the groceries."), - ToDo(text="Call Mum.", completed=True), - ToDo(text="Send invoices.", completed=True), - ], send_signals=True - ) - - todoes = await ToDo.objects.all() - assert len(todoes) == 3 - for todo in todoes: - assert todo.pk is not None - - count = await ToDo.objects.filter(completed=False).count() - assert count == 3 - - @pytest.mark.asyncio async def test_bulk_create_with_relation(): async with database: @@ -270,34 +243,6 @@ async def test_bulk_update(): assert todo.text[-2:] == "_1" -@pytest.mark.asyncio -async def test_bulk_update_send_signals(): - async with database: - - @post_update(ToDo) - async def after_update(sender, instance, **kwargs): - await instance.delete() - - await ToDo.objects.bulk_create( - [ - ToDo(text="Buy the groceries."), - ToDo(text="Call Mum.", completed=True), - ToDo(text="Send invoices.", completed=True), - ] - ) - todoes = await ToDo.objects.all() - assert len(todoes) == 3 - - for todo in todoes: - todo.text = todo.text + "_1" - todo.completed = False - - await ToDo.objects.bulk_update(todoes, send_signals=True) - - count = await ToDo.objects.filter(completed=False).count() - assert count == 0 - - @pytest.mark.asyncio async def test_bulk_update_with_only_selected_columns(): async with database: @@ -367,3 +312,6 @@ async def test_bulk_update_not_saved_objts(): Note(text="Call Mum.", category=category), ] ) + + with pytest.raises(ModelListEmptyError): + await Note.objects.bulk_update([]) diff --git a/tests/test_signals/test_signals.py b/tests/test_signals/test_signals.py index 3dbcfa6..1002a2f 100644 --- a/tests/test_signals/test_signals.py +++ b/tests/test_signals/test_signals.py @@ -6,7 +6,10 @@ import pytest import sqlalchemy import ormar -from ormar import post_delete, post_save, post_update, pre_delete, pre_save, pre_update +from ormar import ( + post_bulk_update, post_delete, post_save, post_update, + pre_delete, pre_save, pre_update +) from ormar.signals import SignalEmitter from ormar.exceptions import SignalDefinitionError from tests.settings import DATABASE_URL @@ -131,6 +134,14 @@ async def test_signal_functions(cleanup): event_log=instance.json(), ).save() + @post_bulk_update(Album) + async def after_bulk_update(sender, instances, **kwargs): + for it in instances: + await AuditLog( + event_type=f"BULK_POST_UPDATE_{sender.get_name()}", + event_log=it.json(), + ).save() + album = await Album.objects.create(name="Venice") audits = await AuditLog.objects.all() @@ -183,6 +194,19 @@ async def test_signal_functions(cleanup): album.signals.pre_save.disconnect(before_save) album.signals.post_save.disconnect(after_save) + albums = await Album.objects.all() + assert len(albums) + + for album in albums: + album.play_count = 1 + + await Album.objects.bulk_update(albums) + + cnt = await AuditLog.objects.filter(event_type__contains="BULK_POST").count() + assert cnt == len(albums) + + album.signals.bulk_post_update.disconnect(after_bulk_update) + @pytest.mark.asyncio async def test_multiple_signals(cleanup):