revert to use tables and columns with labels and aliases instead of text clauses, add encryption, mostly working encryption column type with configurable backends
This commit is contained in:
@ -38,9 +38,12 @@ from ormar.fields import (
|
|||||||
BaseField,
|
BaseField,
|
||||||
BigInteger,
|
BigInteger,
|
||||||
Boolean,
|
Boolean,
|
||||||
|
DECODERS_MAP,
|
||||||
Date,
|
Date,
|
||||||
DateTime,
|
DateTime,
|
||||||
Decimal,
|
Decimal,
|
||||||
|
ENCODERS_MAP,
|
||||||
|
EncryptBackends,
|
||||||
Float,
|
Float,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
ForeignKeyField,
|
ForeignKeyField,
|
||||||
@ -53,7 +56,6 @@ from ormar.fields import (
|
|||||||
Time,
|
Time,
|
||||||
UUID,
|
UUID,
|
||||||
UniqueColumns,
|
UniqueColumns,
|
||||||
EncryptBackends
|
|
||||||
) # noqa: I100
|
) # noqa: I100
|
||||||
from ormar.models import ExcludableItems, Model
|
from ormar.models import ExcludableItems, Model
|
||||||
from ormar.models.metaclass import ModelMeta
|
from ormar.models.metaclass import ModelMeta
|
||||||
@ -111,5 +113,7 @@ __all__ = [
|
|||||||
"ExcludableItems",
|
"ExcludableItems",
|
||||||
"and_",
|
"and_",
|
||||||
"or_",
|
"or_",
|
||||||
"EncryptBackends"
|
"EncryptBackends",
|
||||||
|
"ENCODERS_MAP",
|
||||||
|
"DECODERS_MAP",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -21,8 +21,9 @@ from ormar.fields.model_fields import (
|
|||||||
Time,
|
Time,
|
||||||
UUID,
|
UUID,
|
||||||
)
|
)
|
||||||
from ormar.fields.through_field import Through, ThroughField
|
from ormar.fields.parsers import DECODERS_MAP, ENCODERS_MAP
|
||||||
from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends
|
from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends
|
||||||
|
from ormar.fields.through_field import Through, ThroughField
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Decimal",
|
"Decimal",
|
||||||
@ -46,5 +47,7 @@ __all__ = [
|
|||||||
"ThroughField",
|
"ThroughField",
|
||||||
"Through",
|
"Through",
|
||||||
"EncryptBackends",
|
"EncryptBackends",
|
||||||
"EncryptBackend"
|
"EncryptBackend",
|
||||||
|
"DECODERS_MAP",
|
||||||
|
"ENCODERS_MAP",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -6,8 +6,11 @@ from pydantic.fields import FieldInfo, Required, Undefined
|
|||||||
|
|
||||||
import ormar # noqa I101
|
import ormar # noqa I101
|
||||||
from ormar import ModelDefinitionError
|
from ormar import ModelDefinitionError
|
||||||
from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends, \
|
from ormar.fields.sqlalchemy_encrypted import (
|
||||||
EncryptedString
|
EncryptBackend,
|
||||||
|
EncryptBackends,
|
||||||
|
EncryptedString,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar.models import Model
|
from ormar.models import Model
|
||||||
@ -54,7 +57,7 @@ class BaseField(FieldInfo):
|
|||||||
|
|
||||||
encrypt_secret: str
|
encrypt_secret: str
|
||||||
encrypt_backend: EncryptBackends = EncryptBackends.NONE
|
encrypt_backend: EncryptBackends = EncryptBackends.NONE
|
||||||
encrypt_custom_backend: Type[EncryptBackend] = None
|
encrypt_custom_backend: Optional[Type[EncryptBackend]] = None
|
||||||
encrypt_max_length: int = 5000
|
encrypt_max_length: int = 5000
|
||||||
|
|
||||||
default: Any
|
default: Any
|
||||||
@ -101,8 +104,7 @@ class BaseField(FieldInfo):
|
|||||||
:rtype: bool
|
:rtype: bool
|
||||||
"""
|
"""
|
||||||
return (
|
return (
|
||||||
field_name not in ["default", "default_factory", "alias",
|
field_name not in ["default", "default_factory", "alias", "allow_mutation"]
|
||||||
"allow_mutation"]
|
|
||||||
and not field_name.startswith("__")
|
and not field_name.startswith("__")
|
||||||
and hasattr(cls, field_name)
|
and hasattr(cls, field_name)
|
||||||
and not callable(getattr(cls, field_name))
|
and not callable(getattr(cls, field_name))
|
||||||
@ -278,9 +280,23 @@ class BaseField(FieldInfo):
|
|||||||
server_default=cls.server_default,
|
server_default=cls.server_default,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
column = cls._get_encrypted_column(name=name)
|
||||||
|
return column
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_encrypted_column(cls, name: str) -> sqlalchemy.Column:
|
||||||
|
"""
|
||||||
|
Returns EncryptedString column type instead of actual column.
|
||||||
|
|
||||||
|
:param name: column name
|
||||||
|
:type name: str
|
||||||
|
:return: newly defined column
|
||||||
|
:rtype: sqlalchemy.Column
|
||||||
|
"""
|
||||||
if cls.primary_key or cls.is_relation:
|
if cls.primary_key or cls.is_relation:
|
||||||
raise ModelDefinitionError("Primary key field and relations fields"
|
raise ModelDefinitionError(
|
||||||
"cannot be encrypted!")
|
"Primary key field and relations fields" "cannot be encrypted!"
|
||||||
|
)
|
||||||
column = sqlalchemy.Column(
|
column = sqlalchemy.Column(
|
||||||
cls.alias or name,
|
cls.alias or name,
|
||||||
EncryptedString(
|
EncryptedString(
|
||||||
@ -288,7 +304,7 @@ class BaseField(FieldInfo):
|
|||||||
encrypt_secret=cls.encrypt_secret,
|
encrypt_secret=cls.encrypt_secret,
|
||||||
encrypt_backend=cls.encrypt_backend,
|
encrypt_backend=cls.encrypt_backend,
|
||||||
encrypt_custom_backend=cls.encrypt_custom_backend,
|
encrypt_custom_backend=cls.encrypt_custom_backend,
|
||||||
encrypt_max_length=cls.encrypt_max_length
|
encrypt_max_length=cls.encrypt_max_length,
|
||||||
),
|
),
|
||||||
nullable=cls.nullable,
|
nullable=cls.nullable,
|
||||||
index=cls.index,
|
index=cls.index,
|
||||||
@ -296,7 +312,6 @@ class BaseField(FieldInfo):
|
|||||||
default=cls.default,
|
default=cls.default,
|
||||||
server_default=cls.server_default,
|
server_default=cls.server_default,
|
||||||
)
|
)
|
||||||
|
|
||||||
return column
|
return column
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -184,10 +184,25 @@ def ForeignKey( # noqa CFQ002
|
|||||||
|
|
||||||
owner = kwargs.pop("owner", None)
|
owner = kwargs.pop("owner", None)
|
||||||
self_reference = kwargs.pop("self_reference", False)
|
self_reference = kwargs.pop("self_reference", False)
|
||||||
|
|
||||||
default = kwargs.pop("default", None)
|
default = kwargs.pop("default", None)
|
||||||
if default is not None:
|
encrypt_secret = kwargs.pop("encrypt_secret", None)
|
||||||
|
encrypt_backend = kwargs.pop("encrypt_backend", None)
|
||||||
|
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
|
||||||
|
encrypt_max_length = kwargs.pop("encrypt_max_length", None)
|
||||||
|
|
||||||
|
not_supported = [
|
||||||
|
default,
|
||||||
|
encrypt_secret,
|
||||||
|
encrypt_backend,
|
||||||
|
encrypt_custom_backend,
|
||||||
|
encrypt_max_length,
|
||||||
|
]
|
||||||
|
if any(x is not None for x in not_supported):
|
||||||
raise ModelDefinitionError(
|
raise ModelDefinitionError(
|
||||||
"Argument 'default' is not supported " "on relation fields!"
|
f"Argument {next((x for x in not_supported if x is not None))} "
|
||||||
|
f"is not supported "
|
||||||
|
"on relation fields!"
|
||||||
)
|
)
|
||||||
|
|
||||||
if to.__class__ == ForwardRef:
|
if to.__class__ == ForwardRef:
|
||||||
@ -386,8 +401,6 @@ class ForeignKeyField(BaseField):
|
|||||||
:return: (if needed) registered Model
|
:return: (if needed) registered Model
|
||||||
:rtype: Model
|
:rtype: Model
|
||||||
"""
|
"""
|
||||||
if cls.to.pk_type() == uuid.UUID and isinstance(value, str):
|
|
||||||
value = uuid.UUID(value)
|
|
||||||
if not isinstance(value, cls.to.pk_type()):
|
if not isinstance(value, cls.to.pk_type()):
|
||||||
raise RelationshipInstanceError(
|
raise RelationshipInstanceError(
|
||||||
f"Relationship error - ForeignKey {cls.to.__name__} "
|
f"Relationship error - ForeignKey {cls.to.__name__} "
|
||||||
|
|||||||
@ -97,9 +97,23 @@ def ManyToMany(
|
|||||||
forbid_through_relations(cast(Type["Model"], through))
|
forbid_through_relations(cast(Type["Model"], through))
|
||||||
|
|
||||||
default = kwargs.pop("default", None)
|
default = kwargs.pop("default", None)
|
||||||
if default is not None:
|
encrypt_secret = kwargs.pop("encrypt_secret", None)
|
||||||
|
encrypt_backend = kwargs.pop("encrypt_backend", None)
|
||||||
|
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
|
||||||
|
encrypt_max_length = kwargs.pop("encrypt_max_length", None)
|
||||||
|
|
||||||
|
not_supported = [
|
||||||
|
default,
|
||||||
|
encrypt_secret,
|
||||||
|
encrypt_backend,
|
||||||
|
encrypt_custom_backend,
|
||||||
|
encrypt_max_length,
|
||||||
|
]
|
||||||
|
if any(x is not None for x in not_supported):
|
||||||
raise ModelDefinitionError(
|
raise ModelDefinitionError(
|
||||||
"Argument 'default' is not supported " "on relation fields!"
|
f"Argument {next((x for x in not_supported if x is not None))} "
|
||||||
|
f"is not supported "
|
||||||
|
"on relation fields!"
|
||||||
)
|
)
|
||||||
|
|
||||||
if to.__class__ == ForwardRef:
|
if to.__class__ == ForwardRef:
|
||||||
|
|||||||
@ -76,8 +76,7 @@ class ModelFieldFactory:
|
|||||||
|
|
||||||
encrypt_secret = kwargs.pop("encrypt_secret", None)
|
encrypt_secret = kwargs.pop("encrypt_secret", None)
|
||||||
encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE)
|
encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE)
|
||||||
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend",
|
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
|
||||||
None)
|
|
||||||
encrypt_max_length = kwargs.pop("encrypt_max_length", 5000)
|
encrypt_max_length = kwargs.pop("encrypt_max_length", 5000)
|
||||||
|
|
||||||
namespace = dict(
|
namespace = dict(
|
||||||
|
|||||||
44
ormar/fields/parsers.py
Normal file
44
ormar/fields/parsers.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import datetime
|
||||||
|
import decimal
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
from pydantic.datetime_parse import parse_date, parse_datetime, parse_time
|
||||||
|
|
||||||
|
try:
|
||||||
|
import orjson as json
|
||||||
|
except ImportError: # pragma: no cover
|
||||||
|
import json # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def parse_bool(value: str) -> bool:
|
||||||
|
return value == "true"
|
||||||
|
|
||||||
|
|
||||||
|
def encode_bool(value: bool) -> str:
|
||||||
|
return "true" if value else "false"
|
||||||
|
|
||||||
|
|
||||||
|
def encode_json(value: Any) -> str:
|
||||||
|
value = json.dumps(value) if not isinstance(value, str) else value
|
||||||
|
value = value.decode("utf-8") if isinstance(value, bytes) else value
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
ENCODERS_MAP = {
|
||||||
|
bool: encode_bool,
|
||||||
|
datetime.datetime: lambda x: x.isoformat(),
|
||||||
|
datetime.date: lambda x: x.isoformat(),
|
||||||
|
datetime.time: lambda x: x.isoformat(),
|
||||||
|
pydantic.Json: encode_json,
|
||||||
|
decimal.Decimal: float,
|
||||||
|
}
|
||||||
|
|
||||||
|
DECODERS_MAP = {
|
||||||
|
bool: parse_bool,
|
||||||
|
datetime.datetime: parse_datetime,
|
||||||
|
datetime.date: parse_date,
|
||||||
|
datetime.time: parse_time,
|
||||||
|
pydantic.Json: json.loads,
|
||||||
|
decimal.Decimal: decimal.Decimal,
|
||||||
|
}
|
||||||
@ -1,50 +1,49 @@
|
|||||||
# inspired by sqlalchemy-utils (https://github.com/kvesteri/sqlalchemy-utils)
|
# inspired by sqlalchemy-utils (https://github.com/kvesteri/sqlalchemy-utils)
|
||||||
import abc
|
import abc
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, TYPE_CHECKING, Type, Union
|
from typing import Any, Callable, Optional, TYPE_CHECKING, Type, Union
|
||||||
|
|
||||||
import sqlalchemy.types as types
|
import sqlalchemy.types as types
|
||||||
|
from pydantic.utils import lenient_issubclass
|
||||||
from sqlalchemy.engine.default import DefaultDialect
|
from sqlalchemy.engine.default import DefaultDialect
|
||||||
|
|
||||||
from ormar import ModelDefinitionError
|
import ormar # noqa: I100, I202
|
||||||
|
from ormar import ModelDefinitionError # noqa: I202, I100
|
||||||
|
|
||||||
cryptography = None
|
cryptography = None
|
||||||
try:
|
try: # pragma: nocover
|
||||||
import cryptography
|
import cryptography # type: ignore
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
from cryptography.hazmat.primitives import hashes
|
from cryptography.hazmat.primitives import hashes
|
||||||
except ImportError:
|
except ImportError: # pragma: nocover
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING: # pragma: nocover
|
||||||
from ormar import BaseField
|
from ormar import BaseField
|
||||||
|
|
||||||
|
|
||||||
class EncryptBackend(abc.ABC):
|
class EncryptBackend(abc.ABC):
|
||||||
|
def _refresh(self, key: Union[str, bytes]) -> None:
|
||||||
def _update_key(self, key):
|
|
||||||
if isinstance(key, str):
|
if isinstance(key, str):
|
||||||
key = key.encode()
|
key = key.encode()
|
||||||
digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
|
digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
|
||||||
digest.update(key)
|
digest.update(key)
|
||||||
engine_key = digest.finalize()
|
engine_key = digest.finalize()
|
||||||
|
|
||||||
self._initialize_engine(engine_key)
|
self._initialize_backend(engine_key)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _initialize_engine(self, secret_key: bytes):
|
def _initialize_backend(self, secret_key: bytes) -> None: # pragma: nocover
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def encrypt(self, value: Any) -> str:
|
def encrypt(self, value: Any) -> str: # pragma: nocover
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def decrypt(self, value: Any) -> str:
|
def decrypt(self, value: Any) -> str: # pragma: nocover
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -53,11 +52,11 @@ class HashBackend(EncryptBackend):
|
|||||||
One-way hashing - in example for passwords, no way to decrypt the value!
|
One-way hashing - in example for passwords, no way to decrypt the value!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _initialize_engine(self, secret_key: bytes):
|
def _initialize_backend(self, secret_key: bytes) -> None:
|
||||||
self.secret_key = base64.urlsafe_b64encode(secret_key)
|
self.secret_key = base64.urlsafe_b64encode(secret_key)
|
||||||
|
|
||||||
def encrypt(self, value: Any) -> str:
|
def encrypt(self, value: Any) -> str:
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str): # pragma: nocover
|
||||||
value = repr(value)
|
value = repr(value)
|
||||||
value = value.encode()
|
value = value.encode()
|
||||||
digest = hashes.Hash(hashes.SHA512(), backend=default_backend())
|
digest = hashes.Hash(hashes.SHA512(), backend=default_backend())
|
||||||
@ -67,7 +66,7 @@ class HashBackend(EncryptBackend):
|
|||||||
return hashed_value.hex()
|
return hashed_value.hex()
|
||||||
|
|
||||||
def decrypt(self, value: Any) -> str:
|
def decrypt(self, value: Any) -> str:
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str): # pragma: nocover
|
||||||
value = str(value)
|
value = str(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@ -77,7 +76,7 @@ class FernetBackend(EncryptBackend):
|
|||||||
Two-way encryption, data stored in db are encrypted but decrypted during query.
|
Two-way encryption, data stored in db are encrypted but decrypted during query.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _initialize_engine(self, secret_key: bytes):
|
def _initialize_backend(self, secret_key: bytes) -> None:
|
||||||
self.secret_key = base64.urlsafe_b64encode(secret_key)
|
self.secret_key = base64.urlsafe_b64encode(secret_key)
|
||||||
self.fernet = Fernet(self.secret_key)
|
self.fernet = Fernet(self.secret_key)
|
||||||
|
|
||||||
@ -86,14 +85,14 @@ class FernetBackend(EncryptBackend):
|
|||||||
value = repr(value)
|
value = repr(value)
|
||||||
value = value.encode()
|
value = value.encode()
|
||||||
encrypted = self.fernet.encrypt(value)
|
encrypted = self.fernet.encrypt(value)
|
||||||
return encrypted.decode('utf-8')
|
return encrypted.decode("utf-8")
|
||||||
|
|
||||||
def decrypt(self, value: Any) -> str:
|
def decrypt(self, value: Any) -> str:
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str): # pragma: nocover
|
||||||
value = str(value)
|
value = str(value)
|
||||||
decrypted = self.fernet.decrypt(value.encode())
|
decrypted: Union[str, bytes] = self.fernet.decrypt(value.encode())
|
||||||
if not isinstance(decrypted, str):
|
if not isinstance(decrypted, str):
|
||||||
decrypted = decrypted.decode('utf-8')
|
decrypted = decrypted.decode("utf-8")
|
||||||
return decrypted
|
return decrypted
|
||||||
|
|
||||||
|
|
||||||
@ -104,115 +103,82 @@ class EncryptBackends(Enum):
|
|||||||
CUSTOM = 3
|
CUSTOM = 3
|
||||||
|
|
||||||
|
|
||||||
backends_map = {
|
BACKENDS_MAP = {
|
||||||
EncryptBackends.FERNET: FernetBackend,
|
EncryptBackends.FERNET: FernetBackend,
|
||||||
EncryptBackends.HASH: HashBackend,
|
EncryptBackends.HASH: HashBackend,
|
||||||
EncryptBackends.CUSTOM: None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class EncryptedString(types.TypeDecorator): # pragma nocover
|
class EncryptedString(types.TypeDecorator):
|
||||||
"""
|
"""
|
||||||
Used to store encrypted values in a database
|
Used to store encrypted values in a database
|
||||||
"""
|
"""
|
||||||
|
|
||||||
impl = types.TypeEngine
|
impl = types.TypeEngine
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
*args: Any,
|
self,
|
||||||
encrypt_secret: Union[str, Callable],
|
encrypt_secret: Union[str, Callable],
|
||||||
_field_type: Type["BaseField"],
|
|
||||||
encrypt_max_length: int = 5000,
|
|
||||||
encrypt_backend: EncryptBackends = EncryptBackends.FERNET,
|
encrypt_backend: EncryptBackends = EncryptBackends.FERNET,
|
||||||
encrypt_custom_backend: Type[EncryptBackend] = None,
|
encrypt_custom_backend: Type[EncryptBackend] = None,
|
||||||
**kwargs: Any) -> None:
|
**kwargs: Any,
|
||||||
super().__init__(*args, **kwargs)
|
) -> None:
|
||||||
if not cryptography:
|
_field_type = kwargs.pop("_field_type")
|
||||||
|
encrypt_max_length = kwargs.pop("encrypt_max_length", 5000)
|
||||||
|
super().__init__()
|
||||||
|
if not cryptography: # pragma: nocover
|
||||||
raise ModelDefinitionError(
|
raise ModelDefinitionError(
|
||||||
"In order to encrypt a column 'cryptography' is required!"
|
"In order to encrypt a column 'cryptography' is required!"
|
||||||
)
|
)
|
||||||
backend = backends_map.get(encrypt_backend, encrypt_custom_backend)
|
backend = BACKENDS_MAP.get(encrypt_backend, encrypt_custom_backend)
|
||||||
if not backend or not issubclass(backend, EncryptBackend):
|
if not backend or not lenient_issubclass(backend, EncryptBackend):
|
||||||
raise ModelDefinitionError("Wrong or no encrypt backend provided!")
|
raise ModelDefinitionError("Wrong or no encrypt backend provided!")
|
||||||
self.backend = backend()
|
|
||||||
self._field_type = _field_type
|
|
||||||
self._underlying_type = _field_type.column_type
|
|
||||||
self._key = encrypt_secret
|
|
||||||
self.max_length = encrypt_max_length
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
self.backend: EncryptBackend = backend()
|
||||||
|
self._field_type: Type["BaseField"] = _field_type
|
||||||
|
self._underlying_type: Any = _field_type.column_type
|
||||||
|
self._key: Union[str, Callable] = encrypt_secret
|
||||||
|
self.max_length: int = encrypt_max_length
|
||||||
|
type_ = self._field_type.__type__
|
||||||
|
if type_ is None: # pragma: nocover
|
||||||
|
raise ModelDefinitionError(
|
||||||
|
f"Improperly configured field " f"{self._field_type.name}"
|
||||||
|
)
|
||||||
|
self.type_: Any = type_
|
||||||
|
|
||||||
|
def __repr__(self) -> str: # pragma: nocover
|
||||||
return f"VARCHAR({self.max_length})"
|
return f"VARCHAR({self.max_length})"
|
||||||
|
|
||||||
def load_dialect_impl(self, dialect: DefaultDialect) -> Any:
|
def load_dialect_impl(self, dialect: DefaultDialect) -> Any:
|
||||||
return dialect.type_descriptor(types.VARCHAR(self.max_length))
|
return dialect.type_descriptor(types.VARCHAR(self.max_length))
|
||||||
|
|
||||||
@property
|
def _refresh(self) -> None:
|
||||||
def key(self):
|
|
||||||
return self._key
|
|
||||||
|
|
||||||
@key.setter
|
|
||||||
def key(self, value):
|
|
||||||
self._key = value
|
|
||||||
|
|
||||||
def _update_key(self):
|
|
||||||
key = self._key() if callable(self._key) else self._key
|
key = self._key() if callable(self._key) else self._key
|
||||||
self.backend._update_key(key)
|
self.backend._refresh(key)
|
||||||
|
|
||||||
def process_bind_param(self, value, dialect):
|
|
||||||
"""Encrypt a value on the way in."""
|
|
||||||
if value is not None:
|
|
||||||
self._update_key()
|
|
||||||
|
|
||||||
|
def process_bind_param(self, value: Any, dialect: DefaultDialect) -> Optional[str]:
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
self._refresh()
|
||||||
try:
|
try:
|
||||||
value = self._underlying_type.process_bind_param(
|
value = self._underlying_type.process_bind_param(value, dialect)
|
||||||
value, dialect
|
|
||||||
)
|
|
||||||
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# Doesn't have 'process_bind_param'
|
encoder = ormar.ENCODERS_MAP.get(self.type_, None)
|
||||||
type_ = self._field_type.__type__
|
if encoder:
|
||||||
if issubclass(type_, bool):
|
value = encoder(value) # type: ignore
|
||||||
value = 'true' if value else 'false'
|
|
||||||
|
|
||||||
elif issubclass(type_, (datetime.date, datetime.time)):
|
|
||||||
value = value.isoformat()
|
|
||||||
|
|
||||||
# elif issubclass(type_, JSONType):
|
|
||||||
# value = json.dumps(value)
|
|
||||||
|
|
||||||
return self.backend.encrypt(value)
|
return self.backend.encrypt(value)
|
||||||
|
|
||||||
def process_result_value(self, value, dialect):
|
def process_result_value(self, value: Any, dialect: DefaultDialect) -> Any:
|
||||||
"""Decrypt value on the way out."""
|
if value is None:
|
||||||
if value is not None:
|
return value
|
||||||
self._update_key()
|
self._refresh()
|
||||||
decrypted_value = self.backend.decrypt(value)
|
decrypted_value = self.backend.decrypt(value)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self.underlying_type.process_result_value(
|
return self._underlying_type.process_result_value(decrypted_value, dialect)
|
||||||
decrypted_value, dialect
|
|
||||||
)
|
|
||||||
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# Doesn't have 'process_result_value'
|
decoder = ormar.DECODERS_MAP.get(self.type_, None)
|
||||||
|
if decoder:
|
||||||
|
return decoder(decrypted_value) # type: ignore
|
||||||
|
|
||||||
# Handle 'boolean' and 'dates'
|
return self._field_type.__type__(decrypted_value) # type: ignore
|
||||||
type_ = self._field_type.__type__
|
|
||||||
# date_types = [datetime.datetime, datetime.time, datetime.date]
|
|
||||||
|
|
||||||
if issubclass(type_, bool):
|
|
||||||
return decrypted_value == 'true'
|
|
||||||
|
|
||||||
# elif type_ in date_types:
|
|
||||||
# return DatetimeHandler.process_value(
|
|
||||||
# decrypted_value, type_
|
|
||||||
# )
|
|
||||||
|
|
||||||
# elif issubclass(type_, JSONType):
|
|
||||||
# return json.loads(decrypted_value)
|
|
||||||
|
|
||||||
# Handle all others
|
|
||||||
return self.underlying_type.python_type(decrypted_value)
|
|
||||||
|
|
||||||
def _coerce(self, value):
|
|
||||||
return self.underlying_type._coerce(value)
|
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional
|
||||||
|
|
||||||
from sqlalchemy import CHAR
|
from sqlalchemy import CHAR
|
||||||
from sqlalchemy.engine.default import DefaultDialect
|
from sqlalchemy.engine.default import DefaultDialect
|
||||||
from sqlalchemy.types import TypeDecorator
|
from sqlalchemy.types import TypeDecorator
|
||||||
|
|
||||||
|
|
||||||
class UUID(TypeDecorator): # pragma nocover
|
class UUID(TypeDecorator):
|
||||||
"""
|
"""
|
||||||
Platform-independent GUID type.
|
Platform-independent GUID type.
|
||||||
Uses CHAR(36) if in a string mode, otherwise uses CHAR(32), to store UUID.
|
Uses CHAR(36) if in a string mode, otherwise uses CHAR(32), to store UUID.
|
||||||
@ -20,31 +20,11 @@ class UUID(TypeDecorator): # pragma nocover
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.uuid_format = uuid_format
|
self.uuid_format = uuid_format
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str: # pragma: nocover
|
||||||
if self.uuid_format == "string":
|
if self.uuid_format == "string":
|
||||||
return "CHAR(36)"
|
return "CHAR(36)"
|
||||||
return "CHAR(32)"
|
return "CHAR(32)"
|
||||||
|
|
||||||
def _cast_to_uuid(self, value: Union[str, int, bytes]) -> uuid.UUID:
|
|
||||||
"""
|
|
||||||
Parses given value into uuid.UUID field.
|
|
||||||
|
|
||||||
:param value: value to be parsed
|
|
||||||
:type value: Union[str, int, bytes]
|
|
||||||
:return: initialized uuid
|
|
||||||
:rtype: uuid.UUID
|
|
||||||
"""
|
|
||||||
if not isinstance(value, uuid.UUID):
|
|
||||||
if isinstance(value, bytes):
|
|
||||||
ret_value = uuid.UUID(bytes=value)
|
|
||||||
elif isinstance(value, int):
|
|
||||||
ret_value = uuid.UUID(int=value)
|
|
||||||
elif isinstance(value, str):
|
|
||||||
ret_value = uuid.UUID(value)
|
|
||||||
else:
|
|
||||||
ret_value = value
|
|
||||||
return ret_value
|
|
||||||
|
|
||||||
def load_dialect_impl(self, dialect: DefaultDialect) -> Any:
|
def load_dialect_impl(self, dialect: DefaultDialect) -> Any:
|
||||||
return (
|
return (
|
||||||
dialect.type_descriptor(CHAR(36))
|
dialect.type_descriptor(CHAR(36))
|
||||||
@ -53,12 +33,10 @@ class UUID(TypeDecorator): # pragma nocover
|
|||||||
)
|
)
|
||||||
|
|
||||||
def process_bind_param(
|
def process_bind_param(
|
||||||
self, value: Union[str, int, bytes, uuid.UUID, None], dialect: DefaultDialect
|
self, value: uuid.UUID, dialect: DefaultDialect
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
if not isinstance(value, uuid.UUID):
|
|
||||||
value = self._cast_to_uuid(value)
|
|
||||||
return str(value) if self.uuid_format == "string" else "%.32x" % value.int
|
return str(value) if self.uuid_format == "string" else "%.32x" % value.int
|
||||||
|
|
||||||
def process_result_value(
|
def process_result_value(
|
||||||
@ -68,4 +46,4 @@ class UUID(TypeDecorator): # pragma nocover
|
|||||||
return value
|
return value
|
||||||
if not isinstance(value, uuid.UUID):
|
if not isinstance(value, uuid.UUID):
|
||||||
return uuid.UUID(value)
|
return uuid.UUID(value)
|
||||||
return value
|
return value # pragma: nocover
|
||||||
|
|||||||
@ -53,6 +53,7 @@ def convert_choices_if_needed( # noqa: CCR001
|
|||||||
:return: value, choices list
|
:return: value, choices list
|
||||||
:rtype: Tuple[Any, List]
|
:rtype: Tuple[Any, List]
|
||||||
"""
|
"""
|
||||||
|
# TODO use same maps as with EncryptedString
|
||||||
choices = [o.value if isinstance(o, Enum) else o for o in field.choices]
|
choices = [o.value if isinstance(o, Enum) else o for o in field.choices]
|
||||||
|
|
||||||
if field.__type__ in [datetime.datetime, datetime.date, datetime.time]:
|
if field.__type__ in [datetime.datetime, datetime.date, datetime.time]:
|
||||||
|
|||||||
@ -88,13 +88,13 @@ class SqlJoin:
|
|||||||
return self.main_model.Meta.alias_manager
|
return self.main_model.Meta.alias_manager
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def to_table(self) -> str:
|
def to_table(self) -> sqlalchemy.Table:
|
||||||
"""
|
"""
|
||||||
Shortcut to table name of the next model
|
Shortcut to table name of the next model
|
||||||
:return: name of the target table
|
:return: name of the target table
|
||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
return self.next_model.Meta.table.name
|
return self.next_model.Meta.table
|
||||||
|
|
||||||
def _on_clause(
|
def _on_clause(
|
||||||
self, previous_alias: str, from_clause: str, to_clause: str,
|
self, previous_alias: str, from_clause: str, to_clause: str,
|
||||||
@ -282,7 +282,7 @@ class SqlJoin:
|
|||||||
on_clause = self._on_clause(
|
on_clause = self._on_clause(
|
||||||
previous_alias=self.own_alias,
|
previous_alias=self.own_alias,
|
||||||
from_clause=f"{self.target_field.owner.Meta.tablename}.{from_key}",
|
from_clause=f"{self.target_field.owner.Meta.tablename}.{from_key}",
|
||||||
to_clause=f"{self.to_table}.{to_key}",
|
to_clause=f"{self.to_table.name}.{to_key}",
|
||||||
)
|
)
|
||||||
target_table = self.alias_manager.prefixed_table_name(
|
target_table = self.alias_manager.prefixed_table_name(
|
||||||
self.next_alias, self.to_table
|
self.next_alias, self.to_table
|
||||||
@ -301,7 +301,7 @@ class SqlJoin:
|
|||||||
)
|
)
|
||||||
self.columns.extend(
|
self.columns.extend(
|
||||||
self.alias_manager.prefixed_columns(
|
self.alias_manager.prefixed_columns(
|
||||||
self.next_alias, self.next_model.Meta.table, self_related_fields
|
self.next_alias, target_table, self_related_fields
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.used_aliases.append(self.next_alias)
|
self.used_aliases.append(self.next_alias)
|
||||||
|
|||||||
@ -67,24 +67,21 @@ class AliasManager:
|
|||||||
if not fields
|
if not fields
|
||||||
else [col for col in table.columns if col.name in fields]
|
else [col for col in table.columns if col.name in fields]
|
||||||
)
|
)
|
||||||
return [
|
return [column.label(f"{alias}{column.name}") for column in all_columns]
|
||||||
text(f"{alias}{table.name}.{column.name} as {alias}{column.name}")
|
|
||||||
for column in all_columns
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prefixed_table_name(alias: str, name: str) -> text:
|
def prefixed_table_name(alias: str, table: sqlalchemy.Table) -> text:
|
||||||
"""
|
"""
|
||||||
Creates text clause with table name with aliased name.
|
Creates text clause with table name with aliased name.
|
||||||
|
|
||||||
:param alias: alias of given table
|
:param alias: alias of given table
|
||||||
:type alias: str
|
:type alias: str
|
||||||
:param name: table name
|
:param table: table
|
||||||
:type name: str
|
:type table: sqlalchemy.Table
|
||||||
:return: sqlalchemy text clause as "table_name aliased_name"
|
:return: sqlalchemy text clause as "table_name aliased_name"
|
||||||
:rtype: sqlalchemy text clause
|
:rtype: sqlalchemy text clause
|
||||||
"""
|
"""
|
||||||
return text(f"{name} {alias}_{name}")
|
return table.alias(f"{alias}_{table.name}")
|
||||||
|
|
||||||
def add_relation_type(
|
def add_relation_type(
|
||||||
self, source_model: Type["Model"], relation_name: str, reverse_name: str = None,
|
self, source_model: Type["Model"], relation_name: str, reverse_name: str = None,
|
||||||
|
|||||||
@ -1,12 +1,16 @@
|
|||||||
|
# type: ignore
|
||||||
|
import decimal
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import databases
|
import databases
|
||||||
import pytest
|
import pytest
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
import ormar
|
import ormar
|
||||||
from ormar.exceptions import QueryDefinitionError
|
from ormar import ModelDefinitionError
|
||||||
|
from ormar.fields.sqlalchemy_encrypted import EncryptedString
|
||||||
from tests.settings import DATABASE_URL
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
database = databases.Database(DATABASE_URL)
|
database = databases.Database(DATABASE_URL)
|
||||||
@ -18,22 +22,58 @@ class BaseMeta(ormar.ModelMeta):
|
|||||||
database = database
|
database = database
|
||||||
|
|
||||||
|
|
||||||
|
default_fernet = dict(
|
||||||
|
encrypt_secret="asd123", encrypt_backend=ormar.EncryptBackends.FERNET,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyBackend(ormar.fields.EncryptBackend):
|
||||||
|
def _initialize_backend(self, secret_key: bytes) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def encrypt(self, value: Any) -> str:
|
||||||
|
return value
|
||||||
|
|
||||||
|
def decrypt(self, value: Any) -> str:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class Author(ormar.Model):
|
class Author(ormar.Model):
|
||||||
class Meta(BaseMeta):
|
class Meta(BaseMeta):
|
||||||
tablename = "authors"
|
tablename = "authors"
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
id: int = ormar.Integer(primary_key=True)
|
||||||
name: str = ormar.String(max_length=100,
|
name: str = ormar.String(max_length=100, **default_fernet)
|
||||||
encrypt_secret='asd123',
|
uuid_test = ormar.UUID(default=uuid.uuid4, uuid_format="string")
|
||||||
encrypt_backend=ormar.EncryptBackends.FERNET)
|
uuid_test2 = ormar.UUID(nullable=True, uuid_format="string")
|
||||||
uuid_test = ormar.UUID(default=uuid.uuid4, uuid_format='string')
|
password: str = ormar.String(
|
||||||
password: str = ormar.String(max_length=100,
|
max_length=128,
|
||||||
encrypt_secret='udxc32',
|
encrypt_secret="udxc32",
|
||||||
encrypt_backend=ormar.EncryptBackends.HASH)
|
encrypt_backend=ormar.EncryptBackends.HASH,
|
||||||
birth_year: int = ormar.Integer(nullable=True,
|
)
|
||||||
encrypt_secret='secure89key%^&psdijfipew',
|
birth_year: int = ormar.Integer(
|
||||||
|
nullable=True,
|
||||||
|
encrypt_secret="secure89key%^&psdijfipew",
|
||||||
encrypt_max_length=200,
|
encrypt_max_length=200,
|
||||||
encrypt_backend=ormar.EncryptBackends.FERNET)
|
encrypt_backend=ormar.EncryptBackends.FERNET,
|
||||||
|
)
|
||||||
|
test_text: str = ormar.Text(default="", **default_fernet)
|
||||||
|
test_bool: bool = ormar.Boolean(nullable=False, **default_fernet)
|
||||||
|
test_float: float = ormar.Float(**default_fernet)
|
||||||
|
test_float2: float = ormar.Float(nullable=True, **default_fernet)
|
||||||
|
test_datetime = ormar.DateTime(default=datetime.datetime.now, **default_fernet)
|
||||||
|
test_date = ormar.Date(default=datetime.date.today, **default_fernet)
|
||||||
|
test_time = ormar.Time(default=datetime.time, **default_fernet)
|
||||||
|
test_json = ormar.JSON(default={}, **default_fernet)
|
||||||
|
test_bigint: int = ormar.BigInteger(default=0, **default_fernet)
|
||||||
|
test_decimal = ormar.Decimal(scale=2, precision=10, **default_fernet)
|
||||||
|
test_decimal2 = ormar.Decimal(max_digits=10, decimal_places=2, **default_fernet)
|
||||||
|
custom_backend: str = ormar.String(
|
||||||
|
max_length=200,
|
||||||
|
encrypt_secret="asda8",
|
||||||
|
encrypt_backend=ormar.EncryptBackends.CUSTOM,
|
||||||
|
encrypt_custom_backend=DummyBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True, scope="module")
|
@pytest.fixture(autouse=True, scope="module")
|
||||||
@ -45,16 +85,104 @@ def create_test_database():
|
|||||||
metadata.drop_all(engine)
|
metadata.drop_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_on_encrypted_pk():
|
||||||
|
with pytest.raises(ModelDefinitionError):
|
||||||
|
|
||||||
|
class Wrong(ormar.Model):
|
||||||
|
class Meta(BaseMeta):
|
||||||
|
tablename = "wrongs"
|
||||||
|
|
||||||
|
id: int = ormar.Integer(
|
||||||
|
primary_key=True,
|
||||||
|
encrypt_secret="asd123",
|
||||||
|
encrypt_backend=ormar.EncryptBackends.FERNET,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_on_encrypted_relation():
|
||||||
|
with pytest.raises(ModelDefinitionError):
|
||||||
|
|
||||||
|
class Wrong2(ormar.Model):
|
||||||
|
class Meta(BaseMeta):
|
||||||
|
tablename = "wrongs2"
|
||||||
|
|
||||||
|
id: int = ormar.Integer(primary_key=True)
|
||||||
|
author = ormar.ForeignKey(
|
||||||
|
Author,
|
||||||
|
encrypt_secret="asd123",
|
||||||
|
encrypt_backend=ormar.EncryptBackends.FERNET,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_on_encrypted_m2m_relation():
|
||||||
|
with pytest.raises(ModelDefinitionError):
|
||||||
|
|
||||||
|
class Wrong3(ormar.Model):
|
||||||
|
class Meta(BaseMeta):
|
||||||
|
tablename = "wrongs3"
|
||||||
|
|
||||||
|
id: int = ormar.Integer(primary_key=True)
|
||||||
|
author = ormar.ManyToMany(
|
||||||
|
Author,
|
||||||
|
encrypt_secret="asd123",
|
||||||
|
encrypt_backend=ormar.EncryptBackends.FERNET,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrong_backend():
|
||||||
|
with pytest.raises(ModelDefinitionError):
|
||||||
|
|
||||||
|
class Wrong3(ormar.Model):
|
||||||
|
class Meta(BaseMeta):
|
||||||
|
tablename = "wrongs3"
|
||||||
|
|
||||||
|
id: int = ormar.Integer(primary_key=True)
|
||||||
|
author = ormar.Integer(
|
||||||
|
encrypt_secret="asd123",
|
||||||
|
encrypt_backend=ormar.EncryptBackends.CUSTOM,
|
||||||
|
encrypt_custom_backend="aa",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_db_structure():
|
def test_db_structure():
|
||||||
assert Author.Meta.table.c.get('name').type.impl.__class__ == sqlalchemy.NVARCHAR
|
assert Author.Meta.table.c.get("name").type.__class__ == EncryptedString
|
||||||
assert Author.Meta.table.c.get('birth_year').type.max_length == 200
|
assert Author.Meta.table.c.get("birth_year").type.max_length == 200
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_wrong_query_foreign_key_type():
|
async def test_save_and_retrieve():
|
||||||
async with database:
|
async with database:
|
||||||
await Author(name='Test', birth_year=1988, password='test123').save()
|
test_uuid = uuid.uuid4()
|
||||||
|
await Author(
|
||||||
|
name="Test",
|
||||||
|
birth_year=1988,
|
||||||
|
password="test123",
|
||||||
|
uuid_test=test_uuid,
|
||||||
|
test_float=1.2,
|
||||||
|
test_bool=True,
|
||||||
|
test_decimal=decimal.Decimal(3.5),
|
||||||
|
test_decimal2=decimal.Decimal(5.5),
|
||||||
|
test_json=dict(aa=12),
|
||||||
|
custom_backend="test12",
|
||||||
|
).save()
|
||||||
author = await Author.objects.get()
|
author = await Author.objects.get()
|
||||||
|
|
||||||
assert author.name == 'Test'
|
assert author.name == "Test"
|
||||||
assert author.birth_year == 1988
|
assert author.birth_year == 1988
|
||||||
|
password = (
|
||||||
|
"03e4a4d513e99cb3fe4ee3db282c053daa3f3572b849c3868939a306944ad5c08"
|
||||||
|
"22b50d4886e10f4cd418c3f2df3ceb02e2e7ac6e920ae0c90f2dedfc8fa16e2"
|
||||||
|
)
|
||||||
|
assert author.password == password
|
||||||
|
assert author.uuid_test == test_uuid
|
||||||
|
assert author.uuid_test2 is None
|
||||||
|
assert author.test_datetime.date() == datetime.date.today()
|
||||||
|
assert author.test_date == datetime.date.today()
|
||||||
|
assert author.test_text == ""
|
||||||
|
assert author.test_float == 1.2
|
||||||
|
assert author.test_float2 is None
|
||||||
|
assert author.test_bigint == 0
|
||||||
|
assert author.test_json == {"aa": 12}
|
||||||
|
assert author.test_decimal == 3.5
|
||||||
|
assert author.test_decimal2 == 5.5
|
||||||
|
assert author.custom_backend == "test12"
|
||||||
|
|||||||
Reference in New Issue
Block a user