wip pc problems backup
This commit is contained in:
@ -22,15 +22,15 @@ And what's a better name for python ORM than snakes cabinet :)
|
|||||||
from ormar.protocols import QuerySetProtocol, RelationProtocol # noqa: I100
|
from ormar.protocols import QuerySetProtocol, RelationProtocol # noqa: I100
|
||||||
from ormar.decorators import ( # noqa: I100
|
from ormar.decorators import ( # noqa: I100
|
||||||
post_delete,
|
post_delete,
|
||||||
post_save,
|
|
||||||
post_update,
|
|
||||||
post_relation_add,
|
post_relation_add,
|
||||||
post_relation_remove,
|
post_relation_remove,
|
||||||
|
post_save,
|
||||||
|
post_update,
|
||||||
pre_delete,
|
pre_delete,
|
||||||
pre_save,
|
|
||||||
pre_update,
|
|
||||||
pre_relation_add,
|
pre_relation_add,
|
||||||
pre_relation_remove,
|
pre_relation_remove,
|
||||||
|
pre_save,
|
||||||
|
pre_update,
|
||||||
property_field,
|
property_field,
|
||||||
)
|
)
|
||||||
from ormar.exceptions import ( # noqa: I100
|
from ormar.exceptions import ( # noqa: I100
|
||||||
|
|||||||
@ -10,15 +10,15 @@ Currently only:
|
|||||||
from ormar.decorators.property_field import property_field
|
from ormar.decorators.property_field import property_field
|
||||||
from ormar.decorators.signals import (
|
from ormar.decorators.signals import (
|
||||||
post_delete,
|
post_delete,
|
||||||
post_save,
|
|
||||||
post_update,
|
|
||||||
post_relation_add,
|
post_relation_add,
|
||||||
post_relation_remove,
|
post_relation_remove,
|
||||||
|
post_save,
|
||||||
|
post_update,
|
||||||
pre_delete,
|
pre_delete,
|
||||||
pre_save,
|
|
||||||
pre_update,
|
|
||||||
pre_relation_add,
|
pre_relation_add,
|
||||||
pre_relation_remove,
|
pre_relation_remove,
|
||||||
|
pre_save,
|
||||||
|
pre_update,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -1,5 +1,8 @@
|
|||||||
|
import uuid
|
||||||
from typing import Dict, Optional, Set, TYPE_CHECKING
|
from typing import Dict, Optional, Set, TYPE_CHECKING
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
|
||||||
import ormar
|
import ormar
|
||||||
from ormar.exceptions import ModelPersistenceError
|
from ormar.exceptions import ModelPersistenceError
|
||||||
from ormar.models.helpers.validation import validate_choices
|
from ormar.models.helpers.validation import validate_choices
|
||||||
@ -55,6 +58,25 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
|
|||||||
del new_kwargs[pkname]
|
del new_kwargs[pkname]
|
||||||
return new_kwargs
|
return new_kwargs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_non_db_fields(cls, model_dict: Dict) -> Dict:
|
||||||
|
"""
|
||||||
|
Receives dictionary of model that is about to be saved and changes uuid fields
|
||||||
|
to strings in bulk_update.
|
||||||
|
|
||||||
|
:param model_dict: dictionary of model that is about to be saved
|
||||||
|
:type model_dict: Dict
|
||||||
|
:return: dictionary of model that is about to be saved
|
||||||
|
:rtype: Dict
|
||||||
|
"""
|
||||||
|
for name, field in cls.Meta.model_fields.items():
|
||||||
|
if field.__type__ == uuid.UUID and name in model_dict:
|
||||||
|
if field.column_type.uuid_format == "string":
|
||||||
|
model_dict[name] = str(model_dict[name])
|
||||||
|
else:
|
||||||
|
model_dict[name] = "%.32x" % model_dict[name].int
|
||||||
|
return model_dict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def substitute_models_with_pks(cls, model_dict: Dict) -> Dict: # noqa CCR001
|
def substitute_models_with_pks(cls, model_dict: Dict) -> Dict: # noqa CCR001
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -69,6 +69,7 @@ class Model(ModelRow):
|
|||||||
:return: saved Model
|
:return: saved Model
|
||||||
:rtype: Model
|
:rtype: Model
|
||||||
"""
|
"""
|
||||||
|
await self.signals.pre_save.send(sender=self.__class__, instance=self)
|
||||||
self_fields = self._extract_model_db_fields()
|
self_fields = self._extract_model_db_fields()
|
||||||
|
|
||||||
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
|
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
|
||||||
@ -82,8 +83,6 @@ class Model(ModelRow):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.signals.pre_save.send(sender=self.__class__, instance=self)
|
|
||||||
|
|
||||||
self_fields = self.translate_columns_to_aliases(self_fields)
|
self_fields = self.translate_columns_to_aliases(self_fields)
|
||||||
expr = self.Meta.table.insert()
|
expr = self.Meta.table.insert()
|
||||||
expr = expr.values(**self_fields)
|
expr = expr.values(**self_fields)
|
||||||
@ -216,7 +215,9 @@ class Model(ModelRow):
|
|||||||
"You cannot update not saved model! Use save or upsert method."
|
"You cannot update not saved model! Use save or upsert method."
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.signals.pre_update.send(sender=self.__class__, instance=self)
|
await self.signals.pre_update.send(
|
||||||
|
sender=self.__class__, instance=self, passed_args=kwargs
|
||||||
|
)
|
||||||
self_fields = self._extract_model_db_fields()
|
self_fields = self._extract_model_db_fields()
|
||||||
self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname))
|
self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname))
|
||||||
self_fields = self.translate_columns_to_aliases(self_fields)
|
self_fields = self.translate_columns_to_aliases(self_fields)
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
from typing import Callable, TYPE_CHECKING, Type
|
import decimal
|
||||||
|
from typing import Any, Callable, TYPE_CHECKING, Type
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
from ormar.queryset.actions.query_action import QueryAction
|
from ormar.queryset.actions.query_action import QueryAction # noqa: I202
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
|
|
||||||
|
|
||||||
@ -22,7 +23,7 @@ class SelectAction(QueryAction):
|
|||||||
self, select_str: str, model_cls: Type["Model"], alias: str = None
|
self, select_str: str, model_cls: Type["Model"], alias: str = None
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(query_str=select_str, model_cls=model_cls)
|
super().__init__(query_str=select_str, model_cls=model_cls)
|
||||||
if alias:
|
if alias: # pragma: no cover
|
||||||
self.table_prefix = alias
|
self.table_prefix = alias
|
||||||
|
|
||||||
def _split_value_into_parts(self, order_str: str) -> None:
|
def _split_value_into_parts(self, order_str: str) -> None:
|
||||||
@ -30,6 +31,13 @@ class SelectAction(QueryAction):
|
|||||||
self.field_name = parts[-1]
|
self.field_name = parts[-1]
|
||||||
self.related_parts = parts[:-1]
|
self.related_parts = parts[:-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_numeric(self) -> bool:
|
||||||
|
return self.get_target_field_type() in [int, float, decimal.Decimal]
|
||||||
|
|
||||||
|
def get_target_field_type(self) -> Any:
|
||||||
|
return self.target_model.Meta.model_fields[self.field_name].__type__
|
||||||
|
|
||||||
def get_text_clause(self) -> sqlalchemy.sql.expression.TextClause:
|
def get_text_clause(self) -> sqlalchemy.sql.expression.TextClause:
|
||||||
alias = f"{self.table_prefix}_" if self.table_prefix else ""
|
alias = f"{self.table_prefix}_" if self.table_prefix else ""
|
||||||
return sqlalchemy.text(f"{alias}{self.field_name}")
|
return sqlalchemy.text(f"{alias}{self.field_name}")
|
||||||
|
|||||||
@ -320,6 +320,48 @@ class SqlJoin:
|
|||||||
)
|
)
|
||||||
self.sorted_orders[clause] = clause.get_text_clause()
|
self.sorted_orders[clause] = clause.get_text_clause()
|
||||||
|
|
||||||
|
def _verify_allowed_order_field(self, order_by: str) -> None:
|
||||||
|
"""
|
||||||
|
Verifies if proper field string is used.
|
||||||
|
:param order_by: string with order by definition
|
||||||
|
:type order_by: str
|
||||||
|
"""
|
||||||
|
parts = order_by.split("__")
|
||||||
|
if len(parts) > 2 or parts[0] != self.target_field.through.get_name():
|
||||||
|
raise ModelDefinitionError(
|
||||||
|
"You can order the relation only " "by related or link table columns!"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_alias_and_model(self, order_by: str) -> Tuple[str, Type["Model"]]:
|
||||||
|
"""
|
||||||
|
Returns proper model and alias to be applied in the clause.
|
||||||
|
|
||||||
|
:param order_by: string with order by definition
|
||||||
|
:type order_by: str
|
||||||
|
:return: alias and model to be used in clause
|
||||||
|
:rtype: Tuple[str, Type["Model"]]
|
||||||
|
"""
|
||||||
|
if self.target_field.is_multi and "__" in order_by:
|
||||||
|
self._verify_allowed_order_field(order_by=order_by)
|
||||||
|
alias = self.next_alias
|
||||||
|
model = self.target_field.owner
|
||||||
|
elif self.target_field.is_multi:
|
||||||
|
alias = self.alias_manager.resolve_relation_alias(
|
||||||
|
from_model=self.target_field.through,
|
||||||
|
relation_name=cast(
|
||||||
|
"ManyToManyField", self.target_field
|
||||||
|
).default_target_field_name(),
|
||||||
|
)
|
||||||
|
model = self.target_field.to
|
||||||
|
else:
|
||||||
|
alias = self.alias_manager.resolve_relation_alias(
|
||||||
|
from_model=self.target_field.owner,
|
||||||
|
relation_name=self.target_field.name,
|
||||||
|
)
|
||||||
|
model = self.target_field.to
|
||||||
|
|
||||||
|
return alias, model
|
||||||
|
|
||||||
def _get_order_bys(self) -> None: # noqa: CCR001
|
def _get_order_bys(self) -> None: # noqa: CCR001
|
||||||
"""
|
"""
|
||||||
Triggers construction of order bys if they are given.
|
Triggers construction of order bys if they are given.
|
||||||
@ -339,41 +381,10 @@ class SqlJoin:
|
|||||||
self.already_sorted[
|
self.already_sorted[
|
||||||
f"{self.next_alias}_{self.next_model.get_name()}"
|
f"{self.next_alias}_{self.next_model.get_name()}"
|
||||||
] = condition
|
] = condition
|
||||||
# TODO: refactor into smaller helper functions
|
|
||||||
if self.target_field.orders_by and not current_table_sorted:
|
if self.target_field.orders_by and not current_table_sorted:
|
||||||
current_table_sorted = True
|
current_table_sorted = True
|
||||||
for order_by in self.target_field.orders_by:
|
for order_by in self.target_field.orders_by:
|
||||||
if self.target_field.is_multi and "__" in order_by:
|
alias, model = self._get_alias_and_model(order_by=order_by)
|
||||||
parts = order_by.split("__")
|
|
||||||
if (
|
|
||||||
len(parts) > 2
|
|
||||||
or parts[0] != self.target_field.through.get_name()
|
|
||||||
):
|
|
||||||
raise ModelDefinitionError(
|
|
||||||
"You can order the relation only"
|
|
||||||
"by related or link table columns!"
|
|
||||||
)
|
|
||||||
model = self.target_field.owner
|
|
||||||
clause = ormar.OrderAction(
|
|
||||||
order_str=order_by, model_cls=model, alias=alias,
|
|
||||||
)
|
|
||||||
elif self.target_field.is_multi:
|
|
||||||
alias = self.alias_manager.resolve_relation_alias(
|
|
||||||
from_model=self.target_field.through,
|
|
||||||
relation_name=cast(
|
|
||||||
"ManyToManyField", self.target_field
|
|
||||||
).default_target_field_name(),
|
|
||||||
)
|
|
||||||
model = self.target_field.to
|
|
||||||
clause = ormar.OrderAction(
|
|
||||||
order_str=order_by, model_cls=model, alias=alias
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
alias = self.alias_manager.resolve_relation_alias(
|
|
||||||
from_model=self.target_field.owner,
|
|
||||||
relation_name=self.target_field.name,
|
|
||||||
)
|
|
||||||
model = self.target_field.to
|
|
||||||
clause = ormar.OrderAction(
|
clause = ormar.OrderAction(
|
||||||
order_str=order_by, model_cls=model, alias=alias
|
order_str=order_by, model_cls=model, alias=alias
|
||||||
)
|
)
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from typing import (
|
|||||||
import databases
|
import databases
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy import bindparam
|
from sqlalchemy import bindparam
|
||||||
from sqlalchemy.engine import ResultProxy
|
|
||||||
|
|
||||||
import ormar # noqa I100
|
import ormar # noqa I100
|
||||||
from ormar import MultipleMatches, NoMatch
|
from ormar import MultipleMatches, NoMatch
|
||||||
@ -558,22 +557,24 @@ class QuerySet:
|
|||||||
expr = sqlalchemy.func.count().select().select_from(expr)
|
expr = sqlalchemy.func.count().select().select_from(expr)
|
||||||
return await self.database.fetch_val(expr)
|
return await self.database.fetch_val(expr)
|
||||||
|
|
||||||
async def _query_aggr_function(self, func_name: str, columns: List):
|
async def _query_aggr_function(self, func_name: str, columns: List) -> Any:
|
||||||
func = getattr(sqlalchemy.func, func_name)
|
func = getattr(sqlalchemy.func, func_name)
|
||||||
select_actions = [
|
select_actions = [
|
||||||
SelectAction(select_str=column, model_cls=self.model)
|
SelectAction(select_str=column, model_cls=self.model) for column in columns
|
||||||
for column in columns
|
|
||||||
]
|
]
|
||||||
|
if func_name in ["sum", "avg"]:
|
||||||
|
if any(not x.is_numeric for x in select_actions):
|
||||||
|
raise QueryDefinitionError(
|
||||||
|
"You can use sum and svg only with" "numeric types of columns"
|
||||||
|
)
|
||||||
select_columns = [x.apply_func(func, use_label=True) for x in select_actions]
|
select_columns = [x.apply_func(func, use_label=True) for x in select_actions]
|
||||||
expr = self.build_select_expression().alias(f"subquery_for_{func_name}")
|
expr = self.build_select_expression().alias(f"subquery_for_{func_name}")
|
||||||
expr = sqlalchemy.select(select_columns).select_from(expr)
|
expr = sqlalchemy.select(select_columns).select_from(expr)
|
||||||
# print("\n", expr.compile(compile_kwargs={"literal_binds": True}))
|
# print("\n", expr.compile(compile_kwargs={"literal_binds": True}))
|
||||||
result = await self.database.fetch_one(expr)
|
result = await self.database.fetch_one(expr)
|
||||||
return result if len(result) > 1 else result[0] # type: ignore
|
return dict(result) if len(result) > 1 else result[0] # type: ignore
|
||||||
|
|
||||||
async def max( # noqa: A003
|
async def max(self, columns: Union[str, List[str]]) -> Any: # noqa: A003
|
||||||
self, columns: Union[str, List[str]]
|
|
||||||
) -> Union[Any, ResultProxy]:
|
|
||||||
"""
|
"""
|
||||||
Returns max value of columns for rows matching the given criteria
|
Returns max value of columns for rows matching the given criteria
|
||||||
(applied with `filter` and `exclude` if set before).
|
(applied with `filter` and `exclude` if set before).
|
||||||
@ -585,9 +586,7 @@ class QuerySet:
|
|||||||
columns = [columns]
|
columns = [columns]
|
||||||
return await self._query_aggr_function(func_name="max", columns=columns)
|
return await self._query_aggr_function(func_name="max", columns=columns)
|
||||||
|
|
||||||
async def min( # noqa: A003
|
async def min(self, columns: Union[str, List[str]]) -> Any: # noqa: A003
|
||||||
self, columns: Union[str, List[str]]
|
|
||||||
) -> Union[Any, ResultProxy]:
|
|
||||||
"""
|
"""
|
||||||
Returns min value of columns for rows matching the given criteria
|
Returns min value of columns for rows matching the given criteria
|
||||||
(applied with `filter` and `exclude` if set before).
|
(applied with `filter` and `exclude` if set before).
|
||||||
@ -599,9 +598,7 @@ class QuerySet:
|
|||||||
columns = [columns]
|
columns = [columns]
|
||||||
return await self._query_aggr_function(func_name="min", columns=columns)
|
return await self._query_aggr_function(func_name="min", columns=columns)
|
||||||
|
|
||||||
async def sum( # noqa: A003
|
async def sum(self, columns: Union[str, List[str]]) -> Any: # noqa: A003
|
||||||
self, columns: Union[str, List[str]]
|
|
||||||
) -> Union[Any, ResultProxy]:
|
|
||||||
"""
|
"""
|
||||||
Returns sum value of columns for rows matching the given criteria
|
Returns sum value of columns for rows matching the given criteria
|
||||||
(applied with `filter` and `exclude` if set before).
|
(applied with `filter` and `exclude` if set before).
|
||||||
@ -613,7 +610,7 @@ class QuerySet:
|
|||||||
columns = [columns]
|
columns = [columns]
|
||||||
return await self._query_aggr_function(func_name="sum", columns=columns)
|
return await self._query_aggr_function(func_name="sum", columns=columns)
|
||||||
|
|
||||||
async def avg(self, columns: Union[str, List[str]]) -> Union[Any, ResultProxy]:
|
async def avg(self, columns: Union[str, List[str]]) -> Any:
|
||||||
"""
|
"""
|
||||||
Returns avg value of columns for rows matching the given criteria
|
Returns avg value of columns for rows matching the given criteria
|
||||||
(applied with `filter` and `exclude` if set before).
|
(applied with `filter` and `exclude` if set before).
|
||||||
@ -974,6 +971,7 @@ class QuerySet:
|
|||||||
"You cannot update unsaved objects. "
|
"You cannot update unsaved objects. "
|
||||||
f"{self.model.__name__} has to have {pk_name} filled."
|
f"{self.model.__name__} has to have {pk_name} filled."
|
||||||
)
|
)
|
||||||
|
new_kwargs = self.model.parse_non_db_fields(new_kwargs)
|
||||||
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
|
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
|
||||||
new_kwargs = self.model.translate_columns_to_aliases(new_kwargs)
|
new_kwargs = self.model.translate_columns_to_aliases(new_kwargs)
|
||||||
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
|
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
|
||||||
|
|||||||
@ -12,9 +12,8 @@ from typing import ( # noqa: I100, I201
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sqlalchemy.engine import ResultProxy
|
|
||||||
|
|
||||||
import ormar
|
import ormar # noqa: I100, I202
|
||||||
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
|
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
@ -118,7 +117,6 @@ class QuerysetProxy:
|
|||||||
:type child: Model
|
:type child: Model
|
||||||
"""
|
"""
|
||||||
model_cls = self.relation.through
|
model_cls = self.relation.through
|
||||||
# TODO: Add support for pk with default not only autoincrement id
|
|
||||||
owner_column = self.related_field.default_target_field_name() # type: ignore
|
owner_column = self.related_field.default_target_field_name() # type: ignore
|
||||||
child_column = self.related_field.default_source_field_name() # type: ignore
|
child_column = self.related_field.default_source_field_name() # type: ignore
|
||||||
rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk}
|
rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk}
|
||||||
@ -129,10 +127,8 @@ class QuerysetProxy:
|
|||||||
f"model without primary key set! \n"
|
f"model without primary key set! \n"
|
||||||
f"Save the child model first."
|
f"Save the child model first."
|
||||||
)
|
)
|
||||||
expr = model_cls.Meta.table.insert()
|
print('final kwargs', final_kwargs)
|
||||||
expr = expr.values(**final_kwargs)
|
await model_cls(**final_kwargs).save()
|
||||||
# print("\n", expr.compile(compile_kwargs={"literal_binds": True}))
|
|
||||||
await model_cls.Meta.database.execute(expr)
|
|
||||||
|
|
||||||
async def update_through_instance(self, child: "Model", **kwargs: Any) -> None:
|
async def update_through_instance(self, child: "Model", **kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
@ -148,6 +144,7 @@ class QuerysetProxy:
|
|||||||
child_column = self.related_field.default_source_field_name() # type: ignore
|
child_column = self.related_field.default_source_field_name() # type: ignore
|
||||||
rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk}
|
rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk}
|
||||||
through_model = await model_cls.objects.get(**rel_kwargs)
|
through_model = await model_cls.objects.get(**rel_kwargs)
|
||||||
|
print('update kwargs', kwargs)
|
||||||
await through_model.update(**kwargs)
|
await through_model.update(**kwargs)
|
||||||
|
|
||||||
async def delete_through_instance(self, child: "Model") -> None:
|
async def delete_through_instance(self, child: "Model") -> None:
|
||||||
@ -188,9 +185,7 @@ class QuerysetProxy:
|
|||||||
"""
|
"""
|
||||||
return await self.queryset.count()
|
return await self.queryset.count()
|
||||||
|
|
||||||
async def max( # noqa: A003
|
async def max(self, columns: Union[str, List[str]]) -> Any: # noqa: A003
|
||||||
self, columns: Union[str, List[str]]
|
|
||||||
) -> Union[Any, ResultProxy]:
|
|
||||||
"""
|
"""
|
||||||
Returns max value of columns for rows matching the given criteria
|
Returns max value of columns for rows matching the given criteria
|
||||||
(applied with `filter` and `exclude` if set before).
|
(applied with `filter` and `exclude` if set before).
|
||||||
@ -200,9 +195,7 @@ class QuerysetProxy:
|
|||||||
"""
|
"""
|
||||||
return await self.queryset.max(columns=columns)
|
return await self.queryset.max(columns=columns)
|
||||||
|
|
||||||
async def min( # noqa: A003
|
async def min(self, columns: Union[str, List[str]]) -> Any: # noqa: A003
|
||||||
self, columns: Union[str, List[str]]
|
|
||||||
) -> Union[Any, ResultProxy]:
|
|
||||||
"""
|
"""
|
||||||
Returns min value of columns for rows matching the given criteria
|
Returns min value of columns for rows matching the given criteria
|
||||||
(applied with `filter` and `exclude` if set before).
|
(applied with `filter` and `exclude` if set before).
|
||||||
@ -212,9 +205,7 @@ class QuerysetProxy:
|
|||||||
"""
|
"""
|
||||||
return await self.queryset.min(columns=columns)
|
return await self.queryset.min(columns=columns)
|
||||||
|
|
||||||
async def sum( # noqa: A003
|
async def sum(self, columns: Union[str, List[str]]) -> Any: # noqa: A003
|
||||||
self, columns: Union[str, List[str]]
|
|
||||||
) -> Union[Any, ResultProxy]:
|
|
||||||
"""
|
"""
|
||||||
Returns sum value of columns for rows matching the given criteria
|
Returns sum value of columns for rows matching the given criteria
|
||||||
(applied with `filter` and `exclude` if set before).
|
(applied with `filter` and `exclude` if set before).
|
||||||
@ -224,7 +215,7 @@ class QuerysetProxy:
|
|||||||
"""
|
"""
|
||||||
return await self.queryset.sum(columns=columns)
|
return await self.queryset.sum(columns=columns)
|
||||||
|
|
||||||
async def avg(self, columns: Union[str, List[str]]) -> Union[Any, ResultProxy]:
|
async def avg(self, columns: Union[str, List[str]]) -> Any:
|
||||||
"""
|
"""
|
||||||
Returns avg value of columns for rows matching the given criteria
|
Returns avg value of columns for rows matching the given criteria
|
||||||
(applied with `filter` and `exclude` if set before).
|
(applied with `filter` and `exclude` if set before).
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import pytest
|
|||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
import ormar
|
import ormar
|
||||||
|
from ormar.exceptions import QueryDefinitionError
|
||||||
from tests.settings import DATABASE_URL
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
database = databases.Database(DATABASE_URL)
|
database = databases.Database(DATABASE_URL)
|
||||||
@ -67,8 +68,7 @@ async def test_min_method():
|
|||||||
await sample_data()
|
await sample_data()
|
||||||
assert await Book.objects.min("year") == 1920
|
assert await Book.objects.min("year") == 1920
|
||||||
result = await Book.objects.min(["year", "ranking"])
|
result = await Book.objects.min(["year", "ranking"])
|
||||||
assert result == (1920, 1)
|
assert result == dict(year=1920, ranking=1)
|
||||||
assert dict(result) == dict(year=1920, ranking=1)
|
|
||||||
|
|
||||||
assert await Book.objects.min("title") == "Book 1"
|
assert await Book.objects.min("title") == "Book 1"
|
||||||
|
|
||||||
@ -76,8 +76,7 @@ async def test_min_method():
|
|||||||
result = await Author.objects.select_related("books").min(
|
result = await Author.objects.select_related("books").min(
|
||||||
["books__year", "books__ranking"]
|
["books__year", "books__ranking"]
|
||||||
)
|
)
|
||||||
assert result == (1920, 1)
|
assert result == dict(books__year=1920, books__ranking=1)
|
||||||
assert dict(result) == dict(books__year=1920, books__ranking=1)
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
await Author.objects.select_related("books")
|
await Author.objects.select_related("books")
|
||||||
@ -93,8 +92,7 @@ async def test_max_method():
|
|||||||
await sample_data()
|
await sample_data()
|
||||||
assert await Book.objects.max("year") == 1930
|
assert await Book.objects.max("year") == 1930
|
||||||
result = await Book.objects.max(["year", "ranking"])
|
result = await Book.objects.max(["year", "ranking"])
|
||||||
assert result == (1930, 5)
|
assert result == dict(year=1930, ranking=5)
|
||||||
assert dict(result) == dict(year=1930, ranking=5)
|
|
||||||
|
|
||||||
assert await Book.objects.max("title") == "Book 3"
|
assert await Book.objects.max("title") == "Book 3"
|
||||||
|
|
||||||
@ -102,8 +100,7 @@ async def test_max_method():
|
|||||||
result = await Author.objects.select_related("books").max(
|
result = await Author.objects.select_related("books").max(
|
||||||
["books__year", "books__ranking"]
|
["books__year", "books__ranking"]
|
||||||
)
|
)
|
||||||
assert result == (1930, 5)
|
assert result == dict(books__year=1930, books__ranking=5)
|
||||||
assert dict(result) == dict(books__year=1930, books__ranking=5)
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
await Author.objects.select_related("books")
|
await Author.objects.select_related("books")
|
||||||
@ -119,17 +116,16 @@ async def test_sum_method():
|
|||||||
await sample_data()
|
await sample_data()
|
||||||
assert await Book.objects.sum("year") == 5773
|
assert await Book.objects.sum("year") == 5773
|
||||||
result = await Book.objects.sum(["year", "ranking"])
|
result = await Book.objects.sum(["year", "ranking"])
|
||||||
assert result == (5773, 9)
|
assert result == dict(year=5773, ranking=9)
|
||||||
assert dict(result) == dict(year=5773, ranking=9)
|
|
||||||
|
|
||||||
assert await Book.objects.sum("title") == 0.0
|
with pytest.raises(QueryDefinitionError):
|
||||||
|
await Book.objects.sum("title")
|
||||||
|
|
||||||
assert await Author.objects.select_related("books").sum("books__year") == 5773
|
assert await Author.objects.select_related("books").sum("books__year") == 5773
|
||||||
result = await Author.objects.select_related("books").sum(
|
result = await Author.objects.select_related("books").sum(
|
||||||
["books__year", "books__ranking"]
|
["books__year", "books__ranking"]
|
||||||
)
|
)
|
||||||
assert result == (5773, 9)
|
assert result == dict(books__year=5773, books__ranking=9)
|
||||||
assert dict(result) == dict(books__year=5773, books__ranking=9)
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
await Author.objects.select_related("books")
|
await Author.objects.select_related("books")
|
||||||
@ -143,24 +139,21 @@ async def test_sum_method():
|
|||||||
async def test_avg_method():
|
async def test_avg_method():
|
||||||
async with database:
|
async with database:
|
||||||
await sample_data()
|
await sample_data()
|
||||||
assert round(await Book.objects.avg("year"), 2) == 1924.33
|
assert round(float(await Book.objects.avg("year")), 2) == 1924.33
|
||||||
result = await Book.objects.avg(["year", "ranking"])
|
result = await Book.objects.avg(["year", "ranking"])
|
||||||
assert (round(result[0], 2), result[1]) == (1924.33, 3.0)
|
assert round(float(result.get("year")), 2) == 1924.33
|
||||||
result_dict = dict(result)
|
assert result.get("ranking") == 3.0
|
||||||
assert round(result_dict.get("year"), 2) == 1924.33
|
|
||||||
assert result_dict.get("ranking") == 3.0
|
|
||||||
|
|
||||||
assert await Book.objects.avg("title") == 0.0
|
with pytest.raises(QueryDefinitionError):
|
||||||
|
await Book.objects.avg("title")
|
||||||
|
|
||||||
result = await Author.objects.select_related("books").avg("books__year")
|
result = await Author.objects.select_related("books").avg("books__year")
|
||||||
assert round(result, 2) == 1924.33
|
assert round(float(result), 2) == 1924.33
|
||||||
result = await Author.objects.select_related("books").avg(
|
result = await Author.objects.select_related("books").avg(
|
||||||
["books__year", "books__ranking"]
|
["books__year", "books__ranking"]
|
||||||
)
|
)
|
||||||
assert (round(result[0], 2), result[1]) == (1924.33, 3.0)
|
assert round(float(result.get("books__year")), 2) == 1924.33
|
||||||
result_dict = dict(result)
|
assert result.get("books__ranking") == 3.0
|
||||||
assert round(result_dict.get("books__year"), 2) == 1924.33
|
|
||||||
assert result_dict.get("books__ranking") == 3.0
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
await Author.objects.select_related("books")
|
await Author.objects.select_related("books")
|
||||||
@ -179,4 +172,6 @@ async def test_queryset_method():
|
|||||||
assert await author.books.max("year") == 1930
|
assert await author.books.max("year") == 1930
|
||||||
assert await author.books.sum("ranking") == 9
|
assert await author.books.sum("ranking") == 9
|
||||||
assert await author.books.avg("ranking") == 3.0
|
assert await author.books.avg("ranking") == 3.0
|
||||||
assert await author.books.max(["year", "title"]) == (1930, "Book 3")
|
assert await author.books.max(["year", "title"]) == dict(
|
||||||
|
year=1930, title="Book 3"
|
||||||
|
)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List
|
from typing import Any, Dict, List, Type
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
import databases
|
import databases
|
||||||
@ -6,6 +6,8 @@ import pytest
|
|||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
import ormar
|
import ormar
|
||||||
|
from ormar import ModelDefinitionError, Model, QuerySet, pre_update
|
||||||
|
from ormar import pre_save, pre_relation_add
|
||||||
from tests.settings import DATABASE_URL
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
database = databases.Database(DATABASE_URL)
|
database = databases.Database(DATABASE_URL)
|
||||||
@ -30,7 +32,7 @@ class Link(ormar.Model):
|
|||||||
class Meta(BaseMeta):
|
class Meta(BaseMeta):
|
||||||
tablename = "link_table"
|
tablename = "link_table"
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
id: UUID = ormar.UUID(primary_key=True, default=uuid4)
|
||||||
animal_order: int = ormar.Integer(nullable=True)
|
animal_order: int = ormar.Integer(nullable=True)
|
||||||
human_order: int = ormar.Integer(nullable=True)
|
human_order: int = ormar.Integer(nullable=True)
|
||||||
|
|
||||||
@ -50,6 +52,17 @@ class Human(ormar.Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Human2(ormar.Model):
|
||||||
|
class Meta(BaseMeta):
|
||||||
|
tablename = "humans2"
|
||||||
|
|
||||||
|
id: UUID = ormar.UUID(primary_key=True, default=uuid4)
|
||||||
|
name: str = ormar.Text(default="")
|
||||||
|
favoriteAnimals: List[Animal] = ormar.ManyToMany(
|
||||||
|
Animal, related_name="favoriteHumans2", orders_by=["link__animal_order__fail"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True, scope="module")
|
@pytest.fixture(autouse=True, scope="module")
|
||||||
def create_test_database():
|
def create_test_database():
|
||||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||||
@ -59,9 +72,94 @@ def create_test_database():
|
|||||||
metadata.drop_all(engine)
|
metadata.drop_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ordering_by_through_fail():
|
||||||
|
async with database:
|
||||||
|
alice = await Human2(name="Alice").save()
|
||||||
|
spot = await Animal(name="Spot").save()
|
||||||
|
await alice.favoriteAnimals.add(spot)
|
||||||
|
with pytest.raises(ModelDefinitionError):
|
||||||
|
await alice.load_all()
|
||||||
|
|
||||||
|
|
||||||
|
def get_filtered_query(
|
||||||
|
sender: Type[Model], instance: Model, to_class: Type[Model]
|
||||||
|
) -> QuerySet:
|
||||||
|
pk = getattr(instance, f"{to_class.get_name()}").pk
|
||||||
|
filter_kwargs = {f"{to_class.get_name()}": pk}
|
||||||
|
query = sender.objects.filter(**filter_kwargs)
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
async def populate_order_on_insert(
|
||||||
|
sender: Type[Model], instance: Model, from_class: Type[Model],
|
||||||
|
to_class: Type[Model]
|
||||||
|
):
|
||||||
|
order_column = f"{from_class.get_name()}_order"
|
||||||
|
if getattr(instance, order_column) is None:
|
||||||
|
query = get_filtered_query(sender, instance, to_class)
|
||||||
|
max_order = await query.max(order_column)
|
||||||
|
max_order = max_order + 1 if max_order is not None else 0
|
||||||
|
setattr(instance, order_column, max_order)
|
||||||
|
else:
|
||||||
|
await reorder_on_update(sender, instance, from_class, to_class,
|
||||||
|
passed_args={
|
||||||
|
order_column: getattr(instance, order_column)})
|
||||||
|
|
||||||
|
|
||||||
|
async def reorder_on_update(
|
||||||
|
sender: Type[Model], instance: Model, from_class: Type[Model],
|
||||||
|
to_class: Type[Model], passed_args: Dict
|
||||||
|
):
|
||||||
|
order = f"{from_class.get_name()}_order"
|
||||||
|
if order in passed_args:
|
||||||
|
query = get_filtered_query(sender, instance, to_class)
|
||||||
|
to_reorder = await query.exclude(pk=instance.pk).order_by(order).all()
|
||||||
|
old_order = getattr(instance, order)
|
||||||
|
new_order = passed_args.get(order)
|
||||||
|
if to_reorder:
|
||||||
|
for link in to_reorder:
|
||||||
|
setattr(link, order, getattr(link, order) + 1)
|
||||||
|
await sender.objects.bulk_update(to_reorder, columns=[order])
|
||||||
|
check = await get_filtered_query(sender, instance, to_class).all()
|
||||||
|
print('reordered', check)
|
||||||
|
|
||||||
|
|
||||||
|
@pre_save(Link)
|
||||||
|
async def order_link_on_insert(sender: Type[Model], instance: Model, **kwargs: Any):
|
||||||
|
relations = list(instance.extract_related_names())
|
||||||
|
rel_one = sender.Meta.model_fields[relations[0]].to
|
||||||
|
rel_two = sender.Meta.model_fields[relations[1]].to
|
||||||
|
await populate_order_on_insert(sender, instance, from_class=rel_one,
|
||||||
|
to_class=rel_two)
|
||||||
|
await populate_order_on_insert(sender, instance, from_class=rel_two,
|
||||||
|
to_class=rel_one)
|
||||||
|
|
||||||
|
|
||||||
|
@pre_update(Link)
|
||||||
|
async def reorder_links_on_update(
|
||||||
|
sender: Type[ormar.Model], instance: ormar.Model, passed_args: Dict,
|
||||||
|
**kwargs: Any
|
||||||
|
):
|
||||||
|
relations = list(instance.extract_related_names())
|
||||||
|
rel_one = sender.Meta.model_fields[relations[0]].to
|
||||||
|
rel_two = sender.Meta.model_fields[relations[1]].to
|
||||||
|
await reorder_on_update(sender, instance, from_class=rel_one, to_class=rel_two,
|
||||||
|
passed_args=passed_args)
|
||||||
|
await reorder_on_update(sender, instance, from_class=rel_two, to_class=rel_one,
|
||||||
|
passed_args=passed_args)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ordering_by_through_on_m2m_field():
|
async def test_ordering_by_through_on_m2m_field():
|
||||||
async with database:
|
async with database:
|
||||||
|
def verify_order(instance, expected):
|
||||||
|
field_name = (
|
||||||
|
"favoriteAnimals" if isinstance(instance,
|
||||||
|
Human) else "favoriteHumans"
|
||||||
|
)
|
||||||
|
assert [x.name for x in getattr(instance, field_name)] == expected
|
||||||
|
|
||||||
alice = await Human(name="Alice").save()
|
alice = await Human(name="Alice").save()
|
||||||
bob = await Human(name="Bob").save()
|
bob = await Human(name="Bob").save()
|
||||||
charlie = await Human(name="Charlie").save()
|
charlie = await Human(name="Charlie").save()
|
||||||
@ -70,98 +168,55 @@ async def test_ordering_by_through_on_m2m_field():
|
|||||||
kitty = await Animal(name="Kitty").save()
|
kitty = await Animal(name="Kitty").save()
|
||||||
noodle = await Animal(name="Noodle").save()
|
noodle = await Animal(name="Noodle").save()
|
||||||
|
|
||||||
# you need to add them in order anyway so can provide order explicitly
|
await alice.favoriteAnimals.add(noodle)
|
||||||
# if you have a lot of them a list with enumerate might be an option
|
await alice.favoriteAnimals.add(spot)
|
||||||
await alice.favoriteAnimals.add(noodle, animal_order=0, human_order=0)
|
await alice.favoriteAnimals.add(kitty)
|
||||||
await alice.favoriteAnimals.add(spot, animal_order=1, human_order=0)
|
|
||||||
await alice.favoriteAnimals.add(kitty, animal_order=2, human_order=0)
|
|
||||||
|
|
||||||
# you dont have to reload queries on queryset clears the existing related
|
|
||||||
# alice = await alice.reload()
|
|
||||||
await alice.load_all()
|
await alice.load_all()
|
||||||
assert [x.name for x in alice.favoriteAnimals] == ["Noodle", "Spot", "Kitty"]
|
verify_order(alice, ["Noodle", "Spot", "Kitty"])
|
||||||
|
|
||||||
await bob.favoriteAnimals.add(noodle, animal_order=0, human_order=1)
|
await bob.favoriteAnimals.add(noodle)
|
||||||
await bob.favoriteAnimals.add(kitty, animal_order=1, human_order=1)
|
await bob.favoriteAnimals.add(kitty)
|
||||||
await bob.favoriteAnimals.add(spot, animal_order=2, human_order=1)
|
await bob.favoriteAnimals.add(spot)
|
||||||
|
|
||||||
await bob.load_all()
|
await bob.load_all()
|
||||||
assert [x.name for x in bob.favoriteAnimals] == ["Noodle", "Kitty", "Spot"]
|
verify_order(bob, ["Noodle", "Kitty", "Spot"])
|
||||||
|
|
||||||
await charlie.favoriteAnimals.add(kitty, animal_order=0, human_order=2)
|
await charlie.favoriteAnimals.add(kitty)
|
||||||
await charlie.favoriteAnimals.add(noodle, animal_order=1, human_order=2)
|
await charlie.favoriteAnimals.add(noodle)
|
||||||
await charlie.favoriteAnimals.add(spot, animal_order=2, human_order=2)
|
await charlie.favoriteAnimals.add(spot)
|
||||||
|
|
||||||
await charlie.load_all()
|
await charlie.load_all()
|
||||||
assert [x.name for x in charlie.favoriteAnimals] == ["Kitty", "Noodle", "Spot"]
|
verify_order(charlie, ["Kitty", "Noodle", "Spot"])
|
||||||
|
|
||||||
animals = [noodle, kitty, spot]
|
animals = [noodle, kitty, spot]
|
||||||
for animal in animals:
|
for animal in animals:
|
||||||
await animal.load_all()
|
await animal.load_all()
|
||||||
assert [x.name for x in animal.favoriteHumans] == [
|
verify_order(animal, ["Alice", "Bob", "Charlie"])
|
||||||
"Alice",
|
|
||||||
"Bob",
|
|
||||||
"Charlie",
|
|
||||||
]
|
|
||||||
|
|
||||||
zack = await Human(name="Zack").save()
|
zack = await Human(name="Zack").save()
|
||||||
|
|
||||||
async def reorder_humans(animal, new_ordered_humans):
|
|
||||||
noodle_links = await Link.objects.filter(animal=animal).all()
|
|
||||||
for link in noodle_links:
|
|
||||||
link.human_order = next(
|
|
||||||
(
|
|
||||||
i
|
|
||||||
for i, x in enumerate(new_ordered_humans)
|
|
||||||
if x.pk == link.human.pk
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
await Link.objects.bulk_update(noodle_links, columns=["human_order"])
|
|
||||||
|
|
||||||
await noodle.favoriteHumans.add(zack, animal_order=0, human_order=0)
|
await noodle.favoriteHumans.add(zack, animal_order=0, human_order=0)
|
||||||
await reorder_humans(noodle, [zack, alice, bob, charlie])
|
|
||||||
await noodle.load_all()
|
await noodle.load_all()
|
||||||
assert [x.name for x in noodle.favoriteHumans] == [
|
verify_order(noodle, ["Zack", "Alice", "Bob", "Charlie"])
|
||||||
"Zack",
|
|
||||||
"Alice",
|
|
||||||
"Bob",
|
|
||||||
"Charlie",
|
|
||||||
]
|
|
||||||
|
|
||||||
await zack.load_all()
|
await zack.load_all()
|
||||||
assert [x.name for x in zack.favoriteAnimals] == ["Noodle"]
|
verify_order(zack, ["Noodle"])
|
||||||
|
|
||||||
humans = noodle.favoriteHumans
|
await noodle.favoriteHumans.filter(name='Zack').update(
|
||||||
humans.insert(1, humans.pop(0))
|
link=dict(human_order=1))
|
||||||
await reorder_humans(noodle, humans)
|
|
||||||
await noodle.load_all()
|
await noodle.load_all()
|
||||||
assert [x.name for x in noodle.favoriteHumans] == [
|
verify_order(noodle, ["Alice", "Zack", "Bob", "Charlie"])
|
||||||
"Alice",
|
|
||||||
"Zack",
|
|
||||||
"Bob",
|
|
||||||
"Charlie",
|
|
||||||
]
|
|
||||||
|
|
||||||
humans.insert(2, humans.pop(1))
|
await noodle.favoriteHumans.filter(name='Zack').update(
|
||||||
await reorder_humans(noodle, humans)
|
link=dict(human_order=2))
|
||||||
await noodle.load_all()
|
await noodle.load_all()
|
||||||
assert [x.name for x in noodle.favoriteHumans] == [
|
verify_order(noodle, ["Alice", "Bob", "Zack", "Charlie"])
|
||||||
"Alice",
|
|
||||||
"Bob",
|
|
||||||
"Zack",
|
|
||||||
"Charlie",
|
|
||||||
]
|
|
||||||
|
|
||||||
humans.insert(3, humans.pop(2))
|
await noodle.favoriteHumans.filter(name='Zack').update(
|
||||||
await reorder_humans(noodle, humans)
|
link=dict(human_order=3))
|
||||||
await noodle.load_all()
|
await noodle.load_all()
|
||||||
assert [x.name for x in noodle.favoriteHumans] == [
|
verify_order(noodle, ["Alice", "Bob", "Charlie", "Zack"])
|
||||||
"Alice",
|
|
||||||
"Bob",
|
|
||||||
"Charlie",
|
|
||||||
"Zack",
|
|
||||||
]
|
|
||||||
|
|
||||||
await kitty.favoriteHumans.remove(bob)
|
await kitty.favoriteHumans.remove(bob)
|
||||||
await kitty.load_all()
|
await kitty.load_all()
|
||||||
@ -169,8 +224,9 @@ async def test_ordering_by_through_on_m2m_field():
|
|||||||
|
|
||||||
bob = await noodle.favoriteHumans.get(pk=bob.pk)
|
bob = await noodle.favoriteHumans.get(pk=bob.pk)
|
||||||
assert bob.link.human_order == 1
|
assert bob.link.human_order == 1
|
||||||
|
|
||||||
await noodle.favoriteHumans.remove(
|
await noodle.favoriteHumans.remove(
|
||||||
await noodle.favoriteHumans.filter(link__human_order=2).get()
|
await noodle.favoriteHumans.filter(link__human_order=2).get()
|
||||||
)
|
)
|
||||||
await noodle.load_all()
|
await noodle.load_all()
|
||||||
assert [x.name for x in noodle.favoriteHumans] == ["Alice", "Bob", "Zack"]
|
verify_order(noodle, ["Alice", "Bob", "Zack"])
|
||||||
|
|||||||
@ -49,6 +49,8 @@ def create_test_database():
|
|||||||
|
|
||||||
@pytest.fixture(autouse=True, scope="function")
|
@pytest.fixture(autouse=True, scope="function")
|
||||||
async def cleanup():
|
async def cleanup():
|
||||||
|
yield
|
||||||
|
async with database:
|
||||||
await Book.objects.delete(each=True)
|
await Book.objects.delete(each=True)
|
||||||
await Author.objects.delete(each=True)
|
await Author.objects.delete(each=True)
|
||||||
|
|
||||||
|
|||||||
@ -70,7 +70,7 @@ def create_test_database():
|
|||||||
metadata.drop_all(engine)
|
metadata.drop_all(engine)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(autouse=True, scope="function")
|
||||||
async def cleanup():
|
async def cleanup():
|
||||||
yield
|
yield
|
||||||
async with database:
|
async with database:
|
||||||
|
|||||||
Reference in New Issue
Block a user