fix saving string pk and add db refresh for server_defaults in save() also, bump version

This commit is contained in:
collerek
2020-11-20 11:03:57 +01:00
parent a77fb01b10
commit 2385f95a9f
3 changed files with 95 additions and 4 deletions

View File

@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType() Undefined = UndefinedType()
__version__ = "0.5.0" __version__ = "0.5.1"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",

View File

@ -197,9 +197,19 @@ class Model(NewBaseModel):
expr = self.Meta.table.insert() expr = self.Meta.table.insert()
expr = expr.values(**self_fields) expr = expr.values(**self_fields)
item_id = await self.Meta.database.execute(expr) pk = await self.Meta.database.execute(expr)
if item_id: # postgress does not return id if it's already there if pk and isinstance(pk, self.pk_type()):
setattr(self, self.Meta.pkname, item_id) setattr(self, self.Meta.pkname, pk)
# refresh server side defaults
if any(
field.server_default is not None
for name, field in self.Meta.model_fields.items()
if name not in self_fields
):
await self.load()
return self
self.set_save_status(True) self.set_save_status(True)
return self return self

View File

@ -0,0 +1,81 @@
from random import choice
from string import ascii_uppercase
import databases
import pytest
import sqlalchemy
from sqlalchemy import create_engine
import ormar
from ormar import Float, String
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
def get_id() -> str:
return "".join(choice(ascii_uppercase) for _ in range(12))
class MainMeta(ormar.ModelMeta):
metadata = metadata
database = database
class PositionOrm(ormar.Model):
class Meta(MainMeta):
pass
name: str = String(primary_key=True, max_length=50)
x: float = Float()
y: float = Float()
degrees: float = Float()
class PositionOrmDef(ormar.Model):
class Meta(MainMeta):
pass
name: str = String(primary_key=True, max_length=50, default=get_id)
x: float = Float()
y: float = Float()
degrees: float = Float()
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.fixture(scope="function")
async def cleanup():
yield
async with database:
await PositionOrm.objects.delete(each=True)
await PositionOrmDef.objects.delete(each=True)
@pytest.mark.asyncio
async def test_creating_a_position(cleanup):
async with database:
instance = PositionOrm(name="my_pos", x=1.0, y=2.0, degrees=3.0,)
await instance.save()
assert instance.saved
assert instance.name == "my_pos"
instance2 = PositionOrmDef(x=1.0, y=2.0, degrees=3.0,)
await instance2.save()
assert instance2.saved
assert instance2.name is not None
assert len(instance2.name) == 12
instance3 = PositionOrmDef(x=1.0, y=2.0, degrees=3.0,)
await instance3.save()
assert instance3.saved
assert instance3.name is not None
assert len(instance3.name) == 12
assert instance2.name != instance3.name