Merge pull request #13 from collerek/bulk_operations

Bulk operations - add bulk_create and bulk_update
This commit is contained in:
collerek
2020-09-28 19:47:32 +07:00
committed by GitHub
7 changed files with 329 additions and 33 deletions

BIN
.coverage

Binary file not shown.

View File

@ -328,6 +328,52 @@ assert await Book.objects.count() == 1
``` ```
Since version >=0.3.5 Ormar supports also bulk operations -> bulk_create and bulk_update
```python
import databases
import ormar
import sqlalchemy
database = databases.Database("sqlite:///db.sqlite")
metadata = sqlalchemy.MetaData()
class ToDo(ormar.Model):
class Meta:
tablename = "todos"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
text: ormar.String(max_length=500)
completed: ormar.Boolean(default=False)
# create multiple instances at once with bulk_create
await ToDo.objects.bulk_create(
[
ToDo(text="Buy the groceries."),
ToDo(text="Call Mum.", completed=True),
ToDo(text="Send invoices.", completed=True),
]
)
todoes = await ToDo.objects.all()
assert len(todoes) == 3
# update objects
for todo in todoes:
todo.completed = False
# perform update of all objects at once
# objects need to have pk column set, otherwise exception is raised
await ToDo.objects.bulk_update(todoes)
completed = await ToDo.objects.filter(completed=False).all()
assert len(completed) == 3
```
## Data types ## Data types
The following keyword arguments are supported on all field types. The following keyword arguments are supported on all field types.

View File

@ -26,7 +26,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType() Undefined = UndefinedType()
__version__ = "0.3.4" __version__ = "0.3.5"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",

View File

@ -21,13 +21,13 @@ class ModelTableProxy:
raise NotImplementedError # pragma no cover raise NotImplementedError # pragma no cover
def _extract_own_model_fields(self) -> dict: def _extract_own_model_fields(self) -> dict:
related_names = self._extract_related_names() related_names = self.extract_related_names()
self_fields = {k: v for k, v in self.dict().items() if k not in related_names} self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
return self_fields return self_fields
@classmethod @classmethod
def extract_db_own_fields(cls) -> set: def extract_db_own_fields(cls) -> Set:
related_names = cls._extract_related_names() related_names = cls.extract_related_names()
self_fields = { self_fields = {
name for name in cls.Meta.model_fields.keys() if name not in related_names name for name in cls.Meta.model_fields.keys() if name not in related_names
} }
@ -35,7 +35,7 @@ class ModelTableProxy:
@classmethod @classmethod
def substitute_models_with_pks(cls, model_dict: dict) -> dict: def substitute_models_with_pks(cls, model_dict: dict) -> dict:
for field in cls._extract_related_names(): for field in cls.extract_related_names():
field_value = model_dict.get(field, None) field_value = model_dict.get(field, None)
if field_value is not None: if field_value is not None:
target_field = cls.Meta.model_fields[field] target_field = cls.Meta.model_fields[field]
@ -47,7 +47,7 @@ class ModelTableProxy:
return model_dict return model_dict
@classmethod @classmethod
def _extract_related_names(cls) -> Set: def extract_related_names(cls) -> Set:
related_names = set() related_names = set()
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field) and issubclass(field, ForeignKeyField): if inspect.isclass(field) and issubclass(field, ForeignKeyField):
@ -69,7 +69,7 @@ class ModelTableProxy:
@classmethod @classmethod
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set: def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
if nested: if nested:
return cls._extract_related_names() return cls.extract_related_names()
related_names = set() related_names = set()
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if ( if (

View File

@ -92,7 +92,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
object.__setattr__(self, "__fields_set__", fields_set) object.__setattr__(self, "__fields_set__", fields_set)
# register the related models after initialization # register the related models after initialization
for related in self._extract_related_names(): for related in self.extract_related_names():
self.Meta.model_fields[related].expand_relationship( self.Meta.model_fields[related].expand_relationship(
kwargs.get(related), self, to_register=True kwargs.get(related), self, to_register=True
) )
@ -119,7 +119,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"): if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"):
return object.__getattribute__(self, item) return object.__getattribute__(self, item)
if item != "_extract_related_names" and item in self._extract_related_names(): if item != "extract_related_names" and item in self.extract_related_names():
return self._extract_related_model_instead_of_field(item) return self._extract_related_model_instead_of_field(item)
if item == "pk": if item == "pk":
return self.__dict__.get(self.Meta.pkname, None) return self.__dict__.get(self.Meta.pkname, None)
@ -186,7 +186,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
exclude_defaults=exclude_defaults, exclude_defaults=exclude_defaults,
exclude_none=exclude_none, exclude_none=exclude_none,
) )
for field in self._extract_related_names(): for field in self.extract_related_names():
nested_model = getattr(self, field) nested_model = getattr(self, field)
if self.Meta.model_fields[field].virtual and nested: if self.Meta.model_fields[field].virtual and nested:

