From 9838547c4ff1a4b0989892f2c0924fdbbfbff6cc Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 6 Dec 2020 08:23:57 +0100 Subject: [PATCH 1/7] some cleanup and refactoring --- ormar/models/newbasemodel.py | 29 +++++++++++------------ tests/test_excluding_fields_in_fastapi.py | 10 +++++--- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 9105f7b..45e0fa0 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -47,11 +47,7 @@ if TYPE_CHECKING: # pragma no cover class NewBaseModel( pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass ): - __slots__ = ( - "_orm_id", - "_orm_saved", - "_orm", - ) + __slots__ = ("_orm_id", "_orm_saved", "_orm", "_pk_column") if TYPE_CHECKING: # pragma no cover __model_fields__: Dict[str, Type[BaseField]] @@ -75,6 +71,7 @@ class NewBaseModel( def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore object.__setattr__(self, "_orm_id", uuid.uuid4().hex) object.__setattr__(self, "_orm_saved", False) + object.__setattr__(self, "_pk_column", None) object.__setattr__( self, "_orm", @@ -94,13 +91,8 @@ class NewBaseModel( if "pk" in kwargs: kwargs[self.Meta.pkname] = kwargs.pop("pk") - # remove property fields values from validation - kwargs = { - k: v - for k, v in kwargs.items() - if k not in object.__getattribute__(self, "Meta").property_fields - } # build the models to set them and validate but don't register + # also remove property fields values from validation try: new_kwargs: Dict[str, Any] = { k: self._convert_json( @@ -111,14 +103,15 @@ class NewBaseModel( "dumps", ) for k, v in kwargs.items() + if k not in object.__getattribute__(self, "Meta").property_fields } except KeyError as e: raise ModelError( f"Unknown field '{e.args[0]}' for model {self.get_name(lower=False)}" ) - # explicitly set None to excluded fields with default - # as pydantic populates them with default + # explicitly set None to excluded fields + # as pydantic populates them with default if set for field_to_nullify in excluded: new_kwargs[field_to_nullify] = None @@ -195,7 +188,8 @@ class NewBaseModel( return ( self._orm_id == other._orm_id or (self.pk == other.pk and self.pk is not None) - or self.dict() == other.dict() + or self.dict(exclude=self.extract_related_names()) + == other.dict(exclude=other.extract_related_names()) ) @classmethod @@ -207,7 +201,12 @@ class NewBaseModel( @property def pk_column(self) -> sqlalchemy.Column: - return self.Meta.table.primary_key.columns.values()[0] + if object.__getattribute__(self, "_pk_column") is not None: + return object.__getattribute__(self, "_pk_column") + pk_columns = self.Meta.table.primary_key.columns.values() + pk_col = pk_columns[0] + object.__setattr__(self, "_pk_column", pk_col) + return pk_col @property def saved(self) -> bool: diff --git a/tests/test_excluding_fields_in_fastapi.py b/tests/test_excluding_fields_in_fastapi.py index a75d0e1..3e3349b 100644 --- a/tests/test_excluding_fields_in_fastapi.py +++ b/tests/test_excluding_fields_in_fastapi.py @@ -219,7 +219,7 @@ def test_excluding_fields_in_endpoints(): assert isinstance(user_instance.timestamp, datetime.datetime) assert user_instance.timestamp == timestamp - response = client.post("/users4/", json=user) + response = client.post("/users4/", json=user3) assert list(response.json().keys()) == [ "id", "email", @@ -228,8 +228,12 @@ def test_excluding_fields_in_endpoints(): "category", "timestamp", ] - assert response.json().get("timestamp") != str(timestamp).replace(" ", "T") - assert response.json().get("timestamp") is not None + assert ( + datetime.datetime.strptime( + response.json().get("timestamp"), "%Y-%m-%dT%H:%M:%S.%f" + ) + == timestamp + ) def test_adding_fields_in_endpoints(): From 2bbfd0501743b199d8fd1d0a5778c9f50383317c Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 6 Dec 2020 10:28:48 +0100 Subject: [PATCH 2/7] add base signal class --- docs/releases.md | 9 ++++- ormar/__init__.py | 3 +- ormar/decorators/property_field.py | 2 +- ormar/exceptions.py | 48 +++++++++++++++++++++++-- ormar/queryset/queryset.py | 4 +-- ormar/signals/__init__.py | 0 ormar/signals/signal.py | 53 ++++++++++++++++++++++++++++ tests/test_queryset_level_methods.py | 4 +-- 8 files changed, 112 insertions(+), 11 deletions(-) create mode 100644 ormar/signals/__init__.py create mode 100644 ormar/signals/signal.py diff --git a/docs/releases.md b/docs/releases.md index f940a36..96e799e 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,10 @@ +# 0.7.0 + +* **Breaking:** QuerySet `bulk_update` method now raises `ModelPersistenceError` for unsaved models passed instead of `QueryDefinitionError` +* **Breaking:** Model initialization with unknown field name now raises `ModelError` instead of `KeyError` +* +* Performance optimization + # 0.6.2 * Performance optimization @@ -12,7 +19,7 @@ # 0.6.0 -* **Breaking:** calling instance.load() when the instance row was deleted from db now raises ormar.NoMatch instead of ValueError +* **Breaking:** calling instance.load() when the instance row was deleted from db now raises `NoMatch` instead of `ValueError` * **Breaking:** calling add and remove on ReverseForeignKey relation now updates the child model in db setting/removing fk column * **Breaking:** ReverseForeignKey relation now exposes QuerySetProxy API like ManyToMany relation * **Breaking:** querying related models from ManyToMany cleans list of related models loaded on parent model: diff --git a/ormar/__init__.py b/ormar/__init__.py index 6e196db..036e372 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -1,5 +1,5 @@ from ormar.decorators import property_field -from ormar.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch +from ormar.exceptions import ModelDefinitionError, MultipleMatches, NoMatch from ormar.protocols import QuerySetProtocol, RelationProtocol # noqa: I100 from ormar.fields import ( # noqa: I100 BigInteger, @@ -47,7 +47,6 @@ __all__ = [ "ManyToMany", "Model", "ModelDefinitionError", - "ModelNotSet", "MultipleMatches", "NoMatch", "ForeignKey", diff --git a/ormar/decorators/property_field.py b/ormar/decorators/property_field.py index 8d6a2e2..a377202 100644 --- a/ormar/decorators/property_field.py +++ b/ormar/decorators/property_field.py @@ -13,7 +13,7 @@ def property_field(func: Callable) -> Union[property, Callable]: if len(arguments) > 1 or arguments[0] != "self": raise ModelDefinitionError( "property_field decorator can be used " - "only on class methods with no arguments" + "only on methods with no arguments" ) func.__dict__["__property_field__"] = True return func diff --git a/ormar/exceptions.py b/ormar/exceptions.py index 0ec0c8e..3dbf763 100644 --- a/ormar/exceptions.py +++ b/ormar/exceptions.py @@ -1,28 +1,57 @@ class AsyncOrmException(Exception): + """ + Base ormar Exception + """ + pass class ModelDefinitionError(AsyncOrmException): + """ + Raised for errors related to the model definition itself. + * setting @property_field on method with arguments other than func(self) + * defining a Field without required parameters + * defining a model with more than one primary_key + * defining a model without primary_key + * setting primary_key column as pydantic_only + """ + pass class ModelError(AsyncOrmException): - pass + """ + Raised for initialization of model with non-existing field keyword. + """ - -class ModelNotSet(AsyncOrmException): pass class NoMatch(AsyncOrmException): + """ + Raised for database queries that has no matching result (empty result). + """ + pass class MultipleMatches(AsyncOrmException): + """ + Raised for database queries that should return one row (i.e. get, first etc.) + but has multiple matching results in response. + """ + pass class QueryDefinitionError(AsyncOrmException): + """ + Raised for errors in query definition. + * using contains or icontains filter with instance of the Model + * using Queryset.update() without filter and setting each flag to True + * using Queryset.delete() without filter and setting each flag to True + """ + pass @@ -31,4 +60,17 @@ class RelationshipInstanceError(AsyncOrmException): class ModelPersistenceError(AsyncOrmException): + """ + Raised for update of models without primary_key set (cannot retrieve from db) + or for saving a model with relation to unsaved model (cannot extract fk value). + """ + + pass + + +class SignalDefinitionError(AsyncOrmException): + """ + Raised when non callable receiver is passed as signal callback. + """ + pass diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index f6defd4..0b7a31d 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -6,7 +6,7 @@ from sqlalchemy import bindparam import ormar # noqa I100 from ormar import MultipleMatches, NoMatch -from ormar.exceptions import QueryDefinitionError +from ormar.exceptions import ModelPersistenceError, QueryDefinitionError from ormar.queryset import FilterQuery from ormar.queryset.clause import QueryClause from ormar.queryset.prefetch_query import PrefetchQuery @@ -446,7 +446,7 @@ class QuerySet: for objt in objects: new_kwargs = objt.dict() if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None: - raise QueryDefinitionError( + raise ModelPersistenceError( "You cannot update unsaved objects. " f"{self.model.__name__} has to have {pk_name} filled." ) diff --git a/ormar/signals/__init__.py b/ormar/signals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ormar/signals/signal.py b/ormar/signals/signal.py new file mode 100644 index 0000000..836bd2d --- /dev/null +++ b/ormar/signals/signal.py @@ -0,0 +1,53 @@ +import asyncio +import inspect +from typing import Any, Callable, List, Tuple, Union + +from ormar.exceptions import SignalDefinitionError + + +def callable_accepts_kwargs(func: Callable) -> bool: + return any( + p + for p in inspect.signature(func).parameters.values() + if p.kind == p.VAR_KEYWORD + ) + + +def make_id(target: Any) -> Union[int, Tuple[int, int]]: + if hasattr(target, "__func__"): + return id(target.__self__), id(target.__func__) + return id(target) + + +class Signal: + def __init__(self) -> None: + self._receivers: List[Tuple[Union[int, Tuple[int, int]], Callable]] = [] + + def connect(self, receiver: Callable) -> None: + if not callable(receiver): + raise SignalDefinitionError("Signal receivers must be callable.") + if not callable_accepts_kwargs(receiver): + raise SignalDefinitionError( + "Signal receivers must accept **kwargs argument." + ) + new_receiver_key = make_id(receiver) + if not any(rec_id == new_receiver_key for rec_id, _ in self._receivers): + self._receivers.append((new_receiver_key, receiver)) + + def disconnect(self, receiver: Callable) -> bool: + removed = False + new_receiver_key = make_id(receiver) + for ind, rec in enumerate(self._receivers): + rec_id, _ = rec + if rec_id == new_receiver_key: + removed = True + del self._receivers[ind] + break + return removed + + async def send(self, sender: Any, **kwargs: Any) -> None: + receivers = [] + for receiver in self._receivers: + _, receiver_func = receiver + receivers.append(receiver_func(sender, **kwargs)) + await asyncio.gather(*receivers) diff --git a/tests/test_queryset_level_methods.py b/tests/test_queryset_level_methods.py index 5ea5443..c7db39f 100644 --- a/tests/test_queryset_level_methods.py +++ b/tests/test_queryset_level_methods.py @@ -5,7 +5,7 @@ import pytest import sqlalchemy import ormar -from ormar.exceptions import QueryDefinitionError +from ormar.exceptions import ModelPersistenceError, QueryDefinitionError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -302,7 +302,7 @@ async def test_bulk_update_with_relation(): async def test_bulk_update_not_saved_objts(): async with database: category = await Category.objects.create(name="Sample Category") - with pytest.raises(QueryDefinitionError): + with pytest.raises(ModelPersistenceError): await Note.objects.bulk_update( [ Note(text="Buy the groceries.", category=category), From 85be9e8b80979cb607931ef90a5a9fea552a987e Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 6 Dec 2020 17:23:46 +0100 Subject: [PATCH 3/7] add signals, register six signals on each models (pre/post + save/update/delete) --- docs/releases.md | 1 + ormar/decorators/signals.py | 56 ++++++++ ormar/models/metaclass.py | 24 +++- ormar/models/model.py | 11 +- ormar/models/newbasemodel.py | 5 + ormar/py.typed | 0 ormar/queryset/queryset.py | 6 + ormar/signals/__init__.py | 3 + ormar/signals/signal.py | 24 +++- scripts/clean.sh | 19 --- scripts/publish.sh | 23 ---- setup.py | 4 + tests/test_signals.py | 245 +++++++++++++++++++++++++++++++++++ 13 files changed, 368 insertions(+), 53 deletions(-) create mode 100644 ormar/decorators/signals.py create mode 100644 ormar/py.typed delete mode 100755 scripts/clean.sh delete mode 100755 scripts/publish.sh create mode 100644 tests/test_signals.py diff --git a/docs/releases.md b/docs/releases.md index 96e799e..bd8f3db 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -3,6 +3,7 @@ * **Breaking:** QuerySet `bulk_update` method now raises `ModelPersistenceError` for unsaved models passed instead of `QueryDefinitionError` * **Breaking:** Model initialization with unknown field name now raises `ModelError` instead of `KeyError` * +* Add py.typed and modify setup.py for mypy support * Performance optimization # 0.6.2 diff --git a/ormar/decorators/signals.py b/ormar/decorators/signals.py new file mode 100644 index 0000000..d149f97 --- /dev/null +++ b/ormar/decorators/signals.py @@ -0,0 +1,56 @@ +from typing import Any, Callable, List, TYPE_CHECKING, Type, Union + +if TYPE_CHECKING: # pragma: no cover + from ormar import Model + + +def receiver( + signal: str, senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any +) -> Callable: + def _decorator(func: Callable) -> Callable: + if not isinstance(senders, list): + _senders = [senders] + else: + _senders = senders + for sender in _senders: + signals = getattr(sender.Meta.signals, signal) + signals.connect(func, **kwargs) + return func + + return _decorator + + +def post_save( + senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any +) -> Callable: + return receiver(signal="post_save", senders=senders, **kwargs) + + +def post_update( + senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any +) -> Callable: + return receiver(signal="post_update", senders=senders, **kwargs) + + +def post_delete( + senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any +) -> Callable: + return receiver(signal="post_delete", senders=senders, **kwargs) + + +def pre_save( + senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any +) -> Callable: + return receiver(signal="pre_save", senders=senders, **kwargs) + + +def pre_update( + senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any +) -> Callable: + return receiver(signal="pre_update", senders=senders, **kwargs) + + +def pre_delete( + senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any +) -> Callable: + return receiver(signal="pre_delete", senders=senders, **kwargs) diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 8460e13..5649d22 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -18,6 +18,7 @@ from ormar.fields.many_to_many import ManyToMany, ManyToManyField from ormar.models.quick_access_views import quick_access_set from ormar.queryset import QuerySet from ormar.relations.alias_manager import AliasManager +from ormar.signals import Signal, SignalEmitter if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -38,6 +39,7 @@ class ModelMeta: ] alias_manager: AliasManager property_fields: Set + signals: SignalEmitter def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None: @@ -332,15 +334,12 @@ def add_cached_properties(new_model: Type["Model"]) -> None: new_model._pydantic_fields = {name for name in new_model.__fields__} -def property_fields_not_set(new_model: Type["Model"]) -> bool: - return ( - not hasattr(new_model.Meta, "property_fields") - or not new_model.Meta.property_fields - ) +def meta_field_not_set(model: Type["Model"], field_name: str) -> bool: + return not hasattr(model.Meta, field_name) or not getattr(model.Meta, field_name) def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: CCR001 - if property_fields_not_set(new_model): + if meta_field_not_set(model=new_model, field_name="property_fields"): props = set() for var_name, value in attrs.items(): if isinstance(value, property): @@ -351,6 +350,18 @@ def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: new_model.Meta.property_fields = props +def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001 + if meta_field_not_set(model=new_model, field_name="signals"): + signals = SignalEmitter() + signals.pre_save = Signal() + signals.pre_update = Signal() + signals.pre_delete = Signal() + signals.post_save = Signal() + signals.post_update = Signal() + signals.post_delete = Signal() + new_model.Meta.signals = signals + + class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__( # type: ignore mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict @@ -379,5 +390,6 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): new_model.Meta.alias_manager = alias_manager new_model.objects = QuerySet(new_model) add_property_fields(new_model, attrs) + register_signals(new_model=new_model) return new_model diff --git a/ormar/models/model.py b/ormar/models/model.py index 8bf7d76..8ae50d2 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -195,6 +195,9 @@ class Model(NewBaseModel): if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement: self_fields.pop(self.Meta.pkname, None) self_fields = self.populate_default_values(self_fields) + self.from_dict(self_fields) + + await self.signals.pre_save.send(sender=self.__class__, instance=self) self_fields = self.translate_columns_to_aliases(self_fields) expr = self.Meta.table.insert() @@ -204,6 +207,7 @@ class Model(NewBaseModel): if pk and isinstance(pk, self.pk_type()): setattr(self, self.Meta.pkname, pk) + self.set_save_status(True) # refresh server side defaults if any( field.server_default is not None @@ -211,9 +215,8 @@ class Model(NewBaseModel): if name not in self_fields ): await self.load() - return self - self.set_save_status(True) + await self.signals.post_save.send(sender=self.__class__, instance=self) return self async def save_related( # noqa: CCR001 @@ -268,6 +271,7 @@ class Model(NewBaseModel): "You cannot update not saved model! Use save or upsert method." ) + await self.signals.pre_update.send(sender=self.__class__, instance=self) self_fields = self._extract_model_db_fields() self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname)) self_fields = self.translate_columns_to_aliases(self_fields) @@ -276,13 +280,16 @@ class Model(NewBaseModel): await self.Meta.database.execute(expr) self.set_save_status(True) + await self.signals.post_update.send(sender=self.__class__, instance=self) return self async def delete(self: T) -> int: + await self.signals.pre_delete.send(sender=self.__class__, instance=self) expr = self.Meta.table.delete() expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname))) result = await self.Meta.database.execute(expr) self.set_save_status(False) + await self.signals.post_delete.send(sender=self.__class__, instance=self) return result async def load(self: T) -> T: diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 45e0fa0..ff3f8b0 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -35,6 +35,7 @@ from ormar.relations.relation_manager import RelationsManager if TYPE_CHECKING: # pragma no cover from ormar import Model + from ormar.signals import SignalEmitter T = TypeVar("T", bound=Model) @@ -212,6 +213,10 @@ class NewBaseModel( def saved(self) -> bool: return self._orm_saved + @property + def signals(self) -> "SignalEmitter": + return self.Meta.signals + @classmethod def pk_type(cls) -> Any: return cls.Meta.model_fields[cls.Meta.pkname].__type__ diff --git a/ormar/py.typed b/ormar/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 0b7a31d..63bbf1f 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -395,6 +395,9 @@ class QuerySet: expr = expr.values(**new_kwargs) instance = self.model(**kwargs) + await self.model.Meta.signals.pre_save.send( + sender=self.model, instance=instance + ) pk = await self.database.execute(expr) pk_name = self.model.get_column_alias(self.model_meta.pkname) @@ -411,6 +414,9 @@ class QuerySet: ): instance = await instance.load() instance.set_save_status(True) + await self.model.Meta.signals.post_save.send( + sender=self.model, instance=instance + ) return instance async def bulk_create(self, objects: List["Model"]) -> None: diff --git a/ormar/signals/__init__.py b/ormar/signals/__init__.py index e69de29..6f4706e 100644 --- a/ormar/signals/__init__.py +++ b/ormar/signals/__init__.py @@ -0,0 +1,3 @@ +from ormar.signals.signal import Signal, SignalEmitter + +__all__ = ["Signal", "SignalEmitter"] diff --git a/ormar/signals/signal.py b/ormar/signals/signal.py index 836bd2d..4ed12d7 100644 --- a/ormar/signals/signal.py +++ b/ormar/signals/signal.py @@ -1,9 +1,12 @@ import asyncio import inspect -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable, Dict, List, TYPE_CHECKING, Tuple, Type, Union from ormar.exceptions import SignalDefinitionError +if TYPE_CHECKING: # pragma: no cover + from ormar import Model + def callable_accepts_kwargs(func: Callable) -> bool: return any( @@ -45,9 +48,24 @@ class Signal: break return removed - async def send(self, sender: Any, **kwargs: Any) -> None: + async def send(self, sender: Type["Model"], **kwargs: Any) -> None: receivers = [] for receiver in self._receivers: _, receiver_func = receiver - receivers.append(receiver_func(sender, **kwargs)) + receivers.append(receiver_func(sender=sender, **kwargs)) await asyncio.gather(*receivers) + + +class SignalEmitter: + if TYPE_CHECKING: + signals: Dict[str, Signal] + + def __init__(self) -> None: + object.__setattr__(self, "signals", dict()) + + def __getattr__(self, item: str) -> Signal: + return self.signals.setdefault(item, Signal()) + + def __setattr__(self, key: str, value: Any) -> None: + signals = object.__getattribute__(self, "signals") + signals[key] = value diff --git a/scripts/clean.sh b/scripts/clean.sh deleted file mode 100755 index 9b30e14..0000000 --- a/scripts/clean.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/sh -e -PACKAGE="ormar" -if [ -d 'dist' ] ; then - rm -r dist -fi -if [ -d 'site' ] ; then - rm -r site -fi -if [ -d 'htmlcov' ] ; then - rm -r htmlcov -fi -if [ -d "${PACKAGE}.egg-info" ] ; then - rm -r "${PACKAGE}.egg-info" -fi -find ${PACKAGE} -type f -name "*.py[co]" -delete -find ${PACKAGE} -type d -name __pycache__ -delete - -find tests -type f -name "*.py[co]" -delete -find tests -type d -name __pycache__ -delete \ No newline at end of file diff --git a/scripts/publish.sh b/scripts/publish.sh deleted file mode 100755 index 419fa30..0000000 --- a/scripts/publish.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/sh -e - -PACKAGE="ormar" - -PREFIX="" -if [ -d 'venv' ] ; then - PREFIX="venv/bin/" -fi - -VERSION=`cat ${PACKAGE}/__init__.py | grep __version__ | sed "s/__version__ = //" | sed "s/'//g"` - -set -x - -scripts/clean.sh - -${PREFIX}python setup.py sdist -${PREFIX}twine upload dist/* - -echo "You probably want to also tag the version now:" -echo "git tag -a ${VERSION} -m 'version ${VERSION}'" -echo "git push --tags" - -scripts/clean.sh \ No newline at end of file diff --git a/setup.py b/setup.py index f461d97..85a3929 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,8 @@ setup( author_email="collerek@gmail.com", packages=get_packages(PACKAGE), package_data={PACKAGE: ["py.typed"]}, + include_package_data=True, + zip_safe=False, data_files=[("", ["LICENSE.md"])], install_requires=["databases", "pydantic>=1.5", "sqlalchemy", "typing_extensions"], extras_require={ @@ -65,9 +67,11 @@ setup( "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Topic :: Internet :: WWW/HTTP", + "Framework :: AsyncIO", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3 :: Only", ], ) diff --git a/tests/test_signals.py b/tests/test_signals.py new file mode 100644 index 0000000..6f59d6d --- /dev/null +++ b/tests/test_signals.py @@ -0,0 +1,245 @@ +import databases +import pydantic +import pytest +import sqlalchemy + +import ormar +from ormar.decorators.signals import ( + post_delete, + post_save, + post_update, + pre_delete, + pre_save, + pre_update, +) +from ormar.exceptions import SignalDefinitionError +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class AuditLog(ormar.Model): + class Meta: + tablename = "audits" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + event_type: str = ormar.String(max_length=100) + event_log: pydantic.Json = ormar.JSON() + + +class Album(ormar.Model): + class Meta: + tablename = "albums" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + is_best_seller: bool = ormar.Boolean(default=False) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_passing_not_callable(): + with pytest.raises(SignalDefinitionError): + pre_save(Album)("wrong") + + +def test_passing_callable_without_kwargs(): + with pytest.raises(SignalDefinitionError): + + @pre_save(Album) + def trigger(sender, instance): # pragma: no cover + pass + + +@pytest.mark.asyncio +async def test_signal_functions(): + async with database: + async with database.transaction(force_rollback=True): + + @pre_save(Album) + async def before_save(sender, instance, **kwargs): + await AuditLog( + event_type=f"PRE_SAVE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + @post_save(Album) + async def after_save(sender, instance, **kwargs): + await AuditLog( + event_type=f"POST_SAVE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + @pre_update(Album) + async def before_update(sender, instance, **kwargs): + await AuditLog( + event_type=f"PRE_UPDATE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + @post_update(Album) + async def after_update(sender, instance, **kwargs): + await AuditLog( + event_type=f"POST_UPDATE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + @pre_delete(Album) + async def before_delete(sender, instance, **kwargs): + await AuditLog( + event_type=f"PRE_DELETE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + @post_delete(Album) + async def after_delete(sender, instance, **kwargs): + await AuditLog( + event_type=f"POST_DELETE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + album = await Album.objects.create(name="Venice") + + audits = await AuditLog.objects.all() + assert len(audits) == 2 + assert audits[0].event_type == "PRE_SAVE_album" + assert audits[0].event_log.get("name") == album.name + assert audits[1].event_type == "POST_SAVE_album" + assert audits[1].event_log.get("id") == album.pk + + album = await Album(name="Rome").save() + audits = await AuditLog.objects.all() + assert len(audits) == 4 + assert audits[2].event_type == "PRE_SAVE_album" + assert audits[2].event_log.get("name") == album.name + assert audits[3].event_type == "POST_SAVE_album" + assert audits[3].event_log.get("id") == album.pk + + album.is_best_seller = True + await album.update() + + audits = await AuditLog.objects.filter(event_type__contains="UPDATE").all() + assert len(audits) == 2 + assert audits[0].event_type == "PRE_UPDATE_album" + assert audits[0].event_log.get("name") == album.name + assert audits[1].event_type == "POST_UPDATE_album" + assert audits[1].event_log.get("is_best_seller") == album.is_best_seller + + album.signals.pre_update.disconnect(before_update) + album.signals.post_update.disconnect(after_update) + + album.is_best_seller = False + await album.update() + + audits = await AuditLog.objects.filter(event_type__contains="UPDATE").all() + assert len(audits) == 2 + + await album.delete() + audits = await AuditLog.objects.filter(event_type__contains="DELETE").all() + assert len(audits) == 2 + assert audits[0].event_type == "PRE_DELETE_album" + assert ( + audits[0].event_log.get("id") + == audits[1].event_log.get("id") + == album.id + ) + assert audits[1].event_type == "POST_DELETE_album" + + album.signals.pre_delete.disconnect(before_delete) + album.signals.post_delete.disconnect(after_delete) + album.signals.pre_save.disconnect(before_save) + album.signals.post_save.disconnect(after_save) + + +@pytest.mark.asyncio +async def test_multiple_signals(): + async with database: + async with database.transaction(force_rollback=True): + + @pre_save(Album) + async def before_save(sender, instance, **kwargs): + await AuditLog( + event_type=f"PRE_SAVE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + @pre_save(Album) + async def before_save2(sender, instance, **kwargs): + await AuditLog( + event_type=f"PRE_SAVE2_{sender.get_name()}", + event_log=instance.json(), + ).save() + + album = await Album.objects.create(name="Miami") + audits = await AuditLog.objects.all() + assert len(audits) == 2 + assert audits[0].event_type == "PRE_SAVE_album" + assert audits[0].event_log.get("name") == album.name + assert audits[1].event_type == "PRE_SAVE2_album" + assert audits[1].event_log.get("name") == album.name + + album.signals.pre_save.disconnect(before_save) + album.signals.pre_save.disconnect(before_save2) + + +@pytest.mark.asyncio +async def test_static_methods_as_signals(): + async with database: + async with database.transaction(force_rollback=True): + + class AlbumAuditor: + event_type = "ALBUM_INSTANCE" + + @staticmethod + @pre_save(Album) + async def before_save(sender, instance, **kwargs): + await AuditLog( + event_type=f"{AlbumAuditor.event_type}_SAVE", + event_log=instance.json(), + ).save() + + album = await Album.objects.create(name="Colorado") + audits = await AuditLog.objects.all() + assert len(audits) == 1 + assert audits[0].event_type == "ALBUM_INSTANCE_SAVE" + assert audits[0].event_log.get("name") == album.name + + album.signals.pre_save.disconnect(AlbumAuditor.before_save) + + +@pytest.mark.asyncio +async def test_methods_as_signals(): + async with database: + async with database.transaction(force_rollback=True): + + class AlbumAuditor: + def __init__(self): + self.event_type = "ALBUM_INSTANCE" + + async def before_save(self, sender, instance, **kwargs): + await AuditLog( + event_type=f"{self.event_type}_SAVE", event_log=instance.json() + ).save() + + auditor = AlbumAuditor() + pre_save(Album)(auditor.before_save) + + album = await Album.objects.create(name="San Francisco") + audits = await AuditLog.objects.all() + assert len(audits) == 1 + assert audits[0].event_type == "ALBUM_INSTANCE_SAVE" + assert audits[0].event_log.get("name") == album.name + + album.signals.pre_save.disconnect(auditor.before_save) From 420826f47229650fc050eb2e230498e77022aa5b Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 6 Dec 2020 17:34:27 +0100 Subject: [PATCH 4/7] fix coverage --- ormar/signals/signal.py | 2 +- tests/test_signals.py | 48 ++++++++++++++++++++++++++++++++++------- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/ormar/signals/signal.py b/ormar/signals/signal.py index 4ed12d7..f3d92d9 100644 --- a/ormar/signals/signal.py +++ b/ormar/signals/signal.py @@ -57,7 +57,7 @@ class Signal: class SignalEmitter: - if TYPE_CHECKING: + if TYPE_CHECKING: # pragma: no cover signals: Dict[str, Signal] def __init__(self) -> None: diff --git a/tests/test_signals.py b/tests/test_signals.py index 6f59d6d..bc72706 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,3 +1,5 @@ +from typing import Optional + import databases import pydantic import pytest @@ -30,6 +32,16 @@ class AuditLog(ormar.Model): event_log: pydantic.Json = ormar.JSON() +class Cover(ormar.Model): + class Meta: + tablename = "covers" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + title: str = ormar.String(max_length=100) + + class Album(ormar.Model): class Meta: tablename = "albums" @@ -39,6 +51,7 @@ class Album(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) is_best_seller: bool = ormar.Boolean(default=False) + cover: Optional[Cover] = ormar.ForeignKey(Cover) @pytest.fixture(autouse=True, scope="module") @@ -57,7 +70,6 @@ def test_passing_not_callable(): def test_passing_callable_without_kwargs(): with pytest.raises(SignalDefinitionError): - @pre_save(Album) def trigger(sender, instance): # pragma: no cover pass @@ -67,7 +79,6 @@ def test_passing_callable_without_kwargs(): async def test_signal_functions(): async with database: async with database.transaction(force_rollback=True): - @pre_save(Album) async def before_save(sender, instance, **kwargs): await AuditLog( @@ -151,9 +162,9 @@ async def test_signal_functions(): assert len(audits) == 2 assert audits[0].event_type == "PRE_DELETE_album" assert ( - audits[0].event_log.get("id") - == audits[1].event_log.get("id") - == album.id + audits[0].event_log.get("id") + == audits[1].event_log.get("id") + == album.id ) assert audits[1].event_type == "POST_DELETE_album" @@ -167,7 +178,6 @@ async def test_signal_functions(): async def test_multiple_signals(): async with database: async with database.transaction(force_rollback=True): - @pre_save(Album) async def before_save(sender, instance, **kwargs): await AuditLog( @@ -198,7 +208,6 @@ async def test_multiple_signals(): async def test_static_methods_as_signals(): async with database: async with database.transaction(force_rollback=True): - class AlbumAuditor: event_type = "ALBUM_INSTANCE" @@ -223,7 +232,6 @@ async def test_static_methods_as_signals(): async def test_methods_as_signals(): async with database: async with database.transaction(force_rollback=True): - class AlbumAuditor: def __init__(self): self.event_type = "ALBUM_INSTANCE" @@ -243,3 +251,27 @@ async def test_methods_as_signals(): assert audits[0].event_log.get("name") == album.name album.signals.pre_save.disconnect(auditor.before_save) + +@pytest.mark.asyncio +async def test_multiple_senders_signal(): + async with database: + async with database.transaction(force_rollback=True): + @pre_save([Album, Cover]) + async def before_save(sender, instance, **kwargs): + await AuditLog( + event_type=f"PRE_SAVE_{sender.get_name()}", + event_log=instance.json(), + ).save() + + cover = await Cover(title='Blue').save() + album = await Album.objects.create(name="San Francisco", cover=cover) + + audits = await AuditLog.objects.all() + assert len(audits) == 2 + assert audits[0].event_type == "PRE_SAVE_cover" + assert audits[0].event_log.get("title") == cover.title + assert audits[1].event_type == "PRE_SAVE_album" + assert audits[1].event_log.get("cover") == album.cover.dict(exclude={'albums'}) + + album.signals.pre_save.disconnect(before_save) + cover.signals.pre_save.disconnect(before_save) From 9f86e1d46eece44c5c2ad3d31a5c8152185ed778 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 6 Dec 2020 19:45:09 +0100 Subject: [PATCH 5/7] bump version, more tests, update docs --- docs/releases.md | 6 +- docs/signals.md | 249 ++++++++++++++++++++++ docs_src/signals/__init__.py | 0 docs_src/signals/docs002.py | 22 ++ mkdocs.yml | 1 + ormar/__init__.py | 26 ++- ormar/decorators/__init__.py | 14 ++ ormar/decorators/signals.py | 42 ++-- tests/test_excluding_fields_in_fastapi.py | 18 +- tests/test_signals.py | 85 +++++++- 10 files changed, 419 insertions(+), 44 deletions(-) create mode 100644 docs/signals.md create mode 100644 docs_src/signals/__init__.py create mode 100644 docs_src/signals/docs002.py diff --git a/docs/releases.md b/docs/releases.md index bd8f3db..1ab8e39 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -2,9 +2,11 @@ * **Breaking:** QuerySet `bulk_update` method now raises `ModelPersistenceError` for unsaved models passed instead of `QueryDefinitionError` * **Breaking:** Model initialization with unknown field name now raises `ModelError` instead of `KeyError` -* -* Add py.typed and modify setup.py for mypy support +* Added **Signals**, with pre-defined list signals and decorators: `post_delete`, `post_save`, `post_update`, `pre_delete`, +`pre_save`, `pre_update` +* Add `py.typed` and modify `setup.py` for mypy support * Performance optimization +* Updated docs # 0.6.2 diff --git a/docs/signals.md b/docs/signals.md new file mode 100644 index 0000000..14286ca --- /dev/null +++ b/docs/signals.md @@ -0,0 +1,249 @@ +# Signals + +Signals are a mechanism to fire your piece of code (function / method) whenever given type of event happens in `ormar`. + +To achieve this you need to register your receiver for a given type of signal for selected model(s). + +## Defining receivers + +Given a sample model like following: + +```Python +import databases +import sqlalchemy + +import ormar + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Album(ormar.Model): + class Meta: + tablename = "albums" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + is_best_seller: bool = ormar.Boolean(default=False) + play_count: int = ormar.Integer(default=0) +``` + +You can for example define a trigger that will set `album.is_best_seller` status if it will be played more than 50 times. + +Import `pre_update` decorator, for list of currently available decorators/ signals check below. + +```Python hl_lines="1" +--8<-- "../docs_src/signals/docs002.py" +``` + +Define your function. + +Note that each receiver function: + +* has to be **callable** +* has to accept first **`sender`** argument that receives the class of sending object +* has to accept **`**kwargs`** argument as the parameters send in each `ormar.Signal` can change at any time so your function has to serve them. +* has to be **`async`** cause callbacks are gathered and awaited. + +`pre_update` currently sends only one argument apart from `sender` and it's `instance` one. + +Note how `pre_update` decorator accepts a `senders` argument that can be a single model or a list of models, +for which you want to run the signal receiver. + +Currently there is no way to set signal for all models at once without explicitly passing them all into registration of receiver. + +```Python hl_lines="4-7" +--8<-- "../docs_src/signals/docs002.py" +``` + +!!!note + Note that receivers are defined on a class level -> so even if you connect/disconnect function through instance + it will run/ stop running for all operations on that `ormar.Model` class. + +Note that our newly created function has instance and class of the instance so you can easily run database +queries inside your receivers if you want to. + +```Python hl_lines="15-22" +--8<-- "../docs_src/signals/docs002.py" +``` + +You can define same receiver for multiple models at once by passing a list of models to signal decorator. + +```python +# define a dummy debug function +@pre_update([Album, Track]) +async def before_update(sender, instance, **kwargs): + print(f"{sender.get_name()}: {instance.json()}: {kwargs}") +``` + +Of course you can also create multiple functions for the same signal and model. Each of them will run at each signal. + +```python +@pre_update(Album) +async def before_update(sender, instance, **kwargs): + print(f"{sender.get_name()}: {instance.json()}: {kwargs}") + +@pre_update(Album) +async def before_update2(sender, instance, **kwargs): + print(f'About to update {sender.get_name()} with pk: {instance.pk}') +``` + +Note that `ormar` decorators are the syntactic sugar, you can directly connect your function or method for given signal for +given model. Connect accept only one parameter - your `receiver` function / method. + +```python hl_lines="11 13 16" +class AlbumAuditor: + def __init__(self): + self.event_type = "ALBUM_INSTANCE" + + async def before_save(self, sender, instance, **kwargs): + await AuditLog( + event_type=f"{self.event_type}_SAVE", event_log=instance.json() + ).save() + +auditor = AlbumAuditor() +pre_save(Album)(auditor.before_save) +# call above has same result like the one below +Album.Meta.signals.pre_save.connect(auditor.before_save) +# signals are also exposed on instance +album = Album(name='Miami') +album.signals.pre_save.connect(auditor.before_save) +``` + +!!!warning + Note that signals keep the reference to your receiver (not a `weakref`) so keep that in mind to avoid circular references. + +## Disconnecting the receivers + +To disconnect the receiver and stop it for running for given model you need to disconnect it. + +```python hl_lines="7 10" + +@pre_update(Album) +async def before_update(sender, instance, **kwargs): + if instance.play_count > 50 and not instance.is_best_seller: + instance.is_best_seller = True + +# disconnect given function from signal for given Model +Album.Meta.signals.pre_save.disconnect(before_save) +# signals are also exposed on instance +album = Album(name='Miami') +album.signals.pre_save.disconnect(before_save) +``` + + +## Available signals + +!!!warning + Note that signals are **not** send for: + + * bulk operations (`QuerySet.bulk_create` and `QuerySet.bulk_update`) as they are designed for speed. + + * queyset table level operations (`QuerySet.update` and `QuerySet.delete`) as they run on the underlying tables + (more lak raw sql update/delete operations) and do not have specific instance. + +### pre_save + +`pre_save(sender: Type["Model"], instance: "Model")` + +Send for `Model.save()` and `Model.objects.create()` methods. + +`sender` is a `ormar.Model` class and `instance` is the model to be saved. + +### post_save + +`post_save(sender: Type["Model"], instance: "Model")` + +Send for `Model.save()` and `Model.objects.create()` methods. + +`sender` is a `ormar.Model` class and `instance` is the model that was saved. + +### pre_update + +`pre_update(sender: Type["Model"], instance: "Model")` + +Send for `Model.update()` method. + +`sender` is a `ormar.Model` class and `instance` is the model to be updated. + +### post_update + +`post_update(sender: Type["Model"], instance: "Model")` + +Send for `Model.update()` method. + +`sender` is a `ormar.Model` class and `instance` is the model that was updated. + +### pre_delete + +`pre_delete(sender: Type["Model"], instance: "Model")` + +Send for `Model.save()` and `Model.objects.create()` methods. + +`sender` is a `ormar.Model` class and `instance` is the model to be deleted. + +### post_delete + +`post_delete(sender: Type["Model"], instance: "Model")` + +Send for `Model.update()` method. + +`sender` is a `ormar.Model` class and `instance` is the model that was deleted. + +## Defining your own signals + +Note that you can create your own signals although you will have to send them manually in your code or subclass `ormar.Model` +and trigger your signals there. + +Creating new signal is super easy. Following example will set a new signal with name your_custom_signal. + +```python hl_lines="21" +import databases +import sqlalchemy + +import ormar + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Album(ormar.Model): + class Meta: + tablename = "albums" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + is_best_seller: bool = ormar.Boolean(default=False) + play_count: int = ormar.Integer(default=0) + +Album.Meta.signals.your_custom_signal = ormar.Signal() +Album.Meta.signals.your_custom_signal.connect(your_receiver_name) +``` + +Actually under the hood signal is a `SignalEmitter` instance that keeps a dictionary of know signals, and allows you +to access them as attributes. When you try to access a signal that does not exist `SignalEmitter` will create one for you. + +So example above can be simplified to. The `Signal` will be created for you. + +``` +Album.Meta.signals.your_custom_signal.connect(your_receiver_name) +``` + +Now to trigger this signal you need to call send method of the Signal. + +```python +await Album.Meta.signals.your_custom_signal.send(sender=Album) +``` + +Note that sender is the only required parameter and it should be ormar Model class. + +Additional parameters have to be passed as keyword arguments. + +```python +await Album.Meta.signals.your_custom_signal.send(sender=Album, my_param=True) +``` + diff --git a/docs_src/signals/__init__.py b/docs_src/signals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docs_src/signals/docs002.py b/docs_src/signals/docs002.py new file mode 100644 index 0000000..5b1a15f --- /dev/null +++ b/docs_src/signals/docs002.py @@ -0,0 +1,22 @@ +from ormar import pre_update + + +@pre_update(Album) +async def before_update(sender, instance, **kwargs): + if instance.play_count > 50 and not instance.is_best_seller: + instance.is_best_seller = True + + +# here album.play_count ans is_best_seller get default values +album = await Album.objects.create(name="Venice") +assert not album.is_best_seller +assert album.play_count == 0 + +album.play_count = 30 +# here a trigger is called but play_count is too low +await album.update() +assert not album.is_best_seller + +album.play_count = 60 +await album.update() +assert album.is_best_seller diff --git a/mkdocs.yml b/mkdocs.yml index ccaae7c..60b42b2 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -7,6 +7,7 @@ nav: - Fields: fields.md - Relations: relations.md - Queries: queries.md + - Signals: signals.md - Use with Fastapi: fastapi.md - Use with mypy: mypy.md - PyCharm plugin: plugin.md diff --git a/ormar/__init__.py b/ormar/__init__.py index 036e372..5891f71 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -1,6 +1,18 @@ -from ormar.decorators import property_field -from ormar.exceptions import ModelDefinitionError, MultipleMatches, NoMatch +from ormar.decorators import ( + post_delete, + post_save, + post_update, + pre_delete, + pre_save, + pre_update, + property_field, +) from ormar.protocols import QuerySetProtocol, RelationProtocol # noqa: I100 +from ormar.exceptions import ( # noqa: I100 + ModelDefinitionError, + MultipleMatches, + NoMatch, +) from ormar.fields import ( # noqa: I100 BigInteger, Boolean, @@ -22,6 +34,7 @@ from ormar.models import Model from ormar.models.metaclass import ModelMeta from ormar.queryset import QuerySet from ormar.relations import RelationType +from ormar.signals import Signal class UndefinedType: # pragma no cover @@ -31,7 +44,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.6.2" +__version__ = "0.7.0" __all__ = [ "Integer", "BigInteger", @@ -59,4 +72,11 @@ __all__ = [ "RelationProtocol", "ModelMeta", "property_field", + "post_delete", + "post_save", + "post_update", + "pre_delete", + "pre_save", + "pre_update", + "Signal", ] diff --git a/ormar/decorators/__init__.py b/ormar/decorators/__init__.py index 7dfbe5e..395e3e2 100644 --- a/ormar/decorators/__init__.py +++ b/ormar/decorators/__init__.py @@ -1,5 +1,19 @@ from ormar.decorators.property_field import property_field +from ormar.decorators.signals import ( + post_delete, + post_save, + post_update, + pre_delete, + pre_save, + pre_update, +) __all__ = [ "property_field", + "post_delete", + "post_save", + "post_update", + "pre_delete", + "pre_save", + "pre_update", ] diff --git a/ormar/decorators/signals.py b/ormar/decorators/signals.py index d149f97..0505e1a 100644 --- a/ormar/decorators/signals.py +++ b/ormar/decorators/signals.py @@ -1,11 +1,11 @@ -from typing import Any, Callable, List, TYPE_CHECKING, Type, Union +from typing import Callable, List, TYPE_CHECKING, Type, Union if TYPE_CHECKING: # pragma: no cover from ormar import Model def receiver( - signal: str, senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any + signal: str, senders: Union[Type["Model"], List[Type["Model"]]] ) -> Callable: def _decorator(func: Callable) -> Callable: if not isinstance(senders, list): @@ -14,43 +14,31 @@ def receiver( _senders = senders for sender in _senders: signals = getattr(sender.Meta.signals, signal) - signals.connect(func, **kwargs) + signals.connect(func) return func return _decorator -def post_save( - senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any -) -> Callable: - return receiver(signal="post_save", senders=senders, **kwargs) +def post_save(senders: Union[Type["Model"], List[Type["Model"]]],) -> Callable: + return receiver(signal="post_save", senders=senders) -def post_update( - senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any -) -> Callable: - return receiver(signal="post_update", senders=senders, **kwargs) +def post_update(senders: Union[Type["Model"], List[Type["Model"]]],) -> Callable: + return receiver(signal="post_update", senders=senders) -def post_delete( - senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any -) -> Callable: - return receiver(signal="post_delete", senders=senders, **kwargs) +def post_delete(senders: Union[Type["Model"], List[Type["Model"]]],) -> Callable: + return receiver(signal="post_delete", senders=senders) -def pre_save( - senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any -) -> Callable: - return receiver(signal="pre_save", senders=senders, **kwargs) +def pre_save(senders: Union[Type["Model"], List[Type["Model"]]],) -> Callable: + return receiver(signal="pre_save", senders=senders) -def pre_update( - senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any -) -> Callable: - return receiver(signal="pre_update", senders=senders, **kwargs) +def pre_update(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: + return receiver(signal="pre_update", senders=senders) -def pre_delete( - senders: Union[Type["Model"], List[Type["Model"]]], **kwargs: Any -) -> Callable: - return receiver(signal="pre_delete", senders=senders, **kwargs) +def pre_delete(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: + return receiver(signal="pre_delete", senders=senders) diff --git a/tests/test_excluding_fields_in_fastapi.py b/tests/test_excluding_fields_in_fastapi.py index 3e3349b..6568d9b 100644 --- a/tests/test_excluding_fields_in_fastapi.py +++ b/tests/test_excluding_fields_in_fastapi.py @@ -10,7 +10,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient import ormar -from ormar import property_field +from ormar import post_save, property_field from tests.settings import DATABASE_URL app = FastAPI() @@ -65,8 +65,6 @@ class RandomModel(ormar.Model): metadata = metadata database = database - include_props_in_dict = True - id: int = ormar.Integer(primary_key=True) password: str = ormar.String(max_length=255, default=gen_pass) first_name: str = ormar.String(max_length=255, default="John") @@ -251,7 +249,6 @@ def test_adding_fields_in_endpoints(): ] assert response.json().get("full_name") == "John Test" - RandomModel.Meta.include_props_in_fields = False user3 = {"last_name": "Test"} response = client.post("/random/", json=user3) assert list(response.json().keys()) == [ @@ -268,7 +265,6 @@ def test_adding_fields_in_endpoints(): def test_adding_fields_in_endpoints2(): client = TestClient(app) with client as client: - RandomModel.Meta.include_props_in_dict = True user3 = {"last_name": "Test"} response = client.post("/random2/", json=user3) assert list(response.json().keys()) == [ @@ -283,9 +279,15 @@ def test_adding_fields_in_endpoints2(): def test_excluding_property_field_in_endpoints2(): + + dummy_registry = {} + + @post_save(RandomModel) + async def after_save(sender, instance, **kwargs): + dummy_registry[instance.pk] = instance.dict() + client = TestClient(app) with client as client: - RandomModel.Meta.include_props_in_dict = True user3 = {"last_name": "Test"} response = client.post("/random3/", json=user3) assert list(response.json().keys()) == [ @@ -296,3 +298,7 @@ def test_excluding_property_field_in_endpoints2(): "created_date", ] assert response.json().get("full_name") is None + assert len(dummy_registry) == 1 + check_dict = dummy_registry.get(response.json().get("id")) + check_dict.pop("full_name") + assert response.json().get("password") == check_dict.get("password") diff --git a/tests/test_signals.py b/tests/test_signals.py index bc72706..72eac77 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -6,7 +6,7 @@ import pytest import sqlalchemy import ormar -from ormar.decorators.signals import ( +from ormar import ( post_delete, post_save, post_update, @@ -51,6 +51,7 @@ class Album(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) is_best_seller: bool = ormar.Boolean(default=False) + play_count: int = ormar.Integer(default=0) cover: Optional[Cover] = ormar.ForeignKey(Cover) @@ -70,6 +71,7 @@ def test_passing_not_callable(): def test_passing_callable_without_kwargs(): with pytest.raises(SignalDefinitionError): + @pre_save(Album) def trigger(sender, instance): # pragma: no cover pass @@ -79,6 +81,7 @@ def test_passing_callable_without_kwargs(): async def test_signal_functions(): async with database: async with database.transaction(force_rollback=True): + @pre_save(Album) async def before_save(sender, instance, **kwargs): await AuditLog( @@ -162,9 +165,9 @@ async def test_signal_functions(): assert len(audits) == 2 assert audits[0].event_type == "PRE_DELETE_album" assert ( - audits[0].event_log.get("id") - == audits[1].event_log.get("id") - == album.id + audits[0].event_log.get("id") + == audits[1].event_log.get("id") + == album.id ) assert audits[1].event_type == "POST_DELETE_album" @@ -178,6 +181,7 @@ async def test_signal_functions(): async def test_multiple_signals(): async with database: async with database.transaction(force_rollback=True): + @pre_save(Album) async def before_save(sender, instance, **kwargs): await AuditLog( @@ -208,6 +212,7 @@ async def test_multiple_signals(): async def test_static_methods_as_signals(): async with database: async with database.transaction(force_rollback=True): + class AlbumAuditor: event_type = "ALBUM_INSTANCE" @@ -232,6 +237,7 @@ async def test_static_methods_as_signals(): async def test_methods_as_signals(): async with database: async with database.transaction(force_rollback=True): + class AlbumAuditor: def __init__(self): self.event_type = "ALBUM_INSTANCE" @@ -252,10 +258,12 @@ async def test_methods_as_signals(): album.signals.pre_save.disconnect(auditor.before_save) + @pytest.mark.asyncio async def test_multiple_senders_signal(): async with database: async with database.transaction(force_rollback=True): + @pre_save([Album, Cover]) async def before_save(sender, instance, **kwargs): await AuditLog( @@ -263,7 +271,7 @@ async def test_multiple_senders_signal(): event_log=instance.json(), ).save() - cover = await Cover(title='Blue').save() + cover = await Cover(title="Blue").save() album = await Album.objects.create(name="San Francisco", cover=cover) audits = await AuditLog.objects.all() @@ -271,7 +279,72 @@ async def test_multiple_senders_signal(): assert audits[0].event_type == "PRE_SAVE_cover" assert audits[0].event_log.get("title") == cover.title assert audits[1].event_type == "PRE_SAVE_album" - assert audits[1].event_log.get("cover") == album.cover.dict(exclude={'albums'}) + assert audits[1].event_log.get("cover") == album.cover.dict( + exclude={"albums"} + ) album.signals.pre_save.disconnect(before_save) cover.signals.pre_save.disconnect(before_save) + + +@pytest.mark.asyncio +async def test_modifing_the_instance(): + async with database: + async with database.transaction(force_rollback=True): + + @pre_update(Album) + async def before_update(sender, instance, **kwargs): + if instance.play_count > 50 and not instance.is_best_seller: + instance.is_best_seller = True + + # here album.play_count ans is_best_seller get default values + album = await Album.objects.create(name="Venice") + assert not album.is_best_seller + assert album.play_count == 0 + + album.play_count = 30 + # here a trigger is called but play_count is too low + await album.update() + assert not album.is_best_seller + + album.play_count = 60 + await album.update() + assert album.is_best_seller + album.signals.pre_update.disconnect(before_update) + + +@pytest.mark.asyncio +async def test_custom_signal(): + async with database: + async with database.transaction(force_rollback=True): + + async def after_update(sender, instance, **kwargs): + if instance.play_count > 50 and not instance.is_best_seller: + instance.is_best_seller = True + elif instance.play_count < 50 and instance.is_best_seller: + instance.is_best_seller = False + await instance.update() + + Album.Meta.signals.custom.connect(after_update) + + # here album.play_count ans is_best_seller get default values + album = await Album.objects.create(name="Venice") + assert not album.is_best_seller + assert album.play_count == 0 + + album.play_count = 30 + # here a trigger is called but play_count is too low + await album.update() + assert not album.is_best_seller + + album.play_count = 60 + await album.update() + assert not album.is_best_seller + await Album.Meta.signals.custom.send(sender=Album, instance=album) + assert album.is_best_seller + + album.play_count = 30 + await album.update() + assert album.is_best_seller + await Album.Meta.signals.custom.send(sender=Album, instance=album) + assert not album.is_best_seller From a24f1b923bea9fce3b70e2406f2a698778e4d9f7 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 6 Dec 2020 19:50:35 +0100 Subject: [PATCH 6/7] update readme --- README.md | 12 ++++++++++++ docs/index.md | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/README.md b/README.md index f7adfbe..3d4546f 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,7 @@ The following keyword arguments are supported on all field types. * `unique: bool` * `choices: typing.Sequence` * `name: str` + * `pydantic_only: bool` All fields are required unless one of the following is set: @@ -211,7 +212,18 @@ All fields are required unless one of the following is set: * `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`). * `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column. Autoincrement is set by default on int primary keys. + * `pydantic_only` - Field is available only as normal pydantic field, not stored in the database. + +### Available signals +Signals allow to trigger your function for a given event on a given Model. + + * `pre_save` + * `post_save` + * `pre_update` + * `post_update` + * `pre_delete` + * `post_delete` [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ diff --git a/docs/index.md b/docs/index.md index f7adfbe..3d4546f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -203,6 +203,7 @@ The following keyword arguments are supported on all field types. * `unique: bool` * `choices: typing.Sequence` * `name: str` + * `pydantic_only: bool` All fields are required unless one of the following is set: @@ -211,7 +212,18 @@ All fields are required unless one of the following is set: * `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`). * `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column. Autoincrement is set by default on int primary keys. + * `pydantic_only` - Field is available only as normal pydantic field, not stored in the database. + +### Available signals +Signals allow to trigger your function for a given event on a given Model. + + * `pre_save` + * `post_save` + * `pre_update` + * `post_update` + * `pre_delete` + * `post_delete` [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ From 7dd892edb5799f682cff5fcbb1a997ee47973b6e Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 6 Dec 2020 19:59:43 +0100 Subject: [PATCH 7/7] add cleanup in tests --- tests/test_signals.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/tests/test_signals.py b/tests/test_signals.py index 72eac77..5670980 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -64,6 +64,13 @@ def create_test_database(): metadata.drop_all(engine) +@pytest.fixture(scope="function") +async def cleanup(): + yield + async with database: + await AuditLog.objects.delete(each=True) + + def test_passing_not_callable(): with pytest.raises(SignalDefinitionError): pre_save(Album)("wrong") @@ -71,17 +78,15 @@ def test_passing_not_callable(): def test_passing_callable_without_kwargs(): with pytest.raises(SignalDefinitionError): - @pre_save(Album) def trigger(sender, instance): # pragma: no cover pass @pytest.mark.asyncio -async def test_signal_functions(): +async def test_signal_functions(cleanup): async with database: async with database.transaction(force_rollback=True): - @pre_save(Album) async def before_save(sender, instance, **kwargs): await AuditLog( @@ -165,9 +170,9 @@ async def test_signal_functions(): assert len(audits) == 2 assert audits[0].event_type == "PRE_DELETE_album" assert ( - audits[0].event_log.get("id") - == audits[1].event_log.get("id") - == album.id + audits[0].event_log.get("id") + == audits[1].event_log.get("id") + == album.id ) assert audits[1].event_type == "POST_DELETE_album" @@ -178,10 +183,9 @@ async def test_signal_functions(): @pytest.mark.asyncio -async def test_multiple_signals(): +async def test_multiple_signals(cleanup): async with database: async with database.transaction(force_rollback=True): - @pre_save(Album) async def before_save(sender, instance, **kwargs): await AuditLog( @@ -192,7 +196,7 @@ async def test_multiple_signals(): @pre_save(Album) async def before_save2(sender, instance, **kwargs): await AuditLog( - event_type=f"PRE_SAVE2_{sender.get_name()}", + event_type=f"PRE_SAVE_{sender.get_name()}", event_log=instance.json(), ).save() @@ -201,7 +205,7 @@ async def test_multiple_signals(): assert len(audits) == 2 assert audits[0].event_type == "PRE_SAVE_album" assert audits[0].event_log.get("name") == album.name - assert audits[1].event_type == "PRE_SAVE2_album" + assert audits[1].event_type == "PRE_SAVE_album" assert audits[1].event_log.get("name") == album.name album.signals.pre_save.disconnect(before_save) @@ -209,10 +213,9 @@ async def test_multiple_signals(): @pytest.mark.asyncio -async def test_static_methods_as_signals(): +async def test_static_methods_as_signals(cleanup): async with database: async with database.transaction(force_rollback=True): - class AlbumAuditor: event_type = "ALBUM_INSTANCE" @@ -234,10 +237,9 @@ async def test_static_methods_as_signals(): @pytest.mark.asyncio -async def test_methods_as_signals(): +async def test_methods_as_signals(cleanup): async with database: async with database.transaction(force_rollback=True): - class AlbumAuditor: def __init__(self): self.event_type = "ALBUM_INSTANCE" @@ -260,10 +262,9 @@ async def test_methods_as_signals(): @pytest.mark.asyncio -async def test_multiple_senders_signal(): +async def test_multiple_senders_signal(cleanup): async with database: async with database.transaction(force_rollback=True): - @pre_save([Album, Cover]) async def before_save(sender, instance, **kwargs): await AuditLog( @@ -288,10 +289,9 @@ async def test_multiple_senders_signal(): @pytest.mark.asyncio -async def test_modifing_the_instance(): +async def test_modifing_the_instance(cleanup): async with database: async with database.transaction(force_rollback=True): - @pre_update(Album) async def before_update(sender, instance, **kwargs): if instance.play_count > 50 and not instance.is_best_seller: @@ -314,7 +314,7 @@ async def test_modifing_the_instance(): @pytest.mark.asyncio -async def test_custom_signal(): +async def test_custom_signal(cleanup): async with database: async with database.transaction(force_rollback=True): @@ -348,3 +348,5 @@ async def test_custom_signal(): assert album.is_best_seller await Album.Meta.signals.custom.send(sender=Album, instance=album) assert not album.is_best_seller + + Album.Meta.signals.custom.disconnect(after_update)