fix #409 - nullable large binary fields
This commit is contained in:
@ -59,7 +59,11 @@ class BytesDescriptor:
|
|||||||
def __get__(self, instance: "Model", owner: Type["Model"]) -> Any:
|
def __get__(self, instance: "Model", owner: Type["Model"]) -> Any:
|
||||||
value = instance.__dict__.get(self.name, None)
|
value = instance.__dict__.get(self.name, None)
|
||||||
field = instance.Meta.model_fields[self.name]
|
field = instance.Meta.model_fields[self.name]
|
||||||
if field.represent_as_base64_str and not isinstance(value, str):
|
if (
|
||||||
|
value is not None
|
||||||
|
and field.represent_as_base64_str
|
||||||
|
and not isinstance(value, str)
|
||||||
|
):
|
||||||
value = base64.b64encode(value).decode()
|
value = base64.b64encode(value).decode()
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -51,6 +52,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
|
|||||||
new_kwargs = cls._remove_not_ormar_fields(new_kwargs)
|
new_kwargs = cls._remove_not_ormar_fields(new_kwargs)
|
||||||
new_kwargs = cls.substitute_models_with_pks(new_kwargs)
|
new_kwargs = cls.substitute_models_with_pks(new_kwargs)
|
||||||
new_kwargs = cls.populate_default_values(new_kwargs)
|
new_kwargs = cls.populate_default_values(new_kwargs)
|
||||||
|
new_kwargs = cls.reconvert_str_to_bytes(new_kwargs)
|
||||||
new_kwargs = cls.translate_columns_to_aliases(new_kwargs)
|
new_kwargs = cls.translate_columns_to_aliases(new_kwargs)
|
||||||
return new_kwargs
|
return new_kwargs
|
||||||
|
|
||||||
@ -144,6 +146,36 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
|
|||||||
model_dict.pop(field, None)
|
model_dict.pop(field, None)
|
||||||
return model_dict
|
return model_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reconvert_str_to_bytes(cls, model_dict: Dict) -> Dict:
|
||||||
|
"""
|
||||||
|
Receives dictionary of model that is about to be saved and changes
|
||||||
|
all bytes fields that are represented as strings back into bytes.
|
||||||
|
|
||||||
|
:param model_dict: dictionary of model that is about to be saved
|
||||||
|
:type model_dict: Dict
|
||||||
|
:return: dictionary of model that is about to be saved
|
||||||
|
:rtype: Dict
|
||||||
|
"""
|
||||||
|
bytes_fields = {
|
||||||
|
name
|
||||||
|
for name, field in cls.Meta.model_fields.items()
|
||||||
|
if field.__type__ == bytes
|
||||||
|
}
|
||||||
|
bytes_base64_fields = {
|
||||||
|
name
|
||||||
|
for name, field in cls.Meta.model_fields.items()
|
||||||
|
if field.represent_as_base64_str
|
||||||
|
}
|
||||||
|
for key, value in model_dict.items():
|
||||||
|
if key in bytes_fields and isinstance(value, str):
|
||||||
|
model_dict[key] = (
|
||||||
|
value.encode("utf-8")
|
||||||
|
if key not in bytes_base64_fields
|
||||||
|
else base64.b64decode(value)
|
||||||
|
)
|
||||||
|
return model_dict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def populate_default_values(cls, new_kwargs: Dict) -> Dict:
|
def populate_default_values(cls, new_kwargs: Dict) -> Dict:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -861,7 +861,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
if column_name not in self._bytes_fields:
|
if column_name not in self._bytes_fields:
|
||||||
return value
|
return value
|
||||||
field = self.Meta.model_fields[column_name]
|
field = self.Meta.model_fields[column_name]
|
||||||
if not isinstance(value, bytes):
|
if not isinstance(value, bytes) and value is not None:
|
||||||
if field.represent_as_base64_str:
|
if field.represent_as_base64_str:
|
||||||
value = base64.b64decode(value)
|
value = base64.b64decode(value)
|
||||||
else:
|
else:
|
||||||
@ -882,7 +882,11 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
if column_name not in self._bytes_fields:
|
if column_name not in self._bytes_fields:
|
||||||
return value
|
return value
|
||||||
field = self.Meta.model_fields[column_name]
|
field = self.Meta.model_fields[column_name]
|
||||||
if not isinstance(value, str) and field.represent_as_base64_str:
|
if (
|
||||||
|
value is not None
|
||||||
|
and not isinstance(value, str)
|
||||||
|
and field.represent_as_base64_str
|
||||||
|
):
|
||||||
return base64.b64encode(value).decode()
|
return base64.b64encode(value).decode()
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|||||||
@ -1096,6 +1096,7 @@ class QuerySet(Generic[T]):
|
|||||||
)
|
)
|
||||||
new_kwargs = self.model.parse_non_db_fields(new_kwargs)
|
new_kwargs = self.model.parse_non_db_fields(new_kwargs)
|
||||||
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
|
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
|
||||||
|
new_kwargs = self.model.reconvert_str_to_bytes(new_kwargs)
|
||||||
new_kwargs = self.model.translate_columns_to_aliases(new_kwargs)
|
new_kwargs = self.model.translate_columns_to_aliases(new_kwargs)
|
||||||
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
|
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
|
||||||
ready_objects.append(new_kwargs)
|
ready_objects.append(new_kwargs)
|
||||||
|
|||||||
@ -1,11 +1,9 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import databases
|
import databases
|
||||||
import pydantic
|
|
||||||
import pytest
|
import pytest
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|||||||
@ -58,6 +58,21 @@ class LargeBinaryStr(ormar.Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LargeBinaryNullableStr(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "my_str_blobs2"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: int = ormar.Integer(primary_key=True)
|
||||||
|
test_binary: str = ormar.LargeBinary(
|
||||||
|
max_length=100000,
|
||||||
|
choices=[blob3, blob4],
|
||||||
|
represent_as_base64_str=True,
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class UUIDSample(ormar.Model):
|
class UUIDSample(ormar.Model):
|
||||||
class Meta:
|
class Meta:
|
||||||
tablename = "uuids"
|
tablename = "uuids"
|
||||||
@ -231,6 +246,37 @@ async def test_binary_str_column():
|
|||||||
assert items[1].__dict__["test_binary"] == blob4
|
assert items[1].__dict__["test_binary"] == blob4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_binary_nullable_str_column():
|
||||||
|
async with database:
|
||||||
|
async with database.transaction(force_rollback=True):
|
||||||
|
await LargeBinaryNullableStr().save()
|
||||||
|
await LargeBinaryNullableStr.objects.create()
|
||||||
|
items = await LargeBinaryNullableStr.objects.all()
|
||||||
|
assert len(items) == 2
|
||||||
|
|
||||||
|
items[0].test_binary = blob3
|
||||||
|
items[1].test_binary = blob4
|
||||||
|
|
||||||
|
await LargeBinaryNullableStr.objects.bulk_update(items)
|
||||||
|
items = await LargeBinaryNullableStr.objects.all()
|
||||||
|
assert len(items) == 2
|
||||||
|
assert items[0].test_binary == base64.b64encode(blob3).decode()
|
||||||
|
items[0].test_binary = base64.b64encode(blob4).decode()
|
||||||
|
assert items[0].test_binary == base64.b64encode(blob4).decode()
|
||||||
|
assert items[1].test_binary == base64.b64encode(blob4).decode()
|
||||||
|
assert items[1].__dict__["test_binary"] == blob4
|
||||||
|
|
||||||
|
await LargeBinaryNullableStr.objects.bulk_create(
|
||||||
|
[LargeBinaryNullableStr(), LargeBinaryNullableStr(test_binary=blob3)]
|
||||||
|
)
|
||||||
|
items = await LargeBinaryNullableStr.objects.all()
|
||||||
|
assert len(items) == 4
|
||||||
|
await items[0].update(test_binary=blob4)
|
||||||
|
check_item = await LargeBinaryNullableStr.objects.get(id=items[0].id)
|
||||||
|
assert check_item.test_binary == base64.b64encode(blob4).decode()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_uuid_column():
|
async def test_uuid_column():
|
||||||
async with database:
|
async with database:
|
||||||
|
|||||||
Reference in New Issue
Block a user