diff --git a/docs/releases.md b/docs/releases.md index f2328cc..008c138 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,9 @@ +# 0.10.4 + +## ✨ Features + +* Add possibility to `filter` and `order_by` with field access instead of dunder separated strings. + # 0.10.3 ## ✨ Features diff --git a/ormar/__init__.py b/ormar/__init__.py index 7436730..a57b83f 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -75,7 +75,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.10.3" +__version__ = "0.10.4" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index bc80332..1ed1cfa 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -39,7 +39,7 @@ from ormar.models.helpers import ( sqlalchemy_columns_from_model_fields, ) from ormar.models.quick_access_views import quick_access_set -from ormar.queryset import QuerySet +from ormar.queryset import FieldAccessor, QuerySet from ormar.relations.alias_manager import AliasManager from ormar.signals import Signal, SignalEmitter @@ -561,3 +561,14 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): f"need to call update_forward_refs()." ) return QuerySet(model_cls=cls) + + def __getattr__(self, item: str) -> Any: + if item in object.__getattribute__(self, "Meta").model_fields: + field = self.Meta.model_fields.get(item) + if field.is_relation: + return FieldAccessor( + source_model=self, model=field.to, access_chain=item + ) + else: + return FieldAccessor(source_model=self, field=field, access_chain=item) + return object.__getattribute__(self, item) diff --git a/ormar/queryset/__init__.py b/ormar/queryset/__init__.py index e75febf..909cb19 100644 --- a/ormar/queryset/__init__.py +++ b/ormar/queryset/__init__.py @@ -3,6 +3,7 @@ Contains QuerySet and different Query classes to allow for constructing of sql q """ from ormar.queryset.actions import FilterAction, OrderAction, SelectAction from ormar.queryset.clause import and_, or_ +from ormar.queryset.field_accessor import FieldAccessor from ormar.queryset.filter_query import FilterQuery from ormar.queryset.limit_query import LimitQuery from ormar.queryset.offset_query import OffsetQuery @@ -20,4 +21,5 @@ __all__ = [ "SelectAction", "and_", "or_", + "FieldAccessor", ] diff --git a/ormar/queryset/actions/filter_action.py b/ormar/queryset/actions/filter_action.py index 279e0fa..2917063 100644 --- a/ormar/queryset/actions/filter_action.py +++ b/ormar/queryset/actions/filter_action.py @@ -26,6 +26,23 @@ FILTER_OPERATORS = { "lt": "__lt__", "lte": "__le__", } +METHODS_TO_OPERATORS = { + "__eq__": "exact", + "__mod__": "contains", + "__gt__": "gt", + "__ge__": "gte", + "__lt__": "lt", + "__le__": "lte", + "iexact": "iexact", + "contains": "contains", + "icontains": "icontains", + "startswith": "startswith", + "istartswith": "istartswith", + "endswith": "endswith", + "iendswith": "iendswith", + "isnull": "isnull", + "in": "in", +} ESCAPE_CHARACTERS = ["%", "_"] diff --git a/ormar/queryset/field_accessor.py b/ormar/queryset/field_accessor.py new file mode 100644 index 0000000..02f15e6 --- /dev/null +++ b/ormar/queryset/field_accessor.py @@ -0,0 +1,102 @@ +from typing import Any + +from ormar.queryset.actions import OrderAction +from ormar.queryset.actions import FilterAction +from ormar.queryset.actions.filter_action import METHODS_TO_OPERATORS + + +class FieldAccessor: + def __init__( + self, source_model=None, field=None, model=None, access_chain: str = "" + ): + self._source_model = source_model + self._field = field + self._model = model + self._access_chain = access_chain + + def __getattr__(self, item: str) -> Any: + if self._field and item == self._field.name: + return self._field + + if item in self._model.Meta.model_fields: + field = self._model.Meta.model_fields.get(item) + if field.is_relation: + return FieldAccessor( + source_model=self._source_model, + model=field.to, + access_chain=self._access_chain + f"__{item}", + ) + else: + return FieldAccessor( + source_model=self._source_model, + field=field, + access_chain=self._access_chain + f"__{item}", + ) + return object.__getattribute__(self, item) + + def _check_field(self) -> None: + if not self._field: + raise AttributeError( + "Cannot filter by Model, you need to provide model name" + ) + + def _select_operator(self, op: str, other: Any) -> FilterAction: + self._check_field() + return FilterAction( + filter_str=self._access_chain + f"__{METHODS_TO_OPERATORS[op]}", + value=other, + model_cls=self._source_model, + ) + + def __eq__(self, other: Any) -> FilterAction: # type: ignore + return self._select_operator(op="__eq__", other=other) + + def __ge__(self, other: Any) -> FilterAction: + return self._select_operator(op="__ge__", other=other) + + def __gt__(self, other: Any) -> FilterAction: + return self._select_operator(op="__gt__", other=other) + + def __le__(self, other: Any) -> FilterAction: + return self._select_operator(op="__le__", other=other) + + def __lt__(self, other) -> FilterAction: + return self._select_operator(op="__lt__", other=other) + + def __mod__(self, other) -> FilterAction: + return self._select_operator(op="__mod__", other=other) + + def __contains__(self, item) -> FilterAction: + return self._select_operator(op="in", other=item) + + def iexact(self, other) -> FilterAction: + return self._select_operator(op="iexact", other=other) + + def contains(self, other) -> FilterAction: + return self._select_operator(op="contains", other=other) + + def icontains(self, other) -> FilterAction: + return self._select_operator(op="icontains", other=other) + + def startswith(self, other) -> FilterAction: + return self._select_operator(op="startswith", other=other) + + def istartswith(self, other) -> FilterAction: + return self._select_operator(op="istartswith", other=other) + + def endswith(self, other) -> FilterAction: + return self._select_operator(op="endswith", other=other) + + def iendswith(self, other) -> FilterAction: + return self._select_operator(op="iendswith", other=other) + + def isnull(self, other) -> FilterAction: + return self._select_operator(op="isnull", other=other) + + def asc(self) -> OrderAction: + return OrderAction(order_str=self._access_chain, model_cls=self._source_model) + + def desc(self) -> OrderAction: + return OrderAction( + order_str="-" + self._access_chain, model_cls=self._source_model + ) diff --git a/tests/test_model_definition/test_fields_access.py b/tests/test_model_definition/test_fields_access.py new file mode 100644 index 0000000..22d9941 --- /dev/null +++ b/tests/test_model_definition/test_fields_access.py @@ -0,0 +1,165 @@ +import databases +import pytest +import sqlalchemy + +import ormar +from ormar import BaseField +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class BaseMeta(ormar.ModelMeta): + metadata = metadata + database = database + + +class PriceList(ormar.Model): + class Meta(BaseMeta): + tablename = "price_lists" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Category(ormar.Model): + class Meta(BaseMeta): + tablename = "categories" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + price_lists = ormar.ManyToMany(PriceList, related_name="categories") + + +class Product(ormar.Model): + class Meta(BaseMeta): + tablename = "product" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + rating: float = ormar.Float(minimum=1, maximum=5) + category = ormar.ForeignKey(Category) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_fields_access(): + # basic access + assert Product.id._field == Product.Meta.model_fields["id"] + assert isinstance(Product.id._field, BaseField) + assert Product.id._access_chain == "id" + assert Product.id._source_model == Product + + # nested models + curr_field = Product.category.name + assert curr_field._field == Category.Meta.model_fields["name"] + assert curr_field._access_chain == "category__name" + assert curr_field._source_model == Product + + # deeper nesting + curr_field = Product.category.price_lists.name + assert curr_field._field == PriceList.Meta.model_fields["name"] + assert curr_field._access_chain == "category__price_lists__name" + assert curr_field._source_model == Product + + # reverse nesting + curr_field = PriceList.categories.products.rating + assert curr_field._field == Product.Meta.model_fields["rating"] + assert curr_field._access_chain == "categories__products__rating" + assert curr_field._source_model == PriceList + + +@pytest.mark.parametrize( + "method, expected, expected_value", + [ + ("__eq__", "exact", "Test"), + ("__lt__", "lt", "Test"), + ("__le__", "lte", "Test"), + ("__ge__", "gte", "Test"), + ("__gt__", "gt", "Test"), + ("iexact", "iexact", "Test"), + ("contains", "contains", "%Test%"), + ("icontains", "icontains", "%Test%"), + ("startswith", "startswith", "Test%"), + ("istartswith", "istartswith", "Test%"), + ("endswith", "endswith", "%Test"), + ("iendswith", "iendswith", "%Test"), + ("isnull", "isnull", "Test"), + ("__contains__", "in", "Test"), + ("__mod__", "contains", "%Test%"), + ], +) +def test_operator_return_proper_filter_action(method, expected, expected_value): + action = getattr(Product.name, method)("Test") + assert action.source_model == Product + assert action.target_model == Product + assert action.operator == expected + assert action.filter_value == expected_value + + action = getattr(Product.category.name, method)("Test") + assert action.source_model == Product + assert action.target_model == Category + assert action.operator == expected + assert action.filter_value == expected_value + + action = getattr(PriceList.categories.products.rating, method)("Test") + assert action.source_model == PriceList + assert action.target_model == Product + assert action.operator == expected + assert action.filter_value == expected_value + + +@pytest.mark.parametrize("method, expected_direction", [("asc", ""), ("desc", "desc"),]) +def test_operator_return_proper_order_action(method, expected_direction): + action = getattr(Product.name, method)() + assert action.source_model == Product + assert action.target_model == Product + assert action.direction == expected_direction + assert action.is_source_model_order + + action = getattr(Product.category.name, method)() + assert action.source_model == Product + assert action.target_model == Category + assert action.direction == expected_direction + assert not action.is_source_model_order + + action = getattr(PriceList.categories.products.rating, method)() + assert action.source_model == PriceList + assert action.target_model == Product + assert action.direction == expected_direction + assert not action.is_source_model_order + + +# @pytest.mark.asyncio +# async def test_filtering_by_field_access(): +# async with database: +# async with database.transaction(force_rollback=True): +# category = await Category(name='Toys').save() +# product1 = await Product(name="G.I Joe", +# rating=4.7, +# category=category).save() +# product2 = await Product(name="My Little Pony", +# rating=3.8, +# category=category).save() +# +# check = Product.object.get(Product.name == "My Little Pony") +# assert check == product2 + +# TODO: Finish implementation +# * overload operators and add missing functions that return FilterAction (V) +# * return OrderAction for desc() and asc() (V) + +# * accept args in all functions that accept filters? or only filter and exclude? +# all functions: delete, first, get, get_or_none, get_or_create, all, filter, exclude +# and same from queryset, should they also accept filter groups? +# * create filter groups for & and | (and ~ - NOT?) +# * accept OrderActions in order_by +#