revert to use tables and columns with labels and aliases instead of text clauses, add encryption, mostly working encryption column type with configurable backends

This commit is contained in:
collerek
2021-03-09 20:29:27 +01:00
parent 8d96a3fb84
commit e29bea6f85
14 changed files with 415 additions and 253 deletions

View File

@ -38,9 +38,12 @@ from ormar.fields import (
BaseField, BaseField,
BigInteger, BigInteger,
Boolean, Boolean,
DECODERS_MAP,
Date, Date,
DateTime, DateTime,
Decimal, Decimal,
ENCODERS_MAP,
EncryptBackends,
Float, Float,
ForeignKey, ForeignKey,
ForeignKeyField, ForeignKeyField,
@ -53,7 +56,6 @@ from ormar.fields import (
Time, Time,
UUID, UUID,
UniqueColumns, UniqueColumns,
EncryptBackends
) # noqa: I100 ) # noqa: I100
from ormar.models import ExcludableItems, Model from ormar.models import ExcludableItems, Model
from ormar.models.metaclass import ModelMeta from ormar.models.metaclass import ModelMeta
@ -111,5 +113,7 @@ __all__ = [
"ExcludableItems", "ExcludableItems",
"and_", "and_",
"or_", "or_",
"EncryptBackends" "EncryptBackends",
"ENCODERS_MAP",
"DECODERS_MAP",
] ]

View File

