From b0cf9165312db63c6ba3a7857b94642c6e486f2d Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 13 Nov 2020 13:39:19 +0100 Subject: [PATCH 01/10] add saving status and basic test for this --- ormar/models/model.py | 59 +++++---- ormar/models/modelproxy.py | 1 + ormar/models/newbasemodel.py | 56 +++++---- ormar/queryset/queryset.py | 12 +- tests/test_save_status.py | 224 +++++++++++++++++++++++++++++++++++ 5 files changed, 300 insertions(+), 52 deletions(-) create mode 100644 tests/test_save_status.py diff --git a/ormar/models/model.py b/ormar/models/model.py index 688eed0..f215ffc 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -42,13 +42,13 @@ class Model(NewBaseModel): @classmethod def from_row( # noqa CCR001 - cls: Type[T], - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - related_models: Any = None, - previous_table: str = None, - fields: Optional[Union[Dict, Set]] = None, - exclude_fields: Optional[Union[Dict, Set]] = None, + cls: Type[T], + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + related_models: Any = None, + previous_table: str = None, + fields: Optional[Union[Dict, Set]] = None, + exclude_fields: Optional[Union[Dict, Set]] = None, ) -> Optional[T]: item: Dict[str, Any] = {} @@ -58,9 +58,9 @@ class Model(NewBaseModel): related_models = group_related_list(select_related) if ( - previous_table - and previous_table in cls.Meta.model_fields - and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField) + previous_table + and previous_table in cls.Meta.model_fields + and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField) ): previous_table = cls.Meta.model_fields[ previous_table @@ -90,20 +90,23 @@ class Model(NewBaseModel): exclude_fields=exclude_fields, ) - instance: Optional[T] = cls(**item) if item.get( - cls.Meta.pkname, None - ) is not None else None + if item.get(cls.Meta.pkname, None) is not None: + instance: Optional[T] = cls(**item) + instance.set_save_status(True) + else: + instance = None + return instance @classmethod def populate_nested_models_from_row( # noqa: CFQ002 - cls, - item: dict, - row: sqlalchemy.engine.ResultProxy, - related_models: Any, - previous_table: sqlalchemy.Table, - fields: Optional[Union[Dict, Set]] = None, - exclude_fields: Optional[Union[Dict, Set]] = None, + cls, + item: dict, + row: sqlalchemy.engine.ResultProxy, + related_models: Any, + previous_table: sqlalchemy.Table, + fields: Optional[Union[Dict, Set]] = None, + exclude_fields: Optional[Union[Dict, Set]] = None, ) -> dict: for related in related_models: if isinstance(related_models, dict) and related_models[related]: @@ -137,12 +140,12 @@ class Model(NewBaseModel): @classmethod def extract_prefixed_table_columns( # noqa CCR001 - cls, - item: dict, - row: sqlalchemy.engine.result.ResultProxy, - table_prefix: str, - fields: Optional[Union[Dict, Set]] = None, - exclude_fields: Optional[Union[Dict, Set]] = None, + cls, + item: dict, + row: sqlalchemy.engine.result.ResultProxy, + table_prefix: str, + fields: Optional[Union[Dict, Set]] = None, + exclude_fields: Optional[Union[Dict, Set]] = None, ) -> dict: # databases does not keep aliases in Record for postgres, change to raw row @@ -179,6 +182,7 @@ class Model(NewBaseModel): item_id = await self.Meta.database.execute(expr) if item_id: # postgress does not return id if it's already there setattr(self, self.Meta.pkname, item_id) + self.set_save_status(True) return self async def update(self: T, **kwargs: Any) -> T: @@ -193,12 +197,14 @@ class Model(NewBaseModel): expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname)) await self.Meta.database.execute(expr) + self.set_save_status(True) return self async def delete(self: T) -> int: expr = self.Meta.table.delete() expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname))) result = await self.Meta.database.execute(expr) + self.set_save_status(False) return result async def load(self: T) -> T: @@ -211,4 +217,5 @@ class Model(NewBaseModel): kwargs = dict(row) kwargs = self.translate_aliases_to_columns(kwargs) self.from_dict(kwargs) + self.set_save_status(True) return self diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 294d592..c457bb2 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -220,6 +220,7 @@ class ModelTableProxy: field, cls.merge_two_instances(current_field, getattr(other, field)), ) + other.set_save_status(True) return other @staticmethod diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 2a96c19..7e3daa7 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -123,12 +123,16 @@ class NewBaseModel( object.__setattr__(self, name, value) elif name == "pk": object.__setattr__(self, self.Meta.pkname, value) + self.set_save_status(False) elif name in self._orm: model = self.Meta.model_fields[name].expand_relationship(value, self) if isinstance(self.__dict__.get(name), list): + # virtual foreign key or many to many self.__dict__[name].append(model) else: + # foreign key relation self.__dict__[name] = model + self.set_save_status(False) else: value = ( self._convert_json(name, value, "dumps") @@ -136,15 +140,16 @@ class NewBaseModel( else value ) super().__setattr__(name, value) + self.set_save_status(False) def __getattribute__(self, item: str) -> Any: if item in ( - "_orm_id", - "_orm_saved", - "_orm", - "__fields__", - "_related_names", - "_props", + "_orm_id", + "_orm_saved", + "_orm", + "__fields__", + "_related_names", + "_props", ): return object.__getattribute__(self, item) if item == "pk": @@ -158,7 +163,7 @@ class NewBaseModel( return super().__getattribute__(item) def _extract_related_model_instead_of_field( - self, item: str + self, item: str ) -> Optional[Union["T", Sequence["T"]]]: # alias = self.get_column_alias(item) if item in self._orm: @@ -172,9 +177,9 @@ class NewBaseModel( def __same__(self, other: "NewBaseModel") -> bool: return ( - self._orm_id == other._orm_id - or self.dict() == other.dict() - or (self.pk == other.pk and self.pk is not None) + self._orm_id == other._orm_id + or self.dict() == other.dict() + or (self.pk == other.pk and self.pk is not None) ) @classmethod @@ -199,11 +204,14 @@ class NewBaseModel( def remove(self, name: "T") -> None: self._orm.remove_parent(self, name) + def set_save_status(self, status: bool) -> None: + object.__setattr__(self, "_orm_saved", status) + @classmethod def get_properties( - cls, - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + cls, + include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, ) -> List[str]: if isinstance(cls._props, list): props = cls._props @@ -212,7 +220,7 @@ class NewBaseModel( prop for prop in dir(cls) if isinstance(getattr(cls, prop), property) - and prop not in ("__values__", "__fields__", "fields", "pk_column") + and prop not in ("__values__", "__fields__", "fields", "pk_column") ] cls._props = props if include: @@ -222,16 +230,16 @@ class NewBaseModel( return props def dict( # noqa A003 - self, - *, - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - by_alias: bool = False, - skip_defaults: bool = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - nested: bool = False + self, + *, + include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + nested: bool = False ) -> "DictStrAny": # noqa: A003' dict_instance = super().dict( include=include, diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 448a0a9..a77a9a8 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -358,8 +358,10 @@ class QuerySet: instance.pk = pk # refresh server side defaults - instance = await instance.load() - + if any(field.server_default is not None + for name, field in self.model.Meta.model_fields.items() if name not in kwargs): + instance = await instance.load() + instance.set_save_status(True) return instance async def bulk_create(self, objects: List["Model"]) -> None: @@ -372,6 +374,9 @@ class QuerySet: expr = self.table.insert() await self.database.execute_many(expr, ready_objects) + for objt in objects: + objt.set_save_status(True) + async def bulk_update( self, objects: List["Model"], columns: List[str] = None ) -> None: @@ -418,3 +423,6 @@ class QuerySet: # otherwise it just passes all data to values and results in unconsumed columns expr = str(expr) await self.database.execute_many(expr, ready_objects) + + for objt in objects: + objt.set_save_status(True) diff --git a/tests/test_save_status.py b/tests/test_save_status.py new file mode 100644 index 0000000..baea547 --- /dev/null +++ b/tests/test_save_status.py @@ -0,0 +1,224 @@ +import itertools +from typing import Optional, List + +import databases +import pydantic +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class NickNames(ormar.Model): + class Meta: + tablename = "nicks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + is_lame: bool = ormar.Boolean(nullable=True) + + +class NicksHq(ormar.Model): + class Meta: + tablename = "nicks_x_hq" + metadata = metadata + database = database + + +class HQ(ormar.Model): + class Meta: + tablename = "hqs" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq) + + +class Company(ormar.Model): + class Meta: + tablename = "companies" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="company_name") + founded: int = ormar.Integer(nullable=True) + hq: HQ = ormar.ForeignKey(HQ) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_instantation_false_save_true(): + async with database: + async with database.transaction(force_rollback=True): + comp = Company(name='Banzai', founded=1988) + assert not comp._orm_saved + await comp.save() + assert comp._orm_saved + + +@pytest.mark.asyncio +async def test_saved_edited_not_saved(): + async with database: + async with database.transaction(force_rollback=True): + comp = await Company.objects.create(name='Banzai', founded=1988) + assert comp._orm_saved + comp.name = 'Banzai2' + assert not comp._orm_saved + + await comp.update() + assert comp._orm_saved + + await comp.update(name='Banzai3') + assert comp._orm_saved + + comp.pk = 999 + assert not comp._orm_saved + + await comp.save() + assert comp._orm_saved + + +@pytest.mark.asyncio +async def test_adding_related_gets_dirty(): + async with database: + async with database.transaction(force_rollback=True): + hq = await HQ.objects.create(name='Main') + comp = await Company.objects.create(name='Banzai', founded=1988) + assert comp._orm_saved + + comp.hq = hq + assert not comp._orm_saved + await comp.update() + assert comp._orm_saved + + comp = await Company.objects.select_related('hq').get(name='Banzai') + assert comp._orm_saved + assert comp.hq.pk == hq.pk + assert comp.hq._orm_saved + + +@pytest.mark.asyncio +async def test_adding_many_to_many_does_not_gets_dirty(): + async with database: + async with database.transaction(force_rollback=True): + nick1 = await NickNames.objects.create(name='Bazinga', is_lame=False) + nick2 = await NickNames.objects.create(name='Bazinga2', is_lame=True) + + hq = await HQ.objects.create(name='Main') + assert hq._orm_saved + + await hq.nicks.add(nick1) + assert hq._orm_saved + await hq.nicks.add(nick2) + assert hq._orm_saved + + hq = await HQ.objects.select_related('nicks').get(name='Main') + assert hq._orm_saved + assert hq.nicks[0]._orm_saved + + await hq.nicks.remove(nick1) + assert hq._orm_saved + + +@pytest.mark.asyncio +async def test_delete(): + async with database: + async with database.transaction(force_rollback=True): + comp = await Company.objects.create(name='Banzai', founded=1988) + assert comp._orm_saved + await comp.delete() + assert not comp._orm_saved + + await comp.save() + assert comp._orm_saved + + +@pytest.mark.asyncio +async def test_load(): + async with database: + async with database.transaction(force_rollback=True): + comp = await Company.objects.create(name='Banzai', founded=1988) + assert comp._orm_saved + comp.name = 'AA' + assert not comp._orm_saved + + await comp.load() + assert comp._orm_saved + assert comp.name == 'Banzai' + + +@pytest.mark.asyncio +async def test_queryset_methods(): + async with database: + async with database.transaction(force_rollback=True): + await Company.objects.create(name='Banzai', founded=1988) + await Company.objects.create(name='Yuhu', founded=1989) + await Company.objects.create(name='Konono', founded=1990) + await Company.objects.create(name='Sumaaa', founded=1991) + + comp = await Company.objects.get(name='Banzai') + assert comp._orm_saved + + comp = await Company.objects.first() + assert comp._orm_saved + + comps = await Company.objects.all() + assert [comp._orm_saved for comp in comps] + + comp2 = await Company.objects.get_or_create(name='Banzai_new', founded=2001) + assert comp2._orm_saved + + comp3 = await Company.objects.get_or_create(name='Banzai', founded=1988) + assert comp3._orm_saved + assert comp3.pk == comp.pk + + update_dict = comp.dict() + update_dict['founded'] = 2010 + comp = await Company.objects.update_or_create(**update_dict) + assert comp._orm_saved + assert comp.founded == 2010 + + create_dict = {'name': "Yoko", "founded": 2005} + comp = await Company.objects.update_or_create(**create_dict) + assert comp._orm_saved + assert comp.founded == 2005 + + +@pytest.mark.asyncio +async def test_bulk_methods(): + async with database: + async with database.transaction(force_rollback=True): + c1 = Company(name='Banzai', founded=1988) + c2 = Company(name='Yuhu', founded=1989) + + await Company.objects.bulk_create([c1, c2]) + assert c1._orm_saved + assert c2._orm_saved + + c1, c2 = await Company.objects.all() + c1.name = 'Banzai2' + c2.name = 'Yuhu2' + + assert not c1._orm_saved + assert not c2._orm_saved + + await Company.objects.bulk_update([c1, c2]) + assert c1._orm_saved + assert c2._orm_saved From 1f67da3a5c1dae583f32a90ff62597347978be93 Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 13 Nov 2020 16:21:12 +0100 Subject: [PATCH 02/10] add save status and tests --- ormar/models/model.py | 49 ++++++++++++++++++------------------ ormar/models/newbasemodel.py | 48 +++++++++++++++++------------------ ormar/queryset/queryset.py | 9 ++++--- 3 files changed, 55 insertions(+), 51 deletions(-) diff --git a/ormar/models/model.py b/ormar/models/model.py index f215ffc..5141604 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -42,13 +42,13 @@ class Model(NewBaseModel): @classmethod def from_row( # noqa CCR001 - cls: Type[T], - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - related_models: Any = None, - previous_table: str = None, - fields: Optional[Union[Dict, Set]] = None, - exclude_fields: Optional[Union[Dict, Set]] = None, + cls: Type[T], + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + related_models: Any = None, + previous_table: str = None, + fields: Optional[Union[Dict, Set]] = None, + exclude_fields: Optional[Union[Dict, Set]] = None, ) -> Optional[T]: item: Dict[str, Any] = {} @@ -58,9 +58,9 @@ class Model(NewBaseModel): related_models = group_related_list(select_related) if ( - previous_table - and previous_table in cls.Meta.model_fields - and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField) + previous_table + and previous_table in cls.Meta.model_fields + and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField) ): previous_table = cls.Meta.model_fields[ previous_table @@ -90,8 +90,9 @@ class Model(NewBaseModel): exclude_fields=exclude_fields, ) + instance: Optional[T] = None if item.get(cls.Meta.pkname, None) is not None: - instance: Optional[T] = cls(**item) + instance = cls(**item) instance.set_save_status(True) else: instance = None @@ -100,13 +101,13 @@ class Model(NewBaseModel): @classmethod def populate_nested_models_from_row( # noqa: CFQ002 - cls, - item: dict, - row: sqlalchemy.engine.ResultProxy, - related_models: Any, - previous_table: sqlalchemy.Table, - fields: Optional[Union[Dict, Set]] = None, - exclude_fields: Optional[Union[Dict, Set]] = None, + cls, + item: dict, + row: sqlalchemy.engine.ResultProxy, + related_models: Any, + previous_table: sqlalchemy.Table, + fields: Optional[Union[Dict, Set]] = None, + exclude_fields: Optional[Union[Dict, Set]] = None, ) -> dict: for related in related_models: if isinstance(related_models, dict) and related_models[related]: @@ -140,12 +141,12 @@ class Model(NewBaseModel): @classmethod def extract_prefixed_table_columns( # noqa CCR001 - cls, - item: dict, - row: sqlalchemy.engine.result.ResultProxy, - table_prefix: str, - fields: Optional[Union[Dict, Set]] = None, - exclude_fields: Optional[Union[Dict, Set]] = None, + cls, + item: dict, + row: sqlalchemy.engine.result.ResultProxy, + table_prefix: str, + fields: Optional[Union[Dict, Set]] = None, + exclude_fields: Optional[Union[Dict, Set]] = None, ) -> dict: # databases does not keep aliases in Record for postgres, change to raw row diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 7e3daa7..971bba4 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -144,12 +144,12 @@ class NewBaseModel( def __getattribute__(self, item: str) -> Any: if item in ( - "_orm_id", - "_orm_saved", - "_orm", - "__fields__", - "_related_names", - "_props", + "_orm_id", + "_orm_saved", + "_orm", + "__fields__", + "_related_names", + "_props", ): return object.__getattribute__(self, item) if item == "pk": @@ -163,7 +163,7 @@ class NewBaseModel( return super().__getattribute__(item) def _extract_related_model_instead_of_field( - self, item: str + self, item: str ) -> Optional[Union["T", Sequence["T"]]]: # alias = self.get_column_alias(item) if item in self._orm: @@ -177,9 +177,9 @@ class NewBaseModel( def __same__(self, other: "NewBaseModel") -> bool: return ( - self._orm_id == other._orm_id - or self.dict() == other.dict() - or (self.pk == other.pk and self.pk is not None) + self._orm_id == other._orm_id + or self.dict() == other.dict() + or (self.pk == other.pk and self.pk is not None) ) @classmethod @@ -209,9 +209,9 @@ class NewBaseModel( @classmethod def get_properties( - cls, - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + cls, + include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, ) -> List[str]: if isinstance(cls._props, list): props = cls._props @@ -220,7 +220,7 @@ class NewBaseModel( prop for prop in dir(cls) if isinstance(getattr(cls, prop), property) - and prop not in ("__values__", "__fields__", "fields", "pk_column") + and prop not in ("__values__", "__fields__", "fields", "pk_column") ] cls._props = props if include: @@ -230,16 +230,16 @@ class NewBaseModel( return props def dict( # noqa A003 - self, - *, - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - by_alias: bool = False, - skip_defaults: bool = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - nested: bool = False + self, + *, + include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + nested: bool = False ) -> "DictStrAny": # noqa: A003' dict_instance = super().dict( include=include, diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index a77a9a8..b8b4aa3 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -358,8 +358,11 @@ class QuerySet: instance.pk = pk # refresh server side defaults - if any(field.server_default is not None - for name, field in self.model.Meta.model_fields.items() if name not in kwargs): + if any( + field.server_default is not None + for name, field in self.model.Meta.model_fields.items() + if name not in kwargs + ): instance = await instance.load() instance.set_save_status(True) return instance @@ -377,7 +380,7 @@ class QuerySet: for objt in objects: objt.set_save_status(True) - async def bulk_update( + async def bulk_update( # noqa: CCR001 self, objects: List["Model"], columns: List[str] = None ) -> None: ready_objects = [] From e805ff16b215b46a752cd2179d4517667ecff5a0 Mon Sep 17 00:00:00 2001 From: collerek Date: Sat, 14 Nov 2020 13:53:32 +0100 Subject: [PATCH 03/10] introduce upsert method on model, add tests to see if save status properly changing on nested models --- ormar/exceptions.py | 4 ++ ormar/models/model.py | 11 +++++ tests/test_save_status.py | 95 ++++++++++++++++++++++++++------------- 3 files changed, 80 insertions(+), 30 deletions(-) diff --git a/ormar/exceptions.py b/ormar/exceptions.py index 40cfd26..0800a81 100644 --- a/ormar/exceptions.py +++ b/ormar/exceptions.py @@ -24,3 +24,7 @@ class QueryDefinitionError(AsyncOrmException): class RelationshipInstanceError(AsyncOrmException): pass + + +class ModelPersistenceError(AsyncOrmException): + pass diff --git a/ormar/models/model.py b/ormar/models/model.py index 5141604..ac3f7ce 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Type, TypeVar, import sqlalchemy import ormar.queryset # noqa I100 +from ormar.exceptions import ModelPersistenceError from ormar.fields.many_to_many import ManyToManyField from ormar.models import NewBaseModel # noqa I100 from ormar.models.metaclass import ModelMeta @@ -169,6 +170,11 @@ class Model(NewBaseModel): return item + async def upsert(self: T, **kwargs: Any) -> T: + if not self.pk: + return await self.save() + return await self.update(**kwargs) + async def save(self: T) -> T: self_fields = self._extract_model_db_fields() @@ -191,6 +197,11 @@ class Model(NewBaseModel): new_values = {**self.dict(), **kwargs} self.from_dict(new_values) + if not self.pk: + raise ModelPersistenceError( + "You cannot update not saved model! Use save or upsert method." + ) + self_fields = self._extract_model_db_fields() self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname)) self_fields = self.translate_columns_to_aliases(self_fields) diff --git a/tests/test_save_status.py b/tests/test_save_status.py index baea547..c94a171 100644 --- a/tests/test_save_status.py +++ b/tests/test_save_status.py @@ -7,6 +7,7 @@ import pytest import sqlalchemy import ormar +from ormar.exceptions import ModelPersistenceError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -67,7 +68,7 @@ def create_test_database(): async def test_instantation_false_save_true(): async with database: async with database.transaction(force_rollback=True): - comp = Company(name='Banzai', founded=1988) + comp = Company(name="Banzai", founded=1988) assert not comp._orm_saved await comp.save() assert comp._orm_saved @@ -77,21 +78,21 @@ async def test_instantation_false_save_true(): async def test_saved_edited_not_saved(): async with database: async with database.transaction(force_rollback=True): - comp = await Company.objects.create(name='Banzai', founded=1988) + comp = await Company.objects.create(name="Banzai", founded=1988) assert comp._orm_saved - comp.name = 'Banzai2' + comp.name = "Banzai2" assert not comp._orm_saved await comp.update() assert comp._orm_saved - await comp.update(name='Banzai3') + await comp.update(name="Banzai3") assert comp._orm_saved comp.pk = 999 assert not comp._orm_saved - await comp.save() + await comp.update() assert comp._orm_saved @@ -99,8 +100,8 @@ async def test_saved_edited_not_saved(): async def test_adding_related_gets_dirty(): async with database: async with database.transaction(force_rollback=True): - hq = await HQ.objects.create(name='Main') - comp = await Company.objects.create(name='Banzai', founded=1988) + hq = await HQ.objects.create(name="Main") + comp = await Company.objects.create(name="Banzai", founded=1988) assert comp._orm_saved comp.hq = hq @@ -108,20 +109,28 @@ async def test_adding_related_gets_dirty(): await comp.update() assert comp._orm_saved - comp = await Company.objects.select_related('hq').get(name='Banzai') + comp = await Company.objects.select_related("hq").get(name="Banzai") assert comp._orm_saved + assert comp.hq.pk == hq.pk assert comp.hq._orm_saved + comp.hq.name = "Suburbs" + assert not comp.hq._orm_saved + assert comp._orm_saved + + await comp.hq.update() + assert comp.hq._orm_saved + @pytest.mark.asyncio async def test_adding_many_to_many_does_not_gets_dirty(): async with database: async with database.transaction(force_rollback=True): - nick1 = await NickNames.objects.create(name='Bazinga', is_lame=False) - nick2 = await NickNames.objects.create(name='Bazinga2', is_lame=True) + nick1 = await NickNames.objects.create(name="Bazinga", is_lame=False) + nick2 = await NickNames.objects.create(name="Bazinga2", is_lame=True) - hq = await HQ.objects.create(name='Main') + hq = await HQ.objects.create(name="Main") assert hq._orm_saved await hq.nicks.add(nick1) @@ -129,24 +138,30 @@ async def test_adding_many_to_many_does_not_gets_dirty(): await hq.nicks.add(nick2) assert hq._orm_saved - hq = await HQ.objects.select_related('nicks').get(name='Main') + hq = await HQ.objects.select_related("nicks").get(name="Main") assert hq._orm_saved assert hq.nicks[0]._orm_saved await hq.nicks.remove(nick1) assert hq._orm_saved + hq.nicks[0].name = "Kabucha" + assert not hq.nicks[0]._orm_saved + + await hq.nicks[0].update() + assert hq.nicks[0]._orm_saved + @pytest.mark.asyncio async def test_delete(): async with database: async with database.transaction(force_rollback=True): - comp = await Company.objects.create(name='Banzai', founded=1988) + comp = await Company.objects.create(name="Banzai", founded=1988) assert comp._orm_saved await comp.delete() assert not comp._orm_saved - await comp.save() + await comp.update() assert comp._orm_saved @@ -154,26 +169,26 @@ async def test_delete(): async def test_load(): async with database: async with database.transaction(force_rollback=True): - comp = await Company.objects.create(name='Banzai', founded=1988) + comp = await Company.objects.create(name="Banzai", founded=1988) assert comp._orm_saved - comp.name = 'AA' + comp.name = "AA" assert not comp._orm_saved await comp.load() assert comp._orm_saved - assert comp.name == 'Banzai' + assert comp.name == "Banzai" @pytest.mark.asyncio async def test_queryset_methods(): async with database: async with database.transaction(force_rollback=True): - await Company.objects.create(name='Banzai', founded=1988) - await Company.objects.create(name='Yuhu', founded=1989) - await Company.objects.create(name='Konono', founded=1990) - await Company.objects.create(name='Sumaaa', founded=1991) + await Company.objects.create(name="Banzai", founded=1988) + await Company.objects.create(name="Yuhu", founded=1989) + await Company.objects.create(name="Konono", founded=1990) + await Company.objects.create(name="Sumaaa", founded=1991) - comp = await Company.objects.get(name='Banzai') + comp = await Company.objects.get(name="Banzai") assert comp._orm_saved comp = await Company.objects.first() @@ -182,20 +197,20 @@ async def test_queryset_methods(): comps = await Company.objects.all() assert [comp._orm_saved for comp in comps] - comp2 = await Company.objects.get_or_create(name='Banzai_new', founded=2001) + comp2 = await Company.objects.get_or_create(name="Banzai_new", founded=2001) assert comp2._orm_saved - comp3 = await Company.objects.get_or_create(name='Banzai', founded=1988) + comp3 = await Company.objects.get_or_create(name="Banzai", founded=1988) assert comp3._orm_saved assert comp3.pk == comp.pk update_dict = comp.dict() - update_dict['founded'] = 2010 + update_dict["founded"] = 2010 comp = await Company.objects.update_or_create(**update_dict) assert comp._orm_saved assert comp.founded == 2010 - create_dict = {'name': "Yoko", "founded": 2005} + create_dict = {"name": "Yoko", "founded": 2005} comp = await Company.objects.update_or_create(**create_dict) assert comp._orm_saved assert comp.founded == 2005 @@ -205,16 +220,16 @@ async def test_queryset_methods(): async def test_bulk_methods(): async with database: async with database.transaction(force_rollback=True): - c1 = Company(name='Banzai', founded=1988) - c2 = Company(name='Yuhu', founded=1989) + c1 = Company(name="Banzai", founded=1988) + c2 = Company(name="Yuhu", founded=1989) await Company.objects.bulk_create([c1, c2]) assert c1._orm_saved assert c2._orm_saved c1, c2 = await Company.objects.all() - c1.name = 'Banzai2' - c2.name = 'Yuhu2' + c1.name = "Banzai2" + c2.name = "Yuhu2" assert not c1._orm_saved assert not c2._orm_saved @@ -222,3 +237,23 @@ async def test_bulk_methods(): await Company.objects.bulk_update([c1, c2]) assert c1._orm_saved assert c2._orm_saved + + c3 = Company(name="Cobra", founded=2088) + assert not c3._orm_saved + + with pytest.raises(ModelPersistenceError): + await c3.update() + + await c3.upsert() + assert c3._orm_saved + + c3.name = "Python" + assert not c3._orm_saved + + await c3.upsert() + assert c3._orm_saved + assert c3.name == "Python" + + await c3.upsert(founded=2077) + assert c3._orm_saved + assert c3.founded == 2077 From 58a3855697c6058850f5ae3d4b9998a8d3d0236b Mon Sep 17 00:00:00 2001 From: collerek Date: Sat, 14 Nov 2020 13:57:04 +0100 Subject: [PATCH 04/10] add saved property to avoid private prop access --- ormar/models/newbasemodel.py | 7 ++- tests/test_save_status.py | 96 ++++++++++++++++++------------------ 2 files changed, 54 insertions(+), 49 deletions(-) diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 971bba4..5acf4ca 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -193,6 +193,10 @@ class NewBaseModel( def pk_column(self) -> sqlalchemy.Column: return self.Meta.table.primary_key.columns.values()[0] + @property + def saved(self) -> bool: + return self._orm_saved + @classmethod def pk_type(cls) -> Any: return cls.Meta.model_fields[cls.Meta.pkname].__type__ @@ -220,7 +224,8 @@ class NewBaseModel( prop for prop in dir(cls) if isinstance(getattr(cls, prop), property) - and prop not in ("__values__", "__fields__", "fields", "pk_column") + and prop + not in ("__values__", "__fields__", "fields", "pk_column", "saved") ] cls._props = props if include: diff --git a/tests/test_save_status.py b/tests/test_save_status.py index c94a171..ec15ccf 100644 --- a/tests/test_save_status.py +++ b/tests/test_save_status.py @@ -69,9 +69,9 @@ async def test_instantation_false_save_true(): async with database: async with database.transaction(force_rollback=True): comp = Company(name="Banzai", founded=1988) - assert not comp._orm_saved + assert not comp.saved await comp.save() - assert comp._orm_saved + assert comp.saved @pytest.mark.asyncio @@ -79,21 +79,21 @@ async def test_saved_edited_not_saved(): async with database: async with database.transaction(force_rollback=True): comp = await Company.objects.create(name="Banzai", founded=1988) - assert comp._orm_saved + assert comp.saved comp.name = "Banzai2" - assert not comp._orm_saved + assert not comp.saved await comp.update() - assert comp._orm_saved + assert comp.saved await comp.update(name="Banzai3") - assert comp._orm_saved + assert comp.saved comp.pk = 999 - assert not comp._orm_saved + assert not comp.saved await comp.update() - assert comp._orm_saved + assert comp.saved @pytest.mark.asyncio @@ -102,25 +102,25 @@ async def test_adding_related_gets_dirty(): async with database.transaction(force_rollback=True): hq = await HQ.objects.create(name="Main") comp = await Company.objects.create(name="Banzai", founded=1988) - assert comp._orm_saved + assert comp.saved comp.hq = hq - assert not comp._orm_saved + assert not comp.saved await comp.update() - assert comp._orm_saved + assert comp.saved comp = await Company.objects.select_related("hq").get(name="Banzai") - assert comp._orm_saved + assert comp.saved assert comp.hq.pk == hq.pk - assert comp.hq._orm_saved + assert comp.hq.saved comp.hq.name = "Suburbs" - assert not comp.hq._orm_saved - assert comp._orm_saved + assert not comp.hq.saved + assert comp.saved await comp.hq.update() - assert comp.hq._orm_saved + assert comp.hq.saved @pytest.mark.asyncio @@ -131,25 +131,25 @@ async def test_adding_many_to_many_does_not_gets_dirty(): nick2 = await NickNames.objects.create(name="Bazinga2", is_lame=True) hq = await HQ.objects.create(name="Main") - assert hq._orm_saved + assert hq.saved await hq.nicks.add(nick1) - assert hq._orm_saved + assert hq.saved await hq.nicks.add(nick2) - assert hq._orm_saved + assert hq.saved hq = await HQ.objects.select_related("nicks").get(name="Main") - assert hq._orm_saved - assert hq.nicks[0]._orm_saved + assert hq.saved + assert hq.nicks[0].saved await hq.nicks.remove(nick1) - assert hq._orm_saved + assert hq.saved hq.nicks[0].name = "Kabucha" - assert not hq.nicks[0]._orm_saved + assert not hq.nicks[0].saved await hq.nicks[0].update() - assert hq.nicks[0]._orm_saved + assert hq.nicks[0].saved @pytest.mark.asyncio @@ -157,12 +157,12 @@ async def test_delete(): async with database: async with database.transaction(force_rollback=True): comp = await Company.objects.create(name="Banzai", founded=1988) - assert comp._orm_saved + assert comp.saved await comp.delete() - assert not comp._orm_saved + assert not comp.saved await comp.update() - assert comp._orm_saved + assert comp.saved @pytest.mark.asyncio @@ -170,12 +170,12 @@ async def test_load(): async with database: async with database.transaction(force_rollback=True): comp = await Company.objects.create(name="Banzai", founded=1988) - assert comp._orm_saved + assert comp.saved comp.name = "AA" - assert not comp._orm_saved + assert not comp.saved await comp.load() - assert comp._orm_saved + assert comp.saved assert comp.name == "Banzai" @@ -189,30 +189,30 @@ async def test_queryset_methods(): await Company.objects.create(name="Sumaaa", founded=1991) comp = await Company.objects.get(name="Banzai") - assert comp._orm_saved + assert comp.saved comp = await Company.objects.first() - assert comp._orm_saved + assert comp.saved comps = await Company.objects.all() - assert [comp._orm_saved for comp in comps] + assert [comp.saved for comp in comps] comp2 = await Company.objects.get_or_create(name="Banzai_new", founded=2001) - assert comp2._orm_saved + assert comp2.saved comp3 = await Company.objects.get_or_create(name="Banzai", founded=1988) - assert comp3._orm_saved + assert comp3.saved assert comp3.pk == comp.pk update_dict = comp.dict() update_dict["founded"] = 2010 comp = await Company.objects.update_or_create(**update_dict) - assert comp._orm_saved + assert comp.saved assert comp.founded == 2010 create_dict = {"name": "Yoko", "founded": 2005} comp = await Company.objects.update_or_create(**create_dict) - assert comp._orm_saved + assert comp.saved assert comp.founded == 2005 @@ -224,36 +224,36 @@ async def test_bulk_methods(): c2 = Company(name="Yuhu", founded=1989) await Company.objects.bulk_create([c1, c2]) - assert c1._orm_saved - assert c2._orm_saved + assert c1.saved + assert c2.saved c1, c2 = await Company.objects.all() c1.name = "Banzai2" c2.name = "Yuhu2" - assert not c1._orm_saved - assert not c2._orm_saved + assert not c1.saved + assert not c2.saved await Company.objects.bulk_update([c1, c2]) - assert c1._orm_saved - assert c2._orm_saved + assert c1.saved + assert c2.saved c3 = Company(name="Cobra", founded=2088) - assert not c3._orm_saved + assert not c3.saved with pytest.raises(ModelPersistenceError): await c3.update() await c3.upsert() - assert c3._orm_saved + assert c3.saved c3.name = "Python" - assert not c3._orm_saved + assert not c3.saved await c3.upsert() - assert c3._orm_saved + assert c3.saved assert c3.name == "Python" await c3.upsert(founded=2077) - assert c3._orm_saved + assert c3.saved assert c3.founded == 2077 From cd33f6a96baf15bccaccb31a9285a02f61ef8bd6 Mon Sep 17 00:00:00 2001 From: collerek Date: Sat, 14 Nov 2020 14:29:54 +0100 Subject: [PATCH 05/10] introduce save_related method that traverses the related objects and upserts them if they are not saved --- ormar/models/model.py | 17 +++++ tests/test_save_related.py | 151 +++++++++++++++++++++++++++++++++++++ tests/test_save_status.py | 4 +- 3 files changed, 169 insertions(+), 3 deletions(-) create mode 100644 tests/test_save_related.py diff --git a/ormar/models/model.py b/ormar/models/model.py index ac3f7ce..437ed94 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -192,6 +192,23 @@ class Model(NewBaseModel): self.set_save_status(True) return self + async def save_related(self) -> int: + update_count = 0 + for related in self.extract_related_names(): + if self.Meta.model_fields[related].virtual or issubclass( + self.Meta.model_fields[related], ManyToManyField + ): + for rel in getattr(self, related): + if not rel.saved: + await rel.upsert() + update_count += 1 + else: + rel = getattr(self, related) + if not rel.saved: + await rel.upsert() + update_count += 1 + return update_count + async def update(self: T, **kwargs: Any) -> T: if kwargs: new_values = {**self.dict(), **kwargs} diff --git a/tests/test_save_related.py b/tests/test_save_related.py new file mode 100644 index 0000000..1defabf --- /dev/null +++ b/tests/test_save_related.py @@ -0,0 +1,151 @@ +from typing import List + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class NickNames(ormar.Model): + class Meta: + tablename = "nicks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + is_lame: bool = ormar.Boolean(nullable=True) + + +class NicksHq(ormar.Model): + class Meta: + tablename = "nicks_x_hq" + metadata = metadata + database = database + + +class HQ(ormar.Model): + class Meta: + tablename = "hqs" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq) + + +class Company(ormar.Model): + class Meta: + tablename = "companies" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="company_name") + founded: int = ormar.Integer(nullable=True) + hq: HQ = ormar.ForeignKey(HQ, related_name="companies") + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_saving_related_fk_rel(): + async with database: + async with database.transaction(force_rollback=True): + hq = await HQ.objects.create(name="Main") + comp = await Company.objects.create(name="Banzai", founded=1988, hq=hq) + assert comp.saved + + count = await comp.save_related() + assert count == 0 + + comp.hq.name = "Suburbs" + assert not comp.hq.saved + assert comp.saved + + count = await comp.save_related() + assert count == 1 + assert comp.hq.saved + + +@pytest.mark.asyncio +async def test_adding_many_to_many_does_not_gets_dirty(): + async with database: + async with database.transaction(force_rollback=True): + nick1 = await NickNames.objects.create(name="BazingaO", is_lame=False) + nick2 = await NickNames.objects.create(name="Bazinga20", is_lame=True) + + hq = await HQ.objects.create(name="Main") + assert hq.saved + + await hq.nicks.add(nick1) + assert hq.saved + await hq.nicks.add(nick2) + assert hq.saved + + count = await hq.save_related() + assert count == 0 + + hq.nicks[0].name = "Kabucha" + hq.nicks[1].name = "Kabucha2" + assert not hq.nicks[0].saved + assert not hq.nicks[1].saved + + count = await hq.save_related() + assert count == 2 + assert hq.nicks[0].saved + assert hq.nicks[1].saved + + +@pytest.mark.asyncio +async def test_queryset_methods(): + async with database: + async with database.transaction(force_rollback=True): + hq = await HQ.objects.create(name="Main") + await Company.objects.create(name="Banzai", founded=1988, hq=hq) + + hq = await HQ.objects.select_related("companies").get(name="Main") + assert hq.saved + assert hq.companies[0].saved + + hq.companies[0].name = "Konichiwa" + assert not hq.companies[0].saved + count = await hq.save_related() + assert count == 1 + assert hq.companies[0].saved + + await Company.objects.create(name="Joshua", founded=1888, hq=hq) + + hq = await HQ.objects.select_related("companies").get(name="Main") + assert hq.saved + assert hq.companies[0].saved + assert hq.companies[1].saved + + hq.companies[0].name = hq.companies[0].name + "20" + assert not hq.companies[0].saved + # save only if not saved so now only one + count = await hq.save_related() + assert count == 1 + assert hq.companies[0].saved + + hq.companies[0].name = hq.companies[0].name + "20" + hq.companies[1].name = hq.companies[1].name + "30" + assert not hq.companies[0].saved + assert not hq.companies[1].saved + count = await hq.save_related() + assert count == 2 + assert hq.companies[0].saved + assert hq.companies[1].saved diff --git a/tests/test_save_status.py b/tests/test_save_status.py index ec15ccf..93e89ac 100644 --- a/tests/test_save_status.py +++ b/tests/test_save_status.py @@ -1,8 +1,6 @@ -import itertools -from typing import Optional, List +from typing import List import databases -import pydantic import pytest import sqlalchemy From 0f36944fe1a49a42fadce1db3e0ead74bcf6ebe8 Mon Sep 17 00:00:00 2001 From: collerek Date: Sat, 14 Nov 2020 14:47:33 +0100 Subject: [PATCH 06/10] add safe fails for adding and removing not saved models to many to many rel, add tests for save_related --- ormar/models/model.py | 2 +- ormar/models/modelproxy.py | 12 +++++++++--- ormar/relations/relation_proxy.py | 7 ++++++- tests/test_many_to_many.py | 26 +++++++++++++++++++++++++- 4 files changed, 41 insertions(+), 6 deletions(-) diff --git a/ormar/models/model.py b/ormar/models/model.py index 437ed94..80c9807 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -192,7 +192,7 @@ class Model(NewBaseModel): self.set_save_status(True) return self - async def save_related(self) -> int: + async def save_related(self) -> int: # noqa: CCR001 update_count = 0 for related in self.extract_related_names(): if self.Meta.model_fields[related].virtual or issubclass( diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index c457bb2..3835b41 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -12,7 +12,7 @@ from typing import ( Union, ) -from ormar.exceptions import RelationshipInstanceError +from ormar.exceptions import ModelPersistenceError, RelationshipInstanceError try: import orjson as json @@ -63,8 +63,14 @@ class ModelTableProxy: target_field = cls.Meta.model_fields[field] target_pkname = target_field.to.Meta.pkname if isinstance(field_value, ormar.Model): - model_dict[field] = getattr(field_value, target_pkname) - elif field_value: + pk_value = getattr(field_value, target_pkname) + if not pk_value: + raise ModelPersistenceError( + f"You cannot save {field_value.get_name()} " + f"model without pk set!" + ) + model_dict[field] = pk_value + elif field_value: # nested dict model_dict[field] = field_value.get(target_pkname) else: model_dict.pop(field, None) diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index 88130d5..f03252e 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -1,7 +1,7 @@ from typing import Any, TYPE_CHECKING import ormar -from ormar.exceptions import RelationshipInstanceError +from ormar.exceptions import NoMatch, RelationshipInstanceError from ormar.relations.querysetproxy import QuerysetProxy if TYPE_CHECKING: # pragma no cover @@ -54,6 +54,11 @@ class RelationProxy(list): return queryset async def remove(self, item: "Model") -> None: # type: ignore + if item not in self: + raise NoMatch( + f"Object {self._owner.get_name()} has no " + f"{item.get_name()} with given primary key!" + ) super().remove(item) rel_name = item.resolve_relation_name(item, self._owner) relation = item._orm._get(rel_name) diff --git a/tests/test_many_to_many.py b/tests/test_many_to_many.py index 3ddeb61..8d8b258 100644 --- a/tests/test_many_to_many.py +++ b/tests/test_many_to_many.py @@ -6,7 +6,7 @@ import pytest import sqlalchemy import ormar -from ormar.exceptions import RelationshipInstanceError +from ormar.exceptions import ModelPersistenceError, NoMatch, RelationshipInstanceError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -196,3 +196,27 @@ async def test_selecting_related_fail_without_saving(cleanup): post = Post(title="Hello, M2M", author=guido) with pytest.raises(RelationshipInstanceError): await post.categories.all() + + +@pytest.mark.asyncio +async def test_adding_unsaved_related(cleanup): + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = Category(name="News") + with pytest.raises(ModelPersistenceError): + await post.categories.add(news) + + await news.save() + await post.categories.add(news) + assert len(await post.categories.all()) == 1 + + +@pytest.mark.asyncio +async def test_removing_unsaved_related(cleanup): + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = Category(name="News") + with pytest.raises(NoMatch): + await post.categories.remove(news) From d478ea6e15a2b2dd689743d26afc5ef3ba2c1b05 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 15 Nov 2020 10:33:03 +0100 Subject: [PATCH 07/10] add follow=True for save_related, update docs --- docs/models.md | 98 +++++++++++++++++++++++++++++++++++--- ormar/models/model.py | 54 +++++++++++++++++---- tests/test_save_related.py | 59 ++++++++++++++++++++++- 3 files changed, 193 insertions(+), 18 deletions(-) diff --git a/docs/models.md b/docs/models.md index d7d3393..fc29f6e 100644 --- a/docs/models.md +++ b/docs/models.md @@ -228,6 +228,33 @@ Each model has a `QuerySet` initialised as `objects` parameter !!!info To read more about `QuerySets` (including bulk operations) and available methods visit [queries][queries] +## `Model` save status + +Each model instance is a separate python object and they do not know anything about each other. + +```python +track1 = await Track.objects.get(name='The Bird') +track2 = await Track.objects.get(name='The Bird') +assert track1 == track2 # True + +track1.name = 'The Bird2' +await track1.save() +assert track1.name == track2.name # False +# track2 does not update and knows nothing about track1 +``` + +The objects itself have a saved status, which is set as following: + +* Model is saved after `save/update/load/upsert` method on model +* Model is saved after `create/get/first/all/get_or_create/update_or_create` method +* Model is saved when passed to `bulk_update` and `bulk_create` +* Model is saved after `adding/removing` `ManyToMany` related objects (through model instance auto saved/deleted) +* Model is **not** saved after change of any own field (including `pk` as `Model.pk` alias) +* Model is **not** saved after adding/removing `ForeignKey` related object (fk column not saved) +* Model is **not** saved after instantiation with `__init__` (w/o `QuerySet.create` or before calling `save`) + +You can check if model is saved with `ModelInstance.saved` property + ## `Model` methods ### load @@ -249,16 +276,57 @@ track.album.name # will return 'Malibu' ### save +`save() -> self` + You can create new models by using `QuerySet.create()` method or by initializing your model as a normal pydantic model and later calling `save()` method. -`save()` can also be used to persist changes that you made to the model. +`save()` can also be used to persist changes that you made to the model, but only if the primary key is not set or the model does not exist in database. + +The `save()` method does not check if the model exists in db, so if it does you will get a integrity error from your selected db backend if trying to save model with already existing primary key. ```python track = Track(name='The Bird') await track.save() # will persist the model in database + +track = await Track.objects.get(name='The Bird') +await track.save() # will raise integrity error as pk is populated ``` +### update + +`update(**kwargs) -> self` + +You can update models by using `QuerySet.update()` method or by updating your model attributes (fields) and calling `update()` method. + +If you try to update a model without a primary key set a `ModelPersistenceError` exception will be thrown. + +To persist a newly created model use `save()` or `upsert(**kwargs)` methods. + +```python +track = await Track.objects.get(name='The Bird') +await track.update(name='The Bird Strikes Again') +``` + +### upsert + +`upsert(**kwargs) -> self` + +It's an proxy to either `save()` or `update(**kwargs)` methods described above. + +If the primary key is set -> the `update` method will be called. + +If the pk is not set the `save()` method will be called. + +```python +track = Track(name='The Bird') +await track.upsert() # will call save as the pk is empty + +track = await Track.objects.get(name='The Bird') +await track.upsert(name='The Bird Strikes Again') # will call update as pk is already populated +``` + + ### delete You can delete models by using `QuerySet.delete()` method or by using your model and calling `delete()` method. @@ -271,14 +339,29 @@ await track.delete() # will delete the model from database !!!tip Note that that `track` object stays the same, only record in the database is removed. -### update +### save_related -You can delete models by using `QuerySet.update()` method or by using your model and calling `update()` method. +`save_related(follow: bool = False) -> None` -```python -track = await Track.objects.get(name='The Bird') -await track.update(name='The Bird Strikes Again') -``` +Method goes through all relations of the `Model` on which the method is called, +and calls `upsert()` method on each model that is **not** saved. + +To understand when a model is saved check [save status][save status] section above. + +By default the `save_related` method saved only models that are directly related (one step away) to the model on which the method is called. + +But you can specify the `follow=True` parameter to traverse through nested models and save all of them in the relation tree. + +!!!warning + To avoid circular updates with `follow=True` set, `save_related` keeps a set of already visited Models, + and won't perform nested `save_related` on Models that were already visited. + + So if you have a diamond or circular relations types you need to perform the updates in a manual way. + + ```python + # in example like this the second Street (coming from Company) won't be save_related, so Whatever won't be updated + Street -> District -> City -> Companies -> Street -> Whatever + ``` ## Internals @@ -348,3 +431,4 @@ For example to list table model fields you can: [sqlalchemy connection string]: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls [sqlalchemy table creation]: https://docs.sqlalchemy.org/en/13/core/metadata.html#creating-and-dropping-database-tables [alembic]: https://alembic.sqlalchemy.org/en/latest/tutorial.html +[save status]: ../models/#model-save-status diff --git a/ormar/models/model.py b/ormar/models/model.py index 80c9807..aeb63c3 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -1,5 +1,16 @@ import itertools -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union +from typing import ( + Any, + Dict, + List, + Optional, + Set, + TYPE_CHECKING, + Tuple, + Type, + TypeVar, + Union, +) import sqlalchemy @@ -192,23 +203,48 @@ class Model(NewBaseModel): self.set_save_status(True) return self - async def save_related(self) -> int: # noqa: CCR001 - update_count = 0 + async def save_related( # noqa: CCR001 + self, follow: bool = False, visited: Set = None, update_count: int = 0 + ) -> int: # noqa: CCR001 + if not visited: + visited = {self.__class__} + else: + visited = {x for x in visited} + visited.add(self.__class__) + for related in self.extract_related_names(): if self.Meta.model_fields[related].virtual or issubclass( self.Meta.model_fields[related], ManyToManyField ): for rel in getattr(self, related): - if not rel.saved: - await rel.upsert() - update_count += 1 + update_count, visited = await self._update_and_follow( + rel=rel, + follow=follow, + visited=visited, + update_count=update_count, + ) + visited.add(self.Meta.model_fields[related].to) else: rel = getattr(self, related) - if not rel.saved: - await rel.upsert() - update_count += 1 + update_count, visited = await self._update_and_follow( + rel=rel, follow=follow, visited=visited, update_count=update_count + ) + visited.add(rel.__class__) return update_count + @staticmethod + async def _update_and_follow( + rel: "Model", follow: bool, visited: Set, update_count: int + ) -> Tuple[int, Set]: + if follow and rel.__class__ not in visited: + update_count = await rel.save_related( + follow=follow, visited=visited, update_count=update_count + ) + if not rel.saved: + await rel.upsert() + update_count += 1 + return update_count, visited + async def update(self: T, **kwargs: Any) -> T: if kwargs: new_values = {**self.dict(), **kwargs} diff --git a/tests/test_save_related.py b/tests/test_save_related.py index 1defabf..6cfe108 100644 --- a/tests/test_save_related.py +++ b/tests/test_save_related.py @@ -11,6 +11,16 @@ database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() +class CringeLevel(ormar.Model): + class Meta: + tablename = "levels" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + class NickNames(ormar.Model): class Meta: tablename = "nicks" @@ -20,6 +30,7 @@ class NickNames(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") is_lame: bool = ormar.Boolean(nullable=True) + level: CringeLevel = ormar.ForeignKey(CringeLevel) class NicksHq(ormar.Model): @@ -82,7 +93,7 @@ async def test_saving_related_fk_rel(): @pytest.mark.asyncio -async def test_adding_many_to_many_does_not_gets_dirty(): +async def test_saving_many_to_many(): async with database: async with database.transaction(force_rollback=True): nick1 = await NickNames.objects.create(name="BazingaO", is_lame=False) @@ -111,7 +122,7 @@ async def test_adding_many_to_many_does_not_gets_dirty(): @pytest.mark.asyncio -async def test_queryset_methods(): +async def test_saving_reversed_relation(): async with database: async with database.transaction(force_rollback=True): hq = await HQ.objects.create(name="Main") @@ -149,3 +160,47 @@ async def test_queryset_methods(): assert count == 2 assert hq.companies[0].saved assert hq.companies[1].saved + + +@pytest.mark.asyncio +async def test_saving_nested(): + async with database: + async with database.transaction(force_rollback=True): + level = await CringeLevel.objects.create(name='High') + level2 = await CringeLevel.objects.create(name='Low') + nick1 = await NickNames.objects.create(name="BazingaO", is_lame=False, level=level) + nick2 = await NickNames.objects.create(name="Bazinga20", is_lame=True, level=level2) + + hq = await HQ.objects.create(name="Main") + assert hq.saved + + await hq.nicks.add(nick1) + assert hq.saved + await hq.nicks.add(nick2) + assert hq.saved + + count = await hq.save_related() + assert count == 0 + + hq.nicks[0].level.name = "Medium" + assert not hq.nicks[0].level.saved + assert hq.nicks[0].saved + + count = await hq.save_related(follow=True) + assert count == 1 + assert hq.nicks[0].saved + assert hq.nicks[0].level.saved + + hq.nicks[0].level.name = "Low" + hq.nicks[1].level.name = "Medium" + assert not hq.nicks[0].level.saved + assert not hq.nicks[1].level.saved + assert hq.nicks[0].saved + assert hq.nicks[1].saved + + count = await hq.save_related(follow=True) + assert count == 2 + assert hq.nicks[0].saved + assert hq.nicks[0].level.saved + assert hq.nicks[1].saved + assert hq.nicks[1].level.saved From 1168159a701b19c2591bad742d4c022df6bd3094 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 16 Nov 2020 13:10:03 +0100 Subject: [PATCH 08/10] bump ver, some cleanup --- docs/models.md | 4 ++-- ormar/__init__.py | 2 +- ormar/models/model.py | 2 +- tests/test_save_related.py | 20 ++++++++++++++++---- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/docs/models.md b/docs/models.md index fc29f6e..dfc8324 100644 --- a/docs/models.md +++ b/docs/models.md @@ -359,8 +359,8 @@ But you can specify the `follow=True` parameter to traverse through nested model So if you have a diamond or circular relations types you need to perform the updates in a manual way. ```python - # in example like this the second Street (coming from Company) won't be save_related, so Whatever won't be updated - Street -> District -> City -> Companies -> Street -> Whatever + # in example like this the second Street (coming from City) won't be save_related, so ZipCode won't be updated + Street -> District -> City -> Street -> ZipCode ``` ## Internals diff --git a/ormar/__init__.py b/ormar/__init__.py index 03476f4..7f9e9e2 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.4.4" +__version__ = "0.5.0" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/models/model.py b/ormar/models/model.py index aeb63c3..fb5c295 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -234,7 +234,7 @@ class Model(NewBaseModel): @staticmethod async def _update_and_follow( - rel: "Model", follow: bool, visited: Set, update_count: int + rel: T, follow: bool, visited: Set, update_count: int ) -> Tuple[int, Set]: if follow and rel.__class__ not in visited: update_count = await rel.save_related( diff --git a/tests/test_save_related.py b/tests/test_save_related.py index 6cfe108..967ef1c 100644 --- a/tests/test_save_related.py +++ b/tests/test_save_related.py @@ -161,15 +161,27 @@ async def test_saving_reversed_relation(): assert hq.companies[0].saved assert hq.companies[1].saved + hq = await HQ.objects.select_related( + ["companies", "companies__hq__nicks"] + ).get(name="Main") + hq.companies[0].hq.nicks[0].name = "Sub" + assert not hq.companies[0].hq.nicks[0].saved + await hq.save_related(follow=True) + assert not hq.companies[0].hq.nicks[0].saved + @pytest.mark.asyncio async def test_saving_nested(): async with database: async with database.transaction(force_rollback=True): - level = await CringeLevel.objects.create(name='High') - level2 = await CringeLevel.objects.create(name='Low') - nick1 = await NickNames.objects.create(name="BazingaO", is_lame=False, level=level) - nick2 = await NickNames.objects.create(name="Bazinga20", is_lame=True, level=level2) + level = await CringeLevel.objects.create(name="High") + level2 = await CringeLevel.objects.create(name="Low") + nick1 = await NickNames.objects.create( + name="BazingaO", is_lame=False, level=level + ) + nick2 = await NickNames.objects.create( + name="Bazinga20", is_lame=True, level=level2 + ) hq = await HQ.objects.create(name="Main") assert hq.saved From 5e1f8ddecd2bc7d4d2170524b3f80c79ff5573b0 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 16 Nov 2020 13:14:47 +0100 Subject: [PATCH 09/10] bump ver, some cleanup --- tests/test_save_related.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_save_related.py b/tests/test_save_related.py index 967ef1c..0696018 100644 --- a/tests/test_save_related.py +++ b/tests/test_save_related.py @@ -161,14 +161,6 @@ async def test_saving_reversed_relation(): assert hq.companies[0].saved assert hq.companies[1].saved - hq = await HQ.objects.select_related( - ["companies", "companies__hq__nicks"] - ).get(name="Main") - hq.companies[0].hq.nicks[0].name = "Sub" - assert not hq.companies[0].hq.nicks[0].saved - await hq.save_related(follow=True) - assert not hq.companies[0].hq.nicks[0].saved - @pytest.mark.asyncio async def test_saving_nested(): From 8bc2a0358d332071234a89827cbb1f2623ba589f Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 16 Nov 2020 13:20:18 +0100 Subject: [PATCH 10/10] update relaease docs --- docs/releases.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/docs/releases.md b/docs/releases.md index c4fd384..e636ceb 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,24 @@ +# 0.5.0 + +* Added save status -> you can check if model is saved with `ModelInstance.saved` property + * Model is saved after `save/update/load/upsert` method on model + * Model is saved after `create/get/first/all/get_or_create/update_or_create` method + * Model is saved when passed to `bulk_update` and `bulk_create` + * Model is saved after adding/removing `ManyToMany` related objects (through model instance auto saved/deleted) + * Model is **not** saved after change of any own field (including pk as `Model.pk` alias) + * Model is **not** saved after adding/removing `ForeignKey` related object (fk column not saved) + * Model is **not** saved after instantation with `__init__` (w/o `QuerySet.create` or before calling `save`) +* Added `Model.upsert(**kwargs)` that performs `save()` if pk not set otherwise `update(**kwargs)` +* Added `Model.save_related(follow=False)` that iterates all related objects in all relations and checks if they are saved. If not it calls `upsert()` on each of them. +* **Breaking:** added raising exceptions if `add`-ing/`remove`-ing not saved (pk is None) models to `ManyToMany` relation +* Allow passing dictionaries and sets to fields and exclude_fields +* Auto translate str and lists to dicts for fields and exclude_fields +* **Breaking:** passing nested models to fields and exclude_fields is now by related ForeignKey name and not by target model name +* Performance optimizations - in modelproxy, newbasemodel - > less queries, some properties are cached on models +* Cleanup of unused relations code +* Optional performance dependency orjson added (**strongly recommended**) +* Updated docs + # 0.4.4 * add exclude_fields() method to exclude fields from sql