check if data binding not work only in sqlite
This commit is contained in:
@ -53,6 +53,7 @@ from ormar.fields import (
|
||||
Time,
|
||||
UUID,
|
||||
UniqueColumns,
|
||||
EncryptBackends
|
||||
) # noqa: I100
|
||||
from ormar.models import ExcludableItems, Model
|
||||
from ormar.models.metaclass import ModelMeta
|
||||
@ -68,7 +69,7 @@ class UndefinedType: # pragma no cover
|
||||
|
||||
Undefined = UndefinedType()
|
||||
|
||||
__version__ = "0.9.7"
|
||||
__version__ = "0.9.8"
|
||||
__all__ = [
|
||||
"Integer",
|
||||
"BigInteger",
|
||||
@ -110,4 +111,5 @@ __all__ = [
|
||||
"ExcludableItems",
|
||||
"and_",
|
||||
"or_",
|
||||
"EncryptBackends"
|
||||
]
|
||||
|
||||
@ -22,6 +22,7 @@ from ormar.fields.model_fields import (
|
||||
UUID,
|
||||
)
|
||||
from ormar.fields.through_field import Through, ThroughField
|
||||
from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends
|
||||
|
||||
__all__ = [
|
||||
"Decimal",
|
||||
@ -44,4 +45,6 @@ __all__ = [
|
||||
"ForeignKeyField",
|
||||
"ThroughField",
|
||||
"Through",
|
||||
"EncryptBackends",
|
||||
"EncryptBackend"
|
||||
]
|
||||
|
||||
@ -5,6 +5,9 @@ from pydantic import Field, Json, typing
|
||||
from pydantic.fields import FieldInfo, Required, Undefined
|
||||
|
||||
import ormar # noqa I101
|
||||
from ormar import ModelDefinitionError
|
||||
from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends, \
|
||||
EncryptedString
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar.models import Model
|
||||
@ -49,6 +52,11 @@ class BaseField(FieldInfo):
|
||||
self_reference: bool = False
|
||||
self_reference_primary: Optional[str] = None
|
||||
|
||||
encrypt_secret: str
|
||||
encrypt_backend: EncryptBackends = EncryptBackends.NONE
|
||||
encrypt_custom_backend: Type[EncryptBackend] = None
|
||||
encrypt_max_length: int = 5000
|
||||
|
||||
default: Any
|
||||
server_default: Any
|
||||
|
||||
@ -93,10 +101,11 @@ class BaseField(FieldInfo):
|
||||
:rtype: bool
|
||||
"""
|
||||
return (
|
||||
field_name not in ["default", "default_factory", "alias", "allow_mutation"]
|
||||
and not field_name.startswith("__")
|
||||
and hasattr(cls, field_name)
|
||||
and not callable(getattr(cls, field_name))
|
||||
field_name not in ["default", "default_factory", "alias",
|
||||
"allow_mutation"]
|
||||
and not field_name.startswith("__")
|
||||
and hasattr(cls, field_name)
|
||||
and not callable(getattr(cls, field_name))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -205,7 +214,7 @@ class BaseField(FieldInfo):
|
||||
:rtype: bool
|
||||
"""
|
||||
return cls.default is not None or (
|
||||
cls.server_default is not None and use_server
|
||||
cls.server_default is not None and use_server
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -238,7 +247,7 @@ class BaseField(FieldInfo):
|
||||
ondelete=con.ondelete,
|
||||
onupdate=con.onupdate,
|
||||
name=f"fk_{cls.owner.Meta.tablename}_{cls.to.Meta.tablename}"
|
||||
f"_{cls.to.get_column_alias(cls.to.Meta.pkname)}_{cls.name}",
|
||||
f"_{cls.to.get_column_alias(cls.to.Meta.pkname)}_{cls.name}",
|
||||
)
|
||||
for con in cls.constraints
|
||||
]
|
||||
@ -256,25 +265,46 @@ class BaseField(FieldInfo):
|
||||
:return: actual definition of the database column as sqlalchemy requires.
|
||||
:rtype: sqlalchemy.Column
|
||||
"""
|
||||
column = sqlalchemy.Column(
|
||||
cls.alias or name,
|
||||
cls.column_type,
|
||||
*cls.construct_constraints(),
|
||||
primary_key=cls.primary_key,
|
||||
nullable=cls.nullable and not cls.primary_key,
|
||||
index=cls.index,
|
||||
unique=cls.unique,
|
||||
default=cls.default,
|
||||
server_default=cls.server_default,
|
||||
)
|
||||
if cls.encrypt_backend == EncryptBackends.NONE:
|
||||
column = sqlalchemy.Column(
|
||||
cls.alias or name,
|
||||
cls.column_type,
|
||||
*cls.construct_constraints(),
|
||||
primary_key=cls.primary_key,
|
||||
nullable=cls.nullable and not cls.primary_key,
|
||||
index=cls.index,
|
||||
unique=cls.unique,
|
||||
default=cls.default,
|
||||
server_default=cls.server_default,
|
||||
)
|
||||
else:
|
||||
if cls.primary_key or cls.is_relation:
|
||||
raise ModelDefinitionError("Primary key field and relations fields"
|
||||
"cannot be encrypted!")
|
||||
column = sqlalchemy.Column(
|
||||
cls.alias or name,
|
||||
EncryptedString(
|
||||
_field_type=cls,
|
||||
encrypt_secret=cls.encrypt_secret,
|
||||
encrypt_backend=cls.encrypt_backend,
|
||||
encrypt_custom_backend=cls.encrypt_custom_backend,
|
||||
encrypt_max_length=cls.encrypt_max_length
|
||||
),
|
||||
nullable=cls.nullable,
|
||||
index=cls.index,
|
||||
unique=cls.unique,
|
||||
default=cls.default,
|
||||
server_default=cls.server_default,
|
||||
)
|
||||
|
||||
return column
|
||||
|
||||
@classmethod
|
||||
def expand_relationship(
|
||||
cls,
|
||||
value: Any,
|
||||
child: Union["Model", "NewBaseModel"],
|
||||
to_register: bool = True,
|
||||
cls,
|
||||
value: Any,
|
||||
child: Union["Model", "NewBaseModel"],
|
||||
to_register: bool = True,
|
||||
) -> Any:
|
||||
"""
|
||||
Function overwritten for relations, in basic field the value is returned as is.
|
||||
@ -302,7 +332,7 @@ class BaseField(FieldInfo):
|
||||
:rtype: None
|
||||
"""
|
||||
if cls.owner is not None and (
|
||||
cls.owner == cls.to or cls.owner.Meta == cls.to.Meta
|
||||
cls.owner == cls.to or cls.owner.Meta == cls.to.Meta
|
||||
):
|
||||
cls.self_reference = True
|
||||
cls.self_reference_primary = cls.name
|
||||
|
||||
@ -9,6 +9,7 @@ import sqlalchemy
|
||||
from ormar import ModelDefinitionError # noqa I101
|
||||
from ormar.fields import sqlalchemy_uuid
|
||||
from ormar.fields.base import BaseField # noqa I101
|
||||
from ormar.fields.sqlalchemy_encrypted import EncryptBackends
|
||||
|
||||
|
||||
def is_field_nullable(
|
||||
@ -73,6 +74,12 @@ class ModelFieldFactory:
|
||||
primary_key = kwargs.pop("primary_key", False)
|
||||
autoincrement = kwargs.pop("autoincrement", False)
|
||||
|
||||
encrypt_secret = kwargs.pop("encrypt_secret", None)
|
||||
encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE)
|
||||
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend",
|
||||
None)
|
||||
encrypt_max_length = kwargs.pop("encrypt_max_length", 5000)
|
||||
|
||||
namespace = dict(
|
||||
__type__=cls._type,
|
||||
alias=kwargs.pop("name", None),
|
||||
@ -88,6 +95,10 @@ class ModelFieldFactory:
|
||||
autoincrement=autoincrement,
|
||||
column_type=cls.get_column_type(**kwargs),
|
||||
choices=set(kwargs.pop("choices", [])),
|
||||
encrypt_secret=encrypt_secret,
|
||||
encrypt_backend=encrypt_backend,
|
||||
encrypt_custom_backend=encrypt_custom_backend,
|
||||
encrypt_max_length=encrypt_max_length,
|
||||
**kwargs
|
||||
)
|
||||
return type(cls.__name__, cls._bases, namespace)
|
||||
|
||||
219
ormar/fields/sqlalchemy_encrypted.py
Normal file
219
ormar/fields/sqlalchemy_encrypted.py
Normal file
@ -0,0 +1,219 @@
|
||||
# inspired by sqlalchemy-utils (https://github.com/kvesteri/sqlalchemy-utils)
|
||||
import abc
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, TYPE_CHECKING, Type, Union
|
||||
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.engine.default import DefaultDialect
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
from ormar import ModelDefinitionError
|
||||
|
||||
try:
|
||||
import cryptography
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ormar import BaseField
|
||||
|
||||
|
||||
class EncryptBackend(abc.ABC):
|
||||
|
||||
def _update_key(self, key):
|
||||
if isinstance(key, str):
|
||||
key = key.encode()
|
||||
digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
|
||||
digest.update(key)
|
||||
engine_key = digest.finalize()
|
||||
|
||||
self._initialize_engine(engine_key)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _initialize_engine(self, secret_key: bytes):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def encrypt(self, value: Any) -> str:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def decrypt(self, value: Any) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class HashBackend(EncryptBackend):
|
||||
"""
|
||||
One-way hashing - in example for passwords, no way to decrypt the value!
|
||||
"""
|
||||
|
||||
def _initialize_engine(self, secret_key: bytes):
|
||||
self.secret_key = base64.urlsafe_b64encode(secret_key)
|
||||
|
||||
def encrypt(self, value: Any) -> str:
|
||||
if not isinstance(value, str):
|
||||
value = repr(value)
|
||||
value = value.encode()
|
||||
digest = hashes.Hash(hashes.SHA512(), backend=default_backend())
|
||||
digest.update(self.secret_key)
|
||||
digest.update(value)
|
||||
hashed_value = digest.finalize()
|
||||
return hashed_value.hex()
|
||||
|
||||
def decrypt(self, value: Any) -> str:
|
||||
if not isinstance(value, str):
|
||||
value = str(value)
|
||||
return value
|
||||
|
||||
|
||||
class FernetBackend(EncryptBackend):
|
||||
"""
|
||||
Two-way encryption, data stored in db are encrypted but decrypted during query.
|
||||
"""
|
||||
|
||||
def _initialize_engine(self, secret_key: bytes):
|
||||
self.secret_key = base64.urlsafe_b64encode(secret_key)
|
||||
self.fernet = Fernet(self.secret_key)
|
||||
|
||||
def encrypt(self, value: Any) -> str:
|
||||
if not isinstance(value, str):
|
||||
value = repr(value)
|
||||
value = value.encode()
|
||||
encrypted = self.fernet.encrypt(value)
|
||||
return encrypted.decode('utf-8')
|
||||
|
||||
def decrypt(self, value: Any) -> str:
|
||||
if not isinstance(value, str):
|
||||
value = str(value)
|
||||
decrypted = self.fernet.decrypt(value.encode())
|
||||
if not isinstance(decrypted, str):
|
||||
decrypted = decrypted.decode('utf-8')
|
||||
return decrypted
|
||||
|
||||
|
||||
class EncryptBackends(Enum):
|
||||
NONE = 0
|
||||
FERNET = 1
|
||||
HASH = 2
|
||||
CUSTOM = 3
|
||||
|
||||
|
||||
backends_map = {
|
||||
EncryptBackends.FERNET: FernetBackend,
|
||||
EncryptBackends.HASH: HashBackend,
|
||||
EncryptBackends.CUSTOM: None
|
||||
}
|
||||
|
||||
|
||||
class EncryptedString(TypeDecorator): # pragma nocover
|
||||
"""
|
||||
Used to store encrypted values in a database
|
||||
"""
|
||||
|
||||
impl = String
|
||||
|
||||
def __init__(self,
|
||||
*args: Any,
|
||||
encrypt_secret: Union[str, Callable],
|
||||
_field_type: Type["BaseField"],
|
||||
encrypt_max_length: int = 5000,
|
||||
encrypt_backend: EncryptBackends = EncryptBackends.FERNET,
|
||||
encrypt_custom_backend: Type[EncryptBackend] = None,
|
||||
**kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
if not cryptography:
|
||||
raise ModelDefinitionError(
|
||||
"In order to encrypt a column 'cryptography' is required!"
|
||||
)
|
||||
backend = backends_map.get(encrypt_backend, encrypt_custom_backend)
|
||||
if not backend or not issubclass(backend, EncryptBackend):
|
||||
raise ModelDefinitionError("Wrong or no encrypt backend provided!")
|
||||
self.backend = backend()
|
||||
self._field_type = _field_type
|
||||
self._underlying_type = _field_type.column_type
|
||||
self._key = encrypt_secret
|
||||
self.max_length = encrypt_max_length
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"String({self.max_length})"
|
||||
#
|
||||
# def load_dialect_impl(self, dialect: DefaultDialect) -> Any:
|
||||
# dialect.type_descriptor(VARCHAR(self.max_length))
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return self._key
|
||||
|
||||
@key.setter
|
||||
def key(self, value):
|
||||
self._key = value
|
||||
|
||||
def _update_key(self):
|
||||
key = self._key() if callable(self._key) else self._key
|
||||
self.backend._update_key(key)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
"""Encrypt a value on the way in."""
|
||||
if value is not None:
|
||||
self._update_key()
|
||||
|
||||
try:
|
||||
value = self._underlying_type.process_bind_param(
|
||||
value, dialect
|
||||
)
|
||||
|
||||
except AttributeError:
|
||||
# Doesn't have 'process_bind_param'
|
||||
type_ = self._field_type.__type__
|
||||
if issubclass(type_, bool):
|
||||
value = 'true' if value else 'false'
|
||||
|
||||
elif issubclass(type_, (datetime.date, datetime.time)):
|
||||
value = value.isoformat()
|
||||
|
||||
# elif issubclass(type_, JSONType):
|
||||
# value = json.dumps(value)
|
||||
|
||||
return self.backend.encrypt(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
"""Decrypt value on the way out."""
|
||||
if value is not None:
|
||||
self._update_key()
|
||||
decrypted_value = self.backend.decrypt(value)
|
||||
|
||||
try:
|
||||
return self.underlying_type.process_result_value(
|
||||
decrypted_value, dialect
|
||||
)
|
||||
|
||||
except AttributeError:
|
||||
# Doesn't have 'process_result_value'
|
||||
|
||||
# Handle 'boolean' and 'dates'
|
||||
type_ = self._field_type.__type__
|
||||
# date_types = [datetime.datetime, datetime.time, datetime.date]
|
||||
|
||||
if issubclass(type_, bool):
|
||||
return decrypted_value == 'true'
|
||||
|
||||
# elif type_ in date_types:
|
||||
# return DatetimeHandler.process_value(
|
||||
# decrypted_value, type_
|
||||
# )
|
||||
|
||||
# elif issubclass(type_, JSONType):
|
||||
# return json.loads(decrypted_value)
|
||||
|
||||
# Handle all others
|
||||
return self.underlying_type.python_type(decrypted_value)
|
||||
|
||||
def _coerce(self, value):
|
||||
return self.underlying_type._coerce(value)
|
||||
|
||||
@ -289,7 +289,7 @@ def populate_meta_sqlalchemy_table_if_required(meta: "ModelMeta") -> None:
|
||||
f'{"_".join([str(col) for col in constraint._pending_colargs])}'
|
||||
)
|
||||
table = sqlalchemy.Table(
|
||||
meta.tablename, meta.metadata, *meta.columns, *meta.constraints,
|
||||
meta.tablename, meta.metadata, *meta.columns, *meta.constraints
|
||||
)
|
||||
meta.table = table
|
||||
|
||||
|
||||
60
tests/test_encrypted_columns.py
Normal file
60
tests/test_encrypted_columns.py
Normal file
@ -0,0 +1,60 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import databases
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import QueryDefinitionError
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
class BaseMeta(ormar.ModelMeta):
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
|
||||
class Author(ormar.Model):
|
||||
class Meta(BaseMeta):
|
||||
tablename = "authors"
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100,
|
||||
encrypt_secret='asd123',
|
||||
encrypt_backend=ormar.EncryptBackends.FERNET)
|
||||
uuid_test = ormar.UUID(default=uuid.uuid4, uuid_format='string')
|
||||
password: str = ormar.String(max_length=100,
|
||||
encrypt_secret='udxc32',
|
||||
encrypt_backend=ormar.EncryptBackends.HASH)
|
||||
birth_year: int = ormar.Integer(nullable=True,
|
||||
encrypt_secret='secure89key%^&psdijfipew',
|
||||
encrypt_max_length=200,
|
||||
encrypt_backend=ormar.EncryptBackends.FERNET)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def create_test_database():
|
||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||
metadata.drop_all(engine)
|
||||
metadata.create_all(engine)
|
||||
yield
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
def test_db_structure():
|
||||
assert Author.Meta.table.c.get('name').type.impl.__class__ == sqlalchemy.NVARCHAR
|
||||
assert Author.Meta.table.c.get('birth_year').type.max_length == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrong_query_foreign_key_type():
|
||||
async with database:
|
||||
await Author(name='Test', birth_year=1988, password='test123').save()
|
||||
author = await Author.objects.get()
|
||||
|
||||
assert author.name == 'Test'
|
||||
assert author.birth_year == 1988
|
||||
Reference in New Issue
Block a user