Merge pull request #493 from ponytailer/master
use the dict to store the receivers in signal
This commit is contained in:
@ -1,4 +1,4 @@
|
|||||||
from typing import Callable, List, TYPE_CHECKING, Type, Union
|
from typing import Callable, List, Type, TYPE_CHECKING, Union
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Tuple, Type, TYPE_CHECKING, Union
|
||||||
|
|
||||||
from ormar.exceptions import SignalDefinitionError
|
from ormar.exceptions import SignalDefinitionError
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ class Signal:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._receivers: List[Tuple[Union[int, Tuple[int, int]], Callable]] = []
|
self._receivers: Dict[Union[int, Tuple[int, int]], Callable] = {}
|
||||||
|
|
||||||
def connect(self, receiver: Callable) -> None:
|
def connect(self, receiver: Callable) -> None:
|
||||||
"""
|
"""
|
||||||
@ -63,8 +63,8 @@ class Signal:
|
|||||||
"Signal receivers must accept **kwargs argument."
|
"Signal receivers must accept **kwargs argument."
|
||||||
)
|
)
|
||||||
new_receiver_key = make_id(receiver)
|
new_receiver_key = make_id(receiver)
|
||||||
if not any(rec_id == new_receiver_key for rec_id, _ in self._receivers):
|
if new_receiver_key not in self._receivers:
|
||||||
self._receivers.append((new_receiver_key, receiver))
|
self._receivers[new_receiver_key] = receiver
|
||||||
|
|
||||||
def disconnect(self, receiver: Callable) -> bool:
|
def disconnect(self, receiver: Callable) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -75,15 +75,10 @@ class Signal:
|
|||||||
:return: flag if receiver was removed
|
:return: flag if receiver was removed
|
||||||
:rtype: bool
|
:rtype: bool
|
||||||
"""
|
"""
|
||||||
removed = False
|
|
||||||
new_receiver_key = make_id(receiver)
|
new_receiver_key = make_id(receiver)
|
||||||
for ind, rec in enumerate(self._receivers):
|
receiver_func: Union[Callable, None] = self._receivers.pop(
|
||||||
rec_id, _ = rec
|
new_receiver_key, None)
|
||||||
if rec_id == new_receiver_key:
|
return True if receiver_func is not None else False
|
||||||
removed = True
|
|
||||||
del self._receivers[ind]
|
|
||||||
break
|
|
||||||
return removed
|
|
||||||
|
|
||||||
async def send(self, sender: Type["Model"], **kwargs: Any) -> None:
|
async def send(self, sender: Type["Model"], **kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
@ -93,28 +88,22 @@ class Signal:
|
|||||||
:param kwargs: arguments passed to receivers
|
:param kwargs: arguments passed to receivers
|
||||||
:type kwargs: Any
|
:type kwargs: Any
|
||||||
"""
|
"""
|
||||||
receivers = []
|
receivers = [
|
||||||
for receiver in self._receivers:
|
receiver_func(sender=sender, **kwargs)
|
||||||
_, receiver_func = receiver
|
for receiver_func in self._receivers.values()
|
||||||
receivers.append(receiver_func(sender=sender, **kwargs))
|
]
|
||||||
await asyncio.gather(*receivers)
|
await asyncio.gather(*receivers)
|
||||||
|
|
||||||
|
|
||||||
class SignalEmitter:
|
class SignalEmitter(dict):
|
||||||
"""
|
"""
|
||||||
Emitter that registers the signals in internal dictionary.
|
Emitter that registers the signals in internal dictionary.
|
||||||
If signal with given name does not exist it's auto added on access.
|
If signal with given name does not exist it's auto added on access.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
|
||||||
signals: Dict[str, Signal]
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
object.__setattr__(self, "signals", dict())
|
|
||||||
|
|
||||||
def __getattr__(self, item: str) -> Signal:
|
def __getattr__(self, item: str) -> Signal:
|
||||||
return self.signals.setdefault(item, Signal())
|
return self.setdefault(item, Signal())
|
||||||
|
|
||||||
def __setattr__(self, key: str, value: Any) -> None:
|
def __setattr__(self, key: str, value: Signal) -> None:
|
||||||
signals = object.__getattribute__(self, "signals")
|
if not isinstance(value, Signal):
|
||||||
signals[key] = value
|
raise SignalDefinitionError(f"{value} is not valid signal")
|
||||||
|
self[key] = value
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import sqlalchemy
|
|||||||
|
|
||||||
import ormar
|
import ormar
|
||||||
from ormar import post_delete, post_save, post_update, pre_delete, pre_save, pre_update
|
from ormar import post_delete, post_save, post_update, pre_delete, pre_save, pre_update
|
||||||
|
from ormar.signals import SignalEmitter
|
||||||
from ormar.exceptions import SignalDefinitionError
|
from ormar.exceptions import SignalDefinitionError
|
||||||
from tests.settings import DATABASE_URL
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
@ -77,6 +78,12 @@ def test_passing_callable_without_kwargs():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_signal():
|
||||||
|
emitter = SignalEmitter()
|
||||||
|
with pytest.raises(SignalDefinitionError):
|
||||||
|
emitter.save = 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_signal_functions(cleanup):
|
async def test_signal_functions(cleanup):
|
||||||
async with database:
|
async with database:
|
||||||
|
|||||||
Reference in New Issue
Block a user