diff --git a/ormar/signals/signal.py b/ormar/signals/signal.py index 4ed12d7..f3d92d9 100644 --- a/ormar/signals/signal.py +++ b/ormar/signals/signal.py @@ -57,7 +57,7 @@ class Signal: class SignalEmitter: - if TYPE_CHECKING: + if TYPE_CHECKING: # pragma: no cover signals: Dict[str, Signal] def __init__(self) -> None: diff --git a/tests/test_signals.py b/tests/test_signals.py index 6f59d6d..bc72706 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,3 +1,5 @@ +from typing import Optional + import databases import pydantic import pytest @@ -30,6 +32,16 @@ class AuditLog(ormar.Model): event_log: pydantic.Json = ormar.JSON() +class Cover(ormar.Model): + class Meta: + tablename = "covers" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + title: str = ormar.String(max_length=100) + + class Album(ormar.Model): class Meta: tablename = "albums" @@ -39,6 +51,7 @@ class Album(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) is_best_seller: bool = ormar.Boolean(default=False) + cover: Optional[Cover] = ormar.ForeignKey(Cover) @pytest.fixture(autouse=True, scope="module") @@ -57,7 +70,6 @@ def test_passing_not_callable(): def test_passing_callable_without_kwargs(): with pytest.raises(SignalDefinitionError): - @pre_save(Album) def trigger(sender, instance): # pragma: no cover pass @@ -67,7 +79,6 @@ def test_passing_callable_without_kwargs(): async def test_signal_functions(): async with database: async with database.transaction(force_rollback=True): - @pre_save(Album) async def before_save(sender, instance, **kwargs): await AuditLog( @@ -151,9 +162,9 @@ async def test_signal_functions(): assert len(audits) == 2 assert audits[0].event_type == "PRE_DELETE_album" assert ( - audits[0].event_log.get("id") - == audits[1].event_log.get("id") - == album.id + audits[0].event_log.get("id") + == audits[1].event_log.get("id") + == album.id ) assert audits[1].event_type == "POST_DELETE_album" @@ -167,7 +178,6 @@ async def test_signal_functions(): async def test_multiple_signals(): async with database: async with database.transaction(force_rollback=True): - @pre_save(Album) async def before_save(sender, instance, **kwargs): await AuditLog( @@ -198,7 +208,6 @@ async def test_multiple_signals(): async def test_static_methods_as_signals(): async with database: async with database.transaction(force_rollback=True): - class AlbumAuditor: event_type = "ALBUM_INSTANCE" @@ -223,7 +232,6 @@ async def test_static_methods_as_signals(): async def test_methods_as_signals(): async with database: async with database.transaction(force_rollback=True): - class AlbumAuditor: def __init__(self): self.event_type = "ALBUM_INSTANCE" @@ -243,3 +251,27 @@ async def test_methods_as_signals(): assert audits[0].event_log.get("name") == album.name album.signals.pre_save.disconnect(auditor.before_save) + +@pytest.mark.asyncio +async def test_multiple_senders_signal(): + async with database: + async with database.transaction(force_rollback=True): + @pre_save([Album, Cover]) + async def before_save(sender, instance, **kwargs): + await AuditLog( + event_type=f"PRE_SAVE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + cover = await Cover(title='Blue').save() + album = await Album.objects.create(name="San Francisco", cover=cover) + + audits = await AuditLog.objects.all() + assert len(audits) == 2 + assert audits[0].event_type == "PRE_SAVE_cover" + assert audits[0].event_log.get("title") == cover.title + assert audits[1].event_type == "PRE_SAVE_album" + assert audits[1].event_log.get("cover") == album.cover.dict(exclude={'albums'}) + + album.signals.pre_save.disconnect(before_save) + cover.signals.pre_save.disconnect(before_save)