12
README.md
12
README.md
@ -203,6 +203,7 @@ The following keyword arguments are supported on all field types.
|
||||
* `unique: bool`
|
||||
* `choices: typing.Sequence`
|
||||
* `name: str`
|
||||
* `pydantic_only: bool`
|
||||
|
||||
All fields are required unless one of the following is set:
|
||||
|
||||
@ -211,7 +212,18 @@ All fields are required unless one of the following is set:
|
||||
* `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`).
|
||||
* `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column.
|
||||
Autoincrement is set by default on int primary keys.
|
||||
* `pydantic_only` - Field is available only as normal pydantic field, not stored in the database.
|
||||
|
||||
### Available signals
|
||||
|
||||
Signals allow to trigger your function for a given event on a given Model.
|
||||
|
||||
* `pre_save`
|
||||
* `post_save`
|
||||
* `pre_update`
|
||||
* `post_update`
|
||||
* `pre_delete`
|
||||
* `post_delete`
|
||||
|
||||
|
||||
[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/
|
||||
|
||||
@ -203,6 +203,7 @@ The following keyword arguments are supported on all field types.
|
||||
* `unique: bool`
|
||||
* `choices: typing.Sequence`
|
||||
* `name: str`
|
||||
* `pydantic_only: bool`
|
||||
|
||||
All fields are required unless one of the following is set:
|
||||
|
||||
@ -211,7 +212,18 @@ All fields are required unless one of the following is set:
|
||||
* `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`).
|
||||
* `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column.
|
||||
Autoincrement is set by default on int primary keys.
|
||||
* `pydantic_only` - Field is available only as normal pydantic field, not stored in the database.
|
||||
|
||||
### Available signals
|
||||
|
||||
Signals allow to trigger your function for a given event on a given Model.
|
||||
|
||||
* `pre_save`
|
||||
* `post_save`
|
||||
* `pre_update`
|
||||
* `post_update`
|
||||
* `pre_delete`
|
||||
* `post_delete`
|
||||
|
||||
|
||||
[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/
|
||||
|
||||
@ -1,3 +1,13 @@
|
||||
# 0.7.0
|
||||
|
||||
* **Breaking:** QuerySet `bulk_update` method now raises `ModelPersistenceError` for unsaved models passed instead of `QueryDefinitionError`
|
||||
* **Breaking:** Model initialization with unknown field name now raises `ModelError` instead of `KeyError`
|
||||
* Added **Signals**, with pre-defined list signals and decorators: `post_delete`, `post_save`, `post_update`, `pre_delete`,
|
||||
`pre_save`, `pre_update`
|
||||
* Add `py.typed` and modify `setup.py` for mypy support
|
||||
* Performance optimization
|
||||
* Updated docs
|
||||
|
||||
# 0.6.2
|
||||
|
||||
* Performance optimization
|
||||
@ -12,7 +22,7 @@
|
||||
|
||||
# 0.6.0
|
||||
|
||||
* **Breaking:** calling instance.load() when the instance row was deleted from db now raises ormar.NoMatch instead of ValueError
|
||||
* **Breaking:** calling instance.load() when the instance row was deleted from db now raises `NoMatch` instead of `ValueError`
|
||||
* **Breaking:** calling add and remove on ReverseForeignKey relation now updates the child model in db setting/removing fk column
|
||||
* **Breaking:** ReverseForeignKey relation now exposes QuerySetProxy API like ManyToMany relation
|
||||
* **Breaking:** querying related models from ManyToMany cleans list of related models loaded on parent model:
|
||||
|
||||
249
docs/signals.md
Normal file
249
docs/signals.md
Normal file
@ -0,0 +1,249 @@
|
||||
# Signals
|
||||
|
||||
Signals are a mechanism to fire your piece of code (function / method) whenever given type of event happens in `ormar`.
|
||||
|
||||
To achieve this you need to register your receiver for a given type of signal for selected model(s).
|
||||
|
||||
## Defining receivers
|
||||
|
||||
Given a sample model like following:
|
||||
|
||||
```Python
|
||||
import databases
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
|
||||
database = databases.Database("sqlite:///db.sqlite")
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
class Album(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "albums"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100)
|
||||
is_best_seller: bool = ormar.Boolean(default=False)
|
||||
play_count: int = ormar.Integer(default=0)
|
||||
```
|
||||
|
||||
You can for example define a trigger that will set `album.is_best_seller` status if it will be played more than 50 times.
|
||||
|
||||
Import `pre_update` decorator, for list of currently available decorators/ signals check below.
|
||||
|
||||
```Python hl_lines="1"
|
||||
--8<-- "../docs_src/signals/docs002.py"
|
||||
```
|
||||
|
||||
Define your function.
|
||||
|
||||
Note that each receiver function:
|
||||
|
||||
* has to be **callable**
|
||||
* has to accept first **`sender`** argument that receives the class of sending object
|
||||
* has to accept **`**kwargs`** argument as the parameters send in each `ormar.Signal` can change at any time so your function has to serve them.
|
||||
* has to be **`async`** cause callbacks are gathered and awaited.
|
||||
|
||||
`pre_update` currently sends only one argument apart from `sender` and it's `instance` one.
|
||||
|
||||
Note how `pre_update` decorator accepts a `senders` argument that can be a single model or a list of models,
|
||||
for which you want to run the signal receiver.
|
||||
|
||||
Currently there is no way to set signal for all models at once without explicitly passing them all into registration of receiver.
|
||||
|
||||
```Python hl_lines="4-7"
|
||||
--8<-- "../docs_src/signals/docs002.py"
|
||||
```
|
||||
|
||||
!!!note
|
||||
Note that receivers are defined on a class level -> so even if you connect/disconnect function through instance
|
||||
it will run/ stop running for all operations on that `ormar.Model` class.
|
||||
|
||||
Note that our newly created function has instance and class of the instance so you can easily run database
|
||||
queries inside your receivers if you want to.
|
||||
|
||||
```Python hl_lines="15-22"
|
||||
--8<-- "../docs_src/signals/docs002.py"
|
||||
```
|
||||
|
||||
You can define same receiver for multiple models at once by passing a list of models to signal decorator.
|
||||
|
||||
```python
|
||||
# define a dummy debug function
|
||||
@pre_update([Album, Track])
|
||||
async def before_update(sender, instance, **kwargs):
|
||||
print(f"{sender.get_name()}: {instance.json()}: {kwargs}")
|
||||
```
|
||||
|
||||
Of course you can also create multiple functions for the same signal and model. Each of them will run at each signal.
|
||||
|
||||
```python
|
||||
@pre_update(Album)
|
||||
async def before_update(sender, instance, **kwargs):
|
||||
print(f"{sender.get_name()}: {instance.json()}: {kwargs}")
|
||||
|
||||
@pre_update(Album)
|
||||
async def before_update2(sender, instance, **kwargs):
|
||||
print(f'About to update {sender.get_name()} with pk: {instance.pk}')
|
||||
```
|
||||
|
||||
Note that `ormar` decorators are the syntactic sugar, you can directly connect your function or method for given signal for
|
||||
given model. Connect accept only one parameter - your `receiver` function / method.
|
||||
|
||||
```python hl_lines="11 13 16"
|
||||
class AlbumAuditor:
|
||||
def __init__(self):
|
||||
self.event_type = "ALBUM_INSTANCE"
|
||||
|
||||
async def before_save(self, sender, instance, **kwargs):
|
||||
await AuditLog(
|
||||
event_type=f"{self.event_type}_SAVE", event_log=instance.json()
|
||||
).save()
|
||||
|
||||
auditor = AlbumAuditor()
|
||||
pre_save(Album)(auditor.before_save)
|
||||
# call above has same result like the one below
|
||||
Album.Meta.signals.pre_save.connect(auditor.before_save)
|
||||
# signals are also exposed on instance
|
||||
album = Album(name='Miami')
|
||||
album.signals.pre_save.connect(auditor.before_save)
|
||||
```
|
||||
|
||||
!!!warning
|
||||
Note that signals keep the reference to your receiver (not a `weakref`) so keep that in mind to avoid circular references.
|
||||
|
||||
## Disconnecting the receivers
|
||||
|
||||
To disconnect the receiver and stop it for running for given model you need to disconnect it.
|
||||
|
||||
```python hl_lines="7 10"
|
||||
|
||||
@pre_update(Album)
|
||||
async def before_update(sender, instance, **kwargs):
|
||||
if instance.play_count > 50 and not instance.is_best_seller:
|
||||
instance.is_best_seller = True
|
||||
|
||||
# disconnect given function from signal for given Model
|
||||
Album.Meta.signals.pre_save.disconnect(before_save)
|
||||
# signals are also exposed on instance
|
||||
album = Album(name='Miami')
|
||||
album.signals.pre_save.disconnect(before_save)
|
||||
```
|
||||
|
||||
|
||||
## Available signals
|
||||
|
||||
!!!warning
|
||||
Note that signals are **not** send for:
|
||||
|
||||
* bulk operations (`QuerySet.bulk_create` and `QuerySet.bulk_update`) as they are designed for speed.
|
||||
|
||||
* queyset table level operations (`QuerySet.update` and `QuerySet.delete`) as they run on the underlying tables
|
||||
(more lak raw sql update/delete operations) and do not have specific instance.
|
||||
|
||||
### pre_save
|
||||
|
||||
`pre_save(sender: Type["Model"], instance: "Model")`
|
||||
|
||||
Send for `Model.save()` and `Model.objects.create()` methods.
|
||||
|
||||
`sender` is a `ormar.Model` class and `instance` is the model to be saved.
|
||||
|
||||
### post_save
|
||||
|
||||
`post_save(sender: Type["Model"], instance: "Model")`
|
||||
|
||||
Send for `Model.save()` and `Model.objects.create()` methods.
|
||||
|
||||
`sender` is a `ormar.Model` class and `instance` is the model that was saved.
|
||||
|
||||
### pre_update
|
||||
|
||||
`pre_update(sender: Type["Model"], instance: "Model")`
|
||||
|
||||
Send for `Model.update()` method.
|
||||
|
||||
`sender` is a `ormar.Model` class and `instance` is the model to be updated.
|
||||
|
||||
### post_update
|
||||
|
||||
`post_update(sender: Type["Model"], instance: "Model")`
|
||||
|
||||
Send for `Model.update()` method.
|
||||
|
||||
`sender` is a `ormar.Model` class and `instance` is the model that was updated.
|
||||
|
||||
### pre_delete
|
||||
|
||||
`pre_delete(sender: Type["Model"], instance: "Model")`
|
||||
|
||||
Send for `Model.save()` and `Model.objects.create()` methods.
|
||||
|
||||
`sender` is a `ormar.Model` class and `instance` is the model to be deleted.
|
||||
|
||||
### post_delete
|
||||
|
||||
`post_delete(sender: Type["Model"], instance: "Model")`
|
||||
|
||||
Send for `Model.update()` method.
|
||||
|
||||
`sender` is a `ormar.Model` class and `instance` is the model that was deleted.
|
||||
|
||||
## Defining your own signals
|
||||
|
||||
Note that you can create your own signals although you will have to send them manually in your code or subclass `ormar.Model`
|
||||
and trigger your signals there.
|
||||
|
||||
Creating new signal is super easy. Following example will set a new signal with name your_custom_signal.
|
||||
|
||||
```python hl_lines="21"
|
||||
import databases
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
|
||||
database = databases.Database("sqlite:///db.sqlite")
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
class Album(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "albums"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100)
|
||||
is_best_seller: bool = ormar.Boolean(default=False)
|
||||
play_count: int = ormar.Integer(default=0)
|
||||
|
||||
Album.Meta.signals.your_custom_signal = ormar.Signal()
|
||||
Album.Meta.signals.your_custom_signal.connect(your_receiver_name)
|
||||
```
|
||||
|
||||
Actually under the hood signal is a `SignalEmitter` instance that keeps a dictionary of know signals, and allows you
|
||||
to access them as attributes. When you try to access a signal that does not exist `SignalEmitter` will create one for you.
|
||||
|
||||
So example above can be simplified to. The `Signal` will be created for you.
|
||||
|
||||
```
|
||||
Album.Meta.signals.your_custom_signal.connect(your_receiver_name)
|
||||
```
|
||||
|
||||
Now to trigger this signal you need to call send method of the Signal.
|
||||
|
||||
```python
|
||||
await Album.Meta.signals.your_custom_signal.send(sender=Album)
|
||||
```
|
||||
|
||||
Note that sender is the only required parameter and it should be ormar Model class.
|
||||
|
||||
Additional parameters have to be passed as keyword arguments.
|
||||
|
||||
```python
|
||||
await Album.Meta.signals.your_custom_signal.send(sender=Album, my_param=True)
|
||||
```
|
||||
|
||||
0
docs_src/signals/__init__.py
Normal file
0
docs_src/signals/__init__.py
Normal file
22
docs_src/signals/docs002.py
Normal file
22
docs_src/signals/docs002.py
Normal file
@ -0,0 +1,22 @@
|
||||
from ormar import pre_update
|
||||
|
||||
|
||||
@pre_update(Album)
|
||||
async def before_update(sender, instance, **kwargs):
|
||||
if instance.play_count > 50 and not instance.is_best_seller:
|
||||
instance.is_best_seller = True
|
||||
|
||||
|
||||
# here album.play_count ans is_best_seller get default values
|
||||
album = await Album.objects.create(name="Venice")
|
||||
assert not album.is_best_seller
|
||||
assert album.play_count == 0
|
||||
|
||||
album.play_count = 30
|
||||
# here a trigger is called but play_count is too low
|
||||
await album.update()
|
||||
assert not album.is_best_seller
|
||||
|
||||
album.play_count = 60
|
||||
await album.update()
|
||||
assert album.is_best_seller
|
||||
@ -7,6 +7,7 @@ nav:
|
||||
- Fields: fields.md
|
||||
- Relations: relations.md
|
||||
- Queries: queries.md
|
||||
- Signals: signals.md
|
||||
- Use with Fastapi: fastapi.md
|
||||
- Use with mypy: mypy.md
|
||||
- PyCharm plugin: plugin.md
|
||||
|
||||
@ -1,6 +1,18 @@
|
||||
from ormar.decorators import property_field
|
||||
from ormar.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch
|
||||
from ormar.decorators import (
|
||||
post_delete,
|
||||
post_save,
|
||||
post_update,
|
||||
pre_delete,
|
||||
pre_save,
|
||||
pre_update,
|
||||
property_field,
|
||||
)
|
||||
from ormar.protocols import QuerySetProtocol, RelationProtocol # noqa: I100
|
||||
from ormar.exceptions import ( # noqa: I100
|
||||
ModelDefinitionError,
|
||||
MultipleMatches,
|
||||
NoMatch,
|
||||
)
|
||||
from ormar.fields import ( # noqa: I100
|
||||
BigInteger,
|
||||
Boolean,
|
||||
@ -22,6 +34,7 @@ from ormar.models import Model
|
||||
from ormar.models.metaclass import ModelMeta
|
||||
from ormar.queryset import QuerySet
|
||||
from ormar.relations import RelationType
|
||||
from ormar.signals import Signal
|
||||
|
||||
|
||||
class UndefinedType: # pragma no cover
|
||||
@ -31,7 +44,7 @@ class UndefinedType: # pragma no cover
|
||||
|
||||
Undefined = UndefinedType()
|
||||
|
||||
__version__ = "0.6.2"
|
||||
__version__ = "0.7.0"
|
||||
__all__ = [
|
||||
"Integer",
|
||||
"BigInteger",
|
||||
@ -47,7 +60,6 @@ __all__ = [
|
||||
"ManyToMany",
|
||||
"Model",
|
||||
"ModelDefinitionError",
|
||||
"ModelNotSet",
|
||||
"MultipleMatches",
|
||||
"NoMatch",
|
||||
"ForeignKey",
|
||||
@ -60,4 +72,11 @@ __all__ = [
|
||||
"RelationProtocol",
|
||||
"ModelMeta",
|
||||
"property_field",
|
||||
"post_delete",
|
||||
"post_save",
|
||||
"post_update",
|
||||
"pre_delete",
|
||||
"pre_save",
|
||||
"pre_update",
|
||||
"Signal",
|
||||
]
|
||||
|
||||
@ -1,5 +1,19 @@
|
||||
from ormar.decorators.property_field import property_field
|
||||
from ormar.decorators.signals import (
|
||||
post_delete,
|
||||
post_save,
|
||||
post_update,
|
||||
pre_delete,
|
||||
pre_save,
|
||||
pre_update,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"property_field",
|
||||
"post_delete",
|
||||
"post_save",
|
||||
"post_update",
|
||||
"pre_delete",
|
||||
"pre_save",
|
||||
"pre_update",
|
||||
]
|
||||
|
||||
@ -13,7 +13,7 @@ def property_field(func: Callable) -> Union[property, Callable]:
|
||||
if len(arguments) > 1 or arguments[0] != "self":
|
||||
raise ModelDefinitionError(
|
||||
"property_field decorator can be used "
|
||||
"only on class methods with no arguments"
|
||||
"only on methods with no arguments"
|
||||
)
|
||||
func.__dict__["__property_field__"] = True
|
||||
return func
|
||||
|
||||
44
ormar/decorators/signals.py
Normal file
44
ormar/decorators/signals.py
Normal file
@ -0,0 +1,44 @@
|
||||
from typing import Callable, List, TYPE_CHECKING, Type, Union
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ormar import Model
|
||||
|
||||
|
||||
def receiver(
|
||||
signal: str, senders: Union[Type["Model"], List[Type["Model"]]]
|
||||
) -> Callable:
|
||||
def _decorator(func: Callable) -> Callable:
|
||||
if not isinstance(senders, list):
|
||||
_senders = [senders]
|
||||
else:
|
||||
_senders = senders
|
||||
for sender in _senders:
|
||||
signals = getattr(sender.Meta.signals, signal)
|
||||
signals.connect(func)
|
||||
return func
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
def post_save(senders: Union[Type["Model"], List[Type["Model"]]],) -> Callable:
|
||||
return receiver(signal="post_save", senders=senders)
|
||||
|
||||
|
||||
def post_update(senders: Union[Type["Model"], List[Type["Model"]]],) -> Callable:
|
||||
return receiver(signal="post_update", senders=senders)
|
||||
|
||||
|
||||
def post_delete(senders: Union[Type["Model"], List[Type["Model"]]],) -> Callable:
|
||||
return receiver(signal="post_delete", senders=senders)
|
||||
|
||||
|
||||
def pre_save(senders: Union[Type["Model"], List[Type["Model"]]],) -> Callable:
|
||||
return receiver(signal="pre_save", senders=senders)
|
||||
|
||||
|
||||
def pre_update(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
|
||||
return receiver(signal="pre_update", senders=senders)
|
||||
|
||||
|
||||
def pre_delete(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
|
||||
return receiver(signal="pre_delete", senders=senders)
|
||||
@ -1,28 +1,57 @@
|
||||
class AsyncOrmException(Exception):
|
||||
"""
|
||||
Base ormar Exception
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelDefinitionError(AsyncOrmException):
|
||||
"""
|
||||
Raised for errors related to the model definition itself.
|
||||
* setting @property_field on method with arguments other than func(self)
|
||||
* defining a Field without required parameters
|
||||
* defining a model with more than one primary_key
|
||||
* defining a model without primary_key
|
||||
* setting primary_key column as pydantic_only
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelError(AsyncOrmException):
|
||||
pass
|
||||
"""
|
||||
Raised for initialization of model with non-existing field keyword.
|
||||
"""
|
||||
|
||||
|
||||
class ModelNotSet(AsyncOrmException):
|
||||
pass
|
||||
|
||||
|
||||
class NoMatch(AsyncOrmException):
|
||||
"""
|
||||
Raised for database queries that has no matching result (empty result).
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MultipleMatches(AsyncOrmException):
|
||||
"""
|
||||
Raised for database queries that should return one row (i.e. get, first etc.)
|
||||
but has multiple matching results in response.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QueryDefinitionError(AsyncOrmException):
|
||||
"""
|
||||
Raised for errors in query definition.
|
||||
* using contains or icontains filter with instance of the Model
|
||||
* using Queryset.update() without filter and setting each flag to True
|
||||
* using Queryset.delete() without filter and setting each flag to True
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -31,4 +60,17 @@ class RelationshipInstanceError(AsyncOrmException):
|
||||
|
||||
|
||||
class ModelPersistenceError(AsyncOrmException):
|
||||
"""
|
||||
Raised for update of models without primary_key set (cannot retrieve from db)
|
||||
or for saving a model with relation to unsaved model (cannot extract fk value).
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SignalDefinitionError(AsyncOrmException):
|
||||
"""
|
||||
Raised when non callable receiver is passed as signal callback.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@ -18,6 +18,7 @@ from ormar.fields.many_to_many import ManyToMany, ManyToManyField
|
||||
from ormar.models.quick_access_views import quick_access_set
|
||||
from ormar.queryset import QuerySet
|
||||
from ormar.relations.alias_manager import AliasManager
|
||||
from ormar.signals import Signal, SignalEmitter
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
@ -38,6 +39,7 @@ class ModelMeta:
|
||||
]
|
||||
alias_manager: AliasManager
|
||||
property_fields: Set
|
||||
signals: SignalEmitter
|
||||
|
||||
|
||||
def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None:
|
||||
@ -332,15 +334,12 @@ def add_cached_properties(new_model: Type["Model"]) -> None:
|
||||
new_model._pydantic_fields = {name for name in new_model.__fields__}
|
||||
|
||||
|
||||
def property_fields_not_set(new_model: Type["Model"]) -> bool:
|
||||
return (
|
||||
not hasattr(new_model.Meta, "property_fields")
|
||||
or not new_model.Meta.property_fields
|
||||
)
|
||||
def meta_field_not_set(model: Type["Model"], field_name: str) -> bool:
|
||||
return not hasattr(model.Meta, field_name) or not getattr(model.Meta, field_name)
|
||||
|
||||
|
||||
def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: CCR001
|
||||
if property_fields_not_set(new_model):
|
||||
if meta_field_not_set(model=new_model, field_name="property_fields"):
|
||||
props = set()
|
||||
for var_name, value in attrs.items():
|
||||
if isinstance(value, property):
|
||||
@ -351,6 +350,18 @@ def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa:
|
||||
new_model.Meta.property_fields = props
|
||||
|
||||
|
||||
def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001
|
||||
if meta_field_not_set(model=new_model, field_name="signals"):
|
||||
signals = SignalEmitter()
|
||||
signals.pre_save = Signal()
|
||||
signals.pre_update = Signal()
|
||||
signals.pre_delete = Signal()
|
||||
signals.post_save = Signal()
|
||||
signals.post_update = Signal()
|
||||
signals.post_delete = Signal()
|
||||
new_model.Meta.signals = signals
|
||||
|
||||
|
||||
class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
def __new__( # type: ignore
|
||||
mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict
|
||||
@ -379,5 +390,6 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
new_model.Meta.alias_manager = alias_manager
|
||||
new_model.objects = QuerySet(new_model)
|
||||
add_property_fields(new_model, attrs)
|
||||
register_signals(new_model=new_model)
|
||||
|
||||
return new_model
|
||||
|
||||
@ -195,6 +195,9 @@ class Model(NewBaseModel):
|
||||
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
|
||||
self_fields.pop(self.Meta.pkname, None)
|
||||
self_fields = self.populate_default_values(self_fields)
|
||||
self.from_dict(self_fields)
|
||||
|
||||
await self.signals.pre_save.send(sender=self.__class__, instance=self)
|
||||
|
||||
self_fields = self.translate_columns_to_aliases(self_fields)
|
||||
expr = self.Meta.table.insert()
|
||||
@ -204,6 +207,7 @@ class Model(NewBaseModel):
|
||||
if pk and isinstance(pk, self.pk_type()):
|
||||
setattr(self, self.Meta.pkname, pk)
|
||||
|
||||
self.set_save_status(True)
|
||||
# refresh server side defaults
|
||||
if any(
|
||||
field.server_default is not None
|
||||
@ -211,9 +215,8 @@ class Model(NewBaseModel):
|
||||
if name not in self_fields
|
||||
):
|
||||
await self.load()
|
||||
return self
|
||||
|
||||
self.set_save_status(True)
|
||||
await self.signals.post_save.send(sender=self.__class__, instance=self)
|
||||
return self
|
||||
|
||||
async def save_related( # noqa: CCR001
|
||||
@ -268,6 +271,7 @@ class Model(NewBaseModel):
|
||||
"You cannot update not saved model! Use save or upsert method."
|
||||
)
|
||||
|
||||
await self.signals.pre_update.send(sender=self.__class__, instance=self)
|
||||
self_fields = self._extract_model_db_fields()
|
||||
self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname))
|
||||
self_fields = self.translate_columns_to_aliases(self_fields)
|
||||
@ -276,13 +280,16 @@ class Model(NewBaseModel):
|
||||
|
||||
await self.Meta.database.execute(expr)
|
||||
self.set_save_status(True)
|
||||
await self.signals.post_update.send(sender=self.__class__, instance=self)
|
||||
return self
|
||||
|
||||
async def delete(self: T) -> int:
|
||||
await self.signals.pre_delete.send(sender=self.__class__, instance=self)
|
||||
expr = self.Meta.table.delete()
|
||||
expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname)))
|
||||
result = await self.Meta.database.execute(expr)
|
||||
self.set_save_status(False)
|
||||
await self.signals.post_delete.send(sender=self.__class__, instance=self)
|
||||
return result
|
||||
|
||||
async def load(self: T) -> T:
|
||||
|
||||
@ -35,6 +35,7 @@ from ormar.relations.relation_manager import RelationsManager
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
from ormar.signals import SignalEmitter
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
|
||||
@ -47,11 +48,7 @@ if TYPE_CHECKING: # pragma no cover
|
||||
class NewBaseModel(
|
||||
pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass
|
||||
):
|
||||
__slots__ = (
|
||||
"_orm_id",
|
||||
"_orm_saved",
|
||||
"_orm",
|
||||
)
|
||||
__slots__ = ("_orm_id", "_orm_saved", "_orm", "_pk_column")
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
__model_fields__: Dict[str, Type[BaseField]]
|
||||
@ -75,6 +72,7 @@ class NewBaseModel(
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore
|
||||
object.__setattr__(self, "_orm_id", uuid.uuid4().hex)
|
||||
object.__setattr__(self, "_orm_saved", False)
|
||||
object.__setattr__(self, "_pk_column", None)
|
||||
object.__setattr__(
|
||||
self,
|
||||
"_orm",
|
||||
@ -94,13 +92,8 @@ class NewBaseModel(
|
||||
if "pk" in kwargs:
|
||||
kwargs[self.Meta.pkname] = kwargs.pop("pk")
|
||||
|
||||
# remove property fields values from validation
|
||||
kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k not in object.__getattribute__(self, "Meta").property_fields
|
||||
}
|
||||
# build the models to set them and validate but don't register
|
||||
# also remove property fields values from validation
|
||||
try:
|
||||
new_kwargs: Dict[str, Any] = {
|
||||
k: self._convert_json(
|
||||
@ -111,14 +104,15 @@ class NewBaseModel(
|
||||
"dumps",
|
||||
)
|
||||
for k, v in kwargs.items()
|
||||
if k not in object.__getattribute__(self, "Meta").property_fields
|
||||
}
|
||||
except KeyError as e:
|
||||
raise ModelError(
|
||||
f"Unknown field '{e.args[0]}' for model {self.get_name(lower=False)}"
|
||||
)
|
||||
|
||||
# explicitly set None to excluded fields with default
|
||||
# as pydantic populates them with default
|
||||
# explicitly set None to excluded fields
|
||||
# as pydantic populates them with default if set
|
||||
for field_to_nullify in excluded:
|
||||
new_kwargs[field_to_nullify] = None
|
||||
|
||||
@ -195,7 +189,8 @@ class NewBaseModel(
|
||||
return (
|
||||
self._orm_id == other._orm_id
|
||||
or (self.pk == other.pk and self.pk is not None)
|
||||
or self.dict() == other.dict()
|
||||
or self.dict(exclude=self.extract_related_names())
|
||||
== other.dict(exclude=other.extract_related_names())
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -207,12 +202,21 @@ class NewBaseModel(
|
||||
|
||||
@property
|
||||
def pk_column(self) -> sqlalchemy.Column:
|
||||
return self.Meta.table.primary_key.columns.values()[0]
|
||||
if object.__getattribute__(self, "_pk_column") is not None:
|
||||
return object.__getattribute__(self, "_pk_column")
|
||||
pk_columns = self.Meta.table.primary_key.columns.values()
|
||||
pk_col = pk_columns[0]
|
||||
object.__setattr__(self, "_pk_column", pk_col)
|
||||
return pk_col
|
||||
|
||||
@property
|
||||
def saved(self) -> bool:
|
||||
return self._orm_saved
|
||||
|
||||
@property
|
||||
def signals(self) -> "SignalEmitter":
|
||||
return self.Meta.signals
|
||||
|
||||
@classmethod
|
||||
def pk_type(cls) -> Any:
|
||||
return cls.Meta.model_fields[cls.Meta.pkname].__type__
|
||||
|
||||
0
ormar/py.typed
Normal file
0
ormar/py.typed
Normal file
@ -6,7 +6,7 @@ from sqlalchemy import bindparam
|
||||
|
||||
import ormar # noqa I100
|
||||
from ormar import MultipleMatches, NoMatch
|
||||
from ormar.exceptions import QueryDefinitionError
|
||||
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
|
||||
from ormar.queryset import FilterQuery
|
||||
from ormar.queryset.clause import QueryClause
|
||||
from ormar.queryset.prefetch_query import PrefetchQuery
|
||||
@ -395,6 +395,9 @@ class QuerySet:
|
||||
expr = expr.values(**new_kwargs)
|
||||
|
||||
instance = self.model(**kwargs)
|
||||
await self.model.Meta.signals.pre_save.send(
|
||||
sender=self.model, instance=instance
|
||||
)
|
||||
pk = await self.database.execute(expr)
|
||||
|
||||
pk_name = self.model.get_column_alias(self.model_meta.pkname)
|
||||
@ -411,6 +414,9 @@ class QuerySet:
|
||||
):
|
||||
instance = await instance.load()
|
||||
instance.set_save_status(True)
|
||||
await self.model.Meta.signals.post_save.send(
|
||||
sender=self.model, instance=instance
|
||||
)
|
||||
return instance
|
||||
|
||||
async def bulk_create(self, objects: List["Model"]) -> None:
|
||||
@ -446,7 +452,7 @@ class QuerySet:
|
||||
for objt in objects:
|
||||
new_kwargs = objt.dict()
|
||||
if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None:
|
||||
raise QueryDefinitionError(
|
||||
raise ModelPersistenceError(
|
||||
"You cannot update unsaved objects. "
|
||||
f"{self.model.__name__} has to have {pk_name} filled."
|
||||
)
|
||||
|
||||
3
ormar/signals/__init__.py
Normal file
3
ormar/signals/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from ormar.signals.signal import Signal, SignalEmitter
|
||||
|
||||
__all__ = ["Signal", "SignalEmitter"]
|
||||
71
ormar/signals/signal.py
Normal file
71
ormar/signals/signal.py
Normal file
@ -0,0 +1,71 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Tuple, Type, Union
|
||||
|
||||
from ormar.exceptions import SignalDefinitionError
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ormar import Model
|
||||
|
||||
|
||||
def callable_accepts_kwargs(func: Callable) -> bool:
|
||||
return any(
|
||||
p
|
||||
for p in inspect.signature(func).parameters.values()
|
||||
if p.kind == p.VAR_KEYWORD
|
||||
)
|
||||
|
||||
|
||||
def make_id(target: Any) -> Union[int, Tuple[int, int]]:
|
||||
if hasattr(target, "__func__"):
|
||||
return id(target.__self__), id(target.__func__)
|
||||
return id(target)
|
||||
|
||||
|
||||
class Signal:
|
||||
def __init__(self) -> None:
|
||||
self._receivers: List[Tuple[Union[int, Tuple[int, int]], Callable]] = []
|
||||
|
||||
def connect(self, receiver: Callable) -> None:
|
||||
if not callable(receiver):
|
||||
raise SignalDefinitionError("Signal receivers must be callable.")
|
||||
if not callable_accepts_kwargs(receiver):
|
||||
raise SignalDefinitionError(
|
||||
"Signal receivers must accept **kwargs argument."
|
||||
)
|
||||
new_receiver_key = make_id(receiver)
|
||||
if not any(rec_id == new_receiver_key for rec_id, _ in self._receivers):
|
||||
self._receivers.append((new_receiver_key, receiver))
|
||||
|
||||
def disconnect(self, receiver: Callable) -> bool:
|
||||
removed = False
|
||||
new_receiver_key = make_id(receiver)
|
||||
for ind, rec in enumerate(self._receivers):
|
||||
rec_id, _ = rec
|
||||
if rec_id == new_receiver_key:
|
||||
removed = True
|
||||
del self._receivers[ind]
|
||||
break
|
||||
return removed
|
||||
|
||||
async def send(self, sender: Type["Model"], **kwargs: Any) -> None:
|
||||
receivers = []
|
||||
for receiver in self._receivers:
|
||||
_, receiver_func = receiver
|
||||
receivers.append(receiver_func(sender=sender, **kwargs))
|
||||
await asyncio.gather(*receivers)
|
||||
|
||||
|
||||
class SignalEmitter:
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
signals: Dict[str, Signal]
|
||||
|
||||
def __init__(self) -> None:
|
||||
object.__setattr__(self, "signals", dict())
|
||||
|
||||
def __getattr__(self, item: str) -> Signal:
|
||||
return self.signals.setdefault(item, Signal())
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
signals = object.__getattribute__(self, "signals")
|
||||
signals[key] = value
|
||||
@ -1,19 +0,0 @@
|
||||
#!/bin/sh -e
|
||||
PACKAGE="ormar"
|
||||
if [ -d 'dist' ] ; then
|
||||
rm -r dist
|
||||
fi
|
||||
if [ -d 'site' ] ; then
|
||||
rm -r site
|
||||
fi
|
||||
if [ -d 'htmlcov' ] ; then
|
||||
rm -r htmlcov
|
||||
fi
|
||||
if [ -d "${PACKAGE}.egg-info" ] ; then
|
||||
rm -r "${PACKAGE}.egg-info"
|
||||
fi
|
||||
find ${PACKAGE} -type f -name "*.py[co]" -delete
|
||||
find ${PACKAGE} -type d -name __pycache__ -delete
|
||||
|
||||
find tests -type f -name "*.py[co]" -delete
|
||||
find tests -type d -name __pycache__ -delete
|
||||
@ -1,23 +0,0 @@
|
||||
#!/bin/sh -e
|
||||
|
||||
PACKAGE="ormar"
|
||||
|
||||
PREFIX=""
|
||||
if [ -d 'venv' ] ; then
|
||||
PREFIX="venv/bin/"
|
||||
fi
|
||||
|
||||
VERSION=`cat ${PACKAGE}/__init__.py | grep __version__ | sed "s/__version__ = //" | sed "s/'//g"`
|
||||
|
||||
set -x
|
||||
|
||||
scripts/clean.sh
|
||||
|
||||
${PREFIX}python setup.py sdist
|
||||
${PREFIX}twine upload dist/*
|
||||
|
||||
echo "You probably want to also tag the version now:"
|
||||
echo "git tag -a ${VERSION} -m 'version ${VERSION}'"
|
||||
echo "git push --tags"
|
||||
|
||||
scripts/clean.sh
|
||||
4
setup.py
4
setup.py
@ -50,6 +50,8 @@ setup(
|
||||
author_email="collerek@gmail.com",
|
||||
packages=get_packages(PACKAGE),
|
||||
package_data={PACKAGE: ["py.typed"]},
|
||||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
data_files=[("", ["LICENSE.md"])],
|
||||
install_requires=["databases", "pydantic>=1.5", "sqlalchemy", "typing_extensions"],
|
||||
extras_require={
|
||||
@ -65,9 +67,11 @@ setup(
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Topic :: Internet :: WWW/HTTP",
|
||||
"Framework :: AsyncIO",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
],
|
||||
)
|
||||
|
||||
@ -10,7 +10,7 @@ from fastapi import FastAPI
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
import ormar
|
||||
from ormar import property_field
|
||||
from ormar import post_save, property_field
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
app = FastAPI()
|
||||
@ -65,8 +65,6 @@ class RandomModel(ormar.Model):
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
include_props_in_dict = True
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
password: str = ormar.String(max_length=255, default=gen_pass)
|
||||
first_name: str = ormar.String(max_length=255, default="John")
|
||||
@ -219,7 +217,7 @@ def test_excluding_fields_in_endpoints():
|
||||
assert isinstance(user_instance.timestamp, datetime.datetime)
|
||||
assert user_instance.timestamp == timestamp
|
||||
|
||||
response = client.post("/users4/", json=user)
|
||||
response = client.post("/users4/", json=user3)
|
||||
assert list(response.json().keys()) == [
|
||||
"id",
|
||||
"email",
|
||||
@ -228,8 +226,12 @@ def test_excluding_fields_in_endpoints():
|
||||
"category",
|
||||
"timestamp",
|
||||
]
|
||||
assert response.json().get("timestamp") != str(timestamp).replace(" ", "T")
|
||||
assert response.json().get("timestamp") is not None
|
||||
assert (
|
||||
datetime.datetime.strptime(
|
||||
response.json().get("timestamp"), "%Y-%m-%dT%H:%M:%S.%f"
|
||||
)
|
||||
== timestamp
|
||||
)
|
||||
|
||||
|
||||
def test_adding_fields_in_endpoints():
|
||||
@ -247,7 +249,6 @@ def test_adding_fields_in_endpoints():
|
||||
]
|
||||
assert response.json().get("full_name") == "John Test"
|
||||
|
||||
RandomModel.Meta.include_props_in_fields = False
|
||||
user3 = {"last_name": "Test"}
|
||||
response = client.post("/random/", json=user3)
|
||||
assert list(response.json().keys()) == [
|
||||
@ -264,7 +265,6 @@ def test_adding_fields_in_endpoints():
|
||||
def test_adding_fields_in_endpoints2():
|
||||
client = TestClient(app)
|
||||
with client as client:
|
||||
RandomModel.Meta.include_props_in_dict = True
|
||||
user3 = {"last_name": "Test"}
|
||||
response = client.post("/random2/", json=user3)
|
||||
assert list(response.json().keys()) == [
|
||||
@ -279,9 +279,15 @@ def test_adding_fields_in_endpoints2():
|
||||
|
||||
|
||||
def test_excluding_property_field_in_endpoints2():
|
||||
|
||||
dummy_registry = {}
|
||||
|
||||
@post_save(RandomModel)
|
||||
async def after_save(sender, instance, **kwargs):
|
||||
dummy_registry[instance.pk] = instance.dict()
|
||||
|
||||
client = TestClient(app)
|
||||
with client as client:
|
||||
RandomModel.Meta.include_props_in_dict = True
|
||||
user3 = {"last_name": "Test"}
|
||||
response = client.post("/random3/", json=user3)
|
||||
assert list(response.json().keys()) == [
|
||||
@ -292,3 +298,7 @@ def test_excluding_property_field_in_endpoints2():
|
||||
"created_date",
|
||||
]
|
||||
assert response.json().get("full_name") is None
|
||||
assert len(dummy_registry) == 1
|
||||
check_dict = dummy_registry.get(response.json().get("id"))
|
||||
check_dict.pop("full_name")
|
||||
assert response.json().get("password") == check_dict.get("password")
|
||||
|
||||
@ -5,7 +5,7 @@ import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import QueryDefinitionError
|
||||
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
@ -302,7 +302,7 @@ async def test_bulk_update_with_relation():
|
||||
async def test_bulk_update_not_saved_objts():
|
||||
async with database:
|
||||
category = await Category.objects.create(name="Sample Category")
|
||||
with pytest.raises(QueryDefinitionError):
|
||||
with pytest.raises(ModelPersistenceError):
|
||||
await Note.objects.bulk_update(
|
||||
[
|
||||
Note(text="Buy the groceries.", category=category),
|
||||
|
||||
352
tests/test_signals.py
Normal file
352
tests/test_signals.py
Normal file
@ -0,0 +1,352 @@
|
||||
from typing import Optional
|
||||
|
||||
import databases
|
||||
import pydantic
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from ormar import (
|
||||
post_delete,
|
||||
post_save,
|
||||
post_update,
|
||||
pre_delete,
|
||||
pre_save,
|
||||
pre_update,
|
||||
)
|
||||
from ormar.exceptions import SignalDefinitionError
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
class AuditLog(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "audits"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
event_type: str = ormar.String(max_length=100)
|
||||
event_log: pydantic.Json = ormar.JSON()
|
||||
|
||||
|
||||
class Cover(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "covers"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
title: str = ormar.String(max_length=100)
|
||||
|
||||
|
||||
class Album(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "albums"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100)
|
||||
is_best_seller: bool = ormar.Boolean(default=False)
|
||||
play_count: int = ormar.Integer(default=0)
|
||||
cover: Optional[Cover] = ormar.ForeignKey(Cover)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def create_test_database():
|
||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||
metadata.drop_all(engine)
|
||||
metadata.create_all(engine)
|
||||
yield
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def cleanup():
|
||||
yield
|
||||
async with database:
|
||||
await AuditLog.objects.delete(each=True)
|
||||
|
||||
|
||||
def test_passing_not_callable():
|
||||
with pytest.raises(SignalDefinitionError):
|
||||
pre_save(Album)("wrong")
|
||||
|
||||
|
||||
def test_passing_callable_without_kwargs():
|
||||
with pytest.raises(SignalDefinitionError):
|
||||
@pre_save(Album)
|
||||
def trigger(sender, instance): # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_functions(cleanup):
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
@pre_save(Album)
|
||||
async def before_save(sender, instance, **kwargs):
|
||||
await AuditLog(
|
||||
event_type=f"PRE_SAVE_{sender.get_name()}",
|
||||
event_log=instance.json(),
|
||||
).save()
|
||||
|
||||
@post_save(Album)
|
||||
async def after_save(sender, instance, **kwargs):
|
||||
await AuditLog(
|
||||
event_type=f"POST_SAVE_{sender.get_name()}",
|
||||
event_log=instance.json(),
|
||||
).save()
|
||||
|
||||
@pre_update(Album)
|
||||
async def before_update(sender, instance, **kwargs):
|
||||
await AuditLog(
|
||||
event_type=f"PRE_UPDATE_{sender.get_name()}",
|
||||
event_log=instance.json(),
|
||||
).save()
|
||||
|
||||
@post_update(Album)
|
||||
async def after_update(sender, instance, **kwargs):
|
||||
await AuditLog(
|
||||
event_type=f"POST_UPDATE_{sender.get_name()}",
|
||||
event_log=instance.json(),
|
||||
).save()
|
||||
|
||||
@pre_delete(Album)
|
||||
async def before_delete(sender, instance, **kwargs):
|
||||
await AuditLog(
|
||||
event_type=f"PRE_DELETE_{sender.get_name()}",
|
||||
event_log=instance.json(),
|
||||
).save()
|
||||
|
||||
@post_delete(Album)
|
||||
async def after_delete(sender, instance, **kwargs):
|
||||
await AuditLog(
|
||||
event_type=f"POST_DELETE_{sender.get_name()}",
|
||||
event_log=instance.json(),
|
||||
).save()
|
||||
|
||||
album = await Album.objects.create(name="Venice")
|
||||
|
||||
audits = await AuditLog.objects.all()
|
||||
assert len(audits) == 2
|
||||
assert audits[0].event_type == "PRE_SAVE_album"
|
||||
assert audits[0].event_log.get("name") == album.name
|
||||
assert audits[1].event_type == "POST_SAVE_album"
|
||||
assert audits[1].event_log.get("id") == album.pk
|
||||
|
||||
album = await Album(name="Rome").save()
|
||||
audits = await AuditLog.objects.all()
|
||||
assert len(audits) == 4
|
||||
assert audits[2].event_type == "PRE_SAVE_album"
|
||||
assert audits[2].event_log.get("name") == album.name
|
||||
assert audits[3].event_type == "POST_SAVE_album"
|
||||
assert audits[3].event_log.get("id") == album.pk
|
||||
|
||||
album.is_best_seller = True
|
||||
await album.update()
|
||||
|
||||
audits = await AuditLog.objects.filter(event_type__contains="UPDATE").all()
|
||||
assert len(audits) == 2
|
||||
assert audits[0].event_type == "PRE_UPDATE_album"
|
||||
assert audits[0].event_log.get("name") == album.name
|
||||
assert audits[1].event_type == "POST_UPDATE_album"
|
||||
assert audits[1].event_log.get("is_best_seller") == album.is_best_seller
|
||||
|
||||
album.signals.pre_update.disconnect(before_update)
|
||||
album.signals.post_update.disconnect(after_update)
|
||||
|
||||
album.is_best_seller = False
|
||||
await album.update()
|
||||
|
||||
audits = await AuditLog.objects.filter(event_type__contains="UPDATE").all()
|
||||
assert len(audits) == 2
|
||||
|
||||
await album.delete()
|
||||
audits = await AuditLog.objects.filter(event_type__contains="DELETE").all()
|
||||
assert len(audits) == 2
|
||||
assert audits[0].event_type == "PRE_DELETE_album"
|
||||
assert (
|
||||
audits[0].event_log.get("id")
|
||||
== audits[1].event_log.get("id")
|
||||
== album.id
|
||||
)
|
||||
assert audits[1].event_type == "POST_DELETE_album"
|
||||
|
||||
album.signals.pre_delete.disconnect(before_delete)
|
||||
album.signals.post_delete.disconnect(after_delete)
|
||||
album.signals.pre_save.disconnect(before_save)
|
||||
album.signals.post_save.disconnect(after_save)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_signals(cleanup):
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
@pre_save(Album)
|
||||
async def before_save(sender, instance, **kwargs):
|
||||
await AuditLog(
|
||||
event_type=f"PRE_SAVE_{sender.get_name()}",
|
||||
event_log=instance.json(),
|
||||
).save()
|
||||
|
||||
@pre_save(Album)
|
||||
async def before_save2(sender, instance, **kwargs):
|
||||
await AuditLog(
|
||||
event_type=f"PRE_SAVE_{sender.get_name()}",
|
||||
event_log=instance.json(),
|
||||
).save()
|
||||
|
||||
album = await Album.objects.create(name="Miami")
|
||||
audits = await AuditLog.objects.all()
|
||||
assert len(audits) == 2
|
||||
assert audits[0].event_type == "PRE_SAVE_album"
|
||||
assert audits[0].event_log.get("name") == album.name
|
||||
assert audits[1].event_type == "PRE_SAVE_album"
|
||||
assert audits[1].event_log.get("name") == album.name
|
||||
|
||||
album.signals.pre_save.disconnect(before_save)
|
||||
album.signals.pre_save.disconnect(before_save2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_methods_as_signals(cleanup):
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
class AlbumAuditor:
|
||||
event_type = "ALBUM_INSTANCE"
|
||||
|
||||
@staticmethod
|
||||
@pre_save(Album)
|
||||
async def before_save(sender, instance, **kwargs):
|
||||
await AuditLog(
|
||||
event_type=f"{AlbumAuditor.event_type}_SAVE",
|
||||
event_log=instance.json(),
|
||||
).save()
|
||||
|
||||
album = await Album.objects.create(name="Colorado")
|
||||
audits = await AuditLog.objects.all()
|
||||
assert len(audits) == 1
|
||||
assert audits[0].event_type == "ALBUM_INSTANCE_SAVE"
|
||||
assert audits[0].event_log.get("name") == album.name
|
||||
|
||||
album.signals.pre_save.disconnect(AlbumAuditor.before_save)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_methods_as_signals(cleanup):
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
class AlbumAuditor:
|
||||
def __init__(self):
|
||||
self.event_type = "ALBUM_INSTANCE"
|
||||
|
||||
async def before_save(self, sender, instance, **kwargs):
|
||||
await AuditLog(
|
||||
event_type=f"{self.event_type}_SAVE", event_log=instance.json()
|
||||
).save()
|
||||
|
||||
auditor = AlbumAuditor()
|
||||
pre_save(Album)(auditor.before_save)
|
||||
|
||||
album = await Album.objects.create(name="San Francisco")
|
||||
audits = await AuditLog.objects.all()
|
||||
assert len(audits) == 1
|
||||
assert audits[0].event_type == "ALBUM_INSTANCE_SAVE"
|
||||
assert audits[0].event_log.get("name") == album.name
|
||||
|
||||
album.signals.pre_save.disconnect(auditor.before_save)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_senders_signal(cleanup):
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
@pre_save([Album, Cover])
|
||||
async def before_save(sender, instance, **kwargs):
|
||||
await AuditLog(
|
||||
event_type=f"PRE_SAVE_{sender.get_name()}",
|
||||
event_log=instance.json(),
|
||||
).save()
|
||||
|
||||
cover = await Cover(title="Blue").save()
|
||||
album = await Album.objects.create(name="San Francisco", cover=cover)
|
||||
|
||||
audits = await AuditLog.objects.all()
|
||||
assert len(audits) == 2
|
||||
assert audits[0].event_type == "PRE_SAVE_cover"
|
||||
assert audits[0].event_log.get("title") == cover.title
|
||||
assert audits[1].event_type == "PRE_SAVE_album"
|
||||
assert audits[1].event_log.get("cover") == album.cover.dict(
|
||||
exclude={"albums"}
|
||||
)
|
||||
|
||||
album.signals.pre_save.disconnect(before_save)
|
||||
cover.signals.pre_save.disconnect(before_save)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modifing_the_instance(cleanup):
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
@pre_update(Album)
|
||||
async def before_update(sender, instance, **kwargs):
|
||||
if instance.play_count > 50 and not instance.is_best_seller:
|
||||
instance.is_best_seller = True
|
||||
|
||||
# here album.play_count ans is_best_seller get default values
|
||||
album = await Album.objects.create(name="Venice")
|
||||
assert not album.is_best_seller
|
||||
assert album.play_count == 0
|
||||
|
||||
album.play_count = 30
|
||||
# here a trigger is called but play_count is too low
|
||||
await album.update()
|
||||
assert not album.is_best_seller
|
||||
|
||||
album.play_count = 60
|
||||
await album.update()
|
||||
assert album.is_best_seller
|
||||
album.signals.pre_update.disconnect(before_update)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_signal(cleanup):
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
|
||||
async def after_update(sender, instance, **kwargs):
|
||||
if instance.play_count > 50 and not instance.is_best_seller:
|
||||
instance.is_best_seller = True
|
||||
elif instance.play_count < 50 and instance.is_best_seller:
|
||||
instance.is_best_seller = False
|
||||
await instance.update()
|
||||
|
||||
Album.Meta.signals.custom.connect(after_update)
|
||||
|
||||
# here album.play_count ans is_best_seller get default values
|
||||
album = await Album.objects.create(name="Venice")
|
||||
assert not album.is_best_seller
|
||||
assert album.play_count == 0
|
||||
|
||||
album.play_count = 30
|
||||
# here a trigger is called but play_count is too low
|
||||
await album.update()
|
||||
assert not album.is_best_seller
|
||||
|
||||
album.play_count = 60
|
||||
await album.update()
|
||||
assert not album.is_best_seller
|
||||
await Album.Meta.signals.custom.send(sender=Album, instance=album)
|
||||
assert album.is_best_seller
|
||||
|
||||
album.play_count = 30
|
||||
await album.update()
|
||||
assert album.is_best_seller
|
||||
await Album.Meta.signals.custom.send(sender=Album, instance=album)
|
||||
assert not album.is_best_seller
|
||||
|
||||
Album.Meta.signals.custom.disconnect(after_update)
|
||||
Reference in New Issue
Block a user