add 4 new relation signales, add 4 new aggr methods, wip to cleanup

This commit is contained in:
collerek
2021-03-12 12:13:08 +01:00
parent 0ae340100e
commit ff9d412508
15 changed files with 701 additions and 33 deletions

View File

@ -24,9 +24,13 @@ from ormar.decorators import ( # noqa: I100
post_delete, post_delete,
post_save, post_save,
post_update, post_update,
post_relation_add,
post_relation_remove,
pre_delete, pre_delete,
pre_save, pre_save,
pre_update, pre_update,
pre_relation_add,
pre_relation_remove,
property_field, property_field,
) )
from ormar.exceptions import ( # noqa: I100 from ormar.exceptions import ( # noqa: I100
@ -102,9 +106,13 @@ __all__ = [
"post_delete", "post_delete",
"post_save", "post_save",
"post_update", "post_update",
"post_relation_add",
"post_relation_remove",
"pre_delete", "pre_delete",
"pre_save", "pre_save",
"pre_update", "pre_update",
"pre_relation_remove",
"pre_relation_add",
"Signal", "Signal",
"BaseField", "BaseField",
"ManyToManyField", "ManyToManyField",

View File

@ -12,9 +12,13 @@ from ormar.decorators.signals import (
post_delete, post_delete,
post_save, post_save,
post_update, post_update,
post_relation_add,
post_relation_remove,
pre_delete, pre_delete,
pre_save, pre_save,
pre_update, pre_update,
pre_relation_add,
pre_relation_remove,
) )
__all__ = [ __all__ = [
@ -25,4 +29,8 @@ __all__ = [
"pre_delete", "pre_delete",
"pre_save", "pre_save",
"pre_update", "pre_update",
"post_relation_remove",
"post_relation_add",
"pre_relation_remove",
"pre_relation_add",
] ]

View File

@ -22,7 +22,7 @@ def receiver(
def _decorator(func: Callable) -> Callable: def _decorator(func: Callable) -> Callable:
""" """
Internal decorator that does all the registeriing. Internal decorator that does all the registering.
:param func: function to register as receiver :param func: function to register as receiver
:type func: Callable :type func: Callable
@ -117,3 +117,57 @@ def pre_delete(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
:rtype: Callable :rtype: Callable
""" """
return receiver(signal="pre_delete", senders=senders) return receiver(signal="pre_delete", senders=senders)
def pre_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
"""
Connect given function to all senders for pre_relation_add signal.
:param senders: one or a list of "Model" classes
that should have the signal receiver registered
:type senders: Union[Type["Model"], List[Type["Model"]]]
:return: returns the original function untouched
:rtype: Callable
"""
return receiver(signal="pre_relation_add", senders=senders)
def post_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
"""
Connect given function to all senders for post_relation_add signal.
:param senders: one or a list of "Model" classes
that should have the signal receiver registered
:type senders: Union[Type["Model"], List[Type["Model"]]]
:return: returns the original function untouched
:rtype: Callable
"""
return receiver(signal="post_relation_add", senders=senders)
def pre_relation_remove(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
"""
Connect given function to all senders for pre_relation_remove signal.
:param senders: one or a list of "Model" classes
that should have the signal receiver registered
:type senders: Union[Type["Model"], List[Type["Model"]]]
:return: returns the original function untouched
:rtype: Callable
"""
return receiver(signal="pre_relation_remove", senders=senders)
def post_relation_remove(
senders: Union[Type["Model"], List[Type["Model"]]]
) -> Callable:
"""
Connect given function to all senders for post_relation_remove signal.
:param senders: one or a list of "Model" classes
that should have the signal receiver registered
:type senders: Union[Type["Model"], List[Type["Model"]]]
:return: returns the original function untouched
:rtype: Callable
"""
return receiver(signal="post_relation_remove", senders=senders)

View File

@ -140,6 +140,10 @@ def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001
signals.post_save = Signal() signals.post_save = Signal()
signals.post_update = Signal() signals.post_update = Signal()
signals.post_delete = Signal() signals.post_delete = Signal()
signals.pre_relation_add = Signal()
signals.post_relation_add = Signal()
signals.pre_relation_remove = Signal()
signals.post_relation_remove = Signal()
new_model.Meta.signals = signals new_model.Meta.signals = signals

View File

@ -216,6 +216,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
) )
if isinstance(object.__getattribute__(self, "__dict__").get(name), list): if isinstance(object.__getattribute__(self, "__dict__").get(name), list):
# virtual foreign key or many to many # virtual foreign key or many to many
# TODO: Fix double items in dict, no effect on real action ugly repr
# if model.pk not in [x.pk for x in related_list]:
object.__getattribute__(self, "__dict__")[name].append(model) object.__getattribute__(self, "__dict__")[name].append(model)
else: else:
# foreign key relation # foreign key relation

View File

@ -1,7 +1,7 @@
""" """
Contains QuerySet and different Query classes to allow for constructing of sql queries. Contains QuerySet and different Query classes to allow for constructing of sql queries.
""" """
from ormar.queryset.actions import FilterAction, OrderAction from ormar.queryset.actions import FilterAction, OrderAction, SelectAction
from ormar.queryset.clause import and_, or_ from ormar.queryset.clause import and_, or_
from ormar.queryset.filter_query import FilterQuery from ormar.queryset.filter_query import FilterQuery
from ormar.queryset.limit_query import LimitQuery from ormar.queryset.limit_query import LimitQuery
@ -17,6 +17,7 @@ __all__ = [
"OrderQuery", "OrderQuery",
"FilterAction", "FilterAction",
"OrderAction", "OrderAction",
"SelectAction",
"and_", "and_",
"or_", "or_",
] ]

View File

@ -1,4 +1,5 @@
from ormar.queryset.actions.filter_action import FilterAction from ormar.queryset.actions.filter_action import FilterAction
from ormar.queryset.actions.order_action import OrderAction from ormar.queryset.actions.order_action import OrderAction
from ormar.queryset.actions.select_action import SelectAction
__all__ = ["FilterAction", "OrderAction"] __all__ = ["FilterAction", "OrderAction", "SelectAction"]

View File

@ -0,0 +1,44 @@
from typing import Callable, TYPE_CHECKING, Type
import sqlalchemy
from ormar.queryset.actions.query_action import QueryAction
if TYPE_CHECKING:
from ormar import Model
class SelectAction(QueryAction):
"""
Order Actions is populated by queryset when order_by() is called.
All required params are extracted but kept raw until actual filter clause value
is required -> then the action is converted into text() clause.
Extracted in order to easily change table prefixes on complex relations.
"""
def __init__(
self, select_str: str, model_cls: Type["Model"], alias: str = None
) -> None:
super().__init__(query_str=select_str, model_cls=model_cls)
if alias:
self.table_prefix = alias
def _split_value_into_parts(self, order_str: str) -> None:
parts = order_str.split("__")
self.field_name = parts[-1]
self.related_parts = parts[:-1]
def get_text_clause(self) -> sqlalchemy.sql.expression.TextClause:
alias = f"{self.table_prefix}_" if self.table_prefix else ""
return sqlalchemy.text(f"{alias}{self.field_name}")
def apply_func(
self, func: Callable, use_label: bool = True
) -> sqlalchemy.sql.expression.TextClause:
result = func(self.get_text_clause())
if use_label:
rel_prefix = f"{self.related_str}__" if self.related_str else ""
result = result.label(f"{rel_prefix}{self.field_name}")
return result

View File

@ -6,7 +6,8 @@ from typing import (
Optional, Optional,
TYPE_CHECKING, TYPE_CHECKING,
Tuple, Tuple,
Type, cast, Type,
cast,
) )
import sqlalchemy import sqlalchemy
@ -24,20 +25,20 @@ if TYPE_CHECKING: # pragma no cover
class SqlJoin: class SqlJoin:
def __init__( # noqa: CFQ002 def __init__( # noqa: CFQ002
self, self,
used_aliases: List, used_aliases: List,
select_from: sqlalchemy.sql.select, select_from: sqlalchemy.sql.select,
columns: List[sqlalchemy.Column], columns: List[sqlalchemy.Column],
excludable: "ExcludableItems", excludable: "ExcludableItems",
order_columns: Optional[List["OrderAction"]], order_columns: Optional[List["OrderAction"]],
sorted_orders: OrderedDict, sorted_orders: OrderedDict,
main_model: Type["Model"], main_model: Type["Model"],
relation_name: str, relation_name: str,
relation_str: str, relation_str: str,
related_models: Any = None, related_models: Any = None,
own_alias: str = "", own_alias: str = "",
source_model: Type["Model"] = None, source_model: Type["Model"] = None,
already_sorted: Dict = None, already_sorted: Dict = None,
) -> None: ) -> None:
self.relation_name = relation_name self.relation_name = relation_name
self.related_models = related_models or [] self.related_models = related_models or []
@ -102,7 +103,7 @@ class SqlJoin:
return self.next_model.Meta.table return self.next_model.Meta.table
def _on_clause( def _on_clause(
self, previous_alias: str, from_clause: str, to_clause: str, self, previous_alias: str, from_clause: str, to_clause: str,
) -> text: ) -> text:
""" """
Receives aliases and names of both ends of the join and combines them Receives aliases and names of both ends of the join and combines them
@ -174,8 +175,8 @@ class SqlJoin:
for related_name in self.related_models: for related_name in self.related_models:
remainder = None remainder = None
if ( if (
isinstance(self.related_models, dict) isinstance(self.related_models, dict)
and self.related_models[related_name] and self.related_models[related_name]
): ):
remainder = self.related_models[related_name] remainder = self.related_models[related_name]
self._process_deeper_join(related_name=related_name, remainder=remainder) self._process_deeper_join(related_name=related_name, remainder=remainder)
@ -257,18 +258,18 @@ class SqlJoin:
""" """
target_field = self.target_field target_field = self.target_field
is_primary_self_ref = ( is_primary_self_ref = (
target_field.self_reference target_field.self_reference
and self.relation_name == target_field.self_reference_primary and self.relation_name == target_field.self_reference_primary
) )
if (is_primary_self_ref and not reverse) or ( if (is_primary_self_ref and not reverse) or (
not is_primary_self_ref and reverse not is_primary_self_ref and reverse
): ):
new_part = target_field.default_source_field_name() # type: ignore new_part = target_field.default_source_field_name() # type: ignore
else: else:
new_part = target_field.default_target_field_name() # type: ignore new_part = target_field.default_target_field_name() # type: ignore
return new_part return new_part
def _process_join(self, ) -> None: # noqa: CFQ002 def _process_join(self,) -> None: # noqa: CFQ002
""" """
Resolves to and from column names and table names. Resolves to and from column names and table names.
@ -331,7 +332,7 @@ class SqlJoin:
if self.order_columns: if self.order_columns:
for condition in self.order_columns: for condition in self.order_columns:
if condition.check_if_filter_apply( if condition.check_if_filter_apply(
target_model=self.next_model, alias=alias target_model=self.next_model, alias=alias
): ):
current_table_sorted = True current_table_sorted = True
self.sorted_orders[condition] = condition.get_text_clause() self.sorted_orders[condition] = condition.get_text_clause()
@ -345,8 +346,8 @@ class SqlJoin:
if self.target_field.is_multi and "__" in order_by: if self.target_field.is_multi and "__" in order_by:
parts = order_by.split("__") parts = order_by.split("__")
if ( if (
len(parts) > 2 len(parts) > 2
or parts[0] != self.target_field.through.get_name() or parts[0] != self.target_field.through.get_name()
): ):
raise ModelDefinitionError( raise ModelDefinitionError(
"You can order the relation only" "You can order the relation only"
@ -359,8 +360,9 @@ class SqlJoin:
elif self.target_field.is_multi: elif self.target_field.is_multi:
alias = self.alias_manager.resolve_relation_alias( alias = self.alias_manager.resolve_relation_alias(
from_model=self.target_field.through, from_model=self.target_field.through,
relation_name=cast("ManyToManyField", relation_name=cast(
self.target_field).default_target_field_name(), "ManyToManyField", self.target_field
).default_target_field_name(),
) )
model = self.target_field.to model = self.target_field.to
clause = ormar.OrderAction( clause = ormar.OrderAction(

View File

@ -14,11 +14,12 @@ from typing import (
import databases import databases
import sqlalchemy import sqlalchemy
from sqlalchemy import bindparam from sqlalchemy import bindparam
from sqlalchemy.engine import ResultProxy
import ormar # noqa I100 import ormar # noqa I100
from ormar import MultipleMatches, NoMatch from ormar import MultipleMatches, NoMatch
from ormar.exceptions import ModelError, ModelPersistenceError, QueryDefinitionError from ormar.exceptions import ModelError, ModelPersistenceError, QueryDefinitionError
from ormar.queryset import FilterQuery from ormar.queryset import FilterQuery, SelectAction
from ormar.queryset.actions.order_action import OrderAction from ormar.queryset.actions.order_action import OrderAction
from ormar.queryset.clause import FilterGroup, QueryClause from ormar.queryset.clause import FilterGroup, QueryClause
from ormar.queryset.prefetch_query import PrefetchQuery from ormar.queryset.prefetch_query import PrefetchQuery
@ -557,6 +558,73 @@ class QuerySet:
expr = sqlalchemy.func.count().select().select_from(expr) expr = sqlalchemy.func.count().select().select_from(expr)
return await self.database.fetch_val(expr) return await self.database.fetch_val(expr)
async def _query_aggr_function(self, func_name: str, columns: List):
func = getattr(sqlalchemy.func, func_name)
select_actions = [
SelectAction(select_str=column, model_cls=self.model)
for column in columns
]
select_columns = [x.apply_func(func, use_label=True) for x in select_actions]
expr = self.build_select_expression().alias(f"subquery_for_{func_name}")
expr = sqlalchemy.select(select_columns).select_from(expr)
# print("\n", expr.compile(compile_kwargs={"literal_binds": True}))
result = await self.database.fetch_one(expr)
return result if len(result) > 1 else result[0] # type: ignore
async def max( # noqa: A003
self, columns: Union[str, List[str]]
) -> Union[Any, ResultProxy]:
"""
Returns max value of columns for rows matching the given criteria
(applied with `filter` and `exclude` if set before).
:return: max value of column(s)
:rtype: Any
"""
if not isinstance(columns, list):
columns = [columns]
return await self._query_aggr_function(func_name="max", columns=columns)
async def min( # noqa: A003
self, columns: Union[str, List[str]]
) -> Union[Any, ResultProxy]:
"""
Returns min value of columns for rows matching the given criteria
(applied with `filter` and `exclude` if set before).
:return: min value of column(s)
:rtype: Any
"""
if not isinstance(columns, list):
columns = [columns]
return await self._query_aggr_function(func_name="min", columns=columns)
async def sum( # noqa: A003
self, columns: Union[str, List[str]]
) -> Union[Any, ResultProxy]:
"""
Returns sum value of columns for rows matching the given criteria
(applied with `filter` and `exclude` if set before).
:return: sum value of columns
:rtype: int
"""
if not isinstance(columns, list):
columns = [columns]
return await self._query_aggr_function(func_name="sum", columns=columns)
async def avg(self, columns: Union[str, List[str]]) -> Union[Any, ResultProxy]:
"""
Returns avg value of columns for rows matching the given criteria
(applied with `filter` and `exclude` if set before).
:return: avg value of columns
:rtype: Union[int, float, List]
"""
if not isinstance(columns, list):
columns = [columns]
return await self._query_aggr_function(func_name="avg", columns=columns)
async def update(self, each: bool = False, **kwargs: Any) -> int: async def update(self, each: bool = False, **kwargs: Any) -> int:
""" """
Updates the model table after applying the filters from kwargs. Updates the model table after applying the filters from kwargs.

View File

@ -12,6 +12,8 @@ from typing import ( # noqa: I100, I201
cast, cast,
) )
from sqlalchemy.engine import ResultProxy
import ormar import ormar
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
@ -116,6 +118,7 @@ class QuerysetProxy:
:type child: Model :type child: Model
""" """
model_cls = self.relation.through model_cls = self.relation.through
# TODO: Add support for pk with default not only autoincrement id
owner_column = self.related_field.default_target_field_name() # type: ignore owner_column = self.related_field.default_target_field_name() # type: ignore
child_column = self.related_field.default_source_field_name() # type: ignore child_column = self.related_field.default_source_field_name() # type: ignore
rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk} rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk}
@ -185,6 +188,52 @@ class QuerysetProxy:
""" """
return await self.queryset.count() return await self.queryset.count()
async def max( # noqa: A003
self, columns: Union[str, List[str]]
) -> Union[Any, ResultProxy]:
"""
Returns max value of columns for rows matching the given criteria
(applied with `filter` and `exclude` if set before).
:return: max value of column(s)
:rtype: Any
"""
return await self.queryset.max(columns=columns)
async def min( # noqa: A003
self, columns: Union[str, List[str]]
) -> Union[Any, ResultProxy]:
"""
Returns min value of columns for rows matching the given criteria
(applied with `filter` and `exclude` if set before).
:return: min value of column(s)
:rtype: Any
"""
return await self.queryset.min(columns=columns)
async def sum( # noqa: A003
self, columns: Union[str, List[str]]
) -> Union[Any, ResultProxy]:
"""
Returns sum value of columns for rows matching the given criteria
(applied with `filter` and `exclude` if set before).
:return: sum value of columns
:rtype: int
"""
return await self.queryset.sum(columns=columns)
async def avg(self, columns: Union[str, List[str]]) -> Union[Any, ResultProxy]:
"""
Returns avg value of columns for rows matching the given criteria
(applied with `filter` and `exclude` if set before).
:return: avg value of columns
:rtype: Union[int, float, List]
"""
return await self.queryset.avg(columns=columns)
async def clear(self, keep_reversed: bool = True) -> int: async def clear(self, keep_reversed: bool = True) -> int:
""" """
Removes all related models from given relation. Removes all related models from given relation.

View File

@ -152,6 +152,12 @@ class RelationProxy(list):
f"Object {self._owner.get_name()} has no " f"Object {self._owner.get_name()} has no "
f"{item.get_name()} with given primary key!" f"{item.get_name()} with given primary key!"
) )
await self._owner.signals.pre_relation_remove.send(
sender=self._owner.__class__,
instance=self._owner,
child=item,
relation_name=self.field_name,
)
super().remove(item) super().remove(item)
relation_name = self.related_field_name relation_name = self.related_field_name
relation = item._orm._get(relation_name) relation = item._orm._get(relation_name)
@ -169,6 +175,12 @@ class RelationProxy(list):
await item.update() await item.update()
else: else:
await item.delete() await item.delete()
await self._owner.signals.post_relation_remove.send(
sender=self._owner.__class__,
instance=self._owner,
child=item,
relation_name=self.field_name,
)
async def add(self, item: "Model", **kwargs: Any) -> None: async def add(self, item: "Model", **kwargs: Any) -> None:
""" """
@ -182,6 +194,13 @@ class RelationProxy(list):
:type item: Model :type item: Model
""" """
relation_name = self.related_field_name relation_name = self.related_field_name
await self._owner.signals.pre_relation_add.send(
sender=self._owner.__class__,
instance=self._owner,
child=item,
relation_name=self.field_name,
passed_kwargs=kwargs,
)
self._check_if_model_saved() self._check_if_model_saved()
if self.type_ == ormar.RelationType.MULTIPLE: if self.type_ == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item, **kwargs) await self.queryset_proxy.create_through_instance(item, **kwargs)
@ -189,3 +208,10 @@ class RelationProxy(list):
else: else:
setattr(item, relation_name, self._owner) setattr(item, relation_name, self._owner)
await item.update() await item.update()
await self._owner.signals.post_relation_add.send(
sender=self._owner.__class__,
instance=self._owner,
child=item,
relation_name=self.field_name,
passed_kwargs=kwargs,
)

