From 6d0a5477cd7dde4f154810bf8c631030dc35e385 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 14 Mar 2021 19:09:34 +0100 Subject: [PATCH] wip pc problems backup --- ormar/__init__.py | 8 +- ormar/decorators/__init__.py | 8 +- ormar/models/mixins/save_mixin.py | 30 ++- ormar/models/model.py | 7 +- ormar/queryset/actions/select_action.py | 16 +- ormar/queryset/join.py | 81 ++++---- ormar/queryset/queryset.py | 28 ++- ormar/relations/querysetproxy.py | 25 +-- tests/test_aggr_functions.py | 45 ++--- tests/test_default_through_relation_order.py | 196 ++++++++++++------- tests/test_proper_order_of_sorting_apply.py | 6 +- tests/test_signals_for_relations.py | 2 +- 12 files changed, 268 insertions(+), 184 deletions(-) diff --git a/ormar/__init__.py b/ormar/__init__.py index 1b3f89b..d9225a4 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -22,15 +22,15 @@ And what's a better name for python ORM than snakes cabinet :) from ormar.protocols import QuerySetProtocol, RelationProtocol # noqa: I100 from ormar.decorators import ( # noqa: I100 post_delete, - post_save, - post_update, post_relation_add, post_relation_remove, + post_save, + post_update, pre_delete, - pre_save, - pre_update, pre_relation_add, pre_relation_remove, + pre_save, + pre_update, property_field, ) from ormar.exceptions import ( # noqa: I100 diff --git a/ormar/decorators/__init__.py b/ormar/decorators/__init__.py index d38cf2b..ec320a8 100644 --- a/ormar/decorators/__init__.py +++ b/ormar/decorators/__init__.py @@ -10,15 +10,15 @@ Currently only: from ormar.decorators.property_field import property_field from ormar.decorators.signals import ( post_delete, - post_save, - post_update, post_relation_add, post_relation_remove, + post_save, + post_update, pre_delete, - pre_save, - pre_update, pre_relation_add, pre_relation_remove, + pre_save, + pre_update, ) __all__ = [ diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index abdda94..db3a33b 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -1,5 +1,8 @@ +import uuid from typing import Dict, Optional, Set, TYPE_CHECKING +import pydantic + import ormar from ormar.exceptions import ModelPersistenceError from ormar.models.helpers.validation import validate_choices @@ -50,11 +53,30 @@ class SavePrepareMixin(RelationMixin, AliasMixin): pkname = cls.Meta.pkname pk = cls.Meta.model_fields[pkname] if new_kwargs.get(pkname, ormar.Undefined) is None and ( - pk.nullable or pk.autoincrement + pk.nullable or pk.autoincrement ): del new_kwargs[pkname] return new_kwargs + @classmethod + def parse_non_db_fields(cls, model_dict: Dict) -> Dict: + """ + Receives dictionary of model that is about to be saved and changes uuid fields + to strings in bulk_update. + + :param model_dict: dictionary of model that is about to be saved + :type model_dict: Dict + :return: dictionary of model that is about to be saved + :rtype: Dict + """ + for name, field in cls.Meta.model_fields.items(): + if field.__type__ == uuid.UUID and name in model_dict: + if field.column_type.uuid_format == "string": + model_dict[name] = str(model_dict[name]) + else: + model_dict[name] = "%.32x" % model_dict[name].int + return model_dict + @classmethod def substitute_models_with_pks(cls, model_dict: Dict) -> Dict: # noqa CCR001 """ @@ -104,9 +126,9 @@ class SavePrepareMixin(RelationMixin, AliasMixin): """ for field_name, field in cls.Meta.model_fields.items(): if ( - field_name not in new_kwargs - and field.has_default(use_server=False) - and not field.pydantic_only + field_name not in new_kwargs + and field.has_default(use_server=False) + and not field.pydantic_only ): new_kwargs[field_name] = field.get_default() # clear fields with server_default set as None diff --git a/ormar/models/model.py b/ormar/models/model.py index 894ae39..a0c2abc 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -69,6 +69,7 @@ class Model(ModelRow): :return: saved Model :rtype: Model """ + await self.signals.pre_save.send(sender=self.__class__, instance=self) self_fields = self._extract_model_db_fields() if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement: @@ -82,8 +83,6 @@ class Model(ModelRow): } ) - await self.signals.pre_save.send(sender=self.__class__, instance=self) - self_fields = self.translate_columns_to_aliases(self_fields) expr = self.Meta.table.insert() expr = expr.values(**self_fields) @@ -216,7 +215,9 @@ class Model(ModelRow): "You cannot update not saved model! Use save or upsert method." ) - await self.signals.pre_update.send(sender=self.__class__, instance=self) + await self.signals.pre_update.send( + sender=self.__class__, instance=self, passed_args=kwargs + ) self_fields = self._extract_model_db_fields() self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname)) self_fields = self.translate_columns_to_aliases(self_fields) diff --git a/ormar/queryset/actions/select_action.py b/ormar/queryset/actions/select_action.py index cbceef6..92e5991 100644 --- a/ormar/queryset/actions/select_action.py +++ b/ormar/queryset/actions/select_action.py @@ -1,10 +1,11 @@ -from typing import Callable, TYPE_CHECKING, Type +import decimal +from typing import Any, Callable, TYPE_CHECKING, Type import sqlalchemy -from ormar.queryset.actions.query_action import QueryAction +from ormar.queryset.actions.query_action import QueryAction # noqa: I202 -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from ormar import Model @@ -22,7 +23,7 @@ class SelectAction(QueryAction): self, select_str: str, model_cls: Type["Model"], alias: str = None ) -> None: super().__init__(query_str=select_str, model_cls=model_cls) - if alias: + if alias: # pragma: no cover self.table_prefix = alias def _split_value_into_parts(self, order_str: str) -> None: @@ -30,6 +31,13 @@ class SelectAction(QueryAction): self.field_name = parts[-1] self.related_parts = parts[:-1] + @property + def is_numeric(self) -> bool: + return self.get_target_field_type() in [int, float, decimal.Decimal] + + def get_target_field_type(self) -> Any: + return self.target_model.Meta.model_fields[self.field_name].__type__ + def get_text_clause(self) -> sqlalchemy.sql.expression.TextClause: alias = f"{self.table_prefix}_" if self.table_prefix else "" return sqlalchemy.text(f"{alias}{self.field_name}") diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index 4cde369..1828961 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -320,6 +320,48 @@ class SqlJoin: ) self.sorted_orders[clause] = clause.get_text_clause() + def _verify_allowed_order_field(self, order_by: str) -> None: + """ + Verifies if proper field string is used. + :param order_by: string with order by definition + :type order_by: str + """ + parts = order_by.split("__") + if len(parts) > 2 or parts[0] != self.target_field.through.get_name(): + raise ModelDefinitionError( + "You can order the relation only " "by related or link table columns!" + ) + + def _get_alias_and_model(self, order_by: str) -> Tuple[str, Type["Model"]]: + """ + Returns proper model and alias to be applied in the clause. + + :param order_by: string with order by definition + :type order_by: str + :return: alias and model to be used in clause + :rtype: Tuple[str, Type["Model"]] + """ + if self.target_field.is_multi and "__" in order_by: + self._verify_allowed_order_field(order_by=order_by) + alias = self.next_alias + model = self.target_field.owner + elif self.target_field.is_multi: + alias = self.alias_manager.resolve_relation_alias( + from_model=self.target_field.through, + relation_name=cast( + "ManyToManyField", self.target_field + ).default_target_field_name(), + ) + model = self.target_field.to + else: + alias = self.alias_manager.resolve_relation_alias( + from_model=self.target_field.owner, + relation_name=self.target_field.name, + ) + model = self.target_field.to + + return alias, model + def _get_order_bys(self) -> None: # noqa: CCR001 """ Triggers construction of order bys if they are given. @@ -339,44 +381,13 @@ class SqlJoin: self.already_sorted[ f"{self.next_alias}_{self.next_model.get_name()}" ] = condition - # TODO: refactor into smaller helper functions if self.target_field.orders_by and not current_table_sorted: current_table_sorted = True for order_by in self.target_field.orders_by: - if self.target_field.is_multi and "__" in order_by: - parts = order_by.split("__") - if ( - len(parts) > 2 - or parts[0] != self.target_field.through.get_name() - ): - raise ModelDefinitionError( - "You can order the relation only" - "by related or link table columns!" - ) - model = self.target_field.owner - clause = ormar.OrderAction( - order_str=order_by, model_cls=model, alias=alias, - ) - elif self.target_field.is_multi: - alias = self.alias_manager.resolve_relation_alias( - from_model=self.target_field.through, - relation_name=cast( - "ManyToManyField", self.target_field - ).default_target_field_name(), - ) - model = self.target_field.to - clause = ormar.OrderAction( - order_str=order_by, model_cls=model, alias=alias - ) - else: - alias = self.alias_manager.resolve_relation_alias( - from_model=self.target_field.owner, - relation_name=self.target_field.name, - ) - model = self.target_field.to - clause = ormar.OrderAction( - order_str=order_by, model_cls=model, alias=alias - ) + alias, model = self._get_alias_and_model(order_by=order_by) + clause = ormar.OrderAction( + order_str=order_by, model_cls=model, alias=alias + ) self.sorted_orders[clause] = clause.get_text_clause() self.already_sorted[f"{alias}_{model.get_name()}"] = clause diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index c758faa..5db6338 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -14,7 +14,6 @@ from typing import ( import databases import sqlalchemy from sqlalchemy import bindparam -from sqlalchemy.engine import ResultProxy import ormar # noqa I100 from ormar import MultipleMatches, NoMatch @@ -558,22 +557,24 @@ class QuerySet: expr = sqlalchemy.func.count().select().select_from(expr) return await self.database.fetch_val(expr) - async def _query_aggr_function(self, func_name: str, columns: List): + async def _query_aggr_function(self, func_name: str, columns: List) -> Any: func = getattr(sqlalchemy.func, func_name) select_actions = [ - SelectAction(select_str=column, model_cls=self.model) - for column in columns + SelectAction(select_str=column, model_cls=self.model) for column in columns ] + if func_name in ["sum", "avg"]: + if any(not x.is_numeric for x in select_actions): + raise QueryDefinitionError( + "You can use sum and svg only with" "numeric types of columns" + ) select_columns = [x.apply_func(func, use_label=True) for x in select_actions] expr = self.build_select_expression().alias(f"subquery_for_{func_name}") expr = sqlalchemy.select(select_columns).select_from(expr) # print("\n", expr.compile(compile_kwargs={"literal_binds": True})) result = await self.database.fetch_one(expr) - return result if len(result) > 1 else result[0] # type: ignore + return dict(result) if len(result) > 1 else result[0] # type: ignore - async def max( # noqa: A003 - self, columns: Union[str, List[str]] - ) -> Union[Any, ResultProxy]: + async def max(self, columns: Union[str, List[str]]) -> Any: # noqa: A003 """ Returns max value of columns for rows matching the given criteria (applied with `filter` and `exclude` if set before). @@ -585,9 +586,7 @@ class QuerySet: columns = [columns] return await self._query_aggr_function(func_name="max", columns=columns) - async def min( # noqa: A003 - self, columns: Union[str, List[str]] - ) -> Union[Any, ResultProxy]: + async def min(self, columns: Union[str, List[str]]) -> Any: # noqa: A003 """ Returns min value of columns for rows matching the given criteria (applied with `filter` and `exclude` if set before). @@ -599,9 +598,7 @@ class QuerySet: columns = [columns] return await self._query_aggr_function(func_name="min", columns=columns) - async def sum( # noqa: A003 - self, columns: Union[str, List[str]] - ) -> Union[Any, ResultProxy]: + async def sum(self, columns: Union[str, List[str]]) -> Any: # noqa: A003 """ Returns sum value of columns for rows matching the given criteria (applied with `filter` and `exclude` if set before). @@ -613,7 +610,7 @@ class QuerySet: columns = [columns] return await self._query_aggr_function(func_name="sum", columns=columns) - async def avg(self, columns: Union[str, List[str]]) -> Union[Any, ResultProxy]: + async def avg(self, columns: Union[str, List[str]]) -> Any: """ Returns avg value of columns for rows matching the given criteria (applied with `filter` and `exclude` if set before). @@ -974,6 +971,7 @@ class QuerySet: "You cannot update unsaved objects. " f"{self.model.__name__} has to have {pk_name} filled." ) + new_kwargs = self.model.parse_non_db_fields(new_kwargs) new_kwargs = self.model.substitute_models_with_pks(new_kwargs) new_kwargs = self.model.translate_columns_to_aliases(new_kwargs) new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns} diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index 1268e21..5e0466d 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -12,9 +12,8 @@ from typing import ( # noqa: I100, I201 cast, ) -from sqlalchemy.engine import ResultProxy -import ormar +import ormar # noqa: I100, I202 from ormar.exceptions import ModelPersistenceError, QueryDefinitionError if TYPE_CHECKING: # pragma no cover @@ -118,7 +117,6 @@ class QuerysetProxy: :type child: Model """ model_cls = self.relation.through - # TODO: Add support for pk with default not only autoincrement id owner_column = self.related_field.default_target_field_name() # type: ignore child_column = self.related_field.default_source_field_name() # type: ignore rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk} @@ -129,10 +127,8 @@ class QuerysetProxy: f"model without primary key set! \n" f"Save the child model first." ) - expr = model_cls.Meta.table.insert() - expr = expr.values(**final_kwargs) - # print("\n", expr.compile(compile_kwargs={"literal_binds": True})) - await model_cls.Meta.database.execute(expr) + print('final kwargs', final_kwargs) + await model_cls(**final_kwargs).save() async def update_through_instance(self, child: "Model", **kwargs: Any) -> None: """ @@ -148,6 +144,7 @@ class QuerysetProxy: child_column = self.related_field.default_source_field_name() # type: ignore rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk} through_model = await model_cls.objects.get(**rel_kwargs) + print('update kwargs', kwargs) await through_model.update(**kwargs) async def delete_through_instance(self, child: "Model") -> None: @@ -188,9 +185,7 @@ class QuerysetProxy: """ return await self.queryset.count() - async def max( # noqa: A003 - self, columns: Union[str, List[str]] - ) -> Union[Any, ResultProxy]: + async def max(self, columns: Union[str, List[str]]) -> Any: # noqa: A003 """ Returns max value of columns for rows matching the given criteria (applied with `filter` and `exclude` if set before). @@ -200,9 +195,7 @@ class QuerysetProxy: """ return await self.queryset.max(columns=columns) - async def min( # noqa: A003 - self, columns: Union[str, List[str]] - ) -> Union[Any, ResultProxy]: + async def min(self, columns: Union[str, List[str]]) -> Any: # noqa: A003 """ Returns min value of columns for rows matching the given criteria (applied with `filter` and `exclude` if set before). @@ -212,9 +205,7 @@ class QuerysetProxy: """ return await self.queryset.min(columns=columns) - async def sum( # noqa: A003 - self, columns: Union[str, List[str]] - ) -> Union[Any, ResultProxy]: + async def sum(self, columns: Union[str, List[str]]) -> Any: # noqa: A003 """ Returns sum value of columns for rows matching the given criteria (applied with `filter` and `exclude` if set before). @@ -224,7 +215,7 @@ class QuerysetProxy: """ return await self.queryset.sum(columns=columns) - async def avg(self, columns: Union[str, List[str]]) -> Union[Any, ResultProxy]: + async def avg(self, columns: Union[str, List[str]]) -> Any: """ Returns avg value of columns for rows matching the given criteria (applied with `filter` and `exclude` if set before). diff --git a/tests/test_aggr_functions.py b/tests/test_aggr_functions.py index f55b128..92c10a1 100644 --- a/tests/test_aggr_functions.py +++ b/tests/test_aggr_functions.py @@ -5,6 +5,7 @@ import pytest import sqlalchemy import ormar +from ormar.exceptions import QueryDefinitionError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL) @@ -67,8 +68,7 @@ async def test_min_method(): await sample_data() assert await Book.objects.min("year") == 1920 result = await Book.objects.min(["year", "ranking"]) - assert result == (1920, 1) - assert dict(result) == dict(year=1920, ranking=1) + assert result == dict(year=1920, ranking=1) assert await Book.objects.min("title") == "Book 1" @@ -76,8 +76,7 @@ async def test_min_method(): result = await Author.objects.select_related("books").min( ["books__year", "books__ranking"] ) - assert result == (1920, 1) - assert dict(result) == dict(books__year=1920, books__ranking=1) + assert result == dict(books__year=1920, books__ranking=1) assert ( await Author.objects.select_related("books") @@ -93,8 +92,7 @@ async def test_max_method(): await sample_data() assert await Book.objects.max("year") == 1930 result = await Book.objects.max(["year", "ranking"]) - assert result == (1930, 5) - assert dict(result) == dict(year=1930, ranking=5) + assert result == dict(year=1930, ranking=5) assert await Book.objects.max("title") == "Book 3" @@ -102,8 +100,7 @@ async def test_max_method(): result = await Author.objects.select_related("books").max( ["books__year", "books__ranking"] ) - assert result == (1930, 5) - assert dict(result) == dict(books__year=1930, books__ranking=5) + assert result == dict(books__year=1930, books__ranking=5) assert ( await Author.objects.select_related("books") @@ -119,17 +116,16 @@ async def test_sum_method(): await sample_data() assert await Book.objects.sum("year") == 5773 result = await Book.objects.sum(["year", "ranking"]) - assert result == (5773, 9) - assert dict(result) == dict(year=5773, ranking=9) + assert result == dict(year=5773, ranking=9) - assert await Book.objects.sum("title") == 0.0 + with pytest.raises(QueryDefinitionError): + await Book.objects.sum("title") assert await Author.objects.select_related("books").sum("books__year") == 5773 result = await Author.objects.select_related("books").sum( ["books__year", "books__ranking"] ) - assert result == (5773, 9) - assert dict(result) == dict(books__year=5773, books__ranking=9) + assert result == dict(books__year=5773, books__ranking=9) assert ( await Author.objects.select_related("books") @@ -143,24 +139,21 @@ async def test_sum_method(): async def test_avg_method(): async with database: await sample_data() - assert round(await Book.objects.avg("year"), 2) == 1924.33 + assert round(float(await Book.objects.avg("year")), 2) == 1924.33 result = await Book.objects.avg(["year", "ranking"]) - assert (round(result[0], 2), result[1]) == (1924.33, 3.0) - result_dict = dict(result) - assert round(result_dict.get("year"), 2) == 1924.33 - assert result_dict.get("ranking") == 3.0 + assert round(float(result.get("year")), 2) == 1924.33 + assert result.get("ranking") == 3.0 - assert await Book.objects.avg("title") == 0.0 + with pytest.raises(QueryDefinitionError): + await Book.objects.avg("title") result = await Author.objects.select_related("books").avg("books__year") - assert round(result, 2) == 1924.33 + assert round(float(result), 2) == 1924.33 result = await Author.objects.select_related("books").avg( ["books__year", "books__ranking"] ) - assert (round(result[0], 2), result[1]) == (1924.33, 3.0) - result_dict = dict(result) - assert round(result_dict.get("books__year"), 2) == 1924.33 - assert result_dict.get("books__ranking") == 3.0 + assert round(float(result.get("books__year")), 2) == 1924.33 + assert result.get("books__ranking") == 3.0 assert ( await Author.objects.select_related("books") @@ -179,4 +172,6 @@ async def test_queryset_method(): assert await author.books.max("year") == 1930 assert await author.books.sum("ranking") == 9 assert await author.books.avg("ranking") == 3.0 - assert await author.books.max(["year", "title"]) == (1930, "Book 3") + assert await author.books.max(["year", "title"]) == dict( + year=1930, title="Book 3" + ) diff --git a/tests/test_default_through_relation_order.py b/tests/test_default_through_relation_order.py index eb02ce4..2dbed22 100644 --- a/tests/test_default_through_relation_order.py +++ b/tests/test_default_through_relation_order.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, Dict, List, Type from uuid import UUID, uuid4 import databases @@ -6,6 +6,8 @@ import pytest import sqlalchemy import ormar +from ormar import ModelDefinitionError, Model, QuerySet, pre_update +from ormar import pre_save, pre_relation_add from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL) @@ -30,7 +32,7 @@ class Link(ormar.Model): class Meta(BaseMeta): tablename = "link_table" - id: int = ormar.Integer(primary_key=True) + id: UUID = ormar.UUID(primary_key=True, default=uuid4) animal_order: int = ormar.Integer(nullable=True) human_order: int = ormar.Integer(nullable=True) @@ -50,6 +52,17 @@ class Human(ormar.Model): ) +class Human2(ormar.Model): + class Meta(BaseMeta): + tablename = "humans2" + + id: UUID = ormar.UUID(primary_key=True, default=uuid4) + name: str = ormar.Text(default="") + favoriteAnimals: List[Animal] = ormar.ManyToMany( + Animal, related_name="favoriteHumans2", orders_by=["link__animal_order__fail"] + ) + + @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) @@ -59,9 +72,94 @@ def create_test_database(): metadata.drop_all(engine) +@pytest.mark.asyncio +async def test_ordering_by_through_fail(): + async with database: + alice = await Human2(name="Alice").save() + spot = await Animal(name="Spot").save() + await alice.favoriteAnimals.add(spot) + with pytest.raises(ModelDefinitionError): + await alice.load_all() + + +def get_filtered_query( + sender: Type[Model], instance: Model, to_class: Type[Model] +) -> QuerySet: + pk = getattr(instance, f"{to_class.get_name()}").pk + filter_kwargs = {f"{to_class.get_name()}": pk} + query = sender.objects.filter(**filter_kwargs) + return query + + +async def populate_order_on_insert( + sender: Type[Model], instance: Model, from_class: Type[Model], + to_class: Type[Model] +): + order_column = f"{from_class.get_name()}_order" + if getattr(instance, order_column) is None: + query = get_filtered_query(sender, instance, to_class) + max_order = await query.max(order_column) + max_order = max_order + 1 if max_order is not None else 0 + setattr(instance, order_column, max_order) + else: + await reorder_on_update(sender, instance, from_class, to_class, + passed_args={ + order_column: getattr(instance, order_column)}) + + +async def reorder_on_update( + sender: Type[Model], instance: Model, from_class: Type[Model], + to_class: Type[Model], passed_args: Dict +): + order = f"{from_class.get_name()}_order" + if order in passed_args: + query = get_filtered_query(sender, instance, to_class) + to_reorder = await query.exclude(pk=instance.pk).order_by(order).all() + old_order = getattr(instance, order) + new_order = passed_args.get(order) + if to_reorder: + for link in to_reorder: + setattr(link, order, getattr(link, order) + 1) + await sender.objects.bulk_update(to_reorder, columns=[order]) + check = await get_filtered_query(sender, instance, to_class).all() + print('reordered', check) + + +@pre_save(Link) +async def order_link_on_insert(sender: Type[Model], instance: Model, **kwargs: Any): + relations = list(instance.extract_related_names()) + rel_one = sender.Meta.model_fields[relations[0]].to + rel_two = sender.Meta.model_fields[relations[1]].to + await populate_order_on_insert(sender, instance, from_class=rel_one, + to_class=rel_two) + await populate_order_on_insert(sender, instance, from_class=rel_two, + to_class=rel_one) + + +@pre_update(Link) +async def reorder_links_on_update( + sender: Type[ormar.Model], instance: ormar.Model, passed_args: Dict, + **kwargs: Any +): + relations = list(instance.extract_related_names()) + rel_one = sender.Meta.model_fields[relations[0]].to + rel_two = sender.Meta.model_fields[relations[1]].to + await reorder_on_update(sender, instance, from_class=rel_one, to_class=rel_two, + passed_args=passed_args) + await reorder_on_update(sender, instance, from_class=rel_two, to_class=rel_one, + passed_args=passed_args) + + @pytest.mark.asyncio async def test_ordering_by_through_on_m2m_field(): async with database: + def verify_order(instance, expected): + field_name = ( + "favoriteAnimals" if isinstance(instance, + Human) else "favoriteHumans" + ) + assert [x.name for x in getattr(instance, field_name)] == expected + alice = await Human(name="Alice").save() bob = await Human(name="Bob").save() charlie = await Human(name="Charlie").save() @@ -70,98 +168,55 @@ async def test_ordering_by_through_on_m2m_field(): kitty = await Animal(name="Kitty").save() noodle = await Animal(name="Noodle").save() - # you need to add them in order anyway so can provide order explicitly - # if you have a lot of them a list with enumerate might be an option - await alice.favoriteAnimals.add(noodle, animal_order=0, human_order=0) - await alice.favoriteAnimals.add(spot, animal_order=1, human_order=0) - await alice.favoriteAnimals.add(kitty, animal_order=2, human_order=0) + await alice.favoriteAnimals.add(noodle) + await alice.favoriteAnimals.add(spot) + await alice.favoriteAnimals.add(kitty) - # you dont have to reload queries on queryset clears the existing related - # alice = await alice.reload() await alice.load_all() - assert [x.name for x in alice.favoriteAnimals] == ["Noodle", "Spot", "Kitty"] + verify_order(alice, ["Noodle", "Spot", "Kitty"]) - await bob.favoriteAnimals.add(noodle, animal_order=0, human_order=1) - await bob.favoriteAnimals.add(kitty, animal_order=1, human_order=1) - await bob.favoriteAnimals.add(spot, animal_order=2, human_order=1) + await bob.favoriteAnimals.add(noodle) + await bob.favoriteAnimals.add(kitty) + await bob.favoriteAnimals.add(spot) await bob.load_all() - assert [x.name for x in bob.favoriteAnimals] == ["Noodle", "Kitty", "Spot"] + verify_order(bob, ["Noodle", "Kitty", "Spot"]) - await charlie.favoriteAnimals.add(kitty, animal_order=0, human_order=2) - await charlie.favoriteAnimals.add(noodle, animal_order=1, human_order=2) - await charlie.favoriteAnimals.add(spot, animal_order=2, human_order=2) + await charlie.favoriteAnimals.add(kitty) + await charlie.favoriteAnimals.add(noodle) + await charlie.favoriteAnimals.add(spot) await charlie.load_all() - assert [x.name for x in charlie.favoriteAnimals] == ["Kitty", "Noodle", "Spot"] + verify_order(charlie, ["Kitty", "Noodle", "Spot"]) animals = [noodle, kitty, spot] for animal in animals: await animal.load_all() - assert [x.name for x in animal.favoriteHumans] == [ - "Alice", - "Bob", - "Charlie", - ] + verify_order(animal, ["Alice", "Bob", "Charlie"]) zack = await Human(name="Zack").save() - async def reorder_humans(animal, new_ordered_humans): - noodle_links = await Link.objects.filter(animal=animal).all() - for link in noodle_links: - link.human_order = next( - ( - i - for i, x in enumerate(new_ordered_humans) - if x.pk == link.human.pk - ), - None, - ) - await Link.objects.bulk_update(noodle_links, columns=["human_order"]) - await noodle.favoriteHumans.add(zack, animal_order=0, human_order=0) - await reorder_humans(noodle, [zack, alice, bob, charlie]) await noodle.load_all() - assert [x.name for x in noodle.favoriteHumans] == [ - "Zack", - "Alice", - "Bob", - "Charlie", - ] + verify_order(noodle, ["Zack", "Alice", "Bob", "Charlie"]) await zack.load_all() - assert [x.name for x in zack.favoriteAnimals] == ["Noodle"] + verify_order(zack, ["Noodle"]) - humans = noodle.favoriteHumans - humans.insert(1, humans.pop(0)) - await reorder_humans(noodle, humans) + await noodle.favoriteHumans.filter(name='Zack').update( + link=dict(human_order=1)) await noodle.load_all() - assert [x.name for x in noodle.favoriteHumans] == [ - "Alice", - "Zack", - "Bob", - "Charlie", - ] + verify_order(noodle, ["Alice", "Zack", "Bob", "Charlie"]) - humans.insert(2, humans.pop(1)) - await reorder_humans(noodle, humans) + await noodle.favoriteHumans.filter(name='Zack').update( + link=dict(human_order=2)) await noodle.load_all() - assert [x.name for x in noodle.favoriteHumans] == [ - "Alice", - "Bob", - "Zack", - "Charlie", - ] + verify_order(noodle, ["Alice", "Bob", "Zack", "Charlie"]) - humans.insert(3, humans.pop(2)) - await reorder_humans(noodle, humans) + await noodle.favoriteHumans.filter(name='Zack').update( + link=dict(human_order=3)) await noodle.load_all() - assert [x.name for x in noodle.favoriteHumans] == [ - "Alice", - "Bob", - "Charlie", - "Zack", - ] + verify_order(noodle, ["Alice", "Bob", "Charlie", "Zack"]) await kitty.favoriteHumans.remove(bob) await kitty.load_all() @@ -169,8 +224,9 @@ async def test_ordering_by_through_on_m2m_field(): bob = await noodle.favoriteHumans.get(pk=bob.pk) assert bob.link.human_order == 1 + await noodle.favoriteHumans.remove( await noodle.favoriteHumans.filter(link__human_order=2).get() ) await noodle.load_all() - assert [x.name for x in noodle.favoriteHumans] == ["Alice", "Bob", "Zack"] + verify_order(noodle, ["Alice", "Bob", "Zack"]) diff --git a/tests/test_proper_order_of_sorting_apply.py b/tests/test_proper_order_of_sorting_apply.py index d7506ed..a02f6be 100644 --- a/tests/test_proper_order_of_sorting_apply.py +++ b/tests/test_proper_order_of_sorting_apply.py @@ -49,8 +49,10 @@ def create_test_database(): @pytest.fixture(autouse=True, scope="function") async def cleanup(): - await Book.objects.delete(each=True) - await Author.objects.delete(each=True) + yield + async with database: + await Book.objects.delete(each=True) + await Author.objects.delete(each=True) @pytest.mark.asyncio diff --git a/tests/test_signals_for_relations.py b/tests/test_signals_for_relations.py index e0cc20e..e5cbf8f 100644 --- a/tests/test_signals_for_relations.py +++ b/tests/test_signals_for_relations.py @@ -70,7 +70,7 @@ def create_test_database(): metadata.drop_all(engine) -@pytest.fixture(scope="function") +@pytest.fixture(autouse=True, scope="function") async def cleanup(): yield async with database: