diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 43e02da..7073f63 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1037,7 +1037,9 @@ class QuerySet(Generic[T]): instance = await instance.save() return instance - async def bulk_create(self, objects: List["T"]) -> None: + async def bulk_create( + self, objects: List["T"], send_signals: bool = False + ) -> None: """ Performs a bulk create in one database session to speed up the process. @@ -1049,27 +1051,42 @@ class QuerySet(Generic[T]): :param objects: list of ormar models already initialized and ready to save. :type objects: List[Model] + :param send_signals: confirm send the pre/post create signals + :type send_signals: bool """ + + async def before_create(entity: "T") -> None: + await entity.signals.pre_save.send( + sender=entity.__class__, instance=entity) + + async def after_create(entity: "T") -> None: + entity.set_save_status(True) + # FIXME do not have the `id` value, because not reload from db + if send_signals: + await entity.signals.post_save.send( + sender=entity.__class__, instance=entity) + + if send_signals: + await asyncio.gather( + *[before_create(entity) for entity in objects] + ) + ready_objects = [ obj.prepare_model_to_save(obj.dict()) for obj in objects ] expr = self.table.insert().values(ready_objects) + # shouldn't use the execute_many, it's `queries.foreach(execute)` await self.database.execute(expr) - # FIXME: add pre_save signals - async def after_create(entity: "T") -> None: - entity.set_save_status(True) - await entity.signals.post_save.send( - sender=entity.__class__, instance=entity) - await asyncio.gather( *[after_create(entity) for entity in objects] ) async def bulk_update( # noqa: CCR001 - self, objects: List["T"], columns: List[str] = None + self, objects: List["T"], columns: List[str] = None, + send_signals: bool = False ) -> None: """ Performs bulk update in one database session to speed up the process. @@ -1087,6 +1104,8 @@ class QuerySet(Generic[T]): :type objects: List[Model] :param columns: list of columns to update :type columns: List[str] + :param send_signals: confirm send the pre/post create signals + :type send_signals: bool """ ready_objects = [] pk_name = self.model_meta.pkname @@ -1135,9 +1154,10 @@ class QuerySet(Generic[T]): # FIXME: add pre-update-signals async def after_update(entity: "T") -> None: entity.set_save_status(True) - await entity.signals.post_update.send( - sender=entity.__class__, instance=entity - ) + if send_signals: + await entity.signals.post_update.send( + sender=entity.__class__, instance=entity + ) await asyncio.gather( *[after_update(entity) for entity in objects] diff --git a/tests/test_queries/test_queryset_level_methods.py b/tests/test_queries/test_queryset_level_methods.py index c7db39f..80febf5 100644 --- a/tests/test_queries/test_queryset_level_methods.py +++ b/tests/test_queries/test_queryset_level_methods.py @@ -5,6 +5,7 @@ import pytest import sqlalchemy import ormar +from ormar import post_save, post_update, pre_save from ormar.exceptions import ModelPersistenceError, QueryDefinitionError from tests.settings import DATABASE_URL @@ -193,6 +194,35 @@ async def test_bulk_create(): assert len(completed) == 2 +@pytest.mark.asyncio +async def test_bulk_create_send_signals(): + async with database: + + @pre_save(ToDo) + async def before_save(sender, instance, **kwargs): + instance.completed = False + + @post_save(ToDo) + async def after_save(sender, instance, **kwargs): + assert not instance.completed + + await ToDo.objects.bulk_create( + [ + ToDo(text="Buy the groceries."), + ToDo(text="Call Mum.", completed=True), + ToDo(text="Send invoices.", completed=True), + ], send_signals=True + ) + + todoes = await ToDo.objects.all() + assert len(todoes) == 3 + for todo in todoes: + assert todo.pk is not None + + count = await ToDo.objects.filter(completed=False).count() + assert count == 3 + + @pytest.mark.asyncio async def test_bulk_create_with_relation(): async with database: @@ -240,6 +270,34 @@ async def test_bulk_update(): assert todo.text[-2:] == "_1" +@pytest.mark.asyncio +async def test_bulk_update_send_signals(): + async with database: + + @post_update(ToDo) + async def after_update(sender, instance, **kwargs): + await instance.delete() + + await ToDo.objects.bulk_create( + [ + ToDo(text="Buy the groceries."), + ToDo(text="Call Mum.", completed=True), + ToDo(text="Send invoices.", completed=True), + ] + ) + todoes = await ToDo.objects.all() + assert len(todoes) == 3 + + for todo in todoes: + todo.text = todo.text + "_1" + todo.completed = False + + await ToDo.objects.bulk_update(todoes, send_signals=True) + + count = await ToDo.objects.filter(completed=False).count() + assert count == 0 + + @pytest.mark.asyncio async def test_bulk_update_with_only_selected_columns(): async with database: