diff --git a/ormar/models/helpers/validation.py b/ormar/models/helpers/validation.py index 617381f..e7d66fe 100644 --- a/ormar/models/helpers/validation.py +++ b/ormar/models/helpers/validation.py @@ -52,8 +52,8 @@ def convert_choices_if_needed( # noqa: CCR001 :param field: ormar field to check with choices :type field: BaseField - :param values: current values of the model to verify - :type values: Dict + :param value: current values of the model to verify + :type value: Dict :return: value, choices list :rtype: Tuple[Any, List] """ @@ -97,6 +97,8 @@ def validate_choices(field: "BaseField", value: Any) -> None: :type value: Any """ value, choices = convert_choices_if_needed(field=field, value=value) + if field.nullable: + choices.append(None) if value is not ormar.Undefined and value not in choices: raise ValueError( f"{field.name}: '{value}' " f"not in allowed choices set:" f" {choices}" diff --git a/tests/test_model_definition/test_models.py b/tests/test_model_definition/test_models.py index 71d12aa..583ba9d 100644 --- a/tests/test_model_definition/test_models.py +++ b/tests/test_model_definition/test_models.py @@ -134,6 +134,26 @@ class Country(ormar.Model): ) +class NullableCountry(ormar.Model): + class Meta: + tablename = "country2" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=9, choices=country_name_choices, nullable=True) + + +class NotNullableCountry(ormar.Model): + class Meta: + tablename = "country3" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=9, choices=country_name_choices, nullable=False) + + @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) @@ -538,6 +558,17 @@ async def test_model_choices(): await Country.objects.filter(name="Belize").update(name="Vietnam") +@pytest.mark.asyncio +async def test_nullable_field_model_choices(): + """Test that choices work properly for according to nullable setting""" + async with database: + c1 = await NullableCountry(name=None).save() + assert c1.name is None + + with pytest.raises(ValueError): + await NotNullableCountry(name=None).save() + + @pytest.mark.asyncio async def test_start_and_end_filters(): async with database: