diff --git a/ormar/fields/base.py b/ormar/fields/base.py index db9ae57..8c25810 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -87,9 +87,9 @@ class BaseField(FieldInfo): :rtype: bool """ return ( - field_name not in ["default", "default_factory", "alias"] - and not field_name.startswith("__") - and hasattr(cls, field_name) + field_name not in ["default", "default_factory", "alias"] + and not field_name.startswith("__") + and hasattr(cls, field_name) ) @classmethod @@ -180,7 +180,7 @@ class BaseField(FieldInfo): :rtype: bool """ return cls.default is not None or ( - cls.server_default is not None and use_server + cls.server_default is not None and use_server ) @classmethod @@ -197,6 +197,12 @@ class BaseField(FieldInfo): return cls.autoincrement return False + @classmethod + def construct_contraints(cls) -> List: + return [sqlalchemy.schema.ForeignKey( + con.name, ondelete=con.ondelete, onupdate=con.onupdate + ) for con in cls.constraints] + @classmethod def get_column(cls, name: str) -> sqlalchemy.Column: """ @@ -212,7 +218,7 @@ class BaseField(FieldInfo): return sqlalchemy.Column( cls.alias or name, cls.column_type, - *cls.constraints, + *cls.construct_contraints(), primary_key=cls.primary_key, nullable=cls.nullable and not cls.primary_key, index=cls.index, @@ -223,11 +229,11 @@ class BaseField(FieldInfo): @classmethod def expand_relationship( - cls, - value: Any, - child: Union["Model", "NewBaseModel"], - to_register: bool = True, - relation_name: str = None, + cls, + value: Any, + child: Union["Model", "NewBaseModel"], + to_register: bool = True, + relation_name: str = None, ) -> Any: """ Function overwritten for relations, in basic field the value is returned as is. diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 272d34b..5ef8e78 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -1,4 +1,5 @@ import uuid +from dataclasses import dataclass from typing import Any, List, Optional, TYPE_CHECKING, Type, Union import sqlalchemy @@ -45,8 +46,8 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": def create_dummy_model( - base_model: Type["Model"], - pk_field: Type[Union[BaseField, "ForeignKeyField", "ManyToManyField"]], + base_model: Type["Model"], + pk_field: Type[Union[BaseField, "ForeignKeyField", "ManyToManyField"]], ) -> Type["BaseModel"]: """ Used to construct a dummy pydantic model for type hints and pydantic validation. @@ -75,17 +76,24 @@ class UniqueColumns(UniqueConstraint): pass +@dataclass +class ForeignKeyConstraint: + name: str + ondelete: str + onupdate: str + + def ForeignKey( # noqa CFQ002 - to: Type["Model"], - *, - name: str = None, - unique: bool = False, - nullable: bool = True, - related_name: str = None, - virtual: bool = False, - onupdate: str = None, - ondelete: str = None, - **kwargs: Any, + to: Type["Model"], + *, + name: str = None, + unique: bool = False, + nullable: bool = True, + related_name: str = None, + virtual: bool = False, + onupdate: str = None, + ondelete: str = None, + **kwargs: Any, ) -> Any: """ Despite a name it's a function that returns constructed ForeignKeyField. @@ -132,9 +140,7 @@ def ForeignKey( # noqa CFQ002 name=kwargs.pop("real_name", None), nullable=nullable, constraints=[ - sqlalchemy.schema.ForeignKey( - fk_string, ondelete=ondelete, onupdate=onupdate - ) + ForeignKeyConstraint(name=fk_string, ondelete=ondelete, onupdate=onupdate) ], unique=unique, column_type=to_field.column_type, @@ -162,7 +168,7 @@ class ForeignKeyField(BaseField): @classmethod def _extract_model_from_sequence( - cls, value: List, child: "Model", to_register: bool, relation_name: str + cls, value: List, child: "Model", to_register: bool, relation_name: str ) -> List["Model"]: """ Takes a list of Models and registers them on parent. @@ -191,7 +197,7 @@ class ForeignKeyField(BaseField): @classmethod def _register_existing_model( - cls, value: "Model", child: "Model", to_register: bool, relation_name: str + cls, value: "Model", child: "Model", to_register: bool, relation_name: str ) -> "Model": """ Takes already created instance and registers it for parent. @@ -214,7 +220,7 @@ class ForeignKeyField(BaseField): @classmethod def _construct_model_from_dict( - cls, value: dict, child: "Model", to_register: bool, relation_name: str + cls, value: dict, child: "Model", to_register: bool, relation_name: str ) -> "Model": """ Takes a dictionary, creates a instance and registers it for parent. @@ -241,7 +247,7 @@ class ForeignKeyField(BaseField): @classmethod def _construct_model_from_pk( - cls, value: Any, child: "Model", to_register: bool, relation_name: str + cls, value: Any, child: "Model", to_register: bool, relation_name: str ) -> "Model": """ Takes a pk value, creates a dummy instance and registers it for parent. @@ -273,7 +279,7 @@ class ForeignKeyField(BaseField): @classmethod def register_relation( - cls, model: "Model", child: "Model", relation_name: str + cls, model: "Model", child: "Model", relation_name: str ) -> None: """ Registers relation between parent and child in relation manager. @@ -297,11 +303,11 @@ class ForeignKeyField(BaseField): @classmethod def expand_relationship( - cls, - value: Any, - child: Union["Model", "NewBaseModel"], - to_register: bool = True, - relation_name: str = None, + cls, + value: Any, + child: Union["Model", "NewBaseModel"], + to_register: bool = True, + relation_name: str = None, ) -> Optional[Union["Model", List["Model"]]]: """ For relations the child model is first constructed (if needed), diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index c2baf9e..b2b0110 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -1,3 +1,4 @@ +import copy import logging import warnings from typing import ( @@ -60,7 +61,7 @@ def register_relation_on_build_new(new_model: Type["Model"], field_name: str) -> def register_many_to_many_relation_on_build_new( - new_model: Type["Model"], field: Type[ManyToManyField] + new_model: Type["Model"], field: Type[ManyToManyField] ) -> None: alias_manager.add_relation_type_new( field.through, new_model.get_name(), is_multi=True @@ -71,11 +72,11 @@ def register_many_to_many_relation_on_build_new( def reverse_field_not_already_registered( - child: Type["Model"], child_model_name: str, parent_model: Type["Model"] + child: Type["Model"], child_model_name: str, parent_model: Type["Model"] ) -> bool: return ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ) @@ -86,7 +87,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if reverse_field_not_already_registered( - child, child_model_name, parent_model + child, child_model_name, parent_model ): register_reverse_model_fields( parent_model, child, child_model_name, model_field @@ -94,10 +95,10 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: def register_reverse_model_fields( - model: Type["Model"], - child: Type["Model"], - child_model_name: str, - model_field: Type["ForeignKeyField"], + model: Type["Model"], + child: Type["Model"], + child_model_name: str, + model_field: Type["ForeignKeyField"], ) -> None: if issubclass(model_field, ManyToManyField): model.Meta.model_fields[child_model_name] = ManyToMany( @@ -119,10 +120,10 @@ def register_reverse_model_fields( def adjust_through_many_to_many_model( - model: Type["Model"], - child: Type["Model"], - model_field: Type[ManyToManyField], - child_model_name: str, + model: Type["Model"], + child: Type["Model"], + model_field: Type[ManyToManyField], + child_model_name: str, ) -> None: model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( model, real_name=model.get_name(), ondelete="CASCADE" @@ -139,7 +140,7 @@ def adjust_through_many_to_many_model( def create_pydantic_field( - field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] + field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.__fields__[field_name] = ModelField( name=field_name, @@ -161,7 +162,7 @@ def get_pydantic_field(field_name: str, model: Type["Model"]) -> "ModelField": def create_and_append_m2m_fk( - model: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: column = sqlalchemy.Column( model.get_name(), @@ -177,7 +178,7 @@ def create_and_append_m2m_fk( def check_pk_column_validity( - field_name: str, field: BaseField, pkname: Optional[str] + field_name: str, field: BaseField, pkname: Optional[str] ) -> Optional[str]: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") @@ -187,7 +188,7 @@ def check_pk_column_validity( def validate_related_names_in_relations( - model_fields: Dict, new_model: Type["Model"] + model_fields: Dict, new_model: Type["Model"] ) -> None: already_registered: Dict[str, List[Optional[str]]] = dict() for field in model_fields.values(): @@ -206,7 +207,7 @@ def validate_related_names_in_relations( def sqlalchemy_columns_from_model_fields( - model_fields: Dict, new_model: Type["Model"] + model_fields: Dict, new_model: Type["Model"] ) -> Tuple[Optional[str], List[sqlalchemy.Column]]: columns = [] pkname = None @@ -221,16 +222,16 @@ def sqlalchemy_columns_from_model_fields( if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) if ( - not field.pydantic_only - and not field.virtual - and not issubclass(field, ManyToManyField) + not field.pydantic_only + and not field.virtual + and not issubclass(field, ManyToManyField) ): columns.append(field.get_column(field.get_alias())) return pkname, columns def register_relation_in_alias_manager_new( - new_model: Type["Model"], field: Type[ForeignKeyField], field_name: str + new_model: Type["Model"], field: Type[ForeignKeyField], field_name: str ) -> None: if issubclass(field, ManyToManyField): register_many_to_many_relation_on_build_new(new_model=new_model, field=field) @@ -239,7 +240,7 @@ def register_relation_in_alias_manager_new( def populate_default_pydantic_field_value( - ormar_field: Type[BaseField], field_name: str, attrs: dict + ormar_field: Type[BaseField], field_name: str, attrs: dict ) -> dict: curr_def_value = attrs.get(field_name, ormar.Undefined) if lenient_issubclass(curr_def_value, ormar.fields.BaseField): @@ -284,7 +285,7 @@ def extract_annotations_and_default_vals(attrs: dict) -> Tuple[Dict, Dict]: def populate_meta_tablename_columns_and_pk( - name: str, new_model: Type["Model"] + name: str, new_model: Type["Model"] ) -> Type["Model"]: tablename = name.lower() + "s" new_model.Meta.tablename = ( @@ -309,7 +310,7 @@ def populate_meta_tablename_columns_and_pk( def populate_meta_sqlalchemy_table_if_required( - new_model: Type["Model"], + new_model: Type["Model"], ) -> Type["Model"]: """ Constructs sqlalchemy table out of columns and parameters set on Meta class. @@ -400,7 +401,7 @@ def populate_choices_validators(model: Type["Model"]) -> None: # noqa CCR001 def populate_default_options_values( - new_model: Type["Model"], model_fields: Dict + new_model: Type["Model"], model_fields: Dict ) -> None: """ Sets all optional Meta values to it's defaults @@ -522,11 +523,11 @@ def get_potential_fields(attrs: Dict) -> Dict: def check_conflicting_fields( - new_fields: Set, - attrs: Dict, - base_class: type, - curr_class: type, - previous_fields: Set = None, + new_fields: Set, + attrs: Dict, + base_class: type, + curr_class: type, + previous_fields: Set = None, ) -> None: """ You cannot redefine fields with same names in inherited classes. @@ -556,11 +557,11 @@ def check_conflicting_fields( def update_attrs_and_fields( - attrs: Dict, - new_attrs: Dict, - model_fields: Dict, - new_model_fields: Dict, - new_fields: Set, + attrs: Dict, + new_attrs: Dict, + model_fields: Dict, + new_model_fields: Dict, + new_fields: Set, ) -> None: """ Updates __annotations__, values of model fields (so pydantic FieldInfos) @@ -584,7 +585,7 @@ def update_attrs_and_fields( def update_attrs_from_base_meta( # noqa: CCR001 - base_class: "Model", attrs: Dict, + base_class: "Model", attrs: Dict, ) -> None: """ Updates Meta parameters in child from parent if needed. @@ -612,12 +613,12 @@ def update_attrs_from_base_meta( # noqa: CCR001 def extract_from_parents_definition( # noqa: CCR001 - base_class: type, - curr_class: type, - attrs: Dict, - model_fields: Dict[ - str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] - ], + base_class: type, + curr_class: type, + attrs: Dict, + model_fields: Dict[ + str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] + ], ) -> Tuple[Dict, Dict]: """ Extracts fields from base classes if they have valid oramr fields. @@ -664,7 +665,18 @@ def extract_from_parents_definition( # noqa: CCR001 base_class=base_class, # type: ignore attrs=attrs, ) - model_fields.update(base_class.Meta.model_fields) # type: ignore + parent_fields = dict() + table_name = attrs.get("Meta").tablename if hasattr(attrs.get("Meta"), "tablename") else attrs.get( + '__name__').lower() + 's' + for field_name, field in base_class.Meta.model_fields.items(): + if issubclass(field, ForeignKeyField) and field.related_name: + copy_field = type(field.__name__, (field,), dict(field.__dict__)) + copy_field.related_name = field.related_name + '_' + table_name + parent_fields[field_name] = copy_field + else: + parent_fields[field_name] = field + + model_fields.update(parent_fields) # type: ignore return attrs, model_fields key = "__annotations__" @@ -722,7 +734,7 @@ def extract_from_parents_definition( # noqa: CCR001 class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__( # type: ignore # noqa: CCR001 - mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict + mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict ) -> "ModelMetaclass": attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name diff --git a/tests/test_inheritance_concrete.py b/tests/test_inheritance_concrete.py index 8cef051..2c221de 100644 --- a/tests/test_inheritance_concrete.py +++ b/tests/test_inheritance_concrete.py @@ -82,6 +82,44 @@ class Subject(DateFieldsModel): category: Optional[Category] = ormar.ForeignKey(Category) +class Person(ormar.Model): + class Meta: + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Car(ormar.Model): + class Meta: + abstract = True + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=50) + owner: Person = ormar.ForeignKey(Person) + co_owner: Person = ormar.ForeignKey(Person, related_name='coowned') + + +class Truck(Car): + class Meta: + metadata = metadata + database = db + + max_capacity: int = ormar.Integer() + + +class Bus(Car): + class Meta: + tablename = 'buses' + metadata = metadata + database = db + + max_persons: int = ormar.Integer() + + @pytest.fixture(autouse=True, scope="module") def create_test_database(): metadata.create_all(engine) @@ -96,7 +134,6 @@ def test_init_of_abstract_model(): def test_field_redefining_raises_error(): with pytest.raises(ModelDefinitionError): - class WrongField(DateFieldsModel): # pragma: no cover class Meta(ormar.ModelMeta): tablename = "wrongs" @@ -109,7 +146,6 @@ def test_field_redefining_raises_error(): def test_model_subclassing_non_abstract_raises_error(): with pytest.raises(ModelDefinitionError): - class WrongField2(DateFieldsModelNoSubclass): # pragma: no cover class Meta(ormar.ModelMeta): tablename = "wrongs" @@ -127,7 +163,7 @@ def test_params_are_inherited(): def round_date_to_seconds( - date: datetime.datetime, + date: datetime.datetime, ) -> datetime.datetime: # pragma: no cover if date.microsecond >= 500000: date = date + datetime.timedelta(seconds=1) @@ -170,9 +206,9 @@ async def test_fields_inherited_from_mixin(): sub2 = ( await Subject.objects.select_related("category") - .order_by("-created_date") - .exclude_fields("updated_date") - .get() + .order_by("-created_date") + .exclude_fields("updated_date") + .get() ) assert round_date_to_seconds(sub2.created_date) == round_date_to_seconds( sub.created_date @@ -187,9 +223,9 @@ async def test_fields_inherited_from_mixin(): sub3 = ( await Subject.objects.prefetch_related("category") - .order_by("-created_date") - .exclude_fields({"updated_date": ..., "category": {"updated_date"}}) - .get() + .order_by("-created_date") + .exclude_fields({"updated_date": ..., "category": {"updated_date"}}) + .get() ) assert round_date_to_seconds(sub3.created_date) == round_date_to_seconds( sub.created_date @@ -201,3 +237,21 @@ async def test_fields_inherited_from_mixin(): assert sub3.updated_date is None assert sub3.category.created_by == "Sam" assert sub3.category.updated_by == cat.updated_by + + +@pytest.mark.asyncio +async def test_inheritance_with_relation(): + async with db: + async with db.transaction(force_rollback=True): + sam = await Person(name='Sam').save() + joe = await Person(name='Joe').save() + await Truck(name='Shelby wanna be', max_capacity=1400, owner=sam, co_owner=joe).save() + + shelby = await Truck.objects.select_related(['owner', 'co_owner']).get() + assert shelby.name == 'Shelby wanna be' + assert shelby.owner.name == 'Sam' + assert shelby.co_owner.name == 'Joe' + + joe_check = await Person.objects.select_related('coowned_trucks').get(name='Joe') + assert joe_check.pk == joe.pk + assert joe_check.coowned_trucks[0] == shelby