fix for minimu and maximum in validators, added validator for choices and choices param for BaseField, include tests

This commit is contained in:
collerek
2020-09-22 20:50:24 +02:00
parent 9620452b44
commit ebd812bf00
6 changed files with 202 additions and 15 deletions

BIN
.coverage

Binary file not shown.

View File

@ -1,7 +1,7 @@
from typing import Any, List, Optional, TYPE_CHECKING, Union from typing import Any, List, Optional, TYPE_CHECKING, Union
import sqlalchemy import sqlalchemy
from pydantic import Field from pydantic import Field, typing
from ormar import ModelDefinitionError # noqa I101 from ormar import ModelDefinitionError # noqa I101
@ -23,6 +23,7 @@ class BaseField:
unique: bool unique: bool
pydantic_only: bool pydantic_only: bool
virtual: bool = False virtual: bool = False
choices: typing.Sequence
default: Any default: Any
server_default: Any server_default: Any

View File

@ -18,7 +18,7 @@ def is_field_nullable(
class ModelFieldFactory: class ModelFieldFactory:
_bases = None _bases = BaseField
_type = None _type = None
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]:
@ -40,6 +40,7 @@ class ModelFieldFactory:
pydantic_only=kwargs.pop("pydantic_only", False), pydantic_only=kwargs.pop("pydantic_only", False),
autoincrement=kwargs.pop("autoincrement", False), autoincrement=kwargs.pop("autoincrement", False),
column_type=cls.get_column_type(**kwargs), column_type=cls.get_column_type(**kwargs),
choices=set(kwargs.pop("choices", [])),
**kwargs **kwargs
) )
return type(cls.__name__, cls._bases, namespace) return type(cls.__name__, cls._bases, namespace)
@ -117,6 +118,8 @@ class Integer(ModelFieldFactory):
if k not in ["cls", "__class__", "kwargs"] if k not in ["cls", "__class__", "kwargs"]
}, },
} }
kwargs["ge"] = kwargs["minimum"]
kwargs["le"] = kwargs["maximum"]
return super().__new__(cls, **kwargs) return super().__new__(cls, **kwargs)
@classmethod @classmethod
@ -166,6 +169,8 @@ class Float(ModelFieldFactory):
if k not in ["cls", "__class__", "kwargs"] if k not in ["cls", "__class__", "kwargs"]
}, },
} }
kwargs["ge"] = kwargs["minimum"]
kwargs["le"] = kwargs["maximum"]
return super().__new__(cls, **kwargs) return super().__new__(cls, **kwargs)
@classmethod @classmethod
@ -251,6 +256,19 @@ class Decimal(ModelFieldFactory):
if k not in ["cls", "__class__", "kwargs"] if k not in ["cls", "__class__", "kwargs"]
}, },
} }
kwargs["ge"] = kwargs["minimum"]
kwargs["le"] = kwargs["maximum"]
if kwargs.get("max_digits"):
kwargs["scale"] = kwargs["max_digits"]
elif kwargs.get("scale"):
kwargs["max_digits"] = kwargs["scale"]
if kwargs.get("decimal_places"):
kwargs["precision"] = kwargs["decimal_places"]
elif kwargs.get("precision"):
kwargs["decimal_places"] = kwargs["precision"]
return super().__new__(cls, **kwargs) return super().__new__(cls, **kwargs)
@classmethod @classmethod

View File