View File

@ -2,6 +2,7 @@ from typing import Any, List, Mapping, TYPE_CHECKING, Tuple, Type, Union
import databases import databases
import sqlalchemy import sqlalchemy
from sqlalchemy import bindparam
import ormar # noqa I100 import ormar # noqa I100
from ormar import MultipleMatches, NoMatch from ormar import MultipleMatches, NoMatch
@ -247,3 +248,49 @@ class QuerySet:
if pk and isinstance(pk, self.model_cls.pk_type()): if pk and isinstance(pk, self.model_cls.pk_type()):
setattr(instance, self.model_cls.Meta.pkname, pk) setattr(instance, self.model_cls.Meta.pkname, pk)
return instance return instance
async def bulk_create(self, objects: List["Model"]) -> None:
ready_objects = []
for objt in objects:
new_kwargs = objt.dict()
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs)
ready_objects.append(new_kwargs)
expr = self.table.insert()
await self.database.execute_many(expr, ready_objects)
async def bulk_update(
self, objects: List["Model"], columns: List[str] = None
) -> None:
ready_objects = []
pk_name = self.model_cls.Meta.pkname
if not columns:
columns = self.model_cls.extract_db_own_fields().union(
self.model_cls.extract_related_names()
)
if pk_name not in columns:
columns.append(pk_name)
for objt in objects:
new_kwargs = objt.dict()
if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None:
raise QueryDefinitionError(
"You cannot update unsaved objects. "
f"{self.model_cls.__name__} has to have {pk_name} filled."
)
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
ready_objects.append(new_kwargs)
pk_column = self.model_cls.Meta.table.c.get(pk_name)
expr = self.table.update().where(pk_column == bindparam("new_" + pk_name))
expr = expr.values(
**{k: bindparam("new_" + k) for k in columns if k != pk_name}
)
# databases bind params only where query is passed as string
# otherwise it just pases all data to values and results in unconsumed columns
expr = str(expr)
await self.database.execute_many(expr, ready_objects)

View File

