From 2385f95a9ffc86f9c2dddff528cf0cc2c0447a7d Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 20 Nov 2020 11:03:57 +0100 Subject: [PATCH] fix saving string pk and add db refresh for server_defaults in save() also, bump version --- ormar/__init__.py | 2 +- ormar/models/model.py | 16 +++++-- tests/test_saving_string_pks.py | 81 +++++++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 4 deletions(-) create mode 100644 tests/test_saving_string_pks.py diff --git a/ormar/__init__.py b/ormar/__init__.py index 7f9e9e2..87f1b5b 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.5.0" +__version__ = "0.5.1" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/models/model.py b/ormar/models/model.py index fb5c295..70d8e42 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -197,9 +197,19 @@ class Model(NewBaseModel): expr = self.Meta.table.insert() expr = expr.values(**self_fields) - item_id = await self.Meta.database.execute(expr) - if item_id: # postgress does not return id if it's already there - setattr(self, self.Meta.pkname, item_id) + pk = await self.Meta.database.execute(expr) + if pk and isinstance(pk, self.pk_type()): + setattr(self, self.Meta.pkname, pk) + + # refresh server side defaults + if any( + field.server_default is not None + for name, field in self.Meta.model_fields.items() + if name not in self_fields + ): + await self.load() + return self + self.set_save_status(True) return self diff --git a/tests/test_saving_string_pks.py b/tests/test_saving_string_pks.py new file mode 100644 index 0000000..65bda9c --- /dev/null +++ b/tests/test_saving_string_pks.py @@ -0,0 +1,81 @@ +from random import choice +from string import ascii_uppercase + +import databases +import pytest +import sqlalchemy +from sqlalchemy import create_engine + +import ormar +from ormar import Float, String +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +def get_id() -> str: + return "".join(choice(ascii_uppercase) for _ in range(12)) + + +class MainMeta(ormar.ModelMeta): + metadata = metadata + database = database + + +class PositionOrm(ormar.Model): + class Meta(MainMeta): + pass + + name: str = String(primary_key=True, max_length=50) + x: float = Float() + y: float = Float() + degrees: float = Float() + + +class PositionOrmDef(ormar.Model): + class Meta(MainMeta): + pass + + name: str = String(primary_key=True, max_length=50, default=get_id) + x: float = Float() + y: float = Float() + degrees: float = Float() + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.fixture(scope="function") +async def cleanup(): + yield + async with database: + await PositionOrm.objects.delete(each=True) + await PositionOrmDef.objects.delete(each=True) + + +@pytest.mark.asyncio +async def test_creating_a_position(cleanup): + async with database: + instance = PositionOrm(name="my_pos", x=1.0, y=2.0, degrees=3.0,) + await instance.save() + assert instance.saved + assert instance.name == "my_pos" + + instance2 = PositionOrmDef(x=1.0, y=2.0, degrees=3.0,) + await instance2.save() + assert instance2.saved + assert instance2.name is not None + assert len(instance2.name) == 12 + + instance3 = PositionOrmDef(x=1.0, y=2.0, degrees=3.0,) + await instance3.save() + assert instance3.saved + assert instance3.name is not None + assert len(instance3.name) == 12 + assert instance2.name != instance3.name