diff --git a/.coverage b/.coverage index 550b22b..17ef238 100644 Binary files a/.coverage and b/.coverage differ diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index fd70258..eae086c 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -15,6 +15,10 @@ FILTER_OPERATORS = { "iexact": "ilike", "contains": "like", "icontains": "ilike", + "startswith": "like", + "istartswith": "ilike", + "endswith": "like", + "iendswith": "ilike", "in": "in_", "gt": "__gt__", "gte": "__ge__", @@ -169,7 +173,14 @@ class QueryClause: ) -> Tuple[str, bool]: has_escaped_character = False - if op not in ["contains", "icontains"]: + if op not in [ + "contains", + "icontains", + "startswith", + "istartswith", + "endswith", + "iendswith", + ]: return value, has_escaped_character if isinstance(value, ormar.Model): @@ -183,7 +194,9 @@ class QueryClause: # enable escape modifier for char in ESCAPE_CHARACTERS: value = value.replace(char, f"\\{char}") - value = f"%{value}%" + prefix = "%" if "start" not in op else "" + sufix = "%" if "end" not in op else "" + value = f"{prefix}{value}{sufix}" return value, has_escaped_character diff --git a/tests/test_models.py b/tests/test_models.py index e402c86..79ca0ed 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -374,3 +374,33 @@ async def test_model_choices(): await Country.objects.create( name=name, taxed=taxed, country_code=country_code ) + + +@pytest.mark.asyncio +async def test_start_and_end_filters(): + async with database: + async with database.transaction(force_rollback=True): + await User.objects.create(name="Markos Uj") + await User.objects.create(name="Maqua Bigo") + await User.objects.create(name="maqo quidid") + await User.objects.create(name="Louis Figo") + await User.objects.create(name="Loordi Kami") + await User.objects.create(name="Yuuki Sami") + + users = await User.objects.filter(name__startswith="Mar").all() + assert len(users) == 1 + + users = await User.objects.filter(name__istartswith="ma").all() + assert len(users) == 3 + + users = await User.objects.filter(name__istartswith="Maq").all() + assert len(users) == 2 + + users = await User.objects.filter(name__iendswith="AMI").all() + assert len(users) == 2 + + users = await User.objects.filter(name__endswith="Uj").all() + assert len(users) == 1 + + users = await User.objects.filter(name__endswith="igo").all() + assert len(users) == 2