Merge pull request #45 from collerek/save_status
Add save status, add save_related method, safe fails for adding/removing to many2many, added upsert method
This commit is contained in:
@ -228,6 +228,33 @@ Each model has a `QuerySet` initialised as `objects` parameter
|
||||
!!!info
|
||||
To read more about `QuerySets` (including bulk operations) and available methods visit [queries][queries]
|
||||
|
||||
## `Model` save status
|
||||
|
||||
Each model instance is a separate python object and they do not know anything about each other.
|
||||
|
||||
```python
|
||||
track1 = await Track.objects.get(name='The Bird')
|
||||
track2 = await Track.objects.get(name='The Bird')
|
||||
assert track1 == track2 # True
|
||||
|
||||
track1.name = 'The Bird2'
|
||||
await track1.save()
|
||||
assert track1.name == track2.name # False
|
||||
# track2 does not update and knows nothing about track1
|
||||
```
|
||||
|
||||
The objects itself have a saved status, which is set as following:
|
||||
|
||||
* Model is saved after `save/update/load/upsert` method on model
|
||||
* Model is saved after `create/get/first/all/get_or_create/update_or_create` method
|
||||
* Model is saved when passed to `bulk_update` and `bulk_create`
|
||||
* Model is saved after `adding/removing` `ManyToMany` related objects (through model instance auto saved/deleted)
|
||||
* Model is **not** saved after change of any own field (including `pk` as `Model.pk` alias)
|
||||
* Model is **not** saved after adding/removing `ForeignKey` related object (fk column not saved)
|
||||
* Model is **not** saved after instantiation with `__init__` (w/o `QuerySet.create` or before calling `save`)
|
||||
|
||||
You can check if model is saved with `ModelInstance.saved` property
|
||||
|
||||
## `Model` methods
|
||||
|
||||
### load
|
||||
@ -249,16 +276,57 @@ track.album.name # will return 'Malibu'
|
||||
|
||||
### save
|
||||
|
||||
`save() -> self`
|
||||
|
||||
You can create new models by using `QuerySet.create()` method or by initializing your model as a normal pydantic model
|
||||
and later calling `save()` method.
|
||||
|
||||
`save()` can also be used to persist changes that you made to the model.
|
||||
`save()` can also be used to persist changes that you made to the model, but only if the primary key is not set or the model does not exist in database.
|
||||
|
||||
The `save()` method does not check if the model exists in db, so if it does you will get a integrity error from your selected db backend if trying to save model with already existing primary key.
|
||||
|
||||
```python
|
||||
track = Track(name='The Bird')
|
||||
await track.save() # will persist the model in database
|
||||
|
||||
track = await Track.objects.get(name='The Bird')
|
||||
await track.save() # will raise integrity error as pk is populated
|
||||
```
|
||||
|
||||
### update
|
||||
|
||||
`update(**kwargs) -> self`
|
||||
|
||||
You can update models by using `QuerySet.update()` method or by updating your model attributes (fields) and calling `update()` method.
|
||||
|
||||
If you try to update a model without a primary key set a `ModelPersistenceError` exception will be thrown.
|
||||
|
||||
To persist a newly created model use `save()` or `upsert(**kwargs)` methods.
|
||||
|
||||
```python
|
||||
track = await Track.objects.get(name='The Bird')
|
||||
await track.update(name='The Bird Strikes Again')
|
||||
```
|
||||
|
||||
### upsert
|
||||
|
||||
`upsert(**kwargs) -> self`
|
||||
|
||||
It's an proxy to either `save()` or `update(**kwargs)` methods described above.
|
||||
|
||||
If the primary key is set -> the `update` method will be called.
|
||||
|
||||
If the pk is not set the `save()` method will be called.
|
||||
|
||||
```python
|
||||
track = Track(name='The Bird')
|
||||
await track.upsert() # will call save as the pk is empty
|
||||
|
||||
track = await Track.objects.get(name='The Bird')
|
||||
await track.upsert(name='The Bird Strikes Again') # will call update as pk is already populated
|
||||
```
|
||||
|
||||
|
||||
### delete
|
||||
|
||||
You can delete models by using `QuerySet.delete()` method or by using your model and calling `delete()` method.
|
||||
@ -271,13 +339,28 @@ await track.delete() # will delete the model from database
|
||||
!!!tip
|
||||
Note that that `track` object stays the same, only record in the database is removed.
|
||||
|
||||
### update
|
||||
### save_related
|
||||
|
||||
You can delete models by using `QuerySet.update()` method or by using your model and calling `update()` method.
|
||||
`save_related(follow: bool = False) -> None`
|
||||
|
||||
Method goes through all relations of the `Model` on which the method is called,
|
||||
and calls `upsert()` method on each model that is **not** saved.
|
||||
|
||||
To understand when a model is saved check [save status][save status] section above.
|
||||
|
||||
By default the `save_related` method saved only models that are directly related (one step away) to the model on which the method is called.
|
||||
|
||||
But you can specify the `follow=True` parameter to traverse through nested models and save all of them in the relation tree.
|
||||
|
||||
!!!warning
|
||||
To avoid circular updates with `follow=True` set, `save_related` keeps a set of already visited Models,
|
||||
and won't perform nested `save_related` on Models that were already visited.
|
||||
|
||||
So if you have a diamond or circular relations types you need to perform the updates in a manual way.
|
||||
|
||||
```python
|
||||
track = await Track.objects.get(name='The Bird')
|
||||
await track.update(name='The Bird Strikes Again')
|
||||
# in example like this the second Street (coming from City) won't be save_related, so ZipCode won't be updated
|
||||
Street -> District -> City -> Street -> ZipCode
|
||||
```
|
||||
|
||||
## Internals
|
||||
@ -348,3 +431,4 @@ For example to list table model fields you can:
|
||||
[sqlalchemy connection string]: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
|
||||
[sqlalchemy table creation]: https://docs.sqlalchemy.org/en/13/core/metadata.html#creating-and-dropping-database-tables
|
||||
[alembic]: https://alembic.sqlalchemy.org/en/latest/tutorial.html
|
||||
[save status]: ../models/#model-save-status
|
||||
|
||||
@ -1,3 +1,24 @@
|
||||
# 0.5.0
|
||||
|
||||
* Added save status -> you can check if model is saved with `ModelInstance.saved` property
|
||||
* Model is saved after `save/update/load/upsert` method on model
|
||||
* Model is saved after `create/get/first/all/get_or_create/update_or_create` method
|
||||
* Model is saved when passed to `bulk_update` and `bulk_create`
|
||||
* Model is saved after adding/removing `ManyToMany` related objects (through model instance auto saved/deleted)
|
||||
* Model is **not** saved after change of any own field (including pk as `Model.pk` alias)
|
||||
* Model is **not** saved after adding/removing `ForeignKey` related object (fk column not saved)
|
||||
* Model is **not** saved after instantation with `__init__` (w/o `QuerySet.create` or before calling `save`)
|
||||
* Added `Model.upsert(**kwargs)` that performs `save()` if pk not set otherwise `update(**kwargs)`
|
||||
* Added `Model.save_related(follow=False)` that iterates all related objects in all relations and checks if they are saved. If not it calls `upsert()` on each of them.
|
||||
* **Breaking:** added raising exceptions if `add`-ing/`remove`-ing not saved (pk is None) models to `ManyToMany` relation
|
||||
* Allow passing dictionaries and sets to fields and exclude_fields
|
||||
* Auto translate str and lists to dicts for fields and exclude_fields
|
||||
* **Breaking:** passing nested models to fields and exclude_fields is now by related ForeignKey name and not by target model name
|
||||
* Performance optimizations - in modelproxy, newbasemodel - > less queries, some properties are cached on models
|
||||
* Cleanup of unused relations code
|
||||
* Optional performance dependency orjson added (**strongly recommended**)
|
||||
* Updated docs
|
||||
|
||||
# 0.4.4
|
||||
|
||||
* add exclude_fields() method to exclude fields from sql
|
||||
|
||||
@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover
|
||||
|
||||
Undefined = UndefinedType()
|
||||
|
||||
__version__ = "0.4.4"
|
||||
__version__ = "0.5.0"
|
||||
__all__ = [
|
||||
"Integer",
|
||||
"BigInteger",
|
||||
|
||||
@ -24,3 +24,7 @@ class QueryDefinitionError(AsyncOrmException):
|
||||
|
||||
class RelationshipInstanceError(AsyncOrmException):
|
||||
pass
|
||||
|
||||
|
||||
class ModelPersistenceError(AsyncOrmException):
|
||||
pass
|
||||
|
||||
@ -1,9 +1,21 @@
|
||||
import itertools
|
||||
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
import ormar.queryset # noqa I100
|
||||
from ormar.exceptions import ModelPersistenceError
|
||||
from ormar.fields.many_to_many import ManyToManyField
|
||||
from ormar.models import NewBaseModel # noqa I100
|
||||
from ormar.models.metaclass import ModelMeta
|
||||
@ -90,9 +102,13 @@ class Model(NewBaseModel):
|
||||
exclude_fields=exclude_fields,
|
||||
)
|
||||
|
||||
instance: Optional[T] = cls(**item) if item.get(
|
||||
cls.Meta.pkname, None
|
||||
) is not None else None
|
||||
instance: Optional[T] = None
|
||||
if item.get(cls.Meta.pkname, None) is not None:
|
||||
instance = cls(**item)
|
||||
instance.set_save_status(True)
|
||||
else:
|
||||
instance = None
|
||||
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
@ -165,6 +181,11 @@ class Model(NewBaseModel):
|
||||
|
||||
return item
|
||||
|
||||
async def upsert(self: T, **kwargs: Any) -> T:
|
||||
if not self.pk:
|
||||
return await self.save()
|
||||
return await self.update(**kwargs)
|
||||
|
||||
async def save(self: T) -> T:
|
||||
self_fields = self._extract_model_db_fields()
|
||||
|
||||
@ -179,13 +200,61 @@ class Model(NewBaseModel):
|
||||
item_id = await self.Meta.database.execute(expr)
|
||||
if item_id: # postgress does not return id if it's already there
|
||||
setattr(self, self.Meta.pkname, item_id)
|
||||
self.set_save_status(True)
|
||||
return self
|
||||
|
||||
async def save_related( # noqa: CCR001
|
||||
self, follow: bool = False, visited: Set = None, update_count: int = 0
|
||||
) -> int: # noqa: CCR001
|
||||
if not visited:
|
||||
visited = {self.__class__}
|
||||
else:
|
||||
visited = {x for x in visited}
|
||||
visited.add(self.__class__)
|
||||
|
||||
for related in self.extract_related_names():
|
||||
if self.Meta.model_fields[related].virtual or issubclass(
|
||||
self.Meta.model_fields[related], ManyToManyField
|
||||
):
|
||||
for rel in getattr(self, related):
|
||||
update_count, visited = await self._update_and_follow(
|
||||
rel=rel,
|
||||
follow=follow,
|
||||
visited=visited,
|
||||
update_count=update_count,
|
||||
)
|
||||
visited.add(self.Meta.model_fields[related].to)
|
||||
else:
|
||||
rel = getattr(self, related)
|
||||
update_count, visited = await self._update_and_follow(
|
||||
rel=rel, follow=follow, visited=visited, update_count=update_count
|
||||
)
|
||||
visited.add(rel.__class__)
|
||||
return update_count
|
||||
|
||||
@staticmethod
|
||||
async def _update_and_follow(
|
||||
rel: T, follow: bool, visited: Set, update_count: int
|
||||
) -> Tuple[int, Set]:
|
||||
if follow and rel.__class__ not in visited:
|
||||
update_count = await rel.save_related(
|
||||
follow=follow, visited=visited, update_count=update_count
|
||||
)
|
||||
if not rel.saved:
|
||||
await rel.upsert()
|
||||
update_count += 1
|
||||
return update_count, visited
|
||||
|
||||
async def update(self: T, **kwargs: Any) -> T:
|
||||
if kwargs:
|
||||
new_values = {**self.dict(), **kwargs}
|
||||
self.from_dict(new_values)
|
||||
|
||||
if not self.pk:
|
||||
raise ModelPersistenceError(
|
||||
"You cannot update not saved model! Use save or upsert method."
|
||||
)
|
||||
|
||||
self_fields = self._extract_model_db_fields()
|
||||
self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname))
|
||||
self_fields = self.translate_columns_to_aliases(self_fields)
|
||||
@ -193,12 +262,14 @@ class Model(NewBaseModel):
|
||||
expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname))
|
||||
|
||||
await self.Meta.database.execute(expr)
|
||||
self.set_save_status(True)
|
||||
return self
|
||||
|
||||
async def delete(self: T) -> int:
|
||||
expr = self.Meta.table.delete()
|
||||
expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname)))
|
||||
result = await self.Meta.database.execute(expr)
|
||||
self.set_save_status(False)
|
||||
return result
|
||||
|
||||
async def load(self: T) -> T:
|
||||
@ -211,4 +282,5 @@ class Model(NewBaseModel):
|
||||
kwargs = dict(row)
|
||||
kwargs = self.translate_aliases_to_columns(kwargs)
|
||||
self.from_dict(kwargs)
|
||||
self.set_save_status(True)
|
||||
return self
|
||||
|
||||
@ -12,7 +12,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from ormar.exceptions import RelationshipInstanceError
|
||||
from ormar.exceptions import ModelPersistenceError, RelationshipInstanceError
|
||||
|
||||
try:
|
||||
import orjson as json
|
||||
@ -63,8 +63,14 @@ class ModelTableProxy:
|
||||
target_field = cls.Meta.model_fields[field]
|
||||
target_pkname = target_field.to.Meta.pkname
|
||||
if isinstance(field_value, ormar.Model):
|
||||
model_dict[field] = getattr(field_value, target_pkname)
|
||||
elif field_value:
|
||||
pk_value = getattr(field_value, target_pkname)
|
||||
if not pk_value:
|
||||
raise ModelPersistenceError(
|
||||
f"You cannot save {field_value.get_name()} "
|
||||
f"model without pk set!"
|
||||
)
|
||||
model_dict[field] = pk_value
|
||||
elif field_value: # nested dict
|
||||
model_dict[field] = field_value.get(target_pkname)
|
||||
else:
|
||||
model_dict.pop(field, None)
|
||||
@ -220,6 +226,7 @@ class ModelTableProxy:
|
||||
field,
|
||||
cls.merge_two_instances(current_field, getattr(other, field)),
|
||||
)
|
||||
other.set_save_status(True)
|
||||
return other
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -123,12 +123,16 @@ class NewBaseModel(
|
||||
object.__setattr__(self, name, value)
|
||||
elif name == "pk":
|
||||
object.__setattr__(self, self.Meta.pkname, value)
|
||||
self.set_save_status(False)
|
||||
elif name in self._orm:
|
||||
model = self.Meta.model_fields[name].expand_relationship(value, self)
|
||||
if isinstance(self.__dict__.get(name), list):
|
||||
# virtual foreign key or many to many
|
||||
self.__dict__[name].append(model)
|
||||
else:
|
||||
# foreign key relation
|
||||
self.__dict__[name] = model
|
||||
self.set_save_status(False)
|
||||
else:
|
||||
value = (
|
||||
self._convert_json(name, value, "dumps")
|
||||
@ -136,6 +140,7 @@ class NewBaseModel(
|
||||
else value
|
||||
)
|
||||
super().__setattr__(name, value)
|
||||
self.set_save_status(False)
|
||||
|
||||
def __getattribute__(self, item: str) -> Any:
|
||||
if item in (
|
||||
@ -188,6 +193,10 @@ class NewBaseModel(
|
||||
def pk_column(self) -> sqlalchemy.Column:
|
||||
return self.Meta.table.primary_key.columns.values()[0]
|
||||
|
||||
@property
|
||||
def saved(self) -> bool:
|
||||
return self._orm_saved
|
||||
|
||||
@classmethod
|
||||
def pk_type(cls) -> Any:
|
||||
return cls.Meta.model_fields[cls.Meta.pkname].__type__
|
||||
@ -199,6 +208,9 @@ class NewBaseModel(
|
||||
def remove(self, name: "T") -> None:
|
||||
self._orm.remove_parent(self, name)
|
||||
|
||||
def set_save_status(self, status: bool) -> None:
|
||||
object.__setattr__(self, "_orm_saved", status)
|
||||
|
||||
@classmethod
|
||||
def get_properties(
|
||||
cls,
|
||||
@ -212,7 +224,8 @@ class NewBaseModel(
|
||||
prop
|
||||
for prop in dir(cls)
|
||||
if isinstance(getattr(cls, prop), property)
|
||||
and prop not in ("__values__", "__fields__", "fields", "pk_column")
|
||||
and prop
|
||||
not in ("__values__", "__fields__", "fields", "pk_column", "saved")
|
||||
]
|
||||
cls._props = props
|
||||
if include:
|
||||
|
||||
@ -358,8 +358,13 @@ class QuerySet:
|
||||
instance.pk = pk
|
||||
|
||||
# refresh server side defaults
|
||||
if any(
|
||||
field.server_default is not None
|
||||
for name, field in self.model.Meta.model_fields.items()
|
||||
if name not in kwargs
|
||||
):
|
||||
instance = await instance.load()
|
||||
|
||||
instance.set_save_status(True)
|
||||
return instance
|
||||
|
||||
async def bulk_create(self, objects: List["Model"]) -> None:
|
||||
@ -372,7 +377,10 @@ class QuerySet:
|
||||
expr = self.table.insert()
|
||||
await self.database.execute_many(expr, ready_objects)
|
||||
|
||||
async def bulk_update(
|
||||
for objt in objects:
|
||||
objt.set_save_status(True)
|
||||
|
||||
async def bulk_update( # noqa: CCR001
|
||||
self, objects: List["Model"], columns: List[str] = None
|
||||
) -> None:
|
||||
ready_objects = []
|
||||
@ -418,3 +426,6 @@ class QuerySet:
|
||||
# otherwise it just passes all data to values and results in unconsumed columns
|
||||
expr = str(expr)
|
||||
await self.database.execute_many(expr, ready_objects)
|
||||
|
||||
for objt in objects:
|
||||
objt.set_save_status(True)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import RelationshipInstanceError
|
||||
from ormar.exceptions import NoMatch, RelationshipInstanceError
|
||||
from ormar.relations.querysetproxy import QuerysetProxy
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
@ -54,6 +54,11 @@ class RelationProxy(list):
|
||||
return queryset
|
||||
|
||||
async def remove(self, item: "Model") -> None: # type: ignore
|
||||
if item not in self:
|
||||
raise NoMatch(
|
||||
f"Object {self._owner.get_name()} has no "
|
||||
f"{item.get_name()} with given primary key!"
|
||||
)
|
||||
super().remove(item)
|
||||
rel_name = item.resolve_relation_name(item, self._owner)
|
||||
relation = item._orm._get(rel_name)
|
||||
|
||||
@ -6,7 +6,7 @@ import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import RelationshipInstanceError
|
||||
from ormar.exceptions import ModelPersistenceError, NoMatch, RelationshipInstanceError
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
@ -196,3 +196,27 @@ async def test_selecting_related_fail_without_saving(cleanup):
|
||||
post = Post(title="Hello, M2M", author=guido)
|
||||
with pytest.raises(RelationshipInstanceError):
|
||||
await post.categories.all()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adding_unsaved_related(cleanup):
|
||||
async with database:
|
||||
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
|
||||
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||
news = Category(name="News")
|
||||
with pytest.raises(ModelPersistenceError):
|
||||
await post.categories.add(news)
|
||||
|
||||
await news.save()
|
||||
await post.categories.add(news)
|
||||
assert len(await post.categories.all()) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_removing_unsaved_related(cleanup):
|
||||
async with database:
|
||||
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
|
||||
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||
news = Category(name="News")
|
||||
with pytest.raises(NoMatch):
|
||||
await post.categories.remove(news)
|
||||
|
||||
210
tests/test_save_related.py
Normal file
210
tests/test_save_related.py
Normal file
@ -0,0 +1,210 @@
|
||||
from typing import List
|
||||
|
||||
import databases
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
class CringeLevel(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "levels"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100)
|
||||
|
||||
|
||||
class NickNames(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "nicks"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
|
||||
is_lame: bool = ormar.Boolean(nullable=True)
|
||||
level: CringeLevel = ormar.ForeignKey(CringeLevel)
|
||||
|
||||
|
||||
class NicksHq(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "nicks_x_hq"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
|
||||
class HQ(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "hqs"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
|
||||
nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq)
|
||||
|
||||
|
||||
class Company(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "companies"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100, nullable=False, name="company_name")
|
||||
founded: int = ormar.Integer(nullable=True)
|
||||
hq: HQ = ormar.ForeignKey(HQ, related_name="companies")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def create_test_database():
|
||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||
metadata.drop_all(engine)
|
||||
metadata.create_all(engine)
|
||||
yield
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saving_related_fk_rel():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
hq = await HQ.objects.create(name="Main")
|
||||
comp = await Company.objects.create(name="Banzai", founded=1988, hq=hq)
|
||||
assert comp.saved
|
||||
|
||||
count = await comp.save_related()
|
||||
assert count == 0
|
||||
|
||||
comp.hq.name = "Suburbs"
|
||||
assert not comp.hq.saved
|
||||
assert comp.saved
|
||||
|
||||
count = await comp.save_related()
|
||||
assert count == 1
|
||||
assert comp.hq.saved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saving_many_to_many():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
nick1 = await NickNames.objects.create(name="BazingaO", is_lame=False)
|
||||
nick2 = await NickNames.objects.create(name="Bazinga20", is_lame=True)
|
||||
|
||||
hq = await HQ.objects.create(name="Main")
|
||||
assert hq.saved
|
||||
|
||||
await hq.nicks.add(nick1)
|
||||
assert hq.saved
|
||||
await hq.nicks.add(nick2)
|
||||
assert hq.saved
|
||||
|
||||
count = await hq.save_related()
|
||||
assert count == 0
|
||||
|
||||
hq.nicks[0].name = "Kabucha"
|
||||
hq.nicks[1].name = "Kabucha2"
|
||||
assert not hq.nicks[0].saved
|
||||
assert not hq.nicks[1].saved
|
||||
|
||||
count = await hq.save_related()
|
||||
assert count == 2
|
||||
assert hq.nicks[0].saved
|
||||
assert hq.nicks[1].saved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saving_reversed_relation():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
hq = await HQ.objects.create(name="Main")
|
||||
await Company.objects.create(name="Banzai", founded=1988, hq=hq)
|
||||
|
||||
hq = await HQ.objects.select_related("companies").get(name="Main")
|
||||
assert hq.saved
|
||||
assert hq.companies[0].saved
|
||||
|
||||
hq.companies[0].name = "Konichiwa"
|
||||
assert not hq.companies[0].saved
|
||||
count = await hq.save_related()
|
||||
assert count == 1
|
||||
assert hq.companies[0].saved
|
||||
|
||||
await Company.objects.create(name="Joshua", founded=1888, hq=hq)
|
||||
|
||||
hq = await HQ.objects.select_related("companies").get(name="Main")
|
||||
assert hq.saved
|
||||
assert hq.companies[0].saved
|
||||
assert hq.companies[1].saved
|
||||
|
||||
hq.companies[0].name = hq.companies[0].name + "20"
|
||||
assert not hq.companies[0].saved
|
||||
# save only if not saved so now only one
|
||||
count = await hq.save_related()
|
||||
assert count == 1
|
||||
assert hq.companies[0].saved
|
||||
|
||||
hq.companies[0].name = hq.companies[0].name + "20"
|
||||
hq.companies[1].name = hq.companies[1].name + "30"
|
||||
assert not hq.companies[0].saved
|
||||
assert not hq.companies[1].saved
|
||||
count = await hq.save_related()
|
||||
assert count == 2
|
||||
assert hq.companies[0].saved
|
||||
assert hq.companies[1].saved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saving_nested():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
level = await CringeLevel.objects.create(name="High")
|
||||
level2 = await CringeLevel.objects.create(name="Low")
|
||||
nick1 = await NickNames.objects.create(
|
||||
name="BazingaO", is_lame=False, level=level
|
||||
)
|
||||
nick2 = await NickNames.objects.create(
|
||||
name="Bazinga20", is_lame=True, level=level2
|
||||
)
|
||||
|
||||
hq = await HQ.objects.create(name="Main")
|
||||
assert hq.saved
|
||||
|
||||
await hq.nicks.add(nick1)
|
||||
assert hq.saved
|
||||
await hq.nicks.add(nick2)
|
||||
assert hq.saved
|
||||
|
||||
count = await hq.save_related()
|
||||
assert count == 0
|
||||
|
||||
hq.nicks[0].level.name = "Medium"
|
||||
assert not hq.nicks[0].level.saved
|
||||
assert hq.nicks[0].saved
|
||||
|
||||
count = await hq.save_related(follow=True)
|
||||
assert count == 1
|
||||
assert hq.nicks[0].saved
|
||||
assert hq.nicks[0].level.saved
|
||||
|
||||
hq.nicks[0].level.name = "Low"
|
||||
hq.nicks[1].level.name = "Medium"
|
||||
assert not hq.nicks[0].level.saved
|
||||
assert not hq.nicks[1].level.saved
|
||||
assert hq.nicks[0].saved
|
||||
assert hq.nicks[1].saved
|
||||
|
||||
count = await hq.save_related(follow=True)
|
||||
assert count == 2
|
||||
assert hq.nicks[0].saved
|
||||
assert hq.nicks[0].level.saved
|
||||
assert hq.nicks[1].saved
|
||||
assert hq.nicks[1].level.saved
|
||||
257
tests/test_save_status.py
Normal file
257
tests/test_save_status.py
Normal file
@ -0,0 +1,257 @@
|
||||
from typing import List
|
||||
|
||||
import databases
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import ModelPersistenceError
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
class NickNames(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "nicks"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
|
||||
is_lame: bool = ormar.Boolean(nullable=True)
|
||||
|
||||
|
||||
class NicksHq(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "nicks_x_hq"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
|
||||
class HQ(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "hqs"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
|
||||
nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq)
|
||||
|
||||
|
||||
class Company(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "companies"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100, nullable=False, name="company_name")
|
||||
founded: int = ormar.Integer(nullable=True)
|
||||
hq: HQ = ormar.ForeignKey(HQ)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def create_test_database():
|
||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||
metadata.drop_all(engine)
|
||||
metadata.create_all(engine)
|
||||
yield
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_instantation_false_save_true():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
comp = Company(name="Banzai", founded=1988)
|
||||
assert not comp.saved
|
||||
await comp.save()
|
||||
assert comp.saved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saved_edited_not_saved():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
comp = await Company.objects.create(name="Banzai", founded=1988)
|
||||
assert comp.saved
|
||||
comp.name = "Banzai2"
|
||||
assert not comp.saved
|
||||
|
||||
await comp.update()
|
||||
assert comp.saved
|
||||
|
||||
await comp.update(name="Banzai3")
|
||||
assert comp.saved
|
||||
|
||||
comp.pk = 999
|
||||
assert not comp.saved
|
||||
|
||||
await comp.update()
|
||||
assert comp.saved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adding_related_gets_dirty():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
hq = await HQ.objects.create(name="Main")
|
||||
comp = await Company.objects.create(name="Banzai", founded=1988)
|
||||
assert comp.saved
|
||||
|
||||
comp.hq = hq
|
||||
assert not comp.saved
|
||||
await comp.update()
|
||||
assert comp.saved
|
||||
|
||||
comp = await Company.objects.select_related("hq").get(name="Banzai")
|
||||
assert comp.saved
|
||||
|
||||
assert comp.hq.pk == hq.pk
|
||||
assert comp.hq.saved
|
||||
|
||||
comp.hq.name = "Suburbs"
|
||||
assert not comp.hq.saved
|
||||
assert comp.saved
|
||||
|
||||
await comp.hq.update()
|
||||
assert comp.hq.saved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adding_many_to_many_does_not_gets_dirty():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
nick1 = await NickNames.objects.create(name="Bazinga", is_lame=False)
|
||||
nick2 = await NickNames.objects.create(name="Bazinga2", is_lame=True)
|
||||
|
||||
hq = await HQ.objects.create(name="Main")
|
||||
assert hq.saved
|
||||
|
||||
await hq.nicks.add(nick1)
|
||||
assert hq.saved
|
||||
await hq.nicks.add(nick2)
|
||||
assert hq.saved
|
||||
|
||||
hq = await HQ.objects.select_related("nicks").get(name="Main")
|
||||
assert hq.saved
|
||||
assert hq.nicks[0].saved
|
||||
|
||||
await hq.nicks.remove(nick1)
|
||||
assert hq.saved
|
||||
|
||||
hq.nicks[0].name = "Kabucha"
|
||||
assert not hq.nicks[0].saved
|
||||
|
||||
await hq.nicks[0].update()
|
||||
assert hq.nicks[0].saved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
comp = await Company.objects.create(name="Banzai", founded=1988)
|
||||
assert comp.saved
|
||||
await comp.delete()
|
||||
assert not comp.saved
|
||||
|
||||
await comp.update()
|
||||
assert comp.saved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
comp = await Company.objects.create(name="Banzai", founded=1988)
|
||||
assert comp.saved
|
||||
comp.name = "AA"
|
||||
assert not comp.saved
|
||||
|
||||
await comp.load()
|
||||
assert comp.saved
|
||||
assert comp.name == "Banzai"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queryset_methods():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
await Company.objects.create(name="Banzai", founded=1988)
|
||||
await Company.objects.create(name="Yuhu", founded=1989)
|
||||
await Company.objects.create(name="Konono", founded=1990)
|
||||
await Company.objects.create(name="Sumaaa", founded=1991)
|
||||
|
||||
comp = await Company.objects.get(name="Banzai")
|
||||
assert comp.saved
|
||||
|
||||
comp = await Company.objects.first()
|
||||
assert comp.saved
|
||||
|
||||
comps = await Company.objects.all()
|
||||
assert [comp.saved for comp in comps]
|
||||
|
||||
comp2 = await Company.objects.get_or_create(name="Banzai_new", founded=2001)
|
||||
assert comp2.saved
|
||||
|
||||
comp3 = await Company.objects.get_or_create(name="Banzai", founded=1988)
|
||||
assert comp3.saved
|
||||
assert comp3.pk == comp.pk
|
||||
|
||||
update_dict = comp.dict()
|
||||
update_dict["founded"] = 2010
|
||||
comp = await Company.objects.update_or_create(**update_dict)
|
||||
assert comp.saved
|
||||
assert comp.founded == 2010
|
||||
|
||||
create_dict = {"name": "Yoko", "founded": 2005}
|
||||
comp = await Company.objects.update_or_create(**create_dict)
|
||||
assert comp.saved
|
||||
assert comp.founded == 2005
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_methods():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
c1 = Company(name="Banzai", founded=1988)
|
||||
c2 = Company(name="Yuhu", founded=1989)
|
||||
|
||||
await Company.objects.bulk_create([c1, c2])
|
||||
assert c1.saved
|
||||
assert c2.saved
|
||||
|
||||
c1, c2 = await Company.objects.all()
|
||||
c1.name = "Banzai2"
|
||||
c2.name = "Yuhu2"
|
||||
|
||||
assert not c1.saved
|
||||
assert not c2.saved
|
||||
|
||||
await Company.objects.bulk_update([c1, c2])
|
||||
assert c1.saved
|
||||
assert c2.saved
|
||||
|
||||
c3 = Company(name="Cobra", founded=2088)
|
||||
assert not c3.saved
|
||||
|
||||
with pytest.raises(ModelPersistenceError):
|
||||
await c3.update()
|
||||
|
||||
await c3.upsert()
|
||||
assert c3.saved
|
||||
|
||||
c3.name = "Python"
|
||||
assert not c3.saved
|
||||
|
||||
await c3.upsert()
|
||||
assert c3.saved
|
||||
assert c3.name == "Python"
|
||||
|
||||
await c3.upsert(founded=2077)
|
||||
assert c3.saved
|
||||
assert c3.founded == 2077
|
||||
Reference in New Issue
Block a user