add signals, register six signals on each models (pre/post + save/update/delete)
This commit is contained in:
56
ormar/decorators/signals.py
Normal file
56
ormar/decorators/signals.py
Normal file
@ -0,0 +1,56 @@
|
||||
from typing import Any, Callable, List, TYPE_CHECKING, Type, Union
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ormar import Model
|
||||
|
||||
|
||||
def receiver(
|
||||
signal: str, senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
|
||||
) -> Callable:
|
||||
def _decorator(func: Callable) -> Callable:
|
||||
if not isinstance(senders, list):
|
||||
_senders = [senders]
|
||||
else:
|
||||
_senders = senders
|
||||
for sender in _senders:
|
||||
signals = getattr(sender.Meta.signals, signal)
|
||||
signals.connect(func, **kwargs)
|
||||
return func
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
def post_save(
|
||||
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
|
||||
) -> Callable:
|
||||
return receiver(signal="post_save", senders=senders, **kwargs)
|
||||
|
||||
|
||||
def post_update(
|
||||
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
|
||||
) -> Callable:
|
||||
return receiver(signal="post_update", senders=senders, **kwargs)
|
||||
|
||||
|
||||
def post_delete(
|
||||
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
|
||||
) -> Callable:
|
||||
return receiver(signal="post_delete", senders=senders, **kwargs)
|
||||
|
||||
|
||||
def pre_save(
|
||||
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
|
||||
) -> Callable:
|
||||
return receiver(signal="pre_save", senders=senders, **kwargs)
|
||||
|
||||
|
||||
def pre_update(
|
||||
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
|
||||
) -> Callable:
|
||||
return receiver(signal="pre_update", senders=senders, **kwargs)
|
||||
|
||||
|
||||
def pre_delete(
|
||||
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
|
||||
) -> Callable:
|
||||
return receiver(signal="pre_delete", senders=senders, **kwargs)
|
||||
@ -18,6 +18,7 @@ from ormar.fields.many_to_many import ManyToMany, ManyToManyField
|
||||
from ormar.models.quick_access_views import quick_access_set
|
||||
from ormar.queryset import QuerySet
|
||||
from ormar.relations.alias_manager import AliasManager
|
||||
from ormar.signals import Signal, SignalEmitter
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
@ -38,6 +39,7 @@ class ModelMeta:
|
||||
]
|
||||
alias_manager: AliasManager
|
||||
property_fields: Set
|
||||
signals: SignalEmitter
|
||||
|
||||
|
||||
def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None:
|
||||
@ -332,15 +334,12 @@ def add_cached_properties(new_model: Type["Model"]) -> None:
|
||||
new_model._pydantic_fields = {name for name in new_model.__fields__}
|
||||
|
||||
|
||||
def property_fields_not_set(new_model: Type["Model"]) -> bool:
|
||||
return (
|
||||
not hasattr(new_model.Meta, "property_fields")
|
||||
or not new_model.Meta.property_fields
|
||||
)
|
||||
def meta_field_not_set(model: Type["Model"], field_name: str) -> bool:
|
||||
return not hasattr(model.Meta, field_name) or not getattr(model.Meta, field_name)
|
||||
|
||||
|
||||
def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: CCR001
|
||||
if property_fields_not_set(new_model):
|
||||
if meta_field_not_set(model=new_model, field_name="property_fields"):
|
||||
props = set()
|
||||
for var_name, value in attrs.items():
|
||||
if isinstance(value, property):
|
||||
@ -351,6 +350,18 @@ def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa:
|
||||
new_model.Meta.property_fields = props
|
||||
|
||||
|
||||
def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001
|
||||
if meta_field_not_set(model=new_model, field_name="signals"):
|
||||
signals = SignalEmitter()
|
||||
signals.pre_save = Signal()
|
||||
signals.pre_update = Signal()
|
||||
signals.pre_delete = Signal()
|
||||
signals.post_save = Signal()
|
||||
signals.post_update = Signal()
|
||||
signals.post_delete = Signal()
|
||||
new_model.Meta.signals = signals
|
||||
|
||||
|
||||
class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
def __new__( # type: ignore
|
||||
mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict
|
||||
@ -379,5 +390,6 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
new_model.Meta.alias_manager = alias_manager
|
||||
new_model.objects = QuerySet(new_model)
|
||||
add_property_fields(new_model, attrs)
|
||||
register_signals(new_model=new_model)
|
||||
|
||||
return new_model
|
||||
|
||||
@ -195,6 +195,9 @@ class Model(NewBaseModel):
|
||||
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
|
||||
self_fields.pop(self.Meta.pkname, None)
|
||||
self_fields = self.populate_default_values(self_fields)
|
||||
self.from_dict(self_fields)
|
||||
|
||||
await self.signals.pre_save.send(sender=self.__class__, instance=self)
|
||||
|
||||
self_fields = self.translate_columns_to_aliases(self_fields)
|
||||
expr = self.Meta.table.insert()
|
||||
@ -204,6 +207,7 @@ class Model(NewBaseModel):
|
||||
if pk and isinstance(pk, self.pk_type()):
|
||||
setattr(self, self.Meta.pkname, pk)
|
||||
|
||||
self.set_save_status(True)
|
||||
# refresh server side defaults
|
||||
if any(
|
||||
field.server_default is not None
|
||||
@ -211,9 +215,8 @@ class Model(NewBaseModel):
|
||||
if name not in self_fields
|
||||
):
|
||||
await self.load()
|
||||
return self
|
||||
|
||||
self.set_save_status(True)
|
||||
await self.signals.post_save.send(sender=self.__class__, instance=self)
|
||||
return self
|
||||
|
||||
async def save_related( # noqa: CCR001
|
||||
@ -268,6 +271,7 @@ class Model(NewBaseModel):
|
||||
"You cannot update not saved model! Use save or upsert method."
|
||||
)
|
||||
|
||||
await self.signals.pre_update.send(sender=self.__class__, instance=self)
|
||||
self_fields = self._extract_model_db_fields()
|
||||
self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname))
|
||||
self_fields = self.translate_columns_to_aliases(self_fields)
|
||||
@ -276,13 +280,16 @@ class Model(NewBaseModel):
|
||||
|
||||
await self.Meta.database.execute(expr)
|
||||
self.set_save_status(True)
|
||||
await self.signals.post_update.send(sender=self.__class__, instance=self)
|
||||
return self
|
||||
|
||||
async def delete(self: T) -> int:
|
||||
await self.signals.pre_delete.send(sender=self.__class__, instance=self)
|
||||
expr = self.Meta.table.delete()
|
||||
expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname)))
|
||||
result = await self.Meta.database.execute(expr)
|
||||
self.set_save_status(False)
|
||||
await self.signals.post_delete.send(sender=self.__class__, instance=self)
|
||||
return result
|
||||
|
||||
async def load(self: T) -> T:
|
||||
|
||||
@ -35,6 +35,7 @@ from ormar.relations.relation_manager import RelationsManager
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
from ormar.signals import SignalEmitter
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
|
||||
@ -212,6 +213,10 @@ class NewBaseModel(
|
||||
def saved(self) -> bool:
|
||||
return self._orm_saved
|
||||
|
||||
@property
|
||||
def signals(self) -> "SignalEmitter":
|
||||
return self.Meta.signals
|
||||
|
||||
@classmethod
|
||||
def pk_type(cls) -> Any:
|
||||
return cls.Meta.model_fields[cls.Meta.pkname].__type__
|
||||
|
||||
0
ormar/py.typed
Normal file
0
ormar/py.typed
Normal file
@ -395,6 +395,9 @@ class QuerySet:
|
||||
expr = expr.values(**new_kwargs)
|
||||
|
||||
instance = self.model(**kwargs)
|
||||
await self.model.Meta.signals.pre_save.send(
|
||||
sender=self.model, instance=instance
|
||||
)
|
||||
pk = await self.database.execute(expr)
|
||||
|
||||
pk_name = self.model.get_column_alias(self.model_meta.pkname)
|
||||
@ -411,6 +414,9 @@ class QuerySet:
|
||||
):
|
||||
instance = await instance.load()
|
||||
instance.set_save_status(True)
|
||||
await self.model.Meta.signals.post_save.send(
|
||||
sender=self.model, instance=instance
|
||||
)
|
||||
return instance
|
||||
|
||||
async def bulk_create(self, objects: List["Model"]) -> None:
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
from ormar.signals.signal import Signal, SignalEmitter
|
||||
|
||||
__all__ = ["Signal", "SignalEmitter"]
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Any, Callable, List, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Tuple, Type, Union
|
||||
|
||||
from ormar.exceptions import SignalDefinitionError
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ormar import Model
|
||||
|
||||
|
||||
def callable_accepts_kwargs(func: Callable) -> bool:
|
||||
return any(
|
||||
@ -45,9 +48,24 @@ class Signal:
|
||||
break
|
||||
return removed
|
||||
|
||||
async def send(self, sender: Any, **kwargs: Any) -> None:
|
||||
async def send(self, sender: Type["Model"], **kwargs: Any) -> None:
|
||||
receivers = []
|
||||
for receiver in self._receivers:
|
||||
_, receiver_func = receiver
|
||||
receivers.append(receiver_func(sender, **kwargs))
|
||||
receivers.append(receiver_func(sender=sender, **kwargs))
|
||||
await asyncio.gather(*receivers)
|
||||
|
||||
|
||||
class SignalEmitter:
|
||||
if TYPE_CHECKING:
|
||||
signals: Dict[str, Signal]
|
||||
|
||||
def __init__(self) -> None:
|
||||
object.__setattr__(self, "signals", dict())
|
||||
|
||||
def __getattr__(self, item: str) -> Signal:
|
||||
return self.signals.setdefault(item, Signal())
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
signals = object.__getattribute__(self, "signals")
|
||||
signals[key] = value
|
||||
|
||||
Reference in New Issue
Block a user