@ -7,6 +7,7 @@ import sqlalchemy
from pydantic import BaseConfig from pydantic import BaseConfig
from pydantic.fields import FieldInfo, ModelField from pydantic.fields import FieldInfo, ModelField
import ormar # noqa I100
from ormar import ForeignKey, ModelDefinitionError, Integer # noqa I100 from ormar import ForeignKey, ModelDefinitionError, Integer # noqa I100
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
@ -108,7 +109,7 @@ def create_pydantic_field(
type_=Optional[model], type_=Optional[model],
model_config=model.__config__, model_config=model.__config__,
required=False, required=False,
class_validators=model.__validators__, class_validators={},
) )
@ -252,9 +253,41 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
return Config return Config
def check_if_field_has_choices(field: BaseField) -> bool:
return hasattr(field, "choices") and field.choices
def model_initialized_and_has_model_fields(model: Type["Model"]) -> bool:
return hasattr(model, "Meta") and hasattr(model.Meta, "model_fields")
def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, Any]:
for field_name, field in cls.Meta.model_fields.items():
if check_if_field_has_choices(field):
value = values.get(field_name, ormar.Undefined)
if value is not ormar.Undefined and value not in field.choices:
raise ValueError(
f"{field_name}: '{values.get(field_name)}' "
f"not in allowed choices set:"
f" {field.choices}"
)
return values
def populate_choices_validators(model: Type["Model"], attrs: Dict) -> None: # noqa CCR001
if model_initialized_and_has_model_fields(model):
for _, field in model.Meta.model_fields.items():
if check_if_field_has_choices(field):
validators = attrs.get("__pre_root_validators__", [])
if choices_validator not in validators:
validators.append(choices_validator)
attrs["__pre_root_validators__"] = validators
class ModelMetaclass(pydantic.main.ModelMetaclass): class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
attrs["Config"] = get_pydantic_base_orm_config() attrs["Config"] = get_pydantic_base_orm_config()
attrs["__name__"] = name
attrs = extract_annotations_and_default_vals(attrs, bases) attrs = extract_annotations_and_default_vals(attrs, bases)
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
@ -262,11 +295,11 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
# breakpoint() # breakpoint()
if hasattr(new_model, "Meta"): if hasattr(new_model, "Meta"):
# attrs = extract_annotations_and_default_vals(attrs, bases)
new_model = populate_meta_orm_model_fields(attrs, new_model) new_model = populate_meta_orm_model_fields(attrs, new_model)
new_model = populate_meta_tablename_columns_and_pk(name, new_model) new_model = populate_meta_tablename_columns_and_pk(name, new_model)
new_model = populate_meta_sqlalchemy_table_if_required(new_model) new_model = populate_meta_sqlalchemy_table_if_required(new_model)
expand_reverse_relationships(new_model) expand_reverse_relationships(new_model)
populate_choices_validators(new_model, attrs)
if new_model.Meta.pkname not in attrs["__annotations__"]: if new_model.Meta.pkname not in attrs["__annotations__"]:
field_name = new_model.Meta.pkname field_name = new_model.Meta.pkname

View File

