diff --git a/docs/releases.md b/docs/releases.md index 96e799e..bd8f3db 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -3,6 +3,7 @@ * **Breaking:** QuerySet `bulk_update` method now raises `ModelPersistenceError` for unsaved models passed instead of `QueryDefinitionError` * **Breaking:** Model initialization with unknown field name now raises `ModelError` instead of `KeyError` * +* Add py.typed and modify setup.py for mypy support * Performance optimization # 0.6.2 diff --git a/ormar/decorators/signals.py b/ormar/decorators/signals.py new file mode 100644 index 0000000..d149f97 --- /dev/null +++ b/ormar/decorators/signals.py @@ -0,0 +1,56 @@ +from typing import Any, Callable, List, TYPE_CHECKING, Type, Union + +if TYPE_CHECKING: # pragma: no cover + from ormar import Model + + +def receiver( + signal: str, senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any +) -> Callable: + def _decorator(func: Callable) -> Callable: + if not isinstance(senders, list): + _senders = [senders] + else: + _senders = senders + for sender in _senders: + signals = getattr(sender.Meta.signals, signal) + signals.connect(func, **kwargs) + return func + + return _decorator + + +def post_save( + senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any +) -> Callable: + return receiver(signal="post_save", senders=senders, **kwargs) + + +def post_update( + senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any +) -> Callable: + return receiver(signal="post_update", senders=senders, **kwargs) + + +def post_delete( + senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any +) -> Callable: + return receiver(signal="post_delete", senders=senders, **kwargs) + + +def pre_save( + senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any +) -> Callable: + return receiver(signal="pre_save", senders=senders, **kwargs) + + +def pre_update( + senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any +) -> Callable: + return receiver(signal="pre_update", senders=senders, **kwargs) + + +def pre_delete( + senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any +) -> Callable: + return receiver(signal="pre_delete", senders=senders, **kwargs) diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 8460e13..5649d22 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -18,6 +18,7 @@ from ormar.fields.many_to_many import ManyToMany, ManyToManyField from ormar.models.quick_access_views import quick_access_set from ormar.queryset import QuerySet from ormar.relations.alias_manager import AliasManager +from ormar.signals import Signal, SignalEmitter if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -38,6 +39,7 @@ class ModelMeta: ] alias_manager: AliasManager property_fields: Set + signals: SignalEmitter def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None: @@ -332,15 +334,12 @@ def add_cached_properties(new_model: Type["Model"]) -> None: new_model._pydantic_fields = {name for name in new_model.__fields__} -def property_fields_not_set(new_model: Type["Model"]) -> bool: - return ( - not hasattr(new_model.Meta, "property_fields") - or not new_model.Meta.property_fields - ) +def meta_field_not_set(model: Type["Model"], field_name: str) -> bool: + return not hasattr(model.Meta, field_name) or not getattr(model.Meta, field_name) def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: CCR001 - if property_fields_not_set(new_model): + if meta_field_not_set(model=new_model, field_name="property_fields"): props = set() for var_name, value in attrs.items(): if isinstance(value, property): @@ -351,6 +350,18 @@ def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: new_model.Meta.property_fields = props +def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001 + if meta_field_not_set(model=new_model, field_name="signals"): + signals = SignalEmitter() + signals.pre_save = Signal() + signals.pre_update = Signal() + signals.pre_delete = Signal() + signals.post_save = Signal() + signals.post_update = Signal() + signals.post_delete = Signal() + new_model.Meta.signals = signals + + class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__( # type: ignore mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict @@ -379,5 +390,6 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): new_model.Meta.alias_manager = alias_manager new_model.objects = QuerySet(new_model) add_property_fields(new_model, attrs) + register_signals(new_model=new_model) return new_model diff --git a/ormar/models/model.py b/ormar/models/model.py index 8bf7d76..8ae50d2 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -195,6 +195,9 @@ class Model(NewBaseModel): if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement: self_fields.pop(self.Meta.pkname, None) self_fields = self.populate_default_values(self_fields) + self.from_dict(self_fields) + + await self.signals.pre_save.send(sender=self.__class__, instance=self) self_fields = self.translate_columns_to_aliases(self_fields) expr = self.Meta.table.insert() @@ -204,6 +207,7 @@ class Model(NewBaseModel): if pk and isinstance(pk, self.pk_type()): setattr(self, self.Meta.pkname, pk) + self.set_save_status(True) # refresh server side defaults if any( field.server_default is not None @@ -211,9 +215,8 @@ class Model(NewBaseModel): if name not in self_fields ): await self.load() - return self - self.set_save_status(True) + await self.signals.post_save.send(sender=self.__class__, instance=self) return self async def save_related( # noqa: CCR001 @@ -268,6 +271,7 @@ class Model(NewBaseModel): "You cannot update not saved model! Use save or upsert method." ) + await self.signals.pre_update.send(sender=self.__class__, instance=self) self_fields = self._extract_model_db_fields() self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname)) self_fields = self.translate_columns_to_aliases(self_fields) @@ -276,13 +280,16 @@ class Model(NewBaseModel): await self.Meta.database.execute(expr) self.set_save_status(True) + await self.signals.post_update.send(sender=self.__class__, instance=self) return self async def delete(self: T) -> int: + await self.signals.pre_delete.send(sender=self.__class__, instance=self) expr = self.Meta.table.delete() expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname))) result = await self.Meta.database.execute(expr) self.set_save_status(False) + await self.signals.post_delete.send(sender=self.__class__, instance=self) return result async def load(self: T) -> T: diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 45e0fa0..ff3f8b0 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -35,6 +35,7 @@ from ormar.relations.relation_manager import RelationsManager if TYPE_CHECKING: # pragma no cover from ormar import Model + from ormar.signals import SignalEmitter T = TypeVar("T", bound=Model) @@ -212,6 +213,10 @@ class NewBaseModel( def saved(self) -> bool: return self._orm_saved + @property + def signals(self) -> "SignalEmitter": + return self.Meta.signals + @classmethod def pk_type(cls) -> Any: return cls.Meta.model_fields[cls.Meta.pkname].__type__ diff --git a/ormar/py.typed b/ormar/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 0b7a31d..63bbf1f 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -395,6 +395,9 @@ class QuerySet: expr = expr.values(**new_kwargs) instance = self.model(**kwargs) + await self.model.Meta.signals.pre_save.send( + sender=self.model, instance=instance + ) pk = await self.database.execute(expr) pk_name = self.model.get_column_alias(self.model_meta.pkname) @@ -411,6 +414,9 @@ class QuerySet: ): instance = await instance.load() instance.set_save_status(True) + await self.model.Meta.signals.post_save.send( + sender=self.model, instance=instance + ) return instance async def bulk_create(self, objects: List["Model"]) -> None: diff --git a/ormar/signals/__init__.py b/ormar/signals/__init__.py index e69de29..6f4706e 100644 --- a/ormar/signals/__init__.py +++ b/ormar/signals/__init__.py @@ -0,0 +1,3 @@ +from ormar.signals.signal import Signal, SignalEmitter + +__all__ = ["Signal", "SignalEmitter"] diff --git a/ormar/signals/signal.py b/ormar/signals/signal.py index 836bd2d..4ed12d7 100644 --- a/ormar/signals/signal.py +++ b/ormar/signals/signal.py @@ -1,9 +1,12 @@ import asyncio import inspect -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable, Dict, List, TYPE_CHECKING, Tuple, Type, Union from ormar.exceptions import SignalDefinitionError +if TYPE_CHECKING: # pragma: no cover + from ormar import Model + def callable_accepts_kwargs(func: Callable) -> bool: return any( @@ -45,9 +48,24 @@ class Signal: break return removed - async def send(self, sender: Any, **kwargs: Any) -> None: + async def send(self, sender: Type["Model"], **kwargs: Any) -> None: receivers = [] for receiver in self._receivers: _, receiver_func = receiver - receivers.append(receiver_func(sender, **kwargs)) + receivers.append(receiver_func(sender=sender, **kwargs)) await asyncio.gather(*receivers) + + +class SignalEmitter: + if TYPE_CHECKING: + signals: Dict[str, Signal] + + def __init__(self) -> None: + object.__setattr__(self, "signals", dict()) + + def __getattr__(self, item: str) -> Signal: + return self.signals.setdefault(item, Signal()) + + def __setattr__(self, key: str, value: Any) -> None: + signals = object.__getattribute__(self, "signals") + signals[key] = value diff --git a/scripts/clean.sh b/scripts/clean.sh deleted file mode 100755 index 9b30e14..0000000 --- a/scripts/clean.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/sh -e -PACKAGE="ormar" -if [ -d 'dist' ] ; then - rm -r dist -fi -if [ -d 'site' ] ; then - rm -r site -fi -if [ -d 'htmlcov' ] ; then - rm -r htmlcov -fi -if [ -d "${PACKAGE}.egg-info" ] ; then - rm -r "${PACKAGE}.egg-info" -fi -find ${PACKAGE} -type f -name "*.py[co]" -delete -find ${PACKAGE} -type d -name __pycache__ -delete - -find tests -type f -name "*.py[co]" -delete -find tests -type d -name __pycache__ -delete \ No newline at end of file diff --git a/scripts/publish.sh b/scripts/publish.sh deleted file mode 100755 index 419fa30..0000000 --- a/scripts/publish.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/sh -e - -PACKAGE="ormar" - -PREFIX="" -if [ -d 'venv' ] ; then - PREFIX="venv/bin/" -fi - -VERSION=`cat ${PACKAGE}/__init__.py | grep __version__ | sed "s/__version__ = //" | sed "s/'//g"` - -set -x - -scripts/clean.sh - -${PREFIX}python setup.py sdist -${PREFIX}twine upload dist/* - -echo "You probably want to also tag the version now:" -echo "git tag -a ${VERSION} -m 'version ${VERSION}'" -echo "git push --tags" - -scripts/clean.sh \ No newline at end of file diff --git a/setup.py b/setup.py index f461d97..85a3929 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,8 @@ setup( author_email="collerek@gmail.com", packages=get_packages(PACKAGE), package_data={PACKAGE: ["py.typed"]}, + include_package_data=True, + zip_safe=False, data_files=[("", ["LICENSE.md"])], install_requires=["databases", "pydantic>=1.5", "sqlalchemy", "typing_extensions"], extras_require={ @@ -65,9 +67,11 @@ setup( "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Topic :: Internet :: WWW/HTTP", + "Framework :: AsyncIO", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3 :: Only", ], ) diff --git a/tests/test_signals.py b/tests/test_signals.py new file mode 100644 index 0000000..6f59d6d --- /dev/null +++ b/tests/test_signals.py @@ -0,0 +1,245 @@ +import databases +import pydantic +import pytest +import sqlalchemy + +import ormar +from ormar.decorators.signals import ( + post_delete, + post_save, + post_update, + pre_delete, + pre_save, + pre_update, +) +from ormar.exceptions import SignalDefinitionError +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class AuditLog(ormar.Model): + class Meta: + tablename = "audits" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + event_type: str = ormar.String(max_length=100) + event_log: pydantic.Json = ormar.JSON() + + +class Album(ormar.Model): + class Meta: + tablename = "albums" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + is_best_seller: bool = ormar.Boolean(default=False) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_passing_not_callable(): + with pytest.raises(SignalDefinitionError): + pre_save(Album)("wrong") + + +def test_passing_callable_without_kwargs(): + with pytest.raises(SignalDefinitionError): + + @pre_save(Album) + def trigger(sender, instance): # pragma: no cover + pass + + +@pytest.mark.asyncio +async def test_signal_functions(): + async with database: + async with database.transaction(force_rollback=True): + + @pre_save(Album) + async def before_save(sender, instance, **kwargs): + await AuditLog( + event_type=f"PRE_SAVE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + @post_save(Album) + async def after_save(sender, instance, **kwargs): + await AuditLog( + event_type=f"POST_SAVE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + @pre_update(Album) + async def before_update(sender, instance, **kwargs): + await AuditLog( + event_type=f"PRE_UPDATE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + @post_update(Album) + async def after_update(sender, instance, **kwargs): + await AuditLog( + event_type=f"POST_UPDATE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + @pre_delete(Album) + async def before_delete(sender, instance, **kwargs): + await AuditLog( + event_type=f"PRE_DELETE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + @post_delete(Album) + async def after_delete(sender, instance, **kwargs): + await AuditLog( + event_type=f"POST_DELETE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + album = await Album.objects.create(name="Venice") + + audits = await AuditLog.objects.all() + assert len(audits) == 2 + assert audits[0].event_type == "PRE_SAVE_album" + assert audits[0].event_log.get("name") == album.name + assert audits[1].event_type == "POST_SAVE_album" + assert audits[1].event_log.get("id") == album.pk + + album = await Album(name="Rome").save() + audits = await AuditLog.objects.all() + assert len(audits) == 4 + assert audits[2].event_type == "PRE_SAVE_album" + assert audits[2].event_log.get("name") == album.name + assert audits[3].event_type == "POST_SAVE_album" + assert audits[3].event_log.get("id") == album.pk + + album.is_best_seller = True + await album.update() + + audits = await AuditLog.objects.filter(event_type__contains="UPDATE").all() + assert len(audits) == 2 + assert audits[0].event_type == "PRE_UPDATE_album" + assert audits[0].event_log.get("name") == album.name + assert audits[1].event_type == "POST_UPDATE_album" + assert audits[1].event_log.get("is_best_seller") == album.is_best_seller + + album.signals.pre_update.disconnect(before_update) + album.signals.post_update.disconnect(after_update) + + album.is_best_seller = False + await album.update() + + audits = await AuditLog.objects.filter(event_type__contains="UPDATE").all() + assert len(audits) == 2 + + await album.delete() + audits = await AuditLog.objects.filter(event_type__contains="DELETE").all() + assert len(audits) == 2 + assert audits[0].event_type == "PRE_DELETE_album" + assert ( + audits[0].event_log.get("id") + == audits[1].event_log.get("id") + == album.id + ) + assert audits[1].event_type == "POST_DELETE_album" + + album.signals.pre_delete.disconnect(before_delete) + album.signals.post_delete.disconnect(after_delete) + album.signals.pre_save.disconnect(before_save) + album.signals.post_save.disconnect(after_save) + + +@pytest.mark.asyncio +async def test_multiple_signals(): + async with database: + async with database.transaction(force_rollback=True): + + @pre_save(Album) + async def before_save(sender, instance, **kwargs): + await AuditLog( + event_type=f"PRE_SAVE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + @pre_save(Album) + async def before_save2(sender, instance, **kwargs): + await AuditLog( + event_type=f"PRE_SAVE2_{sender.get_name()}", + event_log=instance.json(), + ).save() + + album = await Album.objects.create(name="Miami") + audits = await AuditLog.objects.all() + assert len(audits) == 2 + assert audits[0].event_type == "PRE_SAVE_album" + assert audits[0].event_log.get("name") == album.name + assert audits[1].event_type == "PRE_SAVE2_album" + assert audits[1].event_log.get("name") == album.name + + album.signals.pre_save.disconnect(before_save) + album.signals.pre_save.disconnect(before_save2) + + +@pytest.mark.asyncio +async def test_static_methods_as_signals(): + async with database: + async with database.transaction(force_rollback=True): + + class AlbumAuditor: + event_type = "ALBUM_INSTANCE" + + @staticmethod + @pre_save(Album) + async def before_save(sender, instance, **kwargs): + await AuditLog( + event_type=f"{AlbumAuditor.event_type}_SAVE", + event_log=instance.json(), + ).save() + + album = await Album.objects.create(name="Colorado") + audits = await AuditLog.objects.all() + assert len(audits) == 1 + assert audits[0].event_type == "ALBUM_INSTANCE_SAVE" + assert audits[0].event_log.get("name") == album.name + + album.signals.pre_save.disconnect(AlbumAuditor.before_save) + + +@pytest.mark.asyncio +async def test_methods_as_signals(): + async with database: + async with database.transaction(force_rollback=True): + + class AlbumAuditor: + def __init__(self): + self.event_type = "ALBUM_INSTANCE" + + async def before_save(self, sender, instance, **kwargs): + await AuditLog( + event_type=f"{self.event_type}_SAVE", event_log=instance.json() + ).save() + + auditor = AlbumAuditor() + pre_save(Album)(auditor.before_save) + + album = await Album.objects.create(name="San Francisco") + audits = await AuditLog.objects.all() + assert len(audits) == 1 + assert audits[0].event_type == "ALBUM_INSTANCE_SAVE" + assert audits[0].event_log.get("name") == album.name + + album.signals.pre_save.disconnect(auditor.before_save)