From ebf7c6e06f580ccbe66abacf3cb6b9d50d7f0f07 Mon Sep 17 00:00:00 2001 From: Ethon Date: Wed, 27 Apr 2022 18:01:00 +0800 Subject: [PATCH] add enum field (#626) * add enum field * add decorator for asyncio * fix enum typing, additional tests, add docs * add more tests Co-authored-by: collerek --- README.md | 5 +- docs/fields/field-types.md | 23 ++++- docs/index.md | 5 +- ormar/__init__.py | 2 + ormar/fields/__init__.py | 2 + ormar/fields/model_fields.py | 55 ++++++++++-- ormar/models/mixins/save_mixin.py | 9 ++ ormar/queryset/queryset.py | 4 +- tests/test_model_definition/test_columns.py | 83 +++++++++++++++++++ .../test_queryset_level_methods.py | 22 ++++- 10 files changed, 193 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 79d96f9..525c87a 100644 --- a/README.md +++ b/README.md @@ -637,10 +637,11 @@ Available Model Fields (with required args - optional ones in docs): * `Decimal(scale, precision)` * `UUID()` * `LargeBinary(max_length)` -* `EnumField` - by passing `choices` to any other Field type +* `Enum(enum_class)` +* `Enum` like Field - by passing `choices` to any other Field type * `EncryptedString` - by passing `encrypt_secret` and `encrypt_backend` * `ForeignKey(to)` -* `ManyToMany(to, through)` +* `ManyToMany(to)` ### Available fields options The following keyword arguments are supported on all field types. diff --git a/docs/fields/field-types.md b/docs/fields/field-types.md index 99c5409..37be286 100644 --- a/docs/fields/field-types.md +++ b/docs/fields/field-types.md @@ -200,8 +200,23 @@ When loaded it's always python UUID so you can compare it and compare two format ### Enum -Although there is no dedicated field type for Enums in `ormar` you can change any -field into `Enum` like field by passing a `choices` list that is accepted by all Field types. +There are two ways to use enums in ormar -> one is a dedicated `Enum` field that uses `sqlalchemy.Enum` column type, while the other is setting `choices` on any field in ormar. + +The Enum field uses the database dialect specific Enum column type if it's available, but fallback to varchar if this field type is not available. + +The `choices` option always respect the database field type selected. + +So which one to use depends on the backend you use and on the column/ data type you want in your Enum field. + +#### Enum - Field + +`Enum(enum_class=Type[Enum])` has a required `enum_class` parameter. + +* Sqlalchemy column: `sqlalchemy.Enum` +* Type (used for pydantic): `Type[Enum]` + +#### Choices +You can change any field into `Enum` like field by passing a `choices` list that is accepted by all Field types. It will add both: validation in `pydantic` model and will display available options in schema, therefore it will be available in docs of `fastapi`. @@ -210,7 +225,7 @@ If you still want to use `Enum` in your application you can do this by passing a and later pass value of given option to a given field (note tha Enum is not JsonSerializable). ```python -# not that imports and endpoints declaration +# note that imports and endpoints declaration # is skipped here for brevity from enum import Enum class TestEnum(Enum): @@ -244,4 +259,4 @@ response = client.post( [relations]: ../relations/index.md [queries]: ../queries.md [pydantic]: https://pydantic-docs.helpmanual.io/usage/schema/#field-customisation -[server default]: https://docs.sqlalchemy.org/en/13/core/defaults.html#server-invoked-ddl-explicit-default-expressions \ No newline at end of file +[server default]: https://docs.sqlalchemy.org/en/13/core/defaults.html#server-invoked-ddl-explicit-default-expressions diff --git a/docs/index.md b/docs/index.md index fa6ce96..6ea70e8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -646,10 +646,11 @@ Available Model Fields (with required args - optional ones in docs): * `Decimal(scale, precision)` * `UUID()` * `LargeBinary(max_length)` -* `EnumField` - by passing `choices` to any other Field type +* `Enum(enum_class)` +* `Enum` like Field - by passing `choices` to any other Field type * `EncryptedString` - by passing `encrypt_secret` and `encrypt_backend` * `ForeignKey(to)` -* `ManyToMany(to, through)` +* `ManyToMany(to)` ### Available fields options The following keyword arguments are supported on all field types. diff --git a/ormar/__init__.py b/ormar/__init__.py index 4187d60..1b61eb6 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -53,6 +53,7 @@ from ormar.fields import ( Decimal, ENCODERS_MAP, EncryptBackends, + Enum, Float, ForeignKey, ForeignKeyField, @@ -97,6 +98,7 @@ __all__ = [ "DateTime", "Date", "Decimal", + "Enum", "Float", "ManyToMany", "Model", diff --git a/ormar/fields/__init__.py b/ormar/fields/__init__.py index e90b5df..e55f2d3 100644 --- a/ormar/fields/__init__.py +++ b/ormar/fields/__init__.py @@ -14,6 +14,7 @@ from ormar.fields.model_fields import ( Date, DateTime, Decimal, + Enum, Float, Integer, JSON, @@ -43,6 +44,7 @@ __all__ = [ "Float", "Time", "UUID", + "Enum", "ForeignKey", "ManyToMany", "ManyToManyField", diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index 0b565aa..4ad3faf 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -1,8 +1,8 @@ import datetime import decimal import uuid -from enum import Enum -from typing import Any, Optional, Set, TYPE_CHECKING, Type, Union, overload +from enum import EnumMeta, Enum as E +from typing import Any, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union, overload import pydantic import sqlalchemy @@ -91,7 +91,7 @@ def convert_choices_if_needed( :return: value, choices list :rtype: Tuple[Any, Set] """ - choices = {o.value if isinstance(o, Enum) else o for o in choices} + choices = {o.value if isinstance(o, E) else o for o in choices} encoder = ormar.ENCODERS_MAP.get(field_type, lambda x: x) if field_type == decimal.Decimal: precision = scale @@ -150,12 +150,14 @@ class ModelFieldFactory: scale=kwargs.get("scale", None), represent_as_str=kwargs.get("represent_as_base64_str", False), ) + enum_class = kwargs.get("enum_class", None) + field_type = cls._type if enum_class is None else enum_class namespace = dict( - __type__=cls._type, + __type__=field_type, __pydantic_type__=overwrite_pydantic_type if overwrite_pydantic_type is not None - else cls._type, + else field_type, __sample__=cls._sample, alias=kwargs.pop("name", None), name=None, @@ -803,3 +805,46 @@ class UUID(ModelFieldFactory, uuid.UUID): """ uuid_format = kwargs.get("uuid_format", "hex") return sqlalchemy_uuid.UUID(uuid_format=uuid_format) + + +if TYPE_CHECKING: # pragma: nocover + + T = TypeVar("T", bound=E) + + def Enum(enum_class: Type[T], **kwargs: Any) -> T: + pass + +else: + + class Enum(ModelFieldFactory): + """ + Enum field factory that construct Field classes and populated their values. + """ + + _type = E + _sample = None + + def __new__( # type: ignore # noqa CFQ002 + cls, *, enum_class: Type[E], **kwargs: Any + ) -> BaseField: + + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def validate(cls, **kwargs: Any) -> None: + enum_class = kwargs.get("enum_class") + if enum_class is None or not isinstance(enum_class, EnumMeta): + raise ModelDefinitionError("Enum Field choices must be EnumType") + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + enum_cls = kwargs.get("enum_class") + return sqlalchemy.Enum(enum_cls) diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index 66922c9..4408fa7 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -1,5 +1,6 @@ import base64 import uuid +from enum import Enum from typing import ( Any, Callable, @@ -73,6 +74,14 @@ class SavePrepareMixin(RelationMixin, AliasMixin): new_kwargs = cls.reconvert_str_to_bytes(new_kwargs) new_kwargs = cls.dump_all_json_fields_to_str(new_kwargs) new_kwargs = cls.translate_columns_to_aliases(new_kwargs) + new_kwargs = cls.translate_enum_columns(new_kwargs) + return new_kwargs + + @classmethod + def translate_enum_columns(cls, new_kwargs: dict) -> dict: + for k, v in new_kwargs.items(): + if isinstance(v, Enum): + new_kwargs[k] = v.name return new_kwargs @classmethod diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 1a6215a..bebccd8 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -698,9 +698,7 @@ class QuerySet(Generic[T]): expr = sqlalchemy.func.count().select().select_from(expr) if distinct: pk_column_name = self.model.get_column_alias(self.model_meta.pkname) - expr_distinct = expr.group_by(pk_column_name).alias( - "subquery_for_group" - ) + expr_distinct = expr.group_by(pk_column_name).alias("subquery_for_group") expr = sqlalchemy.func.count().select().select_from(expr_distinct) return await self.database.fetch_val(expr) diff --git a/tests/test_model_definition/test_columns.py b/tests/test_model_definition/test_columns.py index c4726fa..56c9c74 100644 --- a/tests/test_model_definition/test_columns.py +++ b/tests/test_model_definition/test_columns.py @@ -1,4 +1,5 @@ import datetime +from enum import Enum import databases import pydantic @@ -6,6 +7,7 @@ import pytest import sqlalchemy import ormar +from ormar import ModelDefinitionError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -16,6 +18,11 @@ def time(): return datetime.datetime.now().time() +class MyEnum(Enum): + SMALL = 1 + BIG = 2 + + class Example(ormar.Model): class Meta: tablename = "example" @@ -30,6 +37,17 @@ class Example(ormar.Model): description: str = ormar.Text(nullable=True) value: float = ormar.Float(nullable=True) data: pydantic.Json = ormar.JSON(default={}) + size: MyEnum = ormar.Enum(enum_class=MyEnum, default=MyEnum.SMALL) + + +class EnumExample(ormar.Model): + class Meta: + tablename = "enum_example" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + size: MyEnum = ormar.Enum(enum_class=MyEnum, default=MyEnum.SMALL) @pytest.fixture(autouse=True, scope="module") @@ -40,6 +58,49 @@ def create_test_database(): metadata.drop_all(engine) +def test_proper_enum_column_type(): + assert Example.__fields__["size"].type_ == MyEnum + + +def test_accepts_only_proper_enums(): + class WrongEnum(Enum): + A = 1 + B = 2 + + with pytest.raises(pydantic.ValidationError): + Example(size=WrongEnum.A) + + +@pytest.mark.asyncio +async def test_enum_bulk_operations(): + async with database: + examples = [EnumExample(), EnumExample()] + await EnumExample.objects.bulk_create(examples) + + check = await EnumExample.objects.all() + assert all(x.size == MyEnum.SMALL for x in check) + + for x in check: + x.size = MyEnum.BIG + + await EnumExample.objects.bulk_update(check) + check2 = await EnumExample.objects.all() + assert all(x.size == MyEnum.BIG for x in check2) + + +@pytest.mark.asyncio +async def test_enum_filter(): + async with database: + examples = [EnumExample(), EnumExample(size=MyEnum.BIG)] + await EnumExample.objects.bulk_create(examples) + + check = await EnumExample.objects.all(size=MyEnum.SMALL) + assert len(check) == 1 + + check = await EnumExample.objects.all(size=MyEnum.BIG) + assert len(check) == 1 + + @pytest.mark.asyncio async def test_model_crud(): async with database: @@ -52,6 +113,13 @@ async def test_model_crud(): assert example.description is None assert example.value is None assert example.data == {} + assert example.size == MyEnum.SMALL + + await example.update(data={"foo": 123}, value=123.456, size=MyEnum.BIG) + await example.load() + assert example.value == 123.456 + assert example.data == {"foo": 123} + assert example.size == MyEnum.BIG await example.update(data={"foo": 123}, value=123.456) await example.load() @@ -59,3 +127,18 @@ async def test_model_crud(): assert example.data == {"foo": 123} await example.delete() + + +@pytest.mark.asyncio +async def test_invalid_enum_field(): + async with database: + with pytest.raises(ModelDefinitionError): + + class Example2(ormar.Model): + class Meta: + tablename = "example" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + size: MyEnum = ormar.Enum(enum_class=[]) diff --git a/tests/test_queries/test_queryset_level_methods.py b/tests/test_queries/test_queryset_level_methods.py index 7bf3904..a85cd56 100644 --- a/tests/test_queries/test_queryset_level_methods.py +++ b/tests/test_queries/test_queryset_level_methods.py @@ -1,4 +1,5 @@ -from typing import List, Optional +from enum import Enum +from typing import Optional import databases import pydantic @@ -19,6 +20,11 @@ database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() +class MySize(Enum): + SMALL = 0 + BIG = 1 + + class Book(ormar.Model): class Meta: tablename = "books" @@ -45,6 +51,7 @@ class ToDo(ormar.Model): text: str = ormar.String(max_length=500) completed: bool = ormar.Boolean(default=False) pairs: pydantic.Json = ormar.JSON(default=[]) + size = ormar.Enum(enum_class=MySize, default=MySize.SMALL) class Category(ormar.Model): @@ -77,6 +84,7 @@ class ItemConfig(ormar.Model): id: Optional[int] = ormar.Integer(primary_key=True) item_id: str = ormar.String(max_length=32, index=True) pairs: pydantic.Json = ormar.JSON(default=["2", "3"]) + size = ormar.Enum(enum_class=MySize, default=MySize.SMALL) class QuerySetCls(QuerySet): @@ -343,6 +351,7 @@ async def test_bulk_update(): for todo in todoes: todo.text = todo.text + "_1" todo.completed = False + todo.size = MySize.BIG await ToDo.objects.bulk_update(todoes) @@ -354,6 +363,7 @@ async def test_bulk_update(): for todo in todoes: assert todo.text[-2:] == "_1" + assert todo.size == MySize.BIG @pytest.mark.asyncio @@ -474,3 +484,13 @@ async def test_custom_queryset_cls(): await Customer(name="test").save() c = await Customer.objects.first_or_404(name="test") assert c.name == "test" + + +@pytest.mark.asyncio +async def test_filter_enum(): + async with database: + it = ItemConfig(item_id="test_1") + await it.save() + + it = await ItemConfig.objects.filter(size=MySize.SMALL).first() + assert it