add 4 new relation signales, add 4 new aggr methods, wip to cleanup
This commit is contained in:
@ -24,9 +24,13 @@ from ormar.decorators import ( # noqa: I100
|
||||
post_delete,
|
||||
post_save,
|
||||
post_update,
|
||||
post_relation_add,
|
||||
post_relation_remove,
|
||||
pre_delete,
|
||||
pre_save,
|
||||
pre_update,
|
||||
pre_relation_add,
|
||||
pre_relation_remove,
|
||||
property_field,
|
||||
)
|
||||
from ormar.exceptions import ( # noqa: I100
|
||||
@ -102,9 +106,13 @@ __all__ = [
|
||||
"post_delete",
|
||||
"post_save",
|
||||
"post_update",
|
||||
"post_relation_add",
|
||||
"post_relation_remove",
|
||||
"pre_delete",
|
||||
"pre_save",
|
||||
"pre_update",
|
||||
"pre_relation_remove",
|
||||
"pre_relation_add",
|
||||
"Signal",
|
||||
"BaseField",
|
||||
"ManyToManyField",
|
||||
|
||||
@ -12,9 +12,13 @@ from ormar.decorators.signals import (
|
||||
post_delete,
|
||||
post_save,
|
||||
post_update,
|
||||
post_relation_add,
|
||||
post_relation_remove,
|
||||
pre_delete,
|
||||
pre_save,
|
||||
pre_update,
|
||||
pre_relation_add,
|
||||
pre_relation_remove,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -25,4 +29,8 @@ __all__ = [
|
||||
"pre_delete",
|
||||
"pre_save",
|
||||
"pre_update",
|
||||
"post_relation_remove",
|
||||
"post_relation_add",
|
||||
"pre_relation_remove",
|
||||
"pre_relation_add",
|
||||
]
|
||||
|
||||
@ -22,7 +22,7 @@ def receiver(
|
||||
def _decorator(func: Callable) -> Callable:
|
||||
"""
|
||||
|
||||
Internal decorator that does all the registeriing.
|
||||
Internal decorator that does all the registering.
|
||||
|
||||
:param func: function to register as receiver
|
||||
:type func: Callable
|
||||
@ -117,3 +117,57 @@ def pre_delete(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
|
||||
:rtype: Callable
|
||||
"""
|
||||
return receiver(signal="pre_delete", senders=senders)
|
||||
|
||||
|
||||
def pre_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
|
||||
"""
|
||||
Connect given function to all senders for pre_relation_add signal.
|
||||
|
||||
:param senders: one or a list of "Model" classes
|
||||
that should have the signal receiver registered
|
||||
:type senders: Union[Type["Model"], List[Type["Model"]]]
|
||||
:return: returns the original function untouched
|
||||
:rtype: Callable
|
||||
"""
|
||||
return receiver(signal="pre_relation_add", senders=senders)
|
||||
|
||||
|
||||
def post_relation_add(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
|
||||
"""
|
||||
Connect given function to all senders for post_relation_add signal.
|
||||
|
||||
:param senders: one or a list of "Model" classes
|
||||
that should have the signal receiver registered
|
||||
:type senders: Union[Type["Model"], List[Type["Model"]]]
|
||||
:return: returns the original function untouched
|
||||
:rtype: Callable
|
||||
"""
|
||||
return receiver(signal="post_relation_add", senders=senders)
|
||||
|
||||
|
||||
def pre_relation_remove(senders: Union[Type["Model"], List[Type["Model"]]]) -> Callable:
|
||||
"""
|
||||
Connect given function to all senders for pre_relation_remove signal.
|
||||
|
||||
:param senders: one or a list of "Model" classes
|
||||
that should have the signal receiver registered
|
||||
:type senders: Union[Type["Model"], List[Type["Model"]]]
|
||||
:return: returns the original function untouched
|
||||
:rtype: Callable
|
||||
"""
|
||||
return receiver(signal="pre_relation_remove", senders=senders)
|
||||
|
||||
|
||||
def post_relation_remove(
|
||||
senders: Union[Type["Model"], List[Type["Model"]]]
|
||||
) -> Callable:
|
||||
"""
|
||||
Connect given function to all senders for post_relation_remove signal.
|
||||
|
||||
:param senders: one or a list of "Model" classes
|
||||
that should have the signal receiver registered
|
||||
:type senders: Union[Type["Model"], List[Type["Model"]]]
|
||||
:return: returns the original function untouched
|
||||
:rtype: Callable
|
||||
"""
|
||||
return receiver(signal="post_relation_remove", senders=senders)
|
||||
|
||||
@ -140,6 +140,10 @@ def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001
|
||||
signals.post_save = Signal()
|
||||
signals.post_update = Signal()
|
||||
signals.post_delete = Signal()
|
||||
signals.pre_relation_add = Signal()
|
||||
signals.post_relation_add = Signal()
|
||||
signals.pre_relation_remove = Signal()
|
||||
signals.post_relation_remove = Signal()
|
||||
new_model.Meta.signals = signals
|
||||
|
||||
|
||||
|
||||
@ -216,6 +216,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
)
|
||||
if isinstance(object.__getattribute__(self, "__dict__").get(name), list):
|
||||
# virtual foreign key or many to many
|
||||
# TODO: Fix double items in dict, no effect on real action ugly repr
|
||||
# if model.pk not in [x.pk for x in related_list]:
|
||||
object.__getattribute__(self, "__dict__")[name].append(model)
|
||||
else:
|
||||
# foreign key relation
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""
|
||||
Contains QuerySet and different Query classes to allow for constructing of sql queries.
|
||||
"""
|
||||
from ormar.queryset.actions import FilterAction, OrderAction
|
||||
from ormar.queryset.actions import FilterAction, OrderAction, SelectAction
|
||||
from ormar.queryset.clause import and_, or_
|
||||
from ormar.queryset.filter_query import FilterQuery
|
||||
from ormar.queryset.limit_query import LimitQuery
|
||||
@ -17,6 +17,7 @@ __all__ = [
|
||||
"OrderQuery",
|
||||
"FilterAction",
|
||||
"OrderAction",
|
||||
"SelectAction",
|
||||
"and_",
|
||||
"or_",
|
||||
]
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from ormar.queryset.actions.filter_action import FilterAction
|
||||
from ormar.queryset.actions.order_action import OrderAction
|
||||
from ormar.queryset.actions.select_action import SelectAction
|
||||
|
||||
__all__ = ["FilterAction", "OrderAction"]
|
||||
__all__ = ["FilterAction", "OrderAction", "SelectAction"]
|
||||
|
||||
44
ormar/queryset/actions/select_action.py
Normal file
44
ormar/queryset/actions/select_action.py
Normal file
@ -0,0 +1,44 @@
|
||||
from typing import Callable, TYPE_CHECKING, Type
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ormar.queryset.actions.query_action import QueryAction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ormar import Model
|
||||
|
||||
|
||||
class SelectAction(QueryAction):
|
||||
"""
|
||||
Order Actions is populated by queryset when order_by() is called.
|
||||
|
||||
All required params are extracted but kept raw until actual filter clause value
|
||||
is required -> then the action is converted into text() clause.
|
||||
|
||||
Extracted in order to easily change table prefixes on complex relations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, select_str: str, model_cls: Type["Model"], alias: str = None
|
||||
) -> None:
|
||||
super().__init__(query_str=select_str, model_cls=model_cls)
|
||||
if alias:
|
||||
self.table_prefix = alias
|
||||
|
||||
def _split_value_into_parts(self, order_str: str) -> None:
|
||||
parts = order_str.split("__")
|
||||
self.field_name = parts[-1]
|
||||
self.related_parts = parts[:-1]
|
||||
|
||||
def get_text_clause(self) -> sqlalchemy.sql.expression.TextClause:
|
||||
alias = f"{self.table_prefix}_" if self.table_prefix else ""
|
||||
return sqlalchemy.text(f"{alias}{self.field_name}")
|
||||
|
||||
def apply_func(
|
||||
self, func: Callable, use_label: bool = True
|
||||
) -> sqlalchemy.sql.expression.TextClause:
|
||||
result = func(self.get_text_clause())
|
||||
if use_label:
|
||||
rel_prefix = f"{self.related_str}__" if self.related_str else ""
|
||||
result = result.label(f"{rel_prefix}{self.field_name}")
|
||||
return result
|
||||
@ -6,7 +6,8 @@ from typing import (
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Type, cast,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
import sqlalchemy
|
||||
@ -24,20 +25,20 @@ if TYPE_CHECKING: # pragma no cover
|
||||
|
||||
class SqlJoin:
|
||||
def __init__( # noqa: CFQ002
|
||||
self,
|
||||
used_aliases: List,
|
||||
select_from: sqlalchemy.sql.select,
|
||||
columns: List[sqlalchemy.Column],
|
||||
excludable: "ExcludableItems",
|
||||
order_columns: Optional[List["OrderAction"]],
|
||||
sorted_orders: OrderedDict,
|
||||
main_model: Type["Model"],
|
||||
relation_name: str,
|
||||
relation_str: str,
|
||||
related_models: Any = None,
|
||||
own_alias: str = "",
|
||||
source_model: Type["Model"] = None,
|
||||
already_sorted: Dict = None,
|
||||
self,
|
||||
used_aliases: List,
|
||||
select_from: sqlalchemy.sql.select,
|
||||
columns: List[sqlalchemy.Column],
|
||||
excludable: "ExcludableItems",
|
||||
order_columns: Optional[List["OrderAction"]],
|
||||
sorted_orders: OrderedDict,
|
||||
main_model: Type["Model"],
|
||||
relation_name: str,
|
||||
relation_str: str,
|
||||
related_models: Any = None,
|
||||
own_alias: str = "",
|
||||
source_model: Type["Model"] = None,
|
||||
already_sorted: Dict = None,
|
||||
) -> None:
|
||||
self.relation_name = relation_name
|
||||
self.related_models = related_models or []
|
||||
@ -102,7 +103,7 @@ class SqlJoin:
|
||||
return self.next_model.Meta.table
|
||||
|
||||
def _on_clause(
|
||||
self, previous_alias: str, from_clause: str, to_clause: str,
|
||||
self, previous_alias: str, from_clause: str, to_clause: str,
|
||||
) -> text:
|
||||
"""
|
||||
Receives aliases and names of both ends of the join and combines them
|
||||
@ -174,8 +175,8 @@ class SqlJoin:
|
||||
for related_name in self.related_models:
|
||||
remainder = None
|
||||
if (
|
||||
isinstance(self.related_models, dict)
|
||||
and self.related_models[related_name]
|
||||
isinstance(self.related_models, dict)
|
||||
and self.related_models[related_name]
|
||||
):
|
||||
remainder = self.related_models[related_name]
|
||||
self._process_deeper_join(related_name=related_name, remainder=remainder)
|
||||
@ -257,18 +258,18 @@ class SqlJoin:
|
||||
"""
|
||||
target_field = self.target_field
|
||||
is_primary_self_ref = (
|
||||
target_field.self_reference
|
||||
and self.relation_name == target_field.self_reference_primary
|
||||
target_field.self_reference
|
||||
and self.relation_name == target_field.self_reference_primary
|
||||
)
|
||||
if (is_primary_self_ref and not reverse) or (
|
||||
not is_primary_self_ref and reverse
|
||||
not is_primary_self_ref and reverse
|
||||
):
|
||||
new_part = target_field.default_source_field_name() # type: ignore
|
||||
else:
|
||||
new_part = target_field.default_target_field_name() # type: ignore
|
||||
return new_part
|
||||
|
||||
def _process_join(self, ) -> None: # noqa: CFQ002
|
||||
def _process_join(self,) -> None: # noqa: CFQ002
|
||||
"""
|
||||
Resolves to and from column names and table names.
|
||||
|
||||
@ -331,7 +332,7 @@ class SqlJoin:
|
||||
if self.order_columns:
|
||||
for condition in self.order_columns:
|
||||
if condition.check_if_filter_apply(
|
||||
target_model=self.next_model, alias=alias
|
||||
target_model=self.next_model, alias=alias
|
||||
):
|
||||
current_table_sorted = True
|
||||
self.sorted_orders[condition] = condition.get_text_clause()
|
||||
@ -345,8 +346,8 @@ class SqlJoin:
|
||||
if self.target_field.is_multi and "__" in order_by:
|
||||
parts = order_by.split("__")
|
||||
if (
|
||||
len(parts) > 2
|
||||
or parts[0] != self.target_field.through.get_name()
|
||||
len(parts) > 2
|
||||
or parts[0] != self.target_field.through.get_name()
|
||||
):
|
||||
raise ModelDefinitionError(
|
||||
"You can order the relation only"
|
||||
@ -359,8 +360,9 @@ class SqlJoin:
|
||||
elif self.target_field.is_multi:
|
||||
alias = self.alias_manager.resolve_relation_alias(
|
||||
from_model=self.target_field.through,
|
||||
relation_name=cast("ManyToManyField",
|
||||
self.target_field).default_target_field_name(),
|
||||
relation_name=cast(
|
||||
"ManyToManyField", self.target_field
|
||||
).default_target_field_name(),
|
||||
)
|
||||
model = self.target_field.to
|
||||
clause = ormar.OrderAction(
|
||||
|
||||
@ -14,11 +14,12 @@ from typing import (
|
||||
import databases
|
||||
import sqlalchemy
|
||||
from sqlalchemy import bindparam
|
||||
from sqlalchemy.engine import ResultProxy
|
||||
|
||||
import ormar # noqa I100
|
||||
from ormar import MultipleMatches, NoMatch
|
||||
from ormar.exceptions import ModelError, ModelPersistenceError, QueryDefinitionError
|
||||
from ormar.queryset import FilterQuery
|
||||
from ormar.queryset import FilterQuery, SelectAction
|
||||
from ormar.queryset.actions.order_action import OrderAction
|
||||
from ormar.queryset.clause import FilterGroup, QueryClause
|
||||
from ormar.queryset.prefetch_query import PrefetchQuery
|
||||
@ -557,6 +558,73 @@ class QuerySet:
|
||||
expr = sqlalchemy.func.count().select().select_from(expr)
|
||||
return await self.database.fetch_val(expr)
|
||||
|
||||
async def _query_aggr_function(self, func_name: str, columns: List):
|
||||
func = getattr(sqlalchemy.func, func_name)
|
||||
select_actions = [
|
||||
SelectAction(select_str=column, model_cls=self.model)
|
||||
for column in columns
|
||||
]
|
||||
select_columns = [x.apply_func(func, use_label=True) for x in select_actions]
|
||||
expr = self.build_select_expression().alias(f"subquery_for_{func_name}")
|
||||
expr = sqlalchemy.select(select_columns).select_from(expr)
|
||||
# print("\n", expr.compile(compile_kwargs={"literal_binds": True}))
|
||||
result = await self.database.fetch_one(expr)
|
||||
return result if len(result) > 1 else result[0] # type: ignore
|
||||
|
||||
async def max( # noqa: A003
|
||||
self, columns: Union[str, List[str]]
|
||||
) -> Union[Any, ResultProxy]:
|
||||
"""
|
||||
Returns max value of columns for rows matching the given criteria
|
||||
(applied with `filter` and `exclude` if set before).
|
||||
|
||||
:return: max value of column(s)
|
||||
:rtype: Any
|
||||
"""
|
||||
if not isinstance(columns, list):
|
||||
columns = [columns]
|
||||
return await self._query_aggr_function(func_name="max", columns=columns)
|
||||
|
||||
async def min( # noqa: A003
|
||||
self, columns: Union[str, List[str]]
|
||||
) -> Union[Any, ResultProxy]:
|
||||
"""
|
||||
Returns min value of columns for rows matching the given criteria
|
||||
(applied with `filter` and `exclude` if set before).
|
||||
|
||||
:return: min value of column(s)
|
||||
:rtype: Any
|
||||
"""
|
||||
if not isinstance(columns, list):
|
||||
columns = [columns]
|
||||
return await self._query_aggr_function(func_name="min", columns=columns)
|
||||
|
||||
async def sum( # noqa: A003
|
||||
self, columns: Union[str, List[str]]
|
||||
) -> Union[Any, ResultProxy]:
|
||||
"""
|
||||
Returns sum value of columns for rows matching the given criteria
|
||||
(applied with `filter` and `exclude` if set before).
|
||||
|
||||
:return: sum value of columns
|
||||
:rtype: int
|
||||
"""
|
||||
if not isinstance(columns, list):
|
||||
columns = [columns]
|
||||
return await self._query_aggr_function(func_name="sum", columns=columns)
|
||||
|
||||
async def avg(self, columns: Union[str, List[str]]) -> Union[Any, ResultProxy]:
|
||||
"""
|
||||
Returns avg value of columns for rows matching the given criteria
|
||||
(applied with `filter` and `exclude` if set before).
|
||||
|
||||
:return: avg value of columns
|
||||
:rtype: Union[int, float, List]
|
||||
"""
|
||||
if not isinstance(columns, list):
|
||||
columns = [columns]
|
||||
return await self._query_aggr_function(func_name="avg", columns=columns)
|
||||
|
||||
async def update(self, each: bool = False, **kwargs: Any) -> int:
|
||||
"""
|
||||
Updates the model table after applying the filters from kwargs.
|
||||
|
||||
@ -12,6 +12,8 @@ from typing import ( # noqa: I100, I201
|
||||
cast,
|
||||
)
|
||||
|
||||
from sqlalchemy.engine import ResultProxy
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
|
||||
|
||||
@ -116,6 +118,7 @@ class QuerysetProxy:
|
||||
:type child: Model
|
||||
"""
|
||||
model_cls = self.relation.through
|
||||
# TODO: Add support for pk with default not only autoincrement id
|
||||
owner_column = self.related_field.default_target_field_name() # type: ignore
|
||||
child_column = self.related_field.default_source_field_name() # type: ignore
|
||||
rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk}
|
||||
@ -185,6 +188,52 @@ class QuerysetProxy:
|
||||
"""
|
||||
return await self.queryset.count()
|
||||
|
||||
async def max( # noqa: A003
|
||||
self, columns: Union[str, List[str]]
|
||||
) -> Union[Any, ResultProxy]:
|
||||
"""
|
||||
Returns max value of columns for rows matching the given criteria
|
||||
(applied with `filter` and `exclude` if set before).
|
||||
|
||||
:return: max value of column(s)
|
||||
:rtype: Any
|
||||
"""
|
||||
return await self.queryset.max(columns=columns)
|
||||
|
||||
async def min( # noqa: A003
|
||||
self, columns: Union[str, List[str]]
|
||||
) -> Union[Any, ResultProxy]:
|
||||
"""
|
||||
Returns min value of columns for rows matching the given criteria
|
||||
(applied with `filter` and `exclude` if set before).
|
||||
|
||||
:return: min value of column(s)
|
||||
:rtype: Any
|
||||
"""
|
||||
return await self.queryset.min(columns=columns)
|
||||
|
||||
async def sum( # noqa: A003
|
||||
self, columns: Union[str, List[str]]
|
||||
) -> Union[Any, ResultProxy]:
|
||||
"""
|
||||
Returns sum value of columns for rows matching the given criteria
|
||||
(applied with `filter` and `exclude` if set before).
|
||||
|
||||
:return: sum value of columns
|
||||
:rtype: int
|
||||
"""
|
||||
return await self.queryset.sum(columns=columns)
|
||||
|
||||
async def avg(self, columns: Union[str, List[str]]) -> Union[Any, ResultProxy]:
|
||||
"""
|
||||
Returns avg value of columns for rows matching the given criteria
|
||||
(applied with `filter` and `exclude` if set before).
|
||||
|
||||
:return: avg value of columns
|
||||
:rtype: Union[int, float, List]
|
||||
"""
|
||||
return await self.queryset.avg(columns=columns)
|
||||
|
||||
async def clear(self, keep_reversed: bool = True) -> int:
|
||||
"""
|
||||
Removes all related models from given relation.
|
||||
|
||||
@ -152,6 +152,12 @@ class RelationProxy(list):
|
||||
f"Object {self._owner.get_name()} has no "
|
||||
f"{item.get_name()} with given primary key!"
|
||||
)
|
||||
await self._owner.signals.pre_relation_remove.send(
|
||||
sender=self._owner.__class__,
|
||||
instance=self._owner,
|
||||
child=item,
|
||||
relation_name=self.field_name,
|
||||
)
|
||||
super().remove(item)
|
||||
relation_name = self.related_field_name
|
||||
relation = item._orm._get(relation_name)
|
||||
@ -169,6 +175,12 @@ class RelationProxy(list):
|
||||
await item.update()
|
||||
else:
|
||||
await item.delete()
|
||||
await self._owner.signals.post_relation_remove.send(
|
||||
sender=self._owner.__class__,
|
||||
instance=self._owner,
|
||||
child=item,
|
||||
relation_name=self.field_name,
|
||||
)
|
||||
|
||||
async def add(self, item: "Model", **kwargs: Any) -> None:
|
||||
"""
|
||||
@ -182,6 +194,13 @@ class RelationProxy(list):
|
||||
:type item: Model
|
||||
"""
|
||||
relation_name = self.related_field_name
|
||||
await self._owner.signals.pre_relation_add.send(
|
||||
sender=self._owner.__class__,
|
||||
instance=self._owner,
|
||||
child=item,
|
||||
relation_name=self.field_name,
|
||||
passed_kwargs=kwargs,
|
||||
)
|
||||
self._check_if_model_saved()
|
||||
if self.type_ == ormar.RelationType.MULTIPLE:
|
||||
await self.queryset_proxy.create_through_instance(item, **kwargs)
|
||||
@ -189,3 +208,10 @@ class RelationProxy(list):
|
||||
else:
|
||||
setattr(item, relation_name, self._owner)
|
||||
await item.update()
|
||||
await self._owner.signals.post_relation_add.send(
|
||||
sender=self._owner.__class__,
|
||||
instance=self._owner,
|
||||
child=item,
|
||||
relation_name=self.field_name,
|
||||
passed_kwargs=kwargs,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user