From 0f36944fe1a49a42fadce1db3e0ead74bcf6ebe8 Mon Sep 17 00:00:00 2001 From: collerek Date: Sat, 14 Nov 2020 14:47:33 +0100 Subject: [PATCH] add safe fails for adding and removing not saved models to many to many rel, add tests for save_related --- ormar/models/model.py | 2 +- ormar/models/modelproxy.py | 12 +++++++++--- ormar/relations/relation_proxy.py | 7 ++++++- tests/test_many_to_many.py | 26 +++++++++++++++++++++++++- 4 files changed, 41 insertions(+), 6 deletions(-) diff --git a/ormar/models/model.py b/ormar/models/model.py index 437ed94..80c9807 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -192,7 +192,7 @@ class Model(NewBaseModel): self.set_save_status(True) return self - async def save_related(self) -> int: + async def save_related(self) -> int: # noqa: CCR001 update_count = 0 for related in self.extract_related_names(): if self.Meta.model_fields[related].virtual or issubclass( diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index c457bb2..3835b41 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -12,7 +12,7 @@ from typing import ( Union, ) -from ormar.exceptions import RelationshipInstanceError +from ormar.exceptions import ModelPersistenceError, RelationshipInstanceError try: import orjson as json @@ -63,8 +63,14 @@ class ModelTableProxy: target_field = cls.Meta.model_fields[field] target_pkname = target_field.to.Meta.pkname if isinstance(field_value, ormar.Model): - model_dict[field] = getattr(field_value, target_pkname) - elif field_value: + pk_value = getattr(field_value, target_pkname) + if not pk_value: + raise ModelPersistenceError( + f"You cannot save {field_value.get_name()} " + f"model without pk set!" + ) + model_dict[field] = pk_value + elif field_value: # nested dict model_dict[field] = field_value.get(target_pkname) else: model_dict.pop(field, None) diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index 88130d5..f03252e 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -1,7 +1,7 @@ from typing import Any, TYPE_CHECKING import ormar -from ormar.exceptions import RelationshipInstanceError +from ormar.exceptions import NoMatch, RelationshipInstanceError from ormar.relations.querysetproxy import QuerysetProxy if TYPE_CHECKING: # pragma no cover @@ -54,6 +54,11 @@ class RelationProxy(list): return queryset async def remove(self, item: "Model") -> None: # type: ignore + if item not in self: + raise NoMatch( + f"Object {self._owner.get_name()} has no " + f"{item.get_name()} with given primary key!" + ) super().remove(item) rel_name = item.resolve_relation_name(item, self._owner) relation = item._orm._get(rel_name) diff --git a/tests/test_many_to_many.py b/tests/test_many_to_many.py index 3ddeb61..8d8b258 100644 --- a/tests/test_many_to_many.py +++ b/tests/test_many_to_many.py @@ -6,7 +6,7 @@ import pytest import sqlalchemy import ormar -from ormar.exceptions import RelationshipInstanceError +from ormar.exceptions import ModelPersistenceError, NoMatch, RelationshipInstanceError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -196,3 +196,27 @@ async def test_selecting_related_fail_without_saving(cleanup): post = Post(title="Hello, M2M", author=guido) with pytest.raises(RelationshipInstanceError): await post.categories.all() + + +@pytest.mark.asyncio +async def test_adding_unsaved_related(cleanup): + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = Category(name="News") + with pytest.raises(ModelPersistenceError): + await post.categories.add(news) + + await news.save() + await post.categories.add(news) + assert len(await post.categories.all()) == 1 + + +@pytest.mark.asyncio +async def test_removing_unsaved_related(cleanup): + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = Category(name="News") + with pytest.raises(NoMatch): + await post.categories.remove(news)