fix server_default not setting value in defaults, pop server_default fields if set to None, add tests, update docs

This commit is contained in:
collerek
2020-10-28 15:34:40 +01:00
parent 858fbce67f
commit 29d04887b1
8 changed files with 150 additions and 17 deletions

View File

@ -68,13 +68,31 @@ Used both in sql and pydantic.
A default value used if no other value is passed.
In sql invoked on the server side so you can pass i.e. sql function (like now() wrapped in sqlalchemy text() clause).
In sql invoked on the server side so you can pass i.e. sql function (like now() or query/value wrapped in sqlalchemy text() clause).
If the field has a server_default value it becomes optional.
You can pass a static value or a Callable (function etc.)
Used in sql only.
Sample usage:
```Python hl_lines="19-21"
--8<-- "../docs_src/fields/docs004.py"
```
!!!warning
`server_default` accepts `str`, `sqlalchemy.sql.elements.ClauseElement` or `sqlalchemy.sql.elements.TextClause`
so if you want to set i.e. Integer value you need to wrap it in `sqlalchemy.text()` function like above
!!!tip
You can pass also valid sql (dialect specific) wrapped in `sqlalchemy.text()`
For example `func.now()` above could be exchanged for `text('(CURRENT_TIMESTAMP)')` for sqlite backend
!!!info
`server_default` is passed straight to sqlalchemy table definition so you can read more in [server default][server default] sqlalchemy documentation
### index
@ -240,4 +258,5 @@ You can use either `length` and `precision` parameters or `max_digits` and `deci
[relations]: ./relations.md
[queries]: ./queries.md
[pydantic]: https://pydantic-docs.helpmanual.io/usage/types/#constrained-types
[pydantic]: https://pydantic-docs.helpmanual.io/usage/types/#constrained-types
[server default]: https://docs.sqlalchemy.org/en/13/core/defaults.html#server-invoked-ddl-explicit-default-expressions

View File

@ -0,0 +1,21 @@
import databases
import sqlalchemy
from sqlalchemy import func, text
import ormar
database = databases.Database("sqlite:///test.db")
metadata = sqlalchemy.MetaData()
class Product(ormar.Model):
class Meta:
tablename = "product"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
company: ormar.String(max_length=200, server_default='Acme')
sort_order: ormar.Integer(server_default=text("10"))
created: ormar.DateTime(server_default=func.now())

View File

@ -28,7 +28,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType()
__version__ = "0.3.9"
__version__ = "0.3.11"
__all__ = [
"Integer",
"BigInteger",

View File

@ -33,10 +33,10 @@ class BaseField:
server_default: Any
@classmethod
def default_value(cls) -> Optional[FieldInfo]:
def default_value(cls, use_server: bool = False) -> Optional[FieldInfo]:
if cls.is_auto_primary_key():
return Field(default=None)
if cls.has_default():
if cls.has_default(use_server=use_server):
default = cls.default if cls.default is not None else cls.server_default
if callable(default):
return Field(default_factory=default)
@ -44,16 +44,22 @@ class BaseField:
return None
@classmethod
def get_default(cls) -> Any:
def get_default(cls, use_server: bool = False) -> Any: # noqa CCR001
if cls.has_default():
default = cls.default if cls.default is not None else cls.server_default
default = (
cls.default
if cls.default is not None
else (cls.server_default if use_server else None)
)
if callable(default):
default = default()
return default
@classmethod
def has_default(cls) -> bool:
return cls.default is not None or cls.server_default is not None
def has_default(cls, use_server: bool = True) -> bool:
return cls.default is not None or (
cls.server_default is not None and use_server
)
@classmethod
def is_auto_primary_key(cls) -> bool:

View File

@ -129,7 +129,7 @@ class Model(NewBaseModel):
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
self_fields.pop(self.Meta.pkname, None)
self_fields = self.objects._populate_default_values(self_fields)
self_fields = self.populate_default_values(self_fields)
expr = self.Meta.table.insert()
expr = expr.values(**self_fields)
item_id = await self.Meta.database.execute(expr)

View File

@ -49,6 +49,16 @@ class ModelTableProxy:
model_dict.pop(field, None)
return model_dict
@classmethod
def populate_default_values(cls, new_kwargs: Dict) -> Dict:
for field_name, field in cls.Meta.model_fields.items():
if field_name not in new_kwargs and field.has_default(use_server=False):
new_kwargs[field_name] = field.get_default()
# clear fields with server_default set as None
if field.server_default is not None and not new_kwargs.get(field_name):
new_kwargs.pop(field_name, None)
return new_kwargs
@classmethod
def get_column_alias(cls, field_name: str) -> str:
field = cls.Meta.model_fields.get(field_name)

View File

@ -73,16 +73,10 @@ class QuerySet:
def _prepare_model_to_save(self, new_kwargs: dict) -> dict:
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs)
new_kwargs = self.model.populate_default_values(new_kwargs)
new_kwargs = self.model.translate_columns_to_aliases(new_kwargs)
return new_kwargs
def _populate_default_values(self, new_kwargs: dict) -> dict:
for field_name, field in self.model_meta.model_fields.items():
if field_name not in new_kwargs and field.has_default():
new_kwargs[field_name] = field.get_default()
return new_kwargs
def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict:
pkname = self.model_meta.pkname
pk = self.model_meta.model_fields[pkname]
@ -300,6 +294,9 @@ class QuerySet:
if pk and isinstance(pk, self.model.pk_type()):
setattr(instance, self.model_meta.pkname, pk)
# refresh server side defaults
instance = await instance.load()
return instance
async def bulk_create(self, objects: List["Model"]) -> None:

View File

@ -0,0 +1,80 @@
import asyncio
import time
from datetime import datetime
import databases
import pytest
import sqlalchemy
from sqlalchemy import func, text
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Product(ormar.Model):
class Meta:
tablename = "product"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
company: ormar.String(max_length=200, server_default='Acme')
sort_order: ormar.Integer(server_default=text("10"))
created: ormar.DateTime(server_default=func.now())
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module")
async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
def test_table_defined_properly():
assert Product.Meta.model_fields['created'].nullable
assert not Product.__fields__['created'].required
assert Product.Meta.table.columns['created'].server_default.arg.name == 'now'
@pytest.mark.asyncio
async def test_model_creation():
p1 = Product(name='Test')
assert p1.created is None
await p1.save()
await p1.load()
assert p1.created is not None
assert p1.company == 'Acme'
assert p1.sort_order == 10
date = datetime.strptime('2020-10-27 11:30', '%Y-%m-%d %H:%M')
p3 = await Product.objects.create(name='Test2', created=date, company='Roadrunner', sort_order=1)
assert p3.created is not None
assert p3.created == date
assert p1.created != p3.created
assert p3.company == 'Roadrunner'
assert p3.sort_order == 1
p3 = await Product.objects.get(name='Test2')
assert p3.company == 'Roadrunner'
assert p3.sort_order == 1
time.sleep(1)
p2 = await Product.objects.create(name='Test3')
assert p2.created is not None
assert p1.created != p2.created
assert p2.company == 'Acme'
assert p2.sort_order == 10