Merge branch 'master' of https://github.com/collerek/ormar into check_timezones_filters
This commit is contained in:
@ -679,6 +679,7 @@ Signals allow to trigger your function for a given event on a given Model.
|
|||||||
* `post_relation_add`
|
* `post_relation_add`
|
||||||
* `pre_relation_remove`
|
* `pre_relation_remove`
|
||||||
* `post_relation_remove`
|
* `post_relation_remove`
|
||||||
|
* `post_bulk_update`
|
||||||
|
|
||||||
|
|
||||||
[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/
|
[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/
|
||||||
|
|||||||
@ -232,6 +232,11 @@ Send for `Model.relation_name.remove()` method for `ManyToMany` relations and re
|
|||||||
`sender` - sender class, `instance` - instance to which related model is added, `child` - model being added,
|
`sender` - sender class, `instance` - instance to which related model is added, `child` - model being added,
|
||||||
`relation_name` - name of the relation to which child is added.
|
`relation_name` - name of the relation to which child is added.
|
||||||
|
|
||||||
|
### post_bulk_update
|
||||||
|
|
||||||
|
`post_bulk_update(sender: Type["Model"], instances: List["Model"], **kwargs),
|
||||||
|
Send for `Model.objects.bulk_update(List[objects])` method.
|
||||||
|
|
||||||
|
|
||||||
## Defining your own signals
|
## Defining your own signals
|
||||||
|
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from ormar.decorators import ( # noqa: I100
|
|||||||
post_relation_remove,
|
post_relation_remove,
|
||||||
post_save,
|
post_save,
|
||||||
post_update,
|
post_update,
|
||||||
|
post_bulk_update,
|
||||||
pre_delete,
|
pre_delete,
|
||||||
pre_relation_add,
|
pre_relation_add,
|
||||||
pre_relation_remove,
|
pre_relation_remove,
|
||||||
@ -113,6 +114,7 @@ __all__ = [
|
|||||||
"RelationProtocol",
|
"RelationProtocol",
|
||||||
"ModelMeta",
|
"ModelMeta",
|
||||||
"property_field",
|
"property_field",
|
||||||
|
"post_bulk_update",
|
||||||
"post_delete",
|
"post_delete",
|
||||||
"post_save",
|
"post_save",
|
||||||
"post_update",
|
"post_update",
|
||||||
|
|||||||
@ -9,6 +9,7 @@ Currently only:
|
|||||||
"""
|
"""
|
||||||
from ormar.decorators.property_field import property_field
|
from ormar.decorators.property_field import property_field
|
||||||
from ormar.decorators.signals import (
|
from ormar.decorators.signals import (
|
||||||
|
post_bulk_update,
|
||||||
post_delete,
|
post_delete,
|
||||||
post_relation_add,
|
post_relation_add,
|
||||||
post_relation_remove,
|
post_relation_remove,
|
||||||
@ -23,6 +24,7 @@ from ormar.decorators.signals import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"property_field",
|
"property_field",
|
||||||
|
"post_bulk_update",
|
||||||
"post_delete",
|
"post_delete",
|
||||||
"post_save",
|
"post_save",
|
||||||
"post_update",
|
"post_update",
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Callable, List, TYPE_CHECKING, Type, Union
|
from typing import Callable, List, Type, TYPE_CHECKING, Union
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
@ -171,3 +171,18 @@ def post_relation_remove(
|
|||||||
:rtype: Callable
|
:rtype: Callable
|
||||||
"""
|
"""
|
||||||
return receiver(signal="post_relation_remove", senders=senders)
|
return receiver(signal="post_relation_remove", senders=senders)
|
||||||
|
|
||||||
|
|
||||||
|
def post_bulk_update(
|
||||||
|
senders: Union[Type["Model"], List[Type["Model"]]]
|
||||||
|
) -> Callable:
|
||||||
|
"""
|
||||||
|
Connect given function to all senders for post_bulk_update signal.
|
||||||
|
|
||||||
|
:param senders: one or a list of "Model" classes
|
||||||
|
that should have the signal receiver registered
|
||||||
|
:type senders: Union[Type["Model"], List[Type["Model"]]]
|
||||||
|
:return: returns the original function untouched
|
||||||
|
:rtype: Callable
|
||||||
|
"""
|
||||||
|
return receiver(signal="post_bulk_update", senders=senders)
|
||||||
|
|||||||
@ -81,3 +81,10 @@ class SignalDefinitionError(AsyncOrmException):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ModelListEmptyError(AsyncOrmException):
|
||||||
|
"""
|
||||||
|
Raised for objects is empty when bulk_update
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|||||||
@ -159,6 +159,7 @@ def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001
|
|||||||
signals.post_relation_add = Signal()
|
signals.post_relation_add = Signal()
|
||||||
signals.pre_relation_remove = Signal()
|
signals.pre_relation_remove = Signal()
|
||||||
signals.post_relation_remove = Signal()
|
signals.post_relation_remove = Signal()
|
||||||
|
signals.post_bulk_update = Signal()
|
||||||
new_model.Meta.signals = signals
|
new_model.Meta.signals = signals
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -56,6 +56,21 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
|
|||||||
new_kwargs = cls.translate_columns_to_aliases(new_kwargs)
|
new_kwargs = cls.translate_columns_to_aliases(new_kwargs)
|
||||||
return new_kwargs
|
return new_kwargs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def prepare_model_to_update(cls, new_kwargs: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Combines all preparation methods before updating.
|
||||||
|
:param new_kwargs: dictionary of model that is about to be saved
|
||||||
|
:type new_kwargs: Dict[str, str]
|
||||||
|
:return: dictionary of model that is about to be updated
|
||||||
|
:rtype: Dict[str, str]
|
||||||
|
"""
|
||||||
|
new_kwargs = cls.parse_non_db_fields(new_kwargs)
|
||||||
|
new_kwargs = cls.substitute_models_with_pks(new_kwargs)
|
||||||
|
new_kwargs = cls.reconvert_str_to_bytes(new_kwargs)
|
||||||
|
new_kwargs = cls.translate_columns_to_aliases(new_kwargs)
|
||||||
|
return new_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _remove_not_ormar_fields(cls, new_kwargs: dict) -> dict:
|
def _remove_not_ormar_fields(cls, new_kwargs: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -29,7 +29,10 @@ except ImportError: # pragma: no cover
|
|||||||
|
|
||||||
import ormar # noqa I100
|
import ormar # noqa I100
|
||||||
from ormar import MultipleMatches, NoMatch
|
from ormar import MultipleMatches, NoMatch
|
||||||
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
|
from ormar.exceptions import (
|
||||||
|
ModelPersistenceError, QueryDefinitionError,
|
||||||
|
ModelListEmptyError
|
||||||
|
)
|
||||||
from ormar.queryset import FieldAccessor, FilterQuery, SelectAction
|
from ormar.queryset import FieldAccessor, FilterQuery, SelectAction
|
||||||
from ormar.queryset.actions.order_action import OrderAction
|
from ormar.queryset.actions.order_action import OrderAction
|
||||||
from ormar.queryset.clause import FilterGroup, QueryClause
|
from ormar.queryset.clause import FilterGroup, QueryClause
|
||||||
@ -1049,7 +1052,7 @@ class QuerySet(Generic[T]):
|
|||||||
|
|
||||||
async def bulk_create(self, objects: List["T"]) -> None:
|
async def bulk_create(self, objects: List["T"]) -> None:
|
||||||
"""
|
"""
|
||||||
Performs a bulk update in one database session to speed up the process.
|
Performs a bulk create in one database session to speed up the process.
|
||||||
|
|
||||||
Allows you to create multiple objects at once.
|
Allows you to create multiple objects at once.
|
||||||
|
|
||||||
@ -1060,17 +1063,15 @@ class QuerySet(Generic[T]):
|
|||||||
:param objects: list of ormar models already initialized and ready to save.
|
:param objects: list of ormar models already initialized and ready to save.
|
||||||
:type objects: List[Model]
|
:type objects: List[Model]
|
||||||
"""
|
"""
|
||||||
ready_objects = []
|
ready_objects = [
|
||||||
for objt in objects:
|
obj.prepare_model_to_save(obj.dict())
|
||||||
new_kwargs = objt.dict()
|
for obj in objects
|
||||||
new_kwargs = objt.prepare_model_to_save(new_kwargs)
|
]
|
||||||
ready_objects.append(new_kwargs)
|
expr = self.table.insert().values(ready_objects)
|
||||||
|
await self.database.execute(expr)
|
||||||
|
|
||||||
expr = self.table.insert()
|
for obj in objects:
|
||||||
await self.database.execute_many(expr, ready_objects)
|
obj.set_save_status(True)
|
||||||
|
|
||||||
for objt in objects:
|
|
||||||
objt.set_save_status(True)
|
|
||||||
|
|
||||||
async def bulk_update( # noqa: CCR001
|
async def bulk_update( # noqa: CCR001
|
||||||
self, objects: List["T"], columns: List[str] = None
|
self, objects: List["T"], columns: List[str] = None
|
||||||
@ -1078,7 +1079,7 @@ class QuerySet(Generic[T]):
|
|||||||
"""
|
"""
|
||||||
Performs bulk update in one database session to speed up the process.
|
Performs bulk update in one database session to speed up the process.
|
||||||
|
|
||||||
Allows to update multiple instance at once.
|
Allows you to update multiple instance at once.
|
||||||
|
|
||||||
All `Models` passed need to have primary key column populated.
|
All `Models` passed need to have primary key column populated.
|
||||||
|
|
||||||
@ -1092,6 +1093,9 @@ class QuerySet(Generic[T]):
|
|||||||
:param columns: list of columns to update
|
:param columns: list of columns to update
|
||||||
:type columns: List[str]
|
:type columns: List[str]
|
||||||
"""
|
"""
|
||||||
|
if not objects:
|
||||||
|
raise ModelListEmptyError("Bulk update objects are empty!")
|
||||||
|
|
||||||
ready_objects = []
|
ready_objects = []
|
||||||
pk_name = self.model_meta.pkname
|
pk_name = self.model_meta.pkname
|
||||||
if not columns:
|
if not columns:
|
||||||
@ -1106,19 +1110,17 @@ class QuerySet(Generic[T]):
|
|||||||
|
|
||||||
columns = [self.model.get_column_alias(k) for k in columns]
|
columns = [self.model.get_column_alias(k) for k in columns]
|
||||||
|
|
||||||
for objt in objects:
|
for obj in objects:
|
||||||
new_kwargs = objt.dict()
|
new_kwargs = obj.dict()
|
||||||
if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None:
|
if new_kwargs.get(pk_name) is None:
|
||||||
raise ModelPersistenceError(
|
raise ModelPersistenceError(
|
||||||
"You cannot update unsaved objects. "
|
"You cannot update unsaved objects. "
|
||||||
f"{self.model.__name__} has to have {pk_name} filled."
|
f"{self.model.__name__} has to have {pk_name} filled."
|
||||||
)
|
)
|
||||||
new_kwargs = self.model.parse_non_db_fields(new_kwargs)
|
new_kwargs = obj.prepare_model_to_update(new_kwargs)
|
||||||
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
|
ready_objects.append({
|
||||||
new_kwargs = self.model.reconvert_str_to_bytes(new_kwargs)
|
"new_" + k: v for k, v in new_kwargs.items() if k in columns
|
||||||
new_kwargs = self.model.translate_columns_to_aliases(new_kwargs)
|
})
|
||||||
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
|
|
||||||
ready_objects.append(new_kwargs)
|
|
||||||
|
|
||||||
pk_column = self.model_meta.table.c.get(self.model.get_column_alias(pk_name))
|
pk_column = self.model_meta.table.c.get(self.model.get_column_alias(pk_name))
|
||||||
pk_column_name = self.model.get_column_alias(pk_name)
|
pk_column_name = self.model.get_column_alias(pk_name)
|
||||||
@ -1138,5 +1140,10 @@ class QuerySet(Generic[T]):
|
|||||||
expr = str(expr)
|
expr = str(expr)
|
||||||
await self.database.execute_many(expr, ready_objects)
|
await self.database.execute_many(expr, ready_objects)
|
||||||
|
|
||||||
for objt in objects:
|
for obj in objects:
|
||||||
objt.set_save_status(True)
|
obj.set_save_status(True)
|
||||||
|
|
||||||
|
await cast(Type["Model"], self.model_cls).Meta.signals.post_bulk_update.send(
|
||||||
|
sender=self.model_cls, instances=objects # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Tuple, Type, TYPE_CHECKING, Union
|
||||||
|
|
||||||
from ormar.exceptions import SignalDefinitionError
|
from ormar.exceptions import SignalDefinitionError
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ class Signal:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._receivers: List[Tuple[Union[int, Tuple[int, int]], Callable]] = []
|
self._receivers: Dict[Union[int, Tuple[int, int]], Callable] = {}
|
||||||
|
|
||||||
def connect(self, receiver: Callable) -> None:
|
def connect(self, receiver: Callable) -> None:
|
||||||
"""
|
"""
|
||||||
@ -63,8 +63,8 @@ class Signal:
|
|||||||
"Signal receivers must accept **kwargs argument."
|
"Signal receivers must accept **kwargs argument."
|
||||||
)
|
)
|
||||||
new_receiver_key = make_id(receiver)
|
new_receiver_key = make_id(receiver)
|
||||||
if not any(rec_id == new_receiver_key for rec_id, _ in self._receivers):
|
if new_receiver_key not in self._receivers:
|
||||||
self._receivers.append((new_receiver_key, receiver))
|
self._receivers[new_receiver_key] = receiver
|
||||||
|
|
||||||
def disconnect(self, receiver: Callable) -> bool:
|
def disconnect(self, receiver: Callable) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -75,15 +75,10 @@ class Signal:
|
|||||||
:return: flag if receiver was removed
|
:return: flag if receiver was removed
|
||||||
:rtype: bool
|
:rtype: bool
|
||||||
"""
|
"""
|
||||||
removed = False
|
|
||||||
new_receiver_key = make_id(receiver)
|
new_receiver_key = make_id(receiver)
|
||||||
for ind, rec in enumerate(self._receivers):
|
receiver_func: Union[Callable, None] = self._receivers.pop(
|
||||||
rec_id, _ = rec
|
new_receiver_key, None)
|
||||||
if rec_id == new_receiver_key:
|
return True if receiver_func is not None else False
|
||||||
removed = True
|
|
||||||
del self._receivers[ind]
|
|
||||||
break
|
|
||||||
return removed
|
|
||||||
|
|
||||||
async def send(self, sender: Type["Model"], **kwargs: Any) -> None:
|
async def send(self, sender: Type["Model"], **kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
@ -93,28 +88,22 @@ class Signal:
|
|||||||
:param kwargs: arguments passed to receivers
|
:param kwargs: arguments passed to receivers
|
||||||
:type kwargs: Any
|
:type kwargs: Any
|
||||||
"""
|
"""
|
||||||
receivers = []
|
receivers = [
|
||||||
for receiver in self._receivers:
|
receiver_func(sender=sender, **kwargs)
|
||||||
_, receiver_func = receiver
|
for receiver_func in self._receivers.values()
|
||||||
receivers.append(receiver_func(sender=sender, **kwargs))
|
]
|
||||||
await asyncio.gather(*receivers)
|
await asyncio.gather(*receivers)
|
||||||
|
|
||||||
|
|
||||||
class SignalEmitter:
|
class SignalEmitter(dict):
|
||||||
"""
|
"""
|
||||||
Emitter that registers the signals in internal dictionary.
|
Emitter that registers the signals in internal dictionary.
|
||||||
If signal with given name does not exist it's auto added on access.
|
If signal with given name does not exist it's auto added on access.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
|
||||||
signals: Dict[str, Signal]
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
object.__setattr__(self, "signals", dict())
|
|
||||||
|
|
||||||
def __getattr__(self, item: str) -> Signal:
|
def __getattr__(self, item: str) -> Signal:
|
||||||
return self.signals.setdefault(item, Signal())
|
return self.setdefault(item, Signal())
|
||||||
|
|
||||||
def __setattr__(self, key: str, value: Any) -> None:
|
def __setattr__(self, key: str, value: Signal) -> None:
|
||||||
signals = object.__getattribute__(self, "signals")
|
if not isinstance(value, Signal):
|
||||||
signals[key] = value
|
raise SignalDefinitionError(f"{value} is not valid signal")
|
||||||
|
self[key] = value
|
||||||
|
|||||||
@ -5,7 +5,10 @@ import pytest
|
|||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
import ormar
|
import ormar
|
||||||
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
|
from ormar.exceptions import (
|
||||||
|
ModelPersistenceError, QueryDefinitionError,
|
||||||
|
ModelListEmptyError
|
||||||
|
)
|
||||||
from tests.settings import DATABASE_URL
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||||
@ -309,3 +312,6 @@ async def test_bulk_update_not_saved_objts():
|
|||||||
Note(text="Call Mum.", category=category),
|
Note(text="Call Mum.", category=category),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ModelListEmptyError):
|
||||||
|
await Note.objects.bulk_update([])
|
||||||
|
|||||||
@ -6,7 +6,11 @@ import pytest
|
|||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
import ormar
|
import ormar
|
||||||
from ormar import post_delete, post_save, post_update, pre_delete, pre_save, pre_update
|
from ormar import (
|
||||||
|
post_bulk_update, post_delete, post_save, post_update,
|
||||||
|
pre_delete, pre_save, pre_update
|
||||||
|
)
|
||||||
|
from ormar.signals import SignalEmitter
|
||||||
from ormar.exceptions import SignalDefinitionError
|
from ormar.exceptions import SignalDefinitionError
|
||||||
from tests.settings import DATABASE_URL
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
@ -77,6 +81,12 @@ def test_passing_callable_without_kwargs():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_signal():
|
||||||
|
emitter = SignalEmitter()
|
||||||
|
with pytest.raises(SignalDefinitionError):
|
||||||
|
emitter.save = 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_signal_functions(cleanup):
|
async def test_signal_functions(cleanup):
|
||||||
async with database:
|
async with database:
|
||||||
@ -124,6 +134,14 @@ async def test_signal_functions(cleanup):
|
|||||||
event_log=instance.json(),
|
event_log=instance.json(),
|
||||||
).save()
|
).save()
|
||||||
|
|
||||||
|
@post_bulk_update(Album)
|
||||||
|
async def after_bulk_update(sender, instances, **kwargs):
|
||||||
|
for it in instances:
|
||||||
|
await AuditLog(
|
||||||
|
event_type=f"BULK_POST_UPDATE_{sender.get_name()}",
|
||||||
|
event_log=it.json(),
|
||||||
|
).save()
|
||||||
|
|
||||||
album = await Album.objects.create(name="Venice")
|
album = await Album.objects.create(name="Venice")
|
||||||
|
|
||||||
audits = await AuditLog.objects.all()
|
audits = await AuditLog.objects.all()
|
||||||
@ -176,6 +194,19 @@ async def test_signal_functions(cleanup):
|
|||||||
album.signals.pre_save.disconnect(before_save)
|
album.signals.pre_save.disconnect(before_save)
|
||||||
album.signals.post_save.disconnect(after_save)
|
album.signals.post_save.disconnect(after_save)
|
||||||
|
|
||||||
|
albums = await Album.objects.all()
|
||||||
|
assert len(albums)
|
||||||
|
|
||||||
|
for album in albums:
|
||||||
|
album.play_count = 1
|
||||||
|
|
||||||
|
await Album.objects.bulk_update(albums)
|
||||||
|
|
||||||
|
cnt = await AuditLog.objects.filter(event_type__contains="BULK_POST").count()
|
||||||
|
assert cnt == len(albums)
|
||||||
|
|
||||||
|
album.signals.bulk_post_update.disconnect(after_bulk_update)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_multiple_signals(cleanup):
|
async def test_multiple_signals(cleanup):
|
||||||
|
|||||||
Reference in New Issue
Block a user