diff --git a/docs/releases.md b/docs/releases.md index f858a08..a5576a2 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,16 +1,21 @@ # 0.10.24 +## ✨ Features + +* Add `post_bulk_update` signal (by @ponytailer - thanks!) [#524](https://github.com/collerek/ormar/pull/524) + ## 🐛 Fixes * Fix support for `pydantic==1.9.0` [#502](https://github.com/collerek/ormar/issues/502) * Fix timezone issues with datetime [#504](https://github.com/collerek/ormar/issues/504) * Remove literal binds in query generation to unblock postgres arrays [#/tophat/ormar-postgres-extensions/9](https://github.com/tophat/ormar-postgres-extensions/pull/9) +* Fix bulk update for `JSON` fields [#519](https://github.com/collerek/ormar/issues/519) ## 💬 Other +* Improve performance of `bulk_create` by bypassing `databases` `execute_many` suboptimal implementation. (by @Mng-dev-ai thanks!) [#520](https://github.com/collerek/ormar/pull/520) * Bump min. required `databases` version to `>=5.4`. - # 0.10.23 ## ✨ Features diff --git a/ormar/decorators/signals.py b/ormar/decorators/signals.py index cdb9f45..f70b5b7 100644 --- a/ormar/decorators/signals.py +++ b/ormar/decorators/signals.py @@ -173,9 +173,7 @@ def post_relation_remove( return receiver(signal="post_relation_remove", senders=senders) -def post_bulk_update( - senders: Union[Type["Model"], List[Type["Model"]]] -) -> Callable: +def post_bulk_update(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: """ Connect given function to all senders for post_bulk_update signal. diff --git a/ormar/exceptions.py b/ormar/exceptions.py index 8fd6bab..12e2e3a 100644 --- a/ormar/exceptions.py +++ b/ormar/exceptions.py @@ -87,4 +87,5 @@ class ModelListEmptyError(AsyncOrmException): """ Raised for objects is empty when bulk_update """ + pass diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index 60efa37..89ddbfe 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -12,6 +12,11 @@ from typing import ( cast, ) +try: + import orjson as json +except ImportError: # pragma: no cover + import json # type: ignore + import pydantic import ormar # noqa: I100, I202 @@ -31,6 +36,8 @@ class SavePrepareMixin(RelationMixin, AliasMixin): if TYPE_CHECKING: # pragma: nocover _choices_fields: Optional[Set] _skip_ellipsis: Callable + _json_fields: Set[str] + _bytes_fields: Set[str] __fields__: Dict[str, pydantic.fields.ModelField] @classmethod @@ -53,6 +60,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin): new_kwargs = cls.substitute_models_with_pks(new_kwargs) new_kwargs = cls.populate_default_values(new_kwargs) new_kwargs = cls.reconvert_str_to_bytes(new_kwargs) + new_kwargs = cls.dump_all_json_fields_to_str(new_kwargs) new_kwargs = cls.translate_columns_to_aliases(new_kwargs) return new_kwargs @@ -68,6 +76,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin): new_kwargs = cls.parse_non_db_fields(new_kwargs) new_kwargs = cls.substitute_models_with_pks(new_kwargs) new_kwargs = cls.reconvert_str_to_bytes(new_kwargs) + new_kwargs = cls.dump_all_json_fields_to_str(new_kwargs) new_kwargs = cls.translate_columns_to_aliases(new_kwargs) return new_kwargs @@ -172,18 +181,13 @@ class SavePrepareMixin(RelationMixin, AliasMixin): :return: dictionary of model that is about to be saved :rtype: Dict """ - bytes_fields = { - name - for name, field in cls.Meta.model_fields.items() - if field.__type__ == bytes - } bytes_base64_fields = { name for name, field in cls.Meta.model_fields.items() if field.represent_as_base64_str } for key, value in model_dict.items(): - if key in bytes_fields and isinstance(value, str): + if key in cls._bytes_fields and isinstance(value, str): model_dict[key] = ( value.encode("utf-8") if key not in bytes_base64_fields @@ -191,6 +195,22 @@ class SavePrepareMixin(RelationMixin, AliasMixin): ) return model_dict + @classmethod + def dump_all_json_fields_to_str(cls, model_dict: Dict) -> Dict: + """ + Receives dictionary of model that is about to be saved and changes + all json fields into strings + + :param model_dict: dictionary of model that is about to be saved + :type model_dict: Dict + :return: dictionary of model that is about to be saved + :rtype: Dict + """ + for key, value in model_dict.items(): + if key in cls._json_fields and not isinstance(value, str): + model_dict[key] = json.dumps(value) + return model_dict + @classmethod def populate_default_values(cls, new_kwargs: Dict) -> Dict: """ diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 95cba43..103347d 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -30,8 +30,9 @@ except ImportError: # pragma: no cover import ormar # noqa I100 from ormar import MultipleMatches, NoMatch from ormar.exceptions import ( - ModelPersistenceError, QueryDefinitionError, - ModelListEmptyError + ModelPersistenceError, + QueryDefinitionError, + ModelListEmptyError, ) from ormar.queryset import FieldAccessor, FilterQuery, SelectAction from ormar.queryset.actions.order_action import OrderAction @@ -1063,10 +1064,7 @@ class QuerySet(Generic[T]): :param objects: list of ormar models already initialized and ready to save. :type objects: List[Model] """ - ready_objects = [ - obj.prepare_model_to_save(obj.dict()) - for obj in objects - ] + ready_objects = [obj.prepare_model_to_save(obj.dict()) for obj in objects] expr = self.table.insert().values(ready_objects) await self.database.execute(expr) @@ -1118,9 +1116,9 @@ class QuerySet(Generic[T]): f"{self.model.__name__} has to have {pk_name} filled." ) new_kwargs = obj.prepare_model_to_update(new_kwargs) - ready_objects.append({ - "new_" + k: v for k, v in new_kwargs.items() if k in columns - }) + ready_objects.append( + {"new_" + k: v for k, v in new_kwargs.items() if k in columns} + ) pk_column = self.model_meta.table.c.get(self.model.get_column_alias(pk_name)) pk_column_name = self.model.get_column_alias(pk_name) @@ -1146,4 +1144,3 @@ class QuerySet(Generic[T]): await cast(Type["Model"], self.model_cls).Meta.signals.post_bulk_update.send( sender=self.model_cls, instances=objects # type: ignore ) - diff --git a/ormar/signals/signal.py b/ormar/signals/signal.py index d5b4a5c..2b76787 100644 --- a/ormar/signals/signal.py +++ b/ormar/signals/signal.py @@ -77,7 +77,8 @@ class Signal: """ new_receiver_key = make_id(receiver) receiver_func: Union[Callable, None] = self._receivers.pop( - new_receiver_key, None) + new_receiver_key, None + ) return True if receiver_func is not None else False async def send(self, sender: Type["Model"], **kwargs: Any) -> None: @@ -100,6 +101,7 @@ class SignalEmitter(dict): Emitter that registers the signals in internal dictionary. If signal with given name does not exist it's auto added on access. """ + def __getattr__(self, item: str) -> Signal: return self.setdefault(item, Signal()) diff --git a/tests/test_queries/test_queryset_level_methods.py b/tests/test_queries/test_queryset_level_methods.py index e7a9543..449c7df 100644 --- a/tests/test_queries/test_queryset_level_methods.py +++ b/tests/test_queries/test_queryset_level_methods.py @@ -1,13 +1,15 @@ -from typing import Optional +from typing import List, Optional import databases +import pydantic import pytest import sqlalchemy import ormar from ormar.exceptions import ( - ModelPersistenceError, QueryDefinitionError, - ModelListEmptyError + ModelPersistenceError, + QueryDefinitionError, + ModelListEmptyError, ) from tests.settings import DATABASE_URL @@ -63,6 +65,17 @@ class Note(ormar.Model): category: Optional[Category] = ormar.ForeignKey(Category) +class ItemConfig(ormar.Model): + class Meta: + metadata = metadata + database = database + tablename = "item_config" + + id: Optional[int] = ormar.Integer(primary_key=True) + item_id: str = ormar.String(max_length=32, index=True) + pairs: pydantic.Json = ormar.JSON(default=["2", "3"]) + + @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) @@ -315,3 +328,23 @@ async def test_bulk_update_not_saved_objts(): with pytest.raises(ModelListEmptyError): await Note.objects.bulk_update([]) + + +@pytest.mark.asyncio +async def test_bulk_operations_with_json(): + async with database: + items = [ + ItemConfig(item_id="test1"), + ItemConfig(item_id="test2"), + ItemConfig(item_id="test3"), + ] + await ItemConfig.objects.bulk_create(items) + items = await ItemConfig.objects.all() + assert all(x.pairs == ["2", "3"] for x in items) + + for item in items: + item.pairs = ["1"] + + await ItemConfig.objects.bulk_update(items) + items = await ItemConfig.objects.all() + assert all(x.pairs == ["1"] for x in items) diff --git a/tests/test_signals/test_signals.py b/tests/test_signals/test_signals.py index 1002a2f..01e4f7a 100644 --- a/tests/test_signals/test_signals.py +++ b/tests/test_signals/test_signals.py @@ -7,8 +7,13 @@ import sqlalchemy import ormar from ormar import ( - post_bulk_update, post_delete, post_save, post_update, - pre_delete, pre_save, pre_update + post_bulk_update, + post_delete, + post_save, + post_update, + pre_delete, + pre_save, + pre_update, ) from ormar.signals import SignalEmitter from ormar.exceptions import SignalDefinitionError @@ -202,7 +207,9 @@ async def test_signal_functions(cleanup): await Album.objects.bulk_update(albums) - cnt = await AuditLog.objects.filter(event_type__contains="BULK_POST").count() + cnt = await AuditLog.objects.filter( + event_type__contains="BULK_POST" + ).count() assert cnt == len(albums) album.signals.bulk_post_update.disconnect(after_bulk_update)