allow fields redefining - check column names with names used in constraints

This commit is contained in:
collerek
2020-12-18 10:43:24 +01:00
parent e98300233e
commit 2d74b7bd47
4 changed files with 110 additions and 88 deletions

View File

@ -198,7 +198,15 @@ class BaseField(FieldInfo):
return False return False
@classmethod @classmethod
def construct_contraints(cls) -> List: def construct_constraints(cls) -> List:
"""
Converts list of ormar constraints into sqlalchemy ForeignKeys.
Has to be done dynamically as sqlalchemy binds ForeignKey to the table.
And we need a new ForeignKey for subclasses of current model
:return: List of sqlalchemy foreign keys - by default one.
:rtype: List[sqlalchemy.schema.ForeignKey]
"""
return [ return [
sqlalchemy.schema.ForeignKey( sqlalchemy.schema.ForeignKey(
con.name, ondelete=con.ondelete, onupdate=con.onupdate con.name, ondelete=con.ondelete, onupdate=con.onupdate
@ -221,7 +229,7 @@ class BaseField(FieldInfo):
return sqlalchemy.Column( return sqlalchemy.Column(
cls.alias or name, cls.alias or name,
cls.column_type, cls.column_type,
*cls.construct_contraints(), *cls.construct_constraints(),
primary_key=cls.primary_key, primary_key=cls.primary_key,
nullable=cls.nullable and not cls.primary_key, nullable=cls.nullable and not cls.primary_key,
index=cls.index, index=cls.index,

View File

@ -14,7 +14,6 @@ from typing import (
import databases import databases
import pydantic import pydantic
import sqlalchemy import sqlalchemy
from pydantic.fields import FieldInfo
from sqlalchemy.sql.schema import ColumnCollectionConstraint from sqlalchemy.sql.schema import ColumnCollectionConstraint
import ormar # noqa I100 import ormar # noqa I100
@ -206,47 +205,13 @@ def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001
new_model.Meta.signals = signals new_model.Meta.signals = signals
def check_conflicting_fields(
new_fields: Set,
attrs: Dict,
base_class: type,
curr_class: type,
previous_fields: Set = None,
) -> None:
"""
You cannot redefine fields with same names in inherited classes.
Ormar will raise an exception if it encounters a field that is an ormar
Field and at the same time was already declared in one of base classes.
:param previous_fields: set of names of fields defined in base model
:type previous_fields: Set[str]
:param new_fields: set of names of fields defined in current model
:type new_fields: Set[str]
:param attrs: namespace of current class
:type attrs: Dict
:param base_class: one of the parent classes
:type base_class: Model or model parent class
:param curr_class: current constructed class
:type curr_class: Model or model parent class
"""
if not previous_fields:
previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)})
overwrite = new_fields.intersection(previous_fields)
if overwrite:
raise ModelDefinitionError(
f"Model {curr_class} redefines the fields: "
f"{overwrite} already defined in {base_class}!"
)
def update_attrs_and_fields( def update_attrs_and_fields(
attrs: Dict, attrs: Dict,
new_attrs: Dict, new_attrs: Dict,
model_fields: Dict, model_fields: Dict,
new_model_fields: Dict, new_model_fields: Dict,
new_fields: Set, new_fields: Set,
) -> None: ) -> Dict:
""" """
Updates __annotations__, values of model fields (so pydantic FieldInfos) Updates __annotations__, values of model fields (so pydantic FieldInfos)
as well as model.Meta.model_fields definitions from parents. as well as model.Meta.model_fields definitions from parents.
@ -265,11 +230,43 @@ def update_attrs_and_fields(
key = "__annotations__" key = "__annotations__"
attrs[key].update(new_attrs[key]) attrs[key].update(new_attrs[key])
attrs.update({name: new_attrs[name] for name in new_fields}) attrs.update({name: new_attrs[name] for name in new_fields})
model_fields.update(new_model_fields) updated_model_fields = {k: v for k, v in new_model_fields.items()}
updated_model_fields.update(model_fields)
return updated_model_fields
def verify_constraint_names(
base_class: "Model", model_fields: Dict, parent_value: List
) -> None:
"""
Verifies if redefined fields that are overwritten in subclasses did not remove
any name of the column that is used in constraint as it will fail.
:param base_class: one of the parent classes
:type base_class: Model or model parent class
:param model_fields: ormar fields in defined in current class
:type model_fields: Dict[str, BaseField]
:param parent_value: list of base class constraints
:type parent_value: List
"""
new_aliases = {x.name: x.get_alias() for x in model_fields.values()}
old_aliases = {x.name: x.get_alias() for x in base_class.Meta.model_fields.values()}
old_aliases.update(new_aliases)
constraints_columns = [x._pending_colargs for x in parent_value]
for column_set in constraints_columns:
if any(x not in old_aliases.values() for x in column_set):
raise ModelDefinitionError(
f"Unique columns constraint "
f"{column_set} "
f"has column names "
f"that are not in the model fields."
f"\n Check columns redefined in subclasses "
f"to verify that they have proper 'name' set."
)
def update_attrs_from_base_meta( # noqa: CCR001 def update_attrs_from_base_meta( # noqa: CCR001
base_class: "Model", attrs: Dict, base_class: "Model", attrs: Dict, model_fields: Dict
) -> None: ) -> None:
""" """
Updates Meta parameters in child from parent if needed. Updates Meta parameters in child from parent if needed.
@ -278,7 +275,10 @@ def update_attrs_from_base_meta( # noqa: CCR001
:type base_class: Model or model parent class :type base_class: Model or model parent class
:param attrs: new namespace for class being constructed :param attrs: new namespace for class being constructed
:type attrs: Dict :type attrs: Dict
:param model_fields: ormar fields in defined in current class
:type model_fields: Dict[str, BaseField]
""" """
params_to_update = ["metadata", "database", "constraints"] params_to_update = ["metadata", "database", "constraints"]
for param in params_to_update: for param in params_to_update:
current_value = attrs.get("Meta", {}).__dict__.get(param, ormar.Undefined) current_value = attrs.get("Meta", {}).__dict__.get(param, ormar.Undefined)
@ -287,6 +287,11 @@ def update_attrs_from_base_meta( # noqa: CCR001
) )
if parent_value: if parent_value:
if param == "constraints": if param == "constraints":
verify_constraint_names(
base_class=base_class,
model_fields=model_fields,
parent_value=parent_value,
)
parent_value = [ parent_value = [
ormar.UniqueColumns(*x._pending_colargs) for x in parent_value ormar.UniqueColumns(*x._pending_colargs) for x in parent_value
] ]
@ -326,16 +331,7 @@ def copy_data_from_parent_model( # noqa: CCR001
:rtype: Tuple[Dict, Dict] :rtype: Tuple[Dict, Dict]
""" """
if attrs.get("Meta"): if attrs.get("Meta"):
new_fields = set(base_class.Meta.model_fields.keys()) # type: ignore if model_fields and not base_class.Meta.abstract: # type: ignore
previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)})
check_conflicting_fields(
new_fields=new_fields,
attrs=attrs,
base_class=base_class,
curr_class=curr_class,
previous_fields=previous_fields,
)
if previous_fields and not base_class.Meta.abstract: # type: ignore
raise ModelDefinitionError( raise ModelDefinitionError(
f"{curr_class.__name__} cannot inherit " f"{curr_class.__name__} cannot inherit "
f"from non abstract class {base_class.__name__}" f"from non abstract class {base_class.__name__}"
@ -343,6 +339,7 @@ def copy_data_from_parent_model( # noqa: CCR001
update_attrs_from_base_meta( update_attrs_from_base_meta(
base_class=base_class, # type: ignore base_class=base_class, # type: ignore
attrs=attrs, attrs=attrs,
model_fields=model_fields,
) )
parent_fields = dict() parent_fields = dict()
meta = attrs.get("Meta") meta = attrs.get("Meta")
@ -364,7 +361,8 @@ def copy_data_from_parent_model( # noqa: CCR001
else: else:
parent_fields[field_name] = field parent_fields[field_name] = field
model_fields.update(parent_fields) # type: ignore parent_fields.update(model_fields) # type: ignore
model_fields = parent_fields
return attrs, model_fields return attrs, model_fields
@ -416,14 +414,7 @@ def extract_from_parents_definition( # noqa: CCR001
new_attrs, new_model_fields = getattr(base_class, PARSED_FIELDS_KEY) new_attrs, new_model_fields = getattr(base_class, PARSED_FIELDS_KEY)
new_fields = set(new_model_fields.keys()) new_fields = set(new_model_fields.keys())
check_conflicting_fields( model_fields = update_attrs_and_fields(
new_fields=new_fields,
attrs=attrs,
base_class=base_class,
curr_class=curr_class,
)
update_attrs_and_fields(
attrs=attrs, attrs=attrs,
new_attrs=new_attrs, new_attrs=new_attrs,
model_fields=model_fields, model_fields=model_fields,
@ -435,23 +426,16 @@ def extract_from_parents_definition( # noqa: CCR001
potential_fields = get_potential_fields(base_class.__dict__) potential_fields = get_potential_fields(base_class.__dict__)
if potential_fields: if potential_fields:
# parent model has ormar fields defined and was not parsed before # parent model has ormar fields defined and was not parsed before
new_attrs = {key: base_class.__dict__.get(key, {})} new_attrs = {key: {k: v for k, v in base_class.__dict__.get(key, {}).items()}}
new_attrs.update(potential_fields) new_attrs.update(potential_fields)
new_fields = set(potential_fields.keys()) new_fields = set(potential_fields.keys())
check_conflicting_fields(
new_fields=new_fields,
attrs=attrs,
base_class=base_class,
curr_class=curr_class,
)
for name in new_fields: for name in new_fields:
delattr(base_class, name) delattr(base_class, name)
new_attrs, new_model_fields = extract_annotations_and_default_vals(new_attrs) new_attrs, new_model_fields = extract_annotations_and_default_vals(new_attrs)
setattr(base_class, PARSED_FIELDS_KEY, (new_attrs, new_model_fields)) setattr(base_class, PARSED_FIELDS_KEY, (new_attrs, new_model_fields))
model_fields = update_attrs_and_fields(
update_attrs_and_fields(
attrs=attrs, attrs=attrs,
new_attrs=new_attrs, new_attrs=new_attrs,
model_fields=model_fields, model_fields=model_fields,

View File

@ -4,6 +4,7 @@ from typing import Optional
import databases import databases
import pytest import pytest
import sqlalchemy
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import create_engine from sqlalchemy import create_engine
@ -132,17 +133,37 @@ def test_init_of_abstract_model():
DateFieldsModel() DateFieldsModel()
def test_field_redefining_raises_error(): def test_field_redefining_in_concrete_models():
class RedefinedField(DateFieldsModel):
class Meta(ormar.ModelMeta):
tablename = "redefines"
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True)
created_date: str = ormar.String(max_length=200, name="creation_date")
changed_field = RedefinedField.Meta.model_fields["created_date"]
assert changed_field.default is None
assert changed_field.alias == "creation_date"
assert any(x.name == "creation_date" for x in RedefinedField.Meta.table.columns)
assert isinstance(
RedefinedField.Meta.table.columns["creation_date"].type,
sqlalchemy.sql.sqltypes.String,
)
def test_model_subclassing_that_redefines_constraints_column_names():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class WrongField(DateFieldsModel): # pragma: no cover class WrongField2(DateFieldsModel): # pragma: no cover
class Meta(ormar.ModelMeta): class Meta(ormar.ModelMeta):
tablename = "wrongs" tablename = "wrongs"
metadata = metadata metadata = metadata
database = db database = db
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
created_date: datetime.datetime = ormar.DateTime() created_date: str = ormar.String(max_length=200)
def test_model_subclassing_non_abstract_raises_error(): def test_model_subclassing_non_abstract_raises_error():

View File

@ -4,6 +4,7 @@ from typing import Optional
import databases import databases
import pytest import pytest
import sqlalchemy
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import create_engine from sqlalchemy import create_engine
@ -55,17 +56,19 @@ def create_test_database():
metadata.drop_all(engine) metadata.drop_all(engine)
def test_field_redefining_raises_error(): def test_field_redefining():
with pytest.raises(ModelDefinitionError): class RedefinedField(ormar.Model, DateFieldsMixins):
class Meta(ormar.ModelMeta):
tablename = "redefined"
metadata = metadata
database = db
class WrongField(ormar.Model, DateFieldsMixins): # pragma: no cover id: int = ormar.Integer(primary_key=True)
class Meta(ormar.ModelMeta): created_date: datetime.datetime = ormar.DateTime(name="creation_date")
tablename = "wrongs"
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True) assert RedefinedField.Meta.model_fields["created_date"].default is None
created_date: datetime.datetime = ormar.DateTime() assert RedefinedField.Meta.model_fields["created_date"].alias == "creation_date"
assert any(x.name == "creation_date" for x in RedefinedField.Meta.table.columns)
def test_field_redefining_in_second_raises_error(): def test_field_redefining_in_second_raises_error():
@ -77,16 +80,22 @@ def test_field_redefining_in_second_raises_error():
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
with pytest.raises(ModelDefinitionError): class RedefinedField2(ormar.Model, DateFieldsMixins):
class Meta(ormar.ModelMeta):
tablename = "redefines2"
metadata = metadata
database = db
class WrongField(ormar.Model, DateFieldsMixins): # pragma: no cover id: int = ormar.Integer(primary_key=True)
class Meta(ormar.ModelMeta): created_date: str = ormar.String(max_length=200, name="creation_date")
tablename = "wrongs"
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True) assert RedefinedField2.Meta.model_fields["created_date"].default is None
created_date: datetime.datetime = ormar.DateTime() assert RedefinedField2.Meta.model_fields["created_date"].alias == "creation_date"
assert any(x.name == "creation_date" for x in RedefinedField2.Meta.table.columns)
assert isinstance(
RedefinedField2.Meta.table.columns["creation_date"].type,
sqlalchemy.sql.sqltypes.String,
)
def round_date_to_seconds( def round_date_to_seconds(