Merge branch 'master' of https://github.com/collerek/ormar into check_timezones_filters

This commit is contained in:
collerek
2022-01-14 17:54:20 +01:00
12 changed files with 136 additions and 55 deletions

View File

@ -679,6 +679,7 @@ Signals allow to trigger your function for a given event on a given Model.
* `post_relation_add` * `post_relation_add`
* `pre_relation_remove` * `pre_relation_remove`
* `post_relation_remove` * `post_relation_remove`
* `post_bulk_update`
[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/

View File

@ -232,6 +232,11 @@ Send for `Model.relation_name.remove()` method for `ManyToMany` relations and re
`sender` - sender class, `instance` - instance to which related model is added, `child` - model being added, `sender` - sender class, `instance` - instance to which related model is added, `child` - model being added,
`relation_name` - name of the relation to which child is added. `relation_name` - name of the relation to which child is added.
### post_bulk_update
`post_bulk_update(sender: Type["Model"], instances: List["Model"], **kwargs),
Send for `Model.objects.bulk_update(List[objects])` method.
## Defining your own signals ## Defining your own signals

View File

@ -30,6 +30,7 @@ from ormar.decorators import ( # noqa: I100
post_relation_remove, post_relation_remove,
post_save, post_save,
post_update, post_update,
post_bulk_update,
pre_delete, pre_delete,
pre_relation_add, pre_relation_add,
pre_relation_remove, pre_relation_remove,
@ -113,6 +114,7 @@ __all__ = [
"RelationProtocol", "RelationProtocol",
"ModelMeta", "ModelMeta",
"property_field", "property_field",
"post_bulk_update",
"post_delete", "post_delete",
"post_save", "post_save",
"post_update", "post_update",

View File

@ -9,6 +9,7 @@ Currently only:
""" """
from ormar.decorators.property_field import property_field from ormar.decorators.property_field import property_field
from ormar.decorators.signals import ( from ormar.decorators.signals import (
post_bulk_update,
post_delete, post_delete,
post_relation_add, post_relation_add,
post_relation_remove, post_relation_remove,
@ -23,6 +24,7 @@ from ormar.decorators.signals import (
__all__ = [ __all__ = [
"property_field", "property_field",
"post_bulk_update",
"post_delete", "post_delete",
"post_save", "post_save",
"post_update", "post_update",

View File

@ -1,4 +1,4 @@
from typing import Callable, List, TYPE_CHECKING, Type, Union from typing import Callable, List, Type, TYPE_CHECKING, Union
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from ormar import Model from ormar import Model
@ -171,3 +171,18 @@ def post_relation_remove(
:rtype: Callable :rtype: Callable
""" """
return receiver(signal="post_relation_remove", senders=senders) return receiver(signal="post_relation_remove", senders=senders)
def post_bulk_update(
senders: Union[Type["Model"], List[Type["Model"]]]
) -> Callable:
"""
Connect given function to all senders for post_bulk_update signal.
:param senders: one or a list of "Model" classes
that should have the signal receiver registered
:type senders: Union[Type["Model"], List[Type["Model"]]]
:return: returns the original function untouched
:rtype: Callable
"""
return receiver(signal="post_bulk_update", senders=senders)

View File

@ -81,3 +81,10 @@ class SignalDefinitionError(AsyncOrmException):
""" """
pass pass
class ModelListEmptyError(AsyncOrmException):
"""
Raised for objects is empty when bulk_update
"""
pass

View File

@ -159,6 +159,7 @@ def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001
signals.post_relation_add = Signal() signals.post_relation_add = Signal()
signals.pre_relation_remove = Signal() signals.pre_relation_remove = Signal()
signals.post_relation_remove = Signal() signals.post_relation_remove = Signal()
signals.post_bulk_update = Signal()
new_model.Meta.signals = signals new_model.Meta.signals = signals

View File

@ -56,6 +56,21 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
new_kwargs = cls.translate_columns_to_aliases(new_kwargs) new_kwargs = cls.translate_columns_to_aliases(new_kwargs)
return new_kwargs return new_kwargs
@classmethod
def prepare_model_to_update(cls, new_kwargs: dict) -> dict:
"""
Combines all preparation methods before updating.
:param new_kwargs: dictionary of model that is about to be saved
:type new_kwargs: Dict[str, str]
:return: dictionary of model that is about to be updated
:rtype: Dict[str, str]
"""
new_kwargs = cls.parse_non_db_fields(new_kwargs)
new_kwargs = cls.substitute_models_with_pks(new_kwargs)
new_kwargs = cls.reconvert_str_to_bytes(new_kwargs)
new_kwargs = cls.translate_columns_to_aliases(new_kwargs)
return new_kwargs
@classmethod @classmethod
def _remove_not_ormar_fields(cls, new_kwargs: dict) -> dict: def _remove_not_ormar_fields(cls, new_kwargs: dict) -> dict:
""" """

View File

@ -29,7 +29,10 @@ except ImportError: # pragma: no cover
import ormar # noqa I100 import ormar # noqa I100
from ormar import MultipleMatches, NoMatch from ormar import MultipleMatches, NoMatch
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError from ormar.exceptions import (
ModelPersistenceError, QueryDefinitionError,
ModelListEmptyError
)
from ormar.queryset import FieldAccessor, FilterQuery, SelectAction from ormar.queryset import FieldAccessor, FilterQuery, SelectAction
from ormar.queryset.actions.order_action import OrderAction from ormar.queryset.actions.order_action import OrderAction
from ormar.queryset.clause import FilterGroup, QueryClause from ormar.queryset.clause import FilterGroup, QueryClause
@ -1049,7 +1052,7 @@ class QuerySet(Generic[T]):
async def bulk_create(self, objects: List["T"]) -> None: async def bulk_create(self, objects: List["T"]) -> None:
""" """
Performs a bulk update in one database session to speed up the process. Performs a bulk create in one database session to speed up the process.
Allows you to create multiple objects at once. Allows you to create multiple objects at once.
@ -1060,17 +1063,15 @@ class QuerySet(Generic[T]):
:param objects: list of ormar models already initialized and ready to save. :param objects: list of ormar models already initialized and ready to save.
:type objects: List[Model] :type objects: List[Model]
""" """
ready_objects = [] ready_objects = [
for objt in objects: obj.prepare_model_to_save(obj.dict())
new_kwargs = objt.dict() for obj in objects
new_kwargs = objt.prepare_model_to_save(new_kwargs) ]
ready_objects.append(new_kwargs) expr = self.table.insert().values(ready_objects)
await self.database.execute(expr)
expr = self.table.insert() for obj in objects:
await self.database.execute_many(expr, ready_objects) obj.set_save_status(True)
for objt in objects:
objt.set_save_status(True)
async def bulk_update( # noqa: CCR001 async def bulk_update( # noqa: CCR001
self, objects: List["T"], columns: List[str] = None self, objects: List["T"], columns: List[str] = None
@ -1078,7 +1079,7 @@ class QuerySet(Generic[T]):
""" """
Performs bulk update in one database session to speed up the process. Performs bulk update in one database session to speed up the process.
Allows to update multiple instance at once. Allows you to update multiple instance at once.
All `Models` passed need to have primary key column populated. All `Models` passed need to have primary key column populated.
@ -1092,6 +1093,9 @@ class QuerySet(Generic[T]):
:param columns: list of columns to update :param columns: list of columns to update
:type columns: List[str] :type columns: List[str]
""" """
if not objects:
raise ModelListEmptyError("Bulk update objects are empty!")
ready_objects = [] ready_objects = []
pk_name = self.model_meta.pkname pk_name = self.model_meta.pkname
if not columns: if not columns:
@ -1106,19 +1110,17 @@ class QuerySet(Generic[T]):
columns = [self.model.get_column_alias(k) for k in columns] columns = [self.model.get_column_alias(k) for k in columns]
for objt in objects: for obj in objects:
new_kwargs = objt.dict() new_kwargs = obj.dict()
if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None: if new_kwargs.get(pk_name) is None:
raise ModelPersistenceError( raise ModelPersistenceError(
"You cannot update unsaved objects. " "You cannot update unsaved objects. "
f"{self.model.__name__} has to have {pk_name} filled." f"{self.model.__name__} has to have {pk_name} filled."
) )
new_kwargs = self.model.parse_non_db_fields(new_kwargs) new_kwargs = obj.prepare_model_to_update(new_kwargs)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs) ready_objects.append({
new_kwargs = self.model.reconvert_str_to_bytes(new_kwargs) "new_" + k: v for k, v in new_kwargs.items() if k in columns
new_kwargs = self.model.translate_columns_to_aliases(new_kwargs) })
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
ready_objects.append(new_kwargs)
pk_column = self.model_meta.table.c.get(self.model.get_column_alias(pk_name)) pk_column = self.model_meta.table.c.get(self.model.get_column_alias(pk_name))
pk_column_name = self.model.get_column_alias(pk_name) pk_column_name = self.model.get_column_alias(pk_name)
@ -1138,5 +1140,10 @@ class QuerySet(Generic[T]):
expr = str(expr) expr = str(expr)
await self.database.execute_many(expr, ready_objects) await self.database.execute_many(expr, ready_objects)
for objt in objects: for obj in objects:
objt.set_save_status(True) obj.set_save_status(True)
await cast(Type["Model"], self.model_cls).Meta.signals.post_bulk_update.send(
sender=self.model_cls, instances=objects # type: ignore
)

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
import inspect import inspect
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Tuple, Type, Union from typing import Any, Callable, Dict, Tuple, Type, TYPE_CHECKING, Union
from ormar.exceptions import SignalDefinitionError from ormar.exceptions import SignalDefinitionError
@ -45,7 +45,7 @@ class Signal:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._receivers: List[Tuple[Union[int, Tuple[int, int]], Callable]] = [] self._receivers: Dict[Union[int, Tuple[int, int]], Callable] = {}
def connect(self, receiver: Callable) -> None: def connect(self, receiver: Callable) -> None:
""" """
@ -63,8 +63,8 @@ class Signal:
"Signal receivers must accept **kwargs argument." "Signal receivers must accept **kwargs argument."
) )
new_receiver_key = make_id(receiver) new_receiver_key = make_id(receiver)
if not any(rec_id == new_receiver_key for rec_id, _ in self._receivers): if new_receiver_key not in self._receivers:
self._receivers.append((new_receiver_key, receiver)) self._receivers[new_receiver_key] = receiver
def disconnect(self, receiver: Callable) -> bool: def disconnect(self, receiver: Callable) -> bool:
""" """
@ -75,15 +75,10 @@ class Signal:
:return: flag if receiver was removed :return: flag if receiver was removed
:rtype: bool :rtype: bool
""" """
removed = False
new_receiver_key = make_id(receiver) new_receiver_key = make_id(receiver)
for ind, rec in enumerate(self._receivers): receiver_func: Union[Callable, None] = self._receivers.pop(
rec_id, _ = rec new_receiver_key, None)
if rec_id == new_receiver_key: return True if receiver_func is not None else False
removed = True
del self._receivers[ind]
break
return removed
async def send(self, sender: Type["Model"], **kwargs: Any) -> None: async def send(self, sender: Type["Model"], **kwargs: Any) -> None:
""" """
@ -93,28 +88,22 @@ class Signal:
:param kwargs: arguments passed to receivers :param kwargs: arguments passed to receivers
:type kwargs: Any :type kwargs: Any
""" """
receivers = [] receivers = [
for receiver in self._receivers: receiver_func(sender=sender, **kwargs)
_, receiver_func = receiver for receiver_func in self._receivers.values()
receivers.append(receiver_func(sender=sender, **kwargs)) ]
await asyncio.gather(*receivers) await asyncio.gather(*receivers)
class SignalEmitter: class SignalEmitter(dict):
""" """
Emitter that registers the signals in internal dictionary. Emitter that registers the signals in internal dictionary.
If signal with given name does not exist it's auto added on access. If signal with given name does not exist it's auto added on access.
""" """
if TYPE_CHECKING: # pragma: no cover
signals: Dict[str, Signal]
def __init__(self) -> None:
object.__setattr__(self, "signals", dict())
def __getattr__(self, item: str) -> Signal: def __getattr__(self, item: str) -> Signal:
return self.signals.setdefault(item, Signal()) return self.setdefault(item, Signal())
def __setattr__(self, key: str, value: Any) -> None: def __setattr__(self, key: str, value: Signal) -> None:
signals = object.__getattribute__(self, "signals") if not isinstance(value, Signal):
signals[key] = value raise SignalDefinitionError(f"{value} is not valid signal")
self[key] = value

View File

@ -5,7 +5,10 @@ import pytest
import sqlalchemy import sqlalchemy
import ormar import ormar
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError from ormar.exceptions import (
ModelPersistenceError, QueryDefinitionError,
ModelListEmptyError
)
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
@ -309,3 +312,6 @@ async def test_bulk_update_not_saved_objts():
Note(text="Call Mum.", category=category), Note(text="Call Mum.", category=category),
] ]
) )
with pytest.raises(ModelListEmptyError):
await Note.objects.bulk_update([])

View File

@ -6,7 +6,11 @@ import pytest
import sqlalchemy import sqlalchemy
import ormar import ormar
from ormar import post_delete, post_save, post_update, pre_delete, pre_save, pre_update from ormar import (
post_bulk_update, post_delete, post_save, post_update,
pre_delete, pre_save, pre_update
)
from ormar.signals import SignalEmitter
from ormar.exceptions import SignalDefinitionError from ormar.exceptions import SignalDefinitionError
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -77,6 +81,12 @@ def test_passing_callable_without_kwargs():
pass pass
def test_invalid_signal():
emitter = SignalEmitter()
with pytest.raises(SignalDefinitionError):
emitter.save = 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_signal_functions(cleanup): async def test_signal_functions(cleanup):
async with database: async with database:
@ -124,6 +134,14 @@ async def test_signal_functions(cleanup):
event_log=instance.json(), event_log=instance.json(),
).save() ).save()
@post_bulk_update(Album)
async def after_bulk_update(sender, instances, **kwargs):
for it in instances:
await AuditLog(
event_type=f"BULK_POST_UPDATE_{sender.get_name()}",
event_log=it.json(),
).save()
album = await Album.objects.create(name="Venice") album = await Album.objects.create(name="Venice")
audits = await AuditLog.objects.all() audits = await AuditLog.objects.all()
@ -176,6 +194,19 @@ async def test_signal_functions(cleanup):
album.signals.pre_save.disconnect(before_save) album.signals.pre_save.disconnect(before_save)
album.signals.post_save.disconnect(after_save) album.signals.post_save.disconnect(after_save)
albums = await Album.objects.all()
assert len(albums)
for album in albums:
album.play_count = 1
await Album.objects.bulk_update(albums)
cnt = await AuditLog.objects.filter(event_type__contains="BULK_POST").count()
assert cnt == len(albums)
album.signals.bulk_post_update.disconnect(after_bulk_update)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_signals(cleanup): async def test_multiple_signals(cleanup):