diff --git a/ormar/__init__.py b/ormar/__init__.py index 9193543..1b3f89b 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -24,9 +24,13 @@ from ormar.decorators import ( # noqa: I100 post_delete, post_save, post_update, + post_relation_add, + post_relation_remove, pre_delete, pre_save, pre_update, + pre_relation_add, + pre_relation_remove, property_field, ) from ormar.exceptions import ( # noqa: I100 @@ -102,9 +106,13 @@ __all__ = [ "post_delete", "post_save", "post_update", + "post_relation_add", + "post_relation_remove", "pre_delete", "pre_save", "pre_update", + "pre_relation_remove", + "pre_relation_add", "Signal", "BaseField", "ManyToManyField", diff --git a/ormar/decorators/__init__.py b/ormar/decorators/__init__.py index 69925ce..d38cf2b 100644 --- a/ormar/decorators/__init__.py +++ b/ormar/decorators/__init__.py @@ -12,9 +12,13 @@ from ormar.decorators.signals import ( post_delete, post_save, post_update, + post_relation_add, + post_relation_remove, pre_delete, pre_save, pre_update, + pre_relation_add, + pre_relation_remove, ) __all__ = [ @@ -25,4 +29,8 @@ __all__ = [ "pre_delete", "pre_save", "pre_update", + "post_relation_remove", + "post_relation_add", + "pre_relation_remove", + "pre_relation_add", ] diff --git a/ormar/decorators/signals.py b/ormar/decorators/signals.py index 24f5ce4..8322f19 100644 --- a/ormar/decorators/signals.py +++ b/ormar/decorators/signals.py @@ -22,7 +22,7 @@ def receiver( def _decorator(func: Callable) -> Callable: """ - Internal decorator that does all the registeriing. + Internal decorator that does all the registering. :param func: function to register as receiver :type func: Callable @@ -117,3 +117,57 @@ def pre_delete(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: :rtype: Callable """ return receiver(signal="pre_delete", senders=senders) + + +def pre_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: + """ + Connect given function to all senders for pre_relation_add signal. + + :param senders: one or a list of "Model" classes + that should have the signal receiver registered + :type senders: Union[Type["Model"], List[Type["Model"]]] + :return: returns the original function untouched + :rtype: Callable + """ + return receiver(signal="pre_relation_add", senders=senders) + + +def post_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: + """ + Connect given function to all senders for post_relation_add signal. + + :param senders: one or a list of "Model" classes + that should have the signal receiver registered + :type senders: Union[Type["Model"], List[Type["Model"]]] + :return: returns the original function untouched + :rtype: Callable + """ + return receiver(signal="post_relation_add", senders=senders) + + +def pre_relation_remove(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable: + """ + Connect given function to all senders for pre_relation_remove signal. + + :param senders: one or a list of "Model" classes + that should have the signal receiver registered + :type senders: Union[Type["Model"], List[Type["Model"]]] + :return: returns the original function untouched + :rtype: Callable + """ + return receiver(signal="pre_relation_remove", senders=senders) + + +def post_relation_remove( + senders: Union[Type["Model"], List[Type["Model"]]] +) -> Callable: + """ + Connect given function to all senders for post_relation_remove signal. + + :param senders: one or a list of "Model" classes + that should have the signal receiver registered + :type senders: Union[Type["Model"], List[Type["Model"]]] + :return: returns the original function untouched + :rtype: Callable + """ + return receiver(signal="post_relation_remove", senders=senders) diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 225ff32..3d14b52 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -140,6 +140,10 @@ def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001 signals.post_save = Signal() signals.post_update = Signal() signals.post_delete = Signal() + signals.pre_relation_add = Signal() + signals.post_relation_add = Signal() + signals.pre_relation_remove = Signal() + signals.post_relation_remove = Signal() new_model.Meta.signals = signals diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 8ffafc7..105aa87 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -216,6 +216,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass ) if isinstance(object.__getattribute__(self, "__dict__").get(name), list): # virtual foreign key or many to many + # TODO: Fix double items in dict, no effect on real action ugly repr + # if model.pk not in [x.pk for x in related_list]: object.__getattribute__(self, "__dict__")[name].append(model) else: # foreign key relation diff --git a/ormar/queryset/__init__.py b/ormar/queryset/__init__.py index 678e977..e75febf 100644 --- a/ormar/queryset/__init__.py +++ b/ormar/queryset/__init__.py @@ -1,7 +1,7 @@ """ Contains QuerySet and different Query classes to allow for constructing of sql queries. """ -from ormar.queryset.actions import FilterAction, OrderAction +from ormar.queryset.actions import FilterAction, OrderAction, SelectAction from ormar.queryset.clause import and_, or_ from ormar.queryset.filter_query import FilterQuery from ormar.queryset.limit_query import LimitQuery @@ -17,6 +17,7 @@ __all__ = [ "OrderQuery", "FilterAction", "OrderAction", + "SelectAction", "and_", "or_", ] diff --git a/ormar/queryset/actions/__init__.py b/ormar/queryset/actions/__init__.py index 088d68a..1fe1994 100644 --- a/ormar/queryset/actions/__init__.py +++ b/ormar/queryset/actions/__init__.py @@ -1,4 +1,5 @@ from ormar.queryset.actions.filter_action import FilterAction from ormar.queryset.actions.order_action import OrderAction +from ormar.queryset.actions.select_action import SelectAction -__all__ = ["FilterAction", "OrderAction"] +__all__ = ["FilterAction", "OrderAction", "SelectAction"] diff --git a/ormar/queryset/actions/select_action.py b/ormar/queryset/actions/select_action.py new file mode 100644 index 0000000..cbceef6 --- /dev/null +++ b/ormar/queryset/actions/select_action.py @@ -0,0 +1,44 @@ +from typing import Callable, TYPE_CHECKING, Type + +import sqlalchemy + +from ormar.queryset.actions.query_action import QueryAction + +if TYPE_CHECKING: + from ormar import Model + + +class SelectAction(QueryAction): + """ + Order Actions is populated by queryset when order_by() is called. + + All required params are extracted but kept raw until actual filter clause value + is required -> then the action is converted into text() clause. + + Extracted in order to easily change table prefixes on complex relations. + """ + + def __init__( + self, select_str: str, model_cls: Type["Model"], alias: str = None + ) -> None: + super().__init__(query_str=select_str, model_cls=model_cls) + if alias: + self.table_prefix = alias + + def _split_value_into_parts(self, order_str: str) -> None: + parts = order_str.split("__") + self.field_name = parts[-1] + self.related_parts = parts[:-1] + + def get_text_clause(self) -> sqlalchemy.sql.expression.TextClause: + alias = f"{self.table_prefix}_" if self.table_prefix else "" + return sqlalchemy.text(f"{alias}{self.field_name}") + + def apply_func( + self, func: Callable, use_label: bool = True + ) -> sqlalchemy.sql.expression.TextClause: + result = func(self.get_text_clause()) + if use_label: + rel_prefix = f"{self.related_str}__" if self.related_str else "" + result = result.label(f"{rel_prefix}{self.field_name}") + return result diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index 6547bb3..4cde369 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -6,7 +6,8 @@ from typing import ( Optional, TYPE_CHECKING, Tuple, - Type, cast, + Type, + cast, ) import sqlalchemy @@ -24,20 +25,20 @@ if TYPE_CHECKING: # pragma no cover class SqlJoin: def __init__( # noqa: CFQ002 - self, - used_aliases: List, - select_from: sqlalchemy.sql.select, - columns: List[sqlalchemy.Column], - excludable: "ExcludableItems", - order_columns: Optional[List["OrderAction"]], - sorted_orders: OrderedDict, - main_model: Type["Model"], - relation_name: str, - relation_str: str, - related_models: Any = None, - own_alias: str = "", - source_model: Type["Model"] = None, - already_sorted: Dict = None, + self, + used_aliases: List, + select_from: sqlalchemy.sql.select, + columns: List[sqlalchemy.Column], + excludable: "ExcludableItems", + order_columns: Optional[List["OrderAction"]], + sorted_orders: OrderedDict, + main_model: Type["Model"], + relation_name: str, + relation_str: str, + related_models: Any = None, + own_alias: str = "", + source_model: Type["Model"] = None, + already_sorted: Dict = None, ) -> None: self.relation_name = relation_name self.related_models = related_models or [] @@ -102,7 +103,7 @@ class SqlJoin: return self.next_model.Meta.table def _on_clause( - self, previous_alias: str, from_clause: str, to_clause: str, + self, previous_alias: str, from_clause: str, to_clause: str, ) -> text: """ Receives aliases and names of both ends of the join and combines them @@ -174,8 +175,8 @@ class SqlJoin: for related_name in self.related_models: remainder = None if ( - isinstance(self.related_models, dict) - and self.related_models[related_name] + isinstance(self.related_models, dict) + and self.related_models[related_name] ): remainder = self.related_models[related_name] self._process_deeper_join(related_name=related_name, remainder=remainder) @@ -257,18 +258,18 @@ class SqlJoin: """ target_field = self.target_field is_primary_self_ref = ( - target_field.self_reference - and self.relation_name == target_field.self_reference_primary + target_field.self_reference + and self.relation_name == target_field.self_reference_primary ) if (is_primary_self_ref and not reverse) or ( - not is_primary_self_ref and reverse + not is_primary_self_ref and reverse ): new_part = target_field.default_source_field_name() # type: ignore else: new_part = target_field.default_target_field_name() # type: ignore return new_part - def _process_join(self, ) -> None: # noqa: CFQ002 + def _process_join(self,) -> None: # noqa: CFQ002 """ Resolves to and from column names and table names. @@ -331,7 +332,7 @@ class SqlJoin: if self.order_columns: for condition in self.order_columns: if condition.check_if_filter_apply( - target_model=self.next_model, alias=alias + target_model=self.next_model, alias=alias ): current_table_sorted = True self.sorted_orders[condition] = condition.get_text_clause() @@ -345,8 +346,8 @@ class SqlJoin: if self.target_field.is_multi and "__" in order_by: parts = order_by.split("__") if ( - len(parts) > 2 - or parts[0] != self.target_field.through.get_name() + len(parts) > 2 + or parts[0] != self.target_field.through.get_name() ): raise ModelDefinitionError( "You can order the relation only" @@ -359,8 +360,9 @@ class SqlJoin: elif self.target_field.is_multi: alias = self.alias_manager.resolve_relation_alias( from_model=self.target_field.through, - relation_name=cast("ManyToManyField", - self.target_field).default_target_field_name(), + relation_name=cast( + "ManyToManyField", self.target_field + ).default_target_field_name(), ) model = self.target_field.to clause = ormar.OrderAction( diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index e84b67d..c758faa 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -14,11 +14,12 @@ from typing import ( import databases import sqlalchemy from sqlalchemy import bindparam +from sqlalchemy.engine import ResultProxy import ormar # noqa I100 from ormar import MultipleMatches, NoMatch from ormar.exceptions import ModelError, ModelPersistenceError, QueryDefinitionError -from ormar.queryset import FilterQuery +from ormar.queryset import FilterQuery, SelectAction from ormar.queryset.actions.order_action import OrderAction from ormar.queryset.clause import FilterGroup, QueryClause from ormar.queryset.prefetch_query import PrefetchQuery @@ -557,6 +558,73 @@ class QuerySet: expr = sqlalchemy.func.count().select().select_from(expr) return await self.database.fetch_val(expr) + async def _query_aggr_function(self, func_name: str, columns: List): + func = getattr(sqlalchemy.func, func_name) + select_actions = [ + SelectAction(select_str=column, model_cls=self.model) + for column in columns + ] + select_columns = [x.apply_func(func, use_label=True) for x in select_actions] + expr = self.build_select_expression().alias(f"subquery_for_{func_name}") + expr = sqlalchemy.select(select_columns).select_from(expr) + # print("\n", expr.compile(compile_kwargs={"literal_binds": True})) + result = await self.database.fetch_one(expr) + return result if len(result) > 1 else result[0] # type: ignore + + async def max( # noqa: A003 + self, columns: Union[str, List[str]] + ) -> Union[Any, ResultProxy]: + """ + Returns max value of columns for rows matching the given criteria + (applied with `filter` and `exclude` if set before). + + :return: max value of column(s) + :rtype: Any + """ + if not isinstance(columns, list): + columns = [columns] + return await self._query_aggr_function(func_name="max", columns=columns) + + async def min( # noqa: A003 + self, columns: Union[str, List[str]] + ) -> Union[Any, ResultProxy]: + """ + Returns min value of columns for rows matching the given criteria + (applied with `filter` and `exclude` if set before). + + :return: min value of column(s) + :rtype: Any + """ + if not isinstance(columns, list): + columns = [columns] + return await self._query_aggr_function(func_name="min", columns=columns) + + async def sum( # noqa: A003 + self, columns: Union[str, List[str]] + ) -> Union[Any, ResultProxy]: + """ + Returns sum value of columns for rows matching the given criteria + (applied with `filter` and `exclude` if set before). + + :return: sum value of columns + :rtype: int + """ + if not isinstance(columns, list): + columns = [columns] + return await self._query_aggr_function(func_name="sum", columns=columns) + + async def avg(self, columns: Union[str, List[str]]) -> Union[Any, ResultProxy]: + """ + Returns avg value of columns for rows matching the given criteria + (applied with `filter` and `exclude` if set before). + + :return: avg value of columns + :rtype: Union[int, float, List] + """ + if not isinstance(columns, list): + columns = [columns] + return await self._query_aggr_function(func_name="avg", columns=columns) + async def update(self, each: bool = False, **kwargs: Any) -> int: """ Updates the model table after applying the filters from kwargs. diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index d90776a..1268e21 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -12,6 +12,8 @@ from typing import ( # noqa: I100, I201 cast, ) +from sqlalchemy.engine import ResultProxy + import ormar from ormar.exceptions import ModelPersistenceError, QueryDefinitionError @@ -116,6 +118,7 @@ class QuerysetProxy: :type child: Model """ model_cls = self.relation.through + # TODO: Add support for pk with default not only autoincrement id owner_column = self.related_field.default_target_field_name() # type: ignore child_column = self.related_field.default_source_field_name() # type: ignore rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk} @@ -185,6 +188,52 @@ class QuerysetProxy: """ return await self.queryset.count() + async def max( # noqa: A003 + self, columns: Union[str, List[str]] + ) -> Union[Any, ResultProxy]: + """ + Returns max value of columns for rows matching the given criteria + (applied with `filter` and `exclude` if set before). + + :return: max value of column(s) + :rtype: Any + """ + return await self.queryset.max(columns=columns) + + async def min( # noqa: A003 + self, columns: Union[str, List[str]] + ) -> Union[Any, ResultProxy]: + """ + Returns min value of columns for rows matching the given criteria + (applied with `filter` and `exclude` if set before). + + :return: min value of column(s) + :rtype: Any + """ + return await self.queryset.min(columns=columns) + + async def sum( # noqa: A003 + self, columns: Union[str, List[str]] + ) -> Union[Any, ResultProxy]: + """ + Returns sum value of columns for rows matching the given criteria + (applied with `filter` and `exclude` if set before). + + :return: sum value of columns + :rtype: int + """ + return await self.queryset.sum(columns=columns) + + async def avg(self, columns: Union[str, List[str]]) -> Union[Any, ResultProxy]: + """ + Returns avg value of columns for rows matching the given criteria + (applied with `filter` and `exclude` if set before). + + :return: avg value of columns + :rtype: Union[int, float, List] + """ + return await self.queryset.avg(columns=columns) + async def clear(self, keep_reversed: bool = True) -> int: """ Removes all related models from given relation. diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index ce4b86f..20932b8 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -152,6 +152,12 @@ class RelationProxy(list): f"Object {self._owner.get_name()} has no " f"{item.get_name()} with given primary key!" ) + await self._owner.signals.pre_relation_remove.send( + sender=self._owner.__class__, + instance=self._owner, + child=item, + relation_name=self.field_name, + ) super().remove(item) relation_name = self.related_field_name relation = item._orm._get(relation_name) @@ -169,6 +175,12 @@ class RelationProxy(list): await item.update() else: await item.delete() + await self._owner.signals.post_relation_remove.send( + sender=self._owner.__class__, + instance=self._owner, + child=item, + relation_name=self.field_name, + ) async def add(self, item: "Model", **kwargs: Any) -> None: """ @@ -182,6 +194,13 @@ class RelationProxy(list): :type item: Model """ relation_name = self.related_field_name + await self._owner.signals.pre_relation_add.send( + sender=self._owner.__class__, + instance=self._owner, + child=item, + relation_name=self.field_name, + passed_kwargs=kwargs, + ) self._check_if_model_saved() if self.type_ == ormar.RelationType.MULTIPLE: await self.queryset_proxy.create_through_instance(item, **kwargs) @@ -189,3 +208,10 @@ class RelationProxy(list): else: setattr(item, relation_name, self._owner) await item.update() + await self._owner.signals.post_relation_add.send( + sender=self._owner.__class__, + instance=self._owner, + child=item, + relation_name=self.field_name, + passed_kwargs=kwargs, + ) diff --git a/tests/test_aggr_functions.py b/tests/test_aggr_functions.py new file mode 100644 index 0000000..f55b128 --- /dev/null +++ b/tests/test_aggr_functions.py @@ -0,0 +1,182 @@ +from typing import Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL) +metadata = sqlalchemy.MetaData() + + +class BaseMeta(ormar.ModelMeta): + metadata = metadata + database = database + + +class Author(ormar.Model): + class Meta(BaseMeta): + tablename = "authors" + order_by = ["-name"] + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Book(ormar.Model): + class Meta(BaseMeta): + tablename = "books" + order_by = ["year", "-ranking"] + + id: int = ormar.Integer(primary_key=True) + author: Optional[Author] = ormar.ForeignKey(Author) + title: str = ormar.String(max_length=100) + year: int = ormar.Integer(nullable=True) + ranking: int = ormar.Integer(nullable=True) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.fixture(autouse=True, scope="function") +async def cleanup(): + yield + async with database: + await Book.objects.delete(each=True) + await Author.objects.delete(each=True) + + +async def sample_data(): + author = await Author(name="Author 1").save() + await Book(title="Book 1", year=1920, ranking=3, author=author).save() + await Book(title="Book 2", year=1930, ranking=1, author=author).save() + await Book(title="Book 3", year=1923, ranking=5, author=author).save() + + +@pytest.mark.asyncio +async def test_min_method(): + async with database: + await sample_data() + assert await Book.objects.min("year") == 1920 + result = await Book.objects.min(["year", "ranking"]) + assert result == (1920, 1) + assert dict(result) == dict(year=1920, ranking=1) + + assert await Book.objects.min("title") == "Book 1" + + assert await Author.objects.select_related("books").min("books__year") == 1920 + result = await Author.objects.select_related("books").min( + ["books__year", "books__ranking"] + ) + assert result == (1920, 1) + assert dict(result) == dict(books__year=1920, books__ranking=1) + + assert ( + await Author.objects.select_related("books") + .filter(books__year__gt=1925) + .min("books__year") + == 1930 + ) + + +@pytest.mark.asyncio +async def test_max_method(): + async with database: + await sample_data() + assert await Book.objects.max("year") == 1930 + result = await Book.objects.max(["year", "ranking"]) + assert result == (1930, 5) + assert dict(result) == dict(year=1930, ranking=5) + + assert await Book.objects.max("title") == "Book 3" + + assert await Author.objects.select_related("books").max("books__year") == 1930 + result = await Author.objects.select_related("books").max( + ["books__year", "books__ranking"] + ) + assert result == (1930, 5) + assert dict(result) == dict(books__year=1930, books__ranking=5) + + assert ( + await Author.objects.select_related("books") + .filter(books__year__lt=1925) + .max("books__year") + == 1923 + ) + + +@pytest.mark.asyncio +async def test_sum_method(): + async with database: + await sample_data() + assert await Book.objects.sum("year") == 5773 + result = await Book.objects.sum(["year", "ranking"]) + assert result == (5773, 9) + assert dict(result) == dict(year=5773, ranking=9) + + assert await Book.objects.sum("title") == 0.0 + + assert await Author.objects.select_related("books").sum("books__year") == 5773 + result = await Author.objects.select_related("books").sum( + ["books__year", "books__ranking"] + ) + assert result == (5773, 9) + assert dict(result) == dict(books__year=5773, books__ranking=9) + + assert ( + await Author.objects.select_related("books") + .filter(books__year__lt=1925) + .sum("books__year") + == 3843 + ) + + +@pytest.mark.asyncio +async def test_avg_method(): + async with database: + await sample_data() + assert round(await Book.objects.avg("year"), 2) == 1924.33 + result = await Book.objects.avg(["year", "ranking"]) + assert (round(result[0], 2), result[1]) == (1924.33, 3.0) + result_dict = dict(result) + assert round(result_dict.get("year"), 2) == 1924.33 + assert result_dict.get("ranking") == 3.0 + + assert await Book.objects.avg("title") == 0.0 + + result = await Author.objects.select_related("books").avg("books__year") + assert round(result, 2) == 1924.33 + result = await Author.objects.select_related("books").avg( + ["books__year", "books__ranking"] + ) + assert (round(result[0], 2), result[1]) == (1924.33, 3.0) + result_dict = dict(result) + assert round(result_dict.get("books__year"), 2) == 1924.33 + assert result_dict.get("books__ranking") == 3.0 + + assert ( + await Author.objects.select_related("books") + .filter(books__year__lt=1925) + .avg("books__year") + == 1921.5 + ) + + +@pytest.mark.asyncio +async def test_queryset_method(): + async with database: + await sample_data() + author = await Author.objects.select_related("books").get() + assert await author.books.min("year") == 1920 + assert await author.books.max("year") == 1930 + assert await author.books.sum("ranking") == 9 + assert await author.books.avg("ranking") == 3.0 + assert await author.books.max(["year", "title"]) == (1930, "Book 3") diff --git a/tests/test_default_model_order.py b/tests/test_default_model_order.py index 721792e..c854bbd 100644 --- a/tests/test_default_model_order.py +++ b/tests/test_default_model_order.py @@ -48,8 +48,10 @@ def create_test_database(): @pytest.fixture(autouse=True, scope="function") async def cleanup(): - await Book.objects.delete(each=True) - await Author.objects.delete(each=True) + yield + async with database: + await Book.objects.delete(each=True) + await Author.objects.delete(each=True) @pytest.mark.asyncio diff --git a/tests/test_signals_for_relations.py b/tests/test_signals_for_relations.py new file mode 100644 index 0000000..e0cc20e --- /dev/null +++ b/tests/test_signals_for_relations.py @@ -0,0 +1,217 @@ +from typing import Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from ormar import ( + post_relation_add, + post_relation_remove, + pre_relation_add, + pre_relation_remove, +) +import pydantic +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class AuditLog(ormar.Model): + class Meta: + tablename = "audits" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + event_type: str = ormar.String(max_length=100) + event_log: pydantic.Json = ormar.JSON() + + +class Cover(ormar.Model): + class Meta: + tablename = "covers" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + title: str = ormar.String(max_length=100) + + +class Artist(ormar.Model): + class Meta: + tablename = "artists" + metadata = metadata + database = database + + id: int = ormar.Integer(name="artist_id", primary_key=True) + name: str = ormar.String(name="fname", max_length=100) + + +class Album(ormar.Model): + class Meta: + tablename = "albums" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + title: str = ormar.String(max_length=100) + cover: Optional[Cover] = ormar.ForeignKey(Cover) + artists = ormar.ManyToMany(Artist) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.fixture(scope="function") +async def cleanup(): + yield + async with database: + await AuditLog.objects.delete(each=True) + + +@pytest.mark.asyncio +async def test_relation_signal_functions(): + async with database: + async with database.transaction(force_rollback=True): + + @pre_relation_add([Album, Cover, Artist]) + async def before_relation_add( + sender, instance, child, relation_name, passed_kwargs, **kwargs + ): + await AuditLog.objects.create( + event_type="RELATION_PRE_ADD", + event_log=dict( + class_affected=sender.get_name(), + parent_id=instance.pk, + child_id=child.pk, + relation_name=relation_name, + kwargs=passed_kwargs, + ), + ) + + passed_kwargs.pop("dummy", None) + + @post_relation_add([Album, Cover, Artist]) + async def after_relation_add( + sender, instance, child, relation_name, passed_kwargs, **kwargs + ): + await AuditLog.objects.create( + event_type="RELATION_POST_ADD", + event_log=dict( + class_affected=sender.get_name(), + parent_id=instance.pk, + child_id=child.pk, + relation_name=relation_name, + kwargs=passed_kwargs, + ), + ) + + @pre_relation_remove([Album, Cover, Artist]) + async def before_relation_remove( + sender, instance, child, relation_name, **kwargs + ): + await AuditLog.objects.create( + event_type="RELATION_PRE_REMOVE", + event_log=dict( + class_affected=sender.get_name(), + parent_id=instance.pk, + child_id=child.pk, + relation_name=relation_name, + kwargs=kwargs, + ), + ) + + @post_relation_remove([Album, Cover, Artist]) + async def after_relation_remove( + sender, instance, child, relation_name, **kwargs + ): + await AuditLog.objects.create( + event_type="RELATION_POST_REMOVE", + event_log=dict( + class_affected=sender.get_name(), + parent_id=instance.pk, + child_id=child.pk, + relation_name=relation_name, + kwargs=kwargs, + ), + ) + + cover = await Cover(title="New").save() + artist = await Artist(name="Artist").save() + album = await Album(title="New Album").save() + + await cover.albums.add(album, index=0) + log = await AuditLog.objects.get(event_type="RELATION_PRE_ADD") + assert log.event_log.get("parent_id") == cover.pk + assert log.event_log.get("child_id") == album.pk + assert log.event_log.get("relation_name") == "albums" + assert log.event_log.get("kwargs") == dict(index=0) + + log2 = await AuditLog.objects.get(event_type="RELATION_POST_ADD") + assert log2.event_log.get("parent_id") == cover.pk + assert log2.event_log.get("child_id") == album.pk + assert log2.event_log.get("relation_name") == "albums" + assert log2.event_log.get("kwargs") == dict(index=0) + + await album.artists.add(artist, dummy="test") + + log3 = await AuditLog.objects.filter( + event_type="RELATION_PRE_ADD", id__gt=log2.pk + ).get() + assert log3.event_log.get("parent_id") == album.pk + assert log3.event_log.get("child_id") == artist.pk + assert log3.event_log.get("relation_name") == "artists" + assert log3.event_log.get("kwargs") == dict(dummy="test") + + log4 = await AuditLog.objects.get( + event_type="RELATION_POST_ADD", id__gt=log3.pk + ) + assert log4.event_log.get("parent_id") == album.pk + assert log4.event_log.get("child_id") == artist.pk + assert log4.event_log.get("relation_name") == "artists" + assert log4.event_log.get("kwargs") == dict() + + assert album.cover == cover + assert len(album.artists) == 1 + + await cover.albums.remove(album) + log = await AuditLog.objects.get(event_type="RELATION_PRE_REMOVE") + assert log.event_log.get("parent_id") == cover.pk + assert log.event_log.get("child_id") == album.pk + assert log.event_log.get("relation_name") == "albums" + assert log.event_log.get("kwargs") == dict() + + log2 = await AuditLog.objects.get(event_type="RELATION_POST_REMOVE") + assert log2.event_log.get("parent_id") == cover.pk + assert log2.event_log.get("child_id") == album.pk + assert log2.event_log.get("relation_name") == "albums" + assert log2.event_log.get("kwargs") == dict() + + await album.artists.remove(artist) + log3 = await AuditLog.objects.filter( + event_type="RELATION_PRE_REMOVE", id__gt=log2.pk + ).get() + assert log3.event_log.get("parent_id") == album.pk + assert log3.event_log.get("child_id") == artist.pk + assert log3.event_log.get("relation_name") == "artists" + assert log3.event_log.get("kwargs") == dict() + + log4 = await AuditLog.objects.get( + event_type="RELATION_POST_REMOVE", id__gt=log3.pk + ) + assert log4.event_log.get("parent_id") == album.pk + assert log4.event_log.get("child_id") == artist.pk + assert log4.event_log.get("relation_name") == "artists" + assert log4.event_log.get("kwargs") == dict() + + await album.load_all() + assert len(album.artists) == 0 + assert album.cover is None