diff --git a/.coverage b/.coverage index 12fa709..d8930bb 100644 Binary files a/.coverage and b/.coverage differ diff --git a/orm/fields/foreign_key.py b/orm/fields/foreign_key.py index 827ab19..0788ce2 100644 --- a/orm/fields/foreign_key.py +++ b/orm/fields/foreign_key.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, Union +from typing import Any, List, Optional, TYPE_CHECKING, Type, Union import sqlalchemy from pydantic import BaseModel @@ -25,12 +25,12 @@ def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": class ForeignKey(BaseField): def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual @@ -49,21 +49,21 @@ class ForeignKey(BaseField): to_column = self.to.__model_fields__[self.to.__pkname__] return to_column.get_column_type() - def extract_model_from_sequence( - self, value: Any, child: "Model" + def _extract_model_from_sequence( + self, value: List, child: "Model" ) -> Union["Model", List["Model"]]: - if isinstance(value, list) and not isinstance(value, self.to): - model = [self.expand_relationship(val, child) for val in value] - return model + return [self.expand_relationship(val, child) for val in value] - if isinstance(value, self.to): - model = value - else: - model = self.to(**value) + def _register_existing_model(self, value: "Model", child: "Model") -> "Model": + self.register_relation(value, child) + return value + + def _construct_model_from_dict(self, value: dict, child: "Model") -> "Model": + model = self.to(**value) self.register_relation(model, child) return model - def construct_model_from_pk(self, value: Any, child: "Model") -> "Model": + def _construct_model_from_pk(self, value: Any, child: "Model") -> "Model": if not isinstance(value, self.to.pk_type()): raise RelationshipInstanceError( f"Relationship error - ForeignKey {self.to.__name__} " @@ -74,27 +74,23 @@ class ForeignKey(BaseField): self.register_relation(model, child) return model - def register_relation(self, model, child): - model._orm_relationship_manager.add_relation( - model, child, virtual=self.virtual - ) + def register_relation(self, model: "Model", child: "Model") -> None: + model._orm_relationship_manager.add_relation(model, child, virtual=self.virtual) def expand_relationship( - self, value: Any, child: "Model" + self, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: if value is None: return None - if isinstance(value, orm.models.Model) and not isinstance(value, self.to): - raise RelationshipInstanceError( - f"Relationship error - expecting: {self.to.__name__}, " - f"but {value.__class__.__name__} encountered." - ) - - if isinstance(value, (dict, list, self.to)): - model = self.extract_model_from_sequence(value, child) - else: - model = self.construct_model_from_pk(value, child) + constructors = { + f"{self.to.__name__}": self._register_existing_model, + "dict": self._construct_model_from_dict, + "list": self._extract_model_from_sequence, + } + model = constructors.get( + value.__class__.__name__, self._construct_model_from_pk + )(value, child) return model