diff --git a/.codeclimate.yml b/.codeclimate.yml index b62e2ab..e893c57 100644 --- a/.codeclimate.yml +++ b/.codeclimate.yml @@ -3,6 +3,9 @@ checks: method-complexity: config: threshold: 8 + file-lines: + config: + threshold: 500 engines: bandit: enabled: true diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 2d0bb55..6af4c57 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -1,3 +1,6 @@ +import datetime +import decimal +import uuid from typing import ( Any, Dict, @@ -14,6 +17,7 @@ from typing import ( import databases import pydantic import sqlalchemy +from pydantic.main import SchemaExtraCallable from sqlalchemy.sql.schema import ColumnCollectionConstraint import ormar # noqa I100 @@ -84,6 +88,40 @@ def check_if_field_has_choices(field: Type[BaseField]) -> bool: return hasattr(field, "choices") and bool(field.choices) +def convert_choices_if_needed( + field: Type["BaseField"], values: Dict +) -> Tuple[Any, List]: + """ + Converts dates to isoformat as fastapi can check this condition in routes + and the fields are not yet parsed. + + :param field: ormar field to check with choices + :type field: Type[BaseField] + :param values: current values of the model to verify + :type values: Dict + :return: value, choices list + :rtype: Tuple[Any, List] + """ + value = values.get(field.name, ormar.Undefined) + choices = list(field.choices) + if field.__type__ in [datetime.datetime, datetime.date, datetime.time]: + value = value.isoformat() if not isinstance(value, str) else value + choices = [o.isoformat() for o in field.choices] + elif field.__type__ == uuid.UUID: + value = str(value) if not isinstance(value, str) else value + choices = [str(o) for o in field.choices] + elif field.__type__ == decimal.Decimal: + precision = field.precision # type: ignore + value = ( + round(float(value), precision) + if isinstance(value, decimal.Decimal) + else value + ) + choices = [round(float(o), precision) for o in choices] + + return value, choices + + def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, Any]: """ Validator that is attached to pydantic model pre root validators. @@ -99,16 +137,26 @@ def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, A """ for field_name, field in cls.Meta.model_fields.items(): if check_if_field_has_choices(field): - value = values.get(field_name, ormar.Undefined) - if value is not ormar.Undefined and value not in field.choices: + value, choices = convert_choices_if_needed(field=field, values=values) + if value is not ormar.Undefined and value not in choices: raise ValueError( f"{field_name}: '{values.get(field_name)}' " f"not in allowed choices set:" - f" {field.choices}" + f" {choices}" ) return values +def construct_modify_schema_function(fields_with_choices: List) -> SchemaExtraCallable: + def schema_extra(schema: Dict[str, Any], model: Type["Model"]) -> None: + for field_id, prop in schema.get("properties", {}).items(): + if field_id in fields_with_choices: + prop["enum"] = list(model.Meta.model_fields[field_id].choices) + prop["description"] = prop.get("description", "") + "An enumeration." + + return staticmethod(schema_extra) # type: ignore + + def populate_choices_validators(model: Type["Model"]) -> None: # noqa CCR001 """ Checks if Model has any fields with choices set. @@ -117,14 +165,21 @@ def populate_choices_validators(model: Type["Model"]) -> None: # noqa CCR001 :param model: newly constructed Model :type model: Model class """ + fields_with_choices = [] if not meta_field_not_set(model=model, field_name="model_fields"): - for _, field in model.Meta.model_fields.items(): + for name, field in model.Meta.model_fields.items(): if check_if_field_has_choices(field): + fields_with_choices.append(name) validators = getattr(model, "__pre_root_validators__", []) if choices_validator not in validators: validators.append(choices_validator) model.__pre_root_validators__ = validators + if fields_with_choices: + model.Config.schema_extra = construct_modify_schema_function( + fields_with_choices=fields_with_choices + ) + def add_cached_properties(new_model: Type["Model"]) -> None: """ diff --git a/tests/test_choices_schema.py b/tests/test_choices_schema.py new file mode 100644 index 0000000..63456bd --- /dev/null +++ b/tests/test_choices_schema.py @@ -0,0 +1,142 @@ +import datetime +import decimal +import uuid + +import databases +import pydantic +import pytest +import sqlalchemy +from fastapi import FastAPI +from starlette.testclient import TestClient + +import ormar +from tests.settings import DATABASE_URL + +app = FastAPI() +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() +app.state.database = database + +uuid1 = uuid.uuid4() +uuid2 = uuid.uuid4() + + +class Organisation(ormar.Model): + class Meta: + tablename = "org" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"]) + priority: int = ormar.Integer(choices=[1, 2, 3, 4, 5]) + priority2: int = ormar.BigInteger(choices=[1, 2, 3, 4, 5]) + expire_date: datetime.date = ormar.Date( + choices=[datetime.date(2021, 1, 1), datetime.date(2022, 5, 1)] + ) + expire_time: datetime.time = ormar.Time( + choices=[datetime.time(10, 0, 0), datetime.time(12, 30)] + ) + + expire_datetime: datetime.datetime = ormar.DateTime( + choices=[datetime.datetime(2021, 1, 1, 10, 0, 0), + datetime.datetime(2022, 5, 1, 12, 30)] + ) + random_val: float = ormar.Float(choices=[2.0, 3.5]) + random_decimal: decimal.Decimal = ormar.Decimal(scale=4, precision=2, + choices=[decimal.Decimal(12.4), decimal.Decimal(58.2)] + ) + random_json: pydantic.Json = ormar.JSON( + choices=["aa", "{\"aa\":\"bb\"}"] + ) + random_uuid: uuid.UUID = ormar.UUID( + choices=[uuid1, uuid2]) + + +@app.on_event("startup") +async def startup() -> None: + database_ = app.state.database + if not database_.is_connected: + await database_.connect() + + +@app.on_event("shutdown") +async def shutdown() -> None: + database_ = app.state.database + if database_.is_connected: + await database_.disconnect() + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@app.post("/items/", response_model=Organisation) +async def create_item(item: Organisation): + await item.save() + return item + + +def test_all_endpoints(): + client = TestClient(app) + with client as client: + response = client.post( + "/items/", + json={"id": 1, + "ident": "", + "priority": 4, + "expire_date": "2022-05-01"}, + ) + + assert response.status_code == 422 + response = client.post( + "/items/", + json={ + "id": 1, + "ident": "ACME Ltd", + "priority": 4, + "priority2": 2, + "expire_date": "2022-05-01", + "expire_time": "10:00:00", + "expire_datetime": "2022-05-01T12:30:00", + "random_val": 3.5, + "random_decimal": 12.4, + "random_json": "{\"aa\":\"bb\"}", + "random_uuid": str(uuid1) + }, + ) + + assert response.status_code == 200 + item = Organisation(**response.json()) + assert item.pk is not None + response = client.get("/docs/") + assert response.status_code == 200 + assert b"FastAPI - Swagger UI" in response.content + + +def test_schema_modification(): + schema = Organisation.schema() + for field in ["ident", "priority", "expire_date"]: + assert field in schema["properties"] + assert schema["properties"].get(field).get("enum") == list( + Organisation.Meta.model_fields.get(field).choices + ) + assert "An enumeration." in schema["properties"].get(field).get("description") + + +def test_schema_gen(): + schema = app.openapi() + assert "Organisation" in schema["components"]["schemas"] + props = schema["components"]["schemas"]["Organisation"]["properties"] + for field in ["ident", "priority", "expire_date"]: + assert "enum" in props.get(field) + choices = Organisation.Meta.model_fields.get(field).choices + assert props.get(field).get("enum") == [ + str(x) if isinstance(x, datetime.date) else x for x in choices + ] + assert "description" in props.get(field) + assert "An enumeration." in props.get(field).get("description") diff --git a/tests/test_fastapi_docs.py b/tests/test_fastapi_docs.py index 9fa503f..08118ea 100644 --- a/tests/test_fastapi_docs.py +++ b/tests/test_fastapi_docs.py @@ -1,4 +1,4 @@ -from typing import List, Union, Optional +from typing import List import databases import pytest