From 594542a7f86a6d64ed92a6e8a304758b365c8ad9 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 7 Dec 2020 12:58:37 +0100 Subject: [PATCH] fix issue #68 --- docs/releases.md | 7 ++++- ormar/__init__.py | 2 +- ormar/models/model.py | 8 +++++- tests/test_saving_related.py | 53 ++++++++++++++++++++++++++++++++++++ tests/test_signals.py | 13 +++++++-- 5 files changed, 77 insertions(+), 6 deletions(-) create mode 100644 tests/test_saving_related.py diff --git a/docs/releases.md b/docs/releases.md index 1ab8e39..ac31dff 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,7 @@ +# 0.7.1 + +* Fix for overwriting related models with pk only in `Model.save()` (fix [#68][#68]) + # 0.7.0 * **Breaking:** QuerySet `bulk_update` method now raises `ModelPersistenceError` for unsaved models passed instead of `QueryDefinitionError` @@ -178,4 +182,5 @@ Add queryset level methods [#19]: https://github.com/collerek/ormar/issues/19 -[#60]: https://github.com/collerek/ormar/issues/60 \ No newline at end of file +[#60]: https://github.com/collerek/ormar/issues/60 +[#68]: https://github.com/collerek/ormar/issues/68 \ No newline at end of file diff --git a/ormar/__init__.py b/ormar/__init__.py index 5891f71..173c4f5 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -44,7 +44,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.7.0" +__version__ = "0.7.1" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/models/model.py b/ormar/models/model.py index 8ae50d2..d412a89 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -195,7 +195,13 @@ class Model(NewBaseModel): if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement: self_fields.pop(self.Meta.pkname, None) self_fields = self.populate_default_values(self_fields) - self.from_dict(self_fields) + self.from_dict( + { + k: v + for k, v in self_fields.items() + if k not in self.extract_related_names() + } + ) await self.signals.pre_save.send(sender=self.__class__, instance=self) diff --git a/tests/test_saving_related.py b/tests/test_saving_related.py new file mode 100644 index 0000000..4967ec6 --- /dev/null +++ b/tests/test_saving_related.py @@ -0,0 +1,53 @@ +from typing import Union + +import databases +import ormar +import pytest +import sqlalchemy as sa + +from tests.settings import DATABASE_URL + +engine = sa.create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) +metadata = sa.MetaData() +db = databases.Database(DATABASE_URL) + + +class Category(ormar.Model): + class Meta: + tablename = "categories" + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=50, unique=True, index=True) + code: int = ormar.Integer() + + +class Workshop(ormar.Model): + class Meta: + tablename = "workshops" + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + topic: str = ormar.String(max_length=255, index=True) + category: Union[ormar.Model, Category] = ormar.ForeignKey( + Category, related_name="workshops", nullable=False + ) + + +@pytest.fixture +def create_test_database(): + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_model_relationship(create_test_database): + cat = await Category(name="Foo", code=123).save() + ws = await Workshop(topic="Topic 1", category=cat).save() + + assert ws.id == 1 + assert ws.topic == "Topic 1" + assert ws.category.name == "Foo" diff --git a/tests/test_signals.py b/tests/test_signals.py index 5670980..a5c859d 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -78,6 +78,7 @@ def test_passing_not_callable(): def test_passing_callable_without_kwargs(): with pytest.raises(SignalDefinitionError): + @pre_save(Album) def trigger(sender, instance): # pragma: no cover pass @@ -87,6 +88,7 @@ def test_passing_callable_without_kwargs(): async def test_signal_functions(cleanup): async with database: async with database.transaction(force_rollback=True): + @pre_save(Album) async def before_save(sender, instance, **kwargs): await AuditLog( @@ -170,9 +172,9 @@ async def test_signal_functions(cleanup): assert len(audits) == 2 assert audits[0].event_type == "PRE_DELETE_album" assert ( - audits[0].event_log.get("id") - == audits[1].event_log.get("id") - == album.id + audits[0].event_log.get("id") + == audits[1].event_log.get("id") + == album.id ) assert audits[1].event_type == "POST_DELETE_album" @@ -186,6 +188,7 @@ async def test_signal_functions(cleanup): async def test_multiple_signals(cleanup): async with database: async with database.transaction(force_rollback=True): + @pre_save(Album) async def before_save(sender, instance, **kwargs): await AuditLog( @@ -216,6 +219,7 @@ async def test_multiple_signals(cleanup): async def test_static_methods_as_signals(cleanup): async with database: async with database.transaction(force_rollback=True): + class AlbumAuditor: event_type = "ALBUM_INSTANCE" @@ -240,6 +244,7 @@ async def test_static_methods_as_signals(cleanup): async def test_methods_as_signals(cleanup): async with database: async with database.transaction(force_rollback=True): + class AlbumAuditor: def __init__(self): self.event_type = "ALBUM_INSTANCE" @@ -265,6 +270,7 @@ async def test_methods_as_signals(cleanup): async def test_multiple_senders_signal(cleanup): async with database: async with database.transaction(force_rollback=True): + @pre_save([Album, Cover]) async def before_save(sender, instance, **kwargs): await AuditLog( @@ -292,6 +298,7 @@ async def test_multiple_senders_signal(cleanup): async def test_modifing_the_instance(cleanup): async with database: async with database.transaction(force_rollback=True): + @pre_update(Album) async def before_update(sender, instance, **kwargs): if instance.play_count > 50 and not instance.is_best_seller: