add signals, register six signals on each models (pre/post + save/update/delete)

This commit is contained in:
collerek
2020-12-06 17:23:46 +01:00
parent 2bbfd05017
commit 85be9e8b80
13 changed files with 368 additions and 53 deletions

View File

@ -3,6 +3,7 @@
* **Breaking:** QuerySet `bulk_update` method now raises `ModelPersistenceError` for unsaved models passed instead of `QueryDefinitionError`
* **Breaking:** Model initialization with unknown field name now raises `ModelError` instead of `KeyError`
*
* Add py.typed and modify setup.py for mypy support
* Performance optimization
# 0.6.2

View File

@ -0,0 +1,56 @@
from typing import Any, Callable, List, TYPE_CHECKING, Type, Union
if TYPE_CHECKING: # pragma: no cover
from ormar import Model
def receiver(
signal: str, senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
) -> Callable:
def _decorator(func: Callable) -> Callable:
if not isinstance(senders, list):
_senders = [senders]
else:
_senders = senders
for sender in _senders:
signals = getattr(sender.Meta.signals, signal)
signals.connect(func, **kwargs)
return func
return _decorator
def post_save(
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
) -> Callable:
return receiver(signal="post_save", senders=senders, **kwargs)
def post_update(
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
) -> Callable:
return receiver(signal="post_update", senders=senders, **kwargs)
def post_delete(
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
) -> Callable:
return receiver(signal="post_delete", senders=senders, **kwargs)
def pre_save(
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
) -> Callable:
return receiver(signal="pre_save", senders=senders, **kwargs)
def pre_update(
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
) -> Callable:
return receiver(signal="pre_update", senders=senders, **kwargs)
def pre_delete(
senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any
) -> Callable:
return receiver(signal="pre_delete", senders=senders, **kwargs)

View File

@ -18,6 +18,7 @@ from ormar.fields.many_to_many import ManyToMany, ManyToManyField
from ormar.models.quick_access_views import quick_access_set
from ormar.queryset import QuerySet
from ormar.relations.alias_manager import AliasManager
from ormar.signals import Signal, SignalEmitter
if TYPE_CHECKING: # pragma no cover
from ormar import Model
@ -38,6 +39,7 @@ class ModelMeta:
]
alias_manager: AliasManager
property_fields: Set
signals: SignalEmitter
def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None:
@ -332,15 +334,12 @@ def add_cached_properties(new_model: Type["Model"]) -> None:
new_model._pydantic_fields = {name for name in new_model.__fields__}
def property_fields_not_set(new_model: Type["Model"]) -> bool:
return (
not hasattr(new_model.Meta, "property_fields")
or not new_model.Meta.property_fields
)
def meta_field_not_set(model: Type["Model"], field_name: str) -> bool:
return not hasattr(model.Meta, field_name) or not getattr(model.Meta, field_name)
def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: CCR001
if property_fields_not_set(new_model):
if meta_field_not_set(model=new_model, field_name="property_fields"):
props = set()
for var_name, value in attrs.items():
if isinstance(value, property):
@ -351,6 +350,18 @@ def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa:
new_model.Meta.property_fields = props
def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001
if meta_field_not_set(model=new_model, field_name="signals"):
signals = SignalEmitter()
signals.pre_save = Signal()
signals.pre_update = Signal()
signals.pre_delete = Signal()
signals.post_save = Signal()
signals.post_update = Signal()
signals.post_delete = Signal()
new_model.Meta.signals = signals
class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__( # type: ignore
mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict
@ -379,5 +390,6 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
new_model.Meta.alias_manager = alias_manager
new_model.objects = QuerySet(new_model)
add_property_fields(new_model, attrs)
register_signals(new_model=new_model)
return new_model

View File

@ -195,6 +195,9 @@ class Model(NewBaseModel):
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
self_fields.pop(self.Meta.pkname, None)
self_fields = self.populate_default_values(self_fields)
self.from_dict(self_fields)
await self.signals.pre_save.send(sender=self.__class__, instance=self)
self_fields = self.translate_columns_to_aliases(self_fields)
expr = self.Meta.table.insert()
@ -204,6 +207,7 @@ class Model(NewBaseModel):
if pk and isinstance(pk, self.pk_type()):
setattr(self, self.Meta.pkname, pk)
self.set_save_status(True)
# refresh server side defaults
if any(
field.server_default is not None
@ -211,9 +215,8 @@ class Model(NewBaseModel):
if name not in self_fields
):
await self.load()
return self
self.set_save_status(True)
await self.signals.post_save.send(sender=self.__class__, instance=self)
return self
async def save_related( # noqa: CCR001
@ -268,6 +271,7 @@ class Model(NewBaseModel):
"You cannot update not saved model! Use save or upsert method."
)
await self.signals.pre_update.send(sender=self.__class__, instance=self)
self_fields = self._extract_model_db_fields()
self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname))
self_fields = self.translate_columns_to_aliases(self_fields)
@ -276,13 +280,16 @@ class Model(NewBaseModel):
await self.Meta.database.execute(expr)
self.set_save_status(True)
await self.signals.post_update.send(sender=self.__class__, instance=self)
return self
async def delete(self: T) -> int:
await self.signals.pre_delete.send(sender=self.__class__, instance=self)
expr = self.Meta.table.delete()
expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname)))
result = await self.Meta.database.execute(expr)
self.set_save_status(False)
await self.signals.post_delete.send(sender=self.__class__, instance=self)
return result
async def load(self: T) -> T:

