Merge pull request #26 from collerek/fix_server_default

Fix server_default Field parameter
This commit is contained in:
collerek
2020-10-28 22:01:03 +07:00
committed by GitHub
9 changed files with 160 additions and 17 deletions

View File

@ -68,13 +68,31 @@ Used both in sql and pydantic.
A default value used if no other value is passed. A default value used if no other value is passed.
In sql invoked on the server side so you can pass i.e. sql function (like now() wrapped in sqlalchemy text() clause). In sql invoked on the server side so you can pass i.e. sql function (like now() or query/value wrapped in sqlalchemy text() clause).
If the field has a server_default value it becomes optional. If the field has a server_default value it becomes optional.
You can pass a static value or a Callable (function etc.) You can pass a static value or a Callable (function etc.)
Used in sql only. Used in sql only.
Sample usage:
```Python hl_lines="19-21"
--8<-- "../docs_src/fields/docs004.py"
```
!!!warning
`server_default` accepts `str`, `sqlalchemy.sql.elements.ClauseElement` or `sqlalchemy.sql.elements.TextClause`
so if you want to set i.e. Integer value you need to wrap it in `sqlalchemy.text()` function like above
!!!tip
You can pass also valid sql (dialect specific) wrapped in `sqlalchemy.text()`
For example `func.now()` above could be exchanged for `text('(CURRENT_TIMESTAMP)')` for sqlite backend
!!!info
`server_default` is passed straight to sqlalchemy table definition so you can read more in [server default][server default] sqlalchemy documentation
### index ### index
@ -240,4 +258,5 @@ You can use either `length` and `precision` parameters or `max_digits` and `deci
[relations]: ./relations.md [relations]: ./relations.md
[queries]: ./queries.md [queries]: ./queries.md
[pydantic]: https://pydantic-docs.helpmanual.io/usage/types/#constrained-types [pydantic]: https://pydantic-docs.helpmanual.io/usage/types/#constrained-types
[server default]: https://docs.sqlalchemy.org/en/13/core/defaults.html#server-invoked-ddl-explicit-default-expressions

View File

@ -1,3 +1,7 @@
# 0.3.10
* Fix
# 0.3.9 # 0.3.9
* Fix json schema generation as of [#19][#19] * Fix json schema generation as of [#19][#19]

View File

@ -0,0 +1,21 @@
import databases
import sqlalchemy
from sqlalchemy import func, text
import ormar
database = databases.Database("sqlite:///test.db")
metadata = sqlalchemy.MetaData()
class Product(ormar.Model):
class Meta:
tablename = "product"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
company: ormar.String(max_length=200, server_default='Acme')
sort_order: ormar.Integer(server_default=text("10"))
created: ormar.DateTime(server_default=func.now())

View File

@ -28,7 +28,8 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType() Undefined = UndefinedType()
__version__ = "0.3.10"
__version__ = "0.3.11"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",

View File

@ -33,10 +33,10 @@ class BaseField:
server_default: Any server_default: Any
@classmethod @classmethod
def default_value(cls) -> Optional[FieldInfo]: def default_value(cls, use_server: bool = False) -> Optional[FieldInfo]:
if cls.is_auto_primary_key(): if cls.is_auto_primary_key():
return Field(default=None) return Field(default=None)
if cls.has_default(): if cls.has_default(use_server=use_server):
default = cls.default if cls.default is not None else cls.server_default default = cls.default if cls.default is not None else cls.server_default
if callable(default): if callable(default):
return Field(default_factory=default) return Field(default_factory=default)
@ -44,16 +44,22 @@ class BaseField:
return None return None
@classmethod @classmethod
def get_default(cls) -> Any: def get_default(cls, use_server: bool = False) -> Any: # noqa CCR001
if cls.has_default(): if cls.has_default():
default = cls.default if cls.default is not None else cls.server_default default = (
cls.default
if cls.default is not None
else (cls.server_default if use_server else None)
)
if callable(default): if callable(default):
default = default() default = default()
return default return default
@classmethod @classmethod
def has_default(cls) -> bool: def has_default(cls, use_server: bool = True) -> bool:
return cls.default is not None or cls.server_default is not None return cls.default is not None or (
cls.server_default is not None and use_server
)
@classmethod @classmethod
def is_auto_primary_key(cls) -> bool: def is_auto_primary_key(cls) -> bool:

View File

@ -129,7 +129,7 @@ class Model(NewBaseModel):
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement: if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
self_fields.pop(self.Meta.pkname, None) self_fields.pop(self.Meta.pkname, None)
self_fields = self.objects._populate_default_values(self_fields) self_fields = self.populate_default_values(self_fields)
expr = self.Meta.table.insert() expr = self.Meta.table.insert()
expr = expr.values(**self_fields) expr = expr.values(**self_fields)
item_id = await self.Meta.database.execute(expr) item_id = await self.Meta.database.execute(expr)

View File

@ -49,6 +49,16 @@ class ModelTableProxy:
model_dict.pop(field, None) model_dict.pop(field, None)
return model_dict return model_dict
@classmethod
def populate_default_values(cls, new_kwargs: Dict) -> Dict:
for field_name, field in cls.Meta.model_fields.items():
if field_name not in new_kwargs and field.has_default(use_server=False):
new_kwargs[field_name] = field.get_default()
# clear fields with server_default set as None
if field.server_default is not None and not new_kwargs.get(field_name):
new_kwargs.pop(field_name, None)
return new_kwargs
@classmethod @classmethod
def get_column_alias(cls, field_name: str) -> str: def get_column_alias(cls, field_name: str) -> str:
field = cls.Meta.model_fields.get(field_name) field = cls.Meta.model_fields.get(field_name)

View File

@ -73,16 +73,10 @@ class QuerySet:
def _prepare_model_to_save(self, new_kwargs: dict) -> dict: def _prepare_model_to_save(self, new_kwargs: dict) -> dict:
new_kwargs = self._remove_pk_from_kwargs(new_kwargs) new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs) new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs) new_kwargs = self.model.populate_default_values(new_kwargs)
new_kwargs = self.model.translate_columns_to_aliases(new_kwargs) new_kwargs = self.model.translate_columns_to_aliases(new_kwargs)
return new_kwargs return new_kwargs
def _populate_default_values(self, new_kwargs: dict) -> dict:
for field_name, field in self.model_meta.model_fields.items():
if field_name not in new_kwargs and field.has_default():
new_kwargs[field_name] = field.get_default()
return new_kwargs
def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict: def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict:
pkname = self.model_meta.pkname pkname = self.model_meta.pkname
pk = self.model_meta.model_fields[pkname] pk = self.model_meta.model_fields[pkname]
@ -300,6 +294,9 @@ class QuerySet:
if pk and isinstance(pk, self.model.pk_type()): if pk and isinstance(pk, self.model.pk_type()):
setattr(instance, self.model_meta.pkname, pk) setattr(instance, self.model_meta.pkname, pk)
# refresh server side defaults
instance = await instance.load()
return instance return instance
async def bulk_create(self, objects: List["Model"]) -> None: async def bulk_create(self, objects: List["Model"]) -> None:

View File

@ -0,0 +1,85 @@
import asyncio
import time
from datetime import datetime
import databases
import pytest
import sqlalchemy
from sqlalchemy import func, text
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Product(ormar.Model):
class Meta:
tablename = "product"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
company: ormar.String(max_length=200, server_default='Acme')
sort_order: ormar.Integer(server_default=text("10"))
created: ormar.DateTime(server_default=func.now())
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module")
async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
def test_table_defined_properly():
assert Product.Meta.model_fields['created'].nullable
assert not Product.__fields__['created'].required
assert Product.Meta.table.columns['created'].server_default.arg.name == 'now'
@pytest.mark.asyncio
async def test_model_creation():
async with database:
async with database.transaction(force_rollback=True):
p1 = Product(name='Test')
assert p1.created is None
await p1.save()
await p1.load()
assert p1.created is not None
assert p1.company == 'Acme'
assert p1.sort_order == 10
date = datetime.strptime('2020-10-27 11:30', '%Y-%m-%d %H:%M')
p3 = await Product.objects.create(name='Test2', created=date, company='Roadrunner', sort_order=1)
assert p3.created is not None
assert p3.created == date
assert p1.created != p3.created
assert p3.company == 'Roadrunner'
assert p3.sort_order == 1
p3 = await Product.objects.get(name='Test2')
assert p3.company == 'Roadrunner'
assert p3.sort_order == 1
time.sleep(1)
p2 = await Product.objects.create(name='Test3')
assert p2.created is not None
assert p2.company == 'Acme'
assert p2.sort_order == 10
if Product.db_backend_name() != 'postgresql':
# postgres use transaction timestamp so it will remain the same
assert p1.created != p2.created # pragma nocover