@ -19,7 +19,43 @@ class Book(ormar.Model):
id: ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
title: ormar.String(max_length=200) title: ormar.String(max_length=200)
author: ormar.String(max_length=100) author: ormar.String(max_length=100)
genre: ormar.String(max_length=100, default='Fiction', choices=['Fiction', 'Adventure', 'Historic', 'Fantasy']) genre: ormar.String(
max_length=100,
default="Fiction",
choices=["Fiction", "Adventure", "Historic", "Fantasy"],
)
class ToDo(ormar.Model):
class Meta:
tablename = "todos"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
text: ormar.String(max_length=500)
completed: ormar.Boolean(default=False)
class Category(ormar.Model):
class Meta:
tablename = "categories"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=500)
class Note(ormar.Model):
class Meta:
tablename = "notes"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
text: ormar.String(max_length=500)
category: ormar.ForeignKey(Category)
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
@ -35,36 +71,48 @@ def create_test_database():
async def test_delete_and_update(): async def test_delete_and_update():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
await Book.objects.create(title='Tom Sawyer', author="Twain, Mark", genre='Adventure') await Book.objects.create(
await Book.objects.create(title='War and Peace', author="Tolstoy, Leo", genre='Fiction') title="Tom Sawyer", author="Twain, Mark", genre="Adventure"
await Book.objects.create(title='Anna Karenina', author="Tolstoy, Leo", genre='Fiction') )
await Book.objects.create(title='Harry Potter', author="Rowling, J.K.", genre='Fantasy') await Book.objects.create(
await Book.objects.create(title='Lord of the Rings', author="Tolkien, J.R.", genre='Fantasy') title="War and Peace", author="Tolstoy, Leo", genre="Fiction"
)
await Book.objects.create(
title="Anna Karenina", author="Tolstoy, Leo", genre="Fiction"
)
await Book.objects.create(
title="Harry Potter", author="Rowling, J.K.", genre="Fantasy"
)
await Book.objects.create(
title="Lord of the Rings", author="Tolkien, J.R.", genre="Fantasy"
)
all_books = await Book.objects.all() all_books = await Book.objects.all()
assert len(all_books) == 5 assert len(all_books) == 5
await Book.objects.filter(author="Tolstoy, Leo").update(author="Lenin, Vladimir") await Book.objects.filter(author="Tolstoy, Leo").update(
author="Lenin, Vladimir"
)
all_books = await Book.objects.filter(author="Lenin, Vladimir").all() all_books = await Book.objects.filter(author="Lenin, Vladimir").all()
assert len(all_books) == 2 assert len(all_books) == 2
historic_books = await Book.objects.filter(genre='Historic').all() historic_books = await Book.objects.filter(genre="Historic").all()
assert len(historic_books) == 0 assert len(historic_books) == 0
with pytest.raises(QueryDefinitionError): with pytest.raises(QueryDefinitionError):
await Book.objects.update(genre='Historic') await Book.objects.update(genre="Historic")
await Book.objects.filter(author="Lenin, Vladimir").update(genre='Historic') await Book.objects.filter(author="Lenin, Vladimir").update(genre="Historic")
historic_books = await Book.objects.filter(genre='Historic').all() historic_books = await Book.objects.filter(genre="Historic").all()
assert len(historic_books) == 2 assert len(historic_books) == 2
await Book.objects.delete(genre='Fantasy') await Book.objects.delete(genre="Fantasy")
all_books = await Book.objects.all() all_books = await Book.objects.all()
assert len(all_books) == 3 assert len(all_books) == 3
await Book.objects.update(each=True, genre='Fiction') await Book.objects.update(each=True, genre="Fiction")
all_books = await Book.objects.filter(genre='Fiction').all() all_books = await Book.objects.filter(genre="Fiction").all()
assert len(all_books) == 3 assert len(all_books) == 3
with pytest.raises(QueryDefinitionError): with pytest.raises(QueryDefinitionError):
@ -78,29 +126,184 @@ async def test_delete_and_update():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_or_create(): async def test_get_or_create():
async with database: async with database:
tom = await Book.objects.get_or_create(title="Volume I", author='Anonymous', genre='Fiction') tom = await Book.objects.get_or_create(
title="Volume I", author="Anonymous", genre="Fiction"
)
assert await Book.objects.count() == 1 assert await Book.objects.count() == 1
assert await Book.objects.get_or_create(title="Volume I", author='Anonymous', genre='Fiction') == tom assert (
await Book.objects.get_or_create(
title="Volume I", author="Anonymous", genre="Fiction"
)
== tom
)
assert await Book.objects.count() == 1 assert await Book.objects.count() == 1
assert await Book.objects.create(title="Volume I", author='Anonymous', genre='Fiction') assert await Book.objects.create(
title="Volume I", author="Anonymous", genre="Fiction"
)
with pytest.raises(ormar.exceptions.MultipleMatches): with pytest.raises(ormar.exceptions.MultipleMatches):
await Book.objects.get_or_create(title="Volume I", author='Anonymous', genre='Fiction') await Book.objects.get_or_create(
title="Volume I", author="Anonymous", genre="Fiction"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_or_create(): async def test_update_or_create():
async with database: async with database:
tom = await Book.objects.update_or_create(title="Volume I", author='Anonymous', genre='Fiction') tom = await Book.objects.update_or_create(
title="Volume I", author="Anonymous", genre="Fiction"
)
assert await Book.objects.count() == 1 assert await Book.objects.count() == 1
assert await Book.objects.update_or_create(id=tom.id, genre='Historic') assert await Book.objects.update_or_create(id=tom.id, genre="Historic")
assert await Book.objects.count() == 1 assert await Book.objects.count() == 1
assert await Book.objects.update_or_create(pk=tom.id, genre='Fantasy') assert await Book.objects.update_or_create(pk=tom.id, genre="Fantasy")
assert await Book.objects.count() == 1 assert await Book.objects.count() == 1
assert await Book.objects.create(title="Volume I", author='Anonymous', genre='Fantasy') assert await Book.objects.create(
title="Volume I", author="Anonymous", genre="Fantasy"
)
with pytest.raises(ormar.exceptions.MultipleMatches): with pytest.raises(ormar.exceptions.MultipleMatches):
await Book.objects.get(title="Volume I", author='Anonymous', genre='Fantasy') await Book.objects.get(
title="Volume I", author="Anonymous", genre="Fantasy"
)
@pytest.mark.asyncio
async def test_bulk_create():
async with database:
await ToDo.objects.bulk_create(
[
ToDo(text="Buy the groceries."),
ToDo(text="Call Mum.", completed=True),
ToDo(text="Send invoices.", completed=True),
]
)
todoes = await ToDo.objects.all()
assert len(todoes) == 3
for todo in todoes:
assert todo.pk is not None
completed = await ToDo.objects.filter(completed=True).all()
assert len(completed) == 2
@pytest.mark.asyncio
async def test_bulk_create_with_relation():
async with database:
category = await Category.objects.create(name="Sample Category")
await Note.objects.bulk_create(
[
Note(text="Buy the groceries.", category=category),
Note(text="Call Mum.", category=category),
]
)
todoes = await Note.objects.all()
assert len(todoes) == 2
for todo in todoes:
assert todo.category.pk == category.pk
@pytest.mark.asyncio
async def test_bulk_update():
async with database:
await ToDo.objects.bulk_create(
[
ToDo(text="Buy the groceries."),
ToDo(text="Call Mum.", completed=True),
ToDo(text="Send invoices.", completed=True),
]
)
todoes = await ToDo.objects.all()
assert len(todoes) == 3
for todo in todoes:
todo.text = todo.text + "_1"
todo.completed = False
await ToDo.objects.bulk_update(todoes)
completed = await ToDo.objects.filter(completed=False).all()
assert len(completed) == 3
todoes = await ToDo.objects.all()
assert len(todoes) == 3
for todo in todoes:
assert todo.text[-2:] == "_1"
@pytest.mark.asyncio
async def test_bulk_update_with_only_selected_columns():
async with database:
await ToDo.objects.bulk_create(
[
ToDo(text="Reset the world simulation.", completed=False),
ToDo(text="Watch kittens.", completed=True),
]
)
todoes = await ToDo.objects.all()
assert len(todoes) == 2
for todo in todoes:
todo.text = todo.text + "_1"
todo.completed = False
await ToDo.objects.bulk_update(todoes, columns=["completed"])
completed = await ToDo.objects.filter(completed=False).all()
assert len(completed) == 2
todoes = await ToDo.objects.all()
assert len(todoes) == 2
for todo in todoes:
assert todo.text[-2:] != "_1"
@pytest.mark.asyncio
async def test_bulk_update_with_relation():
async with database:
category = await Category.objects.create(name="Sample Category")
category2 = await Category.objects.create(name="Sample II Category")
await Note.objects.bulk_create(
[
Note(text="Buy the groceries.", category=category),
Note(text="Call Mum.", category=category),
Note(text="Text skynet.", category=category),
]
)
notes = await Note.objects.all()
assert len(notes) == 3
for note in notes:
note.category = category2
await Note.objects.bulk_update(notes)
notes_upd = await Note.objects.all()
assert len(notes_upd) == 3
for note in notes_upd:
assert note.category.pk == category2.pk
@pytest.mark.asyncio
async def test_bulk_update_not_saved_objts():
async with database:
category = await Category.objects.create(name="Sample Category")
with pytest.raises(QueryDefinitionError):
await Note.objects.bulk_update(
[
Note(text="Buy the groceries.", category=category),
Note(text="Call Mum.", category=category),
]
)