add tests
This commit is contained in:
@ -1037,7 +1037,9 @@ class QuerySet(Generic[T]):
|
|||||||
instance = await instance.save()
|
instance = await instance.save()
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
async def bulk_create(self, objects: List["T"]) -> None:
|
async def bulk_create(
|
||||||
|
self, objects: List["T"], send_signals: bool = False
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Performs a bulk create in one database session to speed up the process.
|
Performs a bulk create in one database session to speed up the process.
|
||||||
|
|
||||||
@ -1049,27 +1051,42 @@ class QuerySet(Generic[T]):
|
|||||||
|
|
||||||
:param objects: list of ormar models already initialized and ready to save.
|
:param objects: list of ormar models already initialized and ready to save.
|
||||||
:type objects: List[Model]
|
:type objects: List[Model]
|
||||||
|
:param send_signals: confirm send the pre/post create signals
|
||||||
|
:type send_signals: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async def before_create(entity: "T") -> None:
|
||||||
|
await entity.signals.pre_save.send(
|
||||||
|
sender=entity.__class__, instance=entity)
|
||||||
|
|
||||||
|
async def after_create(entity: "T") -> None:
|
||||||
|
entity.set_save_status(True)
|
||||||
|
# FIXME do not have the `id` value, because not reload from db
|
||||||
|
if send_signals:
|
||||||
|
await entity.signals.post_save.send(
|
||||||
|
sender=entity.__class__, instance=entity)
|
||||||
|
|
||||||
|
if send_signals:
|
||||||
|
await asyncio.gather(
|
||||||
|
*[before_create(entity) for entity in objects]
|
||||||
|
)
|
||||||
|
|
||||||
ready_objects = [
|
ready_objects = [
|
||||||
obj.prepare_model_to_save(obj.dict())
|
obj.prepare_model_to_save(obj.dict())
|
||||||
for obj in objects
|
for obj in objects
|
||||||
]
|
]
|
||||||
expr = self.table.insert().values(ready_objects)
|
expr = self.table.insert().values(ready_objects)
|
||||||
|
|
||||||
# shouldn't use the execute_many, it's `queries.foreach(execute)`
|
# shouldn't use the execute_many, it's `queries.foreach(execute)`
|
||||||
await self.database.execute(expr)
|
await self.database.execute(expr)
|
||||||
|
|
||||||
# FIXME: add pre_save signals
|
|
||||||
async def after_create(entity: "T") -> None:
|
|
||||||
entity.set_save_status(True)
|
|
||||||
await entity.signals.post_save.send(
|
|
||||||
sender=entity.__class__, instance=entity)
|
|
||||||
|
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[after_create(entity) for entity in objects]
|
*[after_create(entity) for entity in objects]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def bulk_update( # noqa: CCR001
|
async def bulk_update( # noqa: CCR001
|
||||||
self, objects: List["T"], columns: List[str] = None
|
self, objects: List["T"], columns: List[str] = None,
|
||||||
|
send_signals: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Performs bulk update in one database session to speed up the process.
|
Performs bulk update in one database session to speed up the process.
|
||||||
@ -1087,6 +1104,8 @@ class QuerySet(Generic[T]):
|
|||||||
:type objects: List[Model]
|
:type objects: List[Model]
|
||||||
:param columns: list of columns to update
|
:param columns: list of columns to update
|
||||||
:type columns: List[str]
|
:type columns: List[str]
|
||||||
|
:param send_signals: confirm send the pre/post create signals
|
||||||
|
:type send_signals: bool
|
||||||
"""
|
"""
|
||||||
ready_objects = []
|
ready_objects = []
|
||||||
pk_name = self.model_meta.pkname
|
pk_name = self.model_meta.pkname
|
||||||
@ -1135,9 +1154,10 @@ class QuerySet(Generic[T]):
|
|||||||
# FIXME: add pre-update-signals
|
# FIXME: add pre-update-signals
|
||||||
async def after_update(entity: "T") -> None:
|
async def after_update(entity: "T") -> None:
|
||||||
entity.set_save_status(True)
|
entity.set_save_status(True)
|
||||||
await entity.signals.post_update.send(
|
if send_signals:
|
||||||
sender=entity.__class__, instance=entity
|
await entity.signals.post_update.send(
|
||||||
)
|
sender=entity.__class__, instance=entity
|
||||||
|
)
|
||||||
|
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[after_update(entity) for entity in objects]
|
*[after_update(entity) for entity in objects]
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import pytest
|
|||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
import ormar
|
import ormar
|
||||||
|
from ormar import post_save, post_update, pre_save
|
||||||
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
|
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
|
||||||
from tests.settings import DATABASE_URL
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
@ -193,6 +194,35 @@ async def test_bulk_create():
|
|||||||
assert len(completed) == 2
|
assert len(completed) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_create_send_signals():
|
||||||
|
async with database:
|
||||||
|
|
||||||
|
@pre_save(ToDo)
|
||||||
|
async def before_save(sender, instance, **kwargs):
|
||||||
|
instance.completed = False
|
||||||
|
|
||||||
|
@post_save(ToDo)
|
||||||
|
async def after_save(sender, instance, **kwargs):
|
||||||
|
assert not instance.completed
|
||||||
|
|
||||||
|
await ToDo.objects.bulk_create(
|
||||||
|
[
|
||||||
|
ToDo(text="Buy the groceries."),
|
||||||
|
ToDo(text="Call Mum.", completed=True),
|
||||||
|
ToDo(text="Send invoices.", completed=True),
|
||||||
|
], send_signals=True
|
||||||
|
)
|
||||||
|
|
||||||
|
todoes = await ToDo.objects.all()
|
||||||
|
assert len(todoes) == 3
|
||||||
|
for todo in todoes:
|
||||||
|
assert todo.pk is not None
|
||||||
|
|
||||||
|
count = await ToDo.objects.filter(completed=False).count()
|
||||||
|
assert count == 3
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bulk_create_with_relation():
|
async def test_bulk_create_with_relation():
|
||||||
async with database:
|
async with database:
|
||||||
@ -240,6 +270,34 @@ async def test_bulk_update():
|
|||||||
assert todo.text[-2:] == "_1"
|
assert todo.text[-2:] == "_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_update_send_signals():
|
||||||
|
async with database:
|
||||||
|
|
||||||
|
@post_update(ToDo)
|
||||||
|
async def after_update(sender, instance, **kwargs):
|
||||||
|
await instance.delete()
|
||||||
|
|
||||||
|
await ToDo.objects.bulk_create(
|
||||||
|
[
|
||||||
|
ToDo(text="Buy the groceries."),
|
||||||
|
ToDo(text="Call Mum.", completed=True),
|
||||||
|
ToDo(text="Send invoices.", completed=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
todoes = await ToDo.objects.all()
|
||||||
|
assert len(todoes) == 3
|
||||||
|
|
||||||
|
for todo in todoes:
|
||||||
|
todo.text = todo.text + "_1"
|
||||||
|
todo.completed = False
|
||||||
|
|
||||||
|
await ToDo.objects.bulk_update(todoes, send_signals=True)
|
||||||
|
|
||||||
|
count = await ToDo.objects.filter(completed=False).count()
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bulk_update_with_only_selected_columns():
|
async def test_bulk_update_with_only_selected_columns():
|
||||||
async with database:
|
async with database:
|
||||||
|
|||||||
Reference in New Issue
Block a user