add bulk_create and bulk_update and tests
This commit is contained in:
@ -21,13 +21,13 @@ class ModelTableProxy:
|
|||||||
raise NotImplementedError # pragma no cover
|
raise NotImplementedError # pragma no cover
|
||||||
|
|
||||||
def _extract_own_model_fields(self) -> dict:
|
def _extract_own_model_fields(self) -> dict:
|
||||||
related_names = self._extract_related_names()
|
related_names = self.extract_related_names()
|
||||||
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
|
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
|
||||||
return self_fields
|
return self_fields
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_db_own_fields(cls) -> set:
|
def extract_db_own_fields(cls) -> Set:
|
||||||
related_names = cls._extract_related_names()
|
related_names = cls.extract_related_names()
|
||||||
self_fields = {
|
self_fields = {
|
||||||
name for name in cls.Meta.model_fields.keys() if name not in related_names
|
name for name in cls.Meta.model_fields.keys() if name not in related_names
|
||||||
}
|
}
|
||||||
@ -35,7 +35,7 @@ class ModelTableProxy:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def substitute_models_with_pks(cls, model_dict: dict) -> dict:
|
def substitute_models_with_pks(cls, model_dict: dict) -> dict:
|
||||||
for field in cls._extract_related_names():
|
for field in cls.extract_related_names():
|
||||||
field_value = model_dict.get(field, None)
|
field_value = model_dict.get(field, None)
|
||||||
if field_value is not None:
|
if field_value is not None:
|
||||||
target_field = cls.Meta.model_fields[field]
|
target_field = cls.Meta.model_fields[field]
|
||||||
@ -47,7 +47,7 @@ class ModelTableProxy:
|
|||||||
return model_dict
|
return model_dict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_related_names(cls) -> Set:
|
def extract_related_names(cls) -> Set:
|
||||||
related_names = set()
|
related_names = set()
|
||||||
for name, field in cls.Meta.model_fields.items():
|
for name, field in cls.Meta.model_fields.items():
|
||||||
if inspect.isclass(field) and issubclass(field, ForeignKeyField):
|
if inspect.isclass(field) and issubclass(field, ForeignKeyField):
|
||||||
@ -69,7 +69,7 @@ class ModelTableProxy:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
|
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
|
||||||
if nested:
|
if nested:
|
||||||
return cls._extract_related_names()
|
return cls.extract_related_names()
|
||||||
related_names = set()
|
related_names = set()
|
||||||
for name, field in cls.Meta.model_fields.items():
|
for name, field in cls.Meta.model_fields.items():
|
||||||
if (
|
if (
|
||||||
|
|||||||
@ -92,7 +92,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
object.__setattr__(self, "__fields_set__", fields_set)
|
object.__setattr__(self, "__fields_set__", fields_set)
|
||||||
|
|
||||||
# register the related models after initialization
|
# register the related models after initialization
|
||||||
for related in self._extract_related_names():
|
for related in self.extract_related_names():
|
||||||
self.Meta.model_fields[related].expand_relationship(
|
self.Meta.model_fields[related].expand_relationship(
|
||||||
kwargs.get(related), self, to_register=True
|
kwargs.get(related), self, to_register=True
|
||||||
)
|
)
|
||||||
@ -119,7 +119,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
def __getattribute__(self, item: str) -> Any:
|
def __getattribute__(self, item: str) -> Any:
|
||||||
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"):
|
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"):
|
||||||
return object.__getattribute__(self, item)
|
return object.__getattribute__(self, item)
|
||||||
if item != "_extract_related_names" and item in self._extract_related_names():
|
if item != "extract_related_names" and item in self.extract_related_names():
|
||||||
return self._extract_related_model_instead_of_field(item)
|
return self._extract_related_model_instead_of_field(item)
|
||||||
if item == "pk":
|
if item == "pk":
|
||||||
return self.__dict__.get(self.Meta.pkname, None)
|
return self.__dict__.get(self.Meta.pkname, None)
|
||||||
@ -186,7 +186,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
exclude_defaults=exclude_defaults,
|
exclude_defaults=exclude_defaults,
|
||||||
exclude_none=exclude_none,
|
exclude_none=exclude_none,
|
||||||
)
|
)
|
||||||
for field in self._extract_related_names():
|
for field in self.extract_related_names():
|
||||||
nested_model = getattr(self, field)
|
nested_model = getattr(self, field)
|
||||||
|
|
||||||
if self.Meta.model_fields[field].virtual and nested:
|
if self.Meta.model_fields[field].virtual and nested:
|
||||||
|
|||||||
@ -247,3 +247,52 @@ class QuerySet:
|
|||||||
if pk and isinstance(pk, self.model_cls.pk_type()):
|
if pk and isinstance(pk, self.model_cls.pk_type()):
|
||||||
setattr(instance, self.model_cls.Meta.pkname, pk)
|
setattr(instance, self.model_cls.Meta.pkname, pk)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
async def bulk_create(self, objects: List["Model"]) -> None:
|
||||||
|
ready_objects = []
|
||||||
|
for objt in objects:
|
||||||
|
new_kwargs = objt.dict()
|
||||||
|
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
|
||||||
|
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
|
||||||
|
new_kwargs = self._populate_default_values(new_kwargs)
|
||||||
|
ready_objects.append(new_kwargs)
|
||||||
|
|
||||||
|
expr = self.table.insert()
|
||||||
|
await self.database.execute_many(expr, ready_objects)
|
||||||
|
|
||||||
|
async def bulk_update(
|
||||||
|
self, objects: List["Model"], columns: List[str] = None
|
||||||
|
) -> None:
|
||||||
|
ready_expressions = []
|
||||||
|
pk_name = self.model_cls.Meta.pkname
|
||||||
|
if not columns:
|
||||||
|
columns = self.model_cls.extract_db_own_fields().union(
|
||||||
|
self.model_cls.extract_related_names()
|
||||||
|
)
|
||||||
|
|
||||||
|
if pk_name not in columns:
|
||||||
|
columns.append(pk_name)
|
||||||
|
|
||||||
|
for objt in objects:
|
||||||
|
new_kwargs = objt.dict()
|
||||||
|
if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None:
|
||||||
|
raise QueryDefinitionError(
|
||||||
|
"You cannot update unsaved objects. "
|
||||||
|
f"{self.model_cls.__name__} has to have {pk_name} filled."
|
||||||
|
)
|
||||||
|
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
|
||||||
|
new_kwargs = self._populate_default_values(new_kwargs)
|
||||||
|
new_kwargs = {k: v for k, v in new_kwargs.items() if k in columns}
|
||||||
|
expr = self.table.update().values(
|
||||||
|
**{k: v for k, v in new_kwargs.items() if k != pk_name}
|
||||||
|
)
|
||||||
|
pk_column = self.model_cls.Meta.table.c.get(pk_name)
|
||||||
|
expr = expr.where(pk_column == new_kwargs.get(pk_name))
|
||||||
|
ready_expressions.append(expr)
|
||||||
|
|
||||||
|
# databases does not bind params for where clause and values separately
|
||||||
|
# no way to pass one dict with both uses
|
||||||
|
# so we need to resort to lower connection api
|
||||||
|
async with self.model_cls.Meta.database.connection() as connection:
|
||||||
|
for single_query in ready_expressions:
|
||||||
|
await connection.execute(single_query)
|
||||||
|
|||||||
@ -19,7 +19,43 @@ class Book(ormar.Model):
|
|||||||
id: ormar.Integer(primary_key=True)
|
id: ormar.Integer(primary_key=True)
|
||||||
title: ormar.String(max_length=200)
|
title: ormar.String(max_length=200)
|
||||||
author: ormar.String(max_length=100)
|
author: ormar.String(max_length=100)
|
||||||
genre: ormar.String(max_length=100, default='Fiction', choices=['Fiction', 'Adventure', 'Historic', 'Fantasy'])
|
genre: ormar.String(
|
||||||
|
max_length=100,
|
||||||
|
default="Fiction",
|
||||||
|
choices=["Fiction", "Adventure", "Historic", "Fantasy"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToDo(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "todos"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
text: ormar.String(max_length=500)
|
||||||
|
completed: ormar.Boolean(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Category(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "categories"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
name: ormar.String(max_length=500)
|
||||||
|
|
||||||
|
|
||||||
|
class Note(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "notes"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
text: ormar.String(max_length=500)
|
||||||
|
category: ormar.ForeignKey(Category)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True, scope="module")
|
@pytest.fixture(autouse=True, scope="module")
|
||||||
@ -35,36 +71,48 @@ def create_test_database():
|
|||||||
async def test_delete_and_update():
|
async def test_delete_and_update():
|
||||||
async with database:
|
async with database:
|
||||||
async with database.transaction(force_rollback=True):
|
async with database.transaction(force_rollback=True):
|
||||||
await Book.objects.create(title='Tom Sawyer', author="Twain, Mark", genre='Adventure')
|
await Book.objects.create(
|
||||||
await Book.objects.create(title='War and Peace', author="Tolstoy, Leo", genre='Fiction')
|
title="Tom Sawyer", author="Twain, Mark", genre="Adventure"
|
||||||
await Book.objects.create(title='Anna Karenina', author="Tolstoy, Leo", genre='Fiction')
|
)
|
||||||
await Book.objects.create(title='Harry Potter', author="Rowling, J.K.", genre='Fantasy')
|
await Book.objects.create(
|
||||||
await Book.objects.create(title='Lord of the Rings', author="Tolkien, J.R.", genre='Fantasy')
|
title="War and Peace", author="Tolstoy, Leo", genre="Fiction"
|
||||||
|
)
|
||||||
|
await Book.objects.create(
|
||||||
|
title="Anna Karenina", author="Tolstoy, Leo", genre="Fiction"
|
||||||
|
)
|
||||||
|
await Book.objects.create(
|
||||||
|
title="Harry Potter", author="Rowling, J.K.", genre="Fantasy"
|
||||||
|
)
|
||||||
|
await Book.objects.create(
|
||||||
|
title="Lord of the Rings", author="Tolkien, J.R.", genre="Fantasy"
|
||||||
|
)
|
||||||
|
|
||||||
all_books = await Book.objects.all()
|
all_books = await Book.objects.all()
|
||||||
assert len(all_books) == 5
|
assert len(all_books) == 5
|
||||||
|
|
||||||
await Book.objects.filter(author="Tolstoy, Leo").update(author="Lenin, Vladimir")
|
await Book.objects.filter(author="Tolstoy, Leo").update(
|
||||||
|
author="Lenin, Vladimir"
|
||||||
|
)
|
||||||
all_books = await Book.objects.filter(author="Lenin, Vladimir").all()
|
all_books = await Book.objects.filter(author="Lenin, Vladimir").all()
|
||||||
assert len(all_books) == 2
|
assert len(all_books) == 2
|
||||||
|
|
||||||
historic_books = await Book.objects.filter(genre='Historic').all()
|
historic_books = await Book.objects.filter(genre="Historic").all()
|
||||||
assert len(historic_books) == 0
|
assert len(historic_books) == 0
|
||||||
|
|
||||||
with pytest.raises(QueryDefinitionError):
|
with pytest.raises(QueryDefinitionError):
|
||||||
await Book.objects.update(genre='Historic')
|
await Book.objects.update(genre="Historic")
|
||||||
|
|
||||||
await Book.objects.filter(author="Lenin, Vladimir").update(genre='Historic')
|
await Book.objects.filter(author="Lenin, Vladimir").update(genre="Historic")
|
||||||
|
|
||||||
historic_books = await Book.objects.filter(genre='Historic').all()
|
historic_books = await Book.objects.filter(genre="Historic").all()
|
||||||
assert len(historic_books) == 2
|
assert len(historic_books) == 2
|
||||||
|
|
||||||
await Book.objects.delete(genre='Fantasy')
|
await Book.objects.delete(genre="Fantasy")
|
||||||
all_books = await Book.objects.all()
|
all_books = await Book.objects.all()
|
||||||
assert len(all_books) == 3
|
assert len(all_books) == 3
|
||||||
|
|
||||||
await Book.objects.update(each=True, genre='Fiction')
|
await Book.objects.update(each=True, genre="Fiction")
|
||||||
all_books = await Book.objects.filter(genre='Fiction').all()
|
all_books = await Book.objects.filter(genre="Fiction").all()
|
||||||
assert len(all_books) == 3
|
assert len(all_books) == 3
|
||||||
|
|
||||||
with pytest.raises(QueryDefinitionError):
|
with pytest.raises(QueryDefinitionError):
|
||||||
@ -78,29 +126,184 @@ async def test_delete_and_update():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_or_create():
|
async def test_get_or_create():
|
||||||
async with database:
|
async with database:
|
||||||
tom = await Book.objects.get_or_create(title="Volume I", author='Anonymous', genre='Fiction')
|
tom = await Book.objects.get_or_create(
|
||||||
|
title="Volume I", author="Anonymous", genre="Fiction"
|
||||||
|
)
|
||||||
assert await Book.objects.count() == 1
|
assert await Book.objects.count() == 1
|
||||||
|
|
||||||
assert await Book.objects.get_or_create(title="Volume I", author='Anonymous', genre='Fiction') == tom
|
assert (
|
||||||
|
await Book.objects.get_or_create(
|
||||||
|
title="Volume I", author="Anonymous", genre="Fiction"
|
||||||
|
)
|
||||||
|
== tom
|
||||||
|
)
|
||||||
assert await Book.objects.count() == 1
|
assert await Book.objects.count() == 1
|
||||||
|
|
||||||
assert await Book.objects.create(title="Volume I", author='Anonymous', genre='Fiction')
|
assert await Book.objects.create(
|
||||||
|
title="Volume I", author="Anonymous", genre="Fiction"
|
||||||
|
)
|
||||||
with pytest.raises(ormar.exceptions.MultipleMatches):
|
with pytest.raises(ormar.exceptions.MultipleMatches):
|
||||||
await Book.objects.get_or_create(title="Volume I", author='Anonymous', genre='Fiction')
|
await Book.objects.get_or_create(
|
||||||
|
title="Volume I", author="Anonymous", genre="Fiction"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_or_create():
|
async def test_update_or_create():
|
||||||
async with database:
|
async with database:
|
||||||
tom = await Book.objects.update_or_create(title="Volume I", author='Anonymous', genre='Fiction')
|
tom = await Book.objects.update_or_create(
|
||||||
|
title="Volume I", author="Anonymous", genre="Fiction"
|
||||||
|
)
|
||||||
assert await Book.objects.count() == 1
|
assert await Book.objects.count() == 1
|
||||||
|
|
||||||
assert await Book.objects.update_or_create(id=tom.id, genre='Historic')
|
assert await Book.objects.update_or_create(id=tom.id, genre="Historic")
|
||||||
assert await Book.objects.count() == 1
|
assert await Book.objects.count() == 1
|
||||||
|
|
||||||
assert await Book.objects.update_or_create(pk=tom.id, genre='Fantasy')
|
assert await Book.objects.update_or_create(pk=tom.id, genre="Fantasy")
|
||||||
assert await Book.objects.count() == 1
|
assert await Book.objects.count() == 1
|
||||||
|
|
||||||
assert await Book.objects.create(title="Volume I", author='Anonymous', genre='Fantasy')
|
assert await Book.objects.create(
|
||||||
|
title="Volume I", author="Anonymous", genre="Fantasy"
|
||||||
|
)
|
||||||
with pytest.raises(ormar.exceptions.MultipleMatches):
|
with pytest.raises(ormar.exceptions.MultipleMatches):
|
||||||
await Book.objects.get(title="Volume I", author='Anonymous', genre='Fantasy')
|
await Book.objects.get(
|
||||||
|
title="Volume I", author="Anonymous", genre="Fantasy"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_create():
|
||||||
|
async with database:
|
||||||
|
await ToDo.objects.bulk_create(
|
||||||
|
[
|
||||||
|
ToDo(text="Buy the groceries."),
|
||||||
|
ToDo(text="Call Mum.", completed=True),
|
||||||
|
ToDo(text="Send invoices.", completed=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
todoes = await ToDo.objects.all()
|
||||||
|
assert len(todoes) == 3
|
||||||
|
for todo in todoes:
|
||||||
|
assert todo.pk is not None
|
||||||
|
|
||||||
|
completed = await ToDo.objects.filter(completed=True).all()
|
||||||
|
assert len(completed) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_create_with_relation():
|
||||||
|
async with database:
|
||||||
|
category = await Category.objects.create(name="Sample Category")
|
||||||
|
|
||||||
|
await Note.objects.bulk_create(
|
||||||
|
[
|
||||||
|
Note(text="Buy the groceries.", category=category),
|
||||||
|
Note(text="Call Mum.", category=category),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
todoes = await Note.objects.all()
|
||||||
|
assert len(todoes) == 2
|
||||||
|
for todo in todoes:
|
||||||
|
assert todo.category.pk == category.pk
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_update():
|
||||||
|
async with database:
|
||||||
|
await ToDo.objects.bulk_create(
|
||||||
|
[
|
||||||
|
ToDo(text="Buy the groceries."),
|
||||||
|
ToDo(text="Call Mum.", completed=True),
|
||||||
|
ToDo(text="Send invoices.", completed=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
todoes = await ToDo.objects.all()
|
||||||
|
assert len(todoes) == 3
|
||||||
|
|
||||||
|
for todo in todoes:
|
||||||
|
todo.text = todo.text + "_1"
|
||||||
|
todo.completed = False
|
||||||
|
|
||||||
|
await ToDo.objects.bulk_update(todoes)
|
||||||
|
|
||||||
|
completed = await ToDo.objects.filter(completed=False).all()
|
||||||
|
assert len(completed) == 3
|
||||||
|
|
||||||
|
todoes = await ToDo.objects.all()
|
||||||
|
assert len(todoes) == 3
|
||||||
|
|
||||||
|
for todo in todoes:
|
||||||
|
assert todo.text[-2:] == "_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_update_with_only_selected_columns():
|
||||||
|
async with database:
|
||||||
|
await ToDo.objects.bulk_create(
|
||||||
|
[
|
||||||
|
ToDo(text="Reset the world simulation.", completed=False),
|
||||||
|
ToDo(text="Watch kittens.", completed=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
todoes = await ToDo.objects.all()
|
||||||
|
assert len(todoes) == 2
|
||||||
|
|
||||||
|
for todo in todoes:
|
||||||
|
todo.text = todo.text + "_1"
|
||||||
|
todo.completed = False
|
||||||
|
|
||||||
|
await ToDo.objects.bulk_update(todoes, columns=["completed"])
|
||||||
|
|
||||||
|
completed = await ToDo.objects.filter(completed=False).all()
|
||||||
|
assert len(completed) == 2
|
||||||
|
|
||||||
|
todoes = await ToDo.objects.all()
|
||||||
|
assert len(todoes) == 2
|
||||||
|
|
||||||
|
for todo in todoes:
|
||||||
|
assert todo.text[-2:] != "_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_update_with_relation():
|
||||||
|
async with database:
|
||||||
|
category = await Category.objects.create(name="Sample Category")
|
||||||
|
category2 = await Category.objects.create(name="Sample II Category")
|
||||||
|
|
||||||
|
await Note.objects.bulk_create(
|
||||||
|
[
|
||||||
|
Note(text="Buy the groceries.", category=category),
|
||||||
|
Note(text="Call Mum.", category=category),
|
||||||
|
Note(text="Text skynet.", category=category),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
notes = await Note.objects.all()
|
||||||
|
assert len(notes) == 3
|
||||||
|
|
||||||
|
for note in notes:
|
||||||
|
note.category = category2
|
||||||
|
|
||||||
|
await Note.objects.bulk_update(notes)
|
||||||
|
|
||||||
|
notes_upd = await Note.objects.all()
|
||||||
|
assert len(notes_upd) == 3
|
||||||
|
|
||||||
|
for note in notes_upd:
|
||||||
|
assert note.category.pk == category2.pk
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_update_not_saved_objts():
|
||||||
|
async with database:
|
||||||
|
category = await Category.objects.create(name="Sample Category")
|
||||||
|
with pytest.raises(QueryDefinitionError):
|
||||||
|
await Note.objects.bulk_update(
|
||||||
|
[
|
||||||
|
Note(text="Buy the groceries.", category=category),
|
||||||
|
Note(text="Call Mum.", category=category),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user