@ -21,8 +21,9 @@ from ormar.fields.model_fields import (
Time, Time,
UUID, UUID,
) )
from ormar.fields.through_field import Through, ThroughField from ormar.fields.parsers import DECODERS_MAP, ENCODERS_MAP
from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends
from ormar.fields.through_field import Through, ThroughField
__all__ = [ __all__ = [
"Decimal", "Decimal",
@ -46,5 +47,7 @@ __all__ = [
"ThroughField", "ThroughField",
"Through", "Through",
"EncryptBackends", "EncryptBackends",
"EncryptBackend" "EncryptBackend",
"DECODERS_MAP",
"ENCODERS_MAP",
] ]

View File

@ -6,8 +6,11 @@ from pydantic.fields import FieldInfo, Required, Undefined
import ormar # noqa I101 import ormar # noqa I101
from ormar import ModelDefinitionError from ormar import ModelDefinitionError
from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends, \ from ormar.fields.sqlalchemy_encrypted import (
EncryptedString EncryptBackend,
EncryptBackends,
EncryptedString,
)
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models import Model from ormar.models import Model
@ -54,7 +57,7 @@ class BaseField(FieldInfo):
encrypt_secret: str encrypt_secret: str
encrypt_backend: EncryptBackends = EncryptBackends.NONE encrypt_backend: EncryptBackends = EncryptBackends.NONE
encrypt_custom_backend: Type[EncryptBackend] = None encrypt_custom_backend: Optional[Type[EncryptBackend]] = None
encrypt_max_length: int = 5000 encrypt_max_length: int = 5000
default: Any default: Any
@ -101,8 +104,7 @@ class BaseField(FieldInfo):
:rtype: bool :rtype: bool
""" """
return ( return (
field_name not in ["default", "default_factory", "alias", field_name not in ["default", "default_factory", "alias", "allow_mutation"]
"allow_mutation"]
and not field_name.startswith("__") and not field_name.startswith("__")
and hasattr(cls, field_name) and hasattr(cls, field_name)
and not callable(getattr(cls, field_name)) and not callable(getattr(cls, field_name))
@ -278,9 +280,23 @@ class BaseField(FieldInfo):
server_default=cls.server_default, server_default=cls.server_default,
) )
else: else:
column = cls._get_encrypted_column(name=name)
return column
@classmethod
def _get_encrypted_column(cls, name: str) -> sqlalchemy.Column:
"""
Returns EncryptedString column type instead of actual column.
:param name: column name
:type name: str
:return: newly defined column
:rtype: sqlalchemy.Column
"""
if cls.primary_key or cls.is_relation: if cls.primary_key or cls.is_relation:
raise ModelDefinitionError("Primary key field and relations fields" raise ModelDefinitionError(
"cannot be encrypted!") "Primary key field and relations fields" "cannot be encrypted!"
)
column = sqlalchemy.Column( column = sqlalchemy.Column(
cls.alias or name, cls.alias or name,
EncryptedString( EncryptedString(
@ -288,7 +304,7 @@ class BaseField(FieldInfo):
encrypt_secret=cls.encrypt_secret, encrypt_secret=cls.encrypt_secret,
encrypt_backend=cls.encrypt_backend, encrypt_backend=cls.encrypt_backend,
encrypt_custom_backend=cls.encrypt_custom_backend, encrypt_custom_backend=cls.encrypt_custom_backend,
encrypt_max_length=cls.encrypt_max_length encrypt_max_length=cls.encrypt_max_length,
), ),
nullable=cls.nullable, nullable=cls.nullable,
index=cls.index, index=cls.index,
@ -296,7 +312,6 @@ class BaseField(FieldInfo):
default=cls.default, default=cls.default,
server_default=cls.server_default, server_default=cls.server_default,
) )
return column return column
@classmethod @classmethod

View File

@ -184,10 +184,25 @@ def ForeignKey( # noqa CFQ002
owner = kwargs.pop("owner", None) owner = kwargs.pop("owner", None)
self_reference = kwargs.pop("self_reference", False) self_reference = kwargs.pop("self_reference", False)
default = kwargs.pop("default", None) default = kwargs.pop("default", None)
if default is not None: encrypt_secret = kwargs.pop("encrypt_secret", None)
encrypt_backend = kwargs.pop("encrypt_backend", None)
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
encrypt_max_length = kwargs.pop("encrypt_max_length", None)
not_supported = [
default,
encrypt_secret,
encrypt_backend,
encrypt_custom_backend,
encrypt_max_length,
]
if any(x is not None for x in not_supported):
raise ModelDefinitionError( raise ModelDefinitionError(
"Argument 'default' is not supported " "on relation fields!" f"Argument {next((x for x in not_supported if x is not None))} "
f"is not supported "
"on relation fields!"
) )
if to.__class__ == ForwardRef: if to.__class__ == ForwardRef:
@ -386,8 +401,6 @@ class ForeignKeyField(BaseField):
:return: (if needed) registered Model :return: (if needed) registered Model
:rtype: Model :rtype: Model
""" """
if cls.to.pk_type() == uuid.UUID and isinstance(value, str):
value = uuid.UUID(value)
if not isinstance(value, cls.to.pk_type()): if not isinstance(value, cls.to.pk_type()):
raise RelationshipInstanceError( raise RelationshipInstanceError(
f"Relationship error - ForeignKey {cls.to.__name__} " f"Relationship error - ForeignKey {cls.to.__name__} "

View File

@ -97,9 +97,23 @@ def ManyToMany(
forbid_through_relations(cast(Type["Model"], through)) forbid_through_relations(cast(Type["Model"], through))
default = kwargs.pop("default", None) default = kwargs.pop("default", None)
if default is not None: encrypt_secret = kwargs.pop("encrypt_secret", None)
encrypt_backend = kwargs.pop("encrypt_backend", None)
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
encrypt_max_length = kwargs.pop("encrypt_max_length", None)
not_supported = [
default,
encrypt_secret,
encrypt_backend,
encrypt_custom_backend,
encrypt_max_length,
]
if any(x is not None for x in not_supported):
raise ModelDefinitionError( raise ModelDefinitionError(
"Argument 'default' is not supported " "on relation fields!" f"Argument {next((x for x in not_supported if x is not None))} "
f"is not supported "
"on relation fields!"
) )
if to.__class__ == ForwardRef: if to.__class__ == ForwardRef:

View File

@ -76,8 +76,7 @@ class ModelFieldFactory:
encrypt_secret = kwargs.pop("encrypt_secret", None) encrypt_secret = kwargs.pop("encrypt_secret", None)
encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE) encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE)
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
None)
encrypt_max_length = kwargs.pop("encrypt_max_length", 5000) encrypt_max_length = kwargs.pop("encrypt_max_length", 5000)
namespace = dict( namespace = dict(

44
ormar/fields/parsers.py Normal file
View File

@ -0,0 +1,44 @@
import datetime
import decimal
from typing import Any
import pydantic
from pydantic.datetime_parse import parse_date, parse_datetime, parse_time
try:
import orjson as json
except ImportError: # pragma: no cover
import json # type: ignore
def parse_bool(value: str) -> bool:
return value == "true"
def encode_bool(value: bool) -> str:
return "true" if value else "false"
def encode_json(value: Any) -> str:
value = json.dumps(value) if not isinstance(value, str) else value
value = value.decode("utf-8") if isinstance(value, bytes) else value
return value
ENCODERS_MAP = {
bool: encode_bool,
datetime.datetime: lambda x: x.isoformat(),
datetime.date: lambda x: x.isoformat(),
datetime.time: lambda x: x.isoformat(),
pydantic.Json: encode_json,
decimal.Decimal: float,
}
DECODERS_MAP = {
bool: parse_bool,
datetime.datetime: parse_datetime,
datetime.date: parse_date,
datetime.time: parse_time,
pydantic.Json: json.loads,
decimal.Decimal: decimal.Decimal,
}

View File

@ -1,50 +1,49 @@
# inspired by sqlalchemy-utils (https://github.com/kvesteri/sqlalchemy-utils) # inspired by sqlalchemy-utils (https://github.com/kvesteri/sqlalchemy-utils)
import abc import abc
import base64 import base64
import datetime
import json
from enum import Enum from enum import Enum
from typing import Any, Callable, TYPE_CHECKING, Type, Union from typing import Any, Callable, Optional, TYPE_CHECKING, Type, Union
import sqlalchemy.types as types import sqlalchemy.types as types
from pydantic.utils import lenient_issubclass
from sqlalchemy.engine.default import DefaultDialect from sqlalchemy.engine.default import DefaultDialect
from ormar import ModelDefinitionError import ormar # noqa: I100, I202
from ormar import ModelDefinitionError # noqa: I202, I100
cryptography = None cryptography = None
try: try: # pragma: nocover
import cryptography import cryptography # type: ignore
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
except ImportError: except ImportError: # pragma: nocover
pass pass
if TYPE_CHECKING: if TYPE_CHECKING: # pragma: nocover
from ormar import BaseField from ormar import BaseField
class EncryptBackend(abc.ABC): class EncryptBackend(abc.ABC):
def _refresh(self, key: Union[str, bytes]) -> None:
def _update_key(self, key):
if isinstance(key, str): if isinstance(key, str):
key = key.encode() key = key.encode()
digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
digest.update(key) digest.update(key)
engine_key = digest.finalize() engine_key = digest.finalize()
self._initialize_engine(engine_key) self._initialize_backend(engine_key)
@abc.abstractmethod @abc.abstractmethod
def _initialize_engine(self, secret_key: bytes): def _initialize_backend(self, secret_key: bytes) -> None: # pragma: nocover
pass pass
@abc.abstractmethod @abc.abstractmethod
def encrypt(self, value: Any) -> str: def encrypt(self, value: Any) -> str: # pragma: nocover
pass pass
@abc.abstractmethod @abc.abstractmethod
def decrypt(self, value: Any) -> str: def decrypt(self, value: Any) -> str: # pragma: nocover
pass pass
@ -53,11 +52,11 @@ class HashBackend(EncryptBackend):
One-way hashing - in example for passwords, no way to decrypt the value! One-way hashing - in example for passwords, no way to decrypt the value!
""" """
def _initialize_engine(self, secret_key: bytes): def _initialize_backend(self, secret_key: bytes) -> None:
self.secret_key = base64.urlsafe_b64encode(secret_key) self.secret_key = base64.urlsafe_b64encode(secret_key)
def encrypt(self, value: Any) -> str: def encrypt(self, value: Any) -> str:
if not isinstance(value, str): if not isinstance(value, str): # pragma: nocover
value = repr(value) value = repr(value)
value = value.encode() value = value.encode()
digest = hashes.Hash(hashes.SHA512(), backend=default_backend()) digest = hashes.Hash(hashes.SHA512(), backend=default_backend())
@ -67,7 +66,7 @@ class HashBackend(EncryptBackend):
return hashed_value.hex() return hashed_value.hex()
def decrypt(self, value: Any) -> str: def decrypt(self, value: Any) -> str:
if not isinstance(value, str): if not isinstance(value, str): # pragma: nocover
value = str(value) value = str(value)
return value return value
@ -77,7 +76,7 @@ class FernetBackend(EncryptBackend):
Two-way encryption, data stored in db are encrypted but decrypted during query. Two-way encryption, data stored in db are encrypted but decrypted during query.
""" """
def _initialize_engine(self, secret_key: bytes): def _initialize_backend(self, secret_key: bytes) -> None:
self.secret_key = base64.urlsafe_b64encode(secret_key) self.secret_key = base64.urlsafe_b64encode(secret_key)
self.fernet = Fernet(self.secret_key) self.fernet = Fernet(self.secret_key)
@ -86,14 +85,14 @@ class FernetBackend(EncryptBackend):
value = repr(value) value = repr(value)
value = value.encode() value = value.encode()
encrypted = self.fernet.encrypt(value) encrypted = self.fernet.encrypt(value)
return encrypted.decode('utf-8') return encrypted.decode("utf-8")
def decrypt(self, value: Any) -> str: def decrypt(self, value: Any) -> str:
if not isinstance(value, str): if not isinstance(value, str): # pragma: nocover
value = str(value) value = str(value)
decrypted = self.fernet.decrypt(value.encode()) decrypted: Union[str, bytes] = self.fernet.decrypt(value.encode())
if not isinstance(decrypted, str): if not isinstance(decrypted, str):
decrypted = decrypted.decode('utf-8') decrypted = decrypted.decode("utf-8")
return decrypted return decrypted
@ -104,115 +103,82 @@ class EncryptBackends(Enum):
CUSTOM = 3 CUSTOM = 3
backends_map = { BACKENDS_MAP = {
EncryptBackends.FERNET: FernetBackend, EncryptBackends.FERNET: FernetBackend,
EncryptBackends.HASH: HashBackend, EncryptBackends.HASH: HashBackend,
EncryptBackends.CUSTOM: None
} }
class EncryptedString(types.TypeDecorator): # pragma nocover class EncryptedString(types.TypeDecorator):
""" """
Used to store encrypted values in a database Used to store encrypted values in a database
""" """
impl = types.TypeEngine impl = types.TypeEngine
def __init__(self, def __init__(
*args: Any, self,
encrypt_secret: Union[str, Callable], encrypt_secret: Union[str, Callable],
_field_type: Type["BaseField"],
encrypt_max_length: int = 5000,
encrypt_backend: EncryptBackends = EncryptBackends.FERNET, encrypt_backend: EncryptBackends = EncryptBackends.FERNET,
encrypt_custom_backend: Type[EncryptBackend] = None, encrypt_custom_backend: Type[EncryptBackend] = None,
**kwargs: Any) -> None: **kwargs: Any,
super().__init__(*args, **kwargs) ) -> None:
if not cryptography: _field_type = kwargs.pop("_field_type")
encrypt_max_length = kwargs.pop("encrypt_max_length", 5000)
super().__init__()
if not cryptography: # pragma: nocover
raise ModelDefinitionError( raise ModelDefinitionError(
"In order to encrypt a column 'cryptography' is required!" "In order to encrypt a column 'cryptography' is required!"
) )
backend = backends_map.get(encrypt_backend, encrypt_custom_backend) backend = BACKENDS_MAP.get(encrypt_backend, encrypt_custom_backend)
if not backend or not issubclass(backend, EncryptBackend): if not backend or not lenient_issubclass(backend, EncryptBackend):
raise ModelDefinitionError("Wrong or no encrypt backend provided!") raise ModelDefinitionError("Wrong or no encrypt backend provided!")
self.backend = backend()
self._field_type = _field_type
self._underlying_type = _field_type.column_type
self._key = encrypt_secret
self.max_length = encrypt_max_length
def __repr__(self) -> str: self.backend: EncryptBackend = backend()
self._field_type: Type["BaseField"] = _field_type
self._underlying_type: Any = _field_type.column_type
self._key: Union[str, Callable] = encrypt_secret
self.max_length: int = encrypt_max_length
type_ = self._field_type.__type__
if type_ is None: # pragma: nocover
raise ModelDefinitionError(
f"Improperly configured field " f"{self._field_type.name}"
)
self.type_: Any = type_
def __repr__(self) -> str: # pragma: nocover
return f"VARCHAR({self.max_length})" return f"VARCHAR({self.max_length})"
def load_dialect_impl(self, dialect: DefaultDialect) -> Any: def load_dialect_impl(self, dialect: DefaultDialect) -> Any:
return dialect.type_descriptor(types.VARCHAR(self.max_length)) return dialect.type_descriptor(types.VARCHAR(self.max_length))
@property def _refresh(self) -> None:
def key(self):
return self._key
@key.setter
def key(self, value):
self._key = value
def _update_key(self):
key = self._key() if callable(self._key) else self._key key = self._key() if callable(self._key) else self._key
self.backend._update_key(key) self.backend._refresh(key)
def process_bind_param(self, value, dialect):
"""Encrypt a value on the way in."""
if value is not None:
self._update_key()
def process_bind_param(self, value: Any, dialect: DefaultDialect) -> Optional[str]:
if value is None:
return value
self._refresh()
try: try:
value = self._underlying_type.process_bind_param( value = self._underlying_type.process_bind_param(value, dialect)
value, dialect
)
except AttributeError: except AttributeError:
# Doesn't have 'process_bind_param' encoder = ormar.ENCODERS_MAP.get(self.type_, None)
type_ = self._field_type.__type__ if encoder:
if issubclass(type_, bool): value = encoder(value) # type: ignore
value = 'true' if value else 'false'
elif issubclass(type_, (datetime.date, datetime.time)):
value = value.isoformat()
# elif issubclass(type_, JSONType):
# value = json.dumps(value)
return self.backend.encrypt(value) return self.backend.encrypt(value)
def process_result_value(self, value, dialect): def process_result_value(self, value: Any, dialect: DefaultDialect) -> Any:
"""Decrypt value on the way out.""" if value is None:
if value is not None: return value
self._update_key() self._refresh()
decrypted_value = self.backend.decrypt(value) decrypted_value = self.backend.decrypt(value)
try: try:
return self.underlying_type.process_result_value( return self._underlying_type.process_result_value(decrypted_value, dialect)
decrypted_value, dialect
)
except AttributeError: except AttributeError:
# Doesn't have 'process_result_value' decoder = ormar.DECODERS_MAP.get(self.type_, None)
if decoder:
return decoder(decrypted_value) # type: ignore
# Handle 'boolean' and 'dates' return self._field_type.__type__(decrypted_value) # type: ignore
type_ = self._field_type.__type__
# date_types = [datetime.datetime, datetime.time, datetime.date]
if issubclass(type_, bool):
return decrypted_value == 'true'
# elif type_ in date_types:
# return DatetimeHandler.process_value(
# decrypted_value, type_
# )
# elif issubclass(type_, JSONType):
# return json.loads(decrypted_value)
# Handle all others
return self.underlying_type.python_type(decrypted_value)
def _coerce(self, value):
return self.underlying_type._coerce(value)

View File

@ -1,12 +1,12 @@
import uuid import uuid
from typing import Any, Optional, Union from typing import Any, Optional
from sqlalchemy import CHAR from sqlalchemy import CHAR
from sqlalchemy.engine.default import DefaultDialect from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.types import TypeDecorator from sqlalchemy.types import TypeDecorator
class UUID(TypeDecorator): # pragma nocover class UUID(TypeDecorator):
""" """
Platform-independent GUID type. Platform-independent GUID type.
Uses CHAR(36) if in a string mode, otherwise uses CHAR(32), to store UUID. Uses CHAR(36) if in a string mode, otherwise uses CHAR(32), to store UUID.
@ -20,31 +20,11 @@ class UUID(TypeDecorator): # pragma nocover
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.uuid_format = uuid_format self.uuid_format = uuid_format
def __repr__(self) -> str: def __repr__(self) -> str: # pragma: nocover
if self.uuid_format == "string": if self.uuid_format == "string":
return "CHAR(36)" return "CHAR(36)"
return "CHAR(32)" return "CHAR(32)"
def _cast_to_uuid(self, value: Union[str, int, bytes]) -> uuid.UUID:
"""
Parses given value into uuid.UUID field.
:param value: value to be parsed
:type value: Union[str, int, bytes]
:return: initialized uuid
:rtype: uuid.UUID
"""
if not isinstance(value, uuid.UUID):
if isinstance(value, bytes):
ret_value = uuid.UUID(bytes=value)
elif isinstance(value, int):
ret_value = uuid.UUID(int=value)
elif isinstance(value, str):
ret_value = uuid.UUID(value)
else:
ret_value = value
return ret_value
def load_dialect_impl(self, dialect: DefaultDialect) -> Any: def load_dialect_impl(self, dialect: DefaultDialect) -> Any:
return ( return (
dialect.type_descriptor(CHAR(36)) dialect.type_descriptor(CHAR(36))
@ -53,12 +33,10 @@ class UUID(TypeDecorator): # pragma nocover
) )
def process_bind_param( def process_bind_param(
self, value: Union[str, int, bytes, uuid.UUID, None], dialect: DefaultDialect self, value: uuid.UUID, dialect: DefaultDialect
) -> Optional[str]: ) -> Optional[str]:
if value is None: if value is None:
return value return value
if not isinstance(value, uuid.UUID):
value = self._cast_to_uuid(value)
return str(value) if self.uuid_format == "string" else "%.32x" % value.int return str(value) if self.uuid_format == "string" else "%.32x" % value.int
def process_result_value( def process_result_value(
@ -68,4 +46,4 @@ class UUID(TypeDecorator): # pragma nocover
return value return value
if not isinstance(value, uuid.UUID): if not isinstance(value, uuid.UUID):
return uuid.UUID(value) return uuid.UUID(value)
return value return value # pragma: nocover

View File

@ -53,6 +53,7 @@ def convert_choices_if_needed( # noqa: CCR001
:return: value, choices list :return: value, choices list
:rtype: Tuple[Any, List] :rtype: Tuple[Any, List]
""" """
# TODO use same maps as with EncryptedString
choices = [o.value if isinstance(o, Enum) else o for o in field.choices] choices = [o.value if isinstance(o, Enum) else o for o in field.choices]
if field.__type__ in [datetime.datetime, datetime.date, datetime.time]: if field.__type__ in [datetime.datetime, datetime.date, datetime.time]:

View File

@ -88,13 +88,13 @@ class SqlJoin:
return self.main_model.Meta.alias_manager return self.main_model.Meta.alias_manager
@property @property
def to_table(self) -> str: def to_table(self) -> sqlalchemy.Table:
""" """
Shortcut to table name of the next model Shortcut to table name of the next model
:return: name of the target table :return: name of the target table
:rtype: str :rtype: str
""" """
return self.next_model.Meta.table.name return self.next_model.Meta.table
def _on_clause( def _on_clause(
self, previous_alias: str, from_clause: str, to_clause: str, self, previous_alias: str, from_clause: str, to_clause: str,
@ -282,7 +282,7 @@ class SqlJoin:
on_clause = self._on_clause( on_clause = self._on_clause(
previous_alias=self.own_alias, previous_alias=self.own_alias,
from_clause=f"{self.target_field.owner.Meta.tablename}.{from_key}", from_clause=f"{self.target_field.owner.Meta.tablename}.{from_key}",
to_clause=f"{self.to_table}.{to_key}", to_clause=f"{self.to_table.name}.{to_key}",
) )
target_table = self.alias_manager.prefixed_table_name( target_table = self.alias_manager.prefixed_table_name(
self.next_alias, self.to_table self.next_alias, self.to_table
@ -301,7 +301,7 @@ class SqlJoin:
) )
self.columns.extend( self.columns.extend(
self.alias_manager.prefixed_columns( self.alias_manager.prefixed_columns(
self.next_alias, self.next_model.Meta.table, self_related_fields self.next_alias, target_table, self_related_fields
) )
) )
self.used_aliases.append(self.next_alias) self.used_aliases.append(self.next_alias)

View File

@ -67,24 +67,21 @@ class AliasManager:
if not fields if not fields
else [col for col in table.columns if col.name in fields] else [col for col in table.columns if col.name in fields]
) )
return [ return [column.label(f"{alias}{column.name}") for column in all_columns]
text(f"{alias}{table.name}.{column.name} as {alias}{column.name}")
for column in all_columns
]
@staticmethod @staticmethod
def prefixed_table_name(alias: str, name: str) -> text: def prefixed_table_name(alias: str, table: sqlalchemy.Table) -> text:
""" """
Creates text clause with table name with aliased name. Creates text clause with table name with aliased name.
:param alias: alias of given table :param alias: alias of given table
:type alias: str :type alias: str
:param name: table name :param table: table
:type name: str :type table: sqlalchemy.Table
:return: sqlalchemy text clause as "table_name aliased_name" :return: sqlalchemy text clause as "table_name aliased_name"
:rtype: sqlalchemy text clause :rtype: sqlalchemy text clause
""" """
return text(f"{name} {alias}_{name}") return table.alias(f"{alias}_{table.name}")
def add_relation_type( def add_relation_type(
self, source_model: Type["Model"], relation_name: str, reverse_name: str = None, self, source_model: Type["Model"], relation_name: str, reverse_name: str = None,

View File

@ -1,12 +1,16 @@
# type: ignore
import decimal
import uuid import uuid
from typing import Optional import datetime
from typing import Any
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
import ormar import ormar
from ormar.exceptions import QueryDefinitionError from ormar import ModelDefinitionError
from ormar.fields.sqlalchemy_encrypted import EncryptedString
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL) database = databases.Database(DATABASE_URL)
@ -18,22 +22,58 @@ class BaseMeta(ormar.ModelMeta):
database = database database = database
default_fernet = dict(
encrypt_secret="asd123", encrypt_backend=ormar.EncryptBackends.FERNET,
)
class DummyBackend(ormar.fields.EncryptBackend):
def _initialize_backend(self, secret_key: bytes) -> None:
pass
def encrypt(self, value: Any) -> str:
return value
def decrypt(self, value: Any) -> str:
return value
class Author(ormar.Model): class Author(ormar.Model):
class Meta(BaseMeta): class Meta(BaseMeta):
tablename = "authors" tablename = "authors"
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, name: str = ormar.String(max_length=100, **default_fernet)
encrypt_secret='asd123', uuid_test = ormar.UUID(default=uuid.uuid4, uuid_format="string")
encrypt_backend=ormar.EncryptBackends.FERNET) uuid_test2 = ormar.UUID(nullable=True, uuid_format="string")
uuid_test = ormar.UUID(default=uuid.uuid4, uuid_format='string') password: str = ormar.String(
password: str = ormar.String(max_length=100, max_length=128,
encrypt_secret='udxc32', encrypt_secret="udxc32",
encrypt_backend=ormar.EncryptBackends.HASH) encrypt_backend=ormar.EncryptBackends.HASH,
birth_year: int = ormar.Integer(nullable=True, )
encrypt_secret='secure89key%^&psdijfipew', birth_year: int = ormar.Integer(
nullable=True,
encrypt_secret="secure89key%^&psdijfipew",
encrypt_max_length=200, encrypt_max_length=200,
encrypt_backend=ormar.EncryptBackends.FERNET) encrypt_backend=ormar.EncryptBackends.FERNET,
)
test_text: str = ormar.Text(default="", **default_fernet)
test_bool: bool = ormar.Boolean(nullable=False, **default_fernet)
test_float: float = ormar.Float(**default_fernet)
test_float2: float = ormar.Float(nullable=True, **default_fernet)
test_datetime = ormar.DateTime(default=datetime.datetime.now, **default_fernet)
test_date = ormar.Date(default=datetime.date.today, **default_fernet)
test_time = ormar.Time(default=datetime.time, **default_fernet)
test_json = ormar.JSON(default={}, **default_fernet)
test_bigint: int = ormar.BigInteger(default=0, **default_fernet)
test_decimal = ormar.Decimal(scale=2, precision=10, **default_fernet)
test_decimal2 = ormar.Decimal(max_digits=10, decimal_places=2, **default_fernet)
custom_backend: str = ormar.String(
max_length=200,
encrypt_secret="asda8",
encrypt_backend=ormar.EncryptBackends.CUSTOM,
encrypt_custom_backend=DummyBackend,
)
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
@ -45,16 +85,104 @@ def create_test_database():
metadata.drop_all(engine) metadata.drop_all(engine)
def test_error_on_encrypted_pk():
with pytest.raises(ModelDefinitionError):
class Wrong(ormar.Model):
class Meta(BaseMeta):
tablename = "wrongs"
id: int = ormar.Integer(
primary_key=True,
encrypt_secret="asd123",
encrypt_backend=ormar.EncryptBackends.FERNET,
)
def test_error_on_encrypted_relation():
with pytest.raises(ModelDefinitionError):
class Wrong2(ormar.Model):
class Meta(BaseMeta):
tablename = "wrongs2"
id: int = ormar.Integer(primary_key=True)
author = ormar.ForeignKey(
Author,
encrypt_secret="asd123",
encrypt_backend=ormar.EncryptBackends.FERNET,
)
def test_error_on_encrypted_m2m_relation():
with pytest.raises(ModelDefinitionError):
class Wrong3(ormar.Model):
class Meta(BaseMeta):
tablename = "wrongs3"
id: int = ormar.Integer(primary_key=True)
author = ormar.ManyToMany(
Author,
encrypt_secret="asd123",
encrypt_backend=ormar.EncryptBackends.FERNET,
)
def test_wrong_backend():
with pytest.raises(ModelDefinitionError):
class Wrong3(ormar.Model):
class Meta(BaseMeta):
tablename = "wrongs3"
id: int = ormar.Integer(primary_key=True)
author = ormar.Integer(
encrypt_secret="asd123",
encrypt_backend=ormar.EncryptBackends.CUSTOM,
encrypt_custom_backend="aa",
)
def test_db_structure(): def test_db_structure():
assert Author.Meta.table.c.get('name').type.impl.__class__ == sqlalchemy.NVARCHAR assert Author.Meta.table.c.get("name").type.__class__ == EncryptedString
assert Author.Meta.table.c.get('birth_year').type.max_length == 200 assert Author.Meta.table.c.get("birth_year").type.max_length == 200
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wrong_query_foreign_key_type(): async def test_save_and_retrieve():
async with database: async with database:
await Author(name='Test', birth_year=1988, password='test123').save() test_uuid = uuid.uuid4()
await Author(
name="Test",
birth_year=1988,
password="test123",
uuid_test=test_uuid,
test_float=1.2,
test_bool=True,
test_decimal=decimal.Decimal(3.5),
test_decimal2=decimal.Decimal(5.5),
test_json=dict(aa=12),
custom_backend="test12",
).save()
author = await Author.objects.get() author = await Author.objects.get()
assert author.name == 'Test' assert author.name == "Test"
assert author.birth_year == 1988 assert author.birth_year == 1988
password = (
"03e4a4d513e99cb3fe4ee3db282c053daa3f3572b849c3868939a306944ad5c08"
"22b50d4886e10f4cd418c3f2df3ceb02e2e7ac6e920ae0c90f2dedfc8fa16e2"
)
assert author.password == password
assert author.uuid_test == test_uuid
assert author.uuid_test2 is None
assert author.test_datetime.date() == datetime.date.today()
assert author.test_date == datetime.date.today()
assert author.test_text == ""
assert author.test_float == 1.2
assert author.test_float2 is None
assert author.test_bigint == 0
assert author.test_json == {"aa": 12}
assert author.test_decimal == 3.5
assert author.test_decimal2 == 5.5
assert author.custom_backend == "test12"