@ -26,7 +26,7 @@ class UndefinedType: # pragma no cover
|
|||||||
|
|
||||||
Undefined = UndefinedType()
|
Undefined = UndefinedType()
|
||||||
|
|
||||||
__version__ = "0.3.1"
|
__version__ = "0.3.2"
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Integer",
|
"Integer",
|
||||||
"BigInteger",
|
"BigInteger",
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from typing import Any, List, Optional, TYPE_CHECKING, Union
|
from typing import Any, List, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from pydantic import Field
|
from pydantic import Field, typing
|
||||||
|
|
||||||
from ormar import ModelDefinitionError # noqa I101
|
from ormar import ModelDefinitionError # noqa I101
|
||||||
|
|
||||||
@ -23,6 +23,7 @@ class BaseField:
|
|||||||
unique: bool
|
unique: bool
|
||||||
pydantic_only: bool
|
pydantic_only: bool
|
||||||
virtual: bool = False
|
virtual: bool = False
|
||||||
|
choices: typing.Sequence
|
||||||
|
|
||||||
default: Any
|
default: Any
|
||||||
server_default: Any
|
server_default: Any
|
||||||
|
|||||||
@ -18,7 +18,7 @@ def is_field_nullable(
|
|||||||
|
|
||||||
|
|
||||||
class ModelFieldFactory:
|
class ModelFieldFactory:
|
||||||
_bases = None
|
_bases = BaseField
|
||||||
_type = None
|
_type = None
|
||||||
|
|
||||||
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]:
|
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]:
|
||||||
@ -40,6 +40,7 @@ class ModelFieldFactory:
|
|||||||
pydantic_only=kwargs.pop("pydantic_only", False),
|
pydantic_only=kwargs.pop("pydantic_only", False),
|
||||||
autoincrement=kwargs.pop("autoincrement", False),
|
autoincrement=kwargs.pop("autoincrement", False),
|
||||||
column_type=cls.get_column_type(**kwargs),
|
column_type=cls.get_column_type(**kwargs),
|
||||||
|
choices=set(kwargs.pop("choices", [])),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
return type(cls.__name__, cls._bases, namespace)
|
return type(cls.__name__, cls._bases, namespace)
|
||||||
@ -117,6 +118,8 @@ class Integer(ModelFieldFactory):
|
|||||||
if k not in ["cls", "__class__", "kwargs"]
|
if k not in ["cls", "__class__", "kwargs"]
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
kwargs["ge"] = kwargs["minimum"]
|
||||||
|
kwargs["le"] = kwargs["maximum"]
|
||||||
return super().__new__(cls, **kwargs)
|
return super().__new__(cls, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -166,6 +169,8 @@ class Float(ModelFieldFactory):
|
|||||||
if k not in ["cls", "__class__", "kwargs"]
|
if k not in ["cls", "__class__", "kwargs"]
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
kwargs["ge"] = kwargs["minimum"]
|
||||||
|
kwargs["le"] = kwargs["maximum"]
|
||||||
return super().__new__(cls, **kwargs)
|
return super().__new__(cls, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -251,6 +256,19 @@ class Decimal(ModelFieldFactory):
|
|||||||
if k not in ["cls", "__class__", "kwargs"]
|
if k not in ["cls", "__class__", "kwargs"]
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
kwargs["ge"] = kwargs["minimum"]
|
||||||
|
kwargs["le"] = kwargs["maximum"]
|
||||||
|
|
||||||
|
if kwargs.get("max_digits"):
|
||||||
|
kwargs["scale"] = kwargs["max_digits"]
|
||||||
|
elif kwargs.get("scale"):
|
||||||
|
kwargs["max_digits"] = kwargs["scale"]
|
||||||
|
|
||||||
|
if kwargs.get("decimal_places"):
|
||||||
|
kwargs["precision"] = kwargs["decimal_places"]
|
||||||
|
elif kwargs.get("precision"):
|
||||||
|
kwargs["decimal_places"] = kwargs["precision"]
|
||||||
|
|
||||||
return super().__new__(cls, **kwargs)
|
return super().__new__(cls, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import sqlalchemy
|
|||||||
from pydantic import BaseConfig
|
from pydantic import BaseConfig
|
||||||
from pydantic.fields import FieldInfo, ModelField
|
from pydantic.fields import FieldInfo, ModelField
|
||||||
|
|
||||||
|
import ormar # noqa I100
|
||||||
from ormar import ForeignKey, ModelDefinitionError, Integer # noqa I100
|
from ormar import ForeignKey, ModelDefinitionError, Integer # noqa I100
|
||||||
from ormar.fields import BaseField
|
from ormar.fields import BaseField
|
||||||
from ormar.fields.foreign_key import ForeignKeyField
|
from ormar.fields.foreign_key import ForeignKeyField
|
||||||
@ -108,7 +109,7 @@ def create_pydantic_field(
|
|||||||
type_=Optional[model],
|
type_=Optional[model],
|
||||||
model_config=model.__config__,
|
model_config=model.__config__,
|
||||||
required=False,
|
required=False,
|
||||||
class_validators=model.__validators__,
|
class_validators={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -252,9 +253,43 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
|
|||||||
return Config
|
return Config
|
||||||
|
|
||||||
|
|
||||||
|
def check_if_field_has_choices(field: BaseField) -> bool:
|
||||||
|
return hasattr(field, "choices") and field.choices
|
||||||
|
|
||||||
|
|
||||||
|
def model_initialized_and_has_model_fields(model: Type["Model"]) -> bool:
|
||||||
|
return hasattr(model, "Meta") and hasattr(model.Meta, "model_fields")
|
||||||
|
|
||||||
|
|
||||||
|
def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
for field_name, field in cls.Meta.model_fields.items():
|
||||||
|
if check_if_field_has_choices(field):
|
||||||
|
value = values.get(field_name, ormar.Undefined)
|
||||||
|
if value is not ormar.Undefined and value not in field.choices:
|
||||||
|
raise ValueError(
|
||||||
|
f"{field_name}: '{values.get(field_name)}' "
|
||||||
|
f"not in allowed choices set:"
|
||||||
|
f" {field.choices}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def populate_choices_validators( # noqa CCR001
|
||||||
|
model: Type["Model"], attrs: Dict
|
||||||
|
) -> None:
|
||||||
|
if model_initialized_and_has_model_fields(model):
|
||||||
|
for _, field in model.Meta.model_fields.items():
|
||||||
|
if check_if_field_has_choices(field):
|
||||||
|
validators = attrs.get("__pre_root_validators__", [])
|
||||||
|
if choices_validator not in validators:
|
||||||
|
validators.append(choices_validator)
|
||||||
|
attrs["__pre_root_validators__"] = validators
|
||||||
|
|
||||||
|
|
||||||
class ModelMetaclass(pydantic.main.ModelMetaclass):
|
class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||||
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
|
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
|
||||||
attrs["Config"] = get_pydantic_base_orm_config()
|
attrs["Config"] = get_pydantic_base_orm_config()
|
||||||
|
attrs["__name__"] = name
|
||||||
attrs = extract_annotations_and_default_vals(attrs, bases)
|
attrs = extract_annotations_and_default_vals(attrs, bases)
|
||||||
new_model = super().__new__( # type: ignore
|
new_model = super().__new__( # type: ignore
|
||||||
mcs, name, bases, attrs
|
mcs, name, bases, attrs
|
||||||
@ -262,11 +297,11 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
|||||||
# breakpoint()
|
# breakpoint()
|
||||||
|
|
||||||
if hasattr(new_model, "Meta"):
|
if hasattr(new_model, "Meta"):
|
||||||
# attrs = extract_annotations_and_default_vals(attrs, bases)
|
|
||||||
new_model = populate_meta_orm_model_fields(attrs, new_model)
|
new_model = populate_meta_orm_model_fields(attrs, new_model)
|
||||||
new_model = populate_meta_tablename_columns_and_pk(name, new_model)
|
new_model = populate_meta_tablename_columns_and_pk(name, new_model)
|
||||||
new_model = populate_meta_sqlalchemy_table_if_required(new_model)
|
new_model = populate_meta_sqlalchemy_table_if_required(new_model)
|
||||||
expand_reverse_relationships(new_model)
|
expand_reverse_relationships(new_model)
|
||||||
|
populate_choices_validators(new_model, attrs)
|
||||||
|
|
||||||
if new_model.Meta.pkname not in attrs["__annotations__"]:
|
if new_model.Meta.pkname not in attrs["__annotations__"]:
|
||||||
field_name = new_model.Meta.pkname
|
field_name = new_model.Meta.pkname
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
import databases
|
import databases
|
||||||
import pytest
|
import pytest
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
from pydantic import root_validator, validator
|
||||||
|
|
||||||
import ormar
|
import ormar
|
||||||
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
|
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
|
||||||
@ -50,7 +53,7 @@ class Organisation(ormar.Model):
|
|||||||
database = database
|
database = database
|
||||||
|
|
||||||
id: ormar.Integer(primary_key=True)
|
id: ormar.Integer(primary_key=True)
|
||||||
ident: ormar.String(max_length=100)
|
ident: ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"])
|
||||||
|
|
||||||
|
|
||||||
class Team(ormar.Model):
|
class Team(ormar.Model):
|
||||||
@ -106,6 +109,7 @@ async def test_related_name():
|
|||||||
await Cover.objects.create(album=album, title="The cover file")
|
await Cover.objects.create(album=album, title="The cover file")
|
||||||
assert len(album.cover_pictures) == 1
|
assert len(album.cover_pictures) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_model_crud():
|
async def test_model_crud():
|
||||||
async with database:
|
async with database:
|
||||||
@ -209,7 +213,6 @@ async def test_model_removal_from_relations():
|
|||||||
assert len(album.tracks) == 1
|
assert len(album.tracks) == 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fk_filter():
|
async def test_fk_filter():
|
||||||
async with database:
|
async with database:
|
||||||
@ -223,9 +226,13 @@ async def test_fk_filter():
|
|||||||
await Track.objects.create(album=malibu, title="The Waters", position=3)
|
await Track.objects.create(album=malibu, title="The Waters", position=3)
|
||||||
|
|
||||||
fantasies = await Album.objects.create(name="Fantasies")
|
fantasies = await Album.objects.create(name="Fantasies")
|
||||||
await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1)
|
await Track.objects.create(
|
||||||
|
album=fantasies, title="Help I'm Alive", position=1
|
||||||
|
)
|
||||||
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
|
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
|
||||||
await Track.objects.create(album=fantasies, title="Satellite Mind", position=3)
|
await Track.objects.create(
|
||||||
|
album=fantasies, title="Satellite Mind", position=3
|
||||||
|
)
|
||||||
|
|
||||||
tracks = (
|
tracks = (
|
||||||
await Track.objects.select_related("album")
|
await Track.objects.select_related("album")
|
||||||
@ -253,7 +260,9 @@ async def test_fk_filter():
|
|||||||
tracks = await Track.objects.filter(album__name__contains="Malibu%").all()
|
tracks = await Track.objects.filter(album__name__contains="Malibu%").all()
|
||||||
assert len(tracks) == 3
|
assert len(tracks) == 3
|
||||||
|
|
||||||
tracks = await Track.objects.filter(album=malibu).select_related("album").all()
|
tracks = (
|
||||||
|
await Track.objects.filter(album=malibu).select_related("album").all()
|
||||||
|
)
|
||||||
assert len(tracks) == 3
|
assert len(tracks) == 3
|
||||||
for track in tracks:
|
for track in tracks:
|
||||||
assert track.album.name == "Malibu%"
|
assert track.album.name == "Malibu%"
|
||||||
@ -290,15 +299,27 @@ async def test_multiple_fk():
|
|||||||
assert member.team.org.ident == "ACME Ltd"
|
assert member.team.org.ident == "ACME Ltd"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wrong_choices():
|
||||||
|
async with database:
|
||||||
|
async with database.transaction(force_rollback=True):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await Organisation.objects.create(ident="Test 1")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pk_filter():
|
async def test_pk_filter():
|
||||||
async with database:
|
async with database:
|
||||||
async with database.transaction(force_rollback=True):
|
async with database.transaction(force_rollback=True):
|
||||||
fantasies = await Album.objects.create(name="Test")
|
fantasies = await Album.objects.create(name="Test")
|
||||||
track = await Track.objects.create(album=fantasies, title="Test1", position=1)
|
track = await Track.objects.create(
|
||||||
|
album=fantasies, title="Test1", position=1
|
||||||
|
)
|
||||||
await Track.objects.create(album=fantasies, title="Test2", position=2)
|
await Track.objects.create(album=fantasies, title="Test2", position=2)
|
||||||
await Track.objects.create(album=fantasies, title="Test3", position=3)
|
await Track.objects.create(album=fantasies, title="Test3", position=3)
|
||||||
tracks = await Track.objects.select_related("album").filter(pk=track.pk).all()
|
tracks = (
|
||||||
|
await Track.objects.select_related("album").filter(pk=track.pk).all()
|
||||||
|
)
|
||||||
assert len(tracks) == 1
|
assert len(tracks) == 1
|
||||||
|
|
||||||
tracks = (
|
tracks = (
|
||||||
@ -314,7 +335,9 @@ async def test_limit_and_offset():
|
|||||||
async with database:
|
async with database:
|
||||||
async with database.transaction(force_rollback=True):
|
async with database.transaction(force_rollback=True):
|
||||||
fantasies = await Album.objects.create(name="Limitless")
|
fantasies = await Album.objects.create(name="Limitless")
|
||||||
await Track.objects.create(id=None, album=fantasies, title="Sample", position=1)
|
await Track.objects.create(
|
||||||
|
id=None, album=fantasies, title="Sample", position=1
|
||||||
|
)
|
||||||
await Track.objects.create(album=fantasies, title="Sample2", position=2)
|
await Track.objects.create(album=fantasies, title="Sample2", position=2)
|
||||||
await Track.objects.create(album=fantasies, title="Sample3", position=3)
|
await Track.objects.create(album=fantasies, title="Sample3", position=3)
|
||||||
|
|
||||||
|
|||||||
@ -126,7 +126,9 @@ async def test_quering_of_the_m2m_models(cleanup):
|
|||||||
category = await Category.objects.filter(posts__author=guido).get()
|
category = await Category.objects.filter(posts__author=guido).get()
|
||||||
assert category == news
|
assert category == news
|
||||||
# or:
|
# or:
|
||||||
category2 = await Category.objects.filter(posts__author__first_name="Guido").get()
|
category2 = await Category.objects.filter(
|
||||||
|
posts__author__first_name="Guido"
|
||||||
|
).get()
|
||||||
assert category2 == news
|
assert category2 == news
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -28,6 +28,7 @@ class ExampleModel(Model):
|
|||||||
test_json: fields.JSON(default={})
|
test_json: fields.JSON(default={})
|
||||||
test_bigint: fields.BigInteger(default=0)
|
test_bigint: fields.BigInteger(default=0)
|
||||||
test_decimal: fields.Decimal(scale=10, precision=2)
|
test_decimal: fields.Decimal(scale=10, precision=2)
|
||||||
|
test_decimal2: fields.Decimal(max_digits=10, decimal_places=2)
|
||||||
|
|
||||||
|
|
||||||
fields_to_check = [
|
fields_to_check = [
|
||||||
@ -55,7 +56,11 @@ class ExampleModel2(Model):
|
|||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def example():
|
def example():
|
||||||
return ExampleModel(
|
return ExampleModel(
|
||||||
pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5)
|
pk=1,
|
||||||
|
test_string="test",
|
||||||
|
test_bool=True,
|
||||||
|
test_decimal=decimal.Decimal(3.5),
|
||||||
|
test_decimal2=decimal.Decimal(5.5),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -73,6 +78,8 @@ def test_model_attribute_access(example):
|
|||||||
assert example.test_float is None
|
assert example.test_float is None
|
||||||
assert example.test_bigint == 0
|
assert example.test_bigint == 0
|
||||||
assert example.test_json == {}
|
assert example.test_json == {}
|
||||||
|
assert example.test_decimal == 3.5
|
||||||
|
assert example.test_decimal2 == 5.5
|
||||||
|
|
||||||
example.test = 12
|
example.test = 12
|
||||||
assert example.test == 12
|
assert example.test == 12
|
||||||
@ -112,6 +119,7 @@ def test_sqlalchemy_table_is_created(example):
|
|||||||
|
|
||||||
def test_no_pk_in_model_definition():
|
def test_no_pk_in_model_definition():
|
||||||
with pytest.raises(ModelDefinitionError):
|
with pytest.raises(ModelDefinitionError):
|
||||||
|
|
||||||
class ExampleModel2(Model):
|
class ExampleModel2(Model):
|
||||||
class Meta:
|
class Meta:
|
||||||
tablename = "example3"
|
tablename = "example3"
|
||||||
@ -122,6 +130,7 @@ def test_no_pk_in_model_definition():
|
|||||||
|
|
||||||
def test_two_pks_in_model_definition():
|
def test_two_pks_in_model_definition():
|
||||||
with pytest.raises(ModelDefinitionError):
|
with pytest.raises(ModelDefinitionError):
|
||||||
|
|
||||||
class ExampleModel2(Model):
|
class ExampleModel2(Model):
|
||||||
class Meta:
|
class Meta:
|
||||||
tablename = "example3"
|
tablename = "example3"
|
||||||
@ -133,6 +142,7 @@ def test_two_pks_in_model_definition():
|
|||||||
|
|
||||||
def test_setting_pk_column_as_pydantic_only_in_model_definition():
|
def test_setting_pk_column_as_pydantic_only_in_model_definition():
|
||||||
with pytest.raises(ModelDefinitionError):
|
with pytest.raises(ModelDefinitionError):
|
||||||
|
|
||||||
class ExampleModel2(Model):
|
class ExampleModel2(Model):
|
||||||
class Meta:
|
class Meta:
|
||||||
tablename = "example4"
|
tablename = "example4"
|
||||||
@ -143,6 +153,7 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition():
|
|||||||
|
|
||||||
def test_decimal_error_in_model_definition():
|
def test_decimal_error_in_model_definition():
|
||||||
with pytest.raises(ModelDefinitionError):
|
with pytest.raises(ModelDefinitionError):
|
||||||
|
|
||||||
class ExampleModel2(Model):
|
class ExampleModel2(Model):
|
||||||
class Meta:
|
class Meta:
|
||||||
tablename = "example5"
|
tablename = "example5"
|
||||||
@ -153,6 +164,7 @@ def test_decimal_error_in_model_definition():
|
|||||||
|
|
||||||
def test_string_error_in_model_definition():
|
def test_string_error_in_model_definition():
|
||||||
with pytest.raises(ModelDefinitionError):
|
with pytest.raises(ModelDefinitionError):
|
||||||
|
|
||||||
class ExampleModel2(Model):
|
class ExampleModel2(Model):
|
||||||
class Meta:
|
class Meta:
|
||||||
tablename = "example6"
|
tablename = "example6"
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import databases
|
import databases
|
||||||
import pydantic
|
import pydantic
|
||||||
@ -47,6 +48,27 @@ class Product(ormar.Model):
|
|||||||
last_delivery: ormar.Date(default=datetime.now)
|
last_delivery: ormar.Date(default=datetime.now)
|
||||||
|
|
||||||
|
|
||||||
|
country_name_choices = ("Canada", "Algeria", "United States")
|
||||||
|
country_taxed_choices = (True,)
|
||||||
|
country_country_code_choices = (-10, 1, 213, 1200)
|
||||||
|
|
||||||
|
|
||||||
|
class Country(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "country"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
name: ormar.String(
|
||||||
|
max_length=9, choices=country_name_choices, default="Canada",
|
||||||
|
)
|
||||||
|
taxed: ormar.Boolean(choices=country_taxed_choices, default=True)
|
||||||
|
country_code: ormar.Integer(
|
||||||
|
minimum=0, maximum=1000, choices=country_country_code_choices, default=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def event_loop():
|
def event_loop():
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@ -242,7 +264,9 @@ async def test_model_limit_with_filter():
|
|||||||
await User.objects.create(name="Tom")
|
await User.objects.create(name="Tom")
|
||||||
await User.objects.create(name="Tom")
|
await User.objects.create(name="Tom")
|
||||||
|
|
||||||
assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2
|
assert (
|
||||||
|
len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -268,3 +292,85 @@ async def test_model_first():
|
|||||||
assert await User.objects.filter(name="Jane").first() == jane
|
assert await User.objects.filter(name="Jane").first() == jane
|
||||||
with pytest.raises(NoMatch):
|
with pytest.raises(NoMatch):
|
||||||
await User.objects.filter(name="Lucy").first()
|
await User.objects.filter(name="Lucy").first()
|
||||||
|
|
||||||
|
|
||||||
|
def not_contains(a, b):
|
||||||
|
return a not in b
|
||||||
|
|
||||||
|
|
||||||
|
def contains(a, b):
|
||||||
|
return a in b
|
||||||
|
|
||||||
|
|
||||||
|
def check_choices(values: tuple, ops: List):
|
||||||
|
ops_dict = {"in": contains, "out": not_contains}
|
||||||
|
checks = (country_name_choices, country_taxed_choices, country_country_code_choices)
|
||||||
|
assert all(
|
||||||
|
[ops_dict[op](value, check) for value, op, check in zip(values, ops, checks)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_choices():
|
||||||
|
"""Test that choices work properly for various types of fields."""
|
||||||
|
async with database:
|
||||||
|
# Test valid choices.
|
||||||
|
await asyncio.gather(
|
||||||
|
Country.objects.create(name="Canada", taxed=True, country_code=1),
|
||||||
|
Country.objects.create(name="Algeria", taxed=True, country_code=213),
|
||||||
|
Country.objects.create(name="Algeria"),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
name, taxed, country_code = "Saudi Arabia", True, 1
|
||||||
|
check_choices((name, taxed, country_code), ["out", "in", "in"])
|
||||||
|
await Country.objects.create(
|
||||||
|
name=name, taxed=taxed, country_code=country_code
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
name, taxed, country_code = "Algeria", False, 1
|
||||||
|
check_choices((name, taxed, country_code), ["in", "out", "in"])
|
||||||
|
await Country.objects.create(
|
||||||
|
name=name, taxed=taxed, country_code=country_code
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
name, taxed, country_code = "Algeria", True, 967
|
||||||
|
check_choices((name, taxed, country_code), ["in", "in", "out"])
|
||||||
|
await Country.objects.create(
|
||||||
|
name=name, taxed=taxed, country_code=country_code
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
name, taxed, country_code = (
|
||||||
|
"United States",
|
||||||
|
True,
|
||||||
|
1,
|
||||||
|
) # name is too long but is a valid choice
|
||||||
|
check_choices((name, taxed, country_code), ["in", "in", "in"])
|
||||||
|
await Country.objects.create(
|
||||||
|
name=name, taxed=taxed, country_code=country_code
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
name, taxed, country_code = (
|
||||||
|
"Algeria",
|
||||||
|
True,
|
||||||
|
-10,
|
||||||
|
) # country code is too small but is a valid choice
|
||||||
|
check_choices((name, taxed, country_code), ["in", "in", "in"])
|
||||||
|
await Country.objects.create(
|
||||||
|
name=name, taxed=taxed, country_code=country_code
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
name, taxed, country_code = (
|
||||||
|
"Algeria",
|
||||||
|
True,
|
||||||
|
1200,
|
||||||
|
) # country code is too large but is a valid choice
|
||||||
|
check_choices((name, taxed, country_code), ["in", "in", "in"])
|
||||||
|
await Country.objects.create(
|
||||||
|
name=name, taxed=taxed, country_code=country_code
|
||||||
|
)
|
||||||
|
|||||||
@ -94,6 +94,7 @@ async def create_data():
|
|||||||
await Student.objects.create(name="Jack", category=category2, schoolclass=class2)
|
await Student.objects.create(name="Jack", category=category2, schoolclass=class2)
|
||||||
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
|
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_model_multiple_instances_of_same_table_in_schema():
|
async def test_model_multiple_instances_of_same_table_in_schema():
|
||||||
async with database:
|
async with database:
|
||||||
|
|||||||
@ -44,4 +44,4 @@ async def test_pk_1():
|
|||||||
async def test_pk_2():
|
async def test_pk_2():
|
||||||
async with database:
|
async with database:
|
||||||
model = await Model.objects.create(name="NAME")
|
model = await Model.objects.create(name="NAME")
|
||||||
assert await Model.objects.all() == [model]
|
assert await Model.objects.all() == [model]
|
||||||
|
|||||||
@ -112,10 +112,14 @@ async def test_model_multiple_instances_of_same_table_in_schema():
|
|||||||
assert classes[0].students[0].schoolclass.name == "Math"
|
assert classes[0].students[0].schoolclass.name == "Math"
|
||||||
assert classes[0].students[0].schoolclass.department.name is None
|
assert classes[0].students[0].schoolclass.department.name is None
|
||||||
await classes[0].students[0].schoolclass.department.load()
|
await classes[0].students[0].schoolclass.department.load()
|
||||||
assert classes[0].students[0].schoolclass.department.name == "Math Department"
|
assert (
|
||||||
|
classes[0].students[0].schoolclass.department.name == "Math Department"
|
||||||
|
)
|
||||||
|
|
||||||
await classes[1].students[0].schoolclass.department.load()
|
await classes[1].students[0].schoolclass.department.load()
|
||||||
assert classes[1].students[0].schoolclass.department.name == "Law Department"
|
assert (
|
||||||
|
classes[1].students[0].schoolclass.department.name == "Law Department"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Reference in New Issue
Block a user