View File

@ -35,6 +35,7 @@ from ormar.relations.relation_manager import RelationsManager
if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.signals import SignalEmitter
T = TypeVar("T", bound=Model)
@ -212,6 +213,10 @@ class NewBaseModel(
def saved(self) -> bool:
return self._orm_saved
@property
def signals(self) -> "SignalEmitter":
return self.Meta.signals
@classmethod
def pk_type(cls) -> Any:
return cls.Meta.model_fields[cls.Meta.pkname].__type__

0
ormar/py.typed Normal file
View File

View File

@ -395,6 +395,9 @@ class QuerySet:
expr = expr.values(**new_kwargs)
instance = self.model(**kwargs)
await self.model.Meta.signals.pre_save.send(
sender=self.model, instance=instance
)
pk = await self.database.execute(expr)
pk_name = self.model.get_column_alias(self.model_meta.pkname)
@ -411,6 +414,9 @@ class QuerySet:
):
instance = await instance.load()
instance.set_save_status(True)
await self.model.Meta.signals.post_save.send(
sender=self.model, instance=instance
)
return instance
async def bulk_create(self, objects: List["Model"]) -> None:

View File

@ -0,0 +1,3 @@
from ormar.signals.signal import Signal, SignalEmitter
__all__ = ["Signal", "SignalEmitter"]

View File