View File

@ -0,0 +1,182 @@
from typing import Optional
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
class BaseMeta(ormar.ModelMeta):
metadata = metadata
database = database
class Author(ormar.Model):
class Meta(BaseMeta):
tablename = "authors"
order_by = ["-name"]
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Book(ormar.Model):
class Meta(BaseMeta):
tablename = "books"
order_by = ["year", "-ranking"]
id: int = ormar.Integer(primary_key=True)
author: Optional[Author] = ormar.ForeignKey(Author)
title: str = ormar.String(max_length=100)
year: int = ormar.Integer(nullable=True)
ranking: int = ormar.Integer(nullable=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.fixture(autouse=True, scope="function")
async def cleanup():
yield
async with database:
await Book.objects.delete(each=True)
await Author.objects.delete(each=True)
async def sample_data():
author = await Author(name="Author 1").save()
await Book(title="Book 1", year=1920, ranking=3, author=author).save()
await Book(title="Book 2", year=1930, ranking=1, author=author).save()
await Book(title="Book 3", year=1923, ranking=5, author=author).save()
@pytest.mark.asyncio
async def test_min_method():
async with database:
await sample_data()
assert await Book.objects.min("year") == 1920
result = await Book.objects.min(["year", "ranking"])
assert result == (1920, 1)
assert dict(result) == dict(year=1920, ranking=1)
assert await Book.objects.min("title") == "Book 1"
assert await Author.objects.select_related("books").min("books__year") == 1920
result = await Author.objects.select_related("books").min(
["books__year", "books__ranking"]
)
assert result == (1920, 1)
assert dict(result) == dict(books__year=1920, books__ranking=1)
assert (
await Author.objects.select_related("books")
.filter(books__year__gt=1925)
.min("books__year")
== 1930
)
@pytest.mark.asyncio
async def test_max_method():
async with database:
await sample_data()
assert await Book.objects.max("year") == 1930
result = await Book.objects.max(["year", "ranking"])
assert result == (1930, 5)
assert dict(result) == dict(year=1930, ranking=5)
assert await Book.objects.max("title") == "Book 3"
assert await Author.objects.select_related("books").max("books__year") == 1930
result = await Author.objects.select_related("books").max(
["books__year", "books__ranking"]
)
assert result == (1930, 5)
assert dict(result) == dict(books__year=1930, books__ranking=5)
assert (
await Author.objects.select_related("books")
.filter(books__year__lt=1925)
.max("books__year")
== 1923
)
@pytest.mark.asyncio
async def test_sum_method():
async with database:
await sample_data()
assert await Book.objects.sum("year") == 5773
result = await Book.objects.sum(["year", "ranking"])
assert result == (5773, 9)
assert dict(result) == dict(year=5773, ranking=9)
assert await Book.objects.sum("title") == 0.0
assert await Author.objects.select_related("books").sum("books__year") == 5773
result = await Author.objects.select_related("books").sum(
["books__year", "books__ranking"]
)
assert result == (5773, 9)
assert dict(result) == dict(books__year=5773, books__ranking=9)
assert (
await Author.objects.select_related("books")
.filter(books__year__lt=1925)
.sum("books__year")
== 3843
)
@pytest.mark.asyncio
async def test_avg_method():
async with database:
await sample_data()
assert round(await Book.objects.avg("year"), 2) == 1924.33
result = await Book.objects.avg(["year", "ranking"])
assert (round(result[0], 2), result[1]) == (1924.33, 3.0)
result_dict = dict(result)
assert round(result_dict.get("year"), 2) == 1924.33
assert result_dict.get("ranking") == 3.0
assert await Book.objects.avg("title") == 0.0
result = await Author.objects.select_related("books").avg("books__year")
assert round(result, 2) == 1924.33
result = await Author.objects.select_related("books").avg(
["books__year", "books__ranking"]
)
assert (round(result[0], 2), result[1]) == (1924.33, 3.0)
result_dict = dict(result)
assert round(result_dict.get("books__year"), 2) == 1924.33
assert result_dict.get("books__ranking") == 3.0
assert (
await Author.objects.select_related("books")
.filter(books__year__lt=1925)
.avg("books__year")
== 1921.5
)
@pytest.mark.asyncio
async def test_queryset_method():
async with database:
await sample_data()
author = await Author.objects.select_related("books").get()
assert await author.books.min("year") == 1920
assert await author.books.max("year") == 1930
assert await author.books.sum("ranking") == 9
assert await author.books.avg("ranking") == 3.0
assert await author.books.max(["year", "title"]) == (1930, "Book 3")

View File

@ -48,8 +48,10 @@ def create_test_database():
@pytest.fixture(autouse=True, scope="function") @pytest.fixture(autouse=True, scope="function")
async def cleanup(): async def cleanup():
await Book.objects.delete(each=True) yield
await Author.objects.delete(each=True) async with database:
await Book.objects.delete(each=True)
await Author.objects.delete(each=True)
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -0,0 +1,217 @@
from typing import Optional
import databases
import pytest
import sqlalchemy
import ormar
from ormar import (
post_relation_add,
post_relation_remove,
pre_relation_add,
pre_relation_remove,
)
import pydantic
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class AuditLog(ormar.Model):
class Meta:
tablename = "audits"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
event_type: str = ormar.String(max_length=100)
event_log: pydantic.Json = ormar.JSON()
class Cover(ormar.Model):
class Meta:
tablename = "covers"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=100)
class Artist(ormar.Model):
class Meta:
tablename = "artists"
metadata = metadata
database = database
id: int = ormar.Integer(name="artist_id", primary_key=True)
name: str = ormar.String(name="fname", max_length=100)
class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=100)
cover: Optional[Cover] = ormar.ForeignKey(Cover)
artists = ormar.ManyToMany(Artist)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.fixture(scope="function")
async def cleanup():
yield
async with database:
await AuditLog.objects.delete(each=True)
@pytest.mark.asyncio
async def test_relation_signal_functions():
async with database:
async with database.transaction(force_rollback=True):
@pre_relation_add([Album, Cover, Artist])
async def before_relation_add(
sender, instance, child, relation_name, passed_kwargs, **kwargs
):
await AuditLog.objects.create(
event_type="RELATION_PRE_ADD",
event_log=dict(
class_affected=sender.get_name(),
parent_id=instance.pk,
child_id=child.pk,
relation_name=relation_name,
kwargs=passed_kwargs,
),
)
passed_kwargs.pop("dummy", None)
@post_relation_add([Album, Cover, Artist])
async def after_relation_add(
sender, instance, child, relation_name, passed_kwargs, **kwargs
):
await AuditLog.objects.create(
event_type="RELATION_POST_ADD",
event_log=dict(
class_affected=sender.get_name(),
parent_id=instance.pk,
child_id=child.pk,
relation_name=relation_name,
kwargs=passed_kwargs,
),
)
@pre_relation_remove([Album, Cover, Artist])
async def before_relation_remove(
sender, instance, child, relation_name, **kwargs
):
await AuditLog.objects.create(
event_type="RELATION_PRE_REMOVE",
event_log=dict(
class_affected=sender.get_name(),
parent_id=instance.pk,
child_id=child.pk,
relation_name=relation_name,
kwargs=kwargs,
),
)
@post_relation_remove([Album, Cover, Artist])
async def after_relation_remove(
sender, instance, child, relation_name, **kwargs
):
await AuditLog.objects.create(
event_type="RELATION_POST_REMOVE",
event_log=dict(
class_affected=sender.get_name(),
parent_id=instance.pk,
child_id=child.pk,
relation_name=relation_name,
kwargs=kwargs,
),
)
cover = await Cover(title="New").save()
artist = await Artist(name="Artist").save()
album = await Album(title="New Album").save()
await cover.albums.add(album, index=0)
log = await AuditLog.objects.get(event_type="RELATION_PRE_ADD")
assert log.event_log.get("parent_id") == cover.pk
assert log.event_log.get("child_id") == album.pk
assert log.event_log.get("relation_name") == "albums"
assert log.event_log.get("kwargs") == dict(index=0)
log2 = await AuditLog.objects.get(event_type="RELATION_POST_ADD")
assert log2.event_log.get("parent_id") == cover.pk
assert log2.event_log.get("child_id") == album.pk
assert log2.event_log.get("relation_name") == "albums"
assert log2.event_log.get("kwargs") == dict(index=0)
await album.artists.add(artist, dummy="test")
log3 = await AuditLog.objects.filter(
event_type="RELATION_PRE_ADD", id__gt=log2.pk
).get()
assert log3.event_log.get("parent_id") == album.pk
assert log3.event_log.get("child_id") == artist.pk
assert log3.event_log.get("relation_name") == "artists"
assert log3.event_log.get("kwargs") == dict(dummy="test")
log4 = await AuditLog.objects.get(
event_type="RELATION_POST_ADD", id__gt=log3.pk
)
assert log4.event_log.get("parent_id") == album.pk
assert log4.event_log.get("child_id") == artist.pk
assert log4.event_log.get("relation_name") == "artists"
assert log4.event_log.get("kwargs") == dict()
assert album.cover == cover
assert len(album.artists) == 1
await cover.albums.remove(album)
log = await AuditLog.objects.get(event_type="RELATION_PRE_REMOVE")
assert log.event_log.get("parent_id") == cover.pk
assert log.event_log.get("child_id") == album.pk
assert log.event_log.get("relation_name") == "albums"
assert log.event_log.get("kwargs") == dict()
log2 = await AuditLog.objects.get(event_type="RELATION_POST_REMOVE")
assert log2.event_log.get("parent_id") == cover.pk
assert log2.event_log.get("child_id") == album.pk
assert log2.event_log.get("relation_name") == "albums"
assert log2.event_log.get("kwargs") == dict()
await album.artists.remove(artist)
log3 = await AuditLog.objects.filter(
event_type="RELATION_PRE_REMOVE", id__gt=log2.pk
).get()
assert log3.event_log.get("parent_id") == album.pk
assert log3.event_log.get("child_id") == artist.pk
assert log3.event_log.get("relation_name") == "artists"
assert log3.event_log.get("kwargs") == dict()
log4 = await AuditLog.objects.get(
event_type="RELATION_POST_REMOVE", id__gt=log3.pk
)
assert log4.event_log.get("parent_id") == album.pk
assert log4.event_log.get("child_id") == artist.pk
assert log4.event_log.get("relation_name") == "artists"
assert log4.event_log.get("kwargs") == dict()
await album.load_all()
assert len(album.artists) == 0
assert album.cover is None