fix #409 - nullable large binary fields

This commit is contained in:
collerek
2021-12-16 14:44:01 +01:00
parent 1f5d993716
commit 15be5170f3
6 changed files with 90 additions and 5 deletions

View File

@ -59,7 +59,11 @@ class BytesDescriptor:
def __get__(self, instance: "Model", owner: Type["Model"]) -> Any: def __get__(self, instance: "Model", owner: Type["Model"]) -> Any:
value = instance.__dict__.get(self.name, None) value = instance.__dict__.get(self.name, None)
field = instance.Meta.model_fields[self.name] field = instance.Meta.model_fields[self.name]
if field.represent_as_base64_str and not isinstance(value, str): if (
value is not None
and field.represent_as_base64_str
and not isinstance(value, str)
):
value = base64.b64encode(value).decode() value = base64.b64encode(value).decode()
return value return value

View File

@ -1,3 +1,4 @@
import base64
import uuid import uuid
from typing import ( from typing import (
Any, Any,
@ -51,6 +52,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
new_kwargs = cls._remove_not_ormar_fields(new_kwargs) new_kwargs = cls._remove_not_ormar_fields(new_kwargs)
new_kwargs = cls.substitute_models_with_pks(new_kwargs) new_kwargs = cls.substitute_models_with_pks(new_kwargs)
new_kwargs = cls.populate_default_values(new_kwargs) new_kwargs = cls.populate_default_values(new_kwargs)
new_kwargs = cls.reconvert_str_to_bytes(new_kwargs)
new_kwargs = cls.translate_columns_to_aliases(new_kwargs) new_kwargs = cls.translate_columns_to_aliases(new_kwargs)
return new_kwargs return new_kwargs
@ -144,6 +146,36 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
model_dict.pop(field, None) model_dict.pop(field, None)
return model_dict return model_dict
@classmethod
def reconvert_str_to_bytes(cls, model_dict: Dict) -> Dict:
"""
Receives dictionary of model that is about to be saved and changes
all bytes fields that are represented as strings back into bytes.
:param model_dict: dictionary of model that is about to be saved
:type model_dict: Dict
:return: dictionary of model that is about to be saved
:rtype: Dict
"""
bytes_fields = {
name
for name, field in cls.Meta.model_fields.items()
if field.__type__ == bytes
}
bytes_base64_fields = {
name
for name, field in cls.Meta.model_fields.items()
if field.represent_as_base64_str
}
for key, value in model_dict.items():
if key in bytes_fields and isinstance(value, str):
model_dict[key] = (
value.encode("utf-8")
if key not in bytes_base64_fields
else base64.b64decode(value)
)
return model_dict
@classmethod @classmethod
def populate_default_values(cls, new_kwargs: Dict) -> Dict: def populate_default_values(cls, new_kwargs: Dict) -> Dict:
""" """

View File

@ -861,7 +861,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if column_name not in self._bytes_fields: if column_name not in self._bytes_fields:
return value return value
field = self.Meta.model_fields[column_name] field = self.Meta.model_fields[column_name]
if not isinstance(value, bytes): if not isinstance(value, bytes) and value is not None:
if field.represent_as_base64_str: if field.represent_as_base64_str:
value = base64.b64decode(value) value = base64.b64decode(value)
else: else:
@ -882,7 +882,11 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if column_name not in self._bytes_fields: if column_name not in self._bytes_fields:
return value return value
field = self.Meta.model_fields[column_name] field = self.Meta.model_fields[column_name]
if not isinstance(value, str) and field.represent_as_base64_str: if (
value is not None
and not isinstance(value, str)
and field.represent_as_base64_str
):
return base64.b64encode(value).decode() return base64.b64encode(value).decode()
return value return value

View File

@ -1096,6 +1096,7 @@ class QuerySet(Generic[T]):
) )
new_kwargs = self.model.parse_non_db_fields(new_kwargs) new_kwargs = self.model.parse_non_db_fields(new_kwargs)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs) new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self.model.reconvert_str_to_bytes(new_kwargs)
new_kwargs = self.model.translate_columns_to_aliases(new_kwargs) new_kwargs = self.model.translate_columns_to_aliases(new_kwargs)
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns} new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
ready_objects.append(new_kwargs) ready_objects.append(new_kwargs)

View File

@ -1,11 +1,9 @@
import base64 import base64
import json import json
import os
import uuid import uuid
from typing import List from typing import List
import databases import databases
import pydantic
import pytest import pytest
import sqlalchemy import sqlalchemy
from fastapi import FastAPI from fastapi import FastAPI

View File

@ -58,6 +58,21 @@ class LargeBinaryStr(ormar.Model):
) )
class LargeBinaryNullableStr(ormar.Model):
class Meta:
tablename = "my_str_blobs2"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
test_binary: str = ormar.LargeBinary(
max_length=100000,
choices=[blob3, blob4],
represent_as_base64_str=True,
nullable=True,
)
class UUIDSample(ormar.Model): class UUIDSample(ormar.Model):
class Meta: class Meta:
tablename = "uuids" tablename = "uuids"
@ -231,6 +246,37 @@ async def test_binary_str_column():
assert items[1].__dict__["test_binary"] == blob4 assert items[1].__dict__["test_binary"] == blob4
@pytest.mark.asyncio
async def test_binary_nullable_str_column():
async with database:
async with database.transaction(force_rollback=True):
await LargeBinaryNullableStr().save()
await LargeBinaryNullableStr.objects.create()
items = await LargeBinaryNullableStr.objects.all()
assert len(items) == 2
items[0].test_binary = blob3
items[1].test_binary = blob4
await LargeBinaryNullableStr.objects.bulk_update(items)
items = await LargeBinaryNullableStr.objects.all()
assert len(items) == 2
assert items[0].test_binary == base64.b64encode(blob3).decode()
items[0].test_binary = base64.b64encode(blob4).decode()
assert items[0].test_binary == base64.b64encode(blob4).decode()
assert items[1].test_binary == base64.b64encode(blob4).decode()
assert items[1].__dict__["test_binary"] == blob4
await LargeBinaryNullableStr.objects.bulk_create(
[LargeBinaryNullableStr(), LargeBinaryNullableStr(test_binary=blob3)]
)
items = await LargeBinaryNullableStr.objects.all()
assert len(items) == 4
await items[0].update(test_binary=blob4)
check_item = await LargeBinaryNullableStr.objects.get(id=items[0].id)
assert check_item.test_binary == base64.b64encode(blob4).decode()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_uuid_column(): async def test_uuid_column():
async with database: async with database: