From b6e057c303b99909ca20d177f2729993604a35af Mon Sep 17 00:00:00 2001 From: Sepehr Bazyar Date: Thu, 14 Jul 2022 12:35:30 +0430 Subject: [PATCH] CheckColumns Constraint (#730) * feat: add check columns class * feat: write document of check columns part * test: write a test for check columns constraints * fix: debuging test exception raise mysql * fix: set pragma no cover to ignore cov * fix: ignore pytest raise in python 3.x not 10 * feat: set constraint name for check columns * refactor: support index and check overwrites * fix: debuging check constraint arguments * fix: debug coverage all code tests * fix: pass the map of type constraint to counter * refactor: edit check name replace sapce underline * refactor: write new function copy constraints * test: write test for invalid constraint type * fix: debug text cluase replaced names * fix: set pragma no cover for result returned * refactor: no coverage for main if statement * perf: change get constraint copy func code * fix: fix bug in mypy typing check --- docs/models/index.md | 11 ++++ docs_src/models/docs018.py | 25 ++++++++ ormar/__init__.py | 2 + ormar/fields/__init__.py | 2 +- ormar/fields/constraints.py | 11 +++- ormar/models/helpers/sqlalchemy.py | 5 +- ormar/models/metaclass.py | 36 ++++++++++-- .../test_inheritance_concrete.py | 21 ++++++- .../test_check_constraints.py | 57 +++++++++++++++++++ 9 files changed, 160 insertions(+), 10 deletions(-) create mode 100644 docs_src/models/docs018.py create mode 100644 tests/test_meta_constraints/test_check_constraints.py diff --git a/docs/models/index.md b/docs/models/index.md index c99a71f..1f891a4 100644 --- a/docs/models/index.md +++ b/docs/models/index.md @@ -429,6 +429,17 @@ You can set this parameter by providing `Meta` class `constraints` argument. To set one column index use [`unique`](../fields/common-parameters.md#index) common parameter. Of course, you can set many columns as indexes with this param but each of them will be a separate index. +#### CheckColumns + +You can set this parameter by providing `Meta` class `constraints` argument. + +```Python hl_lines="14-17" +--8<-- "../docs_src/models/docs018.py" +``` + +!!!note + Note that some databases do not actively support check constraints such as MySQL. + ### Pydantic configuration diff --git a/docs_src/models/docs018.py b/docs_src/models/docs018.py new file mode 100644 index 0000000..d3aee4e --- /dev/null +++ b/docs_src/models/docs018.py @@ -0,0 +1,25 @@ +import datetime +import databases +import sqlalchemy + +import ormar + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Course(ormar.Model): + class Meta: + database = database + metadata = metadata + # define your constraints in Meta class of the model + # it's a list that can contain multiple constraints + # hera a combination of name and column will have a level check in the db + constraints = [ + ormar.CheckColumns("start_time < end_time", name="date_check"), + ] + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + start_date: datetime.date = ormar.Date() + end_date: datetime.date = ormar.Date() diff --git a/ormar/__init__.py b/ormar/__init__.py index 1b61eb6..62eda35 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -70,6 +70,7 @@ from ormar.fields import ( Time, UUID, UniqueColumns, + CheckColumns, ) # noqa: I100 from ormar.models import ExcludableItems, Extra, Model from ormar.models.metaclass import ModelMeta @@ -112,6 +113,7 @@ __all__ = [ "UUID", "UniqueColumns", "IndexColumns", + "CheckColumns", "QuerySetProtocol", "RelationProtocol", "ModelMeta", diff --git a/ormar/fields/__init__.py b/ormar/fields/__init__.py index e55f2d3..42ee83a 100644 --- a/ormar/fields/__init__.py +++ b/ormar/fields/__init__.py @@ -5,7 +5,7 @@ as well as relation Fields (ForeignKey, ManyToMany). Also a definition for custom CHAR based sqlalchemy UUID field """ from ormar.fields.base import BaseField -from ormar.fields.constraints import IndexColumns, UniqueColumns +from ormar.fields.constraints import IndexColumns, UniqueColumns, CheckColumns from ormar.fields.foreign_key import ForeignKey, ForeignKeyField from ormar.fields.many_to_many import ManyToMany, ManyToManyField from ormar.fields.model_fields import ( diff --git a/ormar/fields/constraints.py b/ormar/fields/constraints.py index 1a0d52e..e799d0c 100644 --- a/ormar/fields/constraints.py +++ b/ormar/fields/constraints.py @@ -1,6 +1,6 @@ from typing import Any -from sqlalchemy import Index, UniqueConstraint +from sqlalchemy import Index, UniqueConstraint, CheckConstraint class UniqueColumns(UniqueConstraint): @@ -20,3 +20,12 @@ class IndexColumns(Index): Subclass of sqlalchemy.Index. Used to avoid importing anything from sqlalchemy by user. """ + + +class CheckColumns(CheckConstraint): + """ + Subclass of sqlalchemy.CheckConstraint. + Used to avoid importing anything from sqlalchemy by user. + + Note that some databases do not actively support check constraints such as MySQL. + """ diff --git a/ormar/models/helpers/sqlalchemy.py b/ormar/models/helpers/sqlalchemy.py index 22de0e3..b0ade1d 100644 --- a/ormar/models/helpers/sqlalchemy.py +++ b/ormar/models/helpers/sqlalchemy.py @@ -298,7 +298,7 @@ def populate_meta_sqlalchemy_table_if_required(meta: "ModelMeta") -> None: def set_constraint_names(meta: "ModelMeta") -> None: """ - Populates the names on IndexColumn and UniqueColumns constraints. + Populates the names on IndexColumns and UniqueColumns and CheckColumns constraints. :param meta: Meta class of the Model without sqlalchemy table constructed :type meta: Model class Meta @@ -317,6 +317,9 @@ def set_constraint_names(meta: "ModelMeta") -> None: f"ix_{meta.tablename}_" f'{"_".join([col for col in constraint._pending_colargs])}' ) + elif isinstance(constraint, sqlalchemy.CheckConstraint) and not constraint.name: + sql_condition: str = str(constraint.sqltext).replace(" ", "_") + constraint.name = f"check_{meta.tablename}_{sql_condition}" def update_column_definition( diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index bc22727..208e0eb 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -9,6 +9,7 @@ from typing import ( Type, Union, cast, + Callable, ) import databases @@ -18,6 +19,7 @@ from sqlalchemy.sql.schema import ColumnCollectionConstraint import ormar # noqa I100 import ormar.fields.constraints +from ormar.fields.constraints import UniqueColumns, IndexColumns, CheckColumns from ormar import ModelDefinitionError # noqa I100 from ormar.exceptions import ModelError from ormar.fields import BaseField @@ -186,7 +188,7 @@ def verify_constraint_names( for column_set in constraints_columns: if any(x not in old_aliases.values() for x in column_set): raise ModelDefinitionError( - f"Unique columns constraint " + f"Column constraints " f"{column_set} " f"has column names " f"that are not in the model fields." @@ -195,6 +197,33 @@ def verify_constraint_names( ) +def get_constraint_copy( + constraint: ColumnCollectionConstraint, +) -> Union[UniqueColumns, IndexColumns, CheckColumns]: + """ + Copy the constraint and unpacking it's values + + :raises ValueError: if non subclass of ColumnCollectionConstraint + :param value: an instance of the ColumnCollectionConstraint class + :type value: Instance of ColumnCollectionConstraint child + :return: copy ColumnCollectionConstraint ormar constraints + :rtype: Union[UniqueColumns, IndexColumns, CheckColumns] + """ + + constraints = { + sqlalchemy.UniqueConstraint: lambda x: UniqueColumns(*x._pending_colargs), + sqlalchemy.Index: lambda x: IndexColumns(*x._pending_colargs), + sqlalchemy.CheckConstraint: lambda x: CheckColumns(x.sqltext), + } + checks = (key if isinstance(constraint, key) else None for key in constraints) + target_class = next((target for target in checks if target is not None), None) + constructor: Optional[Callable] = constraints.get(target_class) + if not constructor: + raise ValueError(f"{constraint} must be a ColumnCollectionMixin!") + + return constructor(constraint) + + def update_attrs_from_base_meta( # noqa: CCR001 base_class: "Model", attrs: Dict, model_fields: Dict ) -> None: @@ -222,10 +251,7 @@ def update_attrs_from_base_meta( # noqa: CCR001 model_fields=model_fields, parent_value=parent_value, ) - parent_value = [ - ormar.fields.constraints.UniqueColumns(*x._pending_colargs) - for x in parent_value - ] + parent_value = [get_constraint_copy(value) for value in parent_value] if isinstance(current_value, list): current_value.extend(parent_value) else: diff --git a/tests/test_inheritance_and_pydantic_generation/test_inheritance_concrete.py b/tests/test_inheritance_and_pydantic_generation/test_inheritance_concrete.py index 1acb802..ccaf27f 100644 --- a/tests/test_inheritance_and_pydantic_generation/test_inheritance_concrete.py +++ b/tests/test_inheritance_and_pydantic_generation/test_inheritance_concrete.py @@ -1,6 +1,7 @@ # type: ignore import datetime from typing import List, Optional +from collections import Counter import databases import pytest @@ -11,6 +12,7 @@ import ormar import ormar.fields.constraints from ormar import ModelDefinitionError, property_field from ormar.exceptions import ModelError +from ormar.models.metaclass import get_constraint_copy from tests.settings import DATABASE_URL metadata = sa.MetaData() @@ -47,7 +49,13 @@ class DateFieldsModel(ormar.Model): metadata = metadata database = db constraints = [ - ormar.fields.constraints.UniqueColumns("creation_date", "modification_date") + ormar.fields.constraints.UniqueColumns( + "creation_date", + "modification_date", + ), + ormar.fields.constraints.CheckColumns( + "creation_date <= modification_date", + ), ] created_date: datetime.datetime = ormar.DateTime( @@ -234,9 +242,13 @@ def test_model_subclassing_non_abstract_raises_error(): def test_params_are_inherited(): assert Category.Meta.metadata == metadata assert Category.Meta.database == db - assert len(Category.Meta.constraints) == 2 assert len(Category.Meta.property_fields) == 2 + constraints = Counter(map(lambda c: type(c), Category.Meta.constraints)) + assert constraints[ormar.fields.constraints.UniqueColumns] == 2 + assert constraints[ormar.fields.constraints.IndexColumns] == 0 + assert constraints[ormar.fields.constraints.CheckColumns] == 1 + def round_date_to_seconds( date: datetime.datetime, @@ -519,3 +531,8 @@ def test_custom_config(): sam = ImmutablePerson(name="Sam") with pytest.raises(TypeError): sam.name = "Not Sam" + + +def test_get_constraint_copy(): + with pytest.raises(ValueError): + get_constraint_copy("INVALID CONSTRAINT") diff --git a/tests/test_meta_constraints/test_check_constraints.py b/tests/test_meta_constraints/test_check_constraints.py new file mode 100644 index 0000000..2fc69d4 --- /dev/null +++ b/tests/test_meta_constraints/test_check_constraints.py @@ -0,0 +1,57 @@ +import sqlite3 + +import asyncpg # type: ignore +import databases +import pytest +import sqlalchemy + +import ormar.fields.constraints +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Product(ormar.Model): + class Meta: + tablename = "products" + metadata = metadata + database = database + constraints = [ + ormar.fields.constraints.CheckColumns("inventory > buffer"), + ] + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + company: str = ormar.String(max_length=200) + inventory: int = ormar.Integer() + buffer: int = ormar.Integer() + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_check_columns_exclude_mysql(): + if Product.Meta.database._backend._dialect.name != "mysql": + async with database: # pragma: no cover + async with database.transaction(force_rollback=True): + await Product.objects.create( + name="Mars", company="Nestle", inventory=100, buffer=10 + ) + + with pytest.raises( + ( + sqlite3.IntegrityError, + asyncpg.exceptions.CheckViolationError, + ) + ): + await Product.objects.create( + name="Cookies", company="Nestle", inventory=1, buffer=10 + )