added unique columns constraints to Meta options

This commit is contained in:
collerek
2020-10-01 11:42:20 +02:00
parent c4d1d00ad3
commit d0b6e75470
7 changed files with 107 additions and 42 deletions

BIN
.coverage

Binary file not shown.

View File

@ -394,8 +394,8 @@ All fields are required unless one of the following is set:
* `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column. * `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column.
Autoincrement is set by default on int primary keys. Autoincrement is set by default on int primary keys.
Available Model Fields: Available Model Fields (with required args - optional ones in docs):
* `String(length)` * `String(max_length)`
* `Text()` * `Text()`
* `Boolean()` * `Boolean()`
* `Integer()` * `Integer()`

View File

@ -14,6 +14,7 @@ from ormar.fields import (
Text, Text,
Time, Time,
UUID, UUID,
UniqueColumns,
) )
from ormar.models import Model from ormar.models import Model
from ormar.queryset import QuerySet from ormar.queryset import QuerySet
@ -51,4 +52,5 @@ __all__ = [
"RelationType", "RelationType",
"Undefined", "Undefined",
"UUID", "UUID",
"UniqueColumns",
] ]

View File

@ -1,5 +1,5 @@
from ormar.fields.base import BaseField from ormar.fields.base import BaseField
from ormar.fields.foreign_key import ForeignKey from ormar.fields.foreign_key import ForeignKey, UniqueColumns
from ormar.fields.many_to_many import ManyToMany, ManyToManyField from ormar.fields.many_to_many import ManyToMany, ManyToManyField
from ormar.fields.model_fields import ( from ormar.fields.model_fields import (
BigInteger, BigInteger,
@ -33,4 +33,5 @@ __all__ = [
"ManyToMany", "ManyToMany",
"ManyToManyField", "ManyToManyField",
"BaseField", "BaseField",
"UniqueColumns",
] ]

View File

@ -1,6 +1,7 @@
from typing import Any, Generator, List, Optional, TYPE_CHECKING, Type, Union from typing import Any, Generator, List, Optional, TYPE_CHECKING, Type, Union
import sqlalchemy import sqlalchemy
from sqlalchemy import UniqueConstraint
import ormar # noqa I101 import ormar # noqa I101
from ormar.exceptions import RelationshipInstanceError from ormar.exceptions import RelationshipInstanceError
@ -22,16 +23,20 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
return fk(**init_dict) return fk(**init_dict)
class UniqueColumns(UniqueConstraint):
pass
def ForeignKey( # noqa CFQ002 def ForeignKey( # noqa CFQ002
to: Type["Model"], to: Type["Model"],
*, *,
name: str = None, name: str = None,
unique: bool = False, unique: bool = False,
nullable: bool = True, nullable: bool = True,
related_name: str = None, related_name: str = None,
virtual: bool = False, virtual: bool = False,
onupdate: str = None, onupdate: str = None,
ondelete: str = None, ondelete: str = None,
) -> Type["ForeignKeyField"]: ) -> Type["ForeignKeyField"]:
fk_string = to.Meta.tablename + "." + to.Meta.pkname fk_string = to.Meta.tablename + "." + to.Meta.pkname
to_field = to.__fields__[to.Meta.pkname] to_field = to.__fields__[to.Meta.pkname]
@ -74,7 +79,7 @@ class ForeignKeyField(BaseField):
@classmethod @classmethod
def _extract_model_from_sequence( def _extract_model_from_sequence(
cls, value: List, child: "Model", to_register: bool cls, value: List, child: "Model", to_register: bool
) -> List["Model"]: ) -> List["Model"]:
return [ return [
cls.expand_relationship(val, child, to_register) # type: ignore cls.expand_relationship(val, child, to_register) # type: ignore
@ -83,7 +88,7 @@ class ForeignKeyField(BaseField):
@classmethod @classmethod
def _register_existing_model( def _register_existing_model(
cls, value: "Model", child: "Model", to_register: bool cls, value: "Model", child: "Model", to_register: bool
) -> "Model": ) -> "Model":
if to_register: if to_register:
cls.register_relation(value, child) cls.register_relation(value, child)
@ -91,7 +96,7 @@ class ForeignKeyField(BaseField):
@classmethod @classmethod
def _construct_model_from_dict( def _construct_model_from_dict(
cls, value: dict, child: "Model", to_register: bool cls, value: dict, child: "Model", to_register: bool
) -> "Model": ) -> "Model":
if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname: if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname:
value["__pk_only__"] = True value["__pk_only__"] = True
@ -102,7 +107,7 @@ class ForeignKeyField(BaseField):
@classmethod @classmethod
def _construct_model_from_pk( def _construct_model_from_pk(
cls, value: Any, child: "Model", to_register: bool cls, value: Any, child: "Model", to_register: bool
) -> "Model": ) -> "Model":
if not isinstance(value, cls.to.pk_type()): if not isinstance(value, cls.to.pk_type()):
raise RelationshipInstanceError( raise RelationshipInstanceError(
@ -123,7 +128,7 @@ class ForeignKeyField(BaseField):
@classmethod @classmethod
def expand_relationship( def expand_relationship(
cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True
) -> Optional[Union["Model", List["Model"]]]: ) -> Optional[Union["Model", List["Model"]]]:
if value is None: if value is None:
return None if not cls.virtual else [] return None if not cls.virtual else []

View File

@ -6,6 +6,7 @@ import pydantic
import sqlalchemy import sqlalchemy
from pydantic import BaseConfig from pydantic import BaseConfig
from pydantic.fields import FieldInfo, ModelField from pydantic.fields import FieldInfo, ModelField
from sqlalchemy.sql.schema import ColumnCollectionConstraint
import ormar # noqa I100 import ormar # noqa I100
from ormar import ForeignKey, ModelDefinitionError, Integer # noqa I100 from ormar import ForeignKey, ModelDefinitionError, Integer # noqa I100
@ -27,6 +28,7 @@ class ModelMeta:
metadata: sqlalchemy.MetaData metadata: sqlalchemy.MetaData
database: databases.Database database: databases.Database
columns: List[sqlalchemy.Column] columns: List[sqlalchemy.Column]
constraints: List[ColumnCollectionConstraint]
pkname: str pkname: str
model_fields: Dict[ model_fields: Dict[
str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]]
@ -39,7 +41,7 @@ def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) ->
def register_many_to_many_relation_on_build( def register_many_to_many_relation_on_build(
table_name: str, field: Type[ManyToManyField] table_name: str, field: Type[ManyToManyField]
) -> None: ) -> None:
alias_manager.add_relation_type(field.through.Meta.tablename, table_name) alias_manager.add_relation_type(field.through.Meta.tablename, table_name)
alias_manager.add_relation_type( alias_manager.add_relation_type(
@ -48,11 +50,11 @@ def register_many_to_many_relation_on_build(
def reverse_field_not_already_registered( def reverse_field_not_already_registered(
child: Type["Model"], child_model_name: str, parent_model: Type["Model"] child: Type["Model"], child_model_name: str, parent_model: Type["Model"]
) -> bool: ) -> bool:
return ( return (
child_model_name not in parent_model.__fields__ child_model_name not in parent_model.__fields__
and child.get_name() not in parent_model.__fields__ and child.get_name() not in parent_model.__fields__
) )
@ -63,7 +65,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
parent_model = model_field.to parent_model = model_field.to
child = model child = model
if reverse_field_not_already_registered( if reverse_field_not_already_registered(
child, child_model_name, parent_model child, child_model_name, parent_model
): ):
register_reverse_model_fields( register_reverse_model_fields(
parent_model, child, child_model_name, model_field parent_model, child, child_model_name, model_field
@ -71,10 +73,10 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
def register_reverse_model_fields( def register_reverse_model_fields(
model: Type["Model"], model: Type["Model"],
child: Type["Model"], child: Type["Model"],
child_model_name: str, child_model_name: str,
model_field: Type["ForeignKeyField"], model_field: Type["ForeignKeyField"],
) -> None: ) -> None:
if issubclass(model_field, ManyToManyField): if issubclass(model_field, ManyToManyField):
model.Meta.model_fields[child_model_name] = ManyToMany( model.Meta.model_fields[child_model_name] = ManyToMany(
@ -89,7 +91,7 @@ def register_reverse_model_fields(
def adjust_through_many_to_many_model( def adjust_through_many_to_many_model(
model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField]
) -> None: ) -> None:
model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( model_field.through.Meta.model_fields[model.get_name()] = ForeignKey(
model, name=model.get_name(), ondelete="CASCADE" model, name=model.get_name(), ondelete="CASCADE"
@ -106,7 +108,7 @@ def adjust_through_many_to_many_model(
def create_pydantic_field( def create_pydantic_field(
field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] field_name: str, model: Type["Model"], model_field: Type[ManyToManyField]
) -> None: ) -> None:
model_field.through.__fields__[field_name] = ModelField( model_field.through.__fields__[field_name] = ModelField(
name=field_name, name=field_name,
@ -118,7 +120,7 @@ def create_pydantic_field(
def create_and_append_m2m_fk( def create_and_append_m2m_fk(
model: Type["Model"], model_field: Type[ManyToManyField] model: Type["Model"], model_field: Type[ManyToManyField]
) -> None: ) -> None:
column = sqlalchemy.Column( column = sqlalchemy.Column(
model.get_name(), model.get_name(),
@ -134,7 +136,7 @@ def create_and_append_m2m_fk(
def check_pk_column_validity( def check_pk_column_validity(
field_name: str, field: BaseField, pkname: Optional[str] field_name: str, field: BaseField, pkname: Optional[str]
) -> Optional[str]: ) -> Optional[str]:
if pkname is not None: if pkname is not None:
raise ModelDefinitionError("Only one primary key column is allowed.") raise ModelDefinitionError("Only one primary key column is allowed.")
@ -144,7 +146,7 @@ def check_pk_column_validity(
def sqlalchemy_columns_from_model_fields( def sqlalchemy_columns_from_model_fields(
model_fields: Dict, table_name: str model_fields: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column]]: ) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
columns = [] columns = []
pkname = None pkname = None
@ -158,9 +160,9 @@ def sqlalchemy_columns_from_model_fields(
if field.primary_key: if field.primary_key:
pkname = check_pk_column_validity(field_name, field, pkname) pkname = check_pk_column_validity(field_name, field, pkname)
if ( if (
not field.pydantic_only not field.pydantic_only
and not field.virtual and not field.virtual
and not issubclass(field, ManyToManyField) and not issubclass(field, ManyToManyField)
): ):
columns.append(field.get_column(field_name)) columns.append(field.get_column(field_name))
register_relation_in_alias_manager(table_name, field) register_relation_in_alias_manager(table_name, field)
@ -168,7 +170,7 @@ def sqlalchemy_columns_from_model_fields(
def register_relation_in_alias_manager( def register_relation_in_alias_manager(
table_name: str, field: Type[ForeignKeyField] table_name: str, field: Type[ForeignKeyField]
) -> None: ) -> None:
if issubclass(field, ManyToManyField): if issubclass(field, ManyToManyField):
register_many_to_many_relation_on_build(table_name, field) register_many_to_many_relation_on_build(table_name, field)
@ -177,7 +179,7 @@ def register_relation_in_alias_manager(
def populate_default_pydantic_field_value( def populate_default_pydantic_field_value(
type_: Type[BaseField], field: str, attrs: dict type_: Type[BaseField], field: str, attrs: dict
) -> dict: ) -> dict:
def_value = type_.default_value() def_value = type_.default_value()
curr_def_value = attrs.get(field, "NONE") curr_def_value = attrs.get(field, "NONE")
@ -206,7 +208,7 @@ def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict:
def populate_meta_orm_model_fields( def populate_meta_orm_model_fields(
attrs: dict, new_model: Type["Model"] attrs: dict, new_model: Type["Model"]
) -> Type["Model"]: ) -> Type["Model"]:
model_fields = { model_fields = {
field_name: field field_name: field
@ -218,7 +220,7 @@ def populate_meta_orm_model_fields(
def populate_meta_tablename_columns_and_pk( def populate_meta_tablename_columns_and_pk(
name: str, new_model: Type["Model"] name: str, new_model: Type["Model"]
) -> Type["Model"]: ) -> Type["Model"]:
tablename = name.lower() + "s" tablename = name.lower() + "s"
new_model.Meta.tablename = new_model.Meta.tablename or tablename new_model.Meta.tablename = new_model.Meta.tablename or tablename
@ -242,11 +244,11 @@ def populate_meta_tablename_columns_and_pk(
def populate_meta_sqlalchemy_table_if_required( def populate_meta_sqlalchemy_table_if_required(
new_model: Type["Model"], new_model: Type["Model"],
) -> Type["Model"]: ) -> Type["Model"]:
if not hasattr(new_model.Meta, "table"): if not hasattr(new_model.Meta, "table"):
new_model.Meta.table = sqlalchemy.Table( new_model.Meta.table = sqlalchemy.Table(
new_model.Meta.tablename, new_model.Meta.metadata, *new_model.Meta.columns new_model.Meta.tablename, new_model.Meta.metadata, *new_model.Meta.columns, *new_model.Meta.constraints
) )
return new_model return new_model
@ -281,7 +283,7 @@ def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, A
def populate_choices_validators( # noqa CCR001 def populate_choices_validators( # noqa CCR001
model: Type["Model"], attrs: Dict model: Type["Model"], attrs: Dict
) -> None: ) -> None:
if model_initialized_and_has_model_fields(model): if model_initialized_and_has_model_fields(model):
for _, field in model.Meta.model_fields.items(): for _, field in model.Meta.model_fields.items():
@ -294,7 +296,7 @@ def populate_choices_validators( # noqa CCR001
class ModelMetaclass(pydantic.main.ModelMetaclass): class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__( # type: ignore def __new__( # type: ignore
mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict
) -> "ModelMetaclass": ) -> "ModelMetaclass":
attrs["Config"] = get_pydantic_base_orm_config() attrs["Config"] = get_pydantic_base_orm_config()
attrs["__name__"] = name attrs["__name__"] = name
@ -304,6 +306,8 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
) )
if hasattr(new_model, "Meta"): if hasattr(new_model, "Meta"):
if not hasattr(new_model.Meta, 'constraints'):
new_model.Meta.constraints = []
new_model = populate_meta_orm_model_fields(attrs, new_model) new_model = populate_meta_orm_model_fields(attrs, new_model)
new_model = populate_meta_tablename_columns_and_pk(name, new_model) new_model = populate_meta_tablename_columns_and_pk(name, new_model)
new_model = populate_meta_sqlalchemy_table_if_required(new_model) new_model = populate_meta_sqlalchemy_table_if_required(new_model)

View File

@ -0,0 +1,53 @@
import asyncio
import sqlite3
import databases
import pytest
import sqlalchemy
from sqlalchemy.exc import IntegrityError
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Product(ormar.Model):
class Meta:
tablename = "products"
metadata = metadata
database = database
constraints = [ormar.UniqueColumns('name', 'company')]
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
company: ormar.String(max_length=200)
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module")
async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_unique_columns():
async with database:
async with database.transaction(force_rollback=True):
await Product.objects.create(name='Cookies', company='Nestle')
await Product.objects.create(name='Mars', company='Mars')
await Product.objects.create(name='Mars', company='Nestle')
with pytest.raises((IntegrityError, sqlite3.IntegrityError)):
await Product.objects.create(name='Mars', company='Mars')