From 29d04887b14806a76751e641d8168df12af2b39a Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 28 Oct 2020 15:34:40 +0100 Subject: [PATCH] fix server_default not setting value in defaults, pop server_default fields if set to None, add tests, update docs --- docs/fields.md | 23 ++++++++++- docs_src/fields/docs004.py | 21 ++++++++++ ormar/__init__.py | 2 +- ormar/fields/base.py | 18 +++++--- ormar/models/model.py | 2 +- ormar/models/modelproxy.py | 10 +++++ ormar/queryset/queryset.py | 11 ++--- tests/test_server_default.py | 80 ++++++++++++++++++++++++++++++++++++ 8 files changed, 150 insertions(+), 17 deletions(-) create mode 100644 docs_src/fields/docs004.py create mode 100644 tests/test_server_default.py diff --git a/docs/fields.md b/docs/fields.md index 8381b69..6dc0a17 100644 --- a/docs/fields.md +++ b/docs/fields.md @@ -68,13 +68,31 @@ Used both in sql and pydantic. A default value used if no other value is passed. -In sql invoked on the server side so you can pass i.e. sql function (like now() wrapped in sqlalchemy text() clause). +In sql invoked on the server side so you can pass i.e. sql function (like now() or query/value wrapped in sqlalchemy text() clause). If the field has a server_default value it becomes optional. You can pass a static value or a Callable (function etc.) Used in sql only. + +Sample usage: + +```Python hl_lines="19-21" +--8<-- "../docs_src/fields/docs004.py" +``` + +!!!warning + `server_default` accepts `str`, `sqlalchemy.sql.elements.ClauseElement` or `sqlalchemy.sql.elements.TextClause` + so if you want to set i.e. Integer value you need to wrap it in `sqlalchemy.text()` function like above + +!!!tip + You can pass also valid sql (dialect specific) wrapped in `sqlalchemy.text()` + + For example `func.now()` above could be exchanged for `text('(CURRENT_TIMESTAMP)')` for sqlite backend + +!!!info + `server_default` is passed straight to sqlalchemy table definition so you can read more in [server default][server default] sqlalchemy documentation ### index @@ -240,4 +258,5 @@ You can use either `length` and `precision` parameters or `max_digits` and `deci [relations]: ./relations.md [queries]: ./queries.md -[pydantic]: https://pydantic-docs.helpmanual.io/usage/types/#constrained-types \ No newline at end of file +[pydantic]: https://pydantic-docs.helpmanual.io/usage/types/#constrained-types +[server default]: https://docs.sqlalchemy.org/en/13/core/defaults.html#server-invoked-ddl-explicit-default-expressions \ No newline at end of file diff --git a/docs_src/fields/docs004.py b/docs_src/fields/docs004.py new file mode 100644 index 0000000..b04cdb2 --- /dev/null +++ b/docs_src/fields/docs004.py @@ -0,0 +1,21 @@ +import databases +import sqlalchemy +from sqlalchemy import func, text + +import ormar + +database = databases.Database("sqlite:///test.db") +metadata = sqlalchemy.MetaData() + + +class Product(ormar.Model): + class Meta: + tablename = "product" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + company: ormar.String(max_length=200, server_default='Acme') + sort_order: ormar.Integer(server_default=text("10")) + created: ormar.DateTime(server_default=func.now()) diff --git a/ormar/__init__.py b/ormar/__init__.py index 4aefddc..f009f6c 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -28,7 +28,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.3.9" +__version__ = "0.3.11" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 88e1313..e55de13 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -33,10 +33,10 @@ class BaseField: server_default: Any @classmethod - def default_value(cls) -> Optional[FieldInfo]: + def default_value(cls, use_server: bool = False) -> Optional[FieldInfo]: if cls.is_auto_primary_key(): return Field(default=None) - if cls.has_default(): + if cls.has_default(use_server=use_server): default = cls.default if cls.default is not None else cls.server_default if callable(default): return Field(default_factory=default) @@ -44,16 +44,22 @@ class BaseField: return None @classmethod - def get_default(cls) -> Any: + def get_default(cls, use_server: bool = False) -> Any: # noqa CCR001 if cls.has_default(): - default = cls.default if cls.default is not None else cls.server_default + default = ( + cls.default + if cls.default is not None + else (cls.server_default if use_server else None) + ) if callable(default): default = default() return default @classmethod - def has_default(cls) -> bool: - return cls.default is not None or cls.server_default is not None + def has_default(cls, use_server: bool = True) -> bool: + return cls.default is not None or ( + cls.server_default is not None and use_server + ) @classmethod def is_auto_primary_key(cls) -> bool: diff --git a/ormar/models/model.py b/ormar/models/model.py index 9017db0..d0b07e4 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -129,7 +129,7 @@ class Model(NewBaseModel): if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement: self_fields.pop(self.Meta.pkname, None) - self_fields = self.objects._populate_default_values(self_fields) + self_fields = self.populate_default_values(self_fields) expr = self.Meta.table.insert() expr = expr.values(**self_fields) item_id = await self.Meta.database.execute(expr) diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index c965154..56ee930 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -49,6 +49,16 @@ class ModelTableProxy: model_dict.pop(field, None) return model_dict + @classmethod + def populate_default_values(cls, new_kwargs: Dict) -> Dict: + for field_name, field in cls.Meta.model_fields.items(): + if field_name not in new_kwargs and field.has_default(use_server=False): + new_kwargs[field_name] = field.get_default() + # clear fields with server_default set as None + if field.server_default is not None and not new_kwargs.get(field_name): + new_kwargs.pop(field_name, None) + return new_kwargs + @classmethod def get_column_alias(cls, field_name: str) -> str: field = cls.Meta.model_fields.get(field_name) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index a6c6b26..1decda1 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -73,16 +73,10 @@ class QuerySet: def _prepare_model_to_save(self, new_kwargs: dict) -> dict: new_kwargs = self._remove_pk_from_kwargs(new_kwargs) new_kwargs = self.model.substitute_models_with_pks(new_kwargs) - new_kwargs = self._populate_default_values(new_kwargs) + new_kwargs = self.model.populate_default_values(new_kwargs) new_kwargs = self.model.translate_columns_to_aliases(new_kwargs) return new_kwargs - def _populate_default_values(self, new_kwargs: dict) -> dict: - for field_name, field in self.model_meta.model_fields.items(): - if field_name not in new_kwargs and field.has_default(): - new_kwargs[field_name] = field.get_default() - return new_kwargs - def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict: pkname = self.model_meta.pkname pk = self.model_meta.model_fields[pkname] @@ -300,6 +294,9 @@ class QuerySet: if pk and isinstance(pk, self.model.pk_type()): setattr(instance, self.model_meta.pkname, pk) + # refresh server side defaults + instance = await instance.load() + return instance async def bulk_create(self, objects: List["Model"]) -> None: diff --git a/tests/test_server_default.py b/tests/test_server_default.py new file mode 100644 index 0000000..205996b --- /dev/null +++ b/tests/test_server_default.py @@ -0,0 +1,80 @@ +import asyncio +import time +from datetime import datetime + +import databases +import pytest +import sqlalchemy +from sqlalchemy import func, text + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Product(ormar.Model): + class Meta: + tablename = "product" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + company: ormar.String(max_length=200, server_default='Acme') + sort_order: ormar.Integer(server_default=text("10")) + created: ormar.DateTime(server_default=func.now()) + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@pytest.fixture(autouse=True, scope="module") +async def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_table_defined_properly(): + assert Product.Meta.model_fields['created'].nullable + assert not Product.__fields__['created'].required + assert Product.Meta.table.columns['created'].server_default.arg.name == 'now' + + +@pytest.mark.asyncio +async def test_model_creation(): + p1 = Product(name='Test') + assert p1.created is None + await p1.save() + await p1.load() + assert p1.created is not None + assert p1.company == 'Acme' + assert p1.sort_order == 10 + + date = datetime.strptime('2020-10-27 11:30', '%Y-%m-%d %H:%M') + p3 = await Product.objects.create(name='Test2', created=date, company='Roadrunner', sort_order=1) + assert p3.created is not None + assert p3.created == date + assert p1.created != p3.created + assert p3.company == 'Roadrunner' + assert p3.sort_order == 1 + + p3 = await Product.objects.get(name='Test2') + assert p3.company == 'Roadrunner' + assert p3.sort_order == 1 + + time.sleep(1) + + p2 = await Product.objects.create(name='Test3') + assert p2.created is not None + assert p1.created != p2.created + assert p2.company == 'Acme' + assert p2.sort_order == 10