add signals, register six signals on each models (pre/post + save/update/delete)
This commit is contained in:
@ -3,6 +3,7 @@
|
|||||||
* **Breaking:** QuerySet `bulk_update` method now raises `ModelPersistenceError` for unsaved models passed instead of `QueryDefinitionError`
|
* **Breaking:** QuerySet `bulk_update` method now raises `ModelPersistenceError` for unsaved models passed instead of `QueryDefinitionError`
|
||||||
* **Breaking:** Model initialization with unknown field name now raises `ModelError` instead of `KeyError`
|
* **Breaking:** Model initialization with unknown field name now raises `ModelError` instead of `KeyError`
|
||||||
*
|
*
|
||||||
|
* Add py.typed and modify setup.py for mypy support
|
||||||
* Performance optimization
|
* Performance optimization
|
||||||
|
|
||||||
# 0.6.2
|
# 0.6.2
|
||||||
|
|||||||
56
ormar/decorators/signals.py
Normal file
56
ormar/decorators/signals.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from typing import Any, Callable, List, TYPE_CHECKING, Type, Union
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
|
from ormar import Model
|
||||||
|
|
||||||
|
|
||||||
|
def receiver(
|
||||||
|
signal: str, senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
|
||||||
|
) -> Callable:
|
||||||
|
def _decorator(func: Callable) -> Callable:
|
||||||
|
if not isinstance(senders, list):
|
||||||
|
_senders = [senders]
|
||||||
|
else:
|
||||||
|
_senders = senders
|
||||||
|
for sender in _senders:
|
||||||
|
signals = getattr(sender.Meta.signals, signal)
|
||||||
|
signals.connect(func, **kwargs)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
|
def post_save(
|
||||||
|
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
|
||||||
|
) -> Callable:
|
||||||
|
return receiver(signal="post_save", senders=senders, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def post_update(
|
||||||
|
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
|
||||||
|
) -> Callable:
|
||||||
|
return receiver(signal="post_update", senders=senders, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def post_delete(
|
||||||
|
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
|
||||||
|
) -> Callable:
|
||||||
|
return receiver(signal="post_delete", senders=senders, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def pre_save(
|
||||||
|
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
|
||||||
|
) -> Callable:
|
||||||
|
return receiver(signal="pre_save", senders=senders, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def pre_update(
|
||||||
|
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
|
||||||
|
) -> Callable:
|
||||||
|
return receiver(signal="pre_update", senders=senders, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def pre_delete(
|
||||||
|
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
|
||||||
|
) -> Callable:
|
||||||
|
return receiver(signal="pre_delete", senders=senders, **kwargs)
|
||||||
@ -18,6 +18,7 @@ from ormar.fields.many_to_many import ManyToMany, ManyToManyField
|
|||||||
from ormar.models.quick_access_views import quick_access_set
|
from ormar.models.quick_access_views import quick_access_set
|
||||||
from ormar.queryset import QuerySet
|
from ormar.queryset import QuerySet
|
||||||
from ormar.relations.alias_manager import AliasManager
|
from ormar.relations.alias_manager import AliasManager
|
||||||
|
from ormar.signals import Signal, SignalEmitter
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
@ -38,6 +39,7 @@ class ModelMeta:
|
|||||||
]
|
]
|
||||||
alias_manager: AliasManager
|
alias_manager: AliasManager
|
||||||
property_fields: Set
|
property_fields: Set
|
||||||
|
signals: SignalEmitter
|
||||||
|
|
||||||
|
|
||||||
def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None:
|
def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None:
|
||||||
@ -332,15 +334,12 @@ def add_cached_properties(new_model: Type["Model"]) -> None:
|
|||||||
new_model._pydantic_fields = {name for name in new_model.__fields__}
|
new_model._pydantic_fields = {name for name in new_model.__fields__}
|
||||||
|
|
||||||
|
|
||||||
def property_fields_not_set(new_model: Type["Model"]) -> bool:
|
def meta_field_not_set(model: Type["Model"], field_name: str) -> bool:
|
||||||
return (
|
return not hasattr(model.Meta, field_name) or not getattr(model.Meta, field_name)
|
||||||
not hasattr(new_model.Meta, "property_fields")
|
|
||||||
or not new_model.Meta.property_fields
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: CCR001
|
def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: CCR001
|
||||||
if property_fields_not_set(new_model):
|
if meta_field_not_set(model=new_model, field_name="property_fields"):
|
||||||
props = set()
|
props = set()
|
||||||
for var_name, value in attrs.items():
|
for var_name, value in attrs.items():
|
||||||
if isinstance(value, property):
|
if isinstance(value, property):
|
||||||
@ -351,6 +350,18 @@ def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa:
|
|||||||
new_model.Meta.property_fields = props
|
new_model.Meta.property_fields = props
|
||||||
|
|
||||||
|
|
||||||
|
def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001
|
||||||
|
if meta_field_not_set(model=new_model, field_name="signals"):
|
||||||
|
signals = SignalEmitter()
|
||||||
|
signals.pre_save = Signal()
|
||||||
|
signals.pre_update = Signal()
|
||||||
|
signals.pre_delete = Signal()
|
||||||
|
signals.post_save = Signal()
|
||||||
|
signals.post_update = Signal()
|
||||||
|
signals.post_delete = Signal()
|
||||||
|
new_model.Meta.signals = signals
|
||||||
|
|
||||||
|
|
||||||
class ModelMetaclass(pydantic.main.ModelMetaclass):
|
class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||||
def __new__( # type: ignore
|
def __new__( # type: ignore
|
||||||
mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict
|
mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict
|
||||||
@ -379,5 +390,6 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
|||||||
new_model.Meta.alias_manager = alias_manager
|
new_model.Meta.alias_manager = alias_manager
|
||||||
new_model.objects = QuerySet(new_model)
|
new_model.objects = QuerySet(new_model)
|
||||||
add_property_fields(new_model, attrs)
|
add_property_fields(new_model, attrs)
|
||||||
|
register_signals(new_model=new_model)
|
||||||
|
|
||||||
return new_model
|
return new_model
|
||||||
|
|||||||
@ -195,6 +195,9 @@ class Model(NewBaseModel):
|
|||||||
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
|
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
|
||||||
self_fields.pop(self.Meta.pkname, None)
|
self_fields.pop(self.Meta.pkname, None)
|
||||||
self_fields = self.populate_default_values(self_fields)
|
self_fields = self.populate_default_values(self_fields)
|
||||||
|
self.from_dict(self_fields)
|
||||||
|
|
||||||
|
await self.signals.pre_save.send(sender=self.__class__, instance=self)
|
||||||
|
|
||||||
self_fields = self.translate_columns_to_aliases(self_fields)
|
self_fields = self.translate_columns_to_aliases(self_fields)
|
||||||
expr = self.Meta.table.insert()
|
expr = self.Meta.table.insert()
|
||||||
@ -204,6 +207,7 @@ class Model(NewBaseModel):
|
|||||||
if pk and isinstance(pk, self.pk_type()):
|
if pk and isinstance(pk, self.pk_type()):
|
||||||
setattr(self, self.Meta.pkname, pk)
|
setattr(self, self.Meta.pkname, pk)
|
||||||
|
|
||||||
|
self.set_save_status(True)
|
||||||
# refresh server side defaults
|
# refresh server side defaults
|
||||||
if any(
|
if any(
|
||||||
field.server_default is not None
|
field.server_default is not None
|
||||||
@ -211,9 +215,8 @@ class Model(NewBaseModel):
|
|||||||
if name not in self_fields
|
if name not in self_fields
|
||||||
):
|
):
|
||||||
await self.load()
|
await self.load()
|
||||||
return self
|
|
||||||
|
|
||||||
self.set_save_status(True)
|
await self.signals.post_save.send(sender=self.__class__, instance=self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def save_related( # noqa: CCR001
|
async def save_related( # noqa: CCR001
|
||||||
@ -268,6 +271,7 @@ class Model(NewBaseModel):
|
|||||||
"You cannot update not saved model! Use save or upsert method."
|
"You cannot update not saved model! Use save or upsert method."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await self.signals.pre_update.send(sender=self.__class__, instance=self)
|
||||||
self_fields = self._extract_model_db_fields()
|
self_fields = self._extract_model_db_fields()
|
||||||
self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname))
|
self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname))
|
||||||
self_fields = self.translate_columns_to_aliases(self_fields)
|
self_fields = self.translate_columns_to_aliases(self_fields)
|
||||||
@ -276,13 +280,16 @@ class Model(NewBaseModel):
|
|||||||
|
|
||||||
await self.Meta.database.execute(expr)
|
await self.Meta.database.execute(expr)
|
||||||
self.set_save_status(True)
|
self.set_save_status(True)
|
||||||
|
await self.signals.post_update.send(sender=self.__class__, instance=self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def delete(self: T) -> int:
|
async def delete(self: T) -> int:
|
||||||
|
await self.signals.pre_delete.send(sender=self.__class__, instance=self)
|
||||||
expr = self.Meta.table.delete()
|
expr = self.Meta.table.delete()
|
||||||
expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname)))
|
expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname)))
|
||||||
result = await self.Meta.database.execute(expr)
|
result = await self.Meta.database.execute(expr)
|
||||||
self.set_save_status(False)
|
self.set_save_status(False)
|
||||||
|
await self.signals.post_delete.send(sender=self.__class__, instance=self)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def load(self: T) -> T:
|
async def load(self: T) -> T:
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from ormar.relations.relation_manager import RelationsManager
|
|||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
|
from ormar.signals import SignalEmitter
|
||||||
|
|
||||||
T = TypeVar("T", bound=Model)
|
T = TypeVar("T", bound=Model)
|
||||||
|
|
||||||
@ -212,6 +213,10 @@ class NewBaseModel(
|
|||||||
def saved(self) -> bool:
|
def saved(self) -> bool:
|
||||||
return self._orm_saved
|
return self._orm_saved
|
||||||
|
|
||||||
|
@property
|
||||||
|
def signals(self) -> "SignalEmitter":
|
||||||
|
return self.Meta.signals
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def pk_type(cls) -> Any:
|
def pk_type(cls) -> Any:
|
||||||
return cls.Meta.model_fields[cls.Meta.pkname].__type__
|
return cls.Meta.model_fields[cls.Meta.pkname].__type__
|
||||||
|
|||||||
0
ormar/py.typed
Normal file
0
ormar/py.typed
Normal file
@ -395,6 +395,9 @@ class QuerySet:
|
|||||||
expr = expr.values(**new_kwargs)
|
expr = expr.values(**new_kwargs)
|
||||||
|
|
||||||
instance = self.model(**kwargs)
|
instance = self.model(**kwargs)
|
||||||
|
await self.model.Meta.signals.pre_save.send(
|
||||||
|
sender=self.model, instance=instance
|
||||||
|
)
|
||||||
pk = await self.database.execute(expr)
|
pk = await self.database.execute(expr)
|
||||||
|
|
||||||
pk_name = self.model.get_column_alias(self.model_meta.pkname)
|
pk_name = self.model.get_column_alias(self.model_meta.pkname)
|
||||||
@ -411,6 +414,9 @@ class QuerySet:
|
|||||||
):
|
):
|
||||||
instance = await instance.load()
|
instance = await instance.load()
|
||||||
instance.set_save_status(True)
|
instance.set_save_status(True)
|
||||||
|
await self.model.Meta.signals.post_save.send(
|
||||||
|
sender=self.model, instance=instance
|
||||||
|
)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
async def bulk_create(self, objects: List["Model"]) -> None:
|
async def bulk_create(self, objects: List["Model"]) -> None:
|
||||||
|
|||||||
@ -0,0 +1,3 @@
|
|||||||
|
from ormar.signals.signal import Signal, SignalEmitter
|
||||||
|
|
||||||
|
__all__ = ["Signal", "SignalEmitter"]
|
||||||
|
|||||||
@ -1,9 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, List, Tuple, Union
|
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Tuple, Type, Union
|
||||||
|
|
||||||
from ormar.exceptions import SignalDefinitionError
|
from ormar.exceptions import SignalDefinitionError
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
|
from ormar import Model
|
||||||
|
|
||||||
|
|
||||||
def callable_accepts_kwargs(func: Callable) -> bool:
|
def callable_accepts_kwargs(func: Callable) -> bool:
|
||||||
return any(
|
return any(
|
||||||
@ -45,9 +48,24 @@ class Signal:
|
|||||||
break
|
break
|
||||||
return removed
|
return removed
|
||||||
|
|
||||||
async def send(self, sender: Any, **kwargs: Any) -> None:
|
async def send(self, sender: Type["Model"], **kwargs: Any) -> None:
|
||||||
receivers = []
|
receivers = []
|
||||||
for receiver in self._receivers:
|
for receiver in self._receivers:
|
||||||
_, receiver_func = receiver
|
_, receiver_func = receiver
|
||||||
receivers.append(receiver_func(sender, **kwargs))
|
receivers.append(receiver_func(sender=sender, **kwargs))
|
||||||
await asyncio.gather(*receivers)
|
await asyncio.gather(*receivers)
|
||||||
|
|
||||||
|
|
||||||
|
class SignalEmitter:
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
signals: Dict[str, Signal]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
object.__setattr__(self, "signals", dict())
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Signal:
|
||||||
|
return self.signals.setdefault(item, Signal())
|
||||||
|
|
||||||
|
def __setattr__(self, key: str, value: Any) -> None:
|
||||||
|
signals = object.__getattribute__(self, "signals")
|
||||||
|
signals[key] = value
|
||||||
|
|||||||
@ -1,19 +0,0 @@
|
|||||||
#!/bin/sh -e
|
|
||||||
PACKAGE="ormar"
|
|
||||||
if [ -d 'dist' ] ; then
|
|
||||||
rm -r dist
|
|
||||||
fi
|
|
||||||
if [ -d 'site' ] ; then
|
|
||||||
rm -r site
|
|
||||||
fi
|
|
||||||
if [ -d 'htmlcov' ] ; then
|
|
||||||
rm -r htmlcov
|
|
||||||
fi
|
|
||||||
if [ -d "${PACKAGE}.egg-info" ] ; then
|
|
||||||
rm -r "${PACKAGE}.egg-info"
|
|
||||||
fi
|
|
||||||
find ${PACKAGE} -type f -name "*.py[co]" -delete
|
|
||||||
find ${PACKAGE} -type d -name __pycache__ -delete
|
|
||||||
|
|
||||||
find tests -type f -name "*.py[co]" -delete
|
|
||||||
find tests -type d -name __pycache__ -delete
|
|
||||||
@ -1,23 +0,0 @@
|
|||||||
#!/bin/sh -e
|
|
||||||
|
|
||||||
PACKAGE="ormar"
|
|
||||||
|
|
||||||
PREFIX=""
|
|
||||||
if [ -d 'venv' ] ; then
|
|
||||||
PREFIX="venv/bin/"
|
|
||||||
fi
|
|
||||||
|
|
||||||
VERSION=`cat ${PACKAGE}/__init__.py | grep __version__ | sed "s/__version__ = //" | sed "s/'//g"`
|
|
||||||
|
|
||||||
set -x
|
|
||||||
|
|
||||||
scripts/clean.sh
|
|
||||||
|
|
||||||
${PREFIX}python setup.py sdist
|
|
||||||
${PREFIX}twine upload dist/*
|
|
||||||
|
|
||||||
echo "You probably want to also tag the version now:"
|
|
||||||
echo "git tag -a ${VERSION} -m 'version ${VERSION}'"
|
|
||||||
echo "git push --tags"
|
|
||||||
|
|
||||||
scripts/clean.sh
|
|
||||||
4
setup.py
4
setup.py
@ -50,6 +50,8 @@ setup(
|
|||||||
author_email="collerek@gmail.com",
|
author_email="collerek@gmail.com",
|
||||||
packages=get_packages(PACKAGE),
|
packages=get_packages(PACKAGE),
|
||||||
package_data={PACKAGE: ["py.typed"]},
|
package_data={PACKAGE: ["py.typed"]},
|
||||||
|
include_package_data=True,
|
||||||
|
zip_safe=False,
|
||||||
data_files=[("", ["LICENSE.md"])],
|
data_files=[("", ["LICENSE.md"])],
|
||||||
install_requires=["databases", "pydantic>=1.5", "sqlalchemy", "typing_extensions"],
|
install_requires=["databases", "pydantic>=1.5", "sqlalchemy", "typing_extensions"],
|
||||||
extras_require={
|
extras_require={
|
||||||
@ -65,9 +67,11 @@ setup(
|
|||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Topic :: Internet :: WWW/HTTP",
|
"Topic :: Internet :: WWW/HTTP",
|
||||||
|
"Framework :: AsyncIO",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.6",
|
"Programming Language :: Python :: 3.6",
|
||||||
"Programming Language :: Python :: 3.7",
|
"Programming Language :: Python :: 3.7",
|
||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3 :: Only",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
245
tests/test_signals.py
Normal file
245
tests/test_signals.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
import databases
|
||||||
|
import pydantic
|
||||||
|
import pytest
|
||||||
|
import sqlalchemy
|
||||||
|
|
||||||
|
import ormar
|
||||||
|
from ormar.decorators.signals import (
|
||||||
|
post_delete,
|
||||||
|
post_save,
|
||||||
|
post_update,
|
||||||
|
pre_delete,
|
||||||
|
pre_save,
|
||||||
|
pre_update,
|
||||||
|
)
|
||||||
|
from ormar.exceptions import SignalDefinitionError
|
||||||
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
|
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||||
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLog(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "audits"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: int = ormar.Integer(primary_key=True)
|
||||||
|
event_type: str = ormar.String(max_length=100)
|
||||||
|
event_log: pydantic.Json = ormar.JSON()
|
||||||
|
|
||||||
|
|
||||||
|
class Album(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "albums"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: int = ormar.Integer(primary_key=True)
|
||||||
|
name: str = ormar.String(max_length=100)
|
||||||
|
is_best_seller: bool = ormar.Boolean(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True, scope="module")
|
||||||
|
def create_test_database():
|
||||||
|
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||||
|
metadata.drop_all(engine)
|
||||||
|
metadata.create_all(engine)
|
||||||
|
yield
|
||||||
|
metadata.drop_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
def test_passing_not_callable():
|
||||||
|
with pytest.raises(SignalDefinitionError):
|
||||||
|
pre_save(Album)("wrong")
|
||||||
|
|
||||||
|
|
||||||
|
def test_passing_callable_without_kwargs():
|
||||||
|
with pytest.raises(SignalDefinitionError):
|
||||||
|
|
||||||
|
@pre_save(Album)
|
||||||
|
def trigger(sender, instance): # pragma: no cover
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_signal_functions():
|
||||||
|
async with database:
|
||||||
|
async with database.transaction(force_rollback=True):
|
||||||
|
|
||||||
|
@pre_save(Album)
|
||||||
|
async def before_save(sender, instance, **kwargs):
|
||||||
|
await AuditLog(
|
||||||
|
event_type=f"PRE_SAVE_{sender.get_name()}",
|
||||||
|
event_log=instance.json(),
|
||||||
|
).save()
|
||||||
|
|
||||||
|
@post_save(Album)
|
||||||
|
async def after_save(sender, instance, **kwargs):
|
||||||
|
await AuditLog(
|
||||||
|
event_type=f"POST_SAVE_{sender.get_name()}",
|
||||||
|
event_log=instance.json(),
|
||||||
|
).save()
|
||||||
|
|
||||||
|
@pre_update(Album)
|
||||||
|
async def before_update(sender, instance, **kwargs):
|
||||||
|
await AuditLog(
|
||||||
|
event_type=f"PRE_UPDATE_{sender.get_name()}",
|
||||||
|
event_log=instance.json(),
|
||||||
|
).save()
|
||||||
|
|
||||||
|
@post_update(Album)
|
||||||
|
async def after_update(sender, instance, **kwargs):
|
||||||
|
await AuditLog(
|
||||||
|
event_type=f"POST_UPDATE_{sender.get_name()}",
|
||||||
|
event_log=instance.json(),
|
||||||
|
).save()
|
||||||
|
|
||||||
|
@pre_delete(Album)
|
||||||
|
async def before_delete(sender, instance, **kwargs):
|
||||||
|
await AuditLog(
|
||||||
|
event_type=f"PRE_DELETE_{sender.get_name()}",
|
||||||
|
event_log=instance.json(),
|
||||||
|
).save()
|
||||||
|
|
||||||
|
@post_delete(Album)
|
||||||
|
async def after_delete(sender, instance, **kwargs):
|
||||||
|
await AuditLog(
|
||||||
|
event_type=f"POST_DELETE_{sender.get_name()}",
|
||||||
|
event_log=instance.json(),
|
||||||
|
).save()
|
||||||
|
|
||||||
|
album = await Album.objects.create(name="Venice")
|
||||||
|
|
||||||
|
audits = await AuditLog.objects.all()
|
||||||
|
assert len(audits) == 2
|
||||||
|
assert audits[0].event_type == "PRE_SAVE_album"
|
||||||
|
assert audits[0].event_log.get("name") == album.name
|
||||||
|
assert audits[1].event_type == "POST_SAVE_album"
|
||||||
|
assert audits[1].event_log.get("id") == album.pk
|
||||||
|
|
||||||
|
album = await Album(name="Rome").save()
|
||||||
|
audits = await AuditLog.objects.all()
|
||||||
|
assert len(audits) == 4
|
||||||
|
assert audits[2].event_type == "PRE_SAVE_album"
|
||||||
|
assert audits[2].event_log.get("name") == album.name
|
||||||
|
assert audits[3].event_type == "POST_SAVE_album"
|
||||||
|
assert audits[3].event_log.get("id") == album.pk
|
||||||
|
|
||||||
|
album.is_best_seller = True
|
||||||
|
await album.update()
|
||||||
|
|
||||||
|
audits = await AuditLog.objects.filter(event_type__contains="UPDATE").all()
|
||||||
|
assert len(audits) == 2
|
||||||
|
assert audits[0].event_type == "PRE_UPDATE_album"
|
||||||
|
assert audits[0].event_log.get("name") == album.name
|
||||||
|
assert audits[1].event_type == "POST_UPDATE_album"
|
||||||
|
assert audits[1].event_log.get("is_best_seller") == album.is_best_seller
|
||||||
|
|
||||||
|
album.signals.pre_update.disconnect(before_update)
|
||||||
|
album.signals.post_update.disconnect(after_update)
|
||||||
|
|
||||||
|
album.is_best_seller = False
|
||||||
|
await album.update()
|
||||||
|
|
||||||
|
audits = await AuditLog.objects.filter(event_type__contains="UPDATE").all()
|
||||||
|
assert len(audits) == 2
|
||||||
|
|
||||||
|
await album.delete()
|
||||||
|
audits = await AuditLog.objects.filter(event_type__contains="DELETE").all()
|
||||||
|
assert len(audits) == 2
|
||||||
|
assert audits[0].event_type == "PRE_DELETE_album"
|
||||||
|
assert (
|
||||||
|
audits[0].event_log.get("id")
|
||||||
|
== audits[1].event_log.get("id")
|
||||||
|
== album.id
|
||||||
|
)
|
||||||
|
assert audits[1].event_type == "POST_DELETE_album"
|
||||||
|
|
||||||
|
album.signals.pre_delete.disconnect(before_delete)
|
||||||
|
album.signals.post_delete.disconnect(after_delete)
|
||||||
|
album.signals.pre_save.disconnect(before_save)
|
||||||
|
album.signals.post_save.disconnect(after_save)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_signals():
|
||||||
|
async with database:
|
||||||
|
async with database.transaction(force_rollback=True):
|
||||||
|
|
||||||
|
@pre_save(Album)
|
||||||
|
async def before_save(sender, instance, **kwargs):
|
||||||
|
await AuditLog(
|
||||||
|
event_type=f"PRE_SAVE_{sender.get_name()}",
|
||||||
|
event_log=instance.json(),
|
||||||
|
).save()
|
||||||
|
|
||||||
|
@pre_save(Album)
|
||||||
|
async def before_save2(sender, instance, **kwargs):
|
||||||
|
await AuditLog(
|
||||||
|
event_type=f"PRE_SAVE2_{sender.get_name()}",
|
||||||
|
event_log=instance.json(),
|
||||||
|
).save()
|
||||||
|
|
||||||
|
album = await Album.objects.create(name="Miami")
|
||||||
|
audits = await AuditLog.objects.all()
|
||||||
|
assert len(audits) == 2
|
||||||
|
assert audits[0].event_type == "PRE_SAVE_album"
|
||||||
|
assert audits[0].event_log.get("name") == album.name
|
||||||
|
assert audits[1].event_type == "PRE_SAVE2_album"
|
||||||
|
assert audits[1].event_log.get("name") == album.name
|
||||||
|
|
||||||
|
album.signals.pre_save.disconnect(before_save)
|
||||||
|
album.signals.pre_save.disconnect(before_save2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_static_methods_as_signals():
|
||||||
|
async with database:
|
||||||
|
async with database.transaction(force_rollback=True):
|
||||||
|
|
||||||
|
class AlbumAuditor:
|
||||||
|
event_type = "ALBUM_INSTANCE"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@pre_save(Album)
|
||||||
|
async def before_save(sender, instance, **kwargs):
|
||||||
|
await AuditLog(
|
||||||
|
event_type=f"{AlbumAuditor.event_type}_SAVE",
|
||||||
|
event_log=instance.json(),
|
||||||
|
).save()
|
||||||
|
|
||||||
|
album = await Album.objects.create(name="Colorado")
|
||||||
|
audits = await AuditLog.objects.all()
|
||||||
|
assert len(audits) == 1
|
||||||
|
assert audits[0].event_type == "ALBUM_INSTANCE_SAVE"
|
||||||
|
assert audits[0].event_log.get("name") == album.name
|
||||||
|
|
||||||
|
album.signals.pre_save.disconnect(AlbumAuditor.before_save)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_methods_as_signals():
|
||||||
|
async with database:
|
||||||
|
async with database.transaction(force_rollback=True):
|
||||||
|
|
||||||
|
class AlbumAuditor:
|
||||||
|
def __init__(self):
|
||||||
|
self.event_type = "ALBUM_INSTANCE"
|
||||||
|
|
||||||
|
async def before_save(self, sender, instance, **kwargs):
|
||||||
|
await AuditLog(
|
||||||
|
event_type=f"{self.event_type}_SAVE", event_log=instance.json()
|
||||||
|
).save()
|
||||||
|
|
||||||
|
auditor = AlbumAuditor()
|
||||||
|
pre_save(Album)(auditor.before_save)
|
||||||
|
|
||||||
|
album = await Album.objects.create(name="San Francisco")
|
||||||
|
audits = await AuditLog.objects.all()
|
||||||
|
assert len(audits) == 1
|
||||||
|
assert audits[0].event_type == "ALBUM_INSTANCE_SAVE"
|
||||||
|
assert audits[0].event_log.get("name") == album.name
|
||||||
|
|
||||||
|
album.signals.pre_save.disconnect(auditor.before_save)
|
||||||
Reference in New Issue
Block a user