@ -1,6 +1,9 @@
import asyncio
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
from pydantic import root_validator, validator
import ormar import ormar
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
@ -50,7 +53,7 @@ class Organisation(ormar.Model):
database = database database = database
id: ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
ident: ormar.String(max_length=100) ident: ormar.String(max_length=100, choices=['ACME Ltd', 'Other ltd'])
class Team(ormar.Model): class Team(ormar.Model):
@ -75,6 +78,25 @@ class Member(ormar.Model):
email: ormar.String(max_length=100) email: ormar.String(max_length=100)
country_name_choices = ("Canada", "Algeria", "United States")
country_taxed_choices = (True,)
country_country_code_choices = (-10, 1, 213, 1200)
class Country(ormar.Model):
class Meta:
tablename = "country"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=9, choices=country_name_choices, default="Canada", )
taxed: ormar.Boolean(choices=country_taxed_choices, default=True)
country_code: ormar.Integer(
minimum=0, maximum=1000, choices=country_country_code_choices, default=1
)
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
def create_test_database(): def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
@ -106,6 +128,109 @@ async def test_related_name():
await Cover.objects.create(album=album, title="The cover file") await Cover.objects.create(album=album, title="The cover file")
assert len(album.cover_pictures) == 1 assert len(album.cover_pictures) == 1
@pytest.mark.asyncio
async def test_model_choices():
"""Test that choices work properly for various types of fields."""
async with database:
# Test valid choices.
await asyncio.gather(
Country.objects.create(name="Canada", taxed=True, country_code=1),
Country.objects.create(name="Algeria", taxed=True, country_code=213),
Country.objects.create(name="Algeria"),
)
with pytest.raises(ValueError):
name, taxed, country_code = "Saudi Arabia", True, 1
assert all(
(
name not in country_name_choices,
taxed in country_taxed_choices,
country_code in country_country_code_choices,
)
)
await Country.objects.create(
name=name, taxed=taxed, country_code=country_code
)
with pytest.raises(ValueError):
name, taxed, country_code = "Algeria", False, 1
assert all(
(
name in country_name_choices,
taxed not in country_taxed_choices,
country_code in country_country_code_choices,
)
)
await Country.objects.create(
name=name, taxed=taxed, country_code=country_code
)
with pytest.raises(ValueError):
name, taxed, country_code = "Algeria", True, 967
assert all(
(
name in country_name_choices,
taxed in country_taxed_choices,
country_code not in country_country_code_choices,
)
)
await Country.objects.create(
name=name, taxed=taxed, country_code=country_code
)
with pytest.raises(ValueError):
name, taxed, country_code = (
"United States",
True,
1,
) # name is too long but is a valid choice
assert all(
(
name in country_name_choices,
taxed in country_taxed_choices,
country_code in country_country_code_choices,
)
)
await Country.objects.create(
name=name, taxed=taxed, country_code=country_code
)
with pytest.raises(ValueError):
name, taxed, country_code = (
"Algeria",
True,
-10,
) # country code is too small but is a valid choice
assert all(
(
name in country_name_choices,
taxed in country_taxed_choices,
country_code in country_country_code_choices,
)
)
await Country.objects.create(
name=name, taxed=taxed, country_code=country_code
)
with pytest.raises(ValueError):
name, taxed, country_code = (
"Algeria",
True,
1200,
) # country code is too large but is a valid choice
assert all(
(
name in country_name_choices,
taxed in country_taxed_choices,
country_code in country_country_code_choices,
)
)
await Country.objects.create(
name=name, taxed=taxed, country_code=country_code
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_crud(): async def test_model_crud():
async with database: async with database:
@ -209,7 +334,6 @@ async def test_model_removal_from_relations():
assert len(album.tracks) == 1 assert len(album.tracks) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fk_filter(): async def test_fk_filter():
async with database: async with database:
@ -229,8 +353,8 @@ async def test_fk_filter():
tracks = ( tracks = (
await Track.objects.select_related("album") await Track.objects.select_related("album")
.filter(album__name="Fantasies") .filter(album__name="Fantasies")
.all() .all()
) )
assert len(tracks) == 3 assert len(tracks) == 3
for track in tracks: for track in tracks:
@ -238,8 +362,8 @@ async def test_fk_filter():
tracks = ( tracks = (
await Track.objects.select_related("album") await Track.objects.select_related("album")
.filter(album__name__icontains="fan") .filter(album__name__icontains="fan")
.all() .all()
) )
assert len(tracks) == 3 assert len(tracks) == 3
for track in tracks: for track in tracks:
@ -282,14 +406,22 @@ async def test_multiple_fk():
members = ( members = (
await Member.objects.select_related("team__org") await Member.objects.select_related("team__org")
.filter(team__org__ident="ACME Ltd") .filter(team__org__ident="ACME Ltd")
.all() .all()
) )
assert len(members) == 4 assert len(members) == 4
for member in members: for member in members:
assert member.team.org.ident == "ACME Ltd" assert member.team.org.ident == "ACME Ltd"
@pytest.mark.asyncio
async def test_wrong_choices():
async with database:
async with database.transaction(force_rollback=True):
with pytest.raises(ValueError):
await Organisation.objects.create(ident="Test 1")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pk_filter(): async def test_pk_filter():
async with database: async with database:
@ -303,8 +435,8 @@ async def test_pk_filter():
tracks = ( tracks = (
await Track.objects.select_related("album") await Track.objects.select_related("album")
.filter(position=2, album__name="Test") .filter(position=2, album__name="Test")
.all() .all()
) )
assert len(tracks) == 1 assert len(tracks) == 1

View File

@ -28,6 +28,7 @@ class ExampleModel(Model):
test_json: fields.JSON(default={}) test_json: fields.JSON(default={})
test_bigint: fields.BigInteger(default=0) test_bigint: fields.BigInteger(default=0)
test_decimal: fields.Decimal(scale=10, precision=2) test_decimal: fields.Decimal(scale=10, precision=2)
test_decimal2: fields.Decimal(max_digits=10, decimal_places=2)
fields_to_check = [ fields_to_check = [
@ -55,7 +56,7 @@ class ExampleModel2(Model):
@pytest.fixture() @pytest.fixture()
def example(): def example():
return ExampleModel( return ExampleModel(
pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5) pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5), test_decimal2=decimal.Decimal(5.5)
) )
@ -73,6 +74,8 @@ def test_model_attribute_access(example):
assert example.test_float is None assert example.test_float is None
assert example.test_bigint == 0 assert example.test_bigint == 0
assert example.test_json == {} assert example.test_json == {}
assert example.test_decimal == 3.5
assert example.test_decimal2 == 5.5
example.test = 12 example.test = 12
assert example.test == 12 assert example.test == 12