diff --git a/docs/models.md b/docs/models.md index d7d3393..dfc8324 100644 --- a/docs/models.md +++ b/docs/models.md @@ -228,6 +228,33 @@ Each model has a `QuerySet` initialised as `objects` parameter !!!info To read more about `QuerySets` (including bulk operations) and available methods visit [queries][queries] +## `Model` save status + +Each model instance is a separate python object and they do not know anything about each other. + +```python +track1 = await Track.objects.get(name='The Bird') +track2 = await Track.objects.get(name='The Bird') +assert track1 == track2 # True + +track1.name = 'The Bird2' +await track1.save() +assert track1.name == track2.name # False +# track2 does not update and knows nothing about track1 +``` + +The objects itself have a saved status, which is set as following: + +* Model is saved after `save/update/load/upsert` method on model +* Model is saved after `create/get/first/all/get_or_create/update_or_create` method +* Model is saved when passed to `bulk_update` and `bulk_create` +* Model is saved after `adding/removing` `ManyToMany` related objects (through model instance auto saved/deleted) +* Model is **not** saved after change of any own field (including `pk` as `Model.pk` alias) +* Model is **not** saved after adding/removing `ForeignKey` related object (fk column not saved) +* Model is **not** saved after instantiation with `__init__` (w/o `QuerySet.create` or before calling `save`) + +You can check if model is saved with `ModelInstance.saved` property + ## `Model` methods ### load @@ -249,16 +276,57 @@ track.album.name # will return 'Malibu' ### save +`save() -> self` + You can create new models by using `QuerySet.create()` method or by initializing your model as a normal pydantic model and later calling `save()` method. -`save()` can also be used to persist changes that you made to the model. +`save()` can also be used to persist changes that you made to the model, but only if the primary key is not set or the model does not exist in database. + +The `save()` method does not check if the model exists in db, so if it does you will get a integrity error from your selected db backend if trying to save model with already existing primary key. ```python track = Track(name='The Bird') await track.save() # will persist the model in database + +track = await Track.objects.get(name='The Bird') +await track.save() # will raise integrity error as pk is populated ``` +### update + +`update(**kwargs) -> self` + +You can update models by using `QuerySet.update()` method or by updating your model attributes (fields) and calling `update()` method. + +If you try to update a model without a primary key set a `ModelPersistenceError` exception will be thrown. + +To persist a newly created model use `save()` or `upsert(**kwargs)` methods. + +```python +track = await Track.objects.get(name='The Bird') +await track.update(name='The Bird Strikes Again') +``` + +### upsert + +`upsert(**kwargs) -> self` + +It's an proxy to either `save()` or `update(**kwargs)` methods described above. + +If the primary key is set -> the `update` method will be called. + +If the pk is not set the `save()` method will be called. + +```python +track = Track(name='The Bird') +await track.upsert() # will call save as the pk is empty + +track = await Track.objects.get(name='The Bird') +await track.upsert(name='The Bird Strikes Again') # will call update as pk is already populated +``` + + ### delete You can delete models by using `QuerySet.delete()` method or by using your model and calling `delete()` method. @@ -271,14 +339,29 @@ await track.delete() # will delete the model from database !!!tip Note that that `track` object stays the same, only record in the database is removed. -### update +### save_related -You can delete models by using `QuerySet.update()` method or by using your model and calling `update()` method. +`save_related(follow: bool = False) -> None` -```python -track = await Track.objects.get(name='The Bird') -await track.update(name='The Bird Strikes Again') -``` +Method goes through all relations of the `Model` on which the method is called, +and calls `upsert()` method on each model that is **not** saved. + +To understand when a model is saved check [save status][save status] section above. + +By default the `save_related` method saved only models that are directly related (one step away) to the model on which the method is called. + +But you can specify the `follow=True` parameter to traverse through nested models and save all of them in the relation tree. + +!!!warning + To avoid circular updates with `follow=True` set, `save_related` keeps a set of already visited Models, + and won't perform nested `save_related` on Models that were already visited. + + So if you have a diamond or circular relations types you need to perform the updates in a manual way. + + ```python + # in example like this the second Street (coming from City) won't be save_related, so ZipCode won't be updated + Street -> District -> City -> Street -> ZipCode + ``` ## Internals @@ -348,3 +431,4 @@ For example to list table model fields you can: [sqlalchemy connection string]: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls [sqlalchemy table creation]: https://docs.sqlalchemy.org/en/13/core/metadata.html#creating-and-dropping-database-tables [alembic]: https://alembic.sqlalchemy.org/en/latest/tutorial.html +[save status]: ../models/#model-save-status diff --git a/docs/releases.md b/docs/releases.md index c4fd384..e636ceb 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,24 @@ +# 0.5.0 + +* Added save status -> you can check if model is saved with `ModelInstance.saved` property + * Model is saved after `save/update/load/upsert` method on model + * Model is saved after `create/get/first/all/get_or_create/update_or_create` method + * Model is saved when passed to `bulk_update` and `bulk_create` + * Model is saved after adding/removing `ManyToMany` related objects (through model instance auto saved/deleted) + * Model is **not** saved after change of any own field (including pk as `Model.pk` alias) + * Model is **not** saved after adding/removing `ForeignKey` related object (fk column not saved) + * Model is **not** saved after instantation with `__init__` (w/o `QuerySet.create` or before calling `save`) +* Added `Model.upsert(**kwargs)` that performs `save()` if pk not set otherwise `update(**kwargs)` +* Added `Model.save_related(follow=False)` that iterates all related objects in all relations and checks if they are saved. If not it calls `upsert()` on each of them. +* **Breaking:** added raising exceptions if `add`-ing/`remove`-ing not saved (pk is None) models to `ManyToMany` relation +* Allow passing dictionaries and sets to fields and exclude_fields +* Auto translate str and lists to dicts for fields and exclude_fields +* **Breaking:** passing nested models to fields and exclude_fields is now by related ForeignKey name and not by target model name +* Performance optimizations - in modelproxy, newbasemodel - > less queries, some properties are cached on models +* Cleanup of unused relations code +* Optional performance dependency orjson added (**strongly recommended**) +* Updated docs + # 0.4.4 * add exclude_fields() method to exclude fields from sql diff --git a/ormar/__init__.py b/ormar/__init__.py index 03476f4..7f9e9e2 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.4.4" +__version__ = "0.5.0" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/exceptions.py b/ormar/exceptions.py index 40cfd26..0800a81 100644 --- a/ormar/exceptions.py +++ b/ormar/exceptions.py @@ -24,3 +24,7 @@ class QueryDefinitionError(AsyncOrmException): class RelationshipInstanceError(AsyncOrmException): pass + + +class ModelPersistenceError(AsyncOrmException): + pass diff --git a/ormar/models/model.py b/ormar/models/model.py index 688eed0..fb5c295 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -1,9 +1,21 @@ import itertools -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union +from typing import ( + Any, + Dict, + List, + Optional, + Set, + TYPE_CHECKING, + Tuple, + Type, + TypeVar, + Union, +) import sqlalchemy import ormar.queryset # noqa I100 +from ormar.exceptions import ModelPersistenceError from ormar.fields.many_to_many import ManyToManyField from ormar.models import NewBaseModel # noqa I100 from ormar.models.metaclass import ModelMeta @@ -90,9 +102,13 @@ class Model(NewBaseModel): exclude_fields=exclude_fields, ) - instance: Optional[T] = cls(**item) if item.get( - cls.Meta.pkname, None - ) is not None else None + instance: Optional[T] = None + if item.get(cls.Meta.pkname, None) is not None: + instance = cls(**item) + instance.set_save_status(True) + else: + instance = None + return instance @classmethod @@ -165,6 +181,11 @@ class Model(NewBaseModel): return item + async def upsert(self: T, **kwargs: Any) -> T: + if not self.pk: + return await self.save() + return await self.update(**kwargs) + async def save(self: T) -> T: self_fields = self._extract_model_db_fields() @@ -179,13 +200,61 @@ class Model(NewBaseModel): item_id = await self.Meta.database.execute(expr) if item_id: # postgress does not return id if it's already there setattr(self, self.Meta.pkname, item_id) + self.set_save_status(True) return self + async def save_related( # noqa: CCR001 + self, follow: bool = False, visited: Set = None, update_count: int = 0 + ) -> int: # noqa: CCR001 + if not visited: + visited = {self.__class__} + else: + visited = {x for x in visited} + visited.add(self.__class__) + + for related in self.extract_related_names(): + if self.Meta.model_fields[related].virtual or issubclass( + self.Meta.model_fields[related], ManyToManyField + ): + for rel in getattr(self, related): + update_count, visited = await self._update_and_follow( + rel=rel, + follow=follow, + visited=visited, + update_count=update_count, + ) + visited.add(self.Meta.model_fields[related].to) + else: + rel = getattr(self, related) + update_count, visited = await self._update_and_follow( + rel=rel, follow=follow, visited=visited, update_count=update_count + ) + visited.add(rel.__class__) + return update_count + + @staticmethod + async def _update_and_follow( + rel: T, follow: bool, visited: Set, update_count: int + ) -> Tuple[int, Set]: + if follow and rel.__class__ not in visited: + update_count = await rel.save_related( + follow=follow, visited=visited, update_count=update_count + ) + if not rel.saved: + await rel.upsert() + update_count += 1 + return update_count, visited + async def update(self: T, **kwargs: Any) -> T: if kwargs: new_values = {**self.dict(), **kwargs} self.from_dict(new_values) + if not self.pk: + raise ModelPersistenceError( + "You cannot update not saved model! Use save or upsert method." + ) + self_fields = self._extract_model_db_fields() self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname)) self_fields = self.translate_columns_to_aliases(self_fields) @@ -193,12 +262,14 @@ class Model(NewBaseModel): expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname)) await self.Meta.database.execute(expr) + self.set_save_status(True) return self async def delete(self: T) -> int: expr = self.Meta.table.delete() expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname))) result = await self.Meta.database.execute(expr) + self.set_save_status(False) return result async def load(self: T) -> T: @@ -211,4 +282,5 @@ class Model(NewBaseModel): kwargs = dict(row) kwargs = self.translate_aliases_to_columns(kwargs) self.from_dict(kwargs) + self.set_save_status(True) return self diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 294d592..3835b41 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -12,7 +12,7 @@ from typing import ( Union, ) -from ormar.exceptions import RelationshipInstanceError +from ormar.exceptions import ModelPersistenceError, RelationshipInstanceError try: import orjson as json @@ -63,8 +63,14 @@ class ModelTableProxy: target_field = cls.Meta.model_fields[field] target_pkname = target_field.to.Meta.pkname if isinstance(field_value, ormar.Model): - model_dict[field] = getattr(field_value, target_pkname) - elif field_value: + pk_value = getattr(field_value, target_pkname) + if not pk_value: + raise ModelPersistenceError( + f"You cannot save {field_value.get_name()} " + f"model without pk set!" + ) + model_dict[field] = pk_value + elif field_value: # nested dict model_dict[field] = field_value.get(target_pkname) else: model_dict.pop(field, None) @@ -220,6 +226,7 @@ class ModelTableProxy: field, cls.merge_two_instances(current_field, getattr(other, field)), ) + other.set_save_status(True) return other @staticmethod diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 2a96c19..5acf4ca 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -123,12 +123,16 @@ class NewBaseModel( object.__setattr__(self, name, value) elif name == "pk": object.__setattr__(self, self.Meta.pkname, value) + self.set_save_status(False) elif name in self._orm: model = self.Meta.model_fields[name].expand_relationship(value, self) if isinstance(self.__dict__.get(name), list): + # virtual foreign key or many to many self.__dict__[name].append(model) else: + # foreign key relation self.__dict__[name] = model + self.set_save_status(False) else: value = ( self._convert_json(name, value, "dumps") @@ -136,6 +140,7 @@ class NewBaseModel( else value ) super().__setattr__(name, value) + self.set_save_status(False) def __getattribute__(self, item: str) -> Any: if item in ( @@ -188,6 +193,10 @@ class NewBaseModel( def pk_column(self) -> sqlalchemy.Column: return self.Meta.table.primary_key.columns.values()[0] + @property + def saved(self) -> bool: + return self._orm_saved + @classmethod def pk_type(cls) -> Any: return cls.Meta.model_fields[cls.Meta.pkname].__type__ @@ -199,6 +208,9 @@ class NewBaseModel( def remove(self, name: "T") -> None: self._orm.remove_parent(self, name) + def set_save_status(self, status: bool) -> None: + object.__setattr__(self, "_orm_saved", status) + @classmethod def get_properties( cls, @@ -212,7 +224,8 @@ class NewBaseModel( prop for prop in dir(cls) if isinstance(getattr(cls, prop), property) - and prop not in ("__values__", "__fields__", "fields", "pk_column") + and prop + not in ("__values__", "__fields__", "fields", "pk_column", "saved") ] cls._props = props if include: diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 448a0a9..b8b4aa3 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -358,8 +358,13 @@ class QuerySet: instance.pk = pk # refresh server side defaults - instance = await instance.load() - + if any( + field.server_default is not None + for name, field in self.model.Meta.model_fields.items() + if name not in kwargs + ): + instance = await instance.load() + instance.set_save_status(True) return instance async def bulk_create(self, objects: List["Model"]) -> None: @@ -372,7 +377,10 @@ class QuerySet: expr = self.table.insert() await self.database.execute_many(expr, ready_objects) - async def bulk_update( + for objt in objects: + objt.set_save_status(True) + + async def bulk_update( # noqa: CCR001 self, objects: List["Model"], columns: List[str] = None ) -> None: ready_objects = [] @@ -418,3 +426,6 @@ class QuerySet: # otherwise it just passes all data to values and results in unconsumed columns expr = str(expr) await self.database.execute_many(expr, ready_objects) + + for objt in objects: + objt.set_save_status(True) diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index 88130d5..f03252e 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -1,7 +1,7 @@ from typing import Any, TYPE_CHECKING import ormar -from ormar.exceptions import RelationshipInstanceError +from ormar.exceptions import NoMatch, RelationshipInstanceError from ormar.relations.querysetproxy import QuerysetProxy if TYPE_CHECKING: # pragma no cover @@ -54,6 +54,11 @@ class RelationProxy(list): return queryset async def remove(self, item: "Model") -> None: # type: ignore + if item not in self: + raise NoMatch( + f"Object {self._owner.get_name()} has no " + f"{item.get_name()} with given primary key!" + ) super().remove(item) rel_name = item.resolve_relation_name(item, self._owner) relation = item._orm._get(rel_name) diff --git a/tests/test_many_to_many.py b/tests/test_many_to_many.py index 3ddeb61..8d8b258 100644 --- a/tests/test_many_to_many.py +++ b/tests/test_many_to_many.py @@ -6,7 +6,7 @@ import pytest import sqlalchemy import ormar -from ormar.exceptions import RelationshipInstanceError +from ormar.exceptions import ModelPersistenceError, NoMatch, RelationshipInstanceError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -196,3 +196,27 @@ async def test_selecting_related_fail_without_saving(cleanup): post = Post(title="Hello, M2M", author=guido) with pytest.raises(RelationshipInstanceError): await post.categories.all() + + +@pytest.mark.asyncio +async def test_adding_unsaved_related(cleanup): + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = Category(name="News") + with pytest.raises(ModelPersistenceError): + await post.categories.add(news) + + await news.save() + await post.categories.add(news) + assert len(await post.categories.all()) == 1 + + +@pytest.mark.asyncio +async def test_removing_unsaved_related(cleanup): + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = Category(name="News") + with pytest.raises(NoMatch): + await post.categories.remove(news) diff --git a/tests/test_save_related.py b/tests/test_save_related.py new file mode 100644 index 0000000..0696018 --- /dev/null +++ b/tests/test_save_related.py @@ -0,0 +1,210 @@ +from typing import List + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class CringeLevel(ormar.Model): + class Meta: + tablename = "levels" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class NickNames(ormar.Model): + class Meta: + tablename = "nicks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + is_lame: bool = ormar.Boolean(nullable=True) + level: CringeLevel = ormar.ForeignKey(CringeLevel) + + +class NicksHq(ormar.Model): + class Meta: + tablename = "nicks_x_hq" + metadata = metadata + database = database + + +class HQ(ormar.Model): + class Meta: + tablename = "hqs" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq) + + +class Company(ormar.Model): + class Meta: + tablename = "companies" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="company_name") + founded: int = ormar.Integer(nullable=True) + hq: HQ = ormar.ForeignKey(HQ, related_name="companies") + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_saving_related_fk_rel(): + async with database: + async with database.transaction(force_rollback=True): + hq = await HQ.objects.create(name="Main") + comp = await Company.objects.create(name="Banzai", founded=1988, hq=hq) + assert comp.saved + + count = await comp.save_related() + assert count == 0 + + comp.hq.name = "Suburbs" + assert not comp.hq.saved + assert comp.saved + + count = await comp.save_related() + assert count == 1 + assert comp.hq.saved + + +@pytest.mark.asyncio +async def test_saving_many_to_many(): + async with database: + async with database.transaction(force_rollback=True): + nick1 = await NickNames.objects.create(name="BazingaO", is_lame=False) + nick2 = await NickNames.objects.create(name="Bazinga20", is_lame=True) + + hq = await HQ.objects.create(name="Main") + assert hq.saved + + await hq.nicks.add(nick1) + assert hq.saved + await hq.nicks.add(nick2) + assert hq.saved + + count = await hq.save_related() + assert count == 0 + + hq.nicks[0].name = "Kabucha" + hq.nicks[1].name = "Kabucha2" + assert not hq.nicks[0].saved + assert not hq.nicks[1].saved + + count = await hq.save_related() + assert count == 2 + assert hq.nicks[0].saved + assert hq.nicks[1].saved + + +@pytest.mark.asyncio +async def test_saving_reversed_relation(): + async with database: + async with database.transaction(force_rollback=True): + hq = await HQ.objects.create(name="Main") + await Company.objects.create(name="Banzai", founded=1988, hq=hq) + + hq = await HQ.objects.select_related("companies").get(name="Main") + assert hq.saved + assert hq.companies[0].saved + + hq.companies[0].name = "Konichiwa" + assert not hq.companies[0].saved + count = await hq.save_related() + assert count == 1 + assert hq.companies[0].saved + + await Company.objects.create(name="Joshua", founded=1888, hq=hq) + + hq = await HQ.objects.select_related("companies").get(name="Main") + assert hq.saved + assert hq.companies[0].saved + assert hq.companies[1].saved + + hq.companies[0].name = hq.companies[0].name + "20" + assert not hq.companies[0].saved + # save only if not saved so now only one + count = await hq.save_related() + assert count == 1 + assert hq.companies[0].saved + + hq.companies[0].name = hq.companies[0].name + "20" + hq.companies[1].name = hq.companies[1].name + "30" + assert not hq.companies[0].saved + assert not hq.companies[1].saved + count = await hq.save_related() + assert count == 2 + assert hq.companies[0].saved + assert hq.companies[1].saved + + +@pytest.mark.asyncio +async def test_saving_nested(): + async with database: + async with database.transaction(force_rollback=True): + level = await CringeLevel.objects.create(name="High") + level2 = await CringeLevel.objects.create(name="Low") + nick1 = await NickNames.objects.create( + name="BazingaO", is_lame=False, level=level + ) + nick2 = await NickNames.objects.create( + name="Bazinga20", is_lame=True, level=level2 + ) + + hq = await HQ.objects.create(name="Main") + assert hq.saved + + await hq.nicks.add(nick1) + assert hq.saved + await hq.nicks.add(nick2) + assert hq.saved + + count = await hq.save_related() + assert count == 0 + + hq.nicks[0].level.name = "Medium" + assert not hq.nicks[0].level.saved + assert hq.nicks[0].saved + + count = await hq.save_related(follow=True) + assert count == 1 + assert hq.nicks[0].saved + assert hq.nicks[0].level.saved + + hq.nicks[0].level.name = "Low" + hq.nicks[1].level.name = "Medium" + assert not hq.nicks[0].level.saved + assert not hq.nicks[1].level.saved + assert hq.nicks[0].saved + assert hq.nicks[1].saved + + count = await hq.save_related(follow=True) + assert count == 2 + assert hq.nicks[0].saved + assert hq.nicks[0].level.saved + assert hq.nicks[1].saved + assert hq.nicks[1].level.saved diff --git a/tests/test_save_status.py b/tests/test_save_status.py new file mode 100644 index 0000000..93e89ac --- /dev/null +++ b/tests/test_save_status.py @@ -0,0 +1,257 @@ +from typing import List + +import databases +import pytest +import sqlalchemy + +import ormar +from ormar.exceptions import ModelPersistenceError +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class NickNames(ormar.Model): + class Meta: + tablename = "nicks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + is_lame: bool = ormar.Boolean(nullable=True) + + +class NicksHq(ormar.Model): + class Meta: + tablename = "nicks_x_hq" + metadata = metadata + database = database + + +class HQ(ormar.Model): + class Meta: + tablename = "hqs" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq) + + +class Company(ormar.Model): + class Meta: + tablename = "companies" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="company_name") + founded: int = ormar.Integer(nullable=True) + hq: HQ = ormar.ForeignKey(HQ) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_instantation_false_save_true(): + async with database: + async with database.transaction(force_rollback=True): + comp = Company(name="Banzai", founded=1988) + assert not comp.saved + await comp.save() + assert comp.saved + + +@pytest.mark.asyncio +async def test_saved_edited_not_saved(): + async with database: + async with database.transaction(force_rollback=True): + comp = await Company.objects.create(name="Banzai", founded=1988) + assert comp.saved + comp.name = "Banzai2" + assert not comp.saved + + await comp.update() + assert comp.saved + + await comp.update(name="Banzai3") + assert comp.saved + + comp.pk = 999 + assert not comp.saved + + await comp.update() + assert comp.saved + + +@pytest.mark.asyncio +async def test_adding_related_gets_dirty(): + async with database: + async with database.transaction(force_rollback=True): + hq = await HQ.objects.create(name="Main") + comp = await Company.objects.create(name="Banzai", founded=1988) + assert comp.saved + + comp.hq = hq + assert not comp.saved + await comp.update() + assert comp.saved + + comp = await Company.objects.select_related("hq").get(name="Banzai") + assert comp.saved + + assert comp.hq.pk == hq.pk + assert comp.hq.saved + + comp.hq.name = "Suburbs" + assert not comp.hq.saved + assert comp.saved + + await comp.hq.update() + assert comp.hq.saved + + +@pytest.mark.asyncio +async def test_adding_many_to_many_does_not_gets_dirty(): + async with database: + async with database.transaction(force_rollback=True): + nick1 = await NickNames.objects.create(name="Bazinga", is_lame=False) + nick2 = await NickNames.objects.create(name="Bazinga2", is_lame=True) + + hq = await HQ.objects.create(name="Main") + assert hq.saved + + await hq.nicks.add(nick1) + assert hq.saved + await hq.nicks.add(nick2) + assert hq.saved + + hq = await HQ.objects.select_related("nicks").get(name="Main") + assert hq.saved + assert hq.nicks[0].saved + + await hq.nicks.remove(nick1) + assert hq.saved + + hq.nicks[0].name = "Kabucha" + assert not hq.nicks[0].saved + + await hq.nicks[0].update() + assert hq.nicks[0].saved + + +@pytest.mark.asyncio +async def test_delete(): + async with database: + async with database.transaction(force_rollback=True): + comp = await Company.objects.create(name="Banzai", founded=1988) + assert comp.saved + await comp.delete() + assert not comp.saved + + await comp.update() + assert comp.saved + + +@pytest.mark.asyncio +async def test_load(): + async with database: + async with database.transaction(force_rollback=True): + comp = await Company.objects.create(name="Banzai", founded=1988) + assert comp.saved + comp.name = "AA" + assert not comp.saved + + await comp.load() + assert comp.saved + assert comp.name == "Banzai" + + +@pytest.mark.asyncio +async def test_queryset_methods(): + async with database: + async with database.transaction(force_rollback=True): + await Company.objects.create(name="Banzai", founded=1988) + await Company.objects.create(name="Yuhu", founded=1989) + await Company.objects.create(name="Konono", founded=1990) + await Company.objects.create(name="Sumaaa", founded=1991) + + comp = await Company.objects.get(name="Banzai") + assert comp.saved + + comp = await Company.objects.first() + assert comp.saved + + comps = await Company.objects.all() + assert [comp.saved for comp in comps] + + comp2 = await Company.objects.get_or_create(name="Banzai_new", founded=2001) + assert comp2.saved + + comp3 = await Company.objects.get_or_create(name="Banzai", founded=1988) + assert comp3.saved + assert comp3.pk == comp.pk + + update_dict = comp.dict() + update_dict["founded"] = 2010 + comp = await Company.objects.update_or_create(**update_dict) + assert comp.saved + assert comp.founded == 2010 + + create_dict = {"name": "Yoko", "founded": 2005} + comp = await Company.objects.update_or_create(**create_dict) + assert comp.saved + assert comp.founded == 2005 + + +@pytest.mark.asyncio +async def test_bulk_methods(): + async with database: + async with database.transaction(force_rollback=True): + c1 = Company(name="Banzai", founded=1988) + c2 = Company(name="Yuhu", founded=1989) + + await Company.objects.bulk_create([c1, c2]) + assert c1.saved + assert c2.saved + + c1, c2 = await Company.objects.all() + c1.name = "Banzai2" + c2.name = "Yuhu2" + + assert not c1.saved + assert not c2.saved + + await Company.objects.bulk_update([c1, c2]) + assert c1.saved + assert c2.saved + + c3 = Company(name="Cobra", founded=2088) + assert not c3.saved + + with pytest.raises(ModelPersistenceError): + await c3.update() + + await c3.upsert() + assert c3.saved + + c3.name = "Python" + assert not c3.saved + + await c3.upsert() + assert c3.saved + assert c3.name == "Python" + + await c3.upsert(founded=2077) + assert c3.saved + assert c3.founded == 2077