diff --git a/ormar/decorators/signals.py b/ormar/decorators/signals.py index 468fc0c..54e97e5 100644 --- a/ormar/decorators/signals.py +++ b/ormar/decorators/signals.py @@ -1,4 +1,4 @@ -from typing import Callable, List, TYPE_CHECKING, Type, Union +from typing import Callable, List, Type, TYPE_CHECKING, Union if TYPE_CHECKING: # pragma: no cover from ormar import Model diff --git a/ormar/signals/signal.py b/ormar/signals/signal.py index e2c5275..d5b4a5c 100644 --- a/ormar/signals/signal.py +++ b/ormar/signals/signal.py @@ -1,6 +1,6 @@ import asyncio import inspect -from typing import Any, Callable, Dict, List, TYPE_CHECKING, Tuple, Type, Union +from typing import Any, Callable, Dict, Tuple, Type, TYPE_CHECKING, Union from ormar.exceptions import SignalDefinitionError @@ -45,7 +45,7 @@ class Signal: """ def __init__(self) -> None: - self._receivers: List[Tuple[Union[int, Tuple[int, int]], Callable]] = [] + self._receivers: Dict[Union[int, Tuple[int, int]], Callable] = {} def connect(self, receiver: Callable) -> None: """ @@ -63,8 +63,8 @@ class Signal: "Signal receivers must accept **kwargs argument." ) new_receiver_key = make_id(receiver) - if not any(rec_id == new_receiver_key for rec_id, _ in self._receivers): - self._receivers.append((new_receiver_key, receiver)) + if new_receiver_key not in self._receivers: + self._receivers[new_receiver_key] = receiver def disconnect(self, receiver: Callable) -> bool: """ @@ -75,15 +75,10 @@ class Signal: :return: flag if receiver was removed :rtype: bool """ - removed = False new_receiver_key = make_id(receiver) - for ind, rec in enumerate(self._receivers): - rec_id, _ = rec - if rec_id == new_receiver_key: - removed = True - del self._receivers[ind] - break - return removed + receiver_func: Union[Callable, None] = self._receivers.pop( + new_receiver_key, None) + return True if receiver_func is not None else False async def send(self, sender: Type["Model"], **kwargs: Any) -> None: """ @@ -93,28 +88,22 @@ class Signal: :param kwargs: arguments passed to receivers :type kwargs: Any """ - receivers = [] - for receiver in self._receivers: - _, receiver_func = receiver - receivers.append(receiver_func(sender=sender, **kwargs)) + receivers = [ + receiver_func(sender=sender, **kwargs) + for receiver_func in self._receivers.values() + ] await asyncio.gather(*receivers) -class SignalEmitter: +class SignalEmitter(dict): """ Emitter that registers the signals in internal dictionary. If signal with given name does not exist it's auto added on access. """ - - if TYPE_CHECKING: # pragma: no cover - signals: Dict[str, Signal] - - def __init__(self) -> None: - object.__setattr__(self, "signals", dict()) - def __getattr__(self, item: str) -> Signal: - return self.signals.setdefault(item, Signal()) + return self.setdefault(item, Signal()) - def __setattr__(self, key: str, value: Any) -> None: - signals = object.__getattribute__(self, "signals") - signals[key] = value + def __setattr__(self, key: str, value: Signal) -> None: + if not isinstance(value, Signal): + raise SignalDefinitionError(f"{value} is not valid signal") + self[key] = value diff --git a/tests/test_signals/test_signals.py b/tests/test_signals/test_signals.py index 4832682..3dbcfa6 100644 --- a/tests/test_signals/test_signals.py +++ b/tests/test_signals/test_signals.py @@ -7,6 +7,7 @@ import sqlalchemy import ormar from ormar import post_delete, post_save, post_update, pre_delete, pre_save, pre_update +from ormar.signals import SignalEmitter from ormar.exceptions import SignalDefinitionError from tests.settings import DATABASE_URL @@ -77,6 +78,12 @@ def test_passing_callable_without_kwargs(): pass +def test_invalid_signal(): + emitter = SignalEmitter() + with pytest.raises(SignalDefinitionError): + emitter.save = 1 + + @pytest.mark.asyncio async def test_signal_functions(cleanup): async with database: