fix saving string pk and add db refresh for server_defaults in save() also, bump version
This commit is contained in:
@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover
|
|||||||
|
|
||||||
Undefined = UndefinedType()
|
Undefined = UndefinedType()
|
||||||
|
|
||||||
__version__ = "0.5.0"
|
__version__ = "0.5.1"
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Integer",
|
"Integer",
|
||||||
"BigInteger",
|
"BigInteger",
|
||||||
|
|||||||
@ -197,9 +197,19 @@ class Model(NewBaseModel):
|
|||||||
expr = self.Meta.table.insert()
|
expr = self.Meta.table.insert()
|
||||||
expr = expr.values(**self_fields)
|
expr = expr.values(**self_fields)
|
||||||
|
|
||||||
item_id = await self.Meta.database.execute(expr)
|
pk = await self.Meta.database.execute(expr)
|
||||||
if item_id: # postgress does not return id if it's already there
|
if pk and isinstance(pk, self.pk_type()):
|
||||||
setattr(self, self.Meta.pkname, item_id)
|
setattr(self, self.Meta.pkname, pk)
|
||||||
|
|
||||||
|
# refresh server side defaults
|
||||||
|
if any(
|
||||||
|
field.server_default is not None
|
||||||
|
for name, field in self.Meta.model_fields.items()
|
||||||
|
if name not in self_fields
|
||||||
|
):
|
||||||
|
await self.load()
|
||||||
|
return self
|
||||||
|
|
||||||
self.set_save_status(True)
|
self.set_save_status(True)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
81
tests/test_saving_string_pks.py
Normal file
81
tests/test_saving_string_pks.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
from random import choice
|
||||||
|
from string import ascii_uppercase
|
||||||
|
|
||||||
|
import databases
|
||||||
|
import pytest
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
import ormar
|
||||||
|
from ormar import Float, String
|
||||||
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
|
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||||
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
|
|
||||||
|
def get_id() -> str:
|
||||||
|
return "".join(choice(ascii_uppercase) for _ in range(12))
|
||||||
|
|
||||||
|
|
||||||
|
class MainMeta(ormar.ModelMeta):
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
|
||||||
|
class PositionOrm(ormar.Model):
|
||||||
|
class Meta(MainMeta):
|
||||||
|
pass
|
||||||
|
|
||||||
|
name: str = String(primary_key=True, max_length=50)
|
||||||
|
x: float = Float()
|
||||||
|
y: float = Float()
|
||||||
|
degrees: float = Float()
|
||||||
|
|
||||||
|
|
||||||
|
class PositionOrmDef(ormar.Model):
|
||||||
|
class Meta(MainMeta):
|
||||||
|
pass
|
||||||
|
|
||||||
|
name: str = String(primary_key=True, max_length=50, default=get_id)
|
||||||
|
x: float = Float()
|
||||||
|
y: float = Float()
|
||||||
|
degrees: float = Float()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True, scope="module")
|
||||||
|
def create_test_database():
|
||||||
|
engine = create_engine(DATABASE_URL)
|
||||||
|
metadata.create_all(engine)
|
||||||
|
yield
|
||||||
|
metadata.drop_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
async def cleanup():
|
||||||
|
yield
|
||||||
|
async with database:
|
||||||
|
await PositionOrm.objects.delete(each=True)
|
||||||
|
await PositionOrmDef.objects.delete(each=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_creating_a_position(cleanup):
|
||||||
|
async with database:
|
||||||
|
instance = PositionOrm(name="my_pos", x=1.0, y=2.0, degrees=3.0,)
|
||||||
|
await instance.save()
|
||||||
|
assert instance.saved
|
||||||
|
assert instance.name == "my_pos"
|
||||||
|
|
||||||
|
instance2 = PositionOrmDef(x=1.0, y=2.0, degrees=3.0,)
|
||||||
|
await instance2.save()
|
||||||
|
assert instance2.saved
|
||||||
|
assert instance2.name is not None
|
||||||
|
assert len(instance2.name) == 12
|
||||||
|
|
||||||
|
instance3 = PositionOrmDef(x=1.0, y=2.0, degrees=3.0,)
|
||||||
|
await instance3.save()
|
||||||
|
assert instance3.saved
|
||||||
|
assert instance3.name is not None
|
||||||
|
assert len(instance3.name) == 12
|
||||||
|
assert instance2.name != instance3.name
|
||||||
Reference in New Issue
Block a user