From b0cf9165312db63c6ba3a7857b94642c6e486f2d Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 13 Nov 2020 13:39:19 +0100 Subject: [PATCH] add saving status and basic test for this --- ormar/models/model.py | 59 +++++---- ormar/models/modelproxy.py | 1 + ormar/models/newbasemodel.py | 56 +++++---- ormar/queryset/queryset.py | 12 +- tests/test_save_status.py | 224 +++++++++++++++++++++++++++++++++++ 5 files changed, 300 insertions(+), 52 deletions(-) create mode 100644 tests/test_save_status.py diff --git a/ormar/models/model.py b/ormar/models/model.py index 688eed0..f215ffc 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -42,13 +42,13 @@ class Model(NewBaseModel): @classmethod def from_row( # noqa CCR001 - cls: Type[T], - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - related_models: Any = None, - previous_table: str = None, - fields: Optional[Union[Dict, Set]] = None, - exclude_fields: Optional[Union[Dict, Set]] = None, + cls: Type[T], + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + related_models: Any = None, + previous_table: str = None, + fields: Optional[Union[Dict, Set]] = None, + exclude_fields: Optional[Union[Dict, Set]] = None, ) -> Optional[T]: item: Dict[str, Any] = {} @@ -58,9 +58,9 @@ class Model(NewBaseModel): related_models = group_related_list(select_related) if ( - previous_table - and previous_table in cls.Meta.model_fields - and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField) + previous_table + and previous_table in cls.Meta.model_fields + and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField) ): previous_table = cls.Meta.model_fields[ previous_table @@ -90,20 +90,23 @@ class Model(NewBaseModel): exclude_fields=exclude_fields, ) - instance: Optional[T] = cls(**item) if item.get( - cls.Meta.pkname, None - ) is not None else None + if item.get(cls.Meta.pkname, None) is not None: + instance: Optional[T] = cls(**item) + instance.set_save_status(True) + else: + instance = None + return instance @classmethod def populate_nested_models_from_row( # noqa: CFQ002 - cls, - item: dict, - row: sqlalchemy.engine.ResultProxy, - related_models: Any, - previous_table: sqlalchemy.Table, - fields: Optional[Union[Dict, Set]] = None, - exclude_fields: Optional[Union[Dict, Set]] = None, + cls, + item: dict, + row: sqlalchemy.engine.ResultProxy, + related_models: Any, + previous_table: sqlalchemy.Table, + fields: Optional[Union[Dict, Set]] = None, + exclude_fields: Optional[Union[Dict, Set]] = None, ) -> dict: for related in related_models: if isinstance(related_models, dict) and related_models[related]: @@ -137,12 +140,12 @@ class Model(NewBaseModel): @classmethod def extract_prefixed_table_columns( # noqa CCR001 - cls, - item: dict, - row: sqlalchemy.engine.result.ResultProxy, - table_prefix: str, - fields: Optional[Union[Dict, Set]] = None, - exclude_fields: Optional[Union[Dict, Set]] = None, + cls, + item: dict, + row: sqlalchemy.engine.result.ResultProxy, + table_prefix: str, + fields: Optional[Union[Dict, Set]] = None, + exclude_fields: Optional[Union[Dict, Set]] = None, ) -> dict: # databases does not keep aliases in Record for postgres, change to raw row @@ -179,6 +182,7 @@ class Model(NewBaseModel): item_id = await self.Meta.database.execute(expr) if item_id: # postgress does not return id if it's already there setattr(self, self.Meta.pkname, item_id) + self.set_save_status(True) return self async def update(self: T, **kwargs: Any) -> T: @@ -193,12 +197,14 @@ class Model(NewBaseModel): expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname)) await self.Meta.database.execute(expr) + self.set_save_status(True) return self async def delete(self: T) -> int: expr = self.Meta.table.delete() expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname))) result = await self.Meta.database.execute(expr) + self.set_save_status(False) return result async def load(self: T) -> T: @@ -211,4 +217,5 @@ class Model(NewBaseModel): kwargs = dict(row) kwargs = self.translate_aliases_to_columns(kwargs) self.from_dict(kwargs) + self.set_save_status(True) return self diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 294d592..c457bb2 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -220,6 +220,7 @@ class ModelTableProxy: field, cls.merge_two_instances(current_field, getattr(other, field)), ) + other.set_save_status(True) return other @staticmethod diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 2a96c19..7e3daa7 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -123,12 +123,16 @@ class NewBaseModel( object.__setattr__(self, name, value) elif name == "pk": object.__setattr__(self, self.Meta.pkname, value) + self.set_save_status(False) elif name in self._orm: model = self.Meta.model_fields[name].expand_relationship(value, self) if isinstance(self.__dict__.get(name), list): + # virtual foreign key or many to many self.__dict__[name].append(model) else: + # foreign key relation self.__dict__[name] = model + self.set_save_status(False) else: value = ( self._convert_json(name, value, "dumps") @@ -136,15 +140,16 @@ class NewBaseModel( else value ) super().__setattr__(name, value) + self.set_save_status(False) def __getattribute__(self, item: str) -> Any: if item in ( - "_orm_id", - "_orm_saved", - "_orm", - "__fields__", - "_related_names", - "_props", + "_orm_id", + "_orm_saved", + "_orm", + "__fields__", + "_related_names", + "_props", ): return object.__getattribute__(self, item) if item == "pk": @@ -158,7 +163,7 @@ class NewBaseModel( return super().__getattribute__(item) def _extract_related_model_instead_of_field( - self, item: str + self, item: str ) -> Optional[Union["T", Sequence["T"]]]: # alias = self.get_column_alias(item) if item in self._orm: @@ -172,9 +177,9 @@ class NewBaseModel( def __same__(self, other: "NewBaseModel") -> bool: return ( - self._orm_id == other._orm_id - or self.dict() == other.dict() - or (self.pk == other.pk and self.pk is not None) + self._orm_id == other._orm_id + or self.dict() == other.dict() + or (self.pk == other.pk and self.pk is not None) ) @classmethod @@ -199,11 +204,14 @@ class NewBaseModel( def remove(self, name: "T") -> None: self._orm.remove_parent(self, name) + def set_save_status(self, status: bool) -> None: + object.__setattr__(self, "_orm_saved", status) + @classmethod def get_properties( - cls, - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + cls, + include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, ) -> List[str]: if isinstance(cls._props, list): props = cls._props @@ -212,7 +220,7 @@ class NewBaseModel( prop for prop in dir(cls) if isinstance(getattr(cls, prop), property) - and prop not in ("__values__", "__fields__", "fields", "pk_column") + and prop not in ("__values__", "__fields__", "fields", "pk_column") ] cls._props = props if include: @@ -222,16 +230,16 @@ class NewBaseModel( return props def dict( # noqa A003 - self, - *, - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - by_alias: bool = False, - skip_defaults: bool = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - nested: bool = False + self, + *, + include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + nested: bool = False ) -> "DictStrAny": # noqa: A003' dict_instance = super().dict( include=include, diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 448a0a9..a77a9a8 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -358,8 +358,10 @@ class QuerySet: instance.pk = pk # refresh server side defaults - instance = await instance.load() - + if any(field.server_default is not None + for name, field in self.model.Meta.model_fields.items() if name not in kwargs): + instance = await instance.load() + instance.set_save_status(True) return instance async def bulk_create(self, objects: List["Model"]) -> None: @@ -372,6 +374,9 @@ class QuerySet: expr = self.table.insert() await self.database.execute_many(expr, ready_objects) + for objt in objects: + objt.set_save_status(True) + async def bulk_update( self, objects: List["Model"], columns: List[str] = None ) -> None: @@ -418,3 +423,6 @@ class QuerySet: # otherwise it just passes all data to values and results in unconsumed columns expr = str(expr) await self.database.execute_many(expr, ready_objects) + + for objt in objects: + objt.set_save_status(True) diff --git a/tests/test_save_status.py b/tests/test_save_status.py new file mode 100644 index 0000000..baea547 --- /dev/null +++ b/tests/test_save_status.py @@ -0,0 +1,224 @@ +import itertools +from typing import Optional, List + +import databases +import pydantic +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class NickNames(ormar.Model): + class Meta: + tablename = "nicks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + is_lame: bool = ormar.Boolean(nullable=True) + + +class NicksHq(ormar.Model): + class Meta: + tablename = "nicks_x_hq" + metadata = metadata + database = database + + +class HQ(ormar.Model): + class Meta: + tablename = "hqs" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq) + + +class Company(ormar.Model): + class Meta: + tablename = "companies" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="company_name") + founded: int = ormar.Integer(nullable=True) + hq: HQ = ormar.ForeignKey(HQ) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_instantation_false_save_true(): + async with database: + async with database.transaction(force_rollback=True): + comp = Company(name='Banzai', founded=1988) + assert not comp._orm_saved + await comp.save() + assert comp._orm_saved + + +@pytest.mark.asyncio +async def test_saved_edited_not_saved(): + async with database: + async with database.transaction(force_rollback=True): + comp = await Company.objects.create(name='Banzai', founded=1988) + assert comp._orm_saved + comp.name = 'Banzai2' + assert not comp._orm_saved + + await comp.update() + assert comp._orm_saved + + await comp.update(name='Banzai3') + assert comp._orm_saved + + comp.pk = 999 + assert not comp._orm_saved + + await comp.save() + assert comp._orm_saved + + +@pytest.mark.asyncio +async def test_adding_related_gets_dirty(): + async with database: + async with database.transaction(force_rollback=True): + hq = await HQ.objects.create(name='Main') + comp = await Company.objects.create(name='Banzai', founded=1988) + assert comp._orm_saved + + comp.hq = hq + assert not comp._orm_saved + await comp.update() + assert comp._orm_saved + + comp = await Company.objects.select_related('hq').get(name='Banzai') + assert comp._orm_saved + assert comp.hq.pk == hq.pk + assert comp.hq._orm_saved + + +@pytest.mark.asyncio +async def test_adding_many_to_many_does_not_gets_dirty(): + async with database: + async with database.transaction(force_rollback=True): + nick1 = await NickNames.objects.create(name='Bazinga', is_lame=False) + nick2 = await NickNames.objects.create(name='Bazinga2', is_lame=True) + + hq = await HQ.objects.create(name='Main') + assert hq._orm_saved + + await hq.nicks.add(nick1) + assert hq._orm_saved + await hq.nicks.add(nick2) + assert hq._orm_saved + + hq = await HQ.objects.select_related('nicks').get(name='Main') + assert hq._orm_saved + assert hq.nicks[0]._orm_saved + + await hq.nicks.remove(nick1) + assert hq._orm_saved + + +@pytest.mark.asyncio +async def test_delete(): + async with database: + async with database.transaction(force_rollback=True): + comp = await Company.objects.create(name='Banzai', founded=1988) + assert comp._orm_saved + await comp.delete() + assert not comp._orm_saved + + await comp.save() + assert comp._orm_saved + + +@pytest.mark.asyncio +async def test_load(): + async with database: + async with database.transaction(force_rollback=True): + comp = await Company.objects.create(name='Banzai', founded=1988) + assert comp._orm_saved + comp.name = 'AA' + assert not comp._orm_saved + + await comp.load() + assert comp._orm_saved + assert comp.name == 'Banzai' + + +@pytest.mark.asyncio +async def test_queryset_methods(): + async with database: + async with database.transaction(force_rollback=True): + await Company.objects.create(name='Banzai', founded=1988) + await Company.objects.create(name='Yuhu', founded=1989) + await Company.objects.create(name='Konono', founded=1990) + await Company.objects.create(name='Sumaaa', founded=1991) + + comp = await Company.objects.get(name='Banzai') + assert comp._orm_saved + + comp = await Company.objects.first() + assert comp._orm_saved + + comps = await Company.objects.all() + assert [comp._orm_saved for comp in comps] + + comp2 = await Company.objects.get_or_create(name='Banzai_new', founded=2001) + assert comp2._orm_saved + + comp3 = await Company.objects.get_or_create(name='Banzai', founded=1988) + assert comp3._orm_saved + assert comp3.pk == comp.pk + + update_dict = comp.dict() + update_dict['founded'] = 2010 + comp = await Company.objects.update_or_create(**update_dict) + assert comp._orm_saved + assert comp.founded == 2010 + + create_dict = {'name': "Yoko", "founded": 2005} + comp = await Company.objects.update_or_create(**create_dict) + assert comp._orm_saved + assert comp.founded == 2005 + + +@pytest.mark.asyncio +async def test_bulk_methods(): + async with database: + async with database.transaction(force_rollback=True): + c1 = Company(name='Banzai', founded=1988) + c2 = Company(name='Yuhu', founded=1989) + + await Company.objects.bulk_create([c1, c2]) + assert c1._orm_saved + assert c2._orm_saved + + c1, c2 = await Company.objects.all() + c1.name = 'Banzai2' + c2.name = 'Yuhu2' + + assert not c1._orm_saved + assert not c2._orm_saved + + await Company.objects.bulk_update([c1, c2]) + assert c1._orm_saved + assert c2._orm_saved