revert to use tables and columns with labels and aliases instead of text clauses, add encryption, mostly working encryption column type with configurable backends
This commit is contained in:
@ -38,9 +38,12 @@ from ormar.fields import (
|
||||
BaseField,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
DECODERS_MAP,
|
||||
Date,
|
||||
DateTime,
|
||||
Decimal,
|
||||
ENCODERS_MAP,
|
||||
EncryptBackends,
|
||||
Float,
|
||||
ForeignKey,
|
||||
ForeignKeyField,
|
||||
@ -53,7 +56,6 @@ from ormar.fields import (
|
||||
Time,
|
||||
UUID,
|
||||
UniqueColumns,
|
||||
EncryptBackends
|
||||
) # noqa: I100
|
||||
from ormar.models import ExcludableItems, Model
|
||||
from ormar.models.metaclass import ModelMeta
|
||||
@ -111,5 +113,7 @@ __all__ = [
|
||||
"ExcludableItems",
|
||||
"and_",
|
||||
"or_",
|
||||
"EncryptBackends"
|
||||
"EncryptBackends",
|
||||
"ENCODERS_MAP",
|
||||
"DECODERS_MAP",
|
||||
]
|
||||
|
||||
@ -21,8 +21,9 @@ from ormar.fields.model_fields import (
|
||||
Time,
|
||||
UUID,
|
||||
)
|
||||
from ormar.fields.through_field import Through, ThroughField
|
||||
from ormar.fields.parsers import DECODERS_MAP, ENCODERS_MAP
|
||||
from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends
|
||||
from ormar.fields.through_field import Through, ThroughField
|
||||
|
||||
__all__ = [
|
||||
"Decimal",
|
||||
@ -46,5 +47,7 @@ __all__ = [
|
||||
"ThroughField",
|
||||
"Through",
|
||||
"EncryptBackends",
|
||||
"EncryptBackend"
|
||||
"EncryptBackend",
|
||||
"DECODERS_MAP",
|
||||
"ENCODERS_MAP",
|
||||
]
|
||||
|
||||
@ -6,8 +6,11 @@ from pydantic.fields import FieldInfo, Required, Undefined
|
||||
|
||||
import ormar # noqa I101
|
||||
from ormar import ModelDefinitionError
|
||||
from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends, \
|
||||
EncryptedString
|
||||
from ormar.fields.sqlalchemy_encrypted import (
|
||||
EncryptBackend,
|
||||
EncryptBackends,
|
||||
EncryptedString,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar.models import Model
|
||||
@ -54,7 +57,7 @@ class BaseField(FieldInfo):
|
||||
|
||||
encrypt_secret: str
|
||||
encrypt_backend: EncryptBackends = EncryptBackends.NONE
|
||||
encrypt_custom_backend: Type[EncryptBackend] = None
|
||||
encrypt_custom_backend: Optional[Type[EncryptBackend]] = None
|
||||
encrypt_max_length: int = 5000
|
||||
|
||||
default: Any
|
||||
@ -101,11 +104,10 @@ class BaseField(FieldInfo):
|
||||
:rtype: bool
|
||||
"""
|
||||
return (
|
||||
field_name not in ["default", "default_factory", "alias",
|
||||
"allow_mutation"]
|
||||
and not field_name.startswith("__")
|
||||
and hasattr(cls, field_name)
|
||||
and not callable(getattr(cls, field_name))
|
||||
field_name not in ["default", "default_factory", "alias", "allow_mutation"]
|
||||
and not field_name.startswith("__")
|
||||
and hasattr(cls, field_name)
|
||||
and not callable(getattr(cls, field_name))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -214,7 +216,7 @@ class BaseField(FieldInfo):
|
||||
:rtype: bool
|
||||
"""
|
||||
return cls.default is not None or (
|
||||
cls.server_default is not None and use_server
|
||||
cls.server_default is not None and use_server
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -247,7 +249,7 @@ class BaseField(FieldInfo):
|
||||
ondelete=con.ondelete,
|
||||
onupdate=con.onupdate,
|
||||
name=f"fk_{cls.owner.Meta.tablename}_{cls.to.Meta.tablename}"
|
||||
f"_{cls.to.get_column_alias(cls.to.Meta.pkname)}_{cls.name}",
|
||||
f"_{cls.to.get_column_alias(cls.to.Meta.pkname)}_{cls.name}",
|
||||
)
|
||||
for con in cls.constraints
|
||||
]
|
||||
@ -278,33 +280,46 @@ class BaseField(FieldInfo):
|
||||
server_default=cls.server_default,
|
||||
)
|
||||
else:
|
||||
if cls.primary_key or cls.is_relation:
|
||||
raise ModelDefinitionError("Primary key field and relations fields"
|
||||
"cannot be encrypted!")
|
||||
column = sqlalchemy.Column(
|
||||
cls.alias or name,
|
||||
EncryptedString(
|
||||
_field_type=cls,
|
||||
encrypt_secret=cls.encrypt_secret,
|
||||
encrypt_backend=cls.encrypt_backend,
|
||||
encrypt_custom_backend=cls.encrypt_custom_backend,
|
||||
encrypt_max_length=cls.encrypt_max_length
|
||||
),
|
||||
nullable=cls.nullable,
|
||||
index=cls.index,
|
||||
unique=cls.unique,
|
||||
default=cls.default,
|
||||
server_default=cls.server_default,
|
||||
)
|
||||
column = cls._get_encrypted_column(name=name)
|
||||
return column
|
||||
|
||||
@classmethod
|
||||
def _get_encrypted_column(cls, name: str) -> sqlalchemy.Column:
|
||||
"""
|
||||
Returns EncryptedString column type instead of actual column.
|
||||
|
||||
:param name: column name
|
||||
:type name: str
|
||||
:return: newly defined column
|
||||
:rtype: sqlalchemy.Column
|
||||
"""
|
||||
if cls.primary_key or cls.is_relation:
|
||||
raise ModelDefinitionError(
|
||||
"Primary key field and relations fields" "cannot be encrypted!"
|
||||
)
|
||||
column = sqlalchemy.Column(
|
||||
cls.alias or name,
|
||||
EncryptedString(
|
||||
_field_type=cls,
|
||||
encrypt_secret=cls.encrypt_secret,
|
||||
encrypt_backend=cls.encrypt_backend,
|
||||
encrypt_custom_backend=cls.encrypt_custom_backend,
|
||||
encrypt_max_length=cls.encrypt_max_length,
|
||||
),
|
||||
nullable=cls.nullable,
|
||||
index=cls.index,
|
||||
unique=cls.unique,
|
||||
default=cls.default,
|
||||
server_default=cls.server_default,
|
||||
)
|
||||
return column
|
||||
|
||||
@classmethod
|
||||
def expand_relationship(
|
||||
cls,
|
||||
value: Any,
|
||||
child: Union["Model", "NewBaseModel"],
|
||||
to_register: bool = True,
|
||||
cls,
|
||||
value: Any,
|
||||
child: Union["Model", "NewBaseModel"],
|
||||
to_register: bool = True,
|
||||
) -> Any:
|
||||
"""
|
||||
Function overwritten for relations, in basic field the value is returned as is.
|
||||
@ -332,7 +347,7 @@ class BaseField(FieldInfo):
|
||||
:rtype: None
|
||||
"""
|
||||
if cls.owner is not None and (
|
||||
cls.owner == cls.to or cls.owner.Meta == cls.to.Meta
|
||||
cls.owner == cls.to or cls.owner.Meta == cls.to.Meta
|
||||
):
|
||||
cls.self_reference = True
|
||||
cls.self_reference_primary = cls.name
|
||||
|
||||
@ -184,10 +184,25 @@ def ForeignKey( # noqa CFQ002
|
||||
|
||||
owner = kwargs.pop("owner", None)
|
||||
self_reference = kwargs.pop("self_reference", False)
|
||||
|
||||
default = kwargs.pop("default", None)
|
||||
if default is not None:
|
||||
encrypt_secret = kwargs.pop("encrypt_secret", None)
|
||||
encrypt_backend = kwargs.pop("encrypt_backend", None)
|
||||
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
|
||||
encrypt_max_length = kwargs.pop("encrypt_max_length", None)
|
||||
|
||||
not_supported = [
|
||||
default,
|
||||
encrypt_secret,
|
||||
encrypt_backend,
|
||||
encrypt_custom_backend,
|
||||
encrypt_max_length,
|
||||
]
|
||||
if any(x is not None for x in not_supported):
|
||||
raise ModelDefinitionError(
|
||||
"Argument 'default' is not supported " "on relation fields!"
|
||||
f"Argument {next((x for x in not_supported if x is not None))} "
|
||||
f"is not supported "
|
||||
"on relation fields!"
|
||||
)
|
||||
|
||||
if to.__class__ == ForwardRef:
|
||||
@ -386,8 +401,6 @@ class ForeignKeyField(BaseField):
|
||||
:return: (if needed) registered Model
|
||||
:rtype: Model
|
||||
"""
|
||||
if cls.to.pk_type() == uuid.UUID and isinstance(value, str):
|
||||
value = uuid.UUID(value)
|
||||
if not isinstance(value, cls.to.pk_type()):
|
||||
raise RelationshipInstanceError(
|
||||
f"Relationship error - ForeignKey {cls.to.__name__} "
|
||||
|
||||
@ -97,9 +97,23 @@ def ManyToMany(
|
||||
forbid_through_relations(cast(Type["Model"], through))
|
||||
|
||||
default = kwargs.pop("default", None)
|
||||
if default is not None:
|
||||
encrypt_secret = kwargs.pop("encrypt_secret", None)
|
||||
encrypt_backend = kwargs.pop("encrypt_backend", None)
|
||||
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
|
||||
encrypt_max_length = kwargs.pop("encrypt_max_length", None)
|
||||
|
||||
not_supported = [
|
||||
default,
|
||||
encrypt_secret,
|
||||
encrypt_backend,
|
||||
encrypt_custom_backend,
|
||||
encrypt_max_length,
|
||||
]
|
||||
if any(x is not None for x in not_supported):
|
||||
raise ModelDefinitionError(
|
||||
"Argument 'default' is not supported " "on relation fields!"
|
||||
f"Argument {next((x for x in not_supported if x is not None))} "
|
||||
f"is not supported "
|
||||
"on relation fields!"
|
||||
)
|
||||
|
||||
if to.__class__ == ForwardRef:
|
||||
|
||||
@ -76,8 +76,7 @@ class ModelFieldFactory:
|
||||
|
||||
encrypt_secret = kwargs.pop("encrypt_secret", None)
|
||||
encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE)
|
||||
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend",
|
||||
None)
|
||||
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
|
||||
encrypt_max_length = kwargs.pop("encrypt_max_length", 5000)
|
||||
|
||||
namespace = dict(
|
||||
|
||||
44
ormar/fields/parsers.py
Normal file
44
ormar/fields/parsers.py
Normal file
@ -0,0 +1,44 @@
|
||||
import datetime
|
||||
import decimal
|
||||
from typing import Any
|
||||
|
||||
import pydantic
|
||||
from pydantic.datetime_parse import parse_date, parse_datetime, parse_time
|
||||
|
||||
try:
|
||||
import orjson as json
|
||||
except ImportError: # pragma: no cover
|
||||
import json # type: ignore
|
||||
|
||||
|
||||
def parse_bool(value: str) -> bool:
|
||||
return value == "true"
|
||||
|
||||
|
||||
def encode_bool(value: bool) -> str:
|
||||
return "true" if value else "false"
|
||||
|
||||
|
||||
def encode_json(value: Any) -> str:
|
||||
value = json.dumps(value) if not isinstance(value, str) else value
|
||||
value = value.decode("utf-8") if isinstance(value, bytes) else value
|
||||
return value
|
||||
|
||||
|
||||
ENCODERS_MAP = {
|
||||
bool: encode_bool,
|
||||
datetime.datetime: lambda x: x.isoformat(),
|
||||
datetime.date: lambda x: x.isoformat(),
|
||||
datetime.time: lambda x: x.isoformat(),
|
||||
pydantic.Json: encode_json,
|
||||
decimal.Decimal: float,
|
||||
}
|
||||
|
||||
DECODERS_MAP = {
|
||||
bool: parse_bool,
|
||||
datetime.datetime: parse_datetime,
|
||||
datetime.date: parse_date,
|
||||
datetime.time: parse_time,
|
||||
pydantic.Json: json.loads,
|
||||
decimal.Decimal: decimal.Decimal,
|
||||
}
|
||||
@ -1,50 +1,49 @@
|
||||
# inspired by sqlalchemy-utils (https://github.com/kvesteri/sqlalchemy-utils)
|
||||
import abc
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, TYPE_CHECKING, Type, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Type, Union
|
||||
|
||||
import sqlalchemy.types as types
|
||||
from pydantic.utils import lenient_issubclass
|
||||
from sqlalchemy.engine.default import DefaultDialect
|
||||
|
||||
from ormar import ModelDefinitionError
|
||||
import ormar # noqa: I100, I202
|
||||
from ormar import ModelDefinitionError # noqa: I202, I100
|
||||
|
||||
cryptography = None
|
||||
try:
|
||||
import cryptography
|
||||
try: # pragma: nocover
|
||||
import cryptography # type: ignore
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
except ImportError:
|
||||
except ImportError: # pragma: nocover
|
||||
pass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if TYPE_CHECKING: # pragma: nocover
|
||||
from ormar import BaseField
|
||||
|
||||
|
||||
class EncryptBackend(abc.ABC):
|
||||
|
||||
def _update_key(self, key):
|
||||
def _refresh(self, key: Union[str, bytes]) -> None:
|
||||
if isinstance(key, str):
|
||||
key = key.encode()
|
||||
digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
|
||||
digest.update(key)
|
||||
engine_key = digest.finalize()
|
||||
|
||||
self._initialize_engine(engine_key)
|
||||
self._initialize_backend(engine_key)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _initialize_engine(self, secret_key: bytes):
|
||||
def _initialize_backend(self, secret_key: bytes) -> None: # pragma: nocover
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def encrypt(self, value: Any) -> str:
|
||||
def encrypt(self, value: Any) -> str: # pragma: nocover
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def decrypt(self, value: Any) -> str:
|
||||
def decrypt(self, value: Any) -> str: # pragma: nocover
|
||||
pass
|
||||
|
||||
|
||||
@ -53,11 +52,11 @@ class HashBackend(EncryptBackend):
|
||||
One-way hashing - in example for passwords, no way to decrypt the value!
|
||||
"""
|
||||
|
||||
def _initialize_engine(self, secret_key: bytes):
|
||||
def _initialize_backend(self, secret_key: bytes) -> None:
|
||||
self.secret_key = base64.urlsafe_b64encode(secret_key)
|
||||
|
||||
def encrypt(self, value: Any) -> str:
|
||||
if not isinstance(value, str):
|
||||
if not isinstance(value, str): # pragma: nocover
|
||||
value = repr(value)
|
||||
value = value.encode()
|
||||
digest = hashes.Hash(hashes.SHA512(), backend=default_backend())
|
||||
@ -67,7 +66,7 @@ class HashBackend(EncryptBackend):
|
||||
return hashed_value.hex()
|
||||
|
||||
def decrypt(self, value: Any) -> str:
|
||||
if not isinstance(value, str):
|
||||
if not isinstance(value, str): # pragma: nocover
|
||||
value = str(value)
|
||||
return value
|
||||
|
||||
@ -77,7 +76,7 @@ class FernetBackend(EncryptBackend):
|
||||
Two-way encryption, data stored in db are encrypted but decrypted during query.
|
||||
"""
|
||||
|
||||
def _initialize_engine(self, secret_key: bytes):
|
||||
def _initialize_backend(self, secret_key: bytes) -> None:
|
||||
self.secret_key = base64.urlsafe_b64encode(secret_key)
|
||||
self.fernet = Fernet(self.secret_key)
|
||||
|
||||
@ -86,14 +85,14 @@ class FernetBackend(EncryptBackend):
|
||||
value = repr(value)
|
||||
value = value.encode()
|
||||
encrypted = self.fernet.encrypt(value)
|
||||
return encrypted.decode('utf-8')
|
||||
return encrypted.decode("utf-8")
|
||||
|
||||
def decrypt(self, value: Any) -> str:
|
||||
if not isinstance(value, str):
|
||||
if not isinstance(value, str): # pragma: nocover
|
||||
value = str(value)
|
||||
decrypted = self.fernet.decrypt(value.encode())
|
||||
decrypted: Union[str, bytes] = self.fernet.decrypt(value.encode())
|
||||
if not isinstance(decrypted, str):
|
||||
decrypted = decrypted.decode('utf-8')
|
||||
decrypted = decrypted.decode("utf-8")
|
||||
return decrypted
|
||||
|
||||
|
||||
@ -104,115 +103,82 @@ class EncryptBackends(Enum):
|
||||
CUSTOM = 3
|
||||
|
||||
|
||||
backends_map = {
|
||||
BACKENDS_MAP = {
|
||||
EncryptBackends.FERNET: FernetBackend,
|
||||
EncryptBackends.HASH: HashBackend,
|
||||
EncryptBackends.CUSTOM: None
|
||||
}
|
||||
|
||||
|
||||
class EncryptedString(types.TypeDecorator): # pragma nocover
|
||||
class EncryptedString(types.TypeDecorator):
|
||||
"""
|
||||
Used to store encrypted values in a database
|
||||
"""
|
||||
|
||||
impl = types.TypeEngine
|
||||
|
||||
def __init__(self,
|
||||
*args: Any,
|
||||
encrypt_secret: Union[str, Callable],
|
||||
_field_type: Type["BaseField"],
|
||||
encrypt_max_length: int = 5000,
|
||||
encrypt_backend: EncryptBackends = EncryptBackends.FERNET,
|
||||
encrypt_custom_backend: Type[EncryptBackend] = None,
|
||||
**kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
if not cryptography:
|
||||
def __init__(
|
||||
self,
|
||||
encrypt_secret: Union[str, Callable],
|
||||
encrypt_backend: EncryptBackends = EncryptBackends.FERNET,
|
||||
encrypt_custom_backend: Type[EncryptBackend] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
_field_type = kwargs.pop("_field_type")
|
||||
encrypt_max_length = kwargs.pop("encrypt_max_length", 5000)
|
||||
super().__init__()
|
||||
if not cryptography: # pragma: nocover
|
||||
raise ModelDefinitionError(
|
||||
"In order to encrypt a column 'cryptography' is required!"
|
||||
)
|
||||
backend = backends_map.get(encrypt_backend, encrypt_custom_backend)
|
||||
if not backend or not issubclass(backend, EncryptBackend):
|
||||
backend = BACKENDS_MAP.get(encrypt_backend, encrypt_custom_backend)
|
||||
if not backend or not lenient_issubclass(backend, EncryptBackend):
|
||||
raise ModelDefinitionError("Wrong or no encrypt backend provided!")
|
||||
self.backend = backend()
|
||||
self._field_type = _field_type
|
||||
self._underlying_type = _field_type.column_type
|
||||
self._key = encrypt_secret
|
||||
self.max_length = encrypt_max_length
|
||||
|
||||
def __repr__(self) -> str:
|
||||
self.backend: EncryptBackend = backend()
|
||||
self._field_type: Type["BaseField"] = _field_type
|
||||
self._underlying_type: Any = _field_type.column_type
|
||||
self._key: Union[str, Callable] = encrypt_secret
|
||||
self.max_length: int = encrypt_max_length
|
||||
type_ = self._field_type.__type__
|
||||
if type_ is None: # pragma: nocover
|
||||
raise ModelDefinitionError(
|
||||
f"Improperly configured field " f"{self._field_type.name}"
|
||||
)
|
||||
self.type_: Any = type_
|
||||
|
||||
def __repr__(self) -> str: # pragma: nocover
|
||||
return f"VARCHAR({self.max_length})"
|
||||
|
||||
def load_dialect_impl(self, dialect: DefaultDialect) -> Any:
|
||||
return dialect.type_descriptor(types.VARCHAR(self.max_length))
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return self._key
|
||||
|
||||
@key.setter
|
||||
def key(self, value):
|
||||
self._key = value
|
||||
|
||||
def _update_key(self):
|
||||
def _refresh(self) -> None:
|
||||
key = self._key() if callable(self._key) else self._key
|
||||
self.backend._update_key(key)
|
||||
self.backend._refresh(key)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
"""Encrypt a value on the way in."""
|
||||
if value is not None:
|
||||
self._update_key()
|
||||
def process_bind_param(self, value: Any, dialect: DefaultDialect) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
self._refresh()
|
||||
try:
|
||||
value = self._underlying_type.process_bind_param(value, dialect)
|
||||
except AttributeError:
|
||||
encoder = ormar.ENCODERS_MAP.get(self.type_, None)
|
||||
if encoder:
|
||||
value = encoder(value) # type: ignore
|
||||
|
||||
try:
|
||||
value = self._underlying_type.process_bind_param(
|
||||
value, dialect
|
||||
)
|
||||
return self.backend.encrypt(value)
|
||||
|
||||
except AttributeError:
|
||||
# Doesn't have 'process_bind_param'
|
||||
type_ = self._field_type.__type__
|
||||
if issubclass(type_, bool):
|
||||
value = 'true' if value else 'false'
|
||||
def process_result_value(self, value: Any, dialect: DefaultDialect) -> Any:
|
||||
if value is None:
|
||||
return value
|
||||
self._refresh()
|
||||
decrypted_value = self.backend.decrypt(value)
|
||||
try:
|
||||
return self._underlying_type.process_result_value(decrypted_value, dialect)
|
||||
except AttributeError:
|
||||
decoder = ormar.DECODERS_MAP.get(self.type_, None)
|
||||
if decoder:
|
||||
return decoder(decrypted_value) # type: ignore
|
||||
|
||||
elif issubclass(type_, (datetime.date, datetime.time)):
|
||||
value = value.isoformat()
|
||||
|
||||
# elif issubclass(type_, JSONType):
|
||||
# value = json.dumps(value)
|
||||
|
||||
return self.backend.encrypt(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
"""Decrypt value on the way out."""
|
||||
if value is not None:
|
||||
self._update_key()
|
||||
decrypted_value = self.backend.decrypt(value)
|
||||
|
||||
try:
|
||||
return self.underlying_type.process_result_value(
|
||||
decrypted_value, dialect
|
||||
)
|
||||
|
||||
except AttributeError:
|
||||
# Doesn't have 'process_result_value'
|
||||
|
||||
# Handle 'boolean' and 'dates'
|
||||
type_ = self._field_type.__type__
|
||||
# date_types = [datetime.datetime, datetime.time, datetime.date]
|
||||
|
||||
if issubclass(type_, bool):
|
||||
return decrypted_value == 'true'
|
||||
|
||||
# elif type_ in date_types:
|
||||
# return DatetimeHandler.process_value(
|
||||
# decrypted_value, type_
|
||||
# )
|
||||
|
||||
# elif issubclass(type_, JSONType):
|
||||
# return json.loads(decrypted_value)
|
||||
|
||||
# Handle all others
|
||||
return self.underlying_type.python_type(decrypted_value)
|
||||
|
||||
def _coerce(self, value):
|
||||
return self.underlying_type._coerce(value)
|
||||
return self._field_type.__type__(decrypted_value) # type: ignore
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import uuid
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import CHAR
|
||||
from sqlalchemy.engine.default import DefaultDialect
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
|
||||
class UUID(TypeDecorator): # pragma nocover
|
||||
class UUID(TypeDecorator):
|
||||
"""
|
||||
Platform-independent GUID type.
|
||||
Uses CHAR(36) if in a string mode, otherwise uses CHAR(32), to store UUID.
|
||||
@ -20,31 +20,11 @@ class UUID(TypeDecorator): # pragma nocover
|
||||
super().__init__(*args, **kwargs)
|
||||
self.uuid_format = uuid_format
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self) -> str: # pragma: nocover
|
||||
if self.uuid_format == "string":
|
||||
return "CHAR(36)"
|
||||
return "CHAR(32)"
|
||||
|
||||
def _cast_to_uuid(self, value: Union[str, int, bytes]) -> uuid.UUID:
|
||||
"""
|
||||
Parses given value into uuid.UUID field.
|
||||
|
||||
:param value: value to be parsed
|
||||
:type value: Union[str, int, bytes]
|
||||
:return: initialized uuid
|
||||
:rtype: uuid.UUID
|
||||
"""
|
||||
if not isinstance(value, uuid.UUID):
|
||||
if isinstance(value, bytes):
|
||||
ret_value = uuid.UUID(bytes=value)
|
||||
elif isinstance(value, int):
|
||||
ret_value = uuid.UUID(int=value)
|
||||
elif isinstance(value, str):
|
||||
ret_value = uuid.UUID(value)
|
||||
else:
|
||||
ret_value = value
|
||||
return ret_value
|
||||
|
||||
def load_dialect_impl(self, dialect: DefaultDialect) -> Any:
|
||||
return (
|
||||
dialect.type_descriptor(CHAR(36))
|
||||
@ -53,12 +33,10 @@ class UUID(TypeDecorator): # pragma nocover
|
||||
)
|
||||
|
||||
def process_bind_param(
|
||||
self, value: Union[str, int, bytes, uuid.UUID, None], dialect: DefaultDialect
|
||||
self, value: uuid.UUID, dialect: DefaultDialect
|
||||
) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
if not isinstance(value, uuid.UUID):
|
||||
value = self._cast_to_uuid(value)
|
||||
return str(value) if self.uuid_format == "string" else "%.32x" % value.int
|
||||
|
||||
def process_result_value(
|
||||
@ -68,4 +46,4 @@ class UUID(TypeDecorator): # pragma nocover
|
||||
return value
|
||||
if not isinstance(value, uuid.UUID):
|
||||
return uuid.UUID(value)
|
||||
return value
|
||||
return value # pragma: nocover
|
||||
|
||||
@ -53,6 +53,7 @@ def convert_choices_if_needed( # noqa: CCR001
|
||||
:return: value, choices list
|
||||
:rtype: Tuple[Any, List]
|
||||
"""
|
||||
# TODO use same maps as with EncryptedString
|
||||
choices = [o.value if isinstance(o, Enum) else o for o in field.choices]
|
||||
|
||||
if field.__type__ in [datetime.datetime, datetime.date, datetime.time]:
|
||||
|
||||
@ -88,13 +88,13 @@ class SqlJoin:
|
||||
return self.main_model.Meta.alias_manager
|
||||
|
||||
@property
|
||||
def to_table(self) -> str:
|
||||
def to_table(self) -> sqlalchemy.Table:
|
||||
"""
|
||||
Shortcut to table name of the next model
|
||||
:return: name of the target table
|
||||
:rtype: str
|
||||
"""
|
||||
return self.next_model.Meta.table.name
|
||||
return self.next_model.Meta.table
|
||||
|
||||
def _on_clause(
|
||||
self, previous_alias: str, from_clause: str, to_clause: str,
|
||||
@ -282,7 +282,7 @@ class SqlJoin:
|
||||
on_clause = self._on_clause(
|
||||
previous_alias=self.own_alias,
|
||||
from_clause=f"{self.target_field.owner.Meta.tablename}.{from_key}",
|
||||
to_clause=f"{self.to_table}.{to_key}",
|
||||
to_clause=f"{self.to_table.name}.{to_key}",
|
||||
)
|
||||
target_table = self.alias_manager.prefixed_table_name(
|
||||
self.next_alias, self.to_table
|
||||
@ -301,7 +301,7 @@ class SqlJoin:
|
||||
)
|
||||
self.columns.extend(
|
||||
self.alias_manager.prefixed_columns(
|
||||
self.next_alias, self.next_model.Meta.table, self_related_fields
|
||||
self.next_alias, target_table, self_related_fields
|
||||
)
|
||||
)
|
||||
self.used_aliases.append(self.next_alias)
|
||||
|
||||
@ -67,24 +67,21 @@ class AliasManager:
|
||||
if not fields
|
||||
else [col for col in table.columns if col.name in fields]
|
||||
)
|
||||
return [
|
||||
text(f"{alias}{table.name}.{column.name} as {alias}{column.name}")
|
||||
for column in all_columns
|
||||
]
|
||||
return [column.label(f"{alias}{column.name}") for column in all_columns]
|
||||
|
||||
@staticmethod
|
||||
def prefixed_table_name(alias: str, name: str) -> text:
|
||||
def prefixed_table_name(alias: str, table: sqlalchemy.Table) -> text:
|
||||
"""
|
||||
Creates text clause with table name with aliased name.
|
||||
|
||||
:param alias: alias of given table
|
||||
:type alias: str
|
||||
:param name: table name
|
||||
:type name: str
|
||||
:param table: table
|
||||
:type table: sqlalchemy.Table
|
||||
:return: sqlalchemy text clause as "table_name aliased_name"
|
||||
:rtype: sqlalchemy text clause
|
||||
"""
|
||||
return text(f"{name} {alias}_{name}")
|
||||
return table.alias(f"{alias}_{table.name}")
|
||||
|
||||
def add_relation_type(
|
||||
self, source_model: Type["Model"], relation_name: str, reverse_name: str = None,
|
||||
|
||||
@ -1,12 +1,16 @@
|
||||
# type: ignore
|
||||
import decimal
|
||||
import uuid
|
||||
from typing import Optional
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
import databases
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import QueryDefinitionError
|
||||
from ormar import ModelDefinitionError
|
||||
from ormar.fields.sqlalchemy_encrypted import EncryptedString
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL)
|
||||
@ -18,22 +22,58 @@ class BaseMeta(ormar.ModelMeta):
|
||||
database = database
|
||||
|
||||
|
||||
default_fernet = dict(
|
||||
encrypt_secret="asd123", encrypt_backend=ormar.EncryptBackends.FERNET,
|
||||
)
|
||||
|
||||
|
||||
class DummyBackend(ormar.fields.EncryptBackend):
|
||||
def _initialize_backend(self, secret_key: bytes) -> None:
|
||||
pass
|
||||
|
||||
def encrypt(self, value: Any) -> str:
|
||||
return value
|
||||
|
||||
def decrypt(self, value: Any) -> str:
|
||||
return value
|
||||
|
||||
|
||||
class Author(ormar.Model):
|
||||
class Meta(BaseMeta):
|
||||
tablename = "authors"
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100,
|
||||
encrypt_secret='asd123',
|
||||
encrypt_backend=ormar.EncryptBackends.FERNET)
|
||||
uuid_test = ormar.UUID(default=uuid.uuid4, uuid_format='string')
|
||||
password: str = ormar.String(max_length=100,
|
||||
encrypt_secret='udxc32',
|
||||
encrypt_backend=ormar.EncryptBackends.HASH)
|
||||
birth_year: int = ormar.Integer(nullable=True,
|
||||
encrypt_secret='secure89key%^&psdijfipew',
|
||||
encrypt_max_length=200,
|
||||
encrypt_backend=ormar.EncryptBackends.FERNET)
|
||||
name: str = ormar.String(max_length=100, **default_fernet)
|
||||
uuid_test = ormar.UUID(default=uuid.uuid4, uuid_format="string")
|
||||
uuid_test2 = ormar.UUID(nullable=True, uuid_format="string")
|
||||
password: str = ormar.String(
|
||||
max_length=128,
|
||||
encrypt_secret="udxc32",
|
||||
encrypt_backend=ormar.EncryptBackends.HASH,
|
||||
)
|
||||
birth_year: int = ormar.Integer(
|
||||
nullable=True,
|
||||
encrypt_secret="secure89key%^&psdijfipew",
|
||||
encrypt_max_length=200,
|
||||
encrypt_backend=ormar.EncryptBackends.FERNET,
|
||||
)
|
||||
test_text: str = ormar.Text(default="", **default_fernet)
|
||||
test_bool: bool = ormar.Boolean(nullable=False, **default_fernet)
|
||||
test_float: float = ormar.Float(**default_fernet)
|
||||
test_float2: float = ormar.Float(nullable=True, **default_fernet)
|
||||
test_datetime = ormar.DateTime(default=datetime.datetime.now, **default_fernet)
|
||||
test_date = ormar.Date(default=datetime.date.today, **default_fernet)
|
||||
test_time = ormar.Time(default=datetime.time, **default_fernet)
|
||||
test_json = ormar.JSON(default={}, **default_fernet)
|
||||
test_bigint: int = ormar.BigInteger(default=0, **default_fernet)
|
||||
test_decimal = ormar.Decimal(scale=2, precision=10, **default_fernet)
|
||||
test_decimal2 = ormar.Decimal(max_digits=10, decimal_places=2, **default_fernet)
|
||||
custom_backend: str = ormar.String(
|
||||
max_length=200,
|
||||
encrypt_secret="asda8",
|
||||
encrypt_backend=ormar.EncryptBackends.CUSTOM,
|
||||
encrypt_custom_backend=DummyBackend,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
@ -45,16 +85,104 @@ def create_test_database():
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
def test_error_on_encrypted_pk():
|
||||
with pytest.raises(ModelDefinitionError):
|
||||
|
||||
class Wrong(ormar.Model):
|
||||
class Meta(BaseMeta):
|
||||
tablename = "wrongs"
|
||||
|
||||
id: int = ormar.Integer(
|
||||
primary_key=True,
|
||||
encrypt_secret="asd123",
|
||||
encrypt_backend=ormar.EncryptBackends.FERNET,
|
||||
)
|
||||
|
||||
|
||||
def test_error_on_encrypted_relation():
|
||||
with pytest.raises(ModelDefinitionError):
|
||||
|
||||
class Wrong2(ormar.Model):
|
||||
class Meta(BaseMeta):
|
||||
tablename = "wrongs2"
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
author = ormar.ForeignKey(
|
||||
Author,
|
||||
encrypt_secret="asd123",
|
||||
encrypt_backend=ormar.EncryptBackends.FERNET,
|
||||
)
|
||||
|
||||
|
||||
def test_error_on_encrypted_m2m_relation():
|
||||
with pytest.raises(ModelDefinitionError):
|
||||
|
||||
class Wrong3(ormar.Model):
|
||||
class Meta(BaseMeta):
|
||||
tablename = "wrongs3"
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
author = ormar.ManyToMany(
|
||||
Author,
|
||||
encrypt_secret="asd123",
|
||||
encrypt_backend=ormar.EncryptBackends.FERNET,
|
||||
)
|
||||
|
||||
|
||||
def test_wrong_backend():
|
||||
with pytest.raises(ModelDefinitionError):
|
||||
|
||||
class Wrong3(ormar.Model):
|
||||
class Meta(BaseMeta):
|
||||
tablename = "wrongs3"
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
author = ormar.Integer(
|
||||
encrypt_secret="asd123",
|
||||
encrypt_backend=ormar.EncryptBackends.CUSTOM,
|
||||
encrypt_custom_backend="aa",
|
||||
)
|
||||
|
||||
|
||||
def test_db_structure():
|
||||
assert Author.Meta.table.c.get('name').type.impl.__class__ == sqlalchemy.NVARCHAR
|
||||
assert Author.Meta.table.c.get('birth_year').type.max_length == 200
|
||||
assert Author.Meta.table.c.get("name").type.__class__ == EncryptedString
|
||||
assert Author.Meta.table.c.get("birth_year").type.max_length == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrong_query_foreign_key_type():
|
||||
async def test_save_and_retrieve():
|
||||
async with database:
|
||||
await Author(name='Test', birth_year=1988, password='test123').save()
|
||||
test_uuid = uuid.uuid4()
|
||||
await Author(
|
||||
name="Test",
|
||||
birth_year=1988,
|
||||
password="test123",
|
||||
uuid_test=test_uuid,
|
||||
test_float=1.2,
|
||||
test_bool=True,
|
||||
test_decimal=decimal.Decimal(3.5),
|
||||
test_decimal2=decimal.Decimal(5.5),
|
||||
test_json=dict(aa=12),
|
||||
custom_backend="test12",
|
||||
).save()
|
||||
author = await Author.objects.get()
|
||||
|
||||
assert author.name == 'Test'
|
||||
assert author.name == "Test"
|
||||
assert author.birth_year == 1988
|
||||
password = (
|
||||
"03e4a4d513e99cb3fe4ee3db282c053daa3f3572b849c3868939a306944ad5c08"
|
||||
"22b50d4886e10f4cd418c3f2df3ceb02e2e7ac6e920ae0c90f2dedfc8fa16e2"
|
||||
)
|
||||
assert author.password == password
|
||||
assert author.uuid_test == test_uuid
|
||||
assert author.uuid_test2 is None
|
||||
assert author.test_datetime.date() == datetime.date.today()
|
||||
assert author.test_date == datetime.date.today()
|
||||
assert author.test_text == ""
|
||||
assert author.test_float == 1.2
|
||||
assert author.test_float2 is None
|
||||
assert author.test_bigint == 0
|
||||
assert author.test_json == {"aa": 12}
|
||||
assert author.test_decimal == 3.5
|
||||
assert author.test_decimal2 == 5.5
|
||||
assert author.custom_backend == "test12"
|
||||
|
||||
@ -57,24 +57,24 @@ async def test_or_filters():
|
||||
|
||||
books = (
|
||||
await Book.objects.select_related("author")
|
||||
.filter(ormar.or_(author__name="J.R.R. Tolkien", year__gt=1970))
|
||||
.all()
|
||||
.filter(ormar.or_(author__name="J.R.R. Tolkien", year__gt=1970))
|
||||
.all()
|
||||
)
|
||||
assert len(books) == 5
|
||||
|
||||
books = (
|
||||
await Book.objects.select_related("author")
|
||||
.filter(ormar.or_(author__name="J.R.R. Tolkien", year__lt=1995))
|
||||
.all()
|
||||
.filter(ormar.or_(author__name="J.R.R. Tolkien", year__lt=1995))
|
||||
.all()
|
||||
)
|
||||
assert len(books) == 4
|
||||
assert not any([x.title == "The Tower of Fools" for x in books])
|
||||
|
||||
books = (
|
||||
await Book.objects.select_related("author")
|
||||
.filter(ormar.or_(year__gt=1960, year__lt=1940))
|
||||
.filter(author__name="J.R.R. Tolkien")
|
||||
.all()
|
||||
.filter(ormar.or_(year__gt=1960, year__lt=1940))
|
||||
.filter(author__name="J.R.R. Tolkien")
|
||||
.all()
|
||||
)
|
||||
assert len(books) == 2
|
||||
assert books[0].title == "The Hobbit"
|
||||
@ -82,13 +82,13 @@ async def test_or_filters():
|
||||
|
||||
books = (
|
||||
await Book.objects.select_related("author")
|
||||
.filter(
|
||||
.filter(
|
||||
ormar.and_(
|
||||
ormar.or_(year__gt=1960, year__lt=1940),
|
||||
author__name="J.R.R. Tolkien",
|
||||
)
|
||||
)
|
||||
.all()
|
||||
.all()
|
||||
)
|
||||
|
||||
assert len(books) == 2
|
||||
@ -97,14 +97,14 @@ async def test_or_filters():
|
||||
|
||||
books = (
|
||||
await Book.objects.select_related("author")
|
||||
.filter(
|
||||
.filter(
|
||||
ormar.or_(
|
||||
ormar.and_(year__gt=1960, author__name="J.R.R. Tolkien"),
|
||||
ormar.and_(year__lt=2000, author__name="Andrzej Sapkowski"),
|
||||
)
|
||||
)
|
||||
.filter(title__startswith="The")
|
||||
.all()
|
||||
.filter(title__startswith="The")
|
||||
.all()
|
||||
)
|
||||
assert len(books) == 2
|
||||
assert books[0].title == "The Silmarillion"
|
||||
@ -112,7 +112,7 @@ async def test_or_filters():
|
||||
|
||||
books = (
|
||||
await Book.objects.select_related("author")
|
||||
.filter(
|
||||
.filter(
|
||||
ormar.or_(
|
||||
ormar.and_(
|
||||
ormar.or_(year__gt=1960, year__lt=1940),
|
||||
@ -121,7 +121,7 @@ async def test_or_filters():
|
||||
ormar.and_(year__lt=2000, author__name="Andrzej Sapkowski"),
|
||||
)
|
||||
)
|
||||
.all()
|
||||
.all()
|
||||
)
|
||||
assert len(books) == 3
|
||||
assert books[0].title == "The Hobbit"
|
||||
@ -130,29 +130,29 @@ async def test_or_filters():
|
||||
|
||||
books = (
|
||||
await Book.objects.select_related("author")
|
||||
.exclude(
|
||||
.exclude(
|
||||
ormar.or_(
|
||||
ormar.and_(year__gt=1960, author__name="J.R.R. Tolkien"),
|
||||
ormar.and_(year__lt=2000, author__name="Andrzej Sapkowski"),
|
||||
)
|
||||
)
|
||||
.filter(title__startswith="The")
|
||||
.all()
|
||||
.filter(title__startswith="The")
|
||||
.all()
|
||||
)
|
||||
assert len(books) == 3
|
||||
assert not any([x.title in ["The Silmarillion", "The Witcher"] for x in books])
|
||||
|
||||
books = (
|
||||
await Book.objects.select_related("author")
|
||||
.filter(
|
||||
.filter(
|
||||
ormar.or_(
|
||||
ormar.and_(year__gt=1960, author__name="J.R.R. Tolkien"),
|
||||
ormar.and_(year__lt=2000, author__name="Andrzej Sapkowski"),
|
||||
title__icontains="hobbit",
|
||||
)
|
||||
)
|
||||
.filter(title__startswith="The")
|
||||
.all()
|
||||
.filter(title__startswith="The")
|
||||
.all()
|
||||
)
|
||||
assert len(books) == 3
|
||||
assert not any(
|
||||
@ -161,43 +161,43 @@ async def test_or_filters():
|
||||
|
||||
books = (
|
||||
await Book.objects.select_related("author")
|
||||
.filter(ormar.or_(year__gt=1980, year__lt=1910))
|
||||
.filter(title__startswith="The")
|
||||
.limit(1)
|
||||
.all()
|
||||
.filter(ormar.or_(year__gt=1980, year__lt=1910))
|
||||
.filter(title__startswith="The")
|
||||
.limit(1)
|
||||
.all()
|
||||
)
|
||||
assert len(books) == 1
|
||||
assert books[0].title == "The Witcher"
|
||||
|
||||
books = (
|
||||
await Book.objects.select_related("author")
|
||||
.filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski"))
|
||||
.filter(title__startswith="The")
|
||||
.limit(1)
|
||||
.all()
|
||||
.filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski"))
|
||||
.filter(title__startswith="The")
|
||||
.limit(1)
|
||||
.all()
|
||||
)
|
||||
assert len(books) == 1
|
||||
assert books[0].title == "The Witcher"
|
||||
|
||||
books = (
|
||||
await Book.objects.select_related("author")
|
||||
.filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski"))
|
||||
.filter(title__startswith="The")
|
||||
.limit(1)
|
||||
.offset(1)
|
||||
.all()
|
||||
.filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski"))
|
||||
.filter(title__startswith="The")
|
||||
.limit(1)
|
||||
.offset(1)
|
||||
.all()
|
||||
)
|
||||
assert len(books) == 1
|
||||
assert books[0].title == "The Tower of Fools"
|
||||
|
||||
books = (
|
||||
await Book.objects.select_related("author")
|
||||
.filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski"))
|
||||
.filter(title__startswith="The")
|
||||
.limit(1)
|
||||
.offset(1)
|
||||
.order_by("-id")
|
||||
.all()
|
||||
.filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski"))
|
||||
.filter(title__startswith="The")
|
||||
.limit(1)
|
||||
.offset(1)
|
||||
.order_by("-id")
|
||||
.all()
|
||||
)
|
||||
assert len(books) == 1
|
||||
assert books[0].title == "The Witcher"
|
||||
@ -220,19 +220,19 @@ async def test_or_filters():
|
||||
|
||||
books = (
|
||||
await Book.objects.select_related("author")
|
||||
.filter(ormar.or_(author__name="J.R.R. Tolkien"))
|
||||
.all()
|
||||
.filter(ormar.or_(author__name="J.R.R. Tolkien"))
|
||||
.all()
|
||||
)
|
||||
assert len(books) == 3
|
||||
|
||||
books = (
|
||||
await Book.objects.select_related("author")
|
||||
.filter(
|
||||
.filter(
|
||||
ormar.or_(
|
||||
ormar.and_(author__name__icontains="tolkien"),
|
||||
ormar.and_(author__name__icontains="sapkowski"),
|
||||
)
|
||||
)
|
||||
.all()
|
||||
.all()
|
||||
)
|
||||
assert len(books) == 5
|
||||
|
||||
Reference in New Issue
Block a user