From ebd812bf00410a735a2037dbd07ff3cab8be4f05 Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 22 Sep 2020 20:50:24 +0200 Subject: [PATCH 1/3] fix for minimu and maximum in validators, added validator for choices and choices param for BaseField, include tests --- .coverage | Bin 53248 -> 53248 bytes ormar/fields/base.py | 3 +- ormar/fields/model_fields.py | 20 ++++- ormar/models/metaclass.py | 37 +++++++- tests/test_foreign_keys.py | 152 ++++++++++++++++++++++++++++++--- tests/test_model_definition.py | 5 +- 6 files changed, 202 insertions(+), 15 deletions(-) diff --git a/.coverage b/.coverage index 6cf1da0ed99ceec5070504274b047b8b2788f153..9f9dda5dc04a39e78520795f599497b87aac6e2c 100644 GIT binary patch delta 336 zcmV-W0k8gmpaX!Q1F$nM3oivpz4#P*TAF5BU%358MyY55*6(51S8; z4`UBo4>u184*(AK4($%)4%-g74ww#w4s8x(4pqL34sP7dN<#E&v(B0&UeoD@#pjTvz`NUU=GX>m?1DjVEn(|Zv4N0oB#LjZg>9c zojdpTp8vaZe*Smz@6NZ+AG5`cNdYe{3k2^s0&Y%GR delta 313 zcmV-90mlA-paX!Q1F$nM3o$n!F*!OgG&(gjvpz4#P*T_c5BU%358MyY55*6(51S8; z4`mNs4>=DG4+akW4)YG`4&e^J4xSE+4s#A_4qgsT4n7Vm4jK*(4fGA$4bu(94Ur9T z4O9&)4HXRk4Dt-vvk?%043oi(CIQ)#=!{Jc@_#ya?%cU^=g#H-Pv0Mtj*l=0+yDOV L`+q*O#g98c^1_+? diff --git a/ormar/fields/base.py b/ormar/fields/base.py index da9a0d8..1be4b63 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -1,7 +1,7 @@ from typing import Any, List, Optional, TYPE_CHECKING, Union import sqlalchemy -from pydantic import Field +from pydantic import Field, typing from ormar import ModelDefinitionError # noqa I101 @@ -23,6 +23,7 @@ class BaseField: unique: bool pydantic_only: bool virtual: bool = False + choices: typing.Sequence default: Any server_default: Any diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index 616b37f..388f229 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -18,7 +18,7 @@ def is_field_nullable( class ModelFieldFactory: - _bases = None + _bases = BaseField _type = None def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: @@ -40,6 +40,7 @@ class ModelFieldFactory: pydantic_only=kwargs.pop("pydantic_only", False), autoincrement=kwargs.pop("autoincrement", False), column_type=cls.get_column_type(**kwargs), + choices=set(kwargs.pop("choices", [])), **kwargs ) return type(cls.__name__, cls._bases, namespace) @@ -117,6 +118,8 @@ class Integer(ModelFieldFactory): if k not in ["cls", "__class__", "kwargs"] }, } + kwargs["ge"] = kwargs["minimum"] + kwargs["le"] = kwargs["maximum"] return super().__new__(cls, **kwargs) @classmethod @@ -166,6 +169,8 @@ class Float(ModelFieldFactory): if k not in ["cls", "__class__", "kwargs"] }, } + kwargs["ge"] = kwargs["minimum"] + kwargs["le"] = kwargs["maximum"] return super().__new__(cls, **kwargs) @classmethod @@ -251,6 +256,19 @@ class Decimal(ModelFieldFactory): if k not in ["cls", "__class__", "kwargs"] }, } + kwargs["ge"] = kwargs["minimum"] + kwargs["le"] = kwargs["maximum"] + + if kwargs.get("max_digits"): + kwargs["scale"] = kwargs["max_digits"] + elif kwargs.get("scale"): + kwargs["max_digits"] = kwargs["scale"] + + if kwargs.get("decimal_places"): + kwargs["precision"] = kwargs["decimal_places"] + elif kwargs.get("precision"): + kwargs["decimal_places"] = kwargs["precision"] + return super().__new__(cls, **kwargs) @classmethod diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 08475ae..b4d88ee 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -7,6 +7,7 @@ import sqlalchemy from pydantic import BaseConfig from pydantic.fields import FieldInfo, ModelField +import ormar # noqa I100 from ormar import ForeignKey, ModelDefinitionError, Integer # noqa I100 from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField @@ -108,7 +109,7 @@ def create_pydantic_field( type_=Optional[model], model_config=model.__config__, required=False, - class_validators=model.__validators__, + class_validators={}, ) @@ -252,9 +253,41 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]: return Config +def check_if_field_has_choices(field: BaseField) -> bool: + return hasattr(field, "choices") and field.choices + + +def model_initialized_and_has_model_fields(model: Type["Model"]) -> bool: + return hasattr(model, "Meta") and hasattr(model.Meta, "model_fields") + + +def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, Any]: + for field_name, field in cls.Meta.model_fields.items(): + if check_if_field_has_choices(field): + value = values.get(field_name, ormar.Undefined) + if value is not ormar.Undefined and value not in field.choices: + raise ValueError( + f"{field_name}: '{values.get(field_name)}' " + f"not in allowed choices set:" + f" {field.choices}" + ) + return values + + +def populate_choices_validators(model: Type["Model"], attrs: Dict) -> None: # noqa CCR001 + if model_initialized_and_has_model_fields(model): + for _, field in model.Meta.model_fields.items(): + if check_if_field_has_choices(field): + validators = attrs.get("__pre_root_validators__", []) + if choices_validator not in validators: + validators.append(choices_validator) + attrs["__pre_root_validators__"] = validators + + class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: attrs["Config"] = get_pydantic_base_orm_config() + attrs["__name__"] = name attrs = extract_annotations_and_default_vals(attrs, bases) new_model = super().__new__( # type: ignore mcs, name, bases, attrs @@ -262,11 +295,11 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): # breakpoint() if hasattr(new_model, "Meta"): - # attrs = extract_annotations_and_default_vals(attrs, bases) new_model = populate_meta_orm_model_fields(attrs, new_model) new_model = populate_meta_tablename_columns_and_pk(name, new_model) new_model = populate_meta_sqlalchemy_table_if_required(new_model) expand_reverse_relationships(new_model) + populate_choices_validators(new_model, attrs) if new_model.Meta.pkname not in attrs["__annotations__"]: field_name = new_model.Meta.pkname diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 75bb1fd..6a59a12 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -1,6 +1,9 @@ +import asyncio + import databases import pytest import sqlalchemy +from pydantic import root_validator, validator import ormar from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError @@ -50,7 +53,7 @@ class Organisation(ormar.Model): database = database id: ormar.Integer(primary_key=True) - ident: ormar.String(max_length=100) + ident: ormar.String(max_length=100, choices=['ACME Ltd', 'Other ltd']) class Team(ormar.Model): @@ -75,6 +78,25 @@ class Member(ormar.Model): email: ormar.String(max_length=100) +country_name_choices = ("Canada", "Algeria", "United States") +country_taxed_choices = (True,) +country_country_code_choices = (-10, 1, 213, 1200) + + +class Country(ormar.Model): + class Meta: + tablename = "country" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=9, choices=country_name_choices, default="Canada", ) + taxed: ormar.Boolean(choices=country_taxed_choices, default=True) + country_code: ormar.Integer( + minimum=0, maximum=1000, choices=country_country_code_choices, default=1 + ) + + @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) @@ -106,6 +128,109 @@ async def test_related_name(): await Cover.objects.create(album=album, title="The cover file") assert len(album.cover_pictures) == 1 + +@pytest.mark.asyncio +async def test_model_choices(): + """Test that choices work properly for various types of fields.""" + async with database: + # Test valid choices. + await asyncio.gather( + Country.objects.create(name="Canada", taxed=True, country_code=1), + Country.objects.create(name="Algeria", taxed=True, country_code=213), + Country.objects.create(name="Algeria"), + ) + + with pytest.raises(ValueError): + name, taxed, country_code = "Saudi Arabia", True, 1 + assert all( + ( + name not in country_name_choices, + taxed in country_taxed_choices, + country_code in country_country_code_choices, + ) + ) + await Country.objects.create( + name=name, taxed=taxed, country_code=country_code + ) + + with pytest.raises(ValueError): + name, taxed, country_code = "Algeria", False, 1 + assert all( + ( + name in country_name_choices, + taxed not in country_taxed_choices, + country_code in country_country_code_choices, + ) + ) + await Country.objects.create( + name=name, taxed=taxed, country_code=country_code + ) + + with pytest.raises(ValueError): + name, taxed, country_code = "Algeria", True, 967 + assert all( + ( + name in country_name_choices, + taxed in country_taxed_choices, + country_code not in country_country_code_choices, + ) + ) + await Country.objects.create( + name=name, taxed=taxed, country_code=country_code + ) + + with pytest.raises(ValueError): + name, taxed, country_code = ( + "United States", + True, + 1, + ) # name is too long but is a valid choice + assert all( + ( + name in country_name_choices, + taxed in country_taxed_choices, + country_code in country_country_code_choices, + ) + ) + await Country.objects.create( + name=name, taxed=taxed, country_code=country_code + ) + + with pytest.raises(ValueError): + name, taxed, country_code = ( + "Algeria", + True, + -10, + ) # country code is too small but is a valid choice + assert all( + ( + name in country_name_choices, + taxed in country_taxed_choices, + country_code in country_country_code_choices, + ) + ) + await Country.objects.create( + name=name, taxed=taxed, country_code=country_code + ) + + with pytest.raises(ValueError): + name, taxed, country_code = ( + "Algeria", + True, + 1200, + ) # country code is too large but is a valid choice + assert all( + ( + name in country_name_choices, + taxed in country_taxed_choices, + country_code in country_country_code_choices, + ) + ) + await Country.objects.create( + name=name, taxed=taxed, country_code=country_code + ) + + @pytest.mark.asyncio async def test_model_crud(): async with database: @@ -209,7 +334,6 @@ async def test_model_removal_from_relations(): assert len(album.tracks) == 1 - @pytest.mark.asyncio async def test_fk_filter(): async with database: @@ -229,8 +353,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -238,8 +362,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -282,14 +406,22 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: assert member.team.org.ident == "ACME Ltd" +@pytest.mark.asyncio +async def test_wrong_choices(): + async with database: + async with database.transaction(force_rollback=True): + with pytest.raises(ValueError): + await Organisation.objects.create(ident="Test 1") + + @pytest.mark.asyncio async def test_pk_filter(): async with database: @@ -303,8 +435,8 @@ async def test_pk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1 diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index c374585..6cc266e 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -28,6 +28,7 @@ class ExampleModel(Model): test_json: fields.JSON(default={}) test_bigint: fields.BigInteger(default=0) test_decimal: fields.Decimal(scale=10, precision=2) + test_decimal2: fields.Decimal(max_digits=10, decimal_places=2) fields_to_check = [ @@ -55,7 +56,7 @@ class ExampleModel2(Model): @pytest.fixture() def example(): return ExampleModel( - pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5) + pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5), test_decimal2=decimal.Decimal(5.5) ) @@ -73,6 +74,8 @@ def test_model_attribute_access(example): assert example.test_float is None assert example.test_bigint == 0 assert example.test_json == {} + assert example.test_decimal == 3.5 + assert example.test_decimal2 == 5.5 example.test = 12 assert example.test == 12 From 798475ae5f1ce96eb1bf85ece8ec2f74dbac79db Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 23 Sep 2020 10:09:15 +0200 Subject: [PATCH 2/3] refactor and move tests --- .coverage | Bin 53248 -> 53248 bytes ormar/models/metaclass.py | 4 +- tests/test_foreign_keys.py | 163 +++++----------------------- tests/test_many_to_many.py | 4 +- tests/test_model_definition.py | 11 +- tests/test_models.py | 108 +++++++++++++++++- tests/test_more_same_table_joins.py | 1 + tests/test_non_integer_pkey.py | 2 +- tests/test_same_table_joins.py | 8 +- 9 files changed, 158 insertions(+), 143 deletions(-) diff --git a/.coverage b/.coverage index 9f9dda5dc04a39e78520795f599497b87aac6e2c..550b22b2efd8e90ea4e14326f3bf49df1cbfc4df 100644 GIT binary patch delta 174 zcmV;f08#&dpaX!Q1F$qN3Ns)vFgh?dIyEt~J}<~n0mPFBfI<(74QdTd4Hpdo4C@Tu z48*e$5Nr&S!HiQjG$aH832G$n-F)-S_c`Ck=g*(#_x!$n=XZDhw)?hi^KE|q{LNqf z@|)jh=3zab&u2Kp84i;@jXgCpAp`*lXd!wx-+a$^zWL5~{(L@v{@-sm{@=gN|9f}0 cJOB00oqK!F|J^x1|2z42=iBEGv%if&LbME4Pyhe` delta 177 zcmV;i08amapaX!Q1F$qN3Nj!vFgi3iIy5k|J}<~n0l||8fI<&?4O$I74HOLi4CoBo z48gMz5NQmP#EerZCnN*`2~Z^VZoc{E`+WZVd4A9D+jo9<=Wn}j+cw|k=g;5#^ZB!$19M;w%n+C%FhgMczu#{Bzki$m f_wH_Y{_CAP_x7IuyK{d2ck=Jfx6dE5zl}jcum@Q< diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index b4d88ee..bf8c245 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -274,7 +274,9 @@ def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, A return values -def populate_choices_validators(model: Type["Model"], attrs: Dict) -> None: # noqa CCR001 +def populate_choices_validators( # noqa CCR001 + model: Type["Model"], attrs: Dict +) -> None: if model_initialized_and_has_model_fields(model): for _, field in model.Meta.model_fields.items(): if check_if_field_has_choices(field): diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 6a59a12..b2efbc6 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -53,7 +53,7 @@ class Organisation(ormar.Model): database = database id: ormar.Integer(primary_key=True) - ident: ormar.String(max_length=100, choices=['ACME Ltd', 'Other ltd']) + ident: ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"]) class Team(ormar.Model): @@ -78,25 +78,6 @@ class Member(ormar.Model): email: ormar.String(max_length=100) -country_name_choices = ("Canada", "Algeria", "United States") -country_taxed_choices = (True,) -country_country_code_choices = (-10, 1, 213, 1200) - - -class Country(ormar.Model): - class Meta: - tablename = "country" - metadata = metadata - database = database - - id: ormar.Integer(primary_key=True) - name: ormar.String(max_length=9, choices=country_name_choices, default="Canada", ) - taxed: ormar.Boolean(choices=country_taxed_choices, default=True) - country_code: ormar.Integer( - minimum=0, maximum=1000, choices=country_country_code_choices, default=1 - ) - - @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) @@ -129,108 +110,6 @@ async def test_related_name(): assert len(album.cover_pictures) == 1 -@pytest.mark.asyncio -async def test_model_choices(): - """Test that choices work properly for various types of fields.""" - async with database: - # Test valid choices. - await asyncio.gather( - Country.objects.create(name="Canada", taxed=True, country_code=1), - Country.objects.create(name="Algeria", taxed=True, country_code=213), - Country.objects.create(name="Algeria"), - ) - - with pytest.raises(ValueError): - name, taxed, country_code = "Saudi Arabia", True, 1 - assert all( - ( - name not in country_name_choices, - taxed in country_taxed_choices, - country_code in country_country_code_choices, - ) - ) - await Country.objects.create( - name=name, taxed=taxed, country_code=country_code - ) - - with pytest.raises(ValueError): - name, taxed, country_code = "Algeria", False, 1 - assert all( - ( - name in country_name_choices, - taxed not in country_taxed_choices, - country_code in country_country_code_choices, - ) - ) - await Country.objects.create( - name=name, taxed=taxed, country_code=country_code - ) - - with pytest.raises(ValueError): - name, taxed, country_code = "Algeria", True, 967 - assert all( - ( - name in country_name_choices, - taxed in country_taxed_choices, - country_code not in country_country_code_choices, - ) - ) - await Country.objects.create( - name=name, taxed=taxed, country_code=country_code - ) - - with pytest.raises(ValueError): - name, taxed, country_code = ( - "United States", - True, - 1, - ) # name is too long but is a valid choice - assert all( - ( - name in country_name_choices, - taxed in country_taxed_choices, - country_code in country_country_code_choices, - ) - ) - await Country.objects.create( - name=name, taxed=taxed, country_code=country_code - ) - - with pytest.raises(ValueError): - name, taxed, country_code = ( - "Algeria", - True, - -10, - ) # country code is too small but is a valid choice - assert all( - ( - name in country_name_choices, - taxed in country_taxed_choices, - country_code in country_country_code_choices, - ) - ) - await Country.objects.create( - name=name, taxed=taxed, country_code=country_code - ) - - with pytest.raises(ValueError): - name, taxed, country_code = ( - "Algeria", - True, - 1200, - ) # country code is too large but is a valid choice - assert all( - ( - name in country_name_choices, - taxed in country_taxed_choices, - country_code in country_country_code_choices, - ) - ) - await Country.objects.create( - name=name, taxed=taxed, country_code=country_code - ) - - @pytest.mark.asyncio async def test_model_crud(): async with database: @@ -347,14 +226,18 @@ async def test_fk_filter(): await Track.objects.create(album=malibu, title="The Waters", position=3) fantasies = await Album.objects.create(name="Fantasies") - await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.objects.create( + album=fantasies, title="Help I'm Alive", position=1 + ) await Track.objects.create(album=fantasies, title="Sick Muse", position=2) - await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) + await Track.objects.create( + album=fantasies, title="Satellite Mind", position=3 + ) tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -362,8 +245,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -377,7 +260,9 @@ async def test_fk_filter(): tracks = await Track.objects.filter(album__name__contains="Malibu%").all() assert len(tracks) == 3 - tracks = await Track.objects.filter(album=malibu).select_related("album").all() + tracks = ( + await Track.objects.filter(album=malibu).select_related("album").all() + ) assert len(tracks) == 3 for track in tracks: assert track.album.name == "Malibu%" @@ -406,8 +291,8 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: @@ -427,16 +312,20 @@ async def test_pk_filter(): async with database: async with database.transaction(force_rollback=True): fantasies = await Album.objects.create(name="Test") - track = await Track.objects.create(album=fantasies, title="Test1", position=1) + track = await Track.objects.create( + album=fantasies, title="Test1", position=1 + ) await Track.objects.create(album=fantasies, title="Test2", position=2) await Track.objects.create(album=fantasies, title="Test3", position=3) - tracks = await Track.objects.select_related("album").filter(pk=track.pk).all() + tracks = ( + await Track.objects.select_related("album").filter(pk=track.pk).all() + ) assert len(tracks) == 1 tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1 @@ -446,7 +335,9 @@ async def test_limit_and_offset(): async with database: async with database.transaction(force_rollback=True): fantasies = await Album.objects.create(name="Limitless") - await Track.objects.create(id=None, album=fantasies, title="Sample", position=1) + await Track.objects.create( + id=None, album=fantasies, title="Sample", position=1 + ) await Track.objects.create(album=fantasies, title="Sample2", position=2) await Track.objects.create(album=fantasies, title="Sample3", position=3) diff --git a/tests/test_many_to_many.py b/tests/test_many_to_many.py index 4111458..10347bf 100644 --- a/tests/test_many_to_many.py +++ b/tests/test_many_to_many.py @@ -126,7 +126,9 @@ async def test_quering_of_the_m2m_models(cleanup): category = await Category.objects.filter(posts__author=guido).get() assert category == news # or: - category2 = await Category.objects.filter(posts__author__first_name="Guido").get() + category2 = await Category.objects.filter( + posts__author__first_name="Guido" + ).get() assert category2 == news diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index 6cc266e..2f64259 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -56,7 +56,11 @@ class ExampleModel2(Model): @pytest.fixture() def example(): return ExampleModel( - pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5), test_decimal2=decimal.Decimal(5.5) + pk=1, + test_string="test", + test_bool=True, + test_decimal=decimal.Decimal(3.5), + test_decimal2=decimal.Decimal(5.5), ) @@ -115,6 +119,7 @@ def test_sqlalchemy_table_is_created(example): def test_no_pk_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example3" @@ -125,6 +130,7 @@ def test_no_pk_in_model_definition(): def test_two_pks_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example3" @@ -136,6 +142,7 @@ def test_two_pks_in_model_definition(): def test_setting_pk_column_as_pydantic_only_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example4" @@ -146,6 +153,7 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition(): def test_decimal_error_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example5" @@ -156,6 +164,7 @@ def test_decimal_error_in_model_definition(): def test_string_error_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example6" diff --git a/tests/test_models.py b/tests/test_models.py index 9eece9a..e402c86 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,5 +1,6 @@ import asyncio from datetime import datetime +from typing import List import databases import pydantic @@ -47,6 +48,27 @@ class Product(ormar.Model): last_delivery: ormar.Date(default=datetime.now) +country_name_choices = ("Canada", "Algeria", "United States") +country_taxed_choices = (True,) +country_country_code_choices = (-10, 1, 213, 1200) + + +class Country(ormar.Model): + class Meta: + tablename = "country" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String( + max_length=9, choices=country_name_choices, default="Canada", + ) + taxed: ormar.Boolean(choices=country_taxed_choices, default=True) + country_code: ormar.Integer( + minimum=0, maximum=1000, choices=country_country_code_choices, default=1 + ) + + @pytest.fixture(scope="module") def event_loop(): loop = asyncio.get_event_loop() @@ -242,7 +264,9 @@ async def test_model_limit_with_filter(): await User.objects.create(name="Tom") await User.objects.create(name="Tom") - assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2 + assert ( + len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2 + ) @pytest.mark.asyncio @@ -268,3 +292,85 @@ async def test_model_first(): assert await User.objects.filter(name="Jane").first() == jane with pytest.raises(NoMatch): await User.objects.filter(name="Lucy").first() + + +def not_contains(a, b): + return a not in b + + +def contains(a, b): + return a in b + + +def check_choices(values: tuple, ops: List): + ops_dict = {"in": contains, "out": not_contains} + checks = (country_name_choices, country_taxed_choices, country_country_code_choices) + assert all( + [ops_dict[op](value, check) for value, op, check in zip(values, ops, checks)] + ) + + +@pytest.mark.asyncio +async def test_model_choices(): + """Test that choices work properly for various types of fields.""" + async with database: + # Test valid choices. + await asyncio.gather( + Country.objects.create(name="Canada", taxed=True, country_code=1), + Country.objects.create(name="Algeria", taxed=True, country_code=213), + Country.objects.create(name="Algeria"), + ) + + with pytest.raises(ValueError): + name, taxed, country_code = "Saudi Arabia", True, 1 + check_choices((name, taxed, country_code), ["out", "in", "in"]) + await Country.objects.create( + name=name, taxed=taxed, country_code=country_code + ) + + with pytest.raises(ValueError): + name, taxed, country_code = "Algeria", False, 1 + check_choices((name, taxed, country_code), ["in", "out", "in"]) + await Country.objects.create( + name=name, taxed=taxed, country_code=country_code + ) + + with pytest.raises(ValueError): + name, taxed, country_code = "Algeria", True, 967 + check_choices((name, taxed, country_code), ["in", "in", "out"]) + await Country.objects.create( + name=name, taxed=taxed, country_code=country_code + ) + + with pytest.raises(ValueError): + name, taxed, country_code = ( + "United States", + True, + 1, + ) # name is too long but is a valid choice + check_choices((name, taxed, country_code), ["in", "in", "in"]) + await Country.objects.create( + name=name, taxed=taxed, country_code=country_code + ) + + with pytest.raises(ValueError): + name, taxed, country_code = ( + "Algeria", + True, + -10, + ) # country code is too small but is a valid choice + check_choices((name, taxed, country_code), ["in", "in", "in"]) + await Country.objects.create( + name=name, taxed=taxed, country_code=country_code + ) + + with pytest.raises(ValueError): + name, taxed, country_code = ( + "Algeria", + True, + 1200, + ) # country code is too large but is a valid choice + check_choices((name, taxed, country_code), ["in", "in", "in"]) + await Country.objects.create( + name=name, taxed=taxed, country_code=country_code + ) diff --git a/tests/test_more_same_table_joins.py b/tests/test_more_same_table_joins.py index 5c4ac3d..0492d91 100644 --- a/tests/test_more_same_table_joins.py +++ b/tests/test_more_same_table_joins.py @@ -94,6 +94,7 @@ async def create_data(): await Student.objects.create(name="Jack", category=category2, schoolclass=class2) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) + @pytest.mark.asyncio async def test_model_multiple_instances_of_same_table_in_schema(): async with database: diff --git a/tests/test_non_integer_pkey.py b/tests/test_non_integer_pkey.py index 3face9c..d601e10 100644 --- a/tests/test_non_integer_pkey.py +++ b/tests/test_non_integer_pkey.py @@ -44,4 +44,4 @@ async def test_pk_1(): async def test_pk_2(): async with database: model = await Model.objects.create(name="NAME") - assert await Model.objects.all() == [model] \ No newline at end of file + assert await Model.objects.all() == [model] diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 9eda9d7..33d2677 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -112,10 +112,14 @@ async def test_model_multiple_instances_of_same_table_in_schema(): assert classes[0].students[0].schoolclass.name == "Math" assert classes[0].students[0].schoolclass.department.name is None await classes[0].students[0].schoolclass.department.load() - assert classes[0].students[0].schoolclass.department.name == "Math Department" + assert ( + classes[0].students[0].schoolclass.department.name == "Math Department" + ) await classes[1].students[0].schoolclass.department.load() - assert classes[1].students[0].schoolclass.department.name == "Law Department" + assert ( + classes[1].students[0].schoolclass.department.name == "Law Department" + ) @pytest.mark.asyncio From 9ce50280ae2d739d5c6fb4fde7a031213b447c41 Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 23 Sep 2020 10:11:42 +0200 Subject: [PATCH 3/3] bump version --- ormar/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ormar/__init__.py b/ormar/__init__.py index 0403f31..aea7bfd 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -26,7 +26,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.3.1" +__version__ = "0.3.2" __all__ = [ "Integer", "BigInteger",