From 6299ea43831a295db61d22b53fc268e3a8abf239 Mon Sep 17 00:00:00 2001 From: huangsong Date: Tue, 18 Jan 2022 16:41:22 +0800 Subject: [PATCH] can custom query_cls --- ormar/models/metaclass.py | 4 ++- .../test_queryset_level_methods.py | 33 +++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index cc1dede..1e51252 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -614,6 +614,8 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): return new_model + __queryset_cls__ = QuerySet + @property def objects(cls: Type["T"]) -> "QuerySet[T]": # type: ignore if cls.Meta.requires_ref_update: @@ -622,7 +624,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): f"ForwardRefs. \nBefore using the model you " f"need to call update_forward_refs()." ) - return QuerySet(model_cls=cls) + return cls.__queryset_cls__(model_cls=cls) def __getattr__(self, item: str) -> Any: """ diff --git a/tests/test_queries/test_queryset_level_methods.py b/tests/test_queries/test_queryset_level_methods.py index 9aac707..89f69ef 100644 --- a/tests/test_queries/test_queryset_level_methods.py +++ b/tests/test_queries/test_queryset_level_methods.py @@ -6,6 +6,7 @@ import pytest import sqlalchemy import ormar +from ormar import QuerySet from ormar.exceptions import ( ModelPersistenceError, QueryDefinitionError, @@ -77,6 +78,27 @@ class ItemConfig(ormar.Model): pairs: pydantic.Json = ormar.JSON(default=["2", "3"]) +class Customer(ormar.Model): + class Meta: + metadata = metadata + database = database + tablename = "customer" + + class QuerySetCls(QuerySet): + + async def first_or_404(self, *args, **kwargs): + entity = await self.get_or_none(*args, **kwargs) + if not entity: + # maybe HTTPException in fastapi + raise ValueError("customer not found") + return entity + + __queryset_cls__ = QuerySetCls + + id: Optional[int] = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=32) + + @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) @@ -349,3 +371,14 @@ async def test_bulk_operations_with_json(): await ItemConfig.objects.bulk_update(items) items = await ItemConfig.objects.all() assert all(x.pairs == ["1"] for x in items) + + +@pytest.mark.asyncio +async def test_custom_queryset_cls(): + async with database: + with pytest.raises(ValueError): + await Customer.objects.first_or_404(id=1) + + await Customer(name="test").save() + c = await Customer.objects.first_or_404(name="test") + assert c.name == "test"