From e805ff16b215b46a752cd2179d4517667ecff5a0 Mon Sep 17 00:00:00 2001 From: collerek Date: Sat, 14 Nov 2020 13:53:32 +0100 Subject: [PATCH] introduce upsert method on model, add tests to see if save status properly changing on nested models --- ormar/exceptions.py | 4 ++ ormar/models/model.py | 11 +++++ tests/test_save_status.py | 95 ++++++++++++++++++++++++++------------- 3 files changed, 80 insertions(+), 30 deletions(-) diff --git a/ormar/exceptions.py b/ormar/exceptions.py index 40cfd26..0800a81 100644 --- a/ormar/exceptions.py +++ b/ormar/exceptions.py @@ -24,3 +24,7 @@ class QueryDefinitionError(AsyncOrmException): class RelationshipInstanceError(AsyncOrmException): pass + + +class ModelPersistenceError(AsyncOrmException): + pass diff --git a/ormar/models/model.py b/ormar/models/model.py index 5141604..ac3f7ce 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Type, TypeVar, import sqlalchemy import ormar.queryset # noqa I100 +from ormar.exceptions import ModelPersistenceError from ormar.fields.many_to_many import ManyToManyField from ormar.models import NewBaseModel # noqa I100 from ormar.models.metaclass import ModelMeta @@ -169,6 +170,11 @@ class Model(NewBaseModel): return item + async def upsert(self: T, **kwargs: Any) -> T: + if not self.pk: + return await self.save() + return await self.update(**kwargs) + async def save(self: T) -> T: self_fields = self._extract_model_db_fields() @@ -191,6 +197,11 @@ class Model(NewBaseModel): new_values = {**self.dict(), **kwargs} self.from_dict(new_values) + if not self.pk: + raise ModelPersistenceError( + "You cannot update not saved model! Use save or upsert method." + ) + self_fields = self._extract_model_db_fields() self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname)) self_fields = self.translate_columns_to_aliases(self_fields) diff --git a/tests/test_save_status.py b/tests/test_save_status.py index baea547..c94a171 100644 --- a/tests/test_save_status.py +++ b/tests/test_save_status.py @@ -7,6 +7,7 @@ import pytest import sqlalchemy import ormar +from ormar.exceptions import ModelPersistenceError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -67,7 +68,7 @@ def create_test_database(): async def test_instantation_false_save_true(): async with database: async with database.transaction(force_rollback=True): - comp = Company(name='Banzai', founded=1988) + comp = Company(name="Banzai", founded=1988) assert not comp._orm_saved await comp.save() assert comp._orm_saved @@ -77,21 +78,21 @@ async def test_instantation_false_save_true(): async def test_saved_edited_not_saved(): async with database: async with database.transaction(force_rollback=True): - comp = await Company.objects.create(name='Banzai', founded=1988) + comp = await Company.objects.create(name="Banzai", founded=1988) assert comp._orm_saved - comp.name = 'Banzai2' + comp.name = "Banzai2" assert not comp._orm_saved await comp.update() assert comp._orm_saved - await comp.update(name='Banzai3') + await comp.update(name="Banzai3") assert comp._orm_saved comp.pk = 999 assert not comp._orm_saved - await comp.save() + await comp.update() assert comp._orm_saved @@ -99,8 +100,8 @@ async def test_saved_edited_not_saved(): async def test_adding_related_gets_dirty(): async with database: async with database.transaction(force_rollback=True): - hq = await HQ.objects.create(name='Main') - comp = await Company.objects.create(name='Banzai', founded=1988) + hq = await HQ.objects.create(name="Main") + comp = await Company.objects.create(name="Banzai", founded=1988) assert comp._orm_saved comp.hq = hq @@ -108,20 +109,28 @@ async def test_adding_related_gets_dirty(): await comp.update() assert comp._orm_saved - comp = await Company.objects.select_related('hq').get(name='Banzai') + comp = await Company.objects.select_related("hq").get(name="Banzai") assert comp._orm_saved + assert comp.hq.pk == hq.pk assert comp.hq._orm_saved + comp.hq.name = "Suburbs" + assert not comp.hq._orm_saved + assert comp._orm_saved + + await comp.hq.update() + assert comp.hq._orm_saved + @pytest.mark.asyncio async def test_adding_many_to_many_does_not_gets_dirty(): async with database: async with database.transaction(force_rollback=True): - nick1 = await NickNames.objects.create(name='Bazinga', is_lame=False) - nick2 = await NickNames.objects.create(name='Bazinga2', is_lame=True) + nick1 = await NickNames.objects.create(name="Bazinga", is_lame=False) + nick2 = await NickNames.objects.create(name="Bazinga2", is_lame=True) - hq = await HQ.objects.create(name='Main') + hq = await HQ.objects.create(name="Main") assert hq._orm_saved await hq.nicks.add(nick1) @@ -129,24 +138,30 @@ async def test_adding_many_to_many_does_not_gets_dirty(): await hq.nicks.add(nick2) assert hq._orm_saved - hq = await HQ.objects.select_related('nicks').get(name='Main') + hq = await HQ.objects.select_related("nicks").get(name="Main") assert hq._orm_saved assert hq.nicks[0]._orm_saved await hq.nicks.remove(nick1) assert hq._orm_saved + hq.nicks[0].name = "Kabucha" + assert not hq.nicks[0]._orm_saved + + await hq.nicks[0].update() + assert hq.nicks[0]._orm_saved + @pytest.mark.asyncio async def test_delete(): async with database: async with database.transaction(force_rollback=True): - comp = await Company.objects.create(name='Banzai', founded=1988) + comp = await Company.objects.create(name="Banzai", founded=1988) assert comp._orm_saved await comp.delete() assert not comp._orm_saved - await comp.save() + await comp.update() assert comp._orm_saved @@ -154,26 +169,26 @@ async def test_delete(): async def test_load(): async with database: async with database.transaction(force_rollback=True): - comp = await Company.objects.create(name='Banzai', founded=1988) + comp = await Company.objects.create(name="Banzai", founded=1988) assert comp._orm_saved - comp.name = 'AA' + comp.name = "AA" assert not comp._orm_saved await comp.load() assert comp._orm_saved - assert comp.name == 'Banzai' + assert comp.name == "Banzai" @pytest.mark.asyncio async def test_queryset_methods(): async with database: async with database.transaction(force_rollback=True): - await Company.objects.create(name='Banzai', founded=1988) - await Company.objects.create(name='Yuhu', founded=1989) - await Company.objects.create(name='Konono', founded=1990) - await Company.objects.create(name='Sumaaa', founded=1991) + await Company.objects.create(name="Banzai", founded=1988) + await Company.objects.create(name="Yuhu", founded=1989) + await Company.objects.create(name="Konono", founded=1990) + await Company.objects.create(name="Sumaaa", founded=1991) - comp = await Company.objects.get(name='Banzai') + comp = await Company.objects.get(name="Banzai") assert comp._orm_saved comp = await Company.objects.first() @@ -182,20 +197,20 @@ async def test_queryset_methods(): comps = await Company.objects.all() assert [comp._orm_saved for comp in comps] - comp2 = await Company.objects.get_or_create(name='Banzai_new', founded=2001) + comp2 = await Company.objects.get_or_create(name="Banzai_new", founded=2001) assert comp2._orm_saved - comp3 = await Company.objects.get_or_create(name='Banzai', founded=1988) + comp3 = await Company.objects.get_or_create(name="Banzai", founded=1988) assert comp3._orm_saved assert comp3.pk == comp.pk update_dict = comp.dict() - update_dict['founded'] = 2010 + update_dict["founded"] = 2010 comp = await Company.objects.update_or_create(**update_dict) assert comp._orm_saved assert comp.founded == 2010 - create_dict = {'name': "Yoko", "founded": 2005} + create_dict = {"name": "Yoko", "founded": 2005} comp = await Company.objects.update_or_create(**create_dict) assert comp._orm_saved assert comp.founded == 2005 @@ -205,16 +220,16 @@ async def test_queryset_methods(): async def test_bulk_methods(): async with database: async with database.transaction(force_rollback=True): - c1 = Company(name='Banzai', founded=1988) - c2 = Company(name='Yuhu', founded=1989) + c1 = Company(name="Banzai", founded=1988) + c2 = Company(name="Yuhu", founded=1989) await Company.objects.bulk_create([c1, c2]) assert c1._orm_saved assert c2._orm_saved c1, c2 = await Company.objects.all() - c1.name = 'Banzai2' - c2.name = 'Yuhu2' + c1.name = "Banzai2" + c2.name = "Yuhu2" assert not c1._orm_saved assert not c2._orm_saved @@ -222,3 +237,23 @@ async def test_bulk_methods(): await Company.objects.bulk_update([c1, c2]) assert c1._orm_saved assert c2._orm_saved + + c3 = Company(name="Cobra", founded=2088) + assert not c3._orm_saved + + with pytest.raises(ModelPersistenceError): + await c3.update() + + await c3.upsert() + assert c3._orm_saved + + c3.name = "Python" + assert not c3._orm_saved + + await c3.upsert() + assert c3._orm_saved + assert c3.name == "Python" + + await c3.upsert(founded=2077) + assert c3._orm_saved + assert c3.founded == 2077