@ -1,9 +1,12 @@
import asyncio
import inspect
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Tuple, Type, Union
from ormar.exceptions import SignalDefinitionError
if TYPE_CHECKING: # pragma: no cover
from ormar import Model
def callable_accepts_kwargs(func: Callable) -> bool:
return any(
@ -45,9 +48,24 @@ class Signal:
break
return removed
async def send(self, sender: Any, **kwargs: Any) -> None:
async def send(self, sender: Type["Model"], **kwargs: Any) -> None:
receivers = []
for receiver in self._receivers:
_, receiver_func = receiver
receivers.append(receiver_func(sender, **kwargs))
receivers.append(receiver_func(sender=sender, **kwargs))
await asyncio.gather(*receivers)
class SignalEmitter:
if TYPE_CHECKING:
signals: Dict[str, Signal]
def __init__(self) -> None:
object.__setattr__(self, "signals", dict())
def __getattr__(self, item: str) -> Signal:
return self.signals.setdefault(item, Signal())
def __setattr__(self, key: str, value: Any) -> None:
signals = object.__getattribute__(self, "signals")
signals[key] = value

View File

@ -1,19 +0,0 @@
#!/bin/sh -e
PACKAGE="ormar"
if [ -d 'dist' ] ; then
rm -r dist
fi
if [ -d 'site' ] ; then
rm -r site
fi
if [ -d 'htmlcov' ] ; then
rm -r htmlcov
fi
if [ -d "${PACKAGE}.egg-info" ] ; then
rm -r "${PACKAGE}.egg-info"
fi
find ${PACKAGE} -type f -name "*.py[co]" -delete
find ${PACKAGE} -type d -name __pycache__ -delete
find tests -type f -name "*.py[co]" -delete
find tests -type d -name __pycache__ -delete

View File

@ -1,23 +0,0 @@
#!/bin/sh -e
PACKAGE="ormar"
PREFIX=""
if [ -d 'venv' ] ; then
PREFIX="venv/bin/"
fi
VERSION=`cat ${PACKAGE}/__init__.py | grep __version__ | sed "s/__version__ = //" | sed "s/'//g"`
set -x
scripts/clean.sh
${PREFIX}python setup.py sdist
${PREFIX}twine upload dist/*
echo "You probably want to also tag the version now:"
echo "git tag -a ${VERSION} -m 'version ${VERSION}'"
echo "git push --tags"
scripts/clean.sh

View File

@ -50,6 +50,8 @@ setup(
author_email="collerek@gmail.com",
packages=get_packages(PACKAGE),
package_data={PACKAGE: ["py.typed"]},
include_package_data=True,
zip_safe=False,
data_files=[("", ["LICENSE.md"])],
install_requires=["databases", "pydantic>=1.5", "sqlalchemy", "typing_extensions"],
extras_require={
@ -65,9 +67,11 @@ setup(
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Topic :: Internet :: WWW/HTTP",
"Framework :: AsyncIO",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3 :: Only",
],
)

245
tests/test_signals.py Normal file
View File

@ -0,0 +1,245 @@
import databases
import pydantic
import pytest
import sqlalchemy
import ormar
from ormar.decorators.signals import (
post_delete,
post_save,
post_update,
pre_delete,
pre_save,
pre_update,
)
from ormar.exceptions import SignalDefinitionError
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class AuditLog(ormar.Model):
class Meta:
tablename = "audits"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
event_type: str = ormar.String(max_length=100)
event_log: pydantic.Json = ormar.JSON()
class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
is_best_seller: bool = ormar.Boolean(default=False)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
def test_passing_not_callable():
with pytest.raises(SignalDefinitionError):
pre_save(Album)("wrong")
def test_passing_callable_without_kwargs():
with pytest.raises(SignalDefinitionError):
@pre_save(Album)
def trigger(sender, instance): # pragma: no cover
pass
@pytest.mark.asyncio
async def test_signal_functions():
async with database:
async with database.transaction(force_rollback=True):
@pre_save(Album)
async def before_save(sender, instance, **kwargs):
await AuditLog(
event_type=f"PRE_SAVE_{sender.get_name()}",
event_log=instance.json(),
).save()
@post_save(Album)
async def after_save(sender, instance, **kwargs):
await AuditLog(
event_type=f"POST_SAVE_{sender.get_name()}",
event_log=instance.json(),
).save()
@pre_update(Album)
async def before_update(sender, instance, **kwargs):
await AuditLog(
event_type=f"PRE_UPDATE_{sender.get_name()}",
event_log=instance.json(),
).save()
@post_update(Album)
async def after_update(sender, instance, **kwargs):
await AuditLog(
event_type=f"POST_UPDATE_{sender.get_name()}",
event_log=instance.json(),
).save()
@pre_delete(Album)
async def before_delete(sender, instance, **kwargs):
await AuditLog(
event_type=f"PRE_DELETE_{sender.get_name()}",
event_log=instance.json(),
).save()
@post_delete(Album)
async def after_delete(sender, instance, **kwargs):
await AuditLog(
event_type=f"POST_DELETE_{sender.get_name()}",
event_log=instance.json(),
).save()
album = await Album.objects.create(name="Venice")
audits = await AuditLog.objects.all()
assert len(audits) == 2
assert audits[0].event_type == "PRE_SAVE_album"
assert audits[0].event_log.get("name") == album.name
assert audits[1].event_type == "POST_SAVE_album"
assert audits[1].event_log.get("id") == album.pk
album = await Album(name="Rome").save()
audits = await AuditLog.objects.all()
assert len(audits) == 4
assert audits[2].event_type == "PRE_SAVE_album"
assert audits[2].event_log.get("name") == album.name
assert audits[3].event_type == "POST_SAVE_album"
assert audits[3].event_log.get("id") == album.pk
album.is_best_seller = True
await album.update()
audits = await AuditLog.objects.filter(event_type__contains="UPDATE").all()
assert len(audits) == 2
assert audits[0].event_type == "PRE_UPDATE_album"
assert audits[0].event_log.get("name") == album.name
assert audits[1].event_type == "POST_UPDATE_album"
assert audits[1].event_log.get("is_best_seller") == album.is_best_seller
album.signals.pre_update.disconnect(before_update)
album.signals.post_update.disconnect(after_update)
album.is_best_seller = False
await album.update()
audits = await AuditLog.objects.filter(event_type__contains="UPDATE").all()
assert len(audits) == 2
await album.delete()
audits = await AuditLog.objects.filter(event_type__contains="DELETE").all()
assert len(audits) == 2
assert audits[0].event_type == "PRE_DELETE_album"
assert (
audits[0].event_log.get("id")
== audits[1].event_log.get("id")
== album.id
)
assert audits[1].event_type == "POST_DELETE_album"
album.signals.pre_delete.disconnect(before_delete)
album.signals.post_delete.disconnect(after_delete)
album.signals.pre_save.disconnect(before_save)
album.signals.post_save.disconnect(after_save)
@pytest.mark.asyncio
async def test_multiple_signals():
async with database:
async with database.transaction(force_rollback=True):
@pre_save(Album)
async def before_save(sender, instance, **kwargs):
await AuditLog(
event_type=f"PRE_SAVE_{sender.get_name()}",
event_log=instance.json(),
).save()
@pre_save(Album)
async def before_save2(sender, instance, **kwargs):
await AuditLog(
event_type=f"PRE_SAVE2_{sender.get_name()}",
event_log=instance.json(),
).save()
album = await Album.objects.create(name="Miami")
audits = await AuditLog.objects.all()
assert len(audits) == 2
assert audits[0].event_type == "PRE_SAVE_album"
assert audits[0].event_log.get("name") == album.name
assert audits[1].event_type == "PRE_SAVE2_album"
assert audits[1].event_log.get("name") == album.name
album.signals.pre_save.disconnect(before_save)
album.signals.pre_save.disconnect(before_save2)
@pytest.mark.asyncio
async def test_static_methods_as_signals():
async with database:
async with database.transaction(force_rollback=True):
class AlbumAuditor:
event_type = "ALBUM_INSTANCE"
@staticmethod
@pre_save(Album)
async def before_save(sender, instance, **kwargs):
await AuditLog(
event_type=f"{AlbumAuditor.event_type}_SAVE",
event_log=instance.json(),
).save()
album = await Album.objects.create(name="Colorado")
audits = await AuditLog.objects.all()
assert len(audits) == 1
assert audits[0].event_type == "ALBUM_INSTANCE_SAVE"
assert audits[0].event_log.get("name") == album.name
album.signals.pre_save.disconnect(AlbumAuditor.before_save)
@pytest.mark.asyncio
async def test_methods_as_signals():
async with database:
async with database.transaction(force_rollback=True):
class AlbumAuditor:
def __init__(self):
self.event_type = "ALBUM_INSTANCE"
async def before_save(self, sender, instance, **kwargs):
await AuditLog(
event_type=f"{self.event_type}_SAVE", event_log=instance.json()
).save()
auditor = AlbumAuditor()
pre_save(Album)(auditor.before_save)
album = await Album.objects.create(name="San Francisco")
audits = await AuditLog.objects.all()
assert len(audits) == 1
assert audits[0].event_type == "ALBUM_INSTANCE_SAVE"
assert audits[0].event_log.get("name") == album.name
album.signals.pre_save.disconnect(auditor.before_save)