fix everything?

This commit is contained in:
collerek
2021-03-07 17:48:26 +01:00
parent f85fa7b8a7
commit d388b9f745
5 changed files with 61 additions and 51 deletions

View File

@ -20,7 +20,7 @@ class OrderAction(QueryAction):
""" """
def __init__( def __init__(
self, order_str: str, model_cls: Type["Model"], alias: str = None self, order_str: str, model_cls: Type["Model"], alias: str = None
) -> None: ) -> None:
self.direction: str = "" self.direction: str = ""
super().__init__(query_str=order_str, model_cls=model_cls) super().__init__(query_str=order_str, model_cls=model_cls)
@ -46,9 +46,16 @@ class OrderAction(QueryAction):
prefix = f"{self.table_prefix}_" if self.table_prefix else "" prefix = f"{self.table_prefix}_" if self.table_prefix else ""
return f"{prefix}{self.table}" f".{self.field_alias}" return f"{prefix}{self.table}" f".{self.field_alias}"
def get_min_or_max(self): def get_min_or_max(self) -> sqlalchemy.sql.expression.TextClause:
"""
Used in limit sub queries where you need to use aggregated functions
in order to order by columns not included in group by.
:return: min or max function to order
:rtype: sqlalchemy.sql.elements.TextClause
"""
prefix = f"{self.table_prefix}_" if self.table_prefix else "" prefix = f"{self.table_prefix}_" if self.table_prefix else ""
if self.direction == '': if self.direction == "":
return text(f"min({prefix}{self.table}" f".{self.field_alias})") return text(f"min({prefix}{self.table}" f".{self.field_alias})")
else: else:
return text(f"max({prefix}{self.table}" f".{self.field_alias}) desc") return text(f"max({prefix}{self.table}" f".{self.field_alias}) desc")

View File

@ -18,16 +18,16 @@ if TYPE_CHECKING: # pragma no cover
class Query: class Query:
def __init__( # noqa CFQ002 def __init__( # noqa CFQ002
self, self,
model_cls: Type["Model"], model_cls: Type["Model"],
filter_clauses: List[FilterAction], filter_clauses: List[FilterAction],
exclude_clauses: List[FilterAction], exclude_clauses: List[FilterAction],
select_related: List, select_related: List,
limit_count: Optional[int], limit_count: Optional[int],
offset: Optional[int], offset: Optional[int],
excludable: "ExcludableItems", excludable: "ExcludableItems",
order_bys: Optional[List["OrderAction"]], order_bys: Optional[List["OrderAction"]],
limit_raw_sql: bool, limit_raw_sql: bool,
) -> None: ) -> None:
self.query_offset = offset self.query_offset = offset
self.limit_count = limit_count self.limit_count = limit_count
@ -153,9 +153,10 @@ class Query:
return expr return expr
def _build_pagination_condition( def _build_pagination_condition(
self self,
) -> Tuple[ ) -> Tuple[
sqlalchemy.sql.expression.TextClause, sqlalchemy.sql.expression.TextClause]: sqlalchemy.sql.expression.TextClause, sqlalchemy.sql.expression.TextClause
]:
""" """
In order to apply limit and offset on main table in join only In order to apply limit and offset on main table in join only
(otherwise you can get only partially constructed main model (otherwise you can get only partially constructed main model
@ -183,10 +184,9 @@ class Query:
limit_qry = sqlalchemy.sql.select([qry_text]) limit_qry = sqlalchemy.sql.select([qry_text])
limit_qry = limit_qry.select_from(self.select_from) limit_qry = limit_qry.select_from(self.select_from)
limit_qry = FilterQuery(filter_clauses=self.filter_clauses).apply(limit_qry) limit_qry = FilterQuery(filter_clauses=self.filter_clauses).apply(limit_qry)
limit_qry = FilterQuery(filter_clauses=self.exclude_clauses, limit_qry = FilterQuery(
exclude=True).apply( filter_clauses=self.exclude_clauses, exclude=True
limit_qry ).apply(limit_qry)
)
limit_qry = limit_qry.group_by(qry_text) limit_qry = limit_qry.group_by(qry_text)
for order_by in maxes.values(): for order_by in maxes.values():
limit_qry = limit_qry.order_by(order_by) limit_qry = limit_qry.order_by(order_by)
@ -194,11 +194,12 @@ class Query:
limit_qry = OffsetQuery(query_offset=self.query_offset).apply(limit_qry) limit_qry = OffsetQuery(query_offset=self.query_offset).apply(limit_qry)
limit_qry = limit_qry.alias("limit_query") limit_qry = limit_qry.alias("limit_query")
on_clause = sqlalchemy.text( on_clause = sqlalchemy.text(
f"limit_query.{pk_alias}={self.table.name}.{pk_alias}") f"limit_query.{pk_alias}={self.table.name}.{pk_alias}"
)
return limit_qry, on_clause return limit_qry, on_clause
def _apply_expression_modifiers( def _apply_expression_modifiers(
self, expr: sqlalchemy.sql.select self, expr: sqlalchemy.sql.select
) -> sqlalchemy.sql.select: ) -> sqlalchemy.sql.select:
""" """
Receives the select query (might be join) and applies: Receives the select query (might be join) and applies:

View File

@ -281,7 +281,7 @@ class QuerySet:
limit_raw_sql=self.limit_sql_raw, limit_raw_sql=self.limit_sql_raw,
) )
exp = qry.build_select_expression() exp = qry.build_select_expression()
print("\n", exp.compile(compile_kwargs={"literal_binds": True})) # print("\n", exp.compile(compile_kwargs={"literal_binds": True}))
return exp return exp
def filter( # noqa: A003 def filter( # noqa: A003

View File

@ -153,8 +153,7 @@ async def test_or_filters():
assert books[0].title == "The Witcher" assert books[0].title == "The Witcher"
with pytest.raises(QueryDefinitionError): with pytest.raises(QueryDefinitionError):
await Book.objects.select_related("author").filter('wrong').all() await Book.objects.select_related("author").filter("wrong").all()
# TODO: Check / modify # TODO: Check / modify

View File

@ -174,27 +174,27 @@ async def test_sort_order_on_related_model():
owner = ( owner = (
await Owner.objects.select_related("toys") await Owner.objects.select_related("toys")
.order_by("toys__name") .order_by("toys__name")
.filter(name="Zeus") .filter(name="Zeus")
.get() .get()
) )
assert owner.toys[0].name == "Toy 1" assert owner.toys[0].name == "Toy 1"
assert owner.toys[1].name == "Toy 4" assert owner.toys[1].name == "Toy 4"
owner = ( owner = (
await Owner.objects.select_related("toys") await Owner.objects.select_related("toys")
.order_by("-toys__name") .order_by("-toys__name")
.filter(name="Zeus") .filter(name="Zeus")
.get() .get()
) )
assert owner.toys[0].name == "Toy 4" assert owner.toys[0].name == "Toy 4"
assert owner.toys[1].name == "Toy 1" assert owner.toys[1].name == "Toy 1"
owners = ( owners = (
await Owner.objects.select_related("toys") await Owner.objects.select_related("toys")
.order_by("-toys__name") .order_by("-toys__name")
.filter(name__in=["Zeus", "Hermes"]) .filter(name__in=["Zeus", "Hermes"])
.all() .all()
) )
assert owners[0].toys[0].name == "Toy 6" assert owners[0].toys[0].name == "Toy 6"
assert owners[0].toys[1].name == "Toy 5" assert owners[0].toys[1].name == "Toy 5"
@ -208,9 +208,9 @@ async def test_sort_order_on_related_model():
owners = ( owners = (
await Owner.objects.select_related("toys") await Owner.objects.select_related("toys")
.order_by("-toys__name") .order_by("-toys__name")
.filter(name__in=["Zeus", "Hermes"]) .filter(name__in=["Zeus", "Hermes"])
.all() .all()
) )
assert owners[0].toys[0].name == "Toy 7" assert owners[0].toys[0].name == "Toy 7"
assert owners[0].toys[1].name == "Toy 4" assert owners[0].toys[1].name == "Toy 4"
@ -221,12 +221,15 @@ async def test_sort_order_on_related_model():
assert owners[1].toys[1].name == "Toy 5" assert owners[1].toys[1].name == "Toy 5"
assert owners[1].name == "Hermes" assert owners[1].name == "Hermes"
toys = await Toy.objects.select_related('owner').order_by( toys = (
['owner__name', 'name']).limit( await Toy.objects.select_related("owner")
2).all() .order_by(["owner__name", "name"])
.limit(2)
.all()
)
assert len(toys) == 2 assert len(toys) == 2
assert toys[0].name == 'Toy 2' assert toys[0].name == "Toy 2"
assert toys[1].name == 'Toy 3' assert toys[1].name == "Toy 3"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -257,9 +260,9 @@ async def test_sort_order_on_many_to_many():
user = ( user = (
await User.objects.select_related("cars") await User.objects.select_related("cars")
.filter(name="Mark") .filter(name="Mark")
.order_by("cars__name") .order_by("cars__name")
.get() .get()
) )
assert user.cars[0].name == "Buggy" assert user.cars[0].name == "Buggy"
assert user.cars[1].name == "Ferrari" assert user.cars[1].name == "Ferrari"
@ -268,9 +271,9 @@ async def test_sort_order_on_many_to_many():
user = ( user = (
await User.objects.select_related("cars") await User.objects.select_related("cars")
.filter(name="Mark") .filter(name="Mark")
.order_by("-cars__name") .order_by("-cars__name")
.get() .get()
) )
assert user.cars[3].name == "Buggy" assert user.cars[3].name == "Buggy"
assert user.cars[2].name == "Ferrari" assert user.cars[2].name == "Ferrari"
@ -286,8 +289,8 @@ async def test_sort_order_on_many_to_many():
users = ( users = (
await User.objects.select_related(["cars__factory"]) await User.objects.select_related(["cars__factory"])
.order_by(["-cars__factory__name", "cars__name"]) .order_by(["-cars__factory__name", "cars__name"])
.all() .all()
) )
assert users[0].name == "Julie" assert users[0].name == "Julie"
@ -333,8 +336,8 @@ async def test_sort_order_with_aliases():
aliases = ( aliases = (
await AliasTest.objects.select_related("nested") await AliasTest.objects.select_related("nested")
.order_by("-nested__name") .order_by("-nested__name")
.all() .all()
) )
assert aliases[0].nested.name == "Try4" assert aliases[0].nested.name == "Try4"
assert aliases[1].nested.name == "Try3" assert aliases[1].nested.name == "Try3"