add enum field (#626)

* add enum field

* add decorator for asyncio

* fix enum typing, additional tests, add docs

* add more tests

Co-authored-by: collerek <collerek@gmail.com>
This commit is contained in:
Ethon
2022-04-27 18:01:00 +08:00
committed by GitHub
parent 2caa17812a
commit ebf7c6e06f
10 changed files with 193 additions and 17 deletions

View File

@ -637,10 +637,11 @@ Available Model Fields (with required args - optional ones in docs):
* `Decimal(scale, precision)`
* `UUID()`
* `LargeBinary(max_length)`
* `EnumField` - by passing `choices` to any other Field type
* `Enum(enum_class)`
* `Enum` like Field - by passing `choices` to any other Field type
* `EncryptedString` - by passing `encrypt_secret` and `encrypt_backend`
* `ForeignKey(to)`
* `ManyToMany(to, through)`
* `ManyToMany(to)`
### Available fields options
The following keyword arguments are supported on all field types.

View File

@ -200,8 +200,23 @@ When loaded it's always python UUID so you can compare it and compare two format
### Enum
Although there is no dedicated field type for Enums in `ormar` you can change any
field into `Enum` like field by passing a `choices` list that is accepted by all Field types.
There are two ways to use enums in ormar -> one is a dedicated `Enum` field that uses `sqlalchemy.Enum` column type, while the other is setting `choices` on any field in ormar.
The Enum field uses the database dialect specific Enum column type if it's available, but fallback to varchar if this field type is not available.
The `choices` option always respect the database field type selected.
So which one to use depends on the backend you use and on the column/ data type you want in your Enum field.
#### Enum - Field
`Enum(enum_class=Type[Enum])` has a required `enum_class` parameter.
* Sqlalchemy column: `sqlalchemy.Enum`
* Type (used for pydantic): `Type[Enum]`
#### Choices
You can change any field into `Enum` like field by passing a `choices` list that is accepted by all Field types.
It will add both: validation in `pydantic` model and will display available options in schema,
therefore it will be available in docs of `fastapi`.
@ -210,7 +225,7 @@ If you still want to use `Enum` in your application you can do this by passing a
and later pass value of given option to a given field (note tha Enum is not JsonSerializable).
```python
# not that imports and endpoints declaration
# note that imports and endpoints declaration
# is skipped here for brevity
from enum import Enum
class TestEnum(Enum):
@ -244,4 +259,4 @@ response = client.post(
[relations]: ../relations/index.md
[queries]: ../queries.md
[pydantic]: https://pydantic-docs.helpmanual.io/usage/schema/#field-customisation
[server default]: https://docs.sqlalchemy.org/en/13/core/defaults.html#server-invoked-ddl-explicit-default-expressions
[server default]: https://docs.sqlalchemy.org/en/13/core/defaults.html#server-invoked-ddl-explicit-default-expressions

View File

@ -646,10 +646,11 @@ Available Model Fields (with required args - optional ones in docs):
* `Decimal(scale, precision)`
* `UUID()`
* `LargeBinary(max_length)`
* `EnumField` - by passing `choices` to any other Field type
* `Enum(enum_class)`
* `Enum` like Field - by passing `choices` to any other Field type
* `EncryptedString` - by passing `encrypt_secret` and `encrypt_backend`
* `ForeignKey(to)`
* `ManyToMany(to, through)`
* `ManyToMany(to)`
### Available fields options
The following keyword arguments are supported on all field types.

View File

@ -53,6 +53,7 @@ from ormar.fields import (
Decimal,
ENCODERS_MAP,
EncryptBackends,
Enum,
Float,
ForeignKey,
ForeignKeyField,
@ -97,6 +98,7 @@ __all__ = [
"DateTime",
"Date",
"Decimal",
"Enum",
"Float",
"ManyToMany",
"Model",

View File

@ -14,6 +14,7 @@ from ormar.fields.model_fields import (
Date,
DateTime,
Decimal,
Enum,
Float,
Integer,
JSON,
@ -43,6 +44,7 @@ __all__ = [
"Float",
"Time",
"UUID",
"Enum",
"ForeignKey",
"ManyToMany",
"ManyToManyField",

View File

@ -1,8 +1,8 @@
import datetime
import decimal
import uuid
from enum import Enum
from typing import Any, Optional, Set, TYPE_CHECKING, Type, Union, overload
from enum import EnumMeta, Enum as E
from typing import Any, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union, overload
import pydantic
import sqlalchemy
@ -91,7 +91,7 @@ def convert_choices_if_needed(
:return: value, choices list
:rtype: Tuple[Any, Set]
"""
choices = {o.value if isinstance(o, Enum) else o for o in choices}
choices = {o.value if isinstance(o, E) else o for o in choices}
encoder = ormar.ENCODERS_MAP.get(field_type, lambda x: x)
if field_type == decimal.Decimal:
precision = scale
@ -150,12 +150,14 @@ class ModelFieldFactory:
scale=kwargs.get("scale", None),
represent_as_str=kwargs.get("represent_as_base64_str", False),
)
enum_class = kwargs.get("enum_class", None)
field_type = cls._type if enum_class is None else enum_class
namespace = dict(
__type__=cls._type,
__type__=field_type,
__pydantic_type__=overwrite_pydantic_type
if overwrite_pydantic_type is not None
else cls._type,
else field_type,
__sample__=cls._sample,
alias=kwargs.pop("name", None),
name=None,
@ -803,3 +805,46 @@ class UUID(ModelFieldFactory, uuid.UUID):
"""
uuid_format = kwargs.get("uuid_format", "hex")
return sqlalchemy_uuid.UUID(uuid_format=uuid_format)
if TYPE_CHECKING: # pragma: nocover
T = TypeVar("T", bound=E)
def Enum(enum_class: Type[T], **kwargs: Any) -> T:
pass
else:
class Enum(ModelFieldFactory):
"""
Enum field factory that construct Field classes and populated their values.
"""
_type = E
_sample = None
def __new__( # type: ignore # noqa CFQ002
cls, *, enum_class: Type[E], **kwargs: Any
) -> BaseField:
kwargs = {
**kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def validate(cls, **kwargs: Any) -> None:
enum_class = kwargs.get("enum_class")
if enum_class is None or not isinstance(enum_class, EnumMeta):
raise ModelDefinitionError("Enum Field choices must be EnumType")
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
enum_cls = kwargs.get("enum_class")
return sqlalchemy.Enum(enum_cls)

View File

@ -1,5 +1,6 @@
import base64
import uuid
from enum import Enum
from typing import (
Any,
Callable,
@ -73,6 +74,14 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
new_kwargs = cls.reconvert_str_to_bytes(new_kwargs)
new_kwargs = cls.dump_all_json_fields_to_str(new_kwargs)
new_kwargs = cls.translate_columns_to_aliases(new_kwargs)
new_kwargs = cls.translate_enum_columns(new_kwargs)
return new_kwargs
@classmethod
def translate_enum_columns(cls, new_kwargs: dict) -> dict:
for k, v in new_kwargs.items():
if isinstance(v, Enum):
new_kwargs[k] = v.name
return new_kwargs
@classmethod

View File

@ -698,9 +698,7 @@ class QuerySet(Generic[T]):
expr = sqlalchemy.func.count().select().select_from(expr)
if distinct:
pk_column_name = self.model.get_column_alias(self.model_meta.pkname)
expr_distinct = expr.group_by(pk_column_name).alias(
"subquery_for_group"
)
expr_distinct = expr.group_by(pk_column_name).alias("subquery_for_group")
expr = sqlalchemy.func.count().select().select_from(expr_distinct)
return await self.database.fetch_val(expr)

View File

@ -1,4 +1,5 @@
import datetime
from enum import Enum
import databases
import pydantic
@ -6,6 +7,7 @@ import pytest
import sqlalchemy
import ormar
from ormar import ModelDefinitionError
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
@ -16,6 +18,11 @@ def time():
return datetime.datetime.now().time()
class MyEnum(Enum):
SMALL = 1
BIG = 2
class Example(ormar.Model):
class Meta:
tablename = "example"
@ -30,6 +37,17 @@ class Example(ormar.Model):
description: str = ormar.Text(nullable=True)
value: float = ormar.Float(nullable=True)
data: pydantic.Json = ormar.JSON(default={})
size: MyEnum = ormar.Enum(enum_class=MyEnum, default=MyEnum.SMALL)
class EnumExample(ormar.Model):
class Meta:
tablename = "enum_example"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
size: MyEnum = ormar.Enum(enum_class=MyEnum, default=MyEnum.SMALL)
@pytest.fixture(autouse=True, scope="module")
@ -40,6 +58,49 @@ def create_test_database():
metadata.drop_all(engine)
def test_proper_enum_column_type():
assert Example.__fields__["size"].type_ == MyEnum
def test_accepts_only_proper_enums():
class WrongEnum(Enum):
A = 1
B = 2
with pytest.raises(pydantic.ValidationError):
Example(size=WrongEnum.A)
@pytest.mark.asyncio
async def test_enum_bulk_operations():
async with database:
examples = [EnumExample(), EnumExample()]
await EnumExample.objects.bulk_create(examples)
check = await EnumExample.objects.all()
assert all(x.size == MyEnum.SMALL for x in check)
for x in check:
x.size = MyEnum.BIG
await EnumExample.objects.bulk_update(check)
check2 = await EnumExample.objects.all()
assert all(x.size == MyEnum.BIG for x in check2)
@pytest.mark.asyncio
async def test_enum_filter():
async with database:
examples = [EnumExample(), EnumExample(size=MyEnum.BIG)]
await EnumExample.objects.bulk_create(examples)
check = await EnumExample.objects.all(size=MyEnum.SMALL)
assert len(check) == 1
check = await EnumExample.objects.all(size=MyEnum.BIG)
assert len(check) == 1
@pytest.mark.asyncio
async def test_model_crud():
async with database:
@ -52,6 +113,13 @@ async def test_model_crud():
assert example.description is None
assert example.value is None
assert example.data == {}
assert example.size == MyEnum.SMALL
await example.update(data={"foo": 123}, value=123.456, size=MyEnum.BIG)
await example.load()
assert example.value == 123.456
assert example.data == {"foo": 123}
assert example.size == MyEnum.BIG
await example.update(data={"foo": 123}, value=123.456)
await example.load()
@ -59,3 +127,18 @@ async def test_model_crud():
assert example.data == {"foo": 123}
await example.delete()
@pytest.mark.asyncio
async def test_invalid_enum_field():
async with database:
with pytest.raises(ModelDefinitionError):
class Example2(ormar.Model):
class Meta:
tablename = "example"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
size: MyEnum = ormar.Enum(enum_class=[])

View File

@ -1,4 +1,5 @@
from typing import List, Optional
from enum import Enum
from typing import Optional
import databases
import pydantic
@ -19,6 +20,11 @@ database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class MySize(Enum):
SMALL = 0
BIG = 1
class Book(ormar.Model):
class Meta:
tablename = "books"
@ -45,6 +51,7 @@ class ToDo(ormar.Model):
text: str = ormar.String(max_length=500)
completed: bool = ormar.Boolean(default=False)
pairs: pydantic.Json = ormar.JSON(default=[])
size = ormar.Enum(enum_class=MySize, default=MySize.SMALL)
class Category(ormar.Model):
@ -77,6 +84,7 @@ class ItemConfig(ormar.Model):
id: Optional[int] = ormar.Integer(primary_key=True)
item_id: str = ormar.String(max_length=32, index=True)
pairs: pydantic.Json = ormar.JSON(default=["2", "3"])
size = ormar.Enum(enum_class=MySize, default=MySize.SMALL)
class QuerySetCls(QuerySet):
@ -343,6 +351,7 @@ async def test_bulk_update():
for todo in todoes:
todo.text = todo.text + "_1"
todo.completed = False
todo.size = MySize.BIG
await ToDo.objects.bulk_update(todoes)
@ -354,6 +363,7 @@ async def test_bulk_update():
for todo in todoes:
assert todo.text[-2:] == "_1"
assert todo.size == MySize.BIG
@pytest.mark.asyncio
@ -474,3 +484,13 @@ async def test_custom_queryset_cls():
await Customer(name="test").save()
c = await Customer.objects.first_or_404(name="test")
assert c.name == "test"
@pytest.mark.asyncio
async def test_filter_enum():
async with database:
it = ItemConfig(item_id="test_1")
await it.save()
it = await ItemConfig.objects.filter(size=MySize.SMALL).first()
assert it