revert to use tables and columns with labels and aliases instead of text clauses, add encryption, mostly working encryption column type with configurable backends

This commit is contained in:
collerek
2021-03-09 20:29:27 +01:00
parent 8d96a3fb84
commit e29bea6f85
14 changed files with 415 additions and 253 deletions

View File

@ -38,9 +38,12 @@ from ormar.fields import (
BaseField,
BigInteger,
Boolean,
DECODERS_MAP,
Date,
DateTime,
Decimal,
ENCODERS_MAP,
EncryptBackends,
Float,
ForeignKey,
ForeignKeyField,
@ -53,7 +56,6 @@ from ormar.fields import (
Time,
UUID,
UniqueColumns,
EncryptBackends
) # noqa: I100
from ormar.models import ExcludableItems, Model
from ormar.models.metaclass import ModelMeta
@ -111,5 +113,7 @@ __all__ = [
"ExcludableItems",
"and_",
"or_",
"EncryptBackends"
"EncryptBackends",
"ENCODERS_MAP",
"DECODERS_MAP",
]

View File

@ -21,8 +21,9 @@ from ormar.fields.model_fields import (
Time,
UUID,
)
from ormar.fields.through_field import Through, ThroughField
from ormar.fields.parsers import DECODERS_MAP, ENCODERS_MAP
from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends
from ormar.fields.through_field import Through, ThroughField
__all__ = [
"Decimal",
@ -46,5 +47,7 @@ __all__ = [
"ThroughField",
"Through",
"EncryptBackends",
"EncryptBackend"
"EncryptBackend",
"DECODERS_MAP",
"ENCODERS_MAP",
]

View File

@ -6,8 +6,11 @@ from pydantic.fields import FieldInfo, Required, Undefined
import ormar # noqa I101
from ormar import ModelDefinitionError
from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends, \
EncryptedString
from ormar.fields.sqlalchemy_encrypted import (
EncryptBackend,
EncryptBackends,
EncryptedString,
)
if TYPE_CHECKING: # pragma no cover
from ormar.models import Model
@ -54,7 +57,7 @@ class BaseField(FieldInfo):
encrypt_secret: str
encrypt_backend: EncryptBackends = EncryptBackends.NONE
encrypt_custom_backend: Type[EncryptBackend] = None
encrypt_custom_backend: Optional[Type[EncryptBackend]] = None
encrypt_max_length: int = 5000
default: Any
@ -101,11 +104,10 @@ class BaseField(FieldInfo):
:rtype: bool
"""
return (
field_name not in ["default", "default_factory", "alias",
"allow_mutation"]
and not field_name.startswith("__")
and hasattr(cls, field_name)
and not callable(getattr(cls, field_name))
field_name not in ["default", "default_factory", "alias", "allow_mutation"]
and not field_name.startswith("__")
and hasattr(cls, field_name)
and not callable(getattr(cls, field_name))
)
@classmethod
@ -214,7 +216,7 @@ class BaseField(FieldInfo):
:rtype: bool
"""
return cls.default is not None or (
cls.server_default is not None and use_server
cls.server_default is not None and use_server
)
@classmethod
@ -247,7 +249,7 @@ class BaseField(FieldInfo):
ondelete=con.ondelete,
onupdate=con.onupdate,
name=f"fk_{cls.owner.Meta.tablename}_{cls.to.Meta.tablename}"
f"_{cls.to.get_column_alias(cls.to.Meta.pkname)}_{cls.name}",
f"_{cls.to.get_column_alias(cls.to.Meta.pkname)}_{cls.name}",
)
for con in cls.constraints
]
@ -278,33 +280,46 @@ class BaseField(FieldInfo):
server_default=cls.server_default,
)
else:
if cls.primary_key or cls.is_relation:
raise ModelDefinitionError("Primary key field and relations fields"
"cannot be encrypted!")
column = sqlalchemy.Column(
cls.alias or name,
EncryptedString(
_field_type=cls,
encrypt_secret=cls.encrypt_secret,
encrypt_backend=cls.encrypt_backend,
encrypt_custom_backend=cls.encrypt_custom_backend,
encrypt_max_length=cls.encrypt_max_length
),
nullable=cls.nullable,
index=cls.index,
unique=cls.unique,
default=cls.default,
server_default=cls.server_default,
)
column = cls._get_encrypted_column(name=name)
return column
@classmethod
def _get_encrypted_column(cls, name: str) -> sqlalchemy.Column:
"""
Returns EncryptedString column type instead of actual column.
:param name: column name
:type name: str
:return: newly defined column
:rtype: sqlalchemy.Column
"""
if cls.primary_key or cls.is_relation:
raise ModelDefinitionError(
"Primary key field and relations fields" "cannot be encrypted!"
)
column = sqlalchemy.Column(
cls.alias or name,
EncryptedString(
_field_type=cls,
encrypt_secret=cls.encrypt_secret,
encrypt_backend=cls.encrypt_backend,
encrypt_custom_backend=cls.encrypt_custom_backend,
encrypt_max_length=cls.encrypt_max_length,
),
nullable=cls.nullable,
index=cls.index,
unique=cls.unique,
default=cls.default,
server_default=cls.server_default,
)
return column
@classmethod
def expand_relationship(
cls,
value: Any,
child: Union["Model", "NewBaseModel"],
to_register: bool = True,
cls,
value: Any,
child: Union["Model", "NewBaseModel"],
to_register: bool = True,
) -> Any:
"""
Function overwritten for relations, in basic field the value is returned as is.
@ -332,7 +347,7 @@ class BaseField(FieldInfo):
:rtype: None
"""
if cls.owner is not None and (
cls.owner == cls.to or cls.owner.Meta == cls.to.Meta
cls.owner == cls.to or cls.owner.Meta == cls.to.Meta
):
cls.self_reference = True
cls.self_reference_primary = cls.name

View File

@ -184,10 +184,25 @@ def ForeignKey( # noqa CFQ002
owner = kwargs.pop("owner", None)
self_reference = kwargs.pop("self_reference", False)
default = kwargs.pop("default", None)
if default is not None:
encrypt_secret = kwargs.pop("encrypt_secret", None)
encrypt_backend = kwargs.pop("encrypt_backend", None)
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
encrypt_max_length = kwargs.pop("encrypt_max_length", None)
not_supported = [
default,
encrypt_secret,
encrypt_backend,
encrypt_custom_backend,
encrypt_max_length,
]
if any(x is not None for x in not_supported):
raise ModelDefinitionError(
"Argument 'default' is not supported " "on relation fields!"
f"Argument {next((x for x in not_supported if x is not None))} "
f"is not supported "
"on relation fields!"
)
if to.__class__ == ForwardRef:
@ -386,8 +401,6 @@ class ForeignKeyField(BaseField):
:return: (if needed) registered Model
:rtype: Model
"""
if cls.to.pk_type() == uuid.UUID and isinstance(value, str):
value = uuid.UUID(value)
if not isinstance(value, cls.to.pk_type()):
raise RelationshipInstanceError(
f"Relationship error - ForeignKey {cls.to.__name__} "

View File

@ -97,9 +97,23 @@ def ManyToMany(
forbid_through_relations(cast(Type["Model"], through))
default = kwargs.pop("default", None)
if default is not None:
encrypt_secret = kwargs.pop("encrypt_secret", None)
encrypt_backend = kwargs.pop("encrypt_backend", None)
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
encrypt_max_length = kwargs.pop("encrypt_max_length", None)
not_supported = [
default,
encrypt_secret,
encrypt_backend,
encrypt_custom_backend,
encrypt_max_length,
]
if any(x is not None for x in not_supported):
raise ModelDefinitionError(
"Argument 'default' is not supported " "on relation fields!"
f"Argument {next((x for x in not_supported if x is not None))} "
f"is not supported "
"on relation fields!"
)
if to.__class__ == ForwardRef:

View File

@ -76,8 +76,7 @@ class ModelFieldFactory:
encrypt_secret = kwargs.pop("encrypt_secret", None)
encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE)
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend",
None)
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
encrypt_max_length = kwargs.pop("encrypt_max_length", 5000)
namespace = dict(

44
ormar/fields/parsers.py Normal file
View File

@ -0,0 +1,44 @@
import datetime
import decimal
from typing import Any
import pydantic
from pydantic.datetime_parse import parse_date, parse_datetime, parse_time
try:
import orjson as json
except ImportError: # pragma: no cover
import json # type: ignore
def parse_bool(value: str) -> bool:
return value == "true"
def encode_bool(value: bool) -> str:
return "true" if value else "false"
def encode_json(value: Any) -> str:
value = json.dumps(value) if not isinstance(value, str) else value
value = value.decode("utf-8") if isinstance(value, bytes) else value
return value
ENCODERS_MAP = {
bool: encode_bool,
datetime.datetime: lambda x: x.isoformat(),
datetime.date: lambda x: x.isoformat(),
datetime.time: lambda x: x.isoformat(),
pydantic.Json: encode_json,
decimal.Decimal: float,
}
DECODERS_MAP = {
bool: parse_bool,
datetime.datetime: parse_datetime,
datetime.date: parse_date,
datetime.time: parse_time,
pydantic.Json: json.loads,
decimal.Decimal: decimal.Decimal,
}

View File

@ -1,50 +1,49 @@
# inspired by sqlalchemy-utils (https://github.com/kvesteri/sqlalchemy-utils)
import abc
import base64
import datetime
import json
from enum import Enum
from typing import Any, Callable, TYPE_CHECKING, Type, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Type, Union
import sqlalchemy.types as types
from pydantic.utils import lenient_issubclass
from sqlalchemy.engine.default import DefaultDialect
from ormar import ModelDefinitionError
import ormar # noqa: I100, I202
from ormar import ModelDefinitionError # noqa: I202, I100
cryptography = None
try:
import cryptography
try: # pragma: nocover
import cryptography # type: ignore
from cryptography.fernet import Fernet
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
except ImportError:
except ImportError: # pragma: nocover
pass
if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: nocover
from ormar import BaseField
class EncryptBackend(abc.ABC):
def _update_key(self, key):
def _refresh(self, key: Union[str, bytes]) -> None:
if isinstance(key, str):
key = key.encode()
digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
digest.update(key)
engine_key = digest.finalize()
self._initialize_engine(engine_key)
self._initialize_backend(engine_key)
@abc.abstractmethod
def _initialize_engine(self, secret_key: bytes):
def _initialize_backend(self, secret_key: bytes) -> None: # pragma: nocover
pass
@abc.abstractmethod
def encrypt(self, value: Any) -> str:
def encrypt(self, value: Any) -> str: # pragma: nocover
pass
@abc.abstractmethod
def decrypt(self, value: Any) -> str:
def decrypt(self, value: Any) -> str: # pragma: nocover
pass
@ -53,11 +52,11 @@ class HashBackend(EncryptBackend):
One-way hashing - in example for passwords, no way to decrypt the value!
"""
def _initialize_engine(self, secret_key: bytes):
def _initialize_backend(self, secret_key: bytes) -> None:
self.secret_key = base64.urlsafe_b64encode(secret_key)
def encrypt(self, value: Any) -> str:
if not isinstance(value, str):
if not isinstance(value, str): # pragma: nocover
value = repr(value)
value = value.encode()
digest = hashes.Hash(hashes.SHA512(), backend=default_backend())
@ -67,7 +66,7 @@ class HashBackend(EncryptBackend):
return hashed_value.hex()
def decrypt(self, value: Any) -> str:
if not isinstance(value, str):
if not isinstance(value, str): # pragma: nocover
value = str(value)
return value
@ -77,7 +76,7 @@ class FernetBackend(EncryptBackend):
Two-way encryption, data stored in db are encrypted but decrypted during query.
"""
def _initialize_engine(self, secret_key: bytes):
def _initialize_backend(self, secret_key: bytes) -> None:
self.secret_key = base64.urlsafe_b64encode(secret_key)
self.fernet = Fernet(self.secret_key)
@ -86,14 +85,14 @@ class FernetBackend(EncryptBackend):
value = repr(value)
value = value.encode()
encrypted = self.fernet.encrypt(value)
return encrypted.decode('utf-8')
return encrypted.decode("utf-8")
def decrypt(self, value: Any) -> str:
if not isinstance(value, str):
if not isinstance(value, str): # pragma: nocover
value = str(value)
decrypted = self.fernet.decrypt(value.encode())
decrypted: Union[str, bytes] = self.fernet.decrypt(value.encode())
if not isinstance(decrypted, str):
decrypted = decrypted.decode('utf-8')
decrypted = decrypted.decode("utf-8")
return decrypted
@ -104,115 +103,82 @@ class EncryptBackends(Enum):
CUSTOM = 3
backends_map = {
BACKENDS_MAP = {
EncryptBackends.FERNET: FernetBackend,
EncryptBackends.HASH: HashBackend,
EncryptBackends.CUSTOM: None
}
class EncryptedString(types.TypeDecorator): # pragma nocover
class EncryptedString(types.TypeDecorator):
"""
Used to store encrypted values in a database
"""
impl = types.TypeEngine
def __init__(self,
*args: Any,
encrypt_secret: Union[str, Callable],
_field_type: Type["BaseField"],
encrypt_max_length: int = 5000,
encrypt_backend: EncryptBackends = EncryptBackends.FERNET,
encrypt_custom_backend: Type[EncryptBackend] = None,
**kwargs: Any) -> None:
super().__init__(*args, **kwargs)
if not cryptography:
def __init__(
self,
encrypt_secret: Union[str, Callable],
encrypt_backend: EncryptBackends = EncryptBackends.FERNET,
encrypt_custom_backend: Type[EncryptBackend] = None,
**kwargs: Any,
) -> None:
_field_type = kwargs.pop("_field_type")
encrypt_max_length = kwargs.pop("encrypt_max_length", 5000)
super().__init__()
if not cryptography: # pragma: nocover
raise ModelDefinitionError(
"In order to encrypt a column 'cryptography' is required!"
)
backend = backends_map.get(encrypt_backend, encrypt_custom_backend)
if not backend or not issubclass(backend, EncryptBackend):
backend = BACKENDS_MAP.get(encrypt_backend, encrypt_custom_backend)
if not backend or not lenient_issubclass(backend, EncryptBackend):
raise ModelDefinitionError("Wrong or no encrypt backend provided!")
self.backend = backend()
self._field_type = _field_type
self._underlying_type = _field_type.column_type
self._key = encrypt_secret
self.max_length = encrypt_max_length
def __repr__(self) -> str:
self.backend: EncryptBackend = backend()
self._field_type: Type["BaseField"] = _field_type
self._underlying_type: Any = _field_type.column_type
self._key: Union[str, Callable] = encrypt_secret
self.max_length: int = encrypt_max_length
type_ = self._field_type.__type__
if type_ is None: # pragma: nocover
raise ModelDefinitionError(
f"Improperly configured field " f"{self._field_type.name}"
)
self.type_: Any = type_
def __repr__(self) -> str: # pragma: nocover
return f"VARCHAR({self.max_length})"
def load_dialect_impl(self, dialect: DefaultDialect) -> Any:
return dialect.type_descriptor(types.VARCHAR(self.max_length))
@property
def key(self):
return self._key
@key.setter
def key(self, value):
self._key = value
def _update_key(self):
def _refresh(self) -> None:
key = self._key() if callable(self._key) else self._key
self.backend._update_key(key)
self.backend._refresh(key)
def process_bind_param(self, value, dialect):
"""Encrypt a value on the way in."""
if value is not None:
self._update_key()
def process_bind_param(self, value: Any, dialect: DefaultDialect) -> Optional[str]:
if value is None:
return value
self._refresh()
try:
value = self._underlying_type.process_bind_param(value, dialect)
except AttributeError:
encoder = ormar.ENCODERS_MAP.get(self.type_, None)
if encoder:
value = encoder(value) # type: ignore
try:
value = self._underlying_type.process_bind_param(
value, dialect
)
return self.backend.encrypt(value)
except AttributeError:
# Doesn't have 'process_bind_param'
type_ = self._field_type.__type__
if issubclass(type_, bool):
value = 'true' if value else 'false'
def process_result_value(self, value: Any, dialect: DefaultDialect) -> Any:
if value is None:
return value
self._refresh()
decrypted_value = self.backend.decrypt(value)
try:
return self._underlying_type.process_result_value(decrypted_value, dialect)
except AttributeError:
decoder = ormar.DECODERS_MAP.get(self.type_, None)
if decoder:
return decoder(decrypted_value) # type: ignore
elif issubclass(type_, (datetime.date, datetime.time)):
value = value.isoformat()
# elif issubclass(type_, JSONType):
# value = json.dumps(value)
return self.backend.encrypt(value)
def process_result_value(self, value, dialect):
"""Decrypt value on the way out."""
if value is not None:
self._update_key()
decrypted_value = self.backend.decrypt(value)
try:
return self.underlying_type.process_result_value(
decrypted_value, dialect
)
except AttributeError:
# Doesn't have 'process_result_value'
# Handle 'boolean' and 'dates'
type_ = self._field_type.__type__
# date_types = [datetime.datetime, datetime.time, datetime.date]
if issubclass(type_, bool):
return decrypted_value == 'true'
# elif type_ in date_types:
# return DatetimeHandler.process_value(
# decrypted_value, type_
# )
# elif issubclass(type_, JSONType):
# return json.loads(decrypted_value)
# Handle all others
return self.underlying_type.python_type(decrypted_value)
def _coerce(self, value):
return self.underlying_type._coerce(value)
return self._field_type.__type__(decrypted_value) # type: ignore

View File

@ -1,12 +1,12 @@
import uuid
from typing import Any, Optional, Union
from typing import Any, Optional
from sqlalchemy import CHAR
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.types import TypeDecorator
class UUID(TypeDecorator): # pragma nocover
class UUID(TypeDecorator):
"""
Platform-independent GUID type.
Uses CHAR(36) if in a string mode, otherwise uses CHAR(32), to store UUID.
@ -20,31 +20,11 @@ class UUID(TypeDecorator): # pragma nocover
super().__init__(*args, **kwargs)
self.uuid_format = uuid_format
def __repr__(self) -> str:
def __repr__(self) -> str: # pragma: nocover
if self.uuid_format == "string":
return "CHAR(36)"
return "CHAR(32)"
def _cast_to_uuid(self, value: Union[str, int, bytes]) -> uuid.UUID:
"""
Parses given value into uuid.UUID field.
:param value: value to be parsed
:type value: Union[str, int, bytes]
:return: initialized uuid
:rtype: uuid.UUID
"""
if not isinstance(value, uuid.UUID):
if isinstance(value, bytes):
ret_value = uuid.UUID(bytes=value)
elif isinstance(value, int):
ret_value = uuid.UUID(int=value)
elif isinstance(value, str):
ret_value = uuid.UUID(value)
else:
ret_value = value
return ret_value
def load_dialect_impl(self, dialect: DefaultDialect) -> Any:
return (
dialect.type_descriptor(CHAR(36))
@ -53,12 +33,10 @@ class UUID(TypeDecorator): # pragma nocover
)
def process_bind_param(
self, value: Union[str, int, bytes, uuid.UUID, None], dialect: DefaultDialect
self, value: uuid.UUID, dialect: DefaultDialect
) -> Optional[str]:
if value is None:
return value
if not isinstance(value, uuid.UUID):
value = self._cast_to_uuid(value)
return str(value) if self.uuid_format == "string" else "%.32x" % value.int
def process_result_value(
@ -68,4 +46,4 @@ class UUID(TypeDecorator): # pragma nocover
return value
if not isinstance(value, uuid.UUID):
return uuid.UUID(value)
return value
return value # pragma: nocover

View File

@ -53,6 +53,7 @@ def convert_choices_if_needed( # noqa: CCR001
:return: value, choices list
:rtype: Tuple[Any, List]
"""
# TODO use same maps as with EncryptedString
choices = [o.value if isinstance(o, Enum) else o for o in field.choices]
if field.__type__ in [datetime.datetime, datetime.date, datetime.time]:

View File

@ -88,13 +88,13 @@ class SqlJoin:
return self.main_model.Meta.alias_manager
@property
def to_table(self) -> str:
def to_table(self) -> sqlalchemy.Table:
"""
Shortcut to table name of the next model
:return: name of the target table
:rtype: str
"""
return self.next_model.Meta.table.name
return self.next_model.Meta.table
def _on_clause(
self, previous_alias: str, from_clause: str, to_clause: str,
@ -282,7 +282,7 @@ class SqlJoin:
on_clause = self._on_clause(
previous_alias=self.own_alias,
from_clause=f"{self.target_field.owner.Meta.tablename}.{from_key}",
to_clause=f"{self.to_table}.{to_key}",
to_clause=f"{self.to_table.name}.{to_key}",
)
target_table = self.alias_manager.prefixed_table_name(
self.next_alias, self.to_table
@ -301,7 +301,7 @@ class SqlJoin:
)
self.columns.extend(
self.alias_manager.prefixed_columns(
self.next_alias, self.next_model.Meta.table, self_related_fields
self.next_alias, target_table, self_related_fields
)
)
self.used_aliases.append(self.next_alias)

View File

@ -67,24 +67,21 @@ class AliasManager:
if not fields
else [col for col in table.columns if col.name in fields]
)
return [
text(f"{alias}{table.name}.{column.name} as {alias}{column.name}")
for column in all_columns
]
return [column.label(f"{alias}{column.name}") for column in all_columns]
@staticmethod
def prefixed_table_name(alias: str, name: str) -> text:
def prefixed_table_name(alias: str, table: sqlalchemy.Table) -> text:
"""
Creates text clause with table name with aliased name.
:param alias: alias of given table
:type alias: str
:param name: table name
:type name: str
:param table: table
:type table: sqlalchemy.Table
:return: sqlalchemy text clause as "table_name aliased_name"
:rtype: sqlalchemy text clause
"""
return text(f"{name} {alias}_{name}")
return table.alias(f"{alias}_{table.name}")
def add_relation_type(
self, source_model: Type["Model"], relation_name: str, reverse_name: str = None,