allow fields redefining - check column names with names used in constraints
This commit is contained in:
@ -198,7 +198,15 @@ class BaseField(FieldInfo):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def construct_contraints(cls) -> List:
|
def construct_constraints(cls) -> List:
|
||||||
|
"""
|
||||||
|
Converts list of ormar constraints into sqlalchemy ForeignKeys.
|
||||||
|
Has to be done dynamically as sqlalchemy binds ForeignKey to the table.
|
||||||
|
And we need a new ForeignKey for subclasses of current model
|
||||||
|
|
||||||
|
:return: List of sqlalchemy foreign keys - by default one.
|
||||||
|
:rtype: List[sqlalchemy.schema.ForeignKey]
|
||||||
|
"""
|
||||||
return [
|
return [
|
||||||
sqlalchemy.schema.ForeignKey(
|
sqlalchemy.schema.ForeignKey(
|
||||||
con.name, ondelete=con.ondelete, onupdate=con.onupdate
|
con.name, ondelete=con.ondelete, onupdate=con.onupdate
|
||||||
@ -221,7 +229,7 @@ class BaseField(FieldInfo):
|
|||||||
return sqlalchemy.Column(
|
return sqlalchemy.Column(
|
||||||
cls.alias or name,
|
cls.alias or name,
|
||||||
cls.column_type,
|
cls.column_type,
|
||||||
*cls.construct_contraints(),
|
*cls.construct_constraints(),
|
||||||
primary_key=cls.primary_key,
|
primary_key=cls.primary_key,
|
||||||
nullable=cls.nullable and not cls.primary_key,
|
nullable=cls.nullable and not cls.primary_key,
|
||||||
index=cls.index,
|
index=cls.index,
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from typing import (
|
|||||||
import databases
|
import databases
|
||||||
import pydantic
|
import pydantic
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from pydantic.fields import FieldInfo
|
|
||||||
from sqlalchemy.sql.schema import ColumnCollectionConstraint
|
from sqlalchemy.sql.schema import ColumnCollectionConstraint
|
||||||
|
|
||||||
import ormar # noqa I100
|
import ormar # noqa I100
|
||||||
@ -206,47 +205,13 @@ def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001
|
|||||||
new_model.Meta.signals = signals
|
new_model.Meta.signals = signals
|
||||||
|
|
||||||
|
|
||||||
def check_conflicting_fields(
|
|
||||||
new_fields: Set,
|
|
||||||
attrs: Dict,
|
|
||||||
base_class: type,
|
|
||||||
curr_class: type,
|
|
||||||
previous_fields: Set = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
You cannot redefine fields with same names in inherited classes.
|
|
||||||
Ormar will raise an exception if it encounters a field that is an ormar
|
|
||||||
Field and at the same time was already declared in one of base classes.
|
|
||||||
|
|
||||||
:param previous_fields: set of names of fields defined in base model
|
|
||||||
:type previous_fields: Set[str]
|
|
||||||
:param new_fields: set of names of fields defined in current model
|
|
||||||
:type new_fields: Set[str]
|
|
||||||
:param attrs: namespace of current class
|
|
||||||
:type attrs: Dict
|
|
||||||
:param base_class: one of the parent classes
|
|
||||||
:type base_class: Model or model parent class
|
|
||||||
:param curr_class: current constructed class
|
|
||||||
:type curr_class: Model or model parent class
|
|
||||||
"""
|
|
||||||
if not previous_fields:
|
|
||||||
previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)})
|
|
||||||
overwrite = new_fields.intersection(previous_fields)
|
|
||||||
|
|
||||||
if overwrite:
|
|
||||||
raise ModelDefinitionError(
|
|
||||||
f"Model {curr_class} redefines the fields: "
|
|
||||||
f"{overwrite} already defined in {base_class}!"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def update_attrs_and_fields(
|
def update_attrs_and_fields(
|
||||||
attrs: Dict,
|
attrs: Dict,
|
||||||
new_attrs: Dict,
|
new_attrs: Dict,
|
||||||
model_fields: Dict,
|
model_fields: Dict,
|
||||||
new_model_fields: Dict,
|
new_model_fields: Dict,
|
||||||
new_fields: Set,
|
new_fields: Set,
|
||||||
) -> None:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Updates __annotations__, values of model fields (so pydantic FieldInfos)
|
Updates __annotations__, values of model fields (so pydantic FieldInfos)
|
||||||
as well as model.Meta.model_fields definitions from parents.
|
as well as model.Meta.model_fields definitions from parents.
|
||||||
@ -265,11 +230,43 @@ def update_attrs_and_fields(
|
|||||||
key = "__annotations__"
|
key = "__annotations__"
|
||||||
attrs[key].update(new_attrs[key])
|
attrs[key].update(new_attrs[key])
|
||||||
attrs.update({name: new_attrs[name] for name in new_fields})
|
attrs.update({name: new_attrs[name] for name in new_fields})
|
||||||
model_fields.update(new_model_fields)
|
updated_model_fields = {k: v for k, v in new_model_fields.items()}
|
||||||
|
updated_model_fields.update(model_fields)
|
||||||
|
return updated_model_fields
|
||||||
|
|
||||||
|
|
||||||
|
def verify_constraint_names(
|
||||||
|
base_class: "Model", model_fields: Dict, parent_value: List
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Verifies if redefined fields that are overwritten in subclasses did not remove
|
||||||
|
any name of the column that is used in constraint as it will fail.
|
||||||
|
|
||||||
|
:param base_class: one of the parent classes
|
||||||
|
:type base_class: Model or model parent class
|
||||||
|
:param model_fields: ormar fields in defined in current class
|
||||||
|
:type model_fields: Dict[str, BaseField]
|
||||||
|
:param parent_value: list of base class constraints
|
||||||
|
:type parent_value: List
|
||||||
|
"""
|
||||||
|
new_aliases = {x.name: x.get_alias() for x in model_fields.values()}
|
||||||
|
old_aliases = {x.name: x.get_alias() for x in base_class.Meta.model_fields.values()}
|
||||||
|
old_aliases.update(new_aliases)
|
||||||
|
constraints_columns = [x._pending_colargs for x in parent_value]
|
||||||
|
for column_set in constraints_columns:
|
||||||
|
if any(x not in old_aliases.values() for x in column_set):
|
||||||
|
raise ModelDefinitionError(
|
||||||
|
f"Unique columns constraint "
|
||||||
|
f"{column_set} "
|
||||||
|
f"has column names "
|
||||||
|
f"that are not in the model fields."
|
||||||
|
f"\n Check columns redefined in subclasses "
|
||||||
|
f"to verify that they have proper 'name' set."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def update_attrs_from_base_meta( # noqa: CCR001
|
def update_attrs_from_base_meta( # noqa: CCR001
|
||||||
base_class: "Model", attrs: Dict,
|
base_class: "Model", attrs: Dict, model_fields: Dict
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Updates Meta parameters in child from parent if needed.
|
Updates Meta parameters in child from parent if needed.
|
||||||
@ -278,7 +275,10 @@ def update_attrs_from_base_meta( # noqa: CCR001
|
|||||||
:type base_class: Model or model parent class
|
:type base_class: Model or model parent class
|
||||||
:param attrs: new namespace for class being constructed
|
:param attrs: new namespace for class being constructed
|
||||||
:type attrs: Dict
|
:type attrs: Dict
|
||||||
|
:param model_fields: ormar fields in defined in current class
|
||||||
|
:type model_fields: Dict[str, BaseField]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params_to_update = ["metadata", "database", "constraints"]
|
params_to_update = ["metadata", "database", "constraints"]
|
||||||
for param in params_to_update:
|
for param in params_to_update:
|
||||||
current_value = attrs.get("Meta", {}).__dict__.get(param, ormar.Undefined)
|
current_value = attrs.get("Meta", {}).__dict__.get(param, ormar.Undefined)
|
||||||
@ -287,6 +287,11 @@ def update_attrs_from_base_meta( # noqa: CCR001
|
|||||||
)
|
)
|
||||||
if parent_value:
|
if parent_value:
|
||||||
if param == "constraints":
|
if param == "constraints":
|
||||||
|
verify_constraint_names(
|
||||||
|
base_class=base_class,
|
||||||
|
model_fields=model_fields,
|
||||||
|
parent_value=parent_value,
|
||||||
|
)
|
||||||
parent_value = [
|
parent_value = [
|
||||||
ormar.UniqueColumns(*x._pending_colargs) for x in parent_value
|
ormar.UniqueColumns(*x._pending_colargs) for x in parent_value
|
||||||
]
|
]
|
||||||
@ -326,16 +331,7 @@ def copy_data_from_parent_model( # noqa: CCR001
|
|||||||
:rtype: Tuple[Dict, Dict]
|
:rtype: Tuple[Dict, Dict]
|
||||||
"""
|
"""
|
||||||
if attrs.get("Meta"):
|
if attrs.get("Meta"):
|
||||||
new_fields = set(base_class.Meta.model_fields.keys()) # type: ignore
|
if model_fields and not base_class.Meta.abstract: # type: ignore
|
||||||
previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)})
|
|
||||||
check_conflicting_fields(
|
|
||||||
new_fields=new_fields,
|
|
||||||
attrs=attrs,
|
|
||||||
base_class=base_class,
|
|
||||||
curr_class=curr_class,
|
|
||||||
previous_fields=previous_fields,
|
|
||||||
)
|
|
||||||
if previous_fields and not base_class.Meta.abstract: # type: ignore
|
|
||||||
raise ModelDefinitionError(
|
raise ModelDefinitionError(
|
||||||
f"{curr_class.__name__} cannot inherit "
|
f"{curr_class.__name__} cannot inherit "
|
||||||
f"from non abstract class {base_class.__name__}"
|
f"from non abstract class {base_class.__name__}"
|
||||||
@ -343,6 +339,7 @@ def copy_data_from_parent_model( # noqa: CCR001
|
|||||||
update_attrs_from_base_meta(
|
update_attrs_from_base_meta(
|
||||||
base_class=base_class, # type: ignore
|
base_class=base_class, # type: ignore
|
||||||
attrs=attrs,
|
attrs=attrs,
|
||||||
|
model_fields=model_fields,
|
||||||
)
|
)
|
||||||
parent_fields = dict()
|
parent_fields = dict()
|
||||||
meta = attrs.get("Meta")
|
meta = attrs.get("Meta")
|
||||||
@ -364,7 +361,8 @@ def copy_data_from_parent_model( # noqa: CCR001
|
|||||||
else:
|
else:
|
||||||
parent_fields[field_name] = field
|
parent_fields[field_name] = field
|
||||||
|
|
||||||
model_fields.update(parent_fields) # type: ignore
|
parent_fields.update(model_fields) # type: ignore
|
||||||
|
model_fields = parent_fields
|
||||||
return attrs, model_fields
|
return attrs, model_fields
|
||||||
|
|
||||||
|
|
||||||
@ -416,14 +414,7 @@ def extract_from_parents_definition( # noqa: CCR001
|
|||||||
new_attrs, new_model_fields = getattr(base_class, PARSED_FIELDS_KEY)
|
new_attrs, new_model_fields = getattr(base_class, PARSED_FIELDS_KEY)
|
||||||
|
|
||||||
new_fields = set(new_model_fields.keys())
|
new_fields = set(new_model_fields.keys())
|
||||||
check_conflicting_fields(
|
model_fields = update_attrs_and_fields(
|
||||||
new_fields=new_fields,
|
|
||||||
attrs=attrs,
|
|
||||||
base_class=base_class,
|
|
||||||
curr_class=curr_class,
|
|
||||||
)
|
|
||||||
|
|
||||||
update_attrs_and_fields(
|
|
||||||
attrs=attrs,
|
attrs=attrs,
|
||||||
new_attrs=new_attrs,
|
new_attrs=new_attrs,
|
||||||
model_fields=model_fields,
|
model_fields=model_fields,
|
||||||
@ -435,23 +426,16 @@ def extract_from_parents_definition( # noqa: CCR001
|
|||||||
potential_fields = get_potential_fields(base_class.__dict__)
|
potential_fields = get_potential_fields(base_class.__dict__)
|
||||||
if potential_fields:
|
if potential_fields:
|
||||||
# parent model has ormar fields defined and was not parsed before
|
# parent model has ormar fields defined and was not parsed before
|
||||||
new_attrs = {key: base_class.__dict__.get(key, {})}
|
new_attrs = {key: {k: v for k, v in base_class.__dict__.get(key, {}).items()}}
|
||||||
new_attrs.update(potential_fields)
|
new_attrs.update(potential_fields)
|
||||||
|
|
||||||
new_fields = set(potential_fields.keys())
|
new_fields = set(potential_fields.keys())
|
||||||
check_conflicting_fields(
|
|
||||||
new_fields=new_fields,
|
|
||||||
attrs=attrs,
|
|
||||||
base_class=base_class,
|
|
||||||
curr_class=curr_class,
|
|
||||||
)
|
|
||||||
for name in new_fields:
|
for name in new_fields:
|
||||||
delattr(base_class, name)
|
delattr(base_class, name)
|
||||||
|
|
||||||
new_attrs, new_model_fields = extract_annotations_and_default_vals(new_attrs)
|
new_attrs, new_model_fields = extract_annotations_and_default_vals(new_attrs)
|
||||||
setattr(base_class, PARSED_FIELDS_KEY, (new_attrs, new_model_fields))
|
setattr(base_class, PARSED_FIELDS_KEY, (new_attrs, new_model_fields))
|
||||||
|
model_fields = update_attrs_and_fields(
|
||||||
update_attrs_and_fields(
|
|
||||||
attrs=attrs,
|
attrs=attrs,
|
||||||
new_attrs=new_attrs,
|
new_attrs=new_attrs,
|
||||||
model_fields=model_fields,
|
model_fields=model_fields,
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import databases
|
import databases
|
||||||
import pytest
|
import pytest
|
||||||
|
import sqlalchemy
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
@ -132,17 +133,37 @@ def test_init_of_abstract_model():
|
|||||||
DateFieldsModel()
|
DateFieldsModel()
|
||||||
|
|
||||||
|
|
||||||
def test_field_redefining_raises_error():
|
def test_field_redefining_in_concrete_models():
|
||||||
|
class RedefinedField(DateFieldsModel):
|
||||||
|
class Meta(ormar.ModelMeta):
|
||||||
|
tablename = "redefines"
|
||||||
|
metadata = metadata
|
||||||
|
database = db
|
||||||
|
|
||||||
|
id: int = ormar.Integer(primary_key=True)
|
||||||
|
created_date: str = ormar.String(max_length=200, name="creation_date")
|
||||||
|
|
||||||
|
changed_field = RedefinedField.Meta.model_fields["created_date"]
|
||||||
|
assert changed_field.default is None
|
||||||
|
assert changed_field.alias == "creation_date"
|
||||||
|
assert any(x.name == "creation_date" for x in RedefinedField.Meta.table.columns)
|
||||||
|
assert isinstance(
|
||||||
|
RedefinedField.Meta.table.columns["creation_date"].type,
|
||||||
|
sqlalchemy.sql.sqltypes.String,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_subclassing_that_redefines_constraints_column_names():
|
||||||
with pytest.raises(ModelDefinitionError):
|
with pytest.raises(ModelDefinitionError):
|
||||||
|
|
||||||
class WrongField(DateFieldsModel): # pragma: no cover
|
class WrongField2(DateFieldsModel): # pragma: no cover
|
||||||
class Meta(ormar.ModelMeta):
|
class Meta(ormar.ModelMeta):
|
||||||
tablename = "wrongs"
|
tablename = "wrongs"
|
||||||
metadata = metadata
|
metadata = metadata
|
||||||
database = db
|
database = db
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
id: int = ormar.Integer(primary_key=True)
|
||||||
created_date: datetime.datetime = ormar.DateTime()
|
created_date: str = ormar.String(max_length=200)
|
||||||
|
|
||||||
|
|
||||||
def test_model_subclassing_non_abstract_raises_error():
|
def test_model_subclassing_non_abstract_raises_error():
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import databases
|
import databases
|
||||||
import pytest
|
import pytest
|
||||||
|
import sqlalchemy
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
@ -55,17 +56,19 @@ def create_test_database():
|
|||||||
metadata.drop_all(engine)
|
metadata.drop_all(engine)
|
||||||
|
|
||||||
|
|
||||||
def test_field_redefining_raises_error():
|
def test_field_redefining():
|
||||||
with pytest.raises(ModelDefinitionError):
|
class RedefinedField(ormar.Model, DateFieldsMixins):
|
||||||
|
class Meta(ormar.ModelMeta):
|
||||||
|
tablename = "redefined"
|
||||||
|
metadata = metadata
|
||||||
|
database = db
|
||||||
|
|
||||||
class WrongField(ormar.Model, DateFieldsMixins): # pragma: no cover
|
id: int = ormar.Integer(primary_key=True)
|
||||||
class Meta(ormar.ModelMeta):
|
created_date: datetime.datetime = ormar.DateTime(name="creation_date")
|
||||||
tablename = "wrongs"
|
|
||||||
metadata = metadata
|
|
||||||
database = db
|
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
assert RedefinedField.Meta.model_fields["created_date"].default is None
|
||||||
created_date: datetime.datetime = ormar.DateTime()
|
assert RedefinedField.Meta.model_fields["created_date"].alias == "creation_date"
|
||||||
|
assert any(x.name == "creation_date" for x in RedefinedField.Meta.table.columns)
|
||||||
|
|
||||||
|
|
||||||
def test_field_redefining_in_second_raises_error():
|
def test_field_redefining_in_second_raises_error():
|
||||||
@ -77,16 +80,22 @@ def test_field_redefining_in_second_raises_error():
|
|||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
id: int = ormar.Integer(primary_key=True)
|
||||||
|
|
||||||
with pytest.raises(ModelDefinitionError):
|
class RedefinedField2(ormar.Model, DateFieldsMixins):
|
||||||
|
class Meta(ormar.ModelMeta):
|
||||||
|
tablename = "redefines2"
|
||||||
|
metadata = metadata
|
||||||
|
database = db
|
||||||
|
|
||||||
class WrongField(ormar.Model, DateFieldsMixins): # pragma: no cover
|
id: int = ormar.Integer(primary_key=True)
|
||||||
class Meta(ormar.ModelMeta):
|
created_date: str = ormar.String(max_length=200, name="creation_date")
|
||||||
tablename = "wrongs"
|
|
||||||
metadata = metadata
|
|
||||||
database = db
|
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
assert RedefinedField2.Meta.model_fields["created_date"].default is None
|
||||||
created_date: datetime.datetime = ormar.DateTime()
|
assert RedefinedField2.Meta.model_fields["created_date"].alias == "creation_date"
|
||||||
|
assert any(x.name == "creation_date" for x in RedefinedField2.Meta.table.columns)
|
||||||
|
assert isinstance(
|
||||||
|
RedefinedField2.Meta.table.columns["creation_date"].type,
|
||||||
|
sqlalchemy.sql.sqltypes.String,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def round_date_to_seconds(
|
def round_date_to_seconds(
|
||||||
|
|||||||
Reference in New Issue
Block a user