Merge pull request #66 from collerek/signals

Add Signals
This commit is contained in:
collerek
2020-12-07 02:06:26 +07:00
committed by GitHub
25 changed files with 939 additions and 87 deletions

View File

@ -203,6 +203,7 @@ The following keyword arguments are supported on all field types.
* `unique: bool`
* `choices: typing.Sequence`
* `name: str`
* `pydantic_only: bool`
All fields are required unless one of the following is set:
@ -211,7 +212,18 @@ All fields are required unless one of the following is set:
* `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`).
* `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column.
Autoincrement is set by default on int primary keys.
* `pydantic_only` - Field is available only as normal pydantic field, not stored in the database.
### Available signals
Signals allow to trigger your function for a given event on a given Model.
* `pre_save`
* `post_save`
* `pre_update`
* `post_update`
* `pre_delete`
* `post_delete`
[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/

View File

@ -203,6 +203,7 @@ The following keyword arguments are supported on all field types.
* `unique: bool`
* `choices: typing.Sequence`
* `name: str`
* `pydantic_only: bool`
All fields are required unless one of the following is set:
@ -211,7 +212,18 @@ All fields are required unless one of the following is set:
* `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`).
* `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column.
Autoincrement is set by default on int primary keys.
* `pydantic_only` - Field is available only as normal pydantic field, not stored in the database.
### Available signals
Signals allow to trigger your function for a given event on a given Model.
* `pre_save`
* `post_save`
* `pre_update`
* `post_update`
* `pre_delete`
* `post_delete`
[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/

View File

@ -1,3 +1,13 @@
# 0.7.0
* **Breaking:** QuerySet `bulk_update` method now raises `ModelPersistenceError` for unsaved models passed instead of `QueryDefinitionError`
* **Breaking:** Model initialization with unknown field name now raises `ModelError` instead of `KeyError`
* Added **Signals**, with pre-defined list signals and decorators: `post_delete`, `post_save`, `post_update`, `pre_delete`,
`pre_save`, `pre_update`
* Add `py.typed` and modify `setup.py` for mypy support
* Performance optimization
* Updated docs
# 0.6.2
* Performance optimization
@ -12,7 +22,7 @@
# 0.6.0
* **Breaking:** calling instance.load() when the instance row was deleted from db now raises ormar.NoMatch instead of ValueError
* **Breaking:** calling instance.load() when the instance row was deleted from db now raises `NoMatch` instead of `ValueError`
* **Breaking:** calling add and remove on ReverseForeignKey relation now updates the child model in db setting/removing fk column
* **Breaking:** ReverseForeignKey relation now exposes QuerySetProxy API like ManyToMany relation
* **Breaking:** querying related models from ManyToMany cleans list of related models loaded on parent model:

249
docs/signals.md Normal file
View File

@ -0,0 +1,249 @@
# Signals
Signals are a mechanism to fire your piece of code (function / method) whenever given type of event happens in `ormar`.
To achieve this you need to register your receiver for a given type of signal for selected model(s).
## Defining receivers
Given a sample model like following:
```Python
import databases
import sqlalchemy
import ormar
database = databases.Database("sqlite:///db.sqlite")
metadata = sqlalchemy.MetaData()
class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
is_best_seller: bool = ormar.Boolean(default=False)
play_count: int = ormar.Integer(default=0)
```
You can for example define a trigger that will set `album.is_best_seller` status if it will be played more than 50 times.
Import `pre_update` decorator, for list of currently available decorators/ signals check below.
```Python hl_lines="1"
--8<-- "../docs_src/signals/docs002.py"
```
Define your function.
Note that each receiver function:
* has to be **callable**
* has to accept first **`sender`** argument that receives the class of sending object
* has to accept **`**kwargs`** argument as the parameters send in each `ormar.Signal` can change at any time so your function has to serve them.
* has to be **`async`** cause callbacks are gathered and awaited.
`pre_update` currently sends only one argument apart from `sender` and it's `instance` one.
Note how `pre_update` decorator accepts a `senders` argument that can be a single model or a list of models,
for which you want to run the signal receiver.
Currently there is no way to set signal for all models at once without explicitly passing them all into registration of receiver.
```Python hl_lines="4-7"
--8<-- "../docs_src/signals/docs002.py"
```
!!!note
Note that receivers are defined on a class level -> so even if you connect/disconnect function through instance
it will run/ stop running for all operations on that `ormar.Model` class.
Note that our newly created function has instance and class of the instance so you can easily run database
queries inside your receivers if you want to.
```Python hl_lines="15-22"
--8<-- "../docs_src/signals/docs002.py"
```
You can define same receiver for multiple models at once by passing a list of models to signal decorator.
```python
# define a dummy debug function
@pre_update([Album, Track])
async def before_update(sender, instance, **kwargs):
print(f"{sender.get_name()}: {instance.json()}: {kwargs}")
```
Of course you can also create multiple functions for the same signal and model. Each of them will run at each signal.
```python
@pre_update(Album)
async def before_update(sender, instance, **kwargs):
print(f"{sender.get_name()}: {instance.json()}: {kwargs}")
@pre_update(Album)
async def before_update2(sender, instance, **kwargs):
print(f'About to update {sender.get_name()} with pk: {instance.pk}')
```
Note that `ormar` decorators are the syntactic sugar, you can directly connect your function or method for given signal for
given model. Connect accept only one parameter - your `receiver` function / method.
```python hl_lines="11 13 16"
class AlbumAuditor:
def __init__(self):
self.event_type = "ALBUM_INSTANCE"
async def before_save(self, sender, instance, **kwargs):
await AuditLog(
event_type=f"{self.event_type}_SAVE", event_log=instance.json()
).save()
auditor = AlbumAuditor()
pre_save(Album)(auditor.before_save)
# call above has same result like the one below
Album.Meta.signals.pre_save.connect(auditor.before_save)
# signals are also exposed on instance
album = Album(name='Miami')
album.signals.pre_save.connect(auditor.before_save)
```
!!!warning
Note that signals keep the reference to your receiver (not a `weakref`) so keep that in mind to avoid circular references.
## Disconnecting the receivers
To disconnect the receiver and stop it for running for given model you need to disconnect it.
```python hl_lines="7 10"
@pre_update(Album)
async def before_update(sender, instance, **kwargs):
if instance.play_count > 50 and not instance.is_best_seller:
instance.is_best_seller = True
# disconnect given function from signal for given Model
Album.Meta.signals.pre_save.disconnect(before_save)
# signals are also exposed on instance
album = Album(name='Miami')
album.signals.pre_save.disconnect(before_save)
```
## Available signals
!!!warning
Note that signals are **not** send for:
* bulk operations (`QuerySet.bulk_create` and `QuerySet.bulk_update`) as they are designed for speed.
* queyset table level operations (`QuerySet.update` and `QuerySet.delete`) as they run on the underlying tables
(more lak raw sql update/delete operations) and do not have specific instance.
### pre_save
`pre_save(sender: Type["Model"], instance: "Model")`
Send for `Model.save()` and `Model.objects.create()` methods.
`sender` is a `ormar.Model` class and `instance` is the model to be saved.
### post_save
`post_save(sender: Type["Model"], instance: "Model")`
Send for `Model.save()` and `Model.objects.create()` methods.
`sender` is a `ormar.Model` class and `instance` is the model that was saved.
### pre_update
`pre_update(sender: Type["Model"], instance: "Model")`
Send for `Model.update()` method.
`sender` is a `ormar.Model` class and `instance` is the model to be updated.
### post_update
`post_update(sender: Type["Model"], instance: "Model")`
Send for `Model.update()` method.
`sender` is a `ormar.Model` class and `instance` is the model that was updated.
### pre_delete
`pre_delete(sender: Type["Model"], instance: "Model")`
Send for `Model.save()` and `Model.objects.create()` methods.
`sender` is a `ormar.Model` class and `instance` is the model to be deleted.
### post_delete
`post_delete(sender: Type["Model"], instance: "Model")`
Send for `Model.update()` method.
`sender` is a `ormar.Model` class and `instance` is the model that was deleted.
## Defining your own signals
Note that you can create your own signals although you will have to send them manually in your code or subclass `ormar.Model`
and trigger your signals there.
Creating new signal is super easy. Following example will set a new signal with name your_custom_signal.
```python hl_lines="21"
import databases
import sqlalchemy
import ormar
database = databases.Database("sqlite:///db.sqlite")
metadata = sqlalchemy.MetaData()
class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
is_best_seller: bool = ormar.Boolean(default=False)
play_count: int = ormar.Integer(default=0)
Album.Meta.signals.your_custom_signal = ormar.Signal()
Album.Meta.signals.your_custom_signal.connect(your_receiver_name)
```
Actually under the hood signal is a `SignalEmitter` instance that keeps a dictionary of know signals, and allows you
to access them as attributes. When you try to access a signal that does not exist `SignalEmitter` will create one for you.
So example above can be simplified to. The `Signal` will be created for you.
```
Album.Meta.signals.your_custom_signal.connect(your_receiver_name)
```
Now to trigger this signal you need to call send method of the Signal.
```python
await Album.Meta.signals.your_custom_signal.send(sender=Album)
```
Note that sender is the only required parameter and it should be ormar Model class.
Additional parameters have to be passed as keyword arguments.
```python
await Album.Meta.signals.your_custom_signal.send(sender=Album, my_param=True)
```

View File

View File

@ -0,0 +1,22 @@
from ormar import pre_update
@pre_update(Album)
async def before_update(sender, instance, **kwargs):
if instance.play_count > 50 and not instance.is_best_seller:
instance.is_best_seller = True
# here album.play_count ans is_best_seller get default values
album = await Album.objects.create(name="Venice")
assert not album.is_best_seller
assert album.play_count == 0
album.play_count = 30
# here a trigger is called but play_count is too low
await album.update()
assert not album.is_best_seller
album.play_count = 60
await album.update()
assert album.is_best_seller

View File

@ -7,6 +7,7 @@ nav:
- Fields: fields.md
- Relations: relations.md
- Queries: queries.md
- Signals: signals.md
- Use with Fastapi: fastapi.md
- Use with mypy: mypy.md
- PyCharm plugin: plugin.md

View File

@ -1,6 +1,18 @@
from ormar.decorators import property_field
from ormar.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch
from ormar.decorators import (
post_delete,
post_save,
post_update,
pre_delete,
pre_save,
pre_update,
property_field,
)
from ormar.protocols import QuerySetProtocol, RelationProtocol # noqa: I100
from ormar.exceptions import ( # noqa: I100
ModelDefinitionError,
MultipleMatches,
NoMatch,
)
from ormar.fields import ( # noqa: I100
BigInteger,
Boolean,
@ -22,6 +34,7 @@ from ormar.models import Model
from ormar.models.metaclass import ModelMeta
from ormar.queryset import QuerySet
from ormar.relations import RelationType
from ormar.signals import Signal
class UndefinedType: # pragma no cover
@ -31,7 +44,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType()
__version__ = "0.6.2"
__version__ = "0.7.0"
__all__ = [
"Integer",
"BigInteger",
@ -47,7 +60,6 @@ __all__ = [
"ManyToMany",
"Model",
"ModelDefinitionError",
"ModelNotSet",
"MultipleMatches",
"NoMatch",
"ForeignKey",
@ -60,4 +72,11 @@ __all__ = [
"RelationProtocol",
"ModelMeta",
"property_field",
"post_delete",
"post_save",
"post_update",
"pre_delete",
"pre_save",
"pre_update",
"Signal",
]

View File

@ -1,5 +1,19 @@
from ormar.decorators.property_field import property_field
from ormar.decorators.signals import (
post_delete,
post_save,
post_update,
pre_delete,
pre_save,
pre_update,
)
__all__ = [
"property_field",
"post_delete",
"post_save",
"post_update",
"pre_delete",
"pre_save",
"pre_update",
]

View File

@ -13,7 +13,7 @@ def property_field(func: Callable) -> Union[property, Callable]:
if len(arguments) > 1 or arguments[0] != "self":
raise ModelDefinitionError(
"property_field decorator can be used "
"only on class methods with no arguments"
"only on methods with no arguments"
)
func.__dict__["__property_field__"] = True
return func

View File

@ -0,0 +1,44 @@
from typing import Callable, List, TYPE_CHECKING, Type, Union
if TYPE_CHECKING: # pragma: no cover
from ormar import Model
def receiver(
signal: str, senders: Union[Type["Model"], List[Type["Model"]]]
) -> Callable:
def _decorator(func: Callable) -> Callable:
if not isinstance(senders, list):
_senders = [senders]
else:
_senders = senders
for sender in _senders:
signals = getattr(sender.Meta.signals, signal)
signals.connect(func)
return func
return _decorator
def post_save(senders: Union[Type["Model"], List[Type["Model"]]],) -> Callable:
return receiver(signal="post_save", senders=senders)
def post_update(senders: Union[Type["Model"], List[Type["Model"]]],) -> Callable:
return receiver(signal="post_update", senders=senders)
def post_delete(senders: Union[Type["Model"], List[Type["Model"]]],) -> Callable:
return receiver(signal="post_delete", senders=senders)
def pre_save(senders: Union[Type["Model"], List[Type["Model"]]],) -> Callable:
return receiver(signal="pre_save", senders=senders)
def pre_update(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
return receiver(signal="pre_update", senders=senders)
def pre_delete(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
return receiver(signal="pre_delete", senders=senders)

View File

@ -1,28 +1,57 @@
class AsyncOrmException(Exception):
"""
Base ormar Exception
"""
pass
class ModelDefinitionError(AsyncOrmException):
"""
Raised for errors related to the model definition itself.
* setting @property_field on method with arguments other than func(self)
* defining a Field without required parameters
* defining a model with more than one primary_key
* defining a model without primary_key
* setting primary_key column as pydantic_only
"""
pass
class ModelError(AsyncOrmException):
pass
"""
Raised for initialization of model with non-existing field keyword.
"""
class ModelNotSet(AsyncOrmException):
pass
class NoMatch(AsyncOrmException):
"""
Raised for database queries that has no matching result (empty result).
"""
pass
class MultipleMatches(AsyncOrmException):
"""
Raised for database queries that should return one row (i.e. get, first etc.)
but has multiple matching results in response.
"""
pass
class QueryDefinitionError(AsyncOrmException):
"""
Raised for errors in query definition.
* using contains or icontains filter with instance of the Model
* using Queryset.update() without filter and setting each flag to True
* using Queryset.delete() without filter and setting each flag to True
"""
pass
@ -31,4 +60,17 @@ class RelationshipInstanceError(AsyncOrmException):
class ModelPersistenceError(AsyncOrmException):
"""
Raised for update of models without primary_key set (cannot retrieve from db)
or for saving a model with relation to unsaved model (cannot extract fk value).
"""
pass
class SignalDefinitionError(AsyncOrmException):
"""
Raised when non callable receiver is passed as signal callback.
"""
pass

View File

@ -18,6 +18,7 @@ from ormar.fields.many_to_many import ManyToMany, ManyToManyField
from ormar.models.quick_access_views import quick_access_set
from ormar.queryset import QuerySet
from ormar.relations.alias_manager import AliasManager
from ormar.signals import Signal, SignalEmitter
if TYPE_CHECKING: # pragma no cover
from ormar import Model
@ -38,6 +39,7 @@ class ModelMeta:
]
alias_manager: AliasManager
property_fields: Set
signals: SignalEmitter
def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None:
@ -332,15 +334,12 @@ def add_cached_properties(new_model: Type["Model"]) -> None:
new_model._pydantic_fields = {name for name in new_model.__fields__}
def property_fields_not_set(new_model: Type["Model"]) -> bool:
return (
not hasattr(new_model.Meta, "property_fields")
or not new_model.Meta.property_fields
)
def meta_field_not_set(model: Type["Model"], field_name: str) -> bool:
return not hasattr(model.Meta, field_name) or not getattr(model.Meta, field_name)
def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: CCR001
if property_fields_not_set(new_model):
if meta_field_not_set(model=new_model, field_name="property_fields"):
props = set()
for var_name, value in attrs.items():
if isinstance(value, property):
@ -351,6 +350,18 @@ def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa:
new_model.Meta.property_fields = props
def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001
if meta_field_not_set(model=new_model, field_name="signals"):
signals = SignalEmitter()
signals.pre_save = Signal()
signals.pre_update = Signal()
signals.pre_delete = Signal()
signals.post_save = Signal()
signals.post_update = Signal()
signals.post_delete = Signal()
new_model.Meta.signals = signals
class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__( # type: ignore
mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict
@ -379,5 +390,6 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
new_model.Meta.alias_manager = alias_manager
new_model.objects = QuerySet(new_model)
add_property_fields(new_model, attrs)
register_signals(new_model=new_model)
return new_model

View File

@ -195,6 +195,9 @@ class Model(NewBaseModel):
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
self_fields.pop(self.Meta.pkname, None)
self_fields = self.populate_default_values(self_fields)
self.from_dict(self_fields)
await self.signals.pre_save.send(sender=self.__class__, instance=self)
self_fields = self.translate_columns_to_aliases(self_fields)
expr = self.Meta.table.insert()
@ -204,6 +207,7 @@ class Model(NewBaseModel):
if pk and isinstance(pk, self.pk_type()):
setattr(self, self.Meta.pkname, pk)
self.set_save_status(True)
# refresh server side defaults
if any(
field.server_default is not None
@ -211,9 +215,8 @@ class Model(NewBaseModel):
if name not in self_fields
):
await self.load()
return self
self.set_save_status(True)
await self.signals.post_save.send(sender=self.__class__, instance=self)
return self
async def save_related( # noqa: CCR001
@ -268,6 +271,7 @@ class Model(NewBaseModel):
"You cannot update not saved model! Use save or upsert method."
)
await self.signals.pre_update.send(sender=self.__class__, instance=self)
self_fields = self._extract_model_db_fields()
self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname))
self_fields = self.translate_columns_to_aliases(self_fields)
@ -276,13 +280,16 @@ class Model(NewBaseModel):
await self.Meta.database.execute(expr)
self.set_save_status(True)
await self.signals.post_update.send(sender=self.__class__, instance=self)
return self
async def delete(self: T) -> int:
await self.signals.pre_delete.send(sender=self.__class__, instance=self)
expr = self.Meta.table.delete()
expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname)))
result = await self.Meta.database.execute(expr)
self.set_save_status(False)
await self.signals.post_delete.send(sender=self.__class__, instance=self)
return result
async def load(self: T) -> T:

View File

@ -35,6 +35,7 @@ from ormar.relations.relation_manager import RelationsManager
if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.signals import SignalEmitter
T = TypeVar("T", bound=Model)
@ -47,11 +48,7 @@ if TYPE_CHECKING: # pragma no cover
class NewBaseModel(
pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass
):
__slots__ = (
"_orm_id",
"_orm_saved",
"_orm",
)
__slots__ = ("_orm_id", "_orm_saved", "_orm", "_pk_column")
if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, Type[BaseField]]
@ -75,6 +72,7 @@ class NewBaseModel(
def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore
object.__setattr__(self, "_orm_id", uuid.uuid4().hex)
object.__setattr__(self, "_orm_saved", False)
object.__setattr__(self, "_pk_column", None)
object.__setattr__(
self,
"_orm",
@ -94,13 +92,8 @@ class NewBaseModel(
if "pk" in kwargs:
kwargs[self.Meta.pkname] = kwargs.pop("pk")
# remove property fields values from validation
kwargs = {
k: v
for k, v in kwargs.items()
if k not in object.__getattribute__(self, "Meta").property_fields
}
# build the models to set them and validate but don't register
# also remove property fields values from validation
try:
new_kwargs: Dict[str, Any] = {
k: self._convert_json(
@ -111,14 +104,15 @@ class NewBaseModel(
"dumps",
)
for k, v in kwargs.items()
if k not in object.__getattribute__(self, "Meta").property_fields
}
except KeyError as e:
raise ModelError(
f"Unknown field '{e.args[0]}' for model {self.get_name(lower=False)}"
)
# explicitly set None to excluded fields with default
# as pydantic populates them with default
# explicitly set None to excluded fields
# as pydantic populates them with default if set
for field_to_nullify in excluded:
new_kwargs[field_to_nullify] = None
@ -195,7 +189,8 @@ class NewBaseModel(
return (
self._orm_id == other._orm_id
or (self.pk == other.pk and self.pk is not None)
or self.dict() == other.dict()
or self.dict(exclude=self.extract_related_names())
== other.dict(exclude=other.extract_related_names())
)
@classmethod
@ -207,12 +202,21 @@ class NewBaseModel(
@property
def pk_column(self) -> sqlalchemy.Column:
return self.Meta.table.primary_key.columns.values()[0]
if object.__getattribute__(self, "_pk_column") is not None:
return object.__getattribute__(self, "_pk_column")
pk_columns = self.Meta.table.primary_key.columns.values()
pk_col = pk_columns[0]
object.__setattr__(self, "_pk_column", pk_col)
return pk_col
@property
def saved(self) -> bool:
return self._orm_saved
@property
def signals(self) -> "SignalEmitter":
return self.Meta.signals
@classmethod
def pk_type(cls) -> Any:
return cls.Meta.model_fields[cls.Meta.pkname].__type__

0
ormar/py.typed Normal file
View File

View File

@ -6,7 +6,7 @@ from sqlalchemy import bindparam
import ormar # noqa I100
from ormar import MultipleMatches, NoMatch
from ormar.exceptions import QueryDefinitionError
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
from ormar.queryset import FilterQuery
from ormar.queryset.clause import QueryClause
from ormar.queryset.prefetch_query import PrefetchQuery
@ -395,6 +395,9 @@ class QuerySet:
expr = expr.values(**new_kwargs)
instance = self.model(**kwargs)
await self.model.Meta.signals.pre_save.send(
sender=self.model, instance=instance
)
pk = await self.database.execute(expr)
pk_name = self.model.get_column_alias(self.model_meta.pkname)
@ -411,6 +414,9 @@ class QuerySet:
):
instance = await instance.load()
instance.set_save_status(True)
await self.model.Meta.signals.post_save.send(
sender=self.model, instance=instance
)
return instance
async def bulk_create(self, objects: List["Model"]) -> None:
@ -446,7 +452,7 @@ class QuerySet:
for objt in objects:
new_kwargs = objt.dict()
if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None:
raise QueryDefinitionError(
raise ModelPersistenceError(
"You cannot update unsaved objects. "
f"{self.model.__name__} has to have {pk_name} filled."
)

View File

@ -0,0 +1,3 @@
from ormar.signals.signal import Signal, SignalEmitter
__all__ = ["Signal", "SignalEmitter"]

71
ormar/signals/signal.py Normal file
View File

@ -0,0 +1,71 @@
import asyncio
import inspect
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Tuple, Type, Union
from ormar.exceptions import SignalDefinitionError
if TYPE_CHECKING: # pragma: no cover
from ormar import Model
def callable_accepts_kwargs(func: Callable) -> bool:
return any(
p
for p in inspect.signature(func).parameters.values()
if p.kind == p.VAR_KEYWORD
)
def make_id(target: Any) -> Union[int, Tuple[int, int]]:
if hasattr(target, "__func__"):
return id(target.__self__), id(target.__func__)
return id(target)
class Signal:
def __init__(self) -> None:
self._receivers: List[Tuple[Union[int, Tuple[int, int]], Callable]] = []
def connect(self, receiver: Callable) -> None:
if not callable(receiver):
raise SignalDefinitionError("Signal receivers must be callable.")
if not callable_accepts_kwargs(receiver):
raise SignalDefinitionError(
"Signal receivers must accept **kwargs argument."
)
new_receiver_key = make_id(receiver)
if not any(rec_id == new_receiver_key for rec_id, _ in self._receivers):
self._receivers.append((new_receiver_key, receiver))
def disconnect(self, receiver: Callable) -> bool:
removed = False
new_receiver_key = make_id(receiver)
for ind, rec in enumerate(self._receivers):
rec_id, _ = rec
if rec_id == new_receiver_key:
removed = True
del self._receivers[ind]
break
return removed
async def send(self, sender: Type["Model"], **kwargs: Any) -> None:
receivers = []
for receiver in self._receivers:
_, receiver_func = receiver
receivers.append(receiver_func(sender=sender, **kwargs))
await asyncio.gather(*receivers)
class SignalEmitter:
if TYPE_CHECKING: # pragma: no cover
signals: Dict[str, Signal]
def __init__(self) -> None:
object.__setattr__(self, "signals", dict())
def __getattr__(self, item: str) -> Signal:
return self.signals.setdefault(item, Signal())
def __setattr__(self, key: str, value: Any) -> None:
signals = object.__getattribute__(self, "signals")
signals[key] = value

View File

@ -1,19 +0,0 @@
#!/bin/sh -e
PACKAGE="ormar"
if [ -d 'dist' ] ; then
rm -r dist
fi
if [ -d 'site' ] ; then
rm -r site
fi
if [ -d 'htmlcov' ] ; then
rm -r htmlcov
fi
if [ -d "${PACKAGE}.egg-info" ] ; then
rm -r "${PACKAGE}.egg-info"
fi
find ${PACKAGE} -type f -name "*.py[co]" -delete
find ${PACKAGE} -type d -name __pycache__ -delete
find tests -type f -name "*.py[co]" -delete
find tests -type d -name __pycache__ -delete

View File

@ -1,23 +0,0 @@
#!/bin/sh -e
PACKAGE="ormar"
PREFIX=""
if [ -d 'venv' ] ; then
PREFIX="venv/bin/"
fi
VERSION=`cat ${PACKAGE}/__init__.py | grep __version__ | sed "s/__version__ = //" | sed "s/'//g"`
set -x
scripts/clean.sh
${PREFIX}python setup.py sdist
${PREFIX}twine upload dist/*
echo "You probably want to also tag the version now:"
echo "git tag -a ${VERSION} -m 'version ${VERSION}'"
echo "git push --tags"
scripts/clean.sh

View File

@ -50,6 +50,8 @@ setup(
author_email="collerek@gmail.com",
packages=get_packages(PACKAGE),
package_data={PACKAGE: ["py.typed"]},
include_package_data=True,
zip_safe=False,
data_files=[("", ["LICENSE.md"])],
install_requires=["databases", "pydantic>=1.5", "sqlalchemy", "typing_extensions"],
extras_require={
@ -65,9 +67,11 @@ setup(
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Topic :: Internet :: WWW/HTTP",
"Framework :: AsyncIO",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3 :: Only",
],
)

View File

@ -10,7 +10,7 @@ from fastapi import FastAPI
from starlette.testclient import TestClient
import ormar
from ormar import property_field
from ormar import post_save, property_field
from tests.settings import DATABASE_URL
app = FastAPI()
@ -65,8 +65,6 @@ class RandomModel(ormar.Model):
metadata = metadata
database = database
include_props_in_dict = True
id: int = ormar.Integer(primary_key=True)
password: str = ormar.String(max_length=255, default=gen_pass)
first_name: str = ormar.String(max_length=255, default="John")
@ -219,7 +217,7 @@ def test_excluding_fields_in_endpoints():
assert isinstance(user_instance.timestamp, datetime.datetime)
assert user_instance.timestamp == timestamp
response = client.post("/users4/", json=user)
response = client.post("/users4/", json=user3)
assert list(response.json().keys()) == [
"id",
"email",
@ -228,8 +226,12 @@ def test_excluding_fields_in_endpoints():
"category",
"timestamp",
]
assert response.json().get("timestamp") != str(timestamp).replace(" ", "T")
assert response.json().get("timestamp") is not None
assert (
datetime.datetime.strptime(
response.json().get("timestamp"), "%Y-%m-%dT%H:%M:%S.%f"
)
== timestamp
)
def test_adding_fields_in_endpoints():
@ -247,7 +249,6 @@ def test_adding_fields_in_endpoints():
]
assert response.json().get("full_name") == "John Test"
RandomModel.Meta.include_props_in_fields = False
user3 = {"last_name": "Test"}
response = client.post("/random/", json=user3)
assert list(response.json().keys()) == [
@ -264,7 +265,6 @@ def test_adding_fields_in_endpoints():
def test_adding_fields_in_endpoints2():
client = TestClient(app)
with client as client:
RandomModel.Meta.include_props_in_dict = True
user3 = {"last_name": "Test"}
response = client.post("/random2/", json=user3)
assert list(response.json().keys()) == [
@ -279,9 +279,15 @@ def test_adding_fields_in_endpoints2():
def test_excluding_property_field_in_endpoints2():
dummy_registry = {}
@post_save(RandomModel)
async def after_save(sender, instance, **kwargs):
dummy_registry[instance.pk] = instance.dict()
client = TestClient(app)
with client as client:
RandomModel.Meta.include_props_in_dict = True
user3 = {"last_name": "Test"}
response = client.post("/random3/", json=user3)
assert list(response.json().keys()) == [
@ -292,3 +298,7 @@ def test_excluding_property_field_in_endpoints2():
"created_date",
]
assert response.json().get("full_name") is None
assert len(dummy_registry) == 1
check_dict = dummy_registry.get(response.json().get("id"))
check_dict.pop("full_name")
assert response.json().get("password") == check_dict.get("password")

View File

@ -5,7 +5,7 @@ import pytest
import sqlalchemy
import ormar
from ormar.exceptions import QueryDefinitionError
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
@ -302,7 +302,7 @@ async def test_bulk_update_with_relation():
async def test_bulk_update_not_saved_objts():
async with database:
category = await Category.objects.create(name="Sample Category")
with pytest.raises(QueryDefinitionError):
with pytest.raises(ModelPersistenceError):
await Note.objects.bulk_update(
[
Note(text="Buy the groceries.", category=category),

352
tests/test_signals.py Normal file
View File

@ -0,0 +1,352 @@
from typing import Optional
import databases
import pydantic
import pytest
import sqlalchemy
import ormar
from ormar import (
post_delete,
post_save,
post_update,
pre_delete,
pre_save,
pre_update,
)
from ormar.exceptions import SignalDefinitionError
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class AuditLog(ormar.Model):
class Meta:
tablename = "audits"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
event_type: str = ormar.String(max_length=100)
event_log: pydantic.Json = ormar.JSON()
class Cover(ormar.Model):
class Meta:
tablename = "covers"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=100)
class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
is_best_seller: bool = ormar.Boolean(default=False)
play_count: int = ormar.Integer(default=0)
cover: Optional[Cover] = ormar.ForeignKey(Cover)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.fixture(scope="function")
async def cleanup():
yield
async with database:
await AuditLog.objects.delete(each=True)
def test_passing_not_callable():
with pytest.raises(SignalDefinitionError):
pre_save(Album)("wrong")
def test_passing_callable_without_kwargs():
with pytest.raises(SignalDefinitionError):
@pre_save(Album)
def trigger(sender, instance): # pragma: no cover
pass
@pytest.mark.asyncio
async def test_signal_functions(cleanup):
async with database:
async with database.transaction(force_rollback=True):
@pre_save(Album)
async def before_save(sender, instance, **kwargs):
await AuditLog(
event_type=f"PRE_SAVE_{sender.get_name()}",
event_log=instance.json(),
).save()
@post_save(Album)
async def after_save(sender, instance, **kwargs):
await AuditLog(
event_type=f"POST_SAVE_{sender.get_name()}",
event_log=instance.json(),
).save()
@pre_update(Album)
async def before_update(sender, instance, **kwargs):
await AuditLog(
event_type=f"PRE_UPDATE_{sender.get_name()}",
event_log=instance.json(),
).save()
@post_update(Album)
async def after_update(sender, instance, **kwargs):
await AuditLog(
event_type=f"POST_UPDATE_{sender.get_name()}",
event_log=instance.json(),
).save()
@pre_delete(Album)
async def before_delete(sender, instance, **kwargs):
await AuditLog(
event_type=f"PRE_DELETE_{sender.get_name()}",
event_log=instance.json(),
).save()
@post_delete(Album)
async def after_delete(sender, instance, **kwargs):
await AuditLog(
event_type=f"POST_DELETE_{sender.get_name()}",
event_log=instance.json(),
).save()
album = await Album.objects.create(name="Venice")
audits = await AuditLog.objects.all()
assert len(audits) == 2
assert audits[0].event_type == "PRE_SAVE_album"
assert audits[0].event_log.get("name") == album.name
assert audits[1].event_type == "POST_SAVE_album"
assert audits[1].event_log.get("id") == album.pk
album = await Album(name="Rome").save()
audits = await AuditLog.objects.all()
assert len(audits) == 4
assert audits[2].event_type == "PRE_SAVE_album"
assert audits[2].event_log.get("name") == album.name
assert audits[3].event_type == "POST_SAVE_album"
assert audits[3].event_log.get("id") == album.pk
album.is_best_seller = True
await album.update()
audits = await AuditLog.objects.filter(event_type__contains="UPDATE").all()
assert len(audits) == 2
assert audits[0].event_type == "PRE_UPDATE_album"
assert audits[0].event_log.get("name") == album.name
assert audits[1].event_type == "POST_UPDATE_album"
assert audits[1].event_log.get("is_best_seller") == album.is_best_seller
album.signals.pre_update.disconnect(before_update)
album.signals.post_update.disconnect(after_update)
album.is_best_seller = False
await album.update()
audits = await AuditLog.objects.filter(event_type__contains="UPDATE").all()
assert len(audits) == 2
await album.delete()
audits = await AuditLog.objects.filter(event_type__contains="DELETE").all()
assert len(audits) == 2
assert audits[0].event_type == "PRE_DELETE_album"
assert (
audits[0].event_log.get("id")
== audits[1].event_log.get("id")
== album.id
)
assert audits[1].event_type == "POST_DELETE_album"
album.signals.pre_delete.disconnect(before_delete)
album.signals.post_delete.disconnect(after_delete)
album.signals.pre_save.disconnect(before_save)
album.signals.post_save.disconnect(after_save)
@pytest.mark.asyncio
async def test_multiple_signals(cleanup):
async with database:
async with database.transaction(force_rollback=True):
@pre_save(Album)
async def before_save(sender, instance, **kwargs):
await AuditLog(
event_type=f"PRE_SAVE_{sender.get_name()}",
event_log=instance.json(),
).save()
@pre_save(Album)
async def before_save2(sender, instance, **kwargs):
await AuditLog(
event_type=f"PRE_SAVE_{sender.get_name()}",
event_log=instance.json(),
).save()
album = await Album.objects.create(name="Miami")
audits = await AuditLog.objects.all()
assert len(audits) == 2
assert audits[0].event_type == "PRE_SAVE_album"
assert audits[0].event_log.get("name") == album.name
assert audits[1].event_type == "PRE_SAVE_album"
assert audits[1].event_log.get("name") == album.name
album.signals.pre_save.disconnect(before_save)
album.signals.pre_save.disconnect(before_save2)
@pytest.mark.asyncio
async def test_static_methods_as_signals(cleanup):
async with database:
async with database.transaction(force_rollback=True):
class AlbumAuditor:
event_type = "ALBUM_INSTANCE"
@staticmethod
@pre_save(Album)
async def before_save(sender, instance, **kwargs):
await AuditLog(
event_type=f"{AlbumAuditor.event_type}_SAVE",
event_log=instance.json(),
).save()
album = await Album.objects.create(name="Colorado")
audits = await AuditLog.objects.all()
assert len(audits) == 1
assert audits[0].event_type == "ALBUM_INSTANCE_SAVE"
assert audits[0].event_log.get("name") == album.name
album.signals.pre_save.disconnect(AlbumAuditor.before_save)
@pytest.mark.asyncio
async def test_methods_as_signals(cleanup):
async with database:
async with database.transaction(force_rollback=True):
class AlbumAuditor:
def __init__(self):
self.event_type = "ALBUM_INSTANCE"
async def before_save(self, sender, instance, **kwargs):
await AuditLog(
event_type=f"{self.event_type}_SAVE", event_log=instance.json()
).save()
auditor = AlbumAuditor()
pre_save(Album)(auditor.before_save)
album = await Album.objects.create(name="San Francisco")
audits = await AuditLog.objects.all()
assert len(audits) == 1
assert audits[0].event_type == "ALBUM_INSTANCE_SAVE"
assert audits[0].event_log.get("name") == album.name
album.signals.pre_save.disconnect(auditor.before_save)
@pytest.mark.asyncio
async def test_multiple_senders_signal(cleanup):
async with database:
async with database.transaction(force_rollback=True):
@pre_save([Album, Cover])
async def before_save(sender, instance, **kwargs):
await AuditLog(
event_type=f"PRE_SAVE_{sender.get_name()}",
event_log=instance.json(),
).save()
cover = await Cover(title="Blue").save()
album = await Album.objects.create(name="San Francisco", cover=cover)
audits = await AuditLog.objects.all()
assert len(audits) == 2
assert audits[0].event_type == "PRE_SAVE_cover"
assert audits[0].event_log.get("title") == cover.title
assert audits[1].event_type == "PRE_SAVE_album"
assert audits[1].event_log.get("cover") == album.cover.dict(
exclude={"albums"}
)
album.signals.pre_save.disconnect(before_save)
cover.signals.pre_save.disconnect(before_save)
@pytest.mark.asyncio
async def test_modifing_the_instance(cleanup):
async with database:
async with database.transaction(force_rollback=True):
@pre_update(Album)
async def before_update(sender, instance, **kwargs):
if instance.play_count > 50 and not instance.is_best_seller:
instance.is_best_seller = True
# here album.play_count ans is_best_seller get default values
album = await Album.objects.create(name="Venice")
assert not album.is_best_seller
assert album.play_count == 0
album.play_count = 30
# here a trigger is called but play_count is too low
await album.update()
assert not album.is_best_seller
album.play_count = 60
await album.update()
assert album.is_best_seller
album.signals.pre_update.disconnect(before_update)
@pytest.mark.asyncio
async def test_custom_signal(cleanup):
async with database:
async with database.transaction(force_rollback=True):
async def after_update(sender, instance, **kwargs):
if instance.play_count > 50 and not instance.is_best_seller:
instance.is_best_seller = True
elif instance.play_count < 50 and instance.is_best_seller:
instance.is_best_seller = False
await instance.update()
Album.Meta.signals.custom.connect(after_update)
# here album.play_count ans is_best_seller get default values
album = await Album.objects.create(name="Venice")
assert not album.is_best_seller
assert album.play_count == 0
album.play_count = 30
# here a trigger is called but play_count is too low
await album.update()
assert not album.is_best_seller
album.play_count = 60
await album.update()
assert not album.is_best_seller
await Album.Meta.signals.custom.send(sender=Album, instance=album)
assert album.is_best_seller
album.play_count = 30
await album.update()
assert album.is_best_seller
await Album.Meta.signals.custom.send(sender=Album, instance=album)
assert not album.is_best_seller
Album.Meta.signals.custom.disconnect(after_update)