Merge pull request #493 from ponytailer/master

use the dict to store the receivers in signal
This commit is contained in:
collerek
2022-01-13 17:58:47 +01:00
committed by GitHub
3 changed files with 25 additions and 29 deletions

View File

@ -1,4 +1,4 @@
from typing import Callable, List, TYPE_CHECKING, Type, Union from typing import Callable, List, Type, TYPE_CHECKING, Union
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from ormar import Model from ormar import Model

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
import inspect import inspect
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Tuple, Type, Union from typing import Any, Callable, Dict, Tuple, Type, TYPE_CHECKING, Union
from ormar.exceptions import SignalDefinitionError from ormar.exceptions import SignalDefinitionError
@ -45,7 +45,7 @@ class Signal:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._receivers: List[Tuple[Union[int, Tuple[int, int]], Callable]] = [] self._receivers: Dict[Union[int, Tuple[int, int]], Callable] = {}
def connect(self, receiver: Callable) -> None: def connect(self, receiver: Callable) -> None:
""" """
@ -63,8 +63,8 @@ class Signal:
"Signal receivers must accept **kwargs argument." "Signal receivers must accept **kwargs argument."
) )
new_receiver_key = make_id(receiver) new_receiver_key = make_id(receiver)
if not any(rec_id == new_receiver_key for rec_id, _ in self._receivers): if new_receiver_key not in self._receivers:
self._receivers.append((new_receiver_key, receiver)) self._receivers[new_receiver_key] = receiver
def disconnect(self, receiver: Callable) -> bool: def disconnect(self, receiver: Callable) -> bool:
""" """
@ -75,15 +75,10 @@ class Signal:
:return: flag if receiver was removed :return: flag if receiver was removed
:rtype: bool :rtype: bool
""" """
removed = False
new_receiver_key = make_id(receiver) new_receiver_key = make_id(receiver)
for ind, rec in enumerate(self._receivers): receiver_func: Union[Callable, None] = self._receivers.pop(
rec_id, _ = rec new_receiver_key, None)
if rec_id == new_receiver_key: return True if receiver_func is not None else False
removed = True
del self._receivers[ind]
break
return removed
async def send(self, sender: Type["Model"], **kwargs: Any) -> None: async def send(self, sender: Type["Model"], **kwargs: Any) -> None:
""" """
@ -93,28 +88,22 @@ class Signal:
:param kwargs: arguments passed to receivers :param kwargs: arguments passed to receivers
:type kwargs: Any :type kwargs: Any
""" """
receivers = [] receivers = [
for receiver in self._receivers: receiver_func(sender=sender, **kwargs)
_, receiver_func = receiver for receiver_func in self._receivers.values()
receivers.append(receiver_func(sender=sender, **kwargs)) ]
await asyncio.gather(*receivers) await asyncio.gather(*receivers)
class SignalEmitter: class SignalEmitter(dict):
""" """
Emitter that registers the signals in internal dictionary. Emitter that registers the signals in internal dictionary.
If signal with given name does not exist it's auto added on access. If signal with given name does not exist it's auto added on access.
""" """
if TYPE_CHECKING: # pragma: no cover
signals: Dict[str, Signal]
def __init__(self) -> None:
object.__setattr__(self, "signals", dict())
def __getattr__(self, item: str) -> Signal: def __getattr__(self, item: str) -> Signal:
return self.signals.setdefault(item, Signal()) return self.setdefault(item, Signal())
def __setattr__(self, key: str, value: Any) -> None: def __setattr__(self, key: str, value: Signal) -> None:
signals = object.__getattribute__(self, "signals") if not isinstance(value, Signal):
signals[key] = value raise SignalDefinitionError(f"{value} is not valid signal")
self[key] = value

View File

@ -7,6 +7,7 @@ import sqlalchemy
import ormar import ormar
from ormar import post_delete, post_save, post_update, pre_delete, pre_save, pre_update from ormar import post_delete, post_save, post_update, pre_delete, pre_save, pre_update
from ormar.signals import SignalEmitter
from ormar.exceptions import SignalDefinitionError from ormar.exceptions import SignalDefinitionError
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -77,6 +78,12 @@ def test_passing_callable_without_kwargs():
pass pass
def test_invalid_signal():
emitter = SignalEmitter()
with pytest.raises(SignalDefinitionError):
emitter.save = 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_signal_functions(cleanup): async def test_signal_functions(cleanup):
async with database: async with database: