diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index beae8d7..16f9bff 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -678,7 +678,7 @@ class QuerySet(Generic[T]): expr = sqlalchemy.exists(expr).select() return await self.database.fetch_val(expr) - async def count(self) -> int: + async def count(self, distinct: bool = True) -> int: """ Returns number of rows matching the given criteria (applied with `filter` and `exclude` if set before). @@ -688,6 +688,9 @@ class QuerySet(Generic[T]): """ expr = self.build_select_expression().alias("subquery_for_count") expr = sqlalchemy.func.count().select().select_from(expr) + if distinct: + expr_distinct = expr.group_by(self.model_meta.pkname).alias("subquery_for_group") + expr = sqlalchemy.func.count().select().select_from(expr_distinct) return await self.database.fetch_val(expr) async def _query_aggr_function(self, func_name: str, columns: List) -> Any: diff --git a/tests/test_queries/test_aggr_functions.py b/tests/test_queries/test_aggr_functions.py index 92c10a1..8cc4ae0 100644 --- a/tests/test_queries/test_aggr_functions.py +++ b/tests/test_queries/test_aggr_functions.py @@ -175,3 +175,15 @@ async def test_queryset_method(): assert await author.books.max(["year", "title"]) == dict( year=1930, title="Book 3" ) + +@pytest.mark.asyncio +async def test_count_method(): + async with database: + await sample_data() + + count = await Author.objects.select_related("books").count() + assert count == 1 + + # The legacy functionality + count = await Author.objects.select_related("books").count(distinct=False) + assert count == 3