diff --git a/.coverage b/.coverage index de993f7..6da7884 100644 Binary files a/.coverage and b/.coverage differ diff --git a/README.md b/README.md index 5c5701a..c66305c 100644 --- a/README.md +++ b/README.md @@ -394,8 +394,8 @@ All fields are required unless one of the following is set: * `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column. Autoincrement is set by default on int primary keys. -Available Model Fields: -* `String(length)` +Available Model Fields (with required args - optional ones in docs): +* `String(max_length)` * `Text()` * `Boolean()` * `Integer()` diff --git a/ormar/__init__.py b/ormar/__init__.py index c3d1ed7..1616879 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -14,6 +14,7 @@ from ormar.fields import ( Text, Time, UUID, + UniqueColumns, ) from ormar.models import Model from ormar.queryset import QuerySet @@ -51,4 +52,5 @@ __all__ = [ "RelationType", "Undefined", "UUID", + "UniqueColumns", ] diff --git a/ormar/fields/__init__.py b/ormar/fields/__init__.py index 0035a4f..325fcf6 100644 --- a/ormar/fields/__init__.py +++ b/ormar/fields/__init__.py @@ -1,5 +1,5 @@ from ormar.fields.base import BaseField -from ormar.fields.foreign_key import ForeignKey +from ormar.fields.foreign_key import ForeignKey, UniqueColumns from ormar.fields.many_to_many import ManyToMany, ManyToManyField from ormar.fields.model_fields import ( BigInteger, @@ -33,4 +33,5 @@ __all__ = [ "ManyToMany", "ManyToManyField", "BaseField", + "UniqueColumns", ] diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 84d9110..41cfd19 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -1,6 +1,7 @@ from typing import Any, Generator, List, Optional, TYPE_CHECKING, Type, Union import sqlalchemy +from sqlalchemy import UniqueConstraint import ormar # noqa I101 from ormar.exceptions import RelationshipInstanceError @@ -22,16 +23,20 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": return fk(**init_dict) +class UniqueColumns(UniqueConstraint): + pass + + def ForeignKey( # noqa CFQ002 - to: Type["Model"], - *, - name: str = None, - unique: bool = False, - nullable: bool = True, - related_name: str = None, - virtual: bool = False, - onupdate: str = None, - ondelete: str = None, + to: Type["Model"], + *, + name: str = None, + unique: bool = False, + nullable: bool = True, + related_name: str = None, + virtual: bool = False, + onupdate: str = None, + ondelete: str = None, ) -> Type["ForeignKeyField"]: fk_string = to.Meta.tablename + "." + to.Meta.pkname to_field = to.__fields__[to.Meta.pkname] @@ -74,7 +79,7 @@ class ForeignKeyField(BaseField): @classmethod def _extract_model_from_sequence( - cls, value: List, child: "Model", to_register: bool + cls, value: List, child: "Model", to_register: bool ) -> List["Model"]: return [ cls.expand_relationship(val, child, to_register) # type: ignore @@ -83,7 +88,7 @@ class ForeignKeyField(BaseField): @classmethod def _register_existing_model( - cls, value: "Model", child: "Model", to_register: bool + cls, value: "Model", child: "Model", to_register: bool ) -> "Model": if to_register: cls.register_relation(value, child) @@ -91,7 +96,7 @@ class ForeignKeyField(BaseField): @classmethod def _construct_model_from_dict( - cls, value: dict, child: "Model", to_register: bool + cls, value: dict, child: "Model", to_register: bool ) -> "Model": if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname: value["__pk_only__"] = True @@ -102,7 +107,7 @@ class ForeignKeyField(BaseField): @classmethod def _construct_model_from_pk( - cls, value: Any, child: "Model", to_register: bool + cls, value: Any, child: "Model", to_register: bool ) -> "Model": if not isinstance(value, cls.to.pk_type()): raise RelationshipInstanceError( @@ -123,7 +128,7 @@ class ForeignKeyField(BaseField): @classmethod def expand_relationship( - cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True + cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True ) -> Optional[Union["Model", List["Model"]]]: if value is None: return None if not cls.virtual else [] diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 0ba7840..e8b9c5f 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -6,6 +6,7 @@ import pydantic import sqlalchemy from pydantic import BaseConfig from pydantic.fields import FieldInfo, ModelField +from sqlalchemy.sql.schema import ColumnCollectionConstraint import ormar # noqa I100 from ormar import ForeignKey, ModelDefinitionError, Integer # noqa I100 @@ -27,6 +28,7 @@ class ModelMeta: metadata: sqlalchemy.MetaData database: databases.Database columns: List[sqlalchemy.Column] + constraints: List[ColumnCollectionConstraint] pkname: str model_fields: Dict[ str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] @@ -39,7 +41,7 @@ def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> def register_many_to_many_relation_on_build( - table_name: str, field: Type[ManyToManyField] + table_name: str, field: Type[ManyToManyField] ) -> None: alias_manager.add_relation_type(field.through.Meta.tablename, table_name) alias_manager.add_relation_type( @@ -48,11 +50,11 @@ def register_many_to_many_relation_on_build( def reverse_field_not_already_registered( - child: Type["Model"], child_model_name: str, parent_model: Type["Model"] + child: Type["Model"], child_model_name: str, parent_model: Type["Model"] ) -> bool: return ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ) @@ -63,7 +65,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if reverse_field_not_already_registered( - child, child_model_name, parent_model + child, child_model_name, parent_model ): register_reverse_model_fields( parent_model, child, child_model_name, model_field @@ -71,10 +73,10 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: def register_reverse_model_fields( - model: Type["Model"], - child: Type["Model"], - child_model_name: str, - model_field: Type["ForeignKeyField"], + model: Type["Model"], + child: Type["Model"], + child_model_name: str, + model_field: Type["ForeignKeyField"], ) -> None: if issubclass(model_field, ManyToManyField): model.Meta.model_fields[child_model_name] = ManyToMany( @@ -89,7 +91,7 @@ def register_reverse_model_fields( def adjust_through_many_to_many_model( - model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( model, name=model.get_name(), ondelete="CASCADE" @@ -106,7 +108,7 @@ def adjust_through_many_to_many_model( def create_pydantic_field( - field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] + field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.__fields__[field_name] = ModelField( name=field_name, @@ -118,7 +120,7 @@ def create_pydantic_field( def create_and_append_m2m_fk( - model: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: column = sqlalchemy.Column( model.get_name(), @@ -134,7 +136,7 @@ def create_and_append_m2m_fk( def check_pk_column_validity( - field_name: str, field: BaseField, pkname: Optional[str] + field_name: str, field: BaseField, pkname: Optional[str] ) -> Optional[str]: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") @@ -144,7 +146,7 @@ def check_pk_column_validity( def sqlalchemy_columns_from_model_fields( - model_fields: Dict, table_name: str + model_fields: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column]]: columns = [] pkname = None @@ -158,9 +160,9 @@ def sqlalchemy_columns_from_model_fields( if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) if ( - not field.pydantic_only - and not field.virtual - and not issubclass(field, ManyToManyField) + not field.pydantic_only + and not field.virtual + and not issubclass(field, ManyToManyField) ): columns.append(field.get_column(field_name)) register_relation_in_alias_manager(table_name, field) @@ -168,7 +170,7 @@ def sqlalchemy_columns_from_model_fields( def register_relation_in_alias_manager( - table_name: str, field: Type[ForeignKeyField] + table_name: str, field: Type[ForeignKeyField] ) -> None: if issubclass(field, ManyToManyField): register_many_to_many_relation_on_build(table_name, field) @@ -177,7 +179,7 @@ def register_relation_in_alias_manager( def populate_default_pydantic_field_value( - type_: Type[BaseField], field: str, attrs: dict + type_: Type[BaseField], field: str, attrs: dict ) -> dict: def_value = type_.default_value() curr_def_value = attrs.get(field, "NONE") @@ -206,7 +208,7 @@ def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict: def populate_meta_orm_model_fields( - attrs: dict, new_model: Type["Model"] + attrs: dict, new_model: Type["Model"] ) -> Type["Model"]: model_fields = { field_name: field @@ -218,7 +220,7 @@ def populate_meta_orm_model_fields( def populate_meta_tablename_columns_and_pk( - name: str, new_model: Type["Model"] + name: str, new_model: Type["Model"] ) -> Type["Model"]: tablename = name.lower() + "s" new_model.Meta.tablename = new_model.Meta.tablename or tablename @@ -242,11 +244,11 @@ def populate_meta_tablename_columns_and_pk( def populate_meta_sqlalchemy_table_if_required( - new_model: Type["Model"], + new_model: Type["Model"], ) -> Type["Model"]: if not hasattr(new_model.Meta, "table"): new_model.Meta.table = sqlalchemy.Table( - new_model.Meta.tablename, new_model.Meta.metadata, *new_model.Meta.columns + new_model.Meta.tablename, new_model.Meta.metadata, *new_model.Meta.columns, *new_model.Meta.constraints ) return new_model @@ -281,7 +283,7 @@ def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, A def populate_choices_validators( # noqa CCR001 - model: Type["Model"], attrs: Dict + model: Type["Model"], attrs: Dict ) -> None: if model_initialized_and_has_model_fields(model): for _, field in model.Meta.model_fields.items(): @@ -294,7 +296,7 @@ def populate_choices_validators( # noqa CCR001 class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__( # type: ignore - mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict + mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict ) -> "ModelMetaclass": attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name @@ -304,6 +306,8 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): ) if hasattr(new_model, "Meta"): + if not hasattr(new_model.Meta, 'constraints'): + new_model.Meta.constraints = [] new_model = populate_meta_orm_model_fields(attrs, new_model) new_model = populate_meta_tablename_columns_and_pk(name, new_model) new_model = populate_meta_sqlalchemy_table_if_required(new_model) diff --git a/tests/test_unique_constraints.py b/tests/test_unique_constraints.py new file mode 100644 index 0000000..788ec9d --- /dev/null +++ b/tests/test_unique_constraints.py @@ -0,0 +1,53 @@ +import asyncio +import sqlite3 + +import databases +import pytest +import sqlalchemy +from sqlalchemy.exc import IntegrityError + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Product(ormar.Model): + class Meta: + tablename = "products" + metadata = metadata + database = database + constraints = [ormar.UniqueColumns('name', 'company')] + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + company: ormar.String(max_length=200) + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@pytest.fixture(autouse=True, scope="module") +async def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_unique_columns(): + async with database: + async with database.transaction(force_rollback=True): + await Product.objects.create(name='Cookies', company='Nestle') + await Product.objects.create(name='Mars', company='Mars') + await Product.objects.create(name='Mars', company='Nestle') + + with pytest.raises((IntegrityError, sqlite3.IntegrityError)): + await Product.objects.create(name='Mars', company='Mars')