From cd33f6a96baf15bccaccb31a9285a02f61ef8bd6 Mon Sep 17 00:00:00 2001 From: collerek Date: Sat, 14 Nov 2020 14:29:54 +0100 Subject: [PATCH] introduce save_related method that traverses the related objects and upserts them if they are not saved --- ormar/models/model.py | 17 +++++ tests/test_save_related.py | 151 +++++++++++++++++++++++++++++++++++++ tests/test_save_status.py | 4 +- 3 files changed, 169 insertions(+), 3 deletions(-) create mode 100644 tests/test_save_related.py diff --git a/ormar/models/model.py b/ormar/models/model.py index ac3f7ce..437ed94 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -192,6 +192,23 @@ class Model(NewBaseModel): self.set_save_status(True) return self + async def save_related(self) -> int: + update_count = 0 + for related in self.extract_related_names(): + if self.Meta.model_fields[related].virtual or issubclass( + self.Meta.model_fields[related], ManyToManyField + ): + for rel in getattr(self, related): + if not rel.saved: + await rel.upsert() + update_count += 1 + else: + rel = getattr(self, related) + if not rel.saved: + await rel.upsert() + update_count += 1 + return update_count + async def update(self: T, **kwargs: Any) -> T: if kwargs: new_values = {**self.dict(), **kwargs} diff --git a/tests/test_save_related.py b/tests/test_save_related.py new file mode 100644 index 0000000..1defabf --- /dev/null +++ b/tests/test_save_related.py @@ -0,0 +1,151 @@ +from typing import List + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class NickNames(ormar.Model): + class Meta: + tablename = "nicks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + is_lame: bool = ormar.Boolean(nullable=True) + + +class NicksHq(ormar.Model): + class Meta: + tablename = "nicks_x_hq" + metadata = metadata + database = database + + +class HQ(ormar.Model): + class Meta: + tablename = "hqs" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq) + + +class Company(ormar.Model): + class Meta: + tablename = "companies" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="company_name") + founded: int = ormar.Integer(nullable=True) + hq: HQ = ormar.ForeignKey(HQ, related_name="companies") + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_saving_related_fk_rel(): + async with database: + async with database.transaction(force_rollback=True): + hq = await HQ.objects.create(name="Main") + comp = await Company.objects.create(name="Banzai", founded=1988, hq=hq) + assert comp.saved + + count = await comp.save_related() + assert count == 0 + + comp.hq.name = "Suburbs" + assert not comp.hq.saved + assert comp.saved + + count = await comp.save_related() + assert count == 1 + assert comp.hq.saved + + +@pytest.mark.asyncio +async def test_adding_many_to_many_does_not_gets_dirty(): + async with database: + async with database.transaction(force_rollback=True): + nick1 = await NickNames.objects.create(name="BazingaO", is_lame=False) + nick2 = await NickNames.objects.create(name="Bazinga20", is_lame=True) + + hq = await HQ.objects.create(name="Main") + assert hq.saved + + await hq.nicks.add(nick1) + assert hq.saved + await hq.nicks.add(nick2) + assert hq.saved + + count = await hq.save_related() + assert count == 0 + + hq.nicks[0].name = "Kabucha" + hq.nicks[1].name = "Kabucha2" + assert not hq.nicks[0].saved + assert not hq.nicks[1].saved + + count = await hq.save_related() + assert count == 2 + assert hq.nicks[0].saved + assert hq.nicks[1].saved + + +@pytest.mark.asyncio +async def test_queryset_methods(): + async with database: + async with database.transaction(force_rollback=True): + hq = await HQ.objects.create(name="Main") + await Company.objects.create(name="Banzai", founded=1988, hq=hq) + + hq = await HQ.objects.select_related("companies").get(name="Main") + assert hq.saved + assert hq.companies[0].saved + + hq.companies[0].name = "Konichiwa" + assert not hq.companies[0].saved + count = await hq.save_related() + assert count == 1 + assert hq.companies[0].saved + + await Company.objects.create(name="Joshua", founded=1888, hq=hq) + + hq = await HQ.objects.select_related("companies").get(name="Main") + assert hq.saved + assert hq.companies[0].saved + assert hq.companies[1].saved + + hq.companies[0].name = hq.companies[0].name + "20" + assert not hq.companies[0].saved + # save only if not saved so now only one + count = await hq.save_related() + assert count == 1 + assert hq.companies[0].saved + + hq.companies[0].name = hq.companies[0].name + "20" + hq.companies[1].name = hq.companies[1].name + "30" + assert not hq.companies[0].saved + assert not hq.companies[1].saved + count = await hq.save_related() + assert count == 2 + assert hq.companies[0].saved + assert hq.companies[1].saved diff --git a/tests/test_save_status.py b/tests/test_save_status.py index ec15ccf..93e89ac 100644 --- a/tests/test_save_status.py +++ b/tests/test_save_status.py @@ -1,8 +1,6 @@ -import itertools -from typing import Optional, List +from typing import List import databases -import pydantic import pytest import sqlalchemy