diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml index 00a1be1..4ffe0c5 100644 --- a/.github/workflows/test-package.yml +++ b/.github/workflows/test-package.yml @@ -14,6 +14,7 @@ jobs: tests: name: "Python ${{ matrix.python-version }}" runs-on: ubuntu-latest + if: github.event_name == 'push' || github.event.pull_request.head.repo.full_name != 'collerek/ormar' strategy: matrix: python-version: [3.6, 3.7, 3.8, 3.9] diff --git a/docs/models/methods.md b/docs/models/methods.md index d8d25b1..3b81559 100644 --- a/docs/models/methods.md +++ b/docs/models/methods.md @@ -10,6 +10,13 @@ Each model instance have a set of methods to `save`, `update` or `load` itself. Available methods are described below. +## `pydantic` methods + +Note that each `ormar.Model` is also a `pydantic.BaseModel`, so all `pydantic` methods are also available on a model, +especially `dict()` and `json()` methods that can also accept `exclude`, `include` and other parameters. + +To read more check [pydantic][pydantic] documentation + ## load By default when you query a table without prefetching related models, the ormar will still construct @@ -81,7 +88,7 @@ await track.save() # will raise integrity error as pk is populated ## update -`update(**kwargs) -> self` +`update(_columns: List[str] = None, **kwargs) -> self` You can update models by using `QuerySet.update()` method or by updating your model attributes (fields) and calling `update()` method. @@ -94,6 +101,42 @@ track = await Track.objects.get(name='The Bird') await track.update(name='The Bird Strikes Again') ``` +To update only selected columns from model into the database provide a list of columns that should be updated to `_columns` argument. + +In example: + +```python +class Movie(ormar.Model): + class Meta: + tablename = "movies" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="title") + year: int = ormar.Integer() + profit: float = ormar.Float() + +terminator = await Movie(name='Terminator', year=1984, profit=0.078).save() + +terminator.name = "Terminator 2" +terminator.year = 1991 +terminator.profit = 0.520 + +# update only name +await terminator.update(_columns=["name"]) + +# note that terminator instance was not reloaded so +assert terminator.year == 1991 + +# but once you load the data from db you see it was not updated +await terminator.load() +assert terminator.year == 1984 +``` + +!!!warning + Note that `update()` does not refresh the instance of the Model, so if you change more columns than you pass in `_columns` list your Model instance will have different values than the database! + ## upsert `upsert(**kwargs) -> self` @@ -127,7 +170,7 @@ await track.delete() # will delete the model from database ## save_related -`save_related(follow: bool = False) -> None` +`save_related(follow: bool = False, save_all: bool = False, exclude=Optional[Union[Set, Dict]]) -> None` Method goes through all relations of the `Model` on which the method is called, and calls `upsert()` method on each model that is **not** saved. @@ -138,16 +181,27 @@ By default the `save_related` method saved only models that are directly related But you can specify the `follow=True` parameter to traverse through nested models and save all of them in the relation tree. +By default save_related saves only model that has not `saved` status, meaning that they were modified in current scope. + +If you want to force saving all of the related methods use `save_all=True` flag, which will upsert all related models, regardless of their save status. + +If you want to skip saving some of the relations you can pass `exclude` parameter. + +`Exclude` can be a set of own model relations, +or it can be a dictionary that can also contain nested items. + +!!!note + Note that `exclude` parameter in `save_related` accepts only relation fields names, so + if you pass any other fields they will be saved anyway + +!!!note + To read more about the structure of possible values passed to `exclude` check `Queryset.fields` method documentation. + !!!warning To avoid circular updates with `follow=True` set, `save_related` keeps a set of already visited Models, and won't perform nested `save_related` on Models that were already visited. So if you have a diamond or circular relations types you need to perform the updates in a manual way. - - ```python - # in example like this the second Street (coming from City) won't be save_related, so ZipCode won't be updated - Street -> District -> City -> Street -> ZipCode - ``` [fields]: ../fields.md [relations]: ../relations/index.md diff --git a/docs/releases.md b/docs/releases.md index 7f2b475..e22082c 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,40 @@ +# 0.10.2 + +## ✨ Features + +* `Model.save_related(follow=False)` now accept also two additional arguments: `Model.save_related(follow=False, save_all=False, exclude=None)`. + * `save_all:bool` -> By default (so with `save_all=False`) `ormar` only upserts models that are not saved (so new or updated ones), + with `save_all=True` all related models are saved, regardless of `saved` status, which might be useful if updated + models comes from api call, so are not changed in the backend. + * `exclude: Union[Set, Dict, None]` -> set/dict of relations to exclude from save, those relation won't be saved even with `follow=True` and `save_all=True`. + To exclude nested relations pass a nested dictionary like: `exclude={"child":{"sub_child": {"exclude_sub_child_realtion"}}}`. The allowed values follow + the `fields/exclude_fields` (from `QuerySet`) methods schema so when in doubt you can refer to docs in queries -> selecting subset of fields -> fields. +* `Model.update()` method now accepts `_columns: List[str] = None` parameter, that accepts list of column names to update. If passed only those columns will be updated in database. + Note that `update()` does not refresh the instance of the Model, so if you change more columns than you pass in `_columns` list your Model instance will have different values than the database! +* `Model.dict()` method previously included only directly related models or nested models if they were not nullable and not virtual, + now all related models not previously visited without loops are included in `dict()`. This should be not breaking + as just more data will be dumped to dict, but it should not be missing. +* `QuerySet.delete(each=False, **kwargs)` previously required that you either pass a `filter` (by `**kwargs` or as a separate `filter()` call) or set `each=True` now also accepts + `exclude()` calls that generates NOT filter. So either `each=True` needs to be set to delete whole table or at least one of `filter/exclude` clauses. +* Same thing applies to `QuerySet.update(each=False, **kwargs)` which also previously required that you either pass a `filter` (by `**kwargs` or as a separate `filter()` call) or set `each=True` now also accepts + `exclude()` calls that generates NOT filter. So either `each=True` needs to be set to update whole table or at least one of `filter/exclude` clauses. +* Same thing applies to `QuerysetProxy.update(each=False, **kwargs)` which also previously required that you either pass a `filter` (by `**kwargs` or as a separate `filter()` call) or set `each=True` now also accepts + `exclude()` calls that generates NOT filter. So either `each=True` needs to be set to update whole table or at least one of `filter/exclude` clauses. + +## 🐛 Fixes + +* Fix improper relation field resolution in `QuerysetProxy` if fk column has different database alias. +* Fix hitting recursion error with very complicated models structure with loops when calling `dict()`. +* Fix bug when two non-relation fields were merged (appended) in query result when they were not relation fields (i.e. JSON) +* Fix bug when during translation to dict from list the same relation name is used in chain but leads to different models +* Fix bug when bulk_create would try to save also `property_field` decorated methods and `pydantic` fields +* Fix wrong merging of deeply nested chain of reversed relations + +## 💬 Other + +* Performance optimizations +* Split tests into packages based on tested area + # 0.10.1 ## Features diff --git a/mkdocs.yml b/mkdocs.yml index 735c732..5432018 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -59,6 +59,7 @@ nav: - Model Table Proxy: api/models/model-table-proxy.md - Model Metaclass: api/models/model-metaclass.md - Excludable Items: api/models/excludable-items.md + - Traversible: api/models/traversible.md - Fields: - Base Field: api/fields/base-field.md - Model Fields: api/fields/model-fields.md diff --git a/ormar/__init__.py b/ormar/__init__.py index 2e8ae3f..e05928b 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -75,7 +75,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.10.1" +__version__ = "0.10.2" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/models/mixins/excludable_mixin.py b/ormar/models/mixins/excludable_mixin.py index a7850d5..3a2bb04 100644 --- a/ormar/models/mixins/excludable_mixin.py +++ b/ormar/models/mixins/excludable_mixin.py @@ -138,10 +138,8 @@ class ExcludableMixin(RelationMixin): return columns @classmethod - def _update_excluded_with_related_not_required( - cls, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny", None], - nested: bool = False, + def _update_excluded_with_related( + cls, exclude: Union["AbstractSetIntStr", "MappingIntStrAny", None], ) -> Union[Set, Dict]: """ Used during generation of the dict(). @@ -159,8 +157,9 @@ class ExcludableMixin(RelationMixin): :rtype: Union[Set, Dict] """ exclude = exclude or {} - related_set = cls._exclude_related_names_not_required(nested=nested) + related_set = cls.extract_related_names() if isinstance(exclude, set): + exclude = {s for s in exclude} exclude.union(related_set) else: related_dict = translate_list_to_dict(related_set) diff --git a/ormar/models/mixins/merge_mixin.py b/ormar/models/mixins/merge_mixin.py index 32d7288..524d5e5 100644 --- a/ormar/models/mixins/merge_mixin.py +++ b/ormar/models/mixins/merge_mixin.py @@ -1,7 +1,8 @@ from collections import OrderedDict -from typing import List, TYPE_CHECKING +from typing import Dict, List, Optional, TYPE_CHECKING, cast import ormar +from ormar.queryset.utils import translate_list_to_dict if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -46,13 +47,17 @@ class MergeModelMixin: return merged_rows @classmethod - def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": + def merge_two_instances( + cls, one: "Model", other: "Model", relation_map: Dict = None + ) -> "Model": """ Merges current (other) Model and previous one (one) and returns the current Model instance with data merged from previous one. If needed it's calling itself recurrently and merges also children models. + :param relation_map: map of models relations to follow + :type relation_map: Dict :param one: previous model instance :type one: Model :param other: current model instance @@ -60,20 +65,80 @@ class MergeModelMixin: :return: current Model instance with data merged from previous one. :rtype: Model """ - for field in one.Meta.model_fields.keys(): - current_field = getattr(one, field) - if isinstance(current_field, list) and not isinstance( - current_field, ormar.Model - ): - setattr(other, field, current_field + getattr(other, field)) + relation_map = ( + relation_map + if relation_map is not None + else translate_list_to_dict(one._iterate_related_models()) + ) + for field_name in relation_map: + current_field = getattr(one, field_name) + other_value = getattr(other, field_name, []) + if isinstance(current_field, list): + value_to_set = cls._merge_items_lists( + field_name=field_name, + current_field=current_field, + other_value=other_value, + relation_map=relation_map, + ) + setattr(other, field_name, value_to_set) elif ( isinstance(current_field, ormar.Model) - and current_field.pk == getattr(other, field).pk + and current_field.pk == other_value.pk ): setattr( other, - field, - cls.merge_two_instances(current_field, getattr(other, field)), + field_name, + cls.merge_two_instances( + current_field, + other_value, + relation_map=one._skip_ellipsis( # type: ignore + relation_map, field_name, default_return=dict() + ), + ), ) other.set_save_status(True) return other + + @classmethod + def _merge_items_lists( + cls, + field_name: str, + current_field: List, + other_value: List, + relation_map: Optional[Dict], + ) -> List: + """ + Takes two list of nested models and process them going deeper + according with the map. + + If model from one's list is in other -> they are merged with relations + to follow passed from map. + + If one's model is not in other it's simply appended to the list. + + :param field_name: name of the current relation field + :type field_name: str + :param current_field: list of nested models from one model + :type current_field: List[Model] + :param other_value: list of nested models from other model + :type other_value: List[Model] + :param relation_map: map of relations to follow + :type relation_map: Dict + :return: merged list of models + :rtype: List[Model] + """ + value_to_set = [x for x in other_value] + for cur_field in current_field: + if cur_field in other_value: + old_value = next((x for x in other_value if x == cur_field), None) + new_val = cls.merge_two_instances( + cur_field, + cast("Model", old_value), + relation_map=cur_field._skip_ellipsis( # type: ignore + relation_map, field_name, default_return=dict() + ), + ) + value_to_set = [x for x in value_to_set if x != cur_field] + [new_val] + else: + value_to_set.append(cur_field) + return value_to_set diff --git a/ormar/models/mixins/relation_mixin.py b/ormar/models/mixins/relation_mixin.py index 2a20dcc..151725a 100644 --- a/ormar/models/mixins/relation_mixin.py +++ b/ormar/models/mixins/relation_mixin.py @@ -4,11 +4,10 @@ from typing import ( Optional, Set, TYPE_CHECKING, - Type, - Union, ) from ormar import BaseField +from ormar.models.traversible import NodeList class RelationMixin: @@ -17,7 +16,7 @@ class RelationMixin: """ if TYPE_CHECKING: # pragma no cover - from ormar import ModelMeta, Model + from ormar import ModelMeta Meta: ModelMeta _related_names: Optional[Set] @@ -112,84 +111,39 @@ class RelationMixin: } return related_names - @classmethod - def _exclude_related_names_not_required(cls, nested: bool = False) -> Set: - """ - Returns a set of non mandatory related models field names. - - For a main model (not nested) only nullable related field names are returned, - for nested models all related models are returned. - - :param nested: flag setting nested models (child of previous one, not main one) - :type nested: bool - :return: set of non mandatory related fields - :rtype: Set - """ - if nested: - return cls.extract_related_names() - related_names = cls.extract_related_names() - related_names = { - name for name in related_names if cls.Meta.model_fields[name].nullable - } - return related_names - @classmethod def _iterate_related_models( # noqa: CCR001 - cls, - visited: Set[str] = None, - source_visited: Set[str] = None, - source_relation: str = None, - source_model: Union[Type["Model"], Type["RelationMixin"]] = None, + cls, node_list: NodeList = None, source_relation: str = None ) -> List[str]: """ Iterates related models recursively to extract relation strings of nested not visited models. - :param visited: set of already visited models - :type visited: Set[str] - :param source_relation: name of the current relation - :type source_relation: str - :param source_model: model from which relation comes in nested relations - :type source_model: Type["Model"] :return: list of relation strings to be passed to select_related :rtype: List[str] """ - source_visited = source_visited or cls._populate_source_model_prefixes() + if not node_list: + node_list = NodeList() + current_node = node_list.add(node_class=cls) + else: + current_node = node_list[-1] relations = cls.extract_related_names() processed_relations = [] for relation in relations: - target_model = cls.Meta.model_fields[relation].to - if cls._is_reverse_side_of_same_relation(source_model, target_model): - continue - if target_model not in source_visited or not source_model: + if not current_node.visited(relation): + target_model = cls.Meta.model_fields[relation].to + node_list.add( + node_class=target_model, + relation_name=relation, + parent_node=current_node, + ) deep_relations = target_model._iterate_related_models( - visited=visited, - source_visited=source_visited, - source_relation=relation, - source_model=cls, + source_relation=relation, node_list=node_list ) processed_relations.extend(deep_relations) - else: - processed_relations.append(relation) return cls._get_final_relations(processed_relations, source_relation) - @staticmethod - def _is_reverse_side_of_same_relation( - source_model: Optional[Union[Type["Model"], Type["RelationMixin"]]], - target_model: Type["Model"], - ) -> bool: - """ - Alias to check if source model is the same as target - :param source_model: source model - relation comes from it - :type source_model: Type["Model"] - :param target_model: target model - relation leads to it - :type target_model: Type["Model"] - :return: result of the check - :rtype: bool - """ - return bool(source_model and target_model == source_model) - @staticmethod def _get_final_relations( processed_relations: List, source_relation: Optional[str] @@ -212,12 +166,3 @@ class RelationMixin: else: final_relations = [source_relation] if source_relation else [] return final_relations - - @classmethod - def _populate_source_model_prefixes(cls) -> Set: - relations = cls.extract_related_names() - visited = {cls} - for relation in relations: - target_model = cls.Meta.model_fields[relation].to - visited.add(target_model) - return visited diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index dfca964..5ade49f 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -32,11 +32,29 @@ class SavePrepareMixin(RelationMixin, AliasMixin): :rtype: Dict[str, str] """ new_kwargs = cls._remove_pk_from_kwargs(new_kwargs) + new_kwargs = cls._remove_not_ormar_fields(new_kwargs) new_kwargs = cls.substitute_models_with_pks(new_kwargs) new_kwargs = cls.populate_default_values(new_kwargs) new_kwargs = cls.translate_columns_to_aliases(new_kwargs) return new_kwargs + @classmethod + def _remove_not_ormar_fields(cls, new_kwargs: dict) -> dict: + """ + Removes primary key for if it's nullable or autoincrement pk field, + and it's set to None. + + :param new_kwargs: dictionary of model that is about to be saved + :type new_kwargs: Dict[str, str] + :return: dictionary of model that is about to be saved + :rtype: Dict[str, str] + """ + ormar_fields = { + k for k, v in cls.Meta.model_fields.items() if not v.pydantic_only + } + new_kwargs = {k: v for k, v in new_kwargs.items() if k in ormar_fields} + return new_kwargs + @classmethod def _remove_pk_from_kwargs(cls, new_kwargs: dict) -> dict: """ diff --git a/ormar/models/model.py b/ormar/models/model.py index 6a46696..568a787 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -4,7 +4,6 @@ from typing import ( List, Set, TYPE_CHECKING, - Tuple, TypeVar, Union, ) @@ -14,7 +13,7 @@ from ormar.exceptions import ModelPersistenceError, NoMatch from ormar.models import NewBaseModel # noqa I100 from ormar.models.metaclass import ModelMeta from ormar.models.model_row import ModelRow - +from ormar.queryset.utils import subtract_dict, translate_list_to_dict T = TypeVar("T", bound="Model") @@ -101,8 +100,13 @@ class Model(ModelRow): return self async def save_related( # noqa: CCR001 - self, follow: bool = False, visited: Set = None, update_count: int = 0 - ) -> int: # noqa: CCR001 + self, + follow: bool = False, + save_all: bool = False, + relation_map: Dict = None, + exclude: Union[Set, Dict] = None, + update_count: int = 0, + ) -> int: """ Triggers a upsert method on all related models if the instances are not already saved. @@ -118,77 +122,89 @@ class Model(ModelRow): Model A but will never follow into Model C. Nested relations of those kind need to be persisted manually. + :param exclude: items to exclude during saving of relations + :type exclude: Union[Set, Dict] + :param relation_map: map of relations to follow + :type relation_map: Dict + :param save_all: flag if all models should be saved or only not saved ones + :type save_all: bool :param follow: flag to trigger deep save - by default only directly related models are saved with follow=True also related models of related models are saved :type follow: bool - :param visited: internal parameter for recursive calls - already visited models - :type visited: Set :param update_count: internal parameter for recursive calls - number of updated instances :type update_count: int :return: number of updated/saved models :rtype: int """ - if not visited: - visited = {self.__class__} - else: - visited = {x for x in visited} - visited.add(self.__class__) + relation_map = ( + relation_map + if relation_map is not None + else translate_list_to_dict(self._iterate_related_models()) + ) + if exclude and isinstance(exclude, Set): + exclude = translate_list_to_dict(exclude) + relation_map = subtract_dict(relation_map, exclude or {}) for related in self.extract_related_names(): - if ( - self.Meta.model_fields[related].virtual - or self.Meta.model_fields[related].is_multi - ): - for rel in getattr(self, related): - update_count, visited = await self._update_and_follow( - rel=rel, + if relation_map and related in relation_map: + value = getattr(self, related) + if value: + update_count = await self._update_and_follow( + value=value, follow=follow, - visited=visited, + save_all=save_all, + relation_map=self._skip_ellipsis( # type: ignore + relation_map, related, default_return={} + ), update_count=update_count, ) - visited.add(self.Meta.model_fields[related].to) - else: - rel = getattr(self, related) - update_count, visited = await self._update_and_follow( - rel=rel, follow=follow, visited=visited, update_count=update_count - ) - visited.add(rel.__class__) return update_count @staticmethod async def _update_and_follow( - rel: "Model", follow: bool, visited: Set, update_count: int - ) -> Tuple[int, Set]: + value: Union["Model", List["Model"]], + follow: bool, + save_all: bool, + relation_map: Dict, + update_count: int, + ) -> int: """ Internal method used in save_related to follow related models and update numbers of updated related instances. - :param rel: Model to follow - :type rel: Model + :param value: Model to follow + :type value: Model + :param relation_map: map of relations to follow + :type relation_map: Dict :param follow: flag to trigger deep save - by default only directly related models are saved with follow=True also related models of related models are saved :type follow: bool - :param visited: internal parameter for recursive calls - already visited models - :type visited: Set :param update_count: internal parameter for recursive calls - number of updated instances :type update_count: int :return: tuple of update count and visited - :rtype: Tuple[int, Set] + :rtype: int """ - if follow and rel.__class__ not in visited: - update_count = await rel.save_related( - follow=follow, visited=visited, update_count=update_count - ) - if not rel.saved: - await rel.upsert() - update_count += 1 - return update_count, visited + if not isinstance(value, list): + value = [value] - async def update(self: T, **kwargs: Any) -> T: + for val in value: + if (not val.saved or save_all) and not val.__pk_only__: + await val.upsert() + update_count += 1 + if follow: + update_count = await val.save_related( + follow=follow, + save_all=save_all, + relation_map=relation_map, + update_count=update_count, + ) + return update_count + + async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T: """ Performs update of Model instance in the database. Fields can be updated before or you can pass them as kwargs. @@ -197,6 +213,8 @@ class Model(ModelRow): Sets model save status to True. + :param _columns: list of columns to update, if None all are updated + :type _columns: List :raises ModelPersistenceError: If the pk column is not set :param kwargs: list of fields to update as field=value pairs @@ -217,6 +235,8 @@ class Model(ModelRow): ) self_fields = self._extract_model_db_fields() self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname)) + if _columns: + self_fields = {k: v for k, v in self_fields.items() if k in _columns} self_fields = self.translate_columns_to_aliases(self_fields) expr = self.Meta.table.update().values(**self_fields) expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname)) diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 35c5c7c..a2a1f1f 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -64,7 +64,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass the logic concerned with database connection and data persistance. """ - __slots__ = ("_orm_id", "_orm_saved", "_orm", "_pk_column") + __slots__ = ("_orm_id", "_orm_saved", "_orm", "_pk_column", "__pk_only__") if TYPE_CHECKING: # pragma no cover pk: Any @@ -134,6 +134,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass ) pk_only = kwargs.pop("__pk_only__", False) + object.__setattr__(self, "__pk_only__", pk_only) + excluded: Set[str] = kwargs.pop("__excluded__", set()) if "pk" in kwargs: @@ -267,9 +269,13 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass if item == "pk": return object.__getattribute__(self, "__dict__").get(self.Meta.pkname, None) if item in object.__getattribute__(self, "extract_related_names")(): - return self._extract_related_model_instead_of_field(item) + return object.__getattribute__( + self, "_extract_related_model_instead_of_field" + )(item) if item in object.__getattribute__(self, "extract_through_names")(): - return self._extract_related_model_instead_of_field(item) + return object.__getattribute__( + self, "_extract_related_model_instead_of_field" + )(item) if item in object.__getattribute__(self, "Meta").property_fields: value = object.__getattribute__(self, item) return value() if callable(value) else value @@ -337,8 +343,19 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass return ( self._orm_id == other._orm_id or (self.pk == other.pk and self.pk is not None) - or self.dict(exclude=self.extract_related_names()) - == other.dict(exclude=other.extract_related_names()) + or ( + (self.pk is None and other.pk is None) + and { + k: v + for k, v in self.__dict__.items() + if k not in self.extract_related_names() + } + == { + k: v + for k, v in other.__dict__.items() + if k not in other.extract_related_names() + } + ) ) @classmethod @@ -489,6 +506,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass @staticmethod def _extract_nested_models_from_list( + relation_map: Dict, models: MutableSequence, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None], @@ -509,14 +527,16 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass for model in models: try: result.append( - model.dict(nested=True, include=include, exclude=exclude,) + model.dict( + relation_map=relation_map, include=include, exclude=exclude, + ) ) except ReferenceError: # pragma no cover continue return result def _skip_ellipsis( - self, items: Union[Set, Dict, None], key: str + self, items: Union[Set, Dict, None], key: str, default_return: Any = None ) -> Union[Set, Dict, None]: """ Helper to traverse the include/exclude dictionaries. @@ -531,11 +551,11 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass :rtype: Union[Set, Dict, None] """ result = self.get_child(items, key) - return result if result is not Ellipsis else None + return result if result is not Ellipsis else default_return def _extract_nested_models( # noqa: CCR001 self, - nested: bool, + relation_map: Dict, dict_instance: Dict, include: Optional[Dict], exclude: Optional[Dict], @@ -559,18 +579,23 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass fields = self._get_related_not_excluded_fields(include=include, exclude=exclude) for field in fields: - if self.Meta.model_fields[field].virtual and nested: + if not relation_map or field not in relation_map: continue nested_model = getattr(self, field) if isinstance(nested_model, MutableSequence): dict_instance[field] = self._extract_nested_models_from_list( + relation_map=self._skip_ellipsis( # type: ignore + relation_map, field, default_return=dict() + ), models=nested_model, include=self._skip_ellipsis(include, field), exclude=self._skip_ellipsis(exclude, field), ) elif nested_model is not None: dict_instance[field] = nested_model.dict( - nested=True, + relation_map=self._skip_ellipsis( + relation_map, field, default_return=dict() + ), include=self._skip_ellipsis(include, field), exclude=self._skip_ellipsis(exclude, field), ) @@ -588,7 +613,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - nested: bool = False, + relation_map: Dict = None, ) -> "DictStrAny": # noqa: A003' """ @@ -613,14 +638,14 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass :type exclude_defaults: bool :param exclude_none: flag to exclude None values - passed to pydantic :type exclude_none: bool - :param nested: flag if the current model is nested - :type nested: bool + :param relation_map: map of the relations to follow to avoid circural deps + :type relation_map: Dict :return: :rtype: """ dict_instance = super().dict( include=include, - exclude=self._update_excluded_with_related_not_required(exclude, nested), + exclude=self._update_excluded_with_related(exclude), by_alias=by_alias, skip_defaults=skip_defaults, exclude_unset=exclude_unset, @@ -633,12 +658,19 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass if exclude and isinstance(exclude, Set): exclude = translate_list_to_dict(exclude) - dict_instance = self._extract_nested_models( - nested=nested, - dict_instance=dict_instance, - include=include, # type: ignore - exclude=exclude, # type: ignore + relation_map = ( + relation_map + if relation_map is not None + else translate_list_to_dict(self._iterate_related_models()) ) + pk_only = object.__getattribute__(self, "__pk_only__") + if relation_map and not pk_only: + dict_instance = self._extract_nested_models( + relation_map=relation_map, + dict_instance=dict_instance, + include=include, # type: ignore + exclude=exclude, # type: ignore + ) # include model properties as fields in dict if object.__getattribute__(self, "Meta").property_fields: @@ -714,7 +746,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass :rtype: Dict """ related_names = self.extract_related_names() - self_fields = self.dict(exclude=related_names) + self_fields = {k: v for k, v in self.__dict__.items() if k not in related_names} return self_fields def _extract_model_db_fields(self) -> Dict: diff --git a/ormar/models/quick_access_views.py b/ormar/models/quick_access_views.py index dd4cb5a..0dfc835 100644 --- a/ormar/models/quick_access_views.py +++ b/ormar/models/quick_access_views.py @@ -14,6 +14,7 @@ quick_access_set = { "__json_encoder__", "__post_root_validators__", "__pre_root_validators__", + "__private_attributes__", "__same__", "_calculate_keys", "_choices_fields", @@ -26,8 +27,10 @@ quick_access_set = { "_extract_related_model_instead_of_field", "_get_related_not_excluded_fields", "_get_value", + "_init_private_attributes", "_is_conversion_to_json_needed", "_iter", + "_iterate_related_models", "_orm", "_orm_id", "_orm_saved", @@ -40,8 +43,10 @@ quick_access_set = { "delete", "dict", "extract_related_names", + "extract_related_fields", "extract_through_names", "update_from_dict", + "get_child", "get_column_alias", "get_column_name_from_alias", "get_filtered_names_to_extract", @@ -52,9 +57,11 @@ quick_access_set = { "json", "keys", "load", + "load_all", "pk_column", "pk_type", "populate_default_values", + "prepare_model_to_save", "remove", "resolve_relation_field", "resolve_relation_name", @@ -62,6 +69,7 @@ quick_access_set = { "save_related", "saved", "set_save_status", + "signals", "translate_aliases_to_columns", "translate_columns_to_aliases", "update", diff --git a/ormar/models/traversible.py b/ormar/models/traversible.py new file mode 100644 index 0000000..3a90bb9 --- /dev/null +++ b/ormar/models/traversible.py @@ -0,0 +1,118 @@ +from typing import Any, List, Optional, TYPE_CHECKING, Type + +if TYPE_CHECKING: # pragma no cover + from ormar.models.mixins.relation_mixin import RelationMixin + + +class NodeList: + """ + Helper class that helps with iterating nested models + """ + + def __init__(self) -> None: + self.node_list: List["Node"] = [] + + def __getitem__(self, item: Any) -> Any: + return self.node_list.__getitem__(item) + + def add( + self, + node_class: Type["RelationMixin"], + relation_name: str = None, + parent_node: "Node" = None, + ) -> "Node": + """ + Adds new Node or returns the existing one + + :param node_class: Model in current node + :type node_class: ormar.models.metaclass.ModelMetaclass + :param relation_name: name of the current relation + :type relation_name: str + :param parent_node: parent node + :type parent_node: Optional[Node] + :return: returns new or already existing node + :rtype: Node + """ + existing_node = self.find( + relation_name=relation_name, node_class=node_class, parent_node=parent_node + ) + if not existing_node: + current_node = Node( + node_class=node_class, + relation_name=relation_name, + parent_node=parent_node, + ) + self.node_list.append(current_node) + return current_node + return existing_node # pragma: no cover + + def find( + self, + node_class: Type["RelationMixin"], + relation_name: Optional[str] = None, + parent_node: "Node" = None, + ) -> Optional["Node"]: + """ + Searches for existing node with given parameters + + :param node_class: Model in current node + :type node_class: ormar.models.metaclass.ModelMetaclass + :param relation_name: name of the current relation + :type relation_name: str + :param parent_node: parent node + :type parent_node: Optional[Node] + :return: returns already existing node or None + :rtype: Optional[Node] + """ + for node in self.node_list: + if ( + node.node_class == node_class + and node.parent_node == parent_node + and node.relation_name == relation_name + ): + return node # pragma: no cover + return None + + +class Node: + def __init__( + self, + node_class: Type["RelationMixin"], + relation_name: str = None, + parent_node: "Node" = None, + ) -> None: + self.relation_name = relation_name + self.node_class = node_class + self.parent_node = parent_node + self.visited_children: List["Node"] = [] + if self.parent_node: + self.parent_node.visited_children.append(self) + + def __repr__(self) -> str: # pragma: no cover + return ( + f"{self.node_class.get_name(lower=False)}, " + f"relation:{self.relation_name}, " + f"parent: {self.parent_node}" + ) + + def visited(self, relation_name: str) -> bool: + """ + Checks if given relation was already visited. + + Relation was visited if it's name is in current node children. + + Relation was visited if one of the parent node had the same Model class + + :param relation_name: name of relation + :type relation_name: str + :return: result of the check + :rtype: bool + """ + target_model = self.node_class.Meta.model_fields[relation_name].to + if self.parent_node: + node = self + while node.parent_node: + node = node.parent_node + if node.node_class == target_model: + return True + return False diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index e05b108..c325a20 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -650,7 +650,7 @@ class QuerySet(Generic[T]): :return: number of updated rows :rtype: int """ - if not each and not self.filter_clauses: + if not each and not (self.filter_clauses or self.exclude_clauses): raise QueryDefinitionError( "You cannot update without filtering the queryset first. " "If you want to update all rows use update(each=True, **kwargs)" @@ -666,6 +666,9 @@ class QuerySet(Generic[T]): expr = FilterQuery(filter_clauses=self.filter_clauses).apply( self.table.update().values(**updates) ) + expr = FilterQuery(filter_clauses=self.exclude_clauses, exclude=True).apply( + expr + ) return await self.database.execute(expr) async def delete(self, each: bool = False, **kwargs: Any) -> int: @@ -684,7 +687,7 @@ class QuerySet(Generic[T]): """ if kwargs: return await self.filter(**kwargs).delete() - if not each and not self.filter_clauses: + if not each and not (self.filter_clauses or self.exclude_clauses): raise QueryDefinitionError( "You cannot delete without filtering the queryset first. " "If you want to delete all rows use delete(each=True)" @@ -692,6 +695,9 @@ class QuerySet(Generic[T]): expr = FilterQuery(filter_clauses=self.filter_clauses).apply( self.table.delete() ) + expr = FilterQuery(filter_clauses=self.exclude_clauses, exclude=True).apply( + expr + ) return await self.database.execute(expr) def paginate(self, page: int, page_size: int = 20) -> "QuerySet[T]": diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index f5f00ac..dc24fd9 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: # pragma no cover def check_node_not_dict_or_not_last_node( - part: str, parts: List, current_level: Any + part: str, is_last: bool, current_level: Any ) -> bool: """ Checks if given name is not present in the current level of the structure. @@ -36,7 +36,7 @@ def check_node_not_dict_or_not_last_node( :return: result of the check :rtype: bool """ - return (part not in current_level and part != parts[-1]) or ( + return (part not in current_level and not is_last) or ( part in current_level and not isinstance(current_level[part], dict) ) @@ -71,9 +71,10 @@ def translate_list_to_dict( # noqa: CCR001 else: def_val = "asc" - for part in parts: + for ind, part in enumerate(parts): + is_last = ind == len(parts) - 1 if check_node_not_dict_or_not_last_node( - part=part, parts=parts, current_level=current_level + part=part, is_last=is_last, current_level=current_level ): current_level[part] = dict() elif part not in current_level: @@ -127,6 +128,49 @@ def update(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001 return current_dict +def subtract_dict(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001 + """ + Update one dict with another but with regard for nested keys. + + That way nested sets are unionised, dicts updated and + only other values are overwritten. + + :param current_dict: dict to update + :type current_dict: Dict[str, ellipsis] + :param updating_dict: dict with values to update + :type updating_dict: Dict + :return: combination of both dicts + :rtype: Dict + """ + for key, value in updating_dict.items(): + old_key = current_dict.get(key, {}) + new_value: Optional[Union[Dict, Set]] = None + if not old_key: + continue + if isinstance(value, set) and isinstance(old_key, set): + new_value = old_key.difference(value) + elif isinstance(value, (set, collections.abc.Mapping)) and isinstance( + old_key, (set, collections.abc.Mapping) + ): + value = ( + convert_set_to_required_dict(value) + if not isinstance(value, collections.abc.Mapping) + else value + ) + old_key = ( + convert_set_to_required_dict(old_key) + if not isinstance(old_key, collections.abc.Mapping) + else old_key + ) + new_value = subtract_dict(old_key, value) + + if new_value: + current_dict[key] = new_value + else: + current_dict.pop(key, None) + return current_dict + + def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> Dict: """ Converts the list into dictionary and later performs special update, where diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index 7874b06..953b43d 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -389,7 +389,11 @@ class QuerysetProxy(Generic[T]): :rtype: int """ # queryset proxy always have one filter for pk of parent model - if not each and len(self.queryset.filter_clauses) == 1: + if ( + not each + and (len(self.queryset.filter_clauses) + len(self.queryset.exclude_clauses)) + == 1 + ): raise QueryDefinitionError( "You cannot update without filtering the queryset first. " "If you want to update all rows use update(each=True, **kwargs)" diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index fa5475f..900f8f3 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -127,7 +127,7 @@ class RelationProxy(Generic[T], list): related_field = self.relation.to.Meta.model_fields[related_field_name] pkname = self._owner.get_column_alias(self._owner.Meta.pkname) self._check_if_model_saved() - kwargs = {f"{related_field.get_alias()}__{pkname}": self._owner.pk} + kwargs = {f"{related_field.name}__{pkname}": self._owner.pk} queryset = ( ormar.QuerySet( model_cls=self.relation.to, proxy_source_model=self._owner.__class__ diff --git a/pydoc-markdown.yml b/pydoc-markdown.yml index 6f0188f..7243a2a 100644 --- a/pydoc-markdown.yml +++ b/pydoc-markdown.yml @@ -30,6 +30,9 @@ renderer: - title: Excludable Items contents: - models.excludable.* + - title: Traversible + contents: + - models.traversible.* - title: Model Table Proxy contents: - models.modelproxy.* diff --git a/tests/test_deferred/__init__.py b/tests/test_deferred/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_forward_cross_refs.py b/tests/test_deferred/test_forward_cross_refs.py similarity index 100% rename from tests/test_forward_cross_refs.py rename to tests/test_deferred/test_forward_cross_refs.py diff --git a/tests/test_forward_refs.py b/tests/test_deferred/test_forward_refs.py similarity index 100% rename from tests/test_forward_refs.py rename to tests/test_deferred/test_forward_refs.py diff --git a/tests/test_more_same_table_joins.py b/tests/test_deferred/test_more_same_table_joins.py similarity index 100% rename from tests/test_more_same_table_joins.py rename to tests/test_deferred/test_more_same_table_joins.py diff --git a/tests/test_same_table_joins.py b/tests/test_deferred/test_same_table_joins.py similarity index 100% rename from tests/test_same_table_joins.py rename to tests/test_deferred/test_same_table_joins.py diff --git a/tests/test_encryption/__init__.py b/tests/test_encryption/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_encrypted_columns.py b/tests/test_encryption/test_encrypted_columns.py similarity index 100% rename from tests/test_encrypted_columns.py rename to tests/test_encryption/test_encrypted_columns.py diff --git a/tests/test_exclude_include_dict/__init__.py b/tests/test_exclude_include_dict/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_dumping_model_to_dict.py b/tests/test_exclude_include_dict/test_dumping_model_to_dict.py similarity index 100% rename from tests/test_dumping_model_to_dict.py rename to tests/test_exclude_include_dict/test_dumping_model_to_dict.py diff --git a/tests/test_excludable_items.py b/tests/test_exclude_include_dict/test_excludable_items.py similarity index 100% rename from tests/test_excludable_items.py rename to tests/test_exclude_include_dict/test_excludable_items.py diff --git a/tests/test_excluding_fields_in_fastapi.py b/tests/test_exclude_include_dict/test_excluding_fields_in_fastapi.py similarity index 100% rename from tests/test_excluding_fields_in_fastapi.py rename to tests/test_exclude_include_dict/test_excluding_fields_in_fastapi.py diff --git a/tests/test_excluding_fields_with_default.py b/tests/test_exclude_include_dict/test_excluding_fields_with_default.py similarity index 100% rename from tests/test_excluding_fields_with_default.py rename to tests/test_exclude_include_dict/test_excluding_fields_with_default.py diff --git a/tests/test_excluding_subset_of_columns.py b/tests/test_exclude_include_dict/test_excluding_subset_of_columns.py similarity index 100% rename from tests/test_excluding_subset_of_columns.py rename to tests/test_exclude_include_dict/test_excluding_subset_of_columns.py diff --git a/tests/test_fastapi/__init__.py b/tests/test_fastapi/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_choices_schema.py b/tests/test_fastapi/test_choices_schema.py similarity index 100% rename from tests/test_choices_schema.py rename to tests/test_fastapi/test_choices_schema.py diff --git a/tests/test_docs_with_multiple_relations_to_one.py b/tests/test_fastapi/test_docs_with_multiple_relations_to_one.py similarity index 100% rename from tests/test_docs_with_multiple_relations_to_one.py rename to tests/test_fastapi/test_docs_with_multiple_relations_to_one.py diff --git a/tests/test_fastapi_docs.py b/tests/test_fastapi/test_fastapi_docs.py similarity index 100% rename from tests/test_fastapi_docs.py rename to tests/test_fastapi/test_fastapi_docs.py diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi/test_fastapi_usage.py similarity index 89% rename from tests/test_fastapi_usage.py rename to tests/test_fastapi/test_fastapi_usage.py index 503e582..f52003a 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi/test_fastapi_usage.py @@ -48,7 +48,11 @@ def test_read_main(): ) assert response.status_code == 200 assert response.json() == { - "category": {"id": None, "name": "test cat"}, + "category": { + "id": None, + "items": [{"id": 1, "name": "test"}], + "name": "test cat", + }, "id": 1, "name": "test", } diff --git a/tests/test_inheritance_concrete_fastapi.py b/tests/test_fastapi/test_inheritance_concrete_fastapi.py similarity index 98% rename from tests/test_inheritance_concrete_fastapi.py rename to tests/test_fastapi/test_inheritance_concrete_fastapi.py index 217fe3c..a4a8310 100644 --- a/tests/test_inheritance_concrete_fastapi.py +++ b/tests/test_fastapi/test_inheritance_concrete_fastapi.py @@ -6,7 +6,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient from tests.settings import DATABASE_URL -from tests.test_inheritance_concrete import ( # type: ignore +from tests.test_inheritance.test_inheritance_concrete import ( # type: ignore Category, Subject, Person, diff --git a/tests/test_inheritance_mixins_fastapi.py b/tests/test_fastapi/test_inheritance_mixins_fastapi.py similarity index 94% rename from tests/test_inheritance_mixins_fastapi.py rename to tests/test_fastapi/test_inheritance_mixins_fastapi.py index bfd6979..681f5ef 100644 --- a/tests/test_inheritance_mixins_fastapi.py +++ b/tests/test_fastapi/test_inheritance_mixins_fastapi.py @@ -6,7 +6,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient from tests.settings import DATABASE_URL -from tests.test_inheritance_mixins import Category, Subject, metadata, db as database # type: ignore +from tests.test_inheritance.test_inheritance_mixins import Category, Subject, metadata, db as database # type: ignore app = FastAPI() app.state.database = database diff --git a/tests/test_json_field_fastapi.py b/tests/test_fastapi/test_json_field_fastapi.py similarity index 100% rename from tests/test_json_field_fastapi.py rename to tests/test_fastapi/test_json_field_fastapi.py diff --git a/tests/test_more_reallife_fastapi.py b/tests/test_fastapi/test_more_reallife_fastapi.py similarity index 100% rename from tests/test_more_reallife_fastapi.py rename to tests/test_fastapi/test_more_reallife_fastapi.py diff --git a/tests/test_wekref_exclusion.py b/tests/test_fastapi/test_wekref_exclusion.py similarity index 100% rename from tests/test_wekref_exclusion.py rename to tests/test_fastapi/test_wekref_exclusion.py diff --git a/tests/test_inheritance/__init__.py b/tests/test_inheritance/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_inheritance_concrete.py b/tests/test_inheritance/test_inheritance_concrete.py similarity index 100% rename from tests/test_inheritance_concrete.py rename to tests/test_inheritance/test_inheritance_concrete.py diff --git a/tests/test_inheritance_mixins.py b/tests/test_inheritance/test_inheritance_mixins.py similarity index 100% rename from tests/test_inheritance_mixins.py rename to tests/test_inheritance/test_inheritance_mixins.py diff --git a/tests/test_meta_constraints/__init__.py b/tests/test_meta_constraints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_unique_constraints.py b/tests/test_meta_constraints/test_unique_constraints.py similarity index 100% rename from tests/test_unique_constraints.py rename to tests/test_meta_constraints/test_unique_constraints.py diff --git a/tests/test_model_definition/__init__.py b/tests/test_model_definition/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_model_definition/pks_and_fks/__init__.py b/tests/test_model_definition/pks_and_fks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_non_integer_pkey.py b/tests/test_model_definition/pks_and_fks/test_non_integer_pkey.py similarity index 100% rename from tests/test_non_integer_pkey.py rename to tests/test_model_definition/pks_and_fks/test_non_integer_pkey.py diff --git a/tests/test_saving_string_pks.py b/tests/test_model_definition/pks_and_fks/test_saving_string_pks.py similarity index 100% rename from tests/test_saving_string_pks.py rename to tests/test_model_definition/pks_and_fks/test_saving_string_pks.py diff --git a/tests/test_uuid_fks.py b/tests/test_model_definition/pks_and_fks/test_uuid_fks.py similarity index 100% rename from tests/test_uuid_fks.py rename to tests/test_model_definition/pks_and_fks/test_uuid_fks.py diff --git a/tests/test_aliases.py b/tests/test_model_definition/test_aliases.py similarity index 100% rename from tests/test_aliases.py rename to tests/test_model_definition/test_aliases.py diff --git a/tests/test_columns.py b/tests/test_model_definition/test_columns.py similarity index 100% rename from tests/test_columns.py rename to tests/test_model_definition/test_columns.py diff --git a/tests/test_model_definition.py b/tests/test_model_definition/test_model_definition.py similarity index 100% rename from tests/test_model_definition.py rename to tests/test_model_definition/test_model_definition.py diff --git a/tests/test_models.py b/tests/test_model_definition/test_models.py similarity index 100% rename from tests/test_models.py rename to tests/test_model_definition/test_models.py diff --git a/tests/test_properties.py b/tests/test_model_definition/test_properties.py similarity index 100% rename from tests/test_properties.py rename to tests/test_model_definition/test_properties.py diff --git a/tests/test_pydantic_only_fields.py b/tests/test_model_definition/test_pydantic_only_fields.py similarity index 100% rename from tests/test_pydantic_only_fields.py rename to tests/test_model_definition/test_pydantic_only_fields.py diff --git a/tests/test_save_status.py b/tests/test_model_definition/test_save_status.py similarity index 100% rename from tests/test_save_status.py rename to tests/test_model_definition/test_save_status.py diff --git a/tests/test_saving_nullable_fields.py b/tests/test_model_definition/test_saving_nullable_fields.py similarity index 100% rename from tests/test_saving_nullable_fields.py rename to tests/test_model_definition/test_saving_nullable_fields.py diff --git a/tests/test_server_default.py b/tests/test_model_definition/test_server_default.py similarity index 100% rename from tests/test_server_default.py rename to tests/test_model_definition/test_server_default.py diff --git a/tests/test_model_methods/__init__.py b/tests/test_model_methods/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_load_all.py b/tests/test_model_methods/test_load_all.py similarity index 100% rename from tests/test_load_all.py rename to tests/test_model_methods/test_load_all.py diff --git a/tests/test_save_related.py b/tests/test_model_methods/test_save_related.py similarity index 84% rename from tests/test_save_related.py rename to tests/test_model_methods/test_save_related.py index fc9f3ff..774d167 100644 --- a/tests/test_save_related.py +++ b/tests/test_model_methods/test_save_related.py @@ -91,6 +91,14 @@ async def test_saving_related_fk_rel(): assert count == 1 assert comp.hq.saved + comp.hq.name = "Suburbs 2" + assert not comp.hq.saved + assert comp.saved + + count = await comp.save_related(exclude={"hq"}) + assert count == 0 + assert not comp.hq.saved + @pytest.mark.asyncio async def test_saving_many_to_many(): @@ -110,6 +118,9 @@ async def test_saving_many_to_many(): count = await hq.save_related() assert count == 0 + count = await hq.save_related(save_all=True) + assert count == 2 + hq.nicks[0].name = "Kabucha" hq.nicks[1].name = "Kabucha2" assert not hq.nicks[0].saved @@ -120,6 +131,16 @@ async def test_saving_many_to_many(): assert hq.nicks[0].saved assert hq.nicks[1].saved + hq.nicks[0].name = "Kabucha a" + hq.nicks[1].name = "Kabucha2 a" + assert not hq.nicks[0].saved + assert not hq.nicks[1].saved + + count = await hq.save_related(exclude={"nicks": ...}) + assert count == 0 + assert not hq.nicks[0].saved + assert not hq.nicks[1].saved + @pytest.mark.asyncio async def test_saving_reversed_relation(): @@ -208,3 +229,16 @@ async def test_saving_nested(): assert hq.nicks[0].level.saved assert hq.nicks[1].saved assert hq.nicks[1].level.saved + + hq.nicks[0].level.name = "Low 2" + hq.nicks[1].level.name = "Medium 2" + assert not hq.nicks[0].level.saved + assert not hq.nicks[1].level.saved + assert hq.nicks[0].saved + assert hq.nicks[1].saved + count = await hq.save_related(follow=True, exclude={"nicks": {"level"}}) + assert count == 0 + assert hq.nicks[0].saved + assert not hq.nicks[0].level.saved + assert hq.nicks[1].saved + assert not hq.nicks[1].level.saved diff --git a/tests/test_model_methods/test_update.py b/tests/test_model_methods/test_update.py new file mode 100644 index 0000000..391baf7 --- /dev/null +++ b/tests/test_model_methods/test_update.py @@ -0,0 +1,111 @@ +from typing import Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Director(ormar.Model): + class Meta: + tablename = "directors" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="first_name") + last_name: str = ormar.String(max_length=100, nullable=False, name="last_name") + + +class Movie(ormar.Model): + class Meta: + tablename = "movies" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="title") + year: int = ormar.Integer() + profit: float = ormar.Float() + director: Optional[Director] = ormar.ForeignKey(Director) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_updating_selected_columns(): + async with database: + director1 = await Director(name="Peter", last_name="Jackson").save() + director2 = await Director(name="James", last_name="Cameron").save() + + lotr = await Movie( + name="LOTR", year=2001, director=director1, profit=1.140 + ).save() + + lotr.name = "Lord of The Rings" + lotr.year = 2003 + lotr.profit = 1.212 + + await lotr.update(_columns=["name"]) + + # before reload the field has current value even if not saved + assert lotr.year == 2003 + + lotr = await Movie.objects.get() + assert lotr.name == "Lord of The Rings" + assert lotr.year == 2001 + assert round(lotr.profit, 3) == 1.140 + assert lotr.director.pk == director1.pk + + lotr.year = 2003 + lotr.profit = 1.212 + lotr.director = director2 + + await lotr.update(_columns=["year", "profit"]) + lotr = await Movie.objects.get() + assert lotr.year == 2003 + assert round(lotr.profit, 3) == 1.212 + assert lotr.director.pk == director1.pk + + +@pytest.mark.asyncio +async def test_not_passing_columns_or_empty_list_saves_all(): + async with database: + director = await Director(name="James", last_name="Cameron").save() + terminator = await Movie( + name="Terminator", year=1984, director=director, profit=0.078 + ).save() + + terminator.name = "Terminator 2" + terminator.year = 1991 + terminator.profit = 0.520 + + await terminator.update(_columns=[]) + + terminator = await Movie.objects.get() + assert terminator.name == "Terminator 2" + assert terminator.year == 1991 + assert round(terminator.profit, 3) == 0.520 + + terminator.name = "Terminator 3" + terminator.year = 2003 + terminator.profit = 0.433 + + await terminator.update() + + terminator = await terminator.load() + assert terminator.name == "Terminator 3" + assert terminator.year == 2003 + assert round(terminator.profit, 3) == 0.433 diff --git a/tests/test_ordering/__init__.py b/tests/test_ordering/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_default_model_order.py b/tests/test_ordering/test_default_model_order.py similarity index 100% rename from tests/test_default_model_order.py rename to tests/test_ordering/test_default_model_order.py diff --git a/tests/test_default_relation_order.py b/tests/test_ordering/test_default_relation_order.py similarity index 100% rename from tests/test_default_relation_order.py rename to tests/test_ordering/test_default_relation_order.py diff --git a/tests/test_default_through_relation_order.py b/tests/test_ordering/test_default_through_relation_order.py similarity index 100% rename from tests/test_default_through_relation_order.py rename to tests/test_ordering/test_default_through_relation_order.py diff --git a/tests/test_proper_order_of_sorting_apply.py b/tests/test_ordering/test_proper_order_of_sorting_apply.py similarity index 100% rename from tests/test_proper_order_of_sorting_apply.py rename to tests/test_ordering/test_proper_order_of_sorting_apply.py diff --git a/tests/test_queries/__init__.py b/tests/test_queries/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_aggr_functions.py b/tests/test_queries/test_aggr_functions.py similarity index 100% rename from tests/test_aggr_functions.py rename to tests/test_queries/test_aggr_functions.py diff --git a/tests/test_queries/test_deep_relations_select_all.py b/tests/test_queries/test_deep_relations_select_all.py new file mode 100644 index 0000000..948b81f --- /dev/null +++ b/tests/test_queries/test_deep_relations_select_all.py @@ -0,0 +1,158 @@ +import databases +import pytest +from sqlalchemy import func + +import ormar +import sqlalchemy +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Chart(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "charts" + database = database + metadata = metadata + + chart_id = ormar.Integer(primary_key=True, autoincrement=True) + name = ormar.String(max_length=200, unique=True, index=True) + query_text = ormar.Text() + datasets = ormar.JSON() + layout = ormar.JSON() + data_config = ormar.JSON() + created_date = ormar.DateTime(server_default=func.now()) + library = ormar.String(max_length=200, default="plotly") + used_filters = ormar.JSON() + + +class Report(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "reports" + database = database + metadata = metadata + + report_id = ormar.Integer(primary_key=True, autoincrement=True) + name = ormar.String(max_length=200, unique=True, index=True) + filters_position = ormar.String(max_length=200) + created_date = ormar.DateTime(server_default=func.now()) + + +class Language(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "languages" + database = database + metadata = metadata + + language_id = ormar.Integer(primary_key=True, autoincrement=True) + code = ormar.String(max_length=5) + name = ormar.String(max_length=200) + + +class TranslationNode(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "translation_nodes" + database = database + metadata = metadata + + node_id = ormar.Integer(primary_key=True, autoincrement=True) + node_type = ormar.String(max_length=200) + + +class Translation(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "translations" + database = database + metadata = metadata + + translation_id = ormar.Integer(primary_key=True, autoincrement=True) + node_id = ormar.ForeignKey(TranslationNode, related_name="translations") + language = ormar.ForeignKey(Language, name="language_id") + value = ormar.String(max_length=500) + + +class Filter(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "filters" + database = database + metadata = metadata + + filter_id = ormar.Integer(primary_key=True, autoincrement=True) + name = ormar.String(max_length=200, unique=True, index=True) + label = ormar.String(max_length=200) + query_text = ormar.Text() + allow_multiselect = ormar.Boolean(default=True) + created_date = ormar.DateTime(server_default=func.now()) + is_dynamic = ormar.Boolean(default=True) + is_date = ormar.Boolean(default=False) + translation = ormar.ForeignKey(TranslationNode, name="translation_node_id") + + +class FilterValue(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "filter_values" + database = database + metadata = metadata + + value_id = ormar.Integer(primary_key=True, autoincrement=True) + value = ormar.String(max_length=300) + label = ormar.String(max_length=300) + filter = ormar.ForeignKey(Filter, name="filter_id", related_name="values") + translation = ormar.ForeignKey(TranslationNode, name="translation_node_id") + + +class FilterXReport(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "filters_x_reports" + database = database + metadata = metadata + + filter_x_report_id = ormar.Integer(primary_key=True) + filter = ormar.ForeignKey(Filter, name="filter_id", related_name="reports") + report = ormar.ForeignKey(Report, name="report_id", related_name="filters") + sort_order = ormar.Integer() + default_value = ormar.Text() + is_visible = ormar.Boolean() + + +class ChartXReport(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "charts_x_reports" + database = database + metadata = metadata + + chart_x_report_id = ormar.Integer(primary_key=True) + chart = ormar.ForeignKey(Chart, name="chart_id", related_name="reports") + report = ormar.ForeignKey(Report, name="report_id", related_name="charts") + sort_order = ormar.Integer() + width = ormar.Integer() + + +class ChartColumn(ormar.Model): + class Meta(ormar.ModelMeta): + tablename = "charts_columns" + database = database + metadata = metadata + + column_id = ormar.Integer(primary_key=True, autoincrement=True) + chart = ormar.ForeignKey(Chart, name="chart_id", related_name="columns") + column_name = ormar.String(max_length=200) + column_type = ormar.String(max_length=200) + translation = ormar.ForeignKey(TranslationNode, name="translation_node_id") + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_saving_related_fk_rel(): + async with database: + async with database.transaction(force_rollback=True): + await Report.objects.select_all(follow=True).all() diff --git a/tests/test_filter_groups.py b/tests/test_queries/test_filter_groups.py similarity index 100% rename from tests/test_filter_groups.py rename to tests/test_queries/test_filter_groups.py diff --git a/tests/test_isnull_filter.py b/tests/test_queries/test_isnull_filter.py similarity index 100% rename from tests/test_isnull_filter.py rename to tests/test_queries/test_isnull_filter.py diff --git a/tests/test_queries/test_nested_reverse_relations.py b/tests/test_queries/test_nested_reverse_relations.py new file mode 100644 index 0000000..4c87b92 --- /dev/null +++ b/tests/test_queries/test_nested_reverse_relations.py @@ -0,0 +1,101 @@ +from typing import Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL) +metadata = sqlalchemy.MetaData() + + +class BaseMeta(ormar.ModelMeta): + metadata = metadata + database = database + + +class DataSource(ormar.Model): + class Meta(BaseMeta): + tablename = "datasources" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=200, unique=True, index=True) + + +class DataSourceTable(ormar.Model): + class Meta(BaseMeta): + tablename = "source_tables" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=200, index=True) + source: Optional[DataSource] = ormar.ForeignKey( + DataSource, name="source_id", related_name="tables", ondelete="CASCADE", + ) + + +class DataSourceTableColumn(ormar.Model): + class Meta(BaseMeta): + tablename = "source_columns" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=200, index=True) + data_type: str = ormar.String(max_length=200) + table: Optional[DataSourceTable] = ormar.ForeignKey( + DataSourceTable, name="table_id", related_name="columns", ondelete="CASCADE", + ) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_double_nested_reverse_relation(): + async with database: + data_source = await DataSource(name="local").save() + test_tables = [ + { + "name": "test1", + "columns": [ + {"name": "col1", "data_type": "test"}, + {"name": "col2", "data_type": "test2"}, + {"name": "col3", "data_type": "test3"}, + ], + }, + { + "name": "test2", + "columns": [ + {"name": "col4", "data_type": "test"}, + {"name": "col5", "data_type": "test2"}, + {"name": "col6", "data_type": "test3"}, + ], + }, + ] + data_source.tables = test_tables + await data_source.save_related(save_all=True, follow=True) + + tables = await DataSourceTable.objects.all() + assert len(tables) == 2 + + columns = await DataSourceTableColumn.objects.all() + assert len(columns) == 6 + + data_source = ( + await DataSource.objects.select_related("tables__columns") + .filter(tables__name__in=["test1", "test2"], name="local") + .get() + ) + assert len(data_source.tables) == 2 + assert len(data_source.tables[0].columns) == 3 + assert data_source.tables[0].columns[0].name == "col1" + assert data_source.tables[0].columns[2].name == "col3" + assert len(data_source.tables[1].columns) == 3 + assert data_source.tables[1].columns[0].name == "col4" + assert data_source.tables[1].columns[2].name == "col6" diff --git a/tests/test_queries/test_non_relation_fields_not_merged.py b/tests/test_queries/test_non_relation_fields_not_merged.py new file mode 100644 index 0000000..38cda87 --- /dev/null +++ b/tests/test_queries/test_non_relation_fields_not_merged.py @@ -0,0 +1,53 @@ +from typing import Dict, List, Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL) +metadata = sqlalchemy.MetaData() + + +class BaseMeta(ormar.ModelMeta): + metadata = metadata + database = database + + +class Chart(ormar.Model): + class Meta(BaseMeta): + tablename = "authors" + + id: int = ormar.Integer(primary_key=True) + datasets = ormar.JSON() + + +class Config(ormar.Model): + class Meta(BaseMeta): + tablename = "books" + + id: int = ormar.Integer(primary_key=True) + chart: Optional[Chart] = ormar.ForeignKey(Chart) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_list_field_that_is_not_relation_is_not_merged(): + async with database: + chart = await Chart.objects.create(datasets=[{"test": "ok"}]) + await Config.objects.create(chart=chart) + await Config.objects.create(chart=chart) + + chart2 = await Chart.objects.select_related("configs").get() + assert len(chart2.datasets) == 1 + assert chart2.datasets == [{"test": "ok"}] diff --git a/tests/test_or_filters.py b/tests/test_queries/test_or_filters.py similarity index 100% rename from tests/test_or_filters.py rename to tests/test_queries/test_or_filters.py diff --git a/tests/test_order_by.py b/tests/test_queries/test_order_by.py similarity index 100% rename from tests/test_order_by.py rename to tests/test_queries/test_order_by.py diff --git a/tests/test_pagination.py b/tests/test_queries/test_pagination.py similarity index 100% rename from tests/test_pagination.py rename to tests/test_queries/test_pagination.py diff --git a/tests/test_queryproxy_on_m2m_models.py b/tests/test_queries/test_queryproxy_on_m2m_models.py similarity index 100% rename from tests/test_queryproxy_on_m2m_models.py rename to tests/test_queries/test_queryproxy_on_m2m_models.py diff --git a/tests/test_queryset_level_methods.py b/tests/test_queries/test_queryset_level_methods.py similarity index 100% rename from tests/test_queryset_level_methods.py rename to tests/test_queries/test_queryset_level_methods.py diff --git a/tests/test_reserved_sql_keywords_escaped.py b/tests/test_queries/test_reserved_sql_keywords_escaped.py similarity index 100% rename from tests/test_reserved_sql_keywords_escaped.py rename to tests/test_queries/test_reserved_sql_keywords_escaped.py diff --git a/tests/test_reverse_fk_queryset.py b/tests/test_queries/test_reverse_fk_queryset.py similarity index 96% rename from tests/test_reverse_fk_queryset.py rename to tests/test_queries/test_reverse_fk_queryset.py index dd2c49b..80193bc 100644 --- a/tests/test_reverse_fk_queryset.py +++ b/tests/test_queries/test_reverse_fk_queryset.py @@ -18,7 +18,7 @@ class Album(ormar.Model): metadata = metadata database = database - id: int = ormar.Integer(primary_key=True) + id: int = ormar.Integer(primary_key=True, name="album_id") name: str = ormar.String(max_length=100) is_best_seller: bool = ormar.Boolean(default=False) @@ -29,7 +29,7 @@ class Writer(ormar.Model): metadata = metadata database = database - id: int = ormar.Integer(primary_key=True) + id: int = ormar.Integer(primary_key=True, name="writer_id") name: str = ormar.String(max_length=100) @@ -40,11 +40,11 @@ class Track(ormar.Model): database = database id: int = ormar.Integer(primary_key=True) - album: Optional[Album] = ormar.ForeignKey(Album) + album: Optional[Album] = ormar.ForeignKey(Album, name="album_id") title: str = ormar.String(max_length=100) position: int = ormar.Integer() play_count: int = ormar.Integer(nullable=True) - written_by: Optional[Writer] = ormar.ForeignKey(Writer) + written_by: Optional[Writer] = ormar.ForeignKey(Writer, name="writer_id") async def get_sample_data(): diff --git a/tests/test_selecting_subset_of_columns.py b/tests/test_queries/test_selecting_subset_of_columns.py similarity index 100% rename from tests/test_selecting_subset_of_columns.py rename to tests/test_queries/test_selecting_subset_of_columns.py diff --git a/tests/test_relations/__init__.py b/tests/test_relations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_cascades.py b/tests/test_relations/test_cascades.py similarity index 100% rename from tests/test_cascades.py rename to tests/test_relations/test_cascades.py diff --git a/tests/test_database_fk_creation.py b/tests/test_relations/test_database_fk_creation.py similarity index 100% rename from tests/test_database_fk_creation.py rename to tests/test_relations/test_database_fk_creation.py diff --git a/tests/test_foreign_keys.py b/tests/test_relations/test_foreign_keys.py similarity index 100% rename from tests/test_foreign_keys.py rename to tests/test_relations/test_foreign_keys.py diff --git a/tests/test_m2m_through_fields.py b/tests/test_relations/test_m2m_through_fields.py similarity index 100% rename from tests/test_m2m_through_fields.py rename to tests/test_relations/test_m2m_through_fields.py diff --git a/tests/test_many_to_many.py b/tests/test_relations/test_many_to_many.py similarity index 100% rename from tests/test_many_to_many.py rename to tests/test_relations/test_many_to_many.py diff --git a/tests/test_prefetch_related.py b/tests/test_relations/test_prefetch_related.py similarity index 100% rename from tests/test_prefetch_related.py rename to tests/test_relations/test_prefetch_related.py diff --git a/tests/test_prefetch_related_multiple_models_relation.py b/tests/test_relations/test_prefetch_related_multiple_models_relation.py similarity index 100% rename from tests/test_prefetch_related_multiple_models_relation.py rename to tests/test_relations/test_prefetch_related_multiple_models_relation.py diff --git a/tests/test_relations_default_exception.py b/tests/test_relations/test_relations_default_exception.py similarity index 100% rename from tests/test_relations_default_exception.py rename to tests/test_relations/test_relations_default_exception.py diff --git a/tests/test_saving_related.py b/tests/test_relations/test_saving_related.py similarity index 100% rename from tests/test_saving_related.py rename to tests/test_relations/test_saving_related.py diff --git a/tests/test_select_related_with_limit.py b/tests/test_relations/test_select_related_with_limit.py similarity index 100% rename from tests/test_select_related_with_limit.py rename to tests/test_relations/test_select_related_with_limit.py diff --git a/tests/test_select_related_with_m2m_and_pk_name_set.py b/tests/test_relations/test_select_related_with_m2m_and_pk_name_set.py similarity index 100% rename from tests/test_select_related_with_m2m_and_pk_name_set.py rename to tests/test_relations/test_select_related_with_m2m_and_pk_name_set.py diff --git a/tests/test_selecting_proper_table_prefix.py b/tests/test_relations/test_selecting_proper_table_prefix.py similarity index 100% rename from tests/test_selecting_proper_table_prefix.py rename to tests/test_relations/test_selecting_proper_table_prefix.py diff --git a/tests/test_through_relations_fail.py b/tests/test_relations/test_through_relations_fail.py similarity index 100% rename from tests/test_through_relations_fail.py rename to tests/test_relations/test_through_relations_fail.py diff --git a/tests/test_signals/__init__.py b/tests/test_signals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_signals.py b/tests/test_signals/test_signals.py similarity index 100% rename from tests/test_signals.py rename to tests/test_signals/test_signals.py diff --git a/tests/test_signals_for_relations.py b/tests/test_signals/test_signals_for_relations.py similarity index 100% rename from tests/test_signals_for_relations.py rename to tests/test_signals/test_signals_for_relations.py diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_models_helpers.py b/tests/test_utils/test_models_helpers.py similarity index 100% rename from tests/test_models_helpers.py rename to tests/test_utils/test_models_helpers.py diff --git a/tests/test_queryset_utils.py b/tests/test_utils/test_queryset_utils.py similarity index 63% rename from tests/test_queryset_utils.py rename to tests/test_utils/test_queryset_utils.py index cd96dc8..fe76092 100644 --- a/tests/test_queryset_utils.py +++ b/tests/test_utils/test_queryset_utils.py @@ -4,7 +4,12 @@ import sqlalchemy import ormar from ormar.models.mixins import ExcludableMixin from ormar.queryset.prefetch_query import sort_models -from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update +from ormar.queryset.utils import ( + subtract_dict, + translate_list_to_dict, + update_dict_from_list, + update, +) from tests.settings import DATABASE_URL @@ -79,6 +84,21 @@ def test_updating_dict_inc_set_with_dict(): } +def test_subtracting_dict_inc_set_with_dict(): + curr_dict = { + "aa": Ellipsis, + "bb": Ellipsis, + "cc": {"aa": {"xx", "yy"}, "bb": Ellipsis}, + } + dict_to_update = { + "uu": Ellipsis, + "bb": {"cc", "dd"}, + "cc": {"aa": {"xx": {"oo": Ellipsis}}, "bb": Ellipsis}, + } + test = subtract_dict(curr_dict, dict_to_update) + assert test == {"aa": Ellipsis, "cc": {"aa": {"yy": Ellipsis}}} + + def test_updating_dict_inc_set_with_dict_inc_set(): curr_dict = { "aa": Ellipsis, @@ -99,6 +119,61 @@ def test_updating_dict_inc_set_with_dict_inc_set(): } +def test_subtracting_dict_inc_set_with_dict_inc_set(): + curr_dict = { + "aa": Ellipsis, + "bb": Ellipsis, + "cc": {"aa": {"xx", "yy"}, "bb": Ellipsis}, + "dd": {"aa", "bb"}, + } + dict_to_update = { + "aa": Ellipsis, + "bb": {"cc", "dd"}, + "cc": {"aa": {"xx", "oo", "zz", "ii"}}, + "dd": {"aa", "bb"}, + } + test = subtract_dict(curr_dict, dict_to_update) + assert test == {"cc": {"aa": {"yy"}, "bb": Ellipsis}} + + +def test_subtracting_with_set_and_dict(): + curr_dict = { + "translation": { + "filters": { + "values": Ellipsis, + "reports": {"report": {"charts": {"chart": Ellipsis}}}, + }, + "translations": {"language": Ellipsis}, + "filtervalues": { + "filter": {"reports": {"report": {"charts": {"chart": Ellipsis}}}} + }, + }, + "chart": { + "reports": { + "report": { + "filters": { + "filter": { + "translation": { + "translations": {"language": Ellipsis}, + "filtervalues": Ellipsis, + }, + "values": { + "translation": {"translations": {"language": Ellipsis}} + }, + } + } + } + } + }, + } + dict_to_update = { + "chart": Ellipsis, + "translation": {"filters", "filtervalues", "chartcolumns"}, + } + test = subtract_dict(curr_dict, dict_to_update) + assert test == {"translation": {"translations": {"language": Ellipsis}}} + + database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData()