fix server_default not setting value in defaults, pop server_default fields if set to None, add tests, update docs
This commit is contained in:
@ -68,7 +68,7 @@ Used both in sql and pydantic.
|
|||||||
|
|
||||||
A default value used if no other value is passed.
|
A default value used if no other value is passed.
|
||||||
|
|
||||||
In sql invoked on the server side so you can pass i.e. sql function (like now() wrapped in sqlalchemy text() clause).
|
In sql invoked on the server side so you can pass i.e. sql function (like now() or query/value wrapped in sqlalchemy text() clause).
|
||||||
|
|
||||||
If the field has a server_default value it becomes optional.
|
If the field has a server_default value it becomes optional.
|
||||||
|
|
||||||
@ -76,6 +76,24 @@ You can pass a static value or a Callable (function etc.)
|
|||||||
|
|
||||||
Used in sql only.
|
Used in sql only.
|
||||||
|
|
||||||
|
Sample usage:
|
||||||
|
|
||||||
|
```Python hl_lines="19-21"
|
||||||
|
--8<-- "../docs_src/fields/docs004.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
!!!warning
|
||||||
|
`server_default` accepts `str`, `sqlalchemy.sql.elements.ClauseElement` or `sqlalchemy.sql.elements.TextClause`
|
||||||
|
so if you want to set i.e. Integer value you need to wrap it in `sqlalchemy.text()` function like above
|
||||||
|
|
||||||
|
!!!tip
|
||||||
|
You can pass also valid sql (dialect specific) wrapped in `sqlalchemy.text()`
|
||||||
|
|
||||||
|
For example `func.now()` above could be exchanged for `text('(CURRENT_TIMESTAMP)')` for sqlite backend
|
||||||
|
|
||||||
|
!!!info
|
||||||
|
`server_default` is passed straight to sqlalchemy table definition so you can read more in [server default][server default] sqlalchemy documentation
|
||||||
|
|
||||||
### index
|
### index
|
||||||
|
|
||||||
`index`: `bool` = `False` -> by default False,
|
`index`: `bool` = `False` -> by default False,
|
||||||
@ -241,3 +259,4 @@ You can use either `length` and `precision` parameters or `max_digits` and `deci
|
|||||||
[relations]: ./relations.md
|
[relations]: ./relations.md
|
||||||
[queries]: ./queries.md
|
[queries]: ./queries.md
|
||||||
[pydantic]: https://pydantic-docs.helpmanual.io/usage/types/#constrained-types
|
[pydantic]: https://pydantic-docs.helpmanual.io/usage/types/#constrained-types
|
||||||
|
[server default]: https://docs.sqlalchemy.org/en/13/core/defaults.html#server-invoked-ddl-explicit-default-expressions
|
||||||
21
docs_src/fields/docs004.py
Normal file
21
docs_src/fields/docs004.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import databases
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy import func, text
|
||||||
|
|
||||||
|
import ormar
|
||||||
|
|
||||||
|
database = databases.Database("sqlite:///test.db")
|
||||||
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
|
|
||||||
|
class Product(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "product"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
name: ormar.String(max_length=100)
|
||||||
|
company: ormar.String(max_length=200, server_default='Acme')
|
||||||
|
sort_order: ormar.Integer(server_default=text("10"))
|
||||||
|
created: ormar.DateTime(server_default=func.now())
|
||||||
@ -28,7 +28,7 @@ class UndefinedType: # pragma no cover
|
|||||||
|
|
||||||
Undefined = UndefinedType()
|
Undefined = UndefinedType()
|
||||||
|
|
||||||
__version__ = "0.3.9"
|
__version__ = "0.3.11"
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Integer",
|
"Integer",
|
||||||
"BigInteger",
|
"BigInteger",
|
||||||
|
|||||||
@ -33,10 +33,10 @@ class BaseField:
|
|||||||
server_default: Any
|
server_default: Any
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_value(cls) -> Optional[FieldInfo]:
|
def default_value(cls, use_server: bool = False) -> Optional[FieldInfo]:
|
||||||
if cls.is_auto_primary_key():
|
if cls.is_auto_primary_key():
|
||||||
return Field(default=None)
|
return Field(default=None)
|
||||||
if cls.has_default():
|
if cls.has_default(use_server=use_server):
|
||||||
default = cls.default if cls.default is not None else cls.server_default
|
default = cls.default if cls.default is not None else cls.server_default
|
||||||
if callable(default):
|
if callable(default):
|
||||||
return Field(default_factory=default)
|
return Field(default_factory=default)
|
||||||
@ -44,16 +44,22 @@ class BaseField:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default(cls) -> Any:
|
def get_default(cls, use_server: bool = False) -> Any: # noqa CCR001
|
||||||
if cls.has_default():
|
if cls.has_default():
|
||||||
default = cls.default if cls.default is not None else cls.server_default
|
default = (
|
||||||
|
cls.default
|
||||||
|
if cls.default is not None
|
||||||
|
else (cls.server_default if use_server else None)
|
||||||
|
)
|
||||||
if callable(default):
|
if callable(default):
|
||||||
default = default()
|
default = default()
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def has_default(cls) -> bool:
|
def has_default(cls, use_server: bool = True) -> bool:
|
||||||
return cls.default is not None or cls.server_default is not None
|
return cls.default is not None or (
|
||||||
|
cls.server_default is not None and use_server
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_auto_primary_key(cls) -> bool:
|
def is_auto_primary_key(cls) -> bool:
|
||||||
|
|||||||
@ -129,7 +129,7 @@ class Model(NewBaseModel):
|
|||||||
|
|
||||||
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
|
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
|
||||||
self_fields.pop(self.Meta.pkname, None)
|
self_fields.pop(self.Meta.pkname, None)
|
||||||
self_fields = self.objects._populate_default_values(self_fields)
|
self_fields = self.populate_default_values(self_fields)
|
||||||
expr = self.Meta.table.insert()
|
expr = self.Meta.table.insert()
|
||||||
expr = expr.values(**self_fields)
|
expr = expr.values(**self_fields)
|
||||||
item_id = await self.Meta.database.execute(expr)
|
item_id = await self.Meta.database.execute(expr)
|
||||||
|
|||||||
@ -49,6 +49,16 @@ class ModelTableProxy:
|
|||||||
model_dict.pop(field, None)
|
model_dict.pop(field, None)
|
||||||
return model_dict
|
return model_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def populate_default_values(cls, new_kwargs: Dict) -> Dict:
|
||||||
|
for field_name, field in cls.Meta.model_fields.items():
|
||||||
|
if field_name not in new_kwargs and field.has_default(use_server=False):
|
||||||
|
new_kwargs[field_name] = field.get_default()
|
||||||
|
# clear fields with server_default set as None
|
||||||
|
if field.server_default is not None and not new_kwargs.get(field_name):
|
||||||
|
new_kwargs.pop(field_name, None)
|
||||||
|
return new_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_column_alias(cls, field_name: str) -> str:
|
def get_column_alias(cls, field_name: str) -> str:
|
||||||
field = cls.Meta.model_fields.get(field_name)
|
field = cls.Meta.model_fields.get(field_name)
|
||||||
|
|||||||
@ -73,16 +73,10 @@ class QuerySet:
|
|||||||
def _prepare_model_to_save(self, new_kwargs: dict) -> dict:
|
def _prepare_model_to_save(self, new_kwargs: dict) -> dict:
|
||||||
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
|
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
|
||||||
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
|
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
|
||||||
new_kwargs = self._populate_default_values(new_kwargs)
|
new_kwargs = self.model.populate_default_values(new_kwargs)
|
||||||
new_kwargs = self.model.translate_columns_to_aliases(new_kwargs)
|
new_kwargs = self.model.translate_columns_to_aliases(new_kwargs)
|
||||||
return new_kwargs
|
return new_kwargs
|
||||||
|
|
||||||
def _populate_default_values(self, new_kwargs: dict) -> dict:
|
|
||||||
for field_name, field in self.model_meta.model_fields.items():
|
|
||||||
if field_name not in new_kwargs and field.has_default():
|
|
||||||
new_kwargs[field_name] = field.get_default()
|
|
||||||
return new_kwargs
|
|
||||||
|
|
||||||
def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict:
|
def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict:
|
||||||
pkname = self.model_meta.pkname
|
pkname = self.model_meta.pkname
|
||||||
pk = self.model_meta.model_fields[pkname]
|
pk = self.model_meta.model_fields[pkname]
|
||||||
@ -300,6 +294,9 @@ class QuerySet:
|
|||||||
if pk and isinstance(pk, self.model.pk_type()):
|
if pk and isinstance(pk, self.model.pk_type()):
|
||||||
setattr(instance, self.model_meta.pkname, pk)
|
setattr(instance, self.model_meta.pkname, pk)
|
||||||
|
|
||||||
|
# refresh server side defaults
|
||||||
|
instance = await instance.load()
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
async def bulk_create(self, objects: List["Model"]) -> None:
|
async def bulk_create(self, objects: List["Model"]) -> None:
|
||||||
|
|||||||
80
tests/test_server_default.py
Normal file
80
tests/test_server_default.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import databases
|
||||||
|
import pytest
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy import func, text
|
||||||
|
|
||||||
|
import ormar
|
||||||
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
|
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||||
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
|
|
||||||
|
class Product(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "product"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
name: ormar.String(max_length=100)
|
||||||
|
company: ormar.String(max_length=200, server_default='Acme')
|
||||||
|
sort_order: ormar.Integer(server_default=text("10"))
|
||||||
|
created: ormar.DateTime(server_default=func.now())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def event_loop():
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True, scope="module")
|
||||||
|
async def create_test_database():
|
||||||
|
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||||
|
metadata.drop_all(engine)
|
||||||
|
metadata.create_all(engine)
|
||||||
|
yield
|
||||||
|
metadata.drop_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
def test_table_defined_properly():
|
||||||
|
assert Product.Meta.model_fields['created'].nullable
|
||||||
|
assert not Product.__fields__['created'].required
|
||||||
|
assert Product.Meta.table.columns['created'].server_default.arg.name == 'now'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_creation():
|
||||||
|
p1 = Product(name='Test')
|
||||||
|
assert p1.created is None
|
||||||
|
await p1.save()
|
||||||
|
await p1.load()
|
||||||
|
assert p1.created is not None
|
||||||
|
assert p1.company == 'Acme'
|
||||||
|
assert p1.sort_order == 10
|
||||||
|
|
||||||
|
date = datetime.strptime('2020-10-27 11:30', '%Y-%m-%d %H:%M')
|
||||||
|
p3 = await Product.objects.create(name='Test2', created=date, company='Roadrunner', sort_order=1)
|
||||||
|
assert p3.created is not None
|
||||||
|
assert p3.created == date
|
||||||
|
assert p1.created != p3.created
|
||||||
|
assert p3.company == 'Roadrunner'
|
||||||
|
assert p3.sort_order == 1
|
||||||
|
|
||||||
|
p3 = await Product.objects.get(name='Test2')
|
||||||
|
assert p3.company == 'Roadrunner'
|
||||||
|
assert p3.sort_order == 1
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
p2 = await Product.objects.create(name='Test3')
|
||||||
|
assert p2.created is not None
|
||||||
|
assert p1.created != p2.created
|
||||||
|
assert p2.company == 'Acme'
|
||||||
|
assert p2.sort_order == 10
|
||||||
Reference in New Issue
Block a user