diff --git a/.coverage b/.coverage index 52c345a..105fceb 100644 Binary files a/.coverage and b/.coverage differ diff --git a/orm/__init__.py b/orm/__init__.py index bcc1d9c..2cf9401 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -6,13 +6,13 @@ from orm.fields import ( DateTime, Decimal, Float, + ForeignKey, Integer, JSON, String, Text, Time, ) -from orm.fields.foreign_key import ForeignKey from orm.models import Model __version__ = "0.0.1" diff --git a/orm/fields/foreign_key.py b/orm/fields/foreign_key.py index 506a9ac..b5659b4 100644 --- a/orm/fields/foreign_key.py +++ b/orm/fields/foreign_key.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, TYPE_CHECKING, Type, Union +from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, Union import sqlalchemy from pydantic import BaseModel @@ -49,6 +49,28 @@ class ForeignKey(BaseField): to_column = self.to.__model_fields__[self.to.__pkname__] return to_column.get_column_type() + def extract_model_from_sequence( + self, value: Any, child: "Model" + ) -> Tuple[Union["Model", List["Model"]], bool]: + if isinstance(value, list) and not isinstance(value, self.to): + model = [self.expand_relationship(val, child) for val in value] + return model, True + + if isinstance(value, self.to): + model = value + else: + model = self.to(**value) + return model, False + + def construct_model_from_pk(self, value: Any) -> "Model": + if not isinstance(value, self.to.pk_type()): + raise RelationshipInstanceError( + f"Relationship error - ForeignKey {self.to.__name__} " + f"is of type {self.to.pk_type()} " + f"while {type(value)} passed as a parameter." + ) + return create_dummy_instance(fk=self.to, pk=value) + def expand_relationship( self, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: @@ -56,31 +78,22 @@ class ForeignKey(BaseField): if value is None: return None + is_sequence = False + if isinstance(value, orm.models.Model) and not isinstance(value, self.to): raise RelationshipInstanceError( f"Relationship error - expecting: {self.to.__name__}, " f"but {value.__class__.__name__} encountered." ) - if isinstance(value, list) and not isinstance(value, self.to): - model = [self.expand_relationship(val, child) for val in value] - return model - - if isinstance(value, self.to): - model = value - elif isinstance(value, dict): - model = self.to(**value) + if isinstance(value, (dict, list, self.to)): + model, is_sequence = self.extract_model_from_sequence(value, child) else: - if not isinstance(value, self.to.pk_type()): - raise RelationshipInstanceError( - f"Relationship error - ForeignKey {self.to.__name__} " - f"is of type {self.to.pk_type()} " - f"while {type(value)} passed as a parameter." - ) - model = create_dummy_instance(fk=self.to, pk=value) + model = self.construct_model_from_pk(value) - model._orm_relationship_manager.add_relation( - model, child, virtual=self.virtual, - ) + if not is_sequence: + model._orm_relationship_manager.add_relation( + model, child, virtual=self.virtual + ) return model