diff --git a/.coverage b/.coverage index f743d7d..10f7b3a 100644 Binary files a/.coverage and b/.coverage differ diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 31c0096..470ac57 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -21,13 +21,13 @@ class ModelTableProxy: raise NotImplementedError # pragma no cover def _extract_own_model_fields(self) -> dict: - related_names = self._extract_related_names() + related_names = self.extract_related_names() self_fields = {k: v for k, v in self.dict().items() if k not in related_names} return self_fields @classmethod - def extract_db_own_fields(cls) -> set: - related_names = cls._extract_related_names() + def extract_db_own_fields(cls) -> Set: + related_names = cls.extract_related_names() self_fields = { name for name in cls.Meta.model_fields.keys() if name not in related_names } @@ -35,7 +35,7 @@ class ModelTableProxy: @classmethod def substitute_models_with_pks(cls, model_dict: dict) -> dict: - for field in cls._extract_related_names(): + for field in cls.extract_related_names(): field_value = model_dict.get(field, None) if field_value is not None: target_field = cls.Meta.model_fields[field] @@ -47,7 +47,7 @@ class ModelTableProxy: return model_dict @classmethod - def _extract_related_names(cls) -> Set: + def extract_related_names(cls) -> Set: related_names = set() for name, field in cls.Meta.model_fields.items(): if inspect.isclass(field) and issubclass(field, ForeignKeyField): @@ -69,7 +69,7 @@ class ModelTableProxy: @classmethod def _exclude_related_names_not_required(cls, nested: bool = False) -> Set: if nested: - return cls._extract_related_names() + return cls.extract_related_names() related_names = set() for name, field in cls.Meta.model_fields.items(): if ( diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index ee94013..08294e8 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -92,7 +92,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass object.__setattr__(self, "__fields_set__", fields_set) # register the related models after initialization - for related in self._extract_related_names(): + for related in self.extract_related_names(): self.Meta.model_fields[related].expand_relationship( kwargs.get(related), self, to_register=True ) @@ -119,7 +119,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass def __getattribute__(self, item: str) -> Any: if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"): return object.__getattribute__(self, item) - if item != "_extract_related_names" and item in self._extract_related_names(): + if item != "extract_related_names" and item in self.extract_related_names(): return self._extract_related_model_instead_of_field(item) if item == "pk": return self.__dict__.get(self.Meta.pkname, None) @@ -186,7 +186,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) - for field in self._extract_related_names(): + for field in self.extract_related_names(): nested_model = getattr(self, field) if self.Meta.model_fields[field].virtual and nested: diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 6c7f657..81684d7 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -247,3 +247,52 @@ class QuerySet: if pk and isinstance(pk, self.model_cls.pk_type()): setattr(instance, self.model_cls.Meta.pkname, pk) return instance + + async def bulk_create(self, objects: List["Model"]) -> None: + ready_objects = [] + for objt in objects: + new_kwargs = objt.dict() + new_kwargs = self._remove_pk_from_kwargs(new_kwargs) + new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs) + new_kwargs = self._populate_default_values(new_kwargs) + ready_objects.append(new_kwargs) + + expr = self.table.insert() + await self.database.execute_many(expr, ready_objects) + + async def bulk_update( + self, objects: List["Model"], columns: List[str] = None + ) -> None: + ready_expressions = [] + pk_name = self.model_cls.Meta.pkname + if not columns: + columns = self.model_cls.extract_db_own_fields().union( + self.model_cls.extract_related_names() + ) + + if pk_name not in columns: + columns.append(pk_name) + + for objt in objects: + new_kwargs = objt.dict() + if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None: + raise QueryDefinitionError( + "You cannot update unsaved objects. " + f"{self.model_cls.__name__} has to have {pk_name} filled." + ) + new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs) + new_kwargs = self._populate_default_values(new_kwargs) + new_kwargs = {k: v for k, v in new_kwargs.items() if k in columns} + expr = self.table.update().values( + **{k: v for k, v in new_kwargs.items() if k != pk_name} + ) + pk_column = self.model_cls.Meta.table.c.get(pk_name) + expr = expr.where(pk_column == new_kwargs.get(pk_name)) + ready_expressions.append(expr) + + # databases does not bind params for where clause and values separately + # no way to pass one dict with both uses + # so we need to resort to lower connection api + async with self.model_cls.Meta.database.connection() as connection: + for single_query in ready_expressions: + await connection.execute(single_query) diff --git a/tests/test_queryset_level_methods.py b/tests/test_queryset_level_methods.py index 9611ad6..af8385e 100644 --- a/tests/test_queryset_level_methods.py +++ b/tests/test_queryset_level_methods.py @@ -19,7 +19,43 @@ class Book(ormar.Model): id: ormar.Integer(primary_key=True) title: ormar.String(max_length=200) author: ormar.String(max_length=100) - genre: ormar.String(max_length=100, default='Fiction', choices=['Fiction', 'Adventure', 'Historic', 'Fantasy']) + genre: ormar.String( + max_length=100, + default="Fiction", + choices=["Fiction", "Adventure", "Historic", "Fantasy"], + ) + + +class ToDo(ormar.Model): + class Meta: + tablename = "todos" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + text: ormar.String(max_length=500) + completed: ormar.Boolean(default=False) + + +class Category(ormar.Model): + class Meta: + tablename = "categories" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=500) + + +class Note(ormar.Model): + class Meta: + tablename = "notes" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + text: ormar.String(max_length=500) + category: ormar.ForeignKey(Category) @pytest.fixture(autouse=True, scope="module") @@ -35,36 +71,48 @@ def create_test_database(): async def test_delete_and_update(): async with database: async with database.transaction(force_rollback=True): - await Book.objects.create(title='Tom Sawyer', author="Twain, Mark", genre='Adventure') - await Book.objects.create(title='War and Peace', author="Tolstoy, Leo", genre='Fiction') - await Book.objects.create(title='Anna Karenina', author="Tolstoy, Leo", genre='Fiction') - await Book.objects.create(title='Harry Potter', author="Rowling, J.K.", genre='Fantasy') - await Book.objects.create(title='Lord of the Rings', author="Tolkien, J.R.", genre='Fantasy') + await Book.objects.create( + title="Tom Sawyer", author="Twain, Mark", genre="Adventure" + ) + await Book.objects.create( + title="War and Peace", author="Tolstoy, Leo", genre="Fiction" + ) + await Book.objects.create( + title="Anna Karenina", author="Tolstoy, Leo", genre="Fiction" + ) + await Book.objects.create( + title="Harry Potter", author="Rowling, J.K.", genre="Fantasy" + ) + await Book.objects.create( + title="Lord of the Rings", author="Tolkien, J.R.", genre="Fantasy" + ) all_books = await Book.objects.all() assert len(all_books) == 5 - await Book.objects.filter(author="Tolstoy, Leo").update(author="Lenin, Vladimir") + await Book.objects.filter(author="Tolstoy, Leo").update( + author="Lenin, Vladimir" + ) all_books = await Book.objects.filter(author="Lenin, Vladimir").all() assert len(all_books) == 2 - historic_books = await Book.objects.filter(genre='Historic').all() + historic_books = await Book.objects.filter(genre="Historic").all() assert len(historic_books) == 0 with pytest.raises(QueryDefinitionError): - await Book.objects.update(genre='Historic') + await Book.objects.update(genre="Historic") - await Book.objects.filter(author="Lenin, Vladimir").update(genre='Historic') + await Book.objects.filter(author="Lenin, Vladimir").update(genre="Historic") - historic_books = await Book.objects.filter(genre='Historic').all() + historic_books = await Book.objects.filter(genre="Historic").all() assert len(historic_books) == 2 - await Book.objects.delete(genre='Fantasy') + await Book.objects.delete(genre="Fantasy") all_books = await Book.objects.all() assert len(all_books) == 3 - await Book.objects.update(each=True, genre='Fiction') - all_books = await Book.objects.filter(genre='Fiction').all() + await Book.objects.update(each=True, genre="Fiction") + all_books = await Book.objects.filter(genre="Fiction").all() assert len(all_books) == 3 with pytest.raises(QueryDefinitionError): @@ -78,29 +126,184 @@ async def test_delete_and_update(): @pytest.mark.asyncio async def test_get_or_create(): async with database: - tom = await Book.objects.get_or_create(title="Volume I", author='Anonymous', genre='Fiction') + tom = await Book.objects.get_or_create( + title="Volume I", author="Anonymous", genre="Fiction" + ) assert await Book.objects.count() == 1 - assert await Book.objects.get_or_create(title="Volume I", author='Anonymous', genre='Fiction') == tom + assert ( + await Book.objects.get_or_create( + title="Volume I", author="Anonymous", genre="Fiction" + ) + == tom + ) assert await Book.objects.count() == 1 - assert await Book.objects.create(title="Volume I", author='Anonymous', genre='Fiction') + assert await Book.objects.create( + title="Volume I", author="Anonymous", genre="Fiction" + ) with pytest.raises(ormar.exceptions.MultipleMatches): - await Book.objects.get_or_create(title="Volume I", author='Anonymous', genre='Fiction') + await Book.objects.get_or_create( + title="Volume I", author="Anonymous", genre="Fiction" + ) @pytest.mark.asyncio async def test_update_or_create(): async with database: - tom = await Book.objects.update_or_create(title="Volume I", author='Anonymous', genre='Fiction') + tom = await Book.objects.update_or_create( + title="Volume I", author="Anonymous", genre="Fiction" + ) assert await Book.objects.count() == 1 - assert await Book.objects.update_or_create(id=tom.id, genre='Historic') + assert await Book.objects.update_or_create(id=tom.id, genre="Historic") assert await Book.objects.count() == 1 - assert await Book.objects.update_or_create(pk=tom.id, genre='Fantasy') + assert await Book.objects.update_or_create(pk=tom.id, genre="Fantasy") assert await Book.objects.count() == 1 - assert await Book.objects.create(title="Volume I", author='Anonymous', genre='Fantasy') + assert await Book.objects.create( + title="Volume I", author="Anonymous", genre="Fantasy" + ) with pytest.raises(ormar.exceptions.MultipleMatches): - await Book.objects.get(title="Volume I", author='Anonymous', genre='Fantasy') + await Book.objects.get( + title="Volume I", author="Anonymous", genre="Fantasy" + ) + + +@pytest.mark.asyncio +async def test_bulk_create(): + async with database: + await ToDo.objects.bulk_create( + [ + ToDo(text="Buy the groceries."), + ToDo(text="Call Mum.", completed=True), + ToDo(text="Send invoices.", completed=True), + ] + ) + + todoes = await ToDo.objects.all() + assert len(todoes) == 3 + for todo in todoes: + assert todo.pk is not None + + completed = await ToDo.objects.filter(completed=True).all() + assert len(completed) == 2 + + +@pytest.mark.asyncio +async def test_bulk_create_with_relation(): + async with database: + category = await Category.objects.create(name="Sample Category") + + await Note.objects.bulk_create( + [ + Note(text="Buy the groceries.", category=category), + Note(text="Call Mum.", category=category), + ] + ) + + todoes = await Note.objects.all() + assert len(todoes) == 2 + for todo in todoes: + assert todo.category.pk == category.pk + + +@pytest.mark.asyncio +async def test_bulk_update(): + async with database: + await ToDo.objects.bulk_create( + [ + ToDo(text="Buy the groceries."), + ToDo(text="Call Mum.", completed=True), + ToDo(text="Send invoices.", completed=True), + ] + ) + todoes = await ToDo.objects.all() + assert len(todoes) == 3 + + for todo in todoes: + todo.text = todo.text + "_1" + todo.completed = False + + await ToDo.objects.bulk_update(todoes) + + completed = await ToDo.objects.filter(completed=False).all() + assert len(completed) == 3 + + todoes = await ToDo.objects.all() + assert len(todoes) == 3 + + for todo in todoes: + assert todo.text[-2:] == "_1" + + +@pytest.mark.asyncio +async def test_bulk_update_with_only_selected_columns(): + async with database: + await ToDo.objects.bulk_create( + [ + ToDo(text="Reset the world simulation.", completed=False), + ToDo(text="Watch kittens.", completed=True), + ] + ) + + todoes = await ToDo.objects.all() + assert len(todoes) == 2 + + for todo in todoes: + todo.text = todo.text + "_1" + todo.completed = False + + await ToDo.objects.bulk_update(todoes, columns=["completed"]) + + completed = await ToDo.objects.filter(completed=False).all() + assert len(completed) == 2 + + todoes = await ToDo.objects.all() + assert len(todoes) == 2 + + for todo in todoes: + assert todo.text[-2:] != "_1" + + +@pytest.mark.asyncio +async def test_bulk_update_with_relation(): + async with database: + category = await Category.objects.create(name="Sample Category") + category2 = await Category.objects.create(name="Sample II Category") + + await Note.objects.bulk_create( + [ + Note(text="Buy the groceries.", category=category), + Note(text="Call Mum.", category=category), + Note(text="Text skynet.", category=category), + ] + ) + + notes = await Note.objects.all() + assert len(notes) == 3 + + for note in notes: + note.category = category2 + + await Note.objects.bulk_update(notes) + + notes_upd = await Note.objects.all() + assert len(notes_upd) == 3 + + for note in notes_upd: + assert note.category.pk == category2.pk + + +@pytest.mark.asyncio +async def test_bulk_update_not_saved_objts(): + async with database: + category = await Category.objects.create(name="Sample Category") + with pytest.raises(QueryDefinitionError): + await Note.objects.bulk_update( + [ + Note(text="Buy the groceries.", category=category), + Note(text="Call Mum.", category=category), + ] + )