From 15be5170f3116e9b878c04ef5f57ace3bcc9c8e6 Mon Sep 17 00:00:00 2001 From: collerek Date: Thu, 16 Dec 2021 14:44:01 +0100 Subject: [PATCH] fix #409 - nullable large binary fields --- ormar/models/descriptors/descriptors.py | 6 ++- ormar/models/mixins/save_mixin.py | 32 +++++++++++++++ ormar/models/newbasemodel.py | 8 +++- ormar/queryset/queryset.py | 1 + tests/test_fastapi/test_binary_fields.py | 2 - tests/test_model_definition/test_models.py | 46 ++++++++++++++++++++++ 6 files changed, 90 insertions(+), 5 deletions(-) diff --git a/ormar/models/descriptors/descriptors.py b/ormar/models/descriptors/descriptors.py index dcabd98..3d168e6 100644 --- a/ormar/models/descriptors/descriptors.py +++ b/ormar/models/descriptors/descriptors.py @@ -59,7 +59,11 @@ class BytesDescriptor: def __get__(self, instance: "Model", owner: Type["Model"]) -> Any: value = instance.__dict__.get(self.name, None) field = instance.Meta.model_fields[self.name] - if field.represent_as_base64_str and not isinstance(value, str): + if ( + value is not None + and field.represent_as_base64_str + and not isinstance(value, str) + ): value = base64.b64encode(value).decode() return value diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index 1900a65..e55f043 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -1,3 +1,4 @@ +import base64 import uuid from typing import ( Any, @@ -51,6 +52,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin): new_kwargs = cls._remove_not_ormar_fields(new_kwargs) new_kwargs = cls.substitute_models_with_pks(new_kwargs) new_kwargs = cls.populate_default_values(new_kwargs) + new_kwargs = cls.reconvert_str_to_bytes(new_kwargs) new_kwargs = cls.translate_columns_to_aliases(new_kwargs) return new_kwargs @@ -144,6 +146,36 @@ class SavePrepareMixin(RelationMixin, AliasMixin): model_dict.pop(field, None) return model_dict + @classmethod + def reconvert_str_to_bytes(cls, model_dict: Dict) -> Dict: + """ + Receives dictionary of model that is about to be saved and changes + all bytes fields that are represented as strings back into bytes. + + :param model_dict: dictionary of model that is about to be saved + :type model_dict: Dict + :return: dictionary of model that is about to be saved + :rtype: Dict + """ + bytes_fields = { + name + for name, field in cls.Meta.model_fields.items() + if field.__type__ == bytes + } + bytes_base64_fields = { + name + for name, field in cls.Meta.model_fields.items() + if field.represent_as_base64_str + } + for key, value in model_dict.items(): + if key in bytes_fields and isinstance(value, str): + model_dict[key] = ( + value.encode("utf-8") + if key not in bytes_base64_fields + else base64.b64decode(value) + ) + return model_dict + @classmethod def populate_default_values(cls, new_kwargs: Dict) -> Dict: """ diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 1fbf109..599ccdf 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -861,7 +861,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass if column_name not in self._bytes_fields: return value field = self.Meta.model_fields[column_name] - if not isinstance(value, bytes): + if not isinstance(value, bytes) and value is not None: if field.represent_as_base64_str: value = base64.b64decode(value) else: @@ -882,7 +882,11 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass if column_name not in self._bytes_fields: return value field = self.Meta.model_fields[column_name] - if not isinstance(value, str) and field.represent_as_base64_str: + if ( + value is not None + and not isinstance(value, str) + and field.represent_as_base64_str + ): return base64.b64encode(value).decode() return value diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index f439178..886d150 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1096,6 +1096,7 @@ class QuerySet(Generic[T]): ) new_kwargs = self.model.parse_non_db_fields(new_kwargs) new_kwargs = self.model.substitute_models_with_pks(new_kwargs) + new_kwargs = self.model.reconvert_str_to_bytes(new_kwargs) new_kwargs = self.model.translate_columns_to_aliases(new_kwargs) new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns} ready_objects.append(new_kwargs) diff --git a/tests/test_fastapi/test_binary_fields.py b/tests/test_fastapi/test_binary_fields.py index 8108e22..0863733 100644 --- a/tests/test_fastapi/test_binary_fields.py +++ b/tests/test_fastapi/test_binary_fields.py @@ -1,11 +1,9 @@ import base64 import json -import os import uuid from typing import List import databases -import pydantic import pytest import sqlalchemy from fastapi import FastAPI diff --git a/tests/test_model_definition/test_models.py b/tests/test_model_definition/test_models.py index 650d88c..29d0b81 100644 --- a/tests/test_model_definition/test_models.py +++ b/tests/test_model_definition/test_models.py @@ -58,6 +58,21 @@ class LargeBinaryStr(ormar.Model): ) +class LargeBinaryNullableStr(ormar.Model): + class Meta: + tablename = "my_str_blobs2" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + test_binary: str = ormar.LargeBinary( + max_length=100000, + choices=[blob3, blob4], + represent_as_base64_str=True, + nullable=True, + ) + + class UUIDSample(ormar.Model): class Meta: tablename = "uuids" @@ -231,6 +246,37 @@ async def test_binary_str_column(): assert items[1].__dict__["test_binary"] == blob4 +@pytest.mark.asyncio +async def test_binary_nullable_str_column(): + async with database: + async with database.transaction(force_rollback=True): + await LargeBinaryNullableStr().save() + await LargeBinaryNullableStr.objects.create() + items = await LargeBinaryNullableStr.objects.all() + assert len(items) == 2 + + items[0].test_binary = blob3 + items[1].test_binary = blob4 + + await LargeBinaryNullableStr.objects.bulk_update(items) + items = await LargeBinaryNullableStr.objects.all() + assert len(items) == 2 + assert items[0].test_binary == base64.b64encode(blob3).decode() + items[0].test_binary = base64.b64encode(blob4).decode() + assert items[0].test_binary == base64.b64encode(blob4).decode() + assert items[1].test_binary == base64.b64encode(blob4).decode() + assert items[1].__dict__["test_binary"] == blob4 + + await LargeBinaryNullableStr.objects.bulk_create( + [LargeBinaryNullableStr(), LargeBinaryNullableStr(test_binary=blob3)] + ) + items = await LargeBinaryNullableStr.objects.all() + assert len(items) == 4 + await items[0].update(test_binary=blob4) + check_item = await LargeBinaryNullableStr.objects.get(id=items[0].id) + assert check_item.test_binary == base64.b64encode(blob4).decode() + + @pytest.mark.asyncio async def test_uuid_column(): async with database: