From d478ea6e15a2b2dd689743d26afc5ef3ba2c1b05 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 15 Nov 2020 10:33:03 +0100 Subject: [PATCH] add follow=True for save_related, update docs --- docs/models.md | 98 +++++++++++++++++++++++++++++++++++--- ormar/models/model.py | 54 +++++++++++++++++---- tests/test_save_related.py | 59 ++++++++++++++++++++++- 3 files changed, 193 insertions(+), 18 deletions(-) diff --git a/docs/models.md b/docs/models.md index d7d3393..fc29f6e 100644 --- a/docs/models.md +++ b/docs/models.md @@ -228,6 +228,33 @@ Each model has a `QuerySet` initialised as `objects` parameter !!!info To read more about `QuerySets` (including bulk operations) and available methods visit [queries][queries] +## `Model` save status + +Each model instance is a separate python object and they do not know anything about each other. + +```python +track1 = await Track.objects.get(name='The Bird') +track2 = await Track.objects.get(name='The Bird') +assert track1 == track2 # True + +track1.name = 'The Bird2' +await track1.save() +assert track1.name == track2.name # False +# track2 does not update and knows nothing about track1 +``` + +The objects itself have a saved status, which is set as following: + +* Model is saved after `save/update/load/upsert` method on model +* Model is saved after `create/get/first/all/get_or_create/update_or_create` method +* Model is saved when passed to `bulk_update` and `bulk_create` +* Model is saved after `adding/removing` `ManyToMany` related objects (through model instance auto saved/deleted) +* Model is **not** saved after change of any own field (including `pk` as `Model.pk` alias) +* Model is **not** saved after adding/removing `ForeignKey` related object (fk column not saved) +* Model is **not** saved after instantiation with `__init__` (w/o `QuerySet.create` or before calling `save`) + +You can check if model is saved with `ModelInstance.saved` property + ## `Model` methods ### load @@ -249,16 +276,57 @@ track.album.name # will return 'Malibu' ### save +`save() -> self` + You can create new models by using `QuerySet.create()` method or by initializing your model as a normal pydantic model and later calling `save()` method. -`save()` can also be used to persist changes that you made to the model. +`save()` can also be used to persist changes that you made to the model, but only if the primary key is not set or the model does not exist in database. + +The `save()` method does not check if the model exists in db, so if it does you will get a integrity error from your selected db backend if trying to save model with already existing primary key. ```python track = Track(name='The Bird') await track.save() # will persist the model in database + +track = await Track.objects.get(name='The Bird') +await track.save() # will raise integrity error as pk is populated ``` +### update + +`update(**kwargs) -> self` + +You can update models by using `QuerySet.update()` method or by updating your model attributes (fields) and calling `update()` method. + +If you try to update a model without a primary key set a `ModelPersistenceError` exception will be thrown. + +To persist a newly created model use `save()` or `upsert(**kwargs)` methods. + +```python +track = await Track.objects.get(name='The Bird') +await track.update(name='The Bird Strikes Again') +``` + +### upsert + +`upsert(**kwargs) -> self` + +It's an proxy to either `save()` or `update(**kwargs)` methods described above. + +If the primary key is set -> the `update` method will be called. + +If the pk is not set the `save()` method will be called. + +```python +track = Track(name='The Bird') +await track.upsert() # will call save as the pk is empty + +track = await Track.objects.get(name='The Bird') +await track.upsert(name='The Bird Strikes Again') # will call update as pk is already populated +``` + + ### delete You can delete models by using `QuerySet.delete()` method or by using your model and calling `delete()` method. @@ -271,14 +339,29 @@ await track.delete() # will delete the model from database !!!tip Note that that `track` object stays the same, only record in the database is removed. -### update +### save_related -You can delete models by using `QuerySet.update()` method or by using your model and calling `update()` method. +`save_related(follow: bool = False) -> None` -```python -track = await Track.objects.get(name='The Bird') -await track.update(name='The Bird Strikes Again') -``` +Method goes through all relations of the `Model` on which the method is called, +and calls `upsert()` method on each model that is **not** saved. + +To understand when a model is saved check [save status][save status] section above. + +By default the `save_related` method saved only models that are directly related (one step away) to the model on which the method is called. + +But you can specify the `follow=True` parameter to traverse through nested models and save all of them in the relation tree. + +!!!warning + To avoid circular updates with `follow=True` set, `save_related` keeps a set of already visited Models, + and won't perform nested `save_related` on Models that were already visited. + + So if you have a diamond or circular relations types you need to perform the updates in a manual way. + + ```python + # in example like this the second Street (coming from Company) won't be save_related, so Whatever won't be updated + Street -> District -> City -> Companies -> Street -> Whatever + ``` ## Internals @@ -348,3 +431,4 @@ For example to list table model fields you can: [sqlalchemy connection string]: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls [sqlalchemy table creation]: https://docs.sqlalchemy.org/en/13/core/metadata.html#creating-and-dropping-database-tables [alembic]: https://alembic.sqlalchemy.org/en/latest/tutorial.html +[save status]: ../models/#model-save-status diff --git a/ormar/models/model.py b/ormar/models/model.py index 80c9807..aeb63c3 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -1,5 +1,16 @@ import itertools -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union +from typing import ( + Any, + Dict, + List, + Optional, + Set, + TYPE_CHECKING, + Tuple, + Type, + TypeVar, + Union, +) import sqlalchemy @@ -192,23 +203,48 @@ class Model(NewBaseModel): self.set_save_status(True) return self - async def save_related(self) -> int: # noqa: CCR001 - update_count = 0 + async def save_related( # noqa: CCR001 + self, follow: bool = False, visited: Set = None, update_count: int = 0 + ) -> int: # noqa: CCR001 + if not visited: + visited = {self.__class__} + else: + visited = {x for x in visited} + visited.add(self.__class__) + for related in self.extract_related_names(): if self.Meta.model_fields[related].virtual or issubclass( self.Meta.model_fields[related], ManyToManyField ): for rel in getattr(self, related): - if not rel.saved: - await rel.upsert() - update_count += 1 + update_count, visited = await self._update_and_follow( + rel=rel, + follow=follow, + visited=visited, + update_count=update_count, + ) + visited.add(self.Meta.model_fields[related].to) else: rel = getattr(self, related) - if not rel.saved: - await rel.upsert() - update_count += 1 + update_count, visited = await self._update_and_follow( + rel=rel, follow=follow, visited=visited, update_count=update_count + ) + visited.add(rel.__class__) return update_count + @staticmethod + async def _update_and_follow( + rel: "Model", follow: bool, visited: Set, update_count: int + ) -> Tuple[int, Set]: + if follow and rel.__class__ not in visited: + update_count = await rel.save_related( + follow=follow, visited=visited, update_count=update_count + ) + if not rel.saved: + await rel.upsert() + update_count += 1 + return update_count, visited + async def update(self: T, **kwargs: Any) -> T: if kwargs: new_values = {**self.dict(), **kwargs} diff --git a/tests/test_save_related.py b/tests/test_save_related.py index 1defabf..6cfe108 100644 --- a/tests/test_save_related.py +++ b/tests/test_save_related.py @@ -11,6 +11,16 @@ database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() +class CringeLevel(ormar.Model): + class Meta: + tablename = "levels" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + class NickNames(ormar.Model): class Meta: tablename = "nicks" @@ -20,6 +30,7 @@ class NickNames(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") is_lame: bool = ormar.Boolean(nullable=True) + level: CringeLevel = ormar.ForeignKey(CringeLevel) class NicksHq(ormar.Model): @@ -82,7 +93,7 @@ async def test_saving_related_fk_rel(): @pytest.mark.asyncio -async def test_adding_many_to_many_does_not_gets_dirty(): +async def test_saving_many_to_many(): async with database: async with database.transaction(force_rollback=True): nick1 = await NickNames.objects.create(name="BazingaO", is_lame=False) @@ -111,7 +122,7 @@ async def test_adding_many_to_many_does_not_gets_dirty(): @pytest.mark.asyncio -async def test_queryset_methods(): +async def test_saving_reversed_relation(): async with database: async with database.transaction(force_rollback=True): hq = await HQ.objects.create(name="Main") @@ -149,3 +160,47 @@ async def test_queryset_methods(): assert count == 2 assert hq.companies[0].saved assert hq.companies[1].saved + + +@pytest.mark.asyncio +async def test_saving_nested(): + async with database: + async with database.transaction(force_rollback=True): + level = await CringeLevel.objects.create(name='High') + level2 = await CringeLevel.objects.create(name='Low') + nick1 = await NickNames.objects.create(name="BazingaO", is_lame=False, level=level) + nick2 = await NickNames.objects.create(name="Bazinga20", is_lame=True, level=level2) + + hq = await HQ.objects.create(name="Main") + assert hq.saved + + await hq.nicks.add(nick1) + assert hq.saved + await hq.nicks.add(nick2) + assert hq.saved + + count = await hq.save_related() + assert count == 0 + + hq.nicks[0].level.name = "Medium" + assert not hq.nicks[0].level.saved + assert hq.nicks[0].saved + + count = await hq.save_related(follow=True) + assert count == 1 + assert hq.nicks[0].saved + assert hq.nicks[0].level.saved + + hq.nicks[0].level.name = "Low" + hq.nicks[1].level.name = "Medium" + assert not hq.nicks[0].level.saved + assert not hq.nicks[1].level.saved + assert hq.nicks[0].saved + assert hq.nicks[1].saved + + count = await hq.save_related(follow=True) + assert count == 2 + assert hq.nicks[0].saved + assert hq.nicks[0].level.saved + assert hq.nicks[1].saved + assert hq.nicks[1].level.saved