Merge pull request #203 from collerek/check_defaults

Bug fixes
This commit is contained in:
collerek
2021-05-18 16:38:19 +02:00
committed by GitHub
7 changed files with 275 additions and 55 deletions

View File

@ -1,3 +1,12 @@
# 0.10.8
## 🐛 Fixes
* Fix populating default values in pk_only child models [#202](https://github.com/collerek/ormar/issues/202)
* Fix mypy for LargeBinary fields with base64 str representation [#199](https://github.com/collerek/ormar/issues/199)
* Fix OpenAPI schema format for LargeBinary fields with base64 str representation [#199](https://github.com/collerek/ormar/issues/199)
* Fix OpenAPI choices encoding for LargeBinary fields with base64 str representation
# 0.10.7 # 0.10.7
## ✨ Features ## ✨ Features

View File

@ -76,7 +76,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType() Undefined = UndefinedType()
__version__ = "0.10.7" __version__ = "0.10.8"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",

View File

@ -1,7 +1,7 @@
import datetime import datetime
import decimal import decimal
import uuid import uuid
from typing import Any, Optional, TYPE_CHECKING from typing import Any, Optional, TYPE_CHECKING, Union, overload
import pydantic import pydantic
import sqlalchemy import sqlalchemy
@ -11,6 +11,11 @@ from ormar.fields import sqlalchemy_uuid
from ormar.fields.base import BaseField # noqa I101 from ormar.fields.base import BaseField # noqa I101
from ormar.fields.sqlalchemy_encrypted import EncryptBackends from ormar.fields.sqlalchemy_encrypted import EncryptBackends
try:
from typing import Literal
except ImportError: # pragma: no cover
from typing_extensions import Literal # type: ignore
def is_field_nullable( def is_field_nullable(
nullable: Optional[bool], nullable: Optional[bool],
@ -426,52 +431,85 @@ class JSON(ModelFieldFactory, pydantic.Json):
return sqlalchemy.JSON() return sqlalchemy.JSON()
class LargeBinary(ModelFieldFactory, bytes): if TYPE_CHECKING: # pragma: nocover # noqa: C901
"""
LargeBinary field factory that construct Field classes and populated their values.
"""
_type = bytes @overload
_sample = "bytes" def LargeBinary(
max_length: int, *, represent_as_base64_str: Literal[True], **kwargs: Any
) -> str:
...
def __new__( # type: ignore # noqa CFQ002 @overload
cls, *, max_length: int, represent_as_base64_str: bool = False, **kwargs: Any def LargeBinary(
) -> BaseField: # type: ignore max_length: int, *, represent_as_base64_str: Literal[False], **kwargs: Any
kwargs = { ) -> bytes:
**kwargs, ...
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod @overload
def get_column_type(cls, **kwargs: Any) -> Any: def LargeBinary(
max_length: int, represent_as_base64_str: Literal[False] = ..., **kwargs: Any
) -> bytes:
...
def LargeBinary(
max_length: int, represent_as_base64_str: bool = False, **kwargs: Any
) -> Union[str, bytes]:
pass
else:
class LargeBinary(ModelFieldFactory, bytes):
"""
LargeBinary field factory that construct Field classes
and populated their values.
""" """
Return proper type of db column for given field type.
Accepts required and optional parameters that each column type accepts.
:param kwargs: key, value pairs of sqlalchemy options _type = bytes
:type kwargs: Any _sample = "bytes"
:return: initialized column with proper options
:rtype: sqlalchemy Column
"""
return sqlalchemy.LargeBinary(length=kwargs.get("max_length"))
@classmethod def __new__( # type: ignore # noqa CFQ002
def validate(cls, **kwargs: Any) -> None: cls,
""" *,
Used to validate if all required parameters on a given field type are set. max_length: int,
:param kwargs: all params passed during construction represent_as_base64_str: bool = False,
:type kwargs: Any **kwargs: Any
""" ) -> BaseField: # type: ignore
max_length = kwargs.get("max_length", None) kwargs = {
if max_length <= 0: **kwargs,
raise ModelDefinitionError( **{
"Parameter max_length is required for field LargeBinary" k: v
) for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
"""
Return proper type of db column for given field type.
Accepts required and optional parameters that each column type accepts.
:param kwargs: key, value pairs of sqlalchemy options
:type kwargs: Any
:return: initialized column with proper options
:rtype: sqlalchemy Column
"""
return sqlalchemy.LargeBinary(length=kwargs.get("max_length"))
@classmethod
def validate(cls, **kwargs: Any) -> None:
"""
Used to validate if all required parameters on a given field type are set.
:param kwargs: all params passed during construction
:type kwargs: Any
"""
max_length = kwargs.get("max_length", None)
if max_length <= 0:
raise ModelDefinitionError(
"Parameter max_length is required for field LargeBinary"
)
class BigInteger(Integer, int): class BigInteger(Integer, int):

View File

@ -142,7 +142,8 @@ def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> D
) )
for name, field in model.Meta.model_fields.items(): for name, field in model.Meta.model_fields.items():
if not field.is_relation: if not field.is_relation:
example[name] = field.__sample__ is_bytes_str = field.__type__ == bytes and field.represent_as_base64_str
example[name] = field.__sample__ if not is_bytes_str else "string"
elif isinstance(relation_map, dict) and name in relation_map: elif isinstance(relation_map, dict) and name in relation_map:
example[name] = get_nested_model_example( example[name] = get_nested_model_example(
name=name, field=field, relation_map=relation_map name=name, field=field, relation_map=relation_map
@ -217,6 +218,44 @@ def get_pydantic_example_repr(type_: Any) -> Any:
return "string" return "string"
def overwrite_example_and_description(
schema: Dict[str, Any], model: Type["Model"]
) -> None:
"""
Overwrites the example with properly nested children models.
Overwrites the description if it's taken from ormar.Model.
:param schema: schema of current model
:type schema: Dict[str, Any]
:param model: model class
:type model: Type["Model"]
"""
schema["example"] = generate_model_example(model=model)
if "Main base class of ormar Model." in schema.get("description", ""):
schema["description"] = f"{model.__name__}"
def overwrite_binary_format(schema: Dict[str, Any], model: Type["Model"]) -> None:
"""
Overwrites format of the field if it's a LargeBinary field with
a flag to represent the field as base64 encoded string.
:param schema: schema of current model
:type schema: Dict[str, Any]
:param model: model class
:type model: Type["Model"]
"""
for field_id, prop in schema.get("properties", {}).items():
if (
field_id in model._bytes_fields
and model.Meta.model_fields[field_id].represent_as_base64_str
):
prop["format"] = "base64"
prop["enum"] = [
base64.b64encode(choice).decode() for choice in prop["enum"]
]
def construct_modify_schema_function(fields_with_choices: List) -> SchemaExtraCallable: def construct_modify_schema_function(fields_with_choices: List) -> SchemaExtraCallable:
""" """
Modifies the schema to include fields with choices validator. Modifies the schema to include fields with choices validator.
@ -237,9 +276,8 @@ def construct_modify_schema_function(fields_with_choices: List) -> SchemaExtraCa
if field_id in fields_with_choices: if field_id in fields_with_choices:
prop["enum"] = list(model.Meta.model_fields[field_id].choices) prop["enum"] = list(model.Meta.model_fields[field_id].choices)
prop["description"] = prop.get("description", "") + "An enumeration." prop["description"] = prop.get("description", "") + "An enumeration."
schema["example"] = generate_model_example(model=model) overwrite_example_and_description(schema=schema, model=model)
if "Main base class of ormar Model." in schema.get("description", ""): overwrite_binary_format(schema=schema, model=model)
schema["description"] = f"{model.__name__}"
return staticmethod(schema_extra) # type: ignore return staticmethod(schema_extra) # type: ignore
@ -256,9 +294,8 @@ def construct_schema_function_without_choices() -> SchemaExtraCallable:
""" """
def schema_extra(schema: Dict[str, Any], model: Type["Model"]) -> None: def schema_extra(schema: Dict[str, Any], model: Type["Model"]) -> None:
schema["example"] = generate_model_example(model=model) overwrite_example_and_description(schema=schema, model=model)
if "Main base class of ormar Model." in schema.get("description", ""): overwrite_binary_format(schema=schema, model=model)
schema["description"] = f"{model.__name__}"
return staticmethod(schema_extra) # type: ignore return staticmethod(schema_extra) # type: ignore

View File

@ -132,11 +132,15 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
new_kwargs, through_tmp_dict = self._process_kwargs(kwargs) new_kwargs, through_tmp_dict = self._process_kwargs(kwargs)
values, fields_set, validation_error = pydantic.validate_model( if not pk_only:
self, new_kwargs # type: ignore values, fields_set, validation_error = pydantic.validate_model(
) self, new_kwargs # type: ignore
if validation_error and not pk_only: )
raise validation_error if validation_error:
raise validation_error
else:
fields_set = {self.Meta.pkname}
values = new_kwargs
object.__setattr__(self, "__dict__", values) object.__setattr__(self, "__dict__", values)
object.__setattr__(self, "__fields_set__", fields_set) object.__setattr__(self, "__fields_set__", fields_set)

View File

@ -52,7 +52,7 @@ class BinaryThing(ormar.Model):
id: uuid.UUID = ormar.UUID(primary_key=True, default=uuid.uuid4) id: uuid.UUID = ormar.UUID(primary_key=True, default=uuid.uuid4)
name: str = ormar.Text(default="") name: str = ormar.Text(default="")
bt: bytes = ormar.LargeBinary( bt: str = ormar.LargeBinary(
max_length=1000, max_length=1000,
choices=[blob3, blob4, blob5, blob6], choices=[blob3, blob4, blob5, blob6],
represent_as_base64_str=True, represent_as_base64_str=True,
@ -89,3 +89,14 @@ def test_read_main():
assert response.json()[0]["bt"] == base64.b64encode(blob3).decode() assert response.json()[0]["bt"] == base64.b64encode(blob3).decode()
thing = BinaryThing(**response.json()[0]) thing = BinaryThing(**response.json()[0])
assert thing.__dict__["bt"] == blob3 assert thing.__dict__["bt"] == blob3
def test_schema():
schema = BinaryThing.schema()
assert schema["properties"]["bt"]["format"] == "base64"
converted_choices = ["7g==", "/w==", "8CiMKA==", "wyg="]
assert len(schema["properties"]["bt"]["enum"]) == 4
assert all(
choice in schema["properties"]["bt"]["enum"] for choice in converted_choices
)
assert schema["example"]["bt"] == "string"

View File

@ -0,0 +1,121 @@
from typing import Optional
import databases
import pytest
import sqlalchemy
from fastapi import FastAPI
from starlette.testclient import TestClient
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
app = FastAPI()
app.state.database = database
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
class BaseMeta(ormar.ModelMeta):
metadata = metadata
database = database
class Country(ormar.Model):
class Meta(BaseMeta):
tablename = "countries"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, default="Poland")
class Author(ormar.Model):
class Meta(BaseMeta):
tablename = "authors"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
rating: int = ormar.Integer(default=0)
country: Optional[Country] = ormar.ForeignKey(Country)
class Book(ormar.Model):
class Meta(BaseMeta):
tablename = "books"
id: int = ormar.Integer(primary_key=True)
author: Optional[Author] = ormar.ForeignKey(Author)
title: str = ormar.String(max_length=100)
year: int = ormar.Integer(nullable=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.fixture()
async def sample_data():
async with database:
country = await Country(id=1, name="USA").save()
author = await Author(id=1, name="bug", rating=5, country=country).save()
await Book(
id=1, author=author, title="Bug caused by default value", year=2021
).save()
@app.get("/books/{book_id}", response_model=Book)
async def get_book_by_id(book_id: int):
book = await Book.objects.get(id=book_id)
return book
@app.get("/books_with_author/{book_id}", response_model=Book)
async def get_book_with_author_by_id(book_id: int):
book = await Book.objects.select_related("author").get(id=book_id)
return book
def test_related_with_defaults(sample_data):
client = TestClient(app)
with client as client:
response = client.get("/books/1")
assert response.json() == {
"author": {"id": 1},
"id": 1,
"title": "Bug caused by default value",
"year": 2021,
}
response = client.get("/books_with_author/1")
assert response.json() == {
"author": {
"books": [
{"id": 1, "title": "Bug caused by default value", "year": 2021}
],
"country": {"id": 1},
"id": 1,
"name": "bug",
"rating": 5,
},
"id": 1,
"title": "Bug caused by default value",
"year": 2021,
}