diff --git a/.coverage b/.coverage index 105fceb..12fa709 100644 Binary files a/.coverage and b/.coverage differ diff --git a/orm/fields/foreign_key.py b/orm/fields/foreign_key.py index b5659b4..827ab19 100644 --- a/orm/fields/foreign_key.py +++ b/orm/fields/foreign_key.py @@ -25,12 +25,12 @@ def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": class ForeignKey(BaseField): def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual @@ -50,36 +50,42 @@ class ForeignKey(BaseField): return to_column.get_column_type() def extract_model_from_sequence( - self, value: Any, child: "Model" - ) -> Tuple[Union["Model", List["Model"]], bool]: + self, value: Any, child: "Model" + ) -> Union["Model", List["Model"]]: if isinstance(value, list) and not isinstance(value, self.to): model = [self.expand_relationship(val, child) for val in value] - return model, True + return model if isinstance(value, self.to): model = value else: model = self.to(**value) - return model, False + self.register_relation(model, child) + return model - def construct_model_from_pk(self, value: Any) -> "Model": + def construct_model_from_pk(self, value: Any, child: "Model") -> "Model": if not isinstance(value, self.to.pk_type()): raise RelationshipInstanceError( f"Relationship error - ForeignKey {self.to.__name__} " f"is of type {self.to.pk_type()} " f"while {type(value)} passed as a parameter." ) - return create_dummy_instance(fk=self.to, pk=value) + model = create_dummy_instance(fk=self.to, pk=value) + self.register_relation(model, child) + return model + + def register_relation(self, model, child): + model._orm_relationship_manager.add_relation( + model, child, virtual=self.virtual + ) def expand_relationship( - self, value: Any, child: "Model" + self, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: if value is None: return None - is_sequence = False - if isinstance(value, orm.models.Model) and not isinstance(value, self.to): raise RelationshipInstanceError( f"Relationship error - expecting: {self.to.__name__}, " @@ -87,13 +93,8 @@ class ForeignKey(BaseField): ) if isinstance(value, (dict, list, self.to)): - model, is_sequence = self.extract_model_from_sequence(value, child) + model = self.extract_model_from_sequence(value, child) else: - model = self.construct_model_from_pk(value) - - if not is_sequence: - model._orm_relationship_manager.add_relation( - model, child, virtual=self.virtual - ) + model = self.construct_model_from_pk(value, child) return model