add enum field (#626)

* add enum field

* add decorator for asyncio

* fix enum typing, additional tests, add docs

* add more tests

Co-authored-by: collerek <collerek@gmail.com>
This commit is contained in:
Ethon
2022-04-27 18:01:00 +08:00
committed by GitHub
parent 2caa17812a
commit ebf7c6e06f
10 changed files with 193 additions and 17 deletions

View File

@ -637,10 +637,11 @@ Available Model Fields (with required args - optional ones in docs):
* `Decimal(scale, precision)` * `Decimal(scale, precision)`
* `UUID()` * `UUID()`
* `LargeBinary(max_length)` * `LargeBinary(max_length)`
* `EnumField` - by passing `choices` to any other Field type * `Enum(enum_class)`
* `Enum` like Field - by passing `choices` to any other Field type
* `EncryptedString` - by passing `encrypt_secret` and `encrypt_backend` * `EncryptedString` - by passing `encrypt_secret` and `encrypt_backend`
* `ForeignKey(to)` * `ForeignKey(to)`
* `ManyToMany(to, through)` * `ManyToMany(to)`
### Available fields options ### Available fields options
The following keyword arguments are supported on all field types. The following keyword arguments are supported on all field types.

View File

@ -200,8 +200,23 @@ When loaded it's always python UUID so you can compare it and compare two format
### Enum ### Enum
Although there is no dedicated field type for Enums in `ormar` you can change any There are two ways to use enums in ormar -> one is a dedicated `Enum` field that uses `sqlalchemy.Enum` column type, while the other is setting `choices` on any field in ormar.
field into `Enum` like field by passing a `choices` list that is accepted by all Field types.
The Enum field uses the database dialect specific Enum column type if it's available, but fallback to varchar if this field type is not available.
The `choices` option always respect the database field type selected.
So which one to use depends on the backend you use and on the column/ data type you want in your Enum field.
#### Enum - Field
`Enum(enum_class=Type[Enum])` has a required `enum_class` parameter.
* Sqlalchemy column: `sqlalchemy.Enum`
* Type (used for pydantic): `Type[Enum]`
#### Choices
You can change any field into `Enum` like field by passing a `choices` list that is accepted by all Field types.
It will add both: validation in `pydantic` model and will display available options in schema, It will add both: validation in `pydantic` model and will display available options in schema,
therefore it will be available in docs of `fastapi`. therefore it will be available in docs of `fastapi`.
@ -210,7 +225,7 @@ If you still want to use `Enum` in your application you can do this by passing a
and later pass value of given option to a given field (note tha Enum is not JsonSerializable). and later pass value of given option to a given field (note tha Enum is not JsonSerializable).
```python ```python
# not that imports and endpoints declaration # note that imports and endpoints declaration
# is skipped here for brevity # is skipped here for brevity
from enum import Enum from enum import Enum
class TestEnum(Enum): class TestEnum(Enum):

View File

@ -646,10 +646,11 @@ Available Model Fields (with required args - optional ones in docs):
* `Decimal(scale, precision)` * `Decimal(scale, precision)`
* `UUID()` * `UUID()`
* `LargeBinary(max_length)` * `LargeBinary(max_length)`
* `EnumField` - by passing `choices` to any other Field type * `Enum(enum_class)`
* `Enum` like Field - by passing `choices` to any other Field type
* `EncryptedString` - by passing `encrypt_secret` and `encrypt_backend` * `EncryptedString` - by passing `encrypt_secret` and `encrypt_backend`
* `ForeignKey(to)` * `ForeignKey(to)`
* `ManyToMany(to, through)` * `ManyToMany(to)`
### Available fields options ### Available fields options
The following keyword arguments are supported on all field types. The following keyword arguments are supported on all field types.

View File

@ -53,6 +53,7 @@ from ormar.fields import (
Decimal, Decimal,
ENCODERS_MAP, ENCODERS_MAP,
EncryptBackends, EncryptBackends,
Enum,
Float, Float,
ForeignKey, ForeignKey,
ForeignKeyField, ForeignKeyField,
@ -97,6 +98,7 @@ __all__ = [
"DateTime", "DateTime",
"Date", "Date",
"Decimal", "Decimal",
"Enum",
"Float", "Float",
"ManyToMany", "ManyToMany",
"Model", "Model",

View File

@ -14,6 +14,7 @@ from ormar.fields.model_fields import (
Date, Date,
DateTime, DateTime,
Decimal, Decimal,
Enum,
Float, Float,
Integer, Integer,
JSON, JSON,
@ -43,6 +44,7 @@ __all__ = [
"Float", "Float",
"Time", "Time",
"UUID", "UUID",
"Enum",
"ForeignKey", "ForeignKey",
"ManyToMany", "ManyToMany",
"ManyToManyField", "ManyToManyField",

View File

@ -1,8 +1,8 @@
import datetime import datetime
import decimal import decimal
import uuid import uuid
from enum import Enum from enum import EnumMeta, Enum as E
from typing import Any, Optional, Set, TYPE_CHECKING, Type, Union, overload from typing import Any, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union, overload
import pydantic import pydantic
import sqlalchemy import sqlalchemy
@ -91,7 +91,7 @@ def convert_choices_if_needed(
:return: value, choices list :return: value, choices list
:rtype: Tuple[Any, Set] :rtype: Tuple[Any, Set]
""" """
choices = {o.value if isinstance(o, Enum) else o for o in choices} choices = {o.value if isinstance(o, E) else o for o in choices}
encoder = ormar.ENCODERS_MAP.get(field_type, lambda x: x) encoder = ormar.ENCODERS_MAP.get(field_type, lambda x: x)
if field_type == decimal.Decimal: if field_type == decimal.Decimal:
precision = scale precision = scale
@ -150,12 +150,14 @@ class ModelFieldFactory:
scale=kwargs.get("scale", None), scale=kwargs.get("scale", None),
represent_as_str=kwargs.get("represent_as_base64_str", False), represent_as_str=kwargs.get("represent_as_base64_str", False),
) )
enum_class = kwargs.get("enum_class", None)
field_type = cls._type if enum_class is None else enum_class
namespace = dict( namespace = dict(
__type__=cls._type, __type__=field_type,
__pydantic_type__=overwrite_pydantic_type __pydantic_type__=overwrite_pydantic_type
if overwrite_pydantic_type is not None if overwrite_pydantic_type is not None
else cls._type, else field_type,
__sample__=cls._sample, __sample__=cls._sample,
alias=kwargs.pop("name", None), alias=kwargs.pop("name", None),
name=None, name=None,
@ -803,3 +805,46 @@ class UUID(ModelFieldFactory, uuid.UUID):
""" """
uuid_format = kwargs.get("uuid_format", "hex") uuid_format = kwargs.get("uuid_format", "hex")
return sqlalchemy_uuid.UUID(uuid_format=uuid_format) return sqlalchemy_uuid.UUID(uuid_format=uuid_format)
if TYPE_CHECKING: # pragma: nocover
T = TypeVar("T", bound=E)
def Enum(enum_class: Type[T], **kwargs: Any) -> T:
pass
else:
class Enum(ModelFieldFactory):
"""
Enum field factory that construct Field classes and populated their values.
"""
_type = E
_sample = None
def __new__( # type: ignore # noqa CFQ002
cls, *, enum_class: Type[E], **kwargs: Any
) -> BaseField:
kwargs = {
**kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def validate(cls, **kwargs: Any) -> None:
enum_class = kwargs.get("enum_class")
if enum_class is None or not isinstance(enum_class, EnumMeta):
raise ModelDefinitionError("Enum Field choices must be EnumType")
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
enum_cls = kwargs.get("enum_class")
return sqlalchemy.Enum(enum_cls)

View File

@ -1,5 +1,6 @@
import base64 import base64
import uuid import uuid
from enum import Enum
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -73,6 +74,14 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
new_kwargs = cls.reconvert_str_to_bytes(new_kwargs) new_kwargs = cls.reconvert_str_to_bytes(new_kwargs)
new_kwargs = cls.dump_all_json_fields_to_str(new_kwargs) new_kwargs = cls.dump_all_json_fields_to_str(new_kwargs)
new_kwargs = cls.translate_columns_to_aliases(new_kwargs) new_kwargs = cls.translate_columns_to_aliases(new_kwargs)
new_kwargs = cls.translate_enum_columns(new_kwargs)
return new_kwargs
@classmethod
def translate_enum_columns(cls, new_kwargs: dict) -> dict:
for k, v in new_kwargs.items():
if isinstance(v, Enum):
new_kwargs[k] = v.name
return new_kwargs return new_kwargs
@classmethod @classmethod

View File

@ -698,9 +698,7 @@ class QuerySet(Generic[T]):
expr = sqlalchemy.func.count().select().select_from(expr) expr = sqlalchemy.func.count().select().select_from(expr)
if distinct: if distinct:
pk_column_name = self.model.get_column_alias(self.model_meta.pkname) pk_column_name = self.model.get_column_alias(self.model_meta.pkname)
expr_distinct = expr.group_by(pk_column_name).alias( expr_distinct = expr.group_by(pk_column_name).alias("subquery_for_group")
"subquery_for_group"
)
expr = sqlalchemy.func.count().select().select_from(expr_distinct) expr = sqlalchemy.func.count().select().select_from(expr_distinct)
return await self.database.fetch_val(expr) return await self.database.fetch_val(expr)

View File

@ -1,4 +1,5 @@
import datetime import datetime
from enum import Enum
import databases import databases
import pydantic import pydantic
@ -6,6 +7,7 @@ import pytest
import sqlalchemy import sqlalchemy
import ormar import ormar
from ormar import ModelDefinitionError
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
@ -16,6 +18,11 @@ def time():
return datetime.datetime.now().time() return datetime.datetime.now().time()
class MyEnum(Enum):
SMALL = 1
BIG = 2
class Example(ormar.Model): class Example(ormar.Model):
class Meta: class Meta:
tablename = "example" tablename = "example"
@ -30,6 +37,17 @@ class Example(ormar.Model):
description: str = ormar.Text(nullable=True) description: str = ormar.Text(nullable=True)
value: float = ormar.Float(nullable=True) value: float = ormar.Float(nullable=True)
data: pydantic.Json = ormar.JSON(default={}) data: pydantic.Json = ormar.JSON(default={})
size: MyEnum = ormar.Enum(enum_class=MyEnum, default=MyEnum.SMALL)
class EnumExample(ormar.Model):
class Meta:
tablename = "enum_example"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
size: MyEnum = ormar.Enum(enum_class=MyEnum, default=MyEnum.SMALL)
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
@ -40,6 +58,49 @@ def create_test_database():
metadata.drop_all(engine) metadata.drop_all(engine)
def test_proper_enum_column_type():
assert Example.__fields__["size"].type_ == MyEnum
def test_accepts_only_proper_enums():
class WrongEnum(Enum):
A = 1
B = 2
with pytest.raises(pydantic.ValidationError):
Example(size=WrongEnum.A)
@pytest.mark.asyncio
async def test_enum_bulk_operations():
async with database:
examples = [EnumExample(), EnumExample()]
await EnumExample.objects.bulk_create(examples)
check = await EnumExample.objects.all()
assert all(x.size == MyEnum.SMALL for x in check)
for x in check:
x.size = MyEnum.BIG
await EnumExample.objects.bulk_update(check)
check2 = await EnumExample.objects.all()
assert all(x.size == MyEnum.BIG for x in check2)
@pytest.mark.asyncio
async def test_enum_filter():
async with database:
examples = [EnumExample(), EnumExample(size=MyEnum.BIG)]
await EnumExample.objects.bulk_create(examples)
check = await EnumExample.objects.all(size=MyEnum.SMALL)
assert len(check) == 1
check = await EnumExample.objects.all(size=MyEnum.BIG)
assert len(check) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_crud(): async def test_model_crud():
async with database: async with database:
@ -52,6 +113,13 @@ async def test_model_crud():
assert example.description is None assert example.description is None
assert example.value is None assert example.value is None
assert example.data == {} assert example.data == {}
assert example.size == MyEnum.SMALL
await example.update(data={"foo": 123}, value=123.456, size=MyEnum.BIG)
await example.load()
assert example.value == 123.456
assert example.data == {"foo": 123}
assert example.size == MyEnum.BIG
await example.update(data={"foo": 123}, value=123.456) await example.update(data={"foo": 123}, value=123.456)
await example.load() await example.load()
@ -59,3 +127,18 @@ async def test_model_crud():
assert example.data == {"foo": 123} assert example.data == {"foo": 123}
await example.delete() await example.delete()
@pytest.mark.asyncio
async def test_invalid_enum_field():
async with database:
with pytest.raises(ModelDefinitionError):
class Example2(ormar.Model):
class Meta:
tablename = "example"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
size: MyEnum = ormar.Enum(enum_class=[])

View File

@ -1,4 +1,5 @@
from typing import List, Optional from enum import Enum
from typing import Optional
import databases import databases
import pydantic import pydantic
@ -19,6 +20,11 @@ database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
class MySize(Enum):
SMALL = 0
BIG = 1
class Book(ormar.Model): class Book(ormar.Model):
class Meta: class Meta:
tablename = "books" tablename = "books"
@ -45,6 +51,7 @@ class ToDo(ormar.Model):
text: str = ormar.String(max_length=500) text: str = ormar.String(max_length=500)
completed: bool = ormar.Boolean(default=False) completed: bool = ormar.Boolean(default=False)
pairs: pydantic.Json = ormar.JSON(default=[]) pairs: pydantic.Json = ormar.JSON(default=[])
size = ormar.Enum(enum_class=MySize, default=MySize.SMALL)
class Category(ormar.Model): class Category(ormar.Model):
@ -77,6 +84,7 @@ class ItemConfig(ormar.Model):
id: Optional[int] = ormar.Integer(primary_key=True) id: Optional[int] = ormar.Integer(primary_key=True)
item_id: str = ormar.String(max_length=32, index=True) item_id: str = ormar.String(max_length=32, index=True)
pairs: pydantic.Json = ormar.JSON(default=["2", "3"]) pairs: pydantic.Json = ormar.JSON(default=["2", "3"])
size = ormar.Enum(enum_class=MySize, default=MySize.SMALL)
class QuerySetCls(QuerySet): class QuerySetCls(QuerySet):
@ -343,6 +351,7 @@ async def test_bulk_update():
for todo in todoes: for todo in todoes:
todo.text = todo.text + "_1" todo.text = todo.text + "_1"
todo.completed = False todo.completed = False
todo.size = MySize.BIG
await ToDo.objects.bulk_update(todoes) await ToDo.objects.bulk_update(todoes)
@ -354,6 +363,7 @@ async def test_bulk_update():
for todo in todoes: for todo in todoes:
assert todo.text[-2:] == "_1" assert todo.text[-2:] == "_1"
assert todo.size == MySize.BIG
@pytest.mark.asyncio @pytest.mark.asyncio
@ -474,3 +484,13 @@ async def test_custom_queryset_cls():
await Customer(name="test").save() await Customer(name="test").save()
c = await Customer.objects.first_or_404(name="test") c = await Customer.objects.first_or_404(name="test")
assert c.name == "test" assert c.name == "test"
@pytest.mark.asyncio
async def test_filter_enum():
async with database:
it = ItemConfig(item_id="test_1")
await it.save()
it = await ItemConfig.objects.filter(size=MySize.SMALL).first()
assert it