diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 41cfd19..957c9e5 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -28,15 +28,15 @@ class UniqueColumns(UniqueConstraint): def ForeignKey( # noqa CFQ002 - to: Type["Model"], - *, - name: str = None, - unique: bool = False, - nullable: bool = True, - related_name: str = None, - virtual: bool = False, - onupdate: str = None, - ondelete: str = None, + to: Type["Model"], + *, + name: str = None, + unique: bool = False, + nullable: bool = True, + related_name: str = None, + virtual: bool = False, + onupdate: str = None, + ondelete: str = None, ) -> Type["ForeignKeyField"]: fk_string = to.Meta.tablename + "." + to.Meta.pkname to_field = to.__fields__[to.Meta.pkname] @@ -79,7 +79,7 @@ class ForeignKeyField(BaseField): @classmethod def _extract_model_from_sequence( - cls, value: List, child: "Model", to_register: bool + cls, value: List, child: "Model", to_register: bool ) -> List["Model"]: return [ cls.expand_relationship(val, child, to_register) # type: ignore @@ -88,7 +88,7 @@ class ForeignKeyField(BaseField): @classmethod def _register_existing_model( - cls, value: "Model", child: "Model", to_register: bool + cls, value: "Model", child: "Model", to_register: bool ) -> "Model": if to_register: cls.register_relation(value, child) @@ -96,7 +96,7 @@ class ForeignKeyField(BaseField): @classmethod def _construct_model_from_dict( - cls, value: dict, child: "Model", to_register: bool + cls, value: dict, child: "Model", to_register: bool ) -> "Model": if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname: value["__pk_only__"] = True @@ -107,7 +107,7 @@ class ForeignKeyField(BaseField): @classmethod def _construct_model_from_pk( - cls, value: Any, child: "Model", to_register: bool + cls, value: Any, child: "Model", to_register: bool ) -> "Model": if not isinstance(value, cls.to.pk_type()): raise RelationshipInstanceError( @@ -128,7 +128,7 @@ class ForeignKeyField(BaseField): @classmethod def expand_relationship( - cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True + cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True ) -> Optional[Union["Model", List["Model"]]]: if value is None: return None if not cls.virtual else [] diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index e8b9c5f..30d3071 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -41,7 +41,7 @@ def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> def register_many_to_many_relation_on_build( - table_name: str, field: Type[ManyToManyField] + table_name: str, field: Type[ManyToManyField] ) -> None: alias_manager.add_relation_type(field.through.Meta.tablename, table_name) alias_manager.add_relation_type( @@ -50,11 +50,11 @@ def register_many_to_many_relation_on_build( def reverse_field_not_already_registered( - child: Type["Model"], child_model_name: str, parent_model: Type["Model"] + child: Type["Model"], child_model_name: str, parent_model: Type["Model"] ) -> bool: return ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ) @@ -65,7 +65,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if reverse_field_not_already_registered( - child, child_model_name, parent_model + child, child_model_name, parent_model ): register_reverse_model_fields( parent_model, child, child_model_name, model_field @@ -73,10 +73,10 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: def register_reverse_model_fields( - model: Type["Model"], - child: Type["Model"], - child_model_name: str, - model_field: Type["ForeignKeyField"], + model: Type["Model"], + child: Type["Model"], + child_model_name: str, + model_field: Type["ForeignKeyField"], ) -> None: if issubclass(model_field, ManyToManyField): model.Meta.model_fields[child_model_name] = ManyToMany( @@ -91,7 +91,7 @@ def register_reverse_model_fields( def adjust_through_many_to_many_model( - model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( model, name=model.get_name(), ondelete="CASCADE" @@ -108,7 +108,7 @@ def adjust_through_many_to_many_model( def create_pydantic_field( - field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] + field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.__fields__[field_name] = ModelField( name=field_name, @@ -120,7 +120,7 @@ def create_pydantic_field( def create_and_append_m2m_fk( - model: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: column = sqlalchemy.Column( model.get_name(), @@ -136,7 +136,7 @@ def create_and_append_m2m_fk( def check_pk_column_validity( - field_name: str, field: BaseField, pkname: Optional[str] + field_name: str, field: BaseField, pkname: Optional[str] ) -> Optional[str]: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") @@ -146,7 +146,7 @@ def check_pk_column_validity( def sqlalchemy_columns_from_model_fields( - model_fields: Dict, table_name: str + model_fields: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column]]: columns = [] pkname = None @@ -160,9 +160,9 @@ def sqlalchemy_columns_from_model_fields( if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) if ( - not field.pydantic_only - and not field.virtual - and not issubclass(field, ManyToManyField) + not field.pydantic_only + and not field.virtual + and not issubclass(field, ManyToManyField) ): columns.append(field.get_column(field_name)) register_relation_in_alias_manager(table_name, field) @@ -170,7 +170,7 @@ def sqlalchemy_columns_from_model_fields( def register_relation_in_alias_manager( - table_name: str, field: Type[ForeignKeyField] + table_name: str, field: Type[ForeignKeyField] ) -> None: if issubclass(field, ManyToManyField): register_many_to_many_relation_on_build(table_name, field) @@ -179,7 +179,7 @@ def register_relation_in_alias_manager( def populate_default_pydantic_field_value( - type_: Type[BaseField], field: str, attrs: dict + type_: Type[BaseField], field: str, attrs: dict ) -> dict: def_value = type_.default_value() curr_def_value = attrs.get(field, "NONE") @@ -208,7 +208,7 @@ def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict: def populate_meta_orm_model_fields( - attrs: dict, new_model: Type["Model"] + attrs: dict, new_model: Type["Model"] ) -> Type["Model"]: model_fields = { field_name: field @@ -220,7 +220,7 @@ def populate_meta_orm_model_fields( def populate_meta_tablename_columns_and_pk( - name: str, new_model: Type["Model"] + name: str, new_model: Type["Model"] ) -> Type["Model"]: tablename = name.lower() + "s" new_model.Meta.tablename = new_model.Meta.tablename or tablename @@ -244,11 +244,14 @@ def populate_meta_tablename_columns_and_pk( def populate_meta_sqlalchemy_table_if_required( - new_model: Type["Model"], + new_model: Type["Model"], ) -> Type["Model"]: if not hasattr(new_model.Meta, "table"): new_model.Meta.table = sqlalchemy.Table( - new_model.Meta.tablename, new_model.Meta.metadata, *new_model.Meta.columns, *new_model.Meta.constraints + new_model.Meta.tablename, + new_model.Meta.metadata, + *new_model.Meta.columns, + *new_model.Meta.constraints, ) return new_model @@ -283,7 +286,7 @@ def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, A def populate_choices_validators( # noqa CCR001 - model: Type["Model"], attrs: Dict + model: Type["Model"], attrs: Dict ) -> None: if model_initialized_and_has_model_fields(model): for _, field in model.Meta.model_fields.items(): @@ -296,7 +299,7 @@ def populate_choices_validators( # noqa CCR001 class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__( # type: ignore - mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict + mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict ) -> "ModelMetaclass": attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name @@ -306,7 +309,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): ) if hasattr(new_model, "Meta"): - if not hasattr(new_model.Meta, 'constraints'): + if not hasattr(new_model.Meta, "constraints"): new_model.Meta.constraints = [] new_model = populate_meta_orm_model_fields(attrs, new_model) new_model = populate_meta_tablename_columns_and_pk(name, new_model) diff --git a/tests/test_unique_constraints.py b/tests/test_unique_constraints.py index 788ec9d..7908ae4 100644 --- a/tests/test_unique_constraints.py +++ b/tests/test_unique_constraints.py @@ -1,10 +1,11 @@ import asyncio import sqlite3 +import asyncpg import databases +import pymysql import pytest import sqlalchemy -from sqlalchemy.exc import IntegrityError import ormar from tests.settings import DATABASE_URL @@ -18,7 +19,7 @@ class Product(ormar.Model): tablename = "products" metadata = metadata database = database - constraints = [ormar.UniqueColumns('name', 'company')] + constraints = [ormar.UniqueColumns("name", "company")] id: ormar.Integer(primary_key=True) name: ormar.String(max_length=100) @@ -45,9 +46,15 @@ async def create_test_database(): async def test_unique_columns(): async with database: async with database.transaction(force_rollback=True): - await Product.objects.create(name='Cookies', company='Nestle') - await Product.objects.create(name='Mars', company='Mars') - await Product.objects.create(name='Mars', company='Nestle') + await Product.objects.create(name="Cookies", company="Nestle") + await Product.objects.create(name="Mars", company="Mars") + await Product.objects.create(name="Mars", company="Nestle") - with pytest.raises((IntegrityError, sqlite3.IntegrityError)): - await Product.objects.create(name='Mars', company='Mars') + with pytest.raises( + ( + sqlite3.IntegrityError, + pymysql.IntegrityError, + asyncpg.exceptions.UniqueViolationError, + ) + ): + await Product.objects.create(name="Mars", company="Mars")