diff --git a/.coverage b/.coverage index 26e4f9c..80aea2e 100644 Binary files a/.coverage and b/.coverage differ diff --git a/orm/fields/foreign_key.py b/orm/fields/foreign_key.py index 1496fc9..a5e65dd 100644 --- a/orm/fields/foreign_key.py +++ b/orm/fields/foreign_key.py @@ -1,4 +1,4 @@ -from typing import Type, List, Any, Union, TYPE_CHECKING +from typing import Type, List, Any, Union, TYPE_CHECKING, Optional import sqlalchemy from pydantic import BaseModel @@ -12,9 +12,8 @@ if TYPE_CHECKING: # pragma no cover def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": - init_dict = {fk.__pkname__: pk or -1} init_dict = { - **init_dict, + **{fk.__pkname__: pk or -1}, **{ k: create_dummy_instance(v.to) for k, v in fk.__model_fields__.items() @@ -26,12 +25,12 @@ def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": class ForeignKey(BaseField): def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual @@ -51,8 +50,11 @@ class ForeignKey(BaseField): return to_column.get_column_type() def expand_relationship( - self, value: Any, child: "Model" - ) -> Union["Model", List["Model"]]: + self, value: Any, child: "Model" + ) -> Optional[Union["Model", List["Model"]]]: + + if value is None: + return None if isinstance(value, orm.models.Model) and not isinstance(value, self.to): raise RelationshipInstanceError( @@ -77,15 +79,10 @@ class ForeignKey(BaseField): ) model = create_dummy_instance(fk=self.to, pk=value) - self.add_to_relationship_registry(model, child) - - return model - - def add_to_relationship_registry(self, model: "Model", child: "Model") -> None: model._orm_relationship_manager.add_relation( - model.__class__.__name__.lower(), - child.__class__.__name__.lower(), model, child, virtual=self.virtual, ) + + return model diff --git a/orm/relations.py b/orm/relations.py index 08f49fc..fa0dbcf 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -16,7 +16,7 @@ def get_table_alias() -> str: def get_relation_config( - relation_type: str, table_name: str, field: ForeignKey + relation_type: str, table_name: str, field: ForeignKey ) -> Dict[str, str]: alias = get_table_alias() config = { @@ -37,7 +37,7 @@ class RelationshipManager: self._relations = dict() def add_relation_type( - self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str + self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str ) -> None: if relations_key not in self._relations: self._relations[relations_key] = get_relation_config( @@ -56,15 +56,15 @@ class RelationshipManager: del self._relations[rel_type][model._orm_id] def add_relation( - self, - parent_name: str, - child_name: str, - parent: "FakePydantic", - child: "FakePydantic", - virtual: bool = False, + self, + parent: "FakePydantic", + child: "FakePydantic", + virtual: bool = False, ) -> None: parent_id = parent._orm_id child_id = child._orm_id + parent_name = parent.get_name() + child_name = child.get_name() if virtual: child_name, parent_name = parent_name, child_name child_id, parent_id = parent_id, child_id @@ -97,7 +97,7 @@ class RelationshipManager: return False def get( - self, relations_key: str, instance: "FakePydantic" + self, relations_key: str, instance: "FakePydantic" ) -> Union["Model", List["Model"]]: if relations_key in self._relations: if instance._orm_id in self._relations[relations_key]: @@ -108,8 +108,8 @@ class RelationshipManager: def resolve_relation_join(self, from_table: str, to_table: str) -> str: for relation_name, relation in self._relations.items(): if ( - relation["source_table"] == from_table - and relation["target_table"] == to_table + relation["source_table"] == from_table + and relation["target_table"] == to_table ): return self._relations[relation_name]["table_alias"] return "" diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 6b69507..7121e5a 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -74,6 +74,13 @@ async def test_wrong_query_foreign_key_type(): Track(title="The Error", album="wrong_pk_type") +@pytest.mark.asyncio +async def test_setting_explicitly_empty_relation(): + async with database: + track = Track(album=None, title="The Bird", position=1) + assert track.album is None + + @pytest.mark.asyncio async def test_model_crud(): async with database: @@ -146,8 +153,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -155,8 +162,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -198,8 +205,8 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: @@ -218,8 +225,8 @@ async def test_pk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1