diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 3a824e2..704b88b 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -198,7 +198,15 @@ class BaseField(FieldInfo): return False @classmethod - def construct_contraints(cls) -> List: + def construct_constraints(cls) -> List: + """ + Converts list of ormar constraints into sqlalchemy ForeignKeys. + Has to be done dynamically as sqlalchemy binds ForeignKey to the table. + And we need a new ForeignKey for subclasses of current model + + :return: List of sqlalchemy foreign keys - by default one. + :rtype: List[sqlalchemy.schema.ForeignKey] + """ return [ sqlalchemy.schema.ForeignKey( con.name, ondelete=con.ondelete, onupdate=con.onupdate @@ -221,7 +229,7 @@ class BaseField(FieldInfo): return sqlalchemy.Column( cls.alias or name, cls.column_type, - *cls.construct_contraints(), + *cls.construct_constraints(), primary_key=cls.primary_key, nullable=cls.nullable and not cls.primary_key, index=cls.index, diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 8a12716..452fa47 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -14,7 +14,6 @@ from typing import ( import databases import pydantic import sqlalchemy -from pydantic.fields import FieldInfo from sqlalchemy.sql.schema import ColumnCollectionConstraint import ormar # noqa I100 @@ -206,47 +205,13 @@ def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001 new_model.Meta.signals = signals -def check_conflicting_fields( - new_fields: Set, - attrs: Dict, - base_class: type, - curr_class: type, - previous_fields: Set = None, -) -> None: - """ - You cannot redefine fields with same names in inherited classes. - Ormar will raise an exception if it encounters a field that is an ormar - Field and at the same time was already declared in one of base classes. - - :param previous_fields: set of names of fields defined in base model - :type previous_fields: Set[str] - :param new_fields: set of names of fields defined in current model - :type new_fields: Set[str] - :param attrs: namespace of current class - :type attrs: Dict - :param base_class: one of the parent classes - :type base_class: Model or model parent class - :param curr_class: current constructed class - :type curr_class: Model or model parent class - """ - if not previous_fields: - previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)}) - overwrite = new_fields.intersection(previous_fields) - - if overwrite: - raise ModelDefinitionError( - f"Model {curr_class} redefines the fields: " - f"{overwrite} already defined in {base_class}!" - ) - - def update_attrs_and_fields( attrs: Dict, new_attrs: Dict, model_fields: Dict, new_model_fields: Dict, new_fields: Set, -) -> None: +) -> Dict: """ Updates __annotations__, values of model fields (so pydantic FieldInfos) as well as model.Meta.model_fields definitions from parents. @@ -265,11 +230,43 @@ def update_attrs_and_fields( key = "__annotations__" attrs[key].update(new_attrs[key]) attrs.update({name: new_attrs[name] for name in new_fields}) - model_fields.update(new_model_fields) + updated_model_fields = {k: v for k, v in new_model_fields.items()} + updated_model_fields.update(model_fields) + return updated_model_fields + + +def verify_constraint_names( + base_class: "Model", model_fields: Dict, parent_value: List +) -> None: + """ + Verifies if redefined fields that are overwritten in subclasses did not remove + any name of the column that is used in constraint as it will fail. + + :param base_class: one of the parent classes + :type base_class: Model or model parent class + :param model_fields: ormar fields in defined in current class + :type model_fields: Dict[str, BaseField] + :param parent_value: list of base class constraints + :type parent_value: List + """ + new_aliases = {x.name: x.get_alias() for x in model_fields.values()} + old_aliases = {x.name: x.get_alias() for x in base_class.Meta.model_fields.values()} + old_aliases.update(new_aliases) + constraints_columns = [x._pending_colargs for x in parent_value] + for column_set in constraints_columns: + if any(x not in old_aliases.values() for x in column_set): + raise ModelDefinitionError( + f"Unique columns constraint " + f"{column_set} " + f"has column names " + f"that are not in the model fields." + f"\n Check columns redefined in subclasses " + f"to verify that they have proper 'name' set." + ) def update_attrs_from_base_meta( # noqa: CCR001 - base_class: "Model", attrs: Dict, + base_class: "Model", attrs: Dict, model_fields: Dict ) -> None: """ Updates Meta parameters in child from parent if needed. @@ -278,7 +275,10 @@ def update_attrs_from_base_meta( # noqa: CCR001 :type base_class: Model or model parent class :param attrs: new namespace for class being constructed :type attrs: Dict + :param model_fields: ormar fields in defined in current class + :type model_fields: Dict[str, BaseField] """ + params_to_update = ["metadata", "database", "constraints"] for param in params_to_update: current_value = attrs.get("Meta", {}).__dict__.get(param, ormar.Undefined) @@ -287,6 +287,11 @@ def update_attrs_from_base_meta( # noqa: CCR001 ) if parent_value: if param == "constraints": + verify_constraint_names( + base_class=base_class, + model_fields=model_fields, + parent_value=parent_value, + ) parent_value = [ ormar.UniqueColumns(*x._pending_colargs) for x in parent_value ] @@ -326,16 +331,7 @@ def copy_data_from_parent_model( # noqa: CCR001 :rtype: Tuple[Dict, Dict] """ if attrs.get("Meta"): - new_fields = set(base_class.Meta.model_fields.keys()) # type: ignore - previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)}) - check_conflicting_fields( - new_fields=new_fields, - attrs=attrs, - base_class=base_class, - curr_class=curr_class, - previous_fields=previous_fields, - ) - if previous_fields and not base_class.Meta.abstract: # type: ignore + if model_fields and not base_class.Meta.abstract: # type: ignore raise ModelDefinitionError( f"{curr_class.__name__} cannot inherit " f"from non abstract class {base_class.__name__}" @@ -343,6 +339,7 @@ def copy_data_from_parent_model( # noqa: CCR001 update_attrs_from_base_meta( base_class=base_class, # type: ignore attrs=attrs, + model_fields=model_fields, ) parent_fields = dict() meta = attrs.get("Meta") @@ -364,7 +361,8 @@ def copy_data_from_parent_model( # noqa: CCR001 else: parent_fields[field_name] = field - model_fields.update(parent_fields) # type: ignore + parent_fields.update(model_fields) # type: ignore + model_fields = parent_fields return attrs, model_fields @@ -416,14 +414,7 @@ def extract_from_parents_definition( # noqa: CCR001 new_attrs, new_model_fields = getattr(base_class, PARSED_FIELDS_KEY) new_fields = set(new_model_fields.keys()) - check_conflicting_fields( - new_fields=new_fields, - attrs=attrs, - base_class=base_class, - curr_class=curr_class, - ) - - update_attrs_and_fields( + model_fields = update_attrs_and_fields( attrs=attrs, new_attrs=new_attrs, model_fields=model_fields, @@ -435,23 +426,16 @@ def extract_from_parents_definition( # noqa: CCR001 potential_fields = get_potential_fields(base_class.__dict__) if potential_fields: # parent model has ormar fields defined and was not parsed before - new_attrs = {key: base_class.__dict__.get(key, {})} + new_attrs = {key: {k: v for k, v in base_class.__dict__.get(key, {}).items()}} new_attrs.update(potential_fields) new_fields = set(potential_fields.keys()) - check_conflicting_fields( - new_fields=new_fields, - attrs=attrs, - base_class=base_class, - curr_class=curr_class, - ) for name in new_fields: delattr(base_class, name) new_attrs, new_model_fields = extract_annotations_and_default_vals(new_attrs) setattr(base_class, PARSED_FIELDS_KEY, (new_attrs, new_model_fields)) - - update_attrs_and_fields( + model_fields = update_attrs_and_fields( attrs=attrs, new_attrs=new_attrs, model_fields=model_fields, diff --git a/tests/test_inheritance_concrete.py b/tests/test_inheritance_concrete.py index 9211102..f8f9ce4 100644 --- a/tests/test_inheritance_concrete.py +++ b/tests/test_inheritance_concrete.py @@ -4,6 +4,7 @@ from typing import Optional import databases import pytest +import sqlalchemy import sqlalchemy as sa from sqlalchemy import create_engine @@ -132,17 +133,37 @@ def test_init_of_abstract_model(): DateFieldsModel() -def test_field_redefining_raises_error(): +def test_field_redefining_in_concrete_models(): + class RedefinedField(DateFieldsModel): + class Meta(ormar.ModelMeta): + tablename = "redefines" + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + created_date: str = ormar.String(max_length=200, name="creation_date") + + changed_field = RedefinedField.Meta.model_fields["created_date"] + assert changed_field.default is None + assert changed_field.alias == "creation_date" + assert any(x.name == "creation_date" for x in RedefinedField.Meta.table.columns) + assert isinstance( + RedefinedField.Meta.table.columns["creation_date"].type, + sqlalchemy.sql.sqltypes.String, + ) + + +def test_model_subclassing_that_redefines_constraints_column_names(): with pytest.raises(ModelDefinitionError): - class WrongField(DateFieldsModel): # pragma: no cover + class WrongField2(DateFieldsModel): # pragma: no cover class Meta(ormar.ModelMeta): tablename = "wrongs" metadata = metadata database = db id: int = ormar.Integer(primary_key=True) - created_date: datetime.datetime = ormar.DateTime() + created_date: str = ormar.String(max_length=200) def test_model_subclassing_non_abstract_raises_error(): diff --git a/tests/test_inheritance_mixins.py b/tests/test_inheritance_mixins.py index 2c93b90..6a580a1 100644 --- a/tests/test_inheritance_mixins.py +++ b/tests/test_inheritance_mixins.py @@ -4,6 +4,7 @@ from typing import Optional import databases import pytest +import sqlalchemy import sqlalchemy as sa from sqlalchemy import create_engine @@ -55,17 +56,19 @@ def create_test_database(): metadata.drop_all(engine) -def test_field_redefining_raises_error(): - with pytest.raises(ModelDefinitionError): +def test_field_redefining(): + class RedefinedField(ormar.Model, DateFieldsMixins): + class Meta(ormar.ModelMeta): + tablename = "redefined" + metadata = metadata + database = db - class WrongField(ormar.Model, DateFieldsMixins): # pragma: no cover - class Meta(ormar.ModelMeta): - tablename = "wrongs" - metadata = metadata - database = db + id: int = ormar.Integer(primary_key=True) + created_date: datetime.datetime = ormar.DateTime(name="creation_date") - id: int = ormar.Integer(primary_key=True) - created_date: datetime.datetime = ormar.DateTime() + assert RedefinedField.Meta.model_fields["created_date"].default is None + assert RedefinedField.Meta.model_fields["created_date"].alias == "creation_date" + assert any(x.name == "creation_date" for x in RedefinedField.Meta.table.columns) def test_field_redefining_in_second_raises_error(): @@ -77,16 +80,22 @@ def test_field_redefining_in_second_raises_error(): id: int = ormar.Integer(primary_key=True) - with pytest.raises(ModelDefinitionError): + class RedefinedField2(ormar.Model, DateFieldsMixins): + class Meta(ormar.ModelMeta): + tablename = "redefines2" + metadata = metadata + database = db - class WrongField(ormar.Model, DateFieldsMixins): # pragma: no cover - class Meta(ormar.ModelMeta): - tablename = "wrongs" - metadata = metadata - database = db + id: int = ormar.Integer(primary_key=True) + created_date: str = ormar.String(max_length=200, name="creation_date") - id: int = ormar.Integer(primary_key=True) - created_date: datetime.datetime = ormar.DateTime() + assert RedefinedField2.Meta.model_fields["created_date"].default is None + assert RedefinedField2.Meta.model_fields["created_date"].alias == "creation_date" + assert any(x.name == "creation_date" for x in RedefinedField2.Meta.table.columns) + assert isinstance( + RedefinedField2.Meta.table.columns["creation_date"].type, + sqlalchemy.sql.sqltypes.String, + ) def round_date_to_seconds(