add bulk_post_update: signals
This commit is contained in:
@ -679,6 +679,7 @@ Signals allow to trigger your function for a given event on a given Model.
|
||||
* `post_relation_add`
|
||||
* `pre_relation_remove`
|
||||
* `post_relation_remove`
|
||||
* `post_bulk_update`
|
||||
|
||||
|
||||
[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/
|
||||
|
||||
@ -232,6 +232,11 @@ Send for `Model.relation_name.remove()` method for `ManyToMany` relations and re
|
||||
`sender` - sender class, `instance` - instance to which related model is added, `child` - model being added,
|
||||
`relation_name` - name of the relation to which child is added.
|
||||
|
||||
### post_bulk_update
|
||||
|
||||
`post_bulk_update(sender: Type["Model"], instances: List["Model"], **kwargs),
|
||||
Send for `Model.objects.bulk_update(List[objects])` method.
|
||||
|
||||
|
||||
## Defining your own signals
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ from ormar.decorators import ( # noqa: I100
|
||||
post_relation_remove,
|
||||
post_save,
|
||||
post_update,
|
||||
post_bulk_update,
|
||||
pre_delete,
|
||||
pre_relation_add,
|
||||
pre_relation_remove,
|
||||
@ -113,6 +114,7 @@ __all__ = [
|
||||
"RelationProtocol",
|
||||
"ModelMeta",
|
||||
"property_field",
|
||||
"post_bulk_update",
|
||||
"post_delete",
|
||||
"post_save",
|
||||
"post_update",
|
||||
|
||||
@ -9,6 +9,7 @@ Currently only:
|
||||
"""
|
||||
from ormar.decorators.property_field import property_field
|
||||
from ormar.decorators.signals import (
|
||||
post_bulk_update,
|
||||
post_delete,
|
||||
post_relation_add,
|
||||
post_relation_remove,
|
||||
@ -23,6 +24,7 @@ from ormar.decorators.signals import (
|
||||
|
||||
__all__ = [
|
||||
"property_field",
|
||||
"post_bulk_update",
|
||||
"post_delete",
|
||||
"post_save",
|
||||
"post_update",
|
||||
|
||||
@ -54,7 +54,8 @@ def post_save(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
|
||||
return receiver(signal="post_save", senders=senders)
|
||||
|
||||
|
||||
def post_update(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
|
||||
def post_update(senders: Union[Type["Model"], List[Type["Model"]]]
|
||||
) -> Callable:
|
||||
"""
|
||||
Connect given function to all senders for post_update signal.
|
||||
|
||||
@ -67,7 +68,8 @@ def post_update(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
|
||||
return receiver(signal="post_update", senders=senders)
|
||||
|
||||
|
||||
def post_delete(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
|
||||
def post_delete(senders: Union[Type["Model"], List[Type["Model"]]]
|
||||
) -> Callable:
|
||||
"""
|
||||
Connect given function to all senders for post_delete signal.
|
||||
|
||||
@ -119,7 +121,8 @@ def pre_delete(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
|
||||
return receiver(signal="pre_delete", senders=senders)
|
||||
|
||||
|
||||
def pre_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
|
||||
def pre_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]
|
||||
) -> Callable:
|
||||
"""
|
||||
Connect given function to all senders for pre_relation_add signal.
|
||||
|
||||
@ -132,7 +135,8 @@ def pre_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]) -> Call
|
||||
return receiver(signal="pre_relation_add", senders=senders)
|
||||
|
||||
|
||||
def post_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
|
||||
def post_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]
|
||||
) -> Callable:
|
||||
"""
|
||||
Connect given function to all senders for post_relation_add signal.
|
||||
|
||||
@ -145,7 +149,8 @@ def post_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]) -> Cal
|
||||
return receiver(signal="post_relation_add", senders=senders)
|
||||
|
||||
|
||||
def pre_relation_remove(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
|
||||
def pre_relation_remove(senders: Union[Type["Model"], List[Type["Model"]]]
|
||||
) -> Callable:
|
||||
"""
|
||||
Connect given function to all senders for pre_relation_remove signal.
|
||||
|
||||
@ -171,3 +176,18 @@ def post_relation_remove(
|
||||
:rtype: Callable
|
||||
"""
|
||||
return receiver(signal="post_relation_remove", senders=senders)
|
||||
|
||||
|
||||
def post_bulk_update(
|
||||
senders: Union[Type["Model"], List[Type["Model"]]]
|
||||
) -> Callable:
|
||||
"""
|
||||
Connect given function to all senders for post_bulk_update signal.
|
||||
|
||||
:param senders: one or a list of "Model" classes
|
||||
that should have the signal receiver registered
|
||||
:type senders: Union[Type["Model"], List[Type["Model"]]]
|
||||
:return: returns the original function untouched
|
||||
:rtype: Callable
|
||||
"""
|
||||
return receiver(signal="post_bulk_update", senders=senders)
|
||||
|
||||
@ -81,3 +81,10 @@ class SignalDefinitionError(AsyncOrmException):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelListEmptyError(AsyncOrmException):
|
||||
"""
|
||||
Raised for objects is empty when bulk_update
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -159,6 +159,7 @@ def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001
|
||||
signals.post_relation_add = Signal()
|
||||
signals.pre_relation_remove = Signal()
|
||||
signals.post_relation_remove = Signal()
|
||||
signals.post_bulk_update = Signal()
|
||||
new_model.Meta.signals = signals
|
||||
|
||||
|
||||
|
||||
@ -21,7 +21,10 @@ from sqlalchemy import bindparam
|
||||
|
||||
import ormar # noqa I100
|
||||
from ormar import MultipleMatches, NoMatch
|
||||
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
|
||||
from ormar.exceptions import (
|
||||
ModelPersistenceError, QueryDefinitionError,
|
||||
ModelListEmptyError
|
||||
)
|
||||
from ormar.queryset import FieldAccessor, FilterQuery, SelectAction
|
||||
from ormar.queryset.actions.order_action import OrderAction
|
||||
from ormar.queryset.clause import FilterGroup, QueryClause
|
||||
@ -1037,9 +1040,7 @@ class QuerySet(Generic[T]):
|
||||
instance = await instance.save()
|
||||
return instance
|
||||
|
||||
async def bulk_create(
|
||||
self, objects: List["T"], send_signals: bool = False
|
||||
) -> None:
|
||||
async def bulk_create(self, objects: List["T"]) -> None:
|
||||
"""
|
||||
Performs a bulk create in one database session to speed up the process.
|
||||
|
||||
@ -1051,26 +1052,7 @@ class QuerySet(Generic[T]):
|
||||
|
||||
:param objects: list of ormar models already initialized and ready to save.
|
||||
:type objects: List[Model]
|
||||
:param send_signals: confirm send the pre/post create signals
|
||||
:type send_signals: bool
|
||||
"""
|
||||
|
||||
async def before_create(entity: "T") -> None:
|
||||
await entity.signals.pre_save.send(
|
||||
sender=entity.__class__, instance=entity)
|
||||
|
||||
async def after_create(entity: "T") -> None:
|
||||
entity.set_save_status(True)
|
||||
# FIXME do not have the `id` value, because not reload from db
|
||||
if send_signals:
|
||||
await entity.signals.post_save.send(
|
||||
sender=entity.__class__, instance=entity)
|
||||
|
||||
if send_signals:
|
||||
await asyncio.gather(
|
||||
*[before_create(entity) for entity in objects]
|
||||
)
|
||||
|
||||
ready_objects = [
|
||||
obj.prepare_model_to_save(obj.dict())
|
||||
for obj in objects
|
||||
@ -1080,13 +1062,11 @@ class QuerySet(Generic[T]):
|
||||
# shouldn't use the execute_many, it's `queries.foreach(execute)`
|
||||
await self.database.execute(expr)
|
||||
|
||||
await asyncio.gather(
|
||||
*[after_create(entity) for entity in objects]
|
||||
)
|
||||
for obj in objects:
|
||||
obj.set_save_status(True)
|
||||
|
||||
async def bulk_update( # noqa: CCR001
|
||||
self, objects: List["T"], columns: List[str] = None,
|
||||
send_signals: bool = False
|
||||
self, objects: List["T"], columns: List[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Performs bulk update in one database session to speed up the process.
|
||||
@ -1104,9 +1084,10 @@ class QuerySet(Generic[T]):
|
||||
:type objects: List[Model]
|
||||
:param columns: list of columns to update
|
||||
:type columns: List[str]
|
||||
:param send_signals: confirm send the pre/post create signals
|
||||
:type send_signals: bool
|
||||
"""
|
||||
if not objects:
|
||||
raise ModelListEmptyError("Bulk update objects are empty!")
|
||||
|
||||
ready_objects = []
|
||||
pk_name = self.model_meta.pkname
|
||||
if not columns:
|
||||
@ -1151,14 +1132,11 @@ class QuerySet(Generic[T]):
|
||||
expr = str(expr)
|
||||
await self.database.execute_many(expr, ready_objects)
|
||||
|
||||
# FIXME: add pre-update-signals
|
||||
async def after_update(entity: "T") -> None:
|
||||
entity.set_save_status(True)
|
||||
if send_signals:
|
||||
await entity.signals.post_update.send(
|
||||
sender=entity.__class__, instance=entity
|
||||
)
|
||||
entity = list(objects)[0]
|
||||
|
||||
await asyncio.gather(
|
||||
*[after_update(entity) for entity in objects]
|
||||
await entity.signals.post_bulk_update.send(
|
||||
sender=entity.__class__, instances=objects
|
||||
)
|
||||
|
||||
for obj in objects:
|
||||
obj.set_save_status(True)
|
||||
|
||||
@ -5,8 +5,10 @@ import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from ormar import post_save, post_update, pre_save
|
||||
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
|
||||
from ormar.exceptions import (
|
||||
ModelPersistenceError, QueryDefinitionError,
|
||||
ModelListEmptyError
|
||||
)
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
@ -194,35 +196,6 @@ async def test_bulk_create():
|
||||
assert len(completed) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_create_send_signals():
|
||||
async with database:
|
||||
|
||||
@pre_save(ToDo)
|
||||
async def before_save(sender, instance, **kwargs):
|
||||
instance.completed = False
|
||||
|
||||
@post_save(ToDo)
|
||||
async def after_save(sender, instance, **kwargs):
|
||||
assert not instance.completed
|
||||
|
||||
await ToDo.objects.bulk_create(
|
||||
[
|
||||
ToDo(text="Buy the groceries."),
|
||||
ToDo(text="Call Mum.", completed=True),
|
||||
ToDo(text="Send invoices.", completed=True),
|
||||
], send_signals=True
|
||||
)
|
||||
|
||||
todoes = await ToDo.objects.all()
|
||||
assert len(todoes) == 3
|
||||
for todo in todoes:
|
||||
assert todo.pk is not None
|
||||
|
||||
count = await ToDo.objects.filter(completed=False).count()
|
||||
assert count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_create_with_relation():
|
||||
async with database:
|
||||
@ -270,34 +243,6 @@ async def test_bulk_update():
|
||||
assert todo.text[-2:] == "_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_send_signals():
|
||||
async with database:
|
||||
|
||||
@post_update(ToDo)
|
||||
async def after_update(sender, instance, **kwargs):
|
||||
await instance.delete()
|
||||
|
||||
await ToDo.objects.bulk_create(
|
||||
[
|
||||
ToDo(text="Buy the groceries."),
|
||||
ToDo(text="Call Mum.", completed=True),
|
||||
ToDo(text="Send invoices.", completed=True),
|
||||
]
|
||||
)
|
||||
todoes = await ToDo.objects.all()
|
||||
assert len(todoes) == 3
|
||||
|
||||
for todo in todoes:
|
||||
todo.text = todo.text + "_1"
|
||||
todo.completed = False
|
||||
|
||||
await ToDo.objects.bulk_update(todoes, send_signals=True)
|
||||
|
||||
count = await ToDo.objects.filter(completed=False).count()
|
||||
assert count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_with_only_selected_columns():
|
||||
async with database:
|
||||
@ -367,3 +312,6 @@ async def test_bulk_update_not_saved_objts():
|
||||
Note(text="Call Mum.", category=category),
|
||||
]
|
||||
)
|
||||
|
||||
with pytest.raises(ModelListEmptyError):
|
||||
await Note.objects.bulk_update([])
|
||||
|
||||
@ -6,7 +6,10 @@ import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from ormar import post_delete, post_save, post_update, pre_delete, pre_save, pre_update
|
||||
from ormar import (
|
||||
post_bulk_update, post_delete, post_save, post_update,
|
||||
pre_delete, pre_save, pre_update
|
||||
)
|
||||
from ormar.signals import SignalEmitter
|
||||
from ormar.exceptions import SignalDefinitionError
|
||||
from tests.settings import DATABASE_URL
|
||||
@ -131,6 +134,14 @@ async def test_signal_functions(cleanup):
|
||||
event_log=instance.json(),
|
||||
).save()
|
||||
|
||||
@post_bulk_update(Album)
|
||||
async def after_bulk_update(sender, instances, **kwargs):
|
||||
for it in instances:
|
||||
await AuditLog(
|
||||
event_type=f"BULK_POST_UPDATE_{sender.get_name()}",
|
||||
event_log=it.json(),
|
||||
).save()
|
||||
|
||||
album = await Album.objects.create(name="Venice")
|
||||
|
||||
audits = await AuditLog.objects.all()
|
||||
@ -183,6 +194,19 @@ async def test_signal_functions(cleanup):
|
||||
album.signals.pre_save.disconnect(before_save)
|
||||
album.signals.post_save.disconnect(after_save)
|
||||
|
||||
albums = await Album.objects.all()
|
||||
assert len(albums)
|
||||
|
||||
for album in albums:
|
||||
album.play_count = 1
|
||||
|
||||
await Album.objects.bulk_update(albums)
|
||||
|
||||
cnt = await AuditLog.objects.filter(event_type__contains="BULK_POST").count()
|
||||
assert cnt == len(albums)
|
||||
|
||||
album.signals.bulk_post_update.disconnect(after_bulk_update)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_signals(cleanup):
|
||||
|
||||
Reference in New Issue
Block a user