From e29bea6f85dacb95a276ae3a10ef25b58d67341e Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 9 Mar 2021 20:29:27 +0100 Subject: [PATCH] revert to use tables and columns with labels and aliases instead of text clauses, add encryption, mostly working encryption column type with configurable backends --- ormar/__init__.py | 8 +- ormar/fields/__init__.py | 7 +- ormar/fields/base.py | 81 +++++++----- ormar/fields/foreign_key.py | 21 +++- ormar/fields/many_to_many.py | 18 ++- ormar/fields/model_fields.py | 3 +- ormar/fields/parsers.py | 44 +++++++ ormar/fields/sqlalchemy_encrypted.py | 182 +++++++++++---------------- ormar/fields/sqlalchemy_uuid.py | 32 +---- ormar/models/helpers/validation.py | 1 + ormar/queryset/join.py | 8 +- ormar/relations/alias_manager.py | 13 +- tests/test_encrypted_columns.py | 164 +++++++++++++++++++++--- tests/test_or_filters.py | 86 ++++++------- 14 files changed, 415 insertions(+), 253 deletions(-) create mode 100644 ormar/fields/parsers.py diff --git a/ormar/__init__.py b/ormar/__init__.py index d68a23c..9193543 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -38,9 +38,12 @@ from ormar.fields import ( BaseField, BigInteger, Boolean, + DECODERS_MAP, Date, DateTime, Decimal, + ENCODERS_MAP, + EncryptBackends, Float, ForeignKey, ForeignKeyField, @@ -53,7 +56,6 @@ from ormar.fields import ( Time, UUID, UniqueColumns, - EncryptBackends ) # noqa: I100 from ormar.models import ExcludableItems, Model from ormar.models.metaclass import ModelMeta @@ -111,5 +113,7 @@ __all__ = [ "ExcludableItems", "and_", "or_", - "EncryptBackends" + "EncryptBackends", + "ENCODERS_MAP", + "DECODERS_MAP", ] diff --git a/ormar/fields/__init__.py b/ormar/fields/__init__.py index 7a22c51..e0cb3b0 100644 --- a/ormar/fields/__init__.py +++ b/ormar/fields/__init__.py @@ -21,8 +21,9 @@ from ormar.fields.model_fields import ( Time, UUID, ) -from ormar.fields.through_field import Through, ThroughField +from ormar.fields.parsers import DECODERS_MAP, ENCODERS_MAP from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends +from ormar.fields.through_field import Through, ThroughField __all__ = [ "Decimal", @@ -46,5 +47,7 @@ __all__ = [ "ThroughField", "Through", "EncryptBackends", - "EncryptBackend" + "EncryptBackend", + "DECODERS_MAP", + "ENCODERS_MAP", ] diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 9fdadb7..041d934 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -6,8 +6,11 @@ from pydantic.fields import FieldInfo, Required, Undefined import ormar # noqa I101 from ormar import ModelDefinitionError -from ormar.fields.sqlalchemy_encrypted import EncryptBackend, EncryptBackends, \ - EncryptedString +from ormar.fields.sqlalchemy_encrypted import ( + EncryptBackend, + EncryptBackends, + EncryptedString, +) if TYPE_CHECKING: # pragma no cover from ormar.models import Model @@ -54,7 +57,7 @@ class BaseField(FieldInfo): encrypt_secret: str encrypt_backend: EncryptBackends = EncryptBackends.NONE - encrypt_custom_backend: Type[EncryptBackend] = None + encrypt_custom_backend: Optional[Type[EncryptBackend]] = None encrypt_max_length: int = 5000 default: Any @@ -101,11 +104,10 @@ class BaseField(FieldInfo): :rtype: bool """ return ( - field_name not in ["default", "default_factory", "alias", - "allow_mutation"] - and not field_name.startswith("__") - and hasattr(cls, field_name) - and not callable(getattr(cls, field_name)) + field_name not in ["default", "default_factory", "alias", "allow_mutation"] + and not field_name.startswith("__") + and hasattr(cls, field_name) + and not callable(getattr(cls, field_name)) ) @classmethod @@ -214,7 +216,7 @@ class BaseField(FieldInfo): :rtype: bool """ return cls.default is not None or ( - cls.server_default is not None and use_server + cls.server_default is not None and use_server ) @classmethod @@ -247,7 +249,7 @@ class BaseField(FieldInfo): ondelete=con.ondelete, onupdate=con.onupdate, name=f"fk_{cls.owner.Meta.tablename}_{cls.to.Meta.tablename}" - f"_{cls.to.get_column_alias(cls.to.Meta.pkname)}_{cls.name}", + f"_{cls.to.get_column_alias(cls.to.Meta.pkname)}_{cls.name}", ) for con in cls.constraints ] @@ -278,33 +280,46 @@ class BaseField(FieldInfo): server_default=cls.server_default, ) else: - if cls.primary_key or cls.is_relation: - raise ModelDefinitionError("Primary key field and relations fields" - "cannot be encrypted!") - column = sqlalchemy.Column( - cls.alias or name, - EncryptedString( - _field_type=cls, - encrypt_secret=cls.encrypt_secret, - encrypt_backend=cls.encrypt_backend, - encrypt_custom_backend=cls.encrypt_custom_backend, - encrypt_max_length=cls.encrypt_max_length - ), - nullable=cls.nullable, - index=cls.index, - unique=cls.unique, - default=cls.default, - server_default=cls.server_default, - ) + column = cls._get_encrypted_column(name=name) + return column + @classmethod + def _get_encrypted_column(cls, name: str) -> sqlalchemy.Column: + """ + Returns EncryptedString column type instead of actual column. + + :param name: column name + :type name: str + :return: newly defined column + :rtype: sqlalchemy.Column + """ + if cls.primary_key or cls.is_relation: + raise ModelDefinitionError( + "Primary key field and relations fields" "cannot be encrypted!" + ) + column = sqlalchemy.Column( + cls.alias or name, + EncryptedString( + _field_type=cls, + encrypt_secret=cls.encrypt_secret, + encrypt_backend=cls.encrypt_backend, + encrypt_custom_backend=cls.encrypt_custom_backend, + encrypt_max_length=cls.encrypt_max_length, + ), + nullable=cls.nullable, + index=cls.index, + unique=cls.unique, + default=cls.default, + server_default=cls.server_default, + ) return column @classmethod def expand_relationship( - cls, - value: Any, - child: Union["Model", "NewBaseModel"], - to_register: bool = True, + cls, + value: Any, + child: Union["Model", "NewBaseModel"], + to_register: bool = True, ) -> Any: """ Function overwritten for relations, in basic field the value is returned as is. @@ -332,7 +347,7 @@ class BaseField(FieldInfo): :rtype: None """ if cls.owner is not None and ( - cls.owner == cls.to or cls.owner.Meta == cls.to.Meta + cls.owner == cls.to or cls.owner.Meta == cls.to.Meta ): cls.self_reference = True cls.self_reference_primary = cls.name diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 7f1a500..0aafd5e 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -184,10 +184,25 @@ def ForeignKey( # noqa CFQ002 owner = kwargs.pop("owner", None) self_reference = kwargs.pop("self_reference", False) + default = kwargs.pop("default", None) - if default is not None: + encrypt_secret = kwargs.pop("encrypt_secret", None) + encrypt_backend = kwargs.pop("encrypt_backend", None) + encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None) + encrypt_max_length = kwargs.pop("encrypt_max_length", None) + + not_supported = [ + default, + encrypt_secret, + encrypt_backend, + encrypt_custom_backend, + encrypt_max_length, + ] + if any(x is not None for x in not_supported): raise ModelDefinitionError( - "Argument 'default' is not supported " "on relation fields!" + f"Argument {next((x for x in not_supported if x is not None))} " + f"is not supported " + "on relation fields!" ) if to.__class__ == ForwardRef: @@ -386,8 +401,6 @@ class ForeignKeyField(BaseField): :return: (if needed) registered Model :rtype: Model """ - if cls.to.pk_type() == uuid.UUID and isinstance(value, str): - value = uuid.UUID(value) if not isinstance(value, cls.to.pk_type()): raise RelationshipInstanceError( f"Relationship error - ForeignKey {cls.to.__name__} " diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py index db763e3..ca26364 100644 --- a/ormar/fields/many_to_many.py +++ b/ormar/fields/many_to_many.py @@ -97,9 +97,23 @@ def ManyToMany( forbid_through_relations(cast(Type["Model"], through)) default = kwargs.pop("default", None) - if default is not None: + encrypt_secret = kwargs.pop("encrypt_secret", None) + encrypt_backend = kwargs.pop("encrypt_backend", None) + encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None) + encrypt_max_length = kwargs.pop("encrypt_max_length", None) + + not_supported = [ + default, + encrypt_secret, + encrypt_backend, + encrypt_custom_backend, + encrypt_max_length, + ] + if any(x is not None for x in not_supported): raise ModelDefinitionError( - "Argument 'default' is not supported " "on relation fields!" + f"Argument {next((x for x in not_supported if x is not None))} " + f"is not supported " + "on relation fields!" ) if to.__class__ == ForwardRef: diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index 67ea3e6..9cfd4f4 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -76,8 +76,7 @@ class ModelFieldFactory: encrypt_secret = kwargs.pop("encrypt_secret", None) encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE) - encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", - None) + encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None) encrypt_max_length = kwargs.pop("encrypt_max_length", 5000) namespace = dict( diff --git a/ormar/fields/parsers.py b/ormar/fields/parsers.py new file mode 100644 index 0000000..e0f1a53 --- /dev/null +++ b/ormar/fields/parsers.py @@ -0,0 +1,44 @@ +import datetime +import decimal +from typing import Any + +import pydantic +from pydantic.datetime_parse import parse_date, parse_datetime, parse_time + +try: + import orjson as json +except ImportError: # pragma: no cover + import json # type: ignore + + +def parse_bool(value: str) -> bool: + return value == "true" + + +def encode_bool(value: bool) -> str: + return "true" if value else "false" + + +def encode_json(value: Any) -> str: + value = json.dumps(value) if not isinstance(value, str) else value + value = value.decode("utf-8") if isinstance(value, bytes) else value + return value + + +ENCODERS_MAP = { + bool: encode_bool, + datetime.datetime: lambda x: x.isoformat(), + datetime.date: lambda x: x.isoformat(), + datetime.time: lambda x: x.isoformat(), + pydantic.Json: encode_json, + decimal.Decimal: float, +} + +DECODERS_MAP = { + bool: parse_bool, + datetime.datetime: parse_datetime, + datetime.date: parse_date, + datetime.time: parse_time, + pydantic.Json: json.loads, + decimal.Decimal: decimal.Decimal, +} diff --git a/ormar/fields/sqlalchemy_encrypted.py b/ormar/fields/sqlalchemy_encrypted.py index a2e10b7..bdb1832 100644 --- a/ormar/fields/sqlalchemy_encrypted.py +++ b/ormar/fields/sqlalchemy_encrypted.py @@ -1,50 +1,49 @@ # inspired by sqlalchemy-utils (https://github.com/kvesteri/sqlalchemy-utils) import abc import base64 -import datetime -import json from enum import Enum -from typing import Any, Callable, TYPE_CHECKING, Type, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Type, Union import sqlalchemy.types as types +from pydantic.utils import lenient_issubclass from sqlalchemy.engine.default import DefaultDialect -from ormar import ModelDefinitionError +import ormar # noqa: I100, I202 +from ormar import ModelDefinitionError # noqa: I202, I100 cryptography = None -try: - import cryptography +try: # pragma: nocover + import cryptography # type: ignore from cryptography.fernet import Fernet from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes -except ImportError: +except ImportError: # pragma: nocover pass -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: nocover from ormar import BaseField class EncryptBackend(abc.ABC): - - def _update_key(self, key): + def _refresh(self, key: Union[str, bytes]) -> None: if isinstance(key, str): key = key.encode() digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) digest.update(key) engine_key = digest.finalize() - self._initialize_engine(engine_key) + self._initialize_backend(engine_key) @abc.abstractmethod - def _initialize_engine(self, secret_key: bytes): + def _initialize_backend(self, secret_key: bytes) -> None: # pragma: nocover pass @abc.abstractmethod - def encrypt(self, value: Any) -> str: + def encrypt(self, value: Any) -> str: # pragma: nocover pass @abc.abstractmethod - def decrypt(self, value: Any) -> str: + def decrypt(self, value: Any) -> str: # pragma: nocover pass @@ -53,11 +52,11 @@ class HashBackend(EncryptBackend): One-way hashing - in example for passwords, no way to decrypt the value! """ - def _initialize_engine(self, secret_key: bytes): + def _initialize_backend(self, secret_key: bytes) -> None: self.secret_key = base64.urlsafe_b64encode(secret_key) def encrypt(self, value: Any) -> str: - if not isinstance(value, str): + if not isinstance(value, str): # pragma: nocover value = repr(value) value = value.encode() digest = hashes.Hash(hashes.SHA512(), backend=default_backend()) @@ -67,7 +66,7 @@ class HashBackend(EncryptBackend): return hashed_value.hex() def decrypt(self, value: Any) -> str: - if not isinstance(value, str): + if not isinstance(value, str): # pragma: nocover value = str(value) return value @@ -77,7 +76,7 @@ class FernetBackend(EncryptBackend): Two-way encryption, data stored in db are encrypted but decrypted during query. """ - def _initialize_engine(self, secret_key: bytes): + def _initialize_backend(self, secret_key: bytes) -> None: self.secret_key = base64.urlsafe_b64encode(secret_key) self.fernet = Fernet(self.secret_key) @@ -86,14 +85,14 @@ class FernetBackend(EncryptBackend): value = repr(value) value = value.encode() encrypted = self.fernet.encrypt(value) - return encrypted.decode('utf-8') + return encrypted.decode("utf-8") def decrypt(self, value: Any) -> str: - if not isinstance(value, str): + if not isinstance(value, str): # pragma: nocover value = str(value) - decrypted = self.fernet.decrypt(value.encode()) + decrypted: Union[str, bytes] = self.fernet.decrypt(value.encode()) if not isinstance(decrypted, str): - decrypted = decrypted.decode('utf-8') + decrypted = decrypted.decode("utf-8") return decrypted @@ -104,115 +103,82 @@ class EncryptBackends(Enum): CUSTOM = 3 -backends_map = { +BACKENDS_MAP = { EncryptBackends.FERNET: FernetBackend, EncryptBackends.HASH: HashBackend, - EncryptBackends.CUSTOM: None } -class EncryptedString(types.TypeDecorator): # pragma nocover +class EncryptedString(types.TypeDecorator): """ Used to store encrypted values in a database """ impl = types.TypeEngine - def __init__(self, - *args: Any, - encrypt_secret: Union[str, Callable], - _field_type: Type["BaseField"], - encrypt_max_length: int = 5000, - encrypt_backend: EncryptBackends = EncryptBackends.FERNET, - encrypt_custom_backend: Type[EncryptBackend] = None, - **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - if not cryptography: + def __init__( + self, + encrypt_secret: Union[str, Callable], + encrypt_backend: EncryptBackends = EncryptBackends.FERNET, + encrypt_custom_backend: Type[EncryptBackend] = None, + **kwargs: Any, + ) -> None: + _field_type = kwargs.pop("_field_type") + encrypt_max_length = kwargs.pop("encrypt_max_length", 5000) + super().__init__() + if not cryptography: # pragma: nocover raise ModelDefinitionError( "In order to encrypt a column 'cryptography' is required!" ) - backend = backends_map.get(encrypt_backend, encrypt_custom_backend) - if not backend or not issubclass(backend, EncryptBackend): + backend = BACKENDS_MAP.get(encrypt_backend, encrypt_custom_backend) + if not backend or not lenient_issubclass(backend, EncryptBackend): raise ModelDefinitionError("Wrong or no encrypt backend provided!") - self.backend = backend() - self._field_type = _field_type - self._underlying_type = _field_type.column_type - self._key = encrypt_secret - self.max_length = encrypt_max_length - def __repr__(self) -> str: + self.backend: EncryptBackend = backend() + self._field_type: Type["BaseField"] = _field_type + self._underlying_type: Any = _field_type.column_type + self._key: Union[str, Callable] = encrypt_secret + self.max_length: int = encrypt_max_length + type_ = self._field_type.__type__ + if type_ is None: # pragma: nocover + raise ModelDefinitionError( + f"Improperly configured field " f"{self._field_type.name}" + ) + self.type_: Any = type_ + + def __repr__(self) -> str: # pragma: nocover return f"VARCHAR({self.max_length})" def load_dialect_impl(self, dialect: DefaultDialect) -> Any: return dialect.type_descriptor(types.VARCHAR(self.max_length)) - @property - def key(self): - return self._key - - @key.setter - def key(self, value): - self._key = value - - def _update_key(self): + def _refresh(self) -> None: key = self._key() if callable(self._key) else self._key - self.backend._update_key(key) + self.backend._refresh(key) - def process_bind_param(self, value, dialect): - """Encrypt a value on the way in.""" - if value is not None: - self._update_key() + def process_bind_param(self, value: Any, dialect: DefaultDialect) -> Optional[str]: + if value is None: + return value + self._refresh() + try: + value = self._underlying_type.process_bind_param(value, dialect) + except AttributeError: + encoder = ormar.ENCODERS_MAP.get(self.type_, None) + if encoder: + value = encoder(value) # type: ignore - try: - value = self._underlying_type.process_bind_param( - value, dialect - ) + return self.backend.encrypt(value) - except AttributeError: - # Doesn't have 'process_bind_param' - type_ = self._field_type.__type__ - if issubclass(type_, bool): - value = 'true' if value else 'false' + def process_result_value(self, value: Any, dialect: DefaultDialect) -> Any: + if value is None: + return value + self._refresh() + decrypted_value = self.backend.decrypt(value) + try: + return self._underlying_type.process_result_value(decrypted_value, dialect) + except AttributeError: + decoder = ormar.DECODERS_MAP.get(self.type_, None) + if decoder: + return decoder(decrypted_value) # type: ignore - elif issubclass(type_, (datetime.date, datetime.time)): - value = value.isoformat() - - # elif issubclass(type_, JSONType): - # value = json.dumps(value) - - return self.backend.encrypt(value) - - def process_result_value(self, value, dialect): - """Decrypt value on the way out.""" - if value is not None: - self._update_key() - decrypted_value = self.backend.decrypt(value) - - try: - return self.underlying_type.process_result_value( - decrypted_value, dialect - ) - - except AttributeError: - # Doesn't have 'process_result_value' - - # Handle 'boolean' and 'dates' - type_ = self._field_type.__type__ - # date_types = [datetime.datetime, datetime.time, datetime.date] - - if issubclass(type_, bool): - return decrypted_value == 'true' - - # elif type_ in date_types: - # return DatetimeHandler.process_value( - # decrypted_value, type_ - # ) - - # elif issubclass(type_, JSONType): - # return json.loads(decrypted_value) - - # Handle all others - return self.underlying_type.python_type(decrypted_value) - - def _coerce(self, value): - return self.underlying_type._coerce(value) + return self._field_type.__type__(decrypted_value) # type: ignore diff --git a/ormar/fields/sqlalchemy_uuid.py b/ormar/fields/sqlalchemy_uuid.py index 2a6bfd7..826c0c8 100644 --- a/ormar/fields/sqlalchemy_uuid.py +++ b/ormar/fields/sqlalchemy_uuid.py @@ -1,12 +1,12 @@ import uuid -from typing import Any, Optional, Union +from typing import Any, Optional from sqlalchemy import CHAR from sqlalchemy.engine.default import DefaultDialect from sqlalchemy.types import TypeDecorator -class UUID(TypeDecorator): # pragma nocover +class UUID(TypeDecorator): """ Platform-independent GUID type. Uses CHAR(36) if in a string mode, otherwise uses CHAR(32), to store UUID. @@ -20,31 +20,11 @@ class UUID(TypeDecorator): # pragma nocover super().__init__(*args, **kwargs) self.uuid_format = uuid_format - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: nocover if self.uuid_format == "string": return "CHAR(36)" return "CHAR(32)" - def _cast_to_uuid(self, value: Union[str, int, bytes]) -> uuid.UUID: - """ - Parses given value into uuid.UUID field. - - :param value: value to be parsed - :type value: Union[str, int, bytes] - :return: initialized uuid - :rtype: uuid.UUID - """ - if not isinstance(value, uuid.UUID): - if isinstance(value, bytes): - ret_value = uuid.UUID(bytes=value) - elif isinstance(value, int): - ret_value = uuid.UUID(int=value) - elif isinstance(value, str): - ret_value = uuid.UUID(value) - else: - ret_value = value - return ret_value - def load_dialect_impl(self, dialect: DefaultDialect) -> Any: return ( dialect.type_descriptor(CHAR(36)) @@ -53,12 +33,10 @@ class UUID(TypeDecorator): # pragma nocover ) def process_bind_param( - self, value: Union[str, int, bytes, uuid.UUID, None], dialect: DefaultDialect + self, value: uuid.UUID, dialect: DefaultDialect ) -> Optional[str]: if value is None: return value - if not isinstance(value, uuid.UUID): - value = self._cast_to_uuid(value) return str(value) if self.uuid_format == "string" else "%.32x" % value.int def process_result_value( @@ -68,4 +46,4 @@ class UUID(TypeDecorator): # pragma nocover return value if not isinstance(value, uuid.UUID): return uuid.UUID(value) - return value + return value # pragma: nocover diff --git a/ormar/models/helpers/validation.py b/ormar/models/helpers/validation.py index 582c3fa..bc87235 100644 --- a/ormar/models/helpers/validation.py +++ b/ormar/models/helpers/validation.py @@ -53,6 +53,7 @@ def convert_choices_if_needed( # noqa: CCR001 :return: value, choices list :rtype: Tuple[Any, List] """ + # TODO use same maps as with EncryptedString choices = [o.value if isinstance(o, Enum) else o for o in field.choices] if field.__type__ in [datetime.datetime, datetime.date, datetime.time]: diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index b9e71df..e710aef 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -88,13 +88,13 @@ class SqlJoin: return self.main_model.Meta.alias_manager @property - def to_table(self) -> str: + def to_table(self) -> sqlalchemy.Table: """ Shortcut to table name of the next model :return: name of the target table :rtype: str """ - return self.next_model.Meta.table.name + return self.next_model.Meta.table def _on_clause( self, previous_alias: str, from_clause: str, to_clause: str, @@ -282,7 +282,7 @@ class SqlJoin: on_clause = self._on_clause( previous_alias=self.own_alias, from_clause=f"{self.target_field.owner.Meta.tablename}.{from_key}", - to_clause=f"{self.to_table}.{to_key}", + to_clause=f"{self.to_table.name}.{to_key}", ) target_table = self.alias_manager.prefixed_table_name( self.next_alias, self.to_table @@ -301,7 +301,7 @@ class SqlJoin: ) self.columns.extend( self.alias_manager.prefixed_columns( - self.next_alias, self.next_model.Meta.table, self_related_fields + self.next_alias, target_table, self_related_fields ) ) self.used_aliases.append(self.next_alias) diff --git a/ormar/relations/alias_manager.py b/ormar/relations/alias_manager.py index 815a4dc..2ec6159 100644 --- a/ormar/relations/alias_manager.py +++ b/ormar/relations/alias_manager.py @@ -67,24 +67,21 @@ class AliasManager: if not fields else [col for col in table.columns if col.name in fields] ) - return [ - text(f"{alias}{table.name}.{column.name} as {alias}{column.name}") - for column in all_columns - ] + return [column.label(f"{alias}{column.name}") for column in all_columns] @staticmethod - def prefixed_table_name(alias: str, name: str) -> text: + def prefixed_table_name(alias: str, table: sqlalchemy.Table) -> text: """ Creates text clause with table name with aliased name. :param alias: alias of given table :type alias: str - :param name: table name - :type name: str + :param table: table + :type table: sqlalchemy.Table :return: sqlalchemy text clause as "table_name aliased_name" :rtype: sqlalchemy text clause """ - return text(f"{name} {alias}_{name}") + return table.alias(f"{alias}_{table.name}") def add_relation_type( self, source_model: Type["Model"], relation_name: str, reverse_name: str = None, diff --git a/tests/test_encrypted_columns.py b/tests/test_encrypted_columns.py index 4de0459..5eff02f 100644 --- a/tests/test_encrypted_columns.py +++ b/tests/test_encrypted_columns.py @@ -1,12 +1,16 @@ +# type: ignore +import decimal import uuid -from typing import Optional +import datetime +from typing import Any import databases import pytest import sqlalchemy import ormar -from ormar.exceptions import QueryDefinitionError +from ormar import ModelDefinitionError +from ormar.fields.sqlalchemy_encrypted import EncryptedString from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL) @@ -18,22 +22,58 @@ class BaseMeta(ormar.ModelMeta): database = database +default_fernet = dict( + encrypt_secret="asd123", encrypt_backend=ormar.EncryptBackends.FERNET, +) + + +class DummyBackend(ormar.fields.EncryptBackend): + def _initialize_backend(self, secret_key: bytes) -> None: + pass + + def encrypt(self, value: Any) -> str: + return value + + def decrypt(self, value: Any) -> str: + return value + + class Author(ormar.Model): class Meta(BaseMeta): tablename = "authors" id: int = ormar.Integer(primary_key=True) - name: str = ormar.String(max_length=100, - encrypt_secret='asd123', - encrypt_backend=ormar.EncryptBackends.FERNET) - uuid_test = ormar.UUID(default=uuid.uuid4, uuid_format='string') - password: str = ormar.String(max_length=100, - encrypt_secret='udxc32', - encrypt_backend=ormar.EncryptBackends.HASH) - birth_year: int = ormar.Integer(nullable=True, - encrypt_secret='secure89key%^&psdijfipew', - encrypt_max_length=200, - encrypt_backend=ormar.EncryptBackends.FERNET) + name: str = ormar.String(max_length=100, **default_fernet) + uuid_test = ormar.UUID(default=uuid.uuid4, uuid_format="string") + uuid_test2 = ormar.UUID(nullable=True, uuid_format="string") + password: str = ormar.String( + max_length=128, + encrypt_secret="udxc32", + encrypt_backend=ormar.EncryptBackends.HASH, + ) + birth_year: int = ormar.Integer( + nullable=True, + encrypt_secret="secure89key%^&psdijfipew", + encrypt_max_length=200, + encrypt_backend=ormar.EncryptBackends.FERNET, + ) + test_text: str = ormar.Text(default="", **default_fernet) + test_bool: bool = ormar.Boolean(nullable=False, **default_fernet) + test_float: float = ormar.Float(**default_fernet) + test_float2: float = ormar.Float(nullable=True, **default_fernet) + test_datetime = ormar.DateTime(default=datetime.datetime.now, **default_fernet) + test_date = ormar.Date(default=datetime.date.today, **default_fernet) + test_time = ormar.Time(default=datetime.time, **default_fernet) + test_json = ormar.JSON(default={}, **default_fernet) + test_bigint: int = ormar.BigInteger(default=0, **default_fernet) + test_decimal = ormar.Decimal(scale=2, precision=10, **default_fernet) + test_decimal2 = ormar.Decimal(max_digits=10, decimal_places=2, **default_fernet) + custom_backend: str = ormar.String( + max_length=200, + encrypt_secret="asda8", + encrypt_backend=ormar.EncryptBackends.CUSTOM, + encrypt_custom_backend=DummyBackend, + ) @pytest.fixture(autouse=True, scope="module") @@ -45,16 +85,104 @@ def create_test_database(): metadata.drop_all(engine) +def test_error_on_encrypted_pk(): + with pytest.raises(ModelDefinitionError): + + class Wrong(ormar.Model): + class Meta(BaseMeta): + tablename = "wrongs" + + id: int = ormar.Integer( + primary_key=True, + encrypt_secret="asd123", + encrypt_backend=ormar.EncryptBackends.FERNET, + ) + + +def test_error_on_encrypted_relation(): + with pytest.raises(ModelDefinitionError): + + class Wrong2(ormar.Model): + class Meta(BaseMeta): + tablename = "wrongs2" + + id: int = ormar.Integer(primary_key=True) + author = ormar.ForeignKey( + Author, + encrypt_secret="asd123", + encrypt_backend=ormar.EncryptBackends.FERNET, + ) + + +def test_error_on_encrypted_m2m_relation(): + with pytest.raises(ModelDefinitionError): + + class Wrong3(ormar.Model): + class Meta(BaseMeta): + tablename = "wrongs3" + + id: int = ormar.Integer(primary_key=True) + author = ormar.ManyToMany( + Author, + encrypt_secret="asd123", + encrypt_backend=ormar.EncryptBackends.FERNET, + ) + + +def test_wrong_backend(): + with pytest.raises(ModelDefinitionError): + + class Wrong3(ormar.Model): + class Meta(BaseMeta): + tablename = "wrongs3" + + id: int = ormar.Integer(primary_key=True) + author = ormar.Integer( + encrypt_secret="asd123", + encrypt_backend=ormar.EncryptBackends.CUSTOM, + encrypt_custom_backend="aa", + ) + + def test_db_structure(): - assert Author.Meta.table.c.get('name').type.impl.__class__ == sqlalchemy.NVARCHAR - assert Author.Meta.table.c.get('birth_year').type.max_length == 200 + assert Author.Meta.table.c.get("name").type.__class__ == EncryptedString + assert Author.Meta.table.c.get("birth_year").type.max_length == 200 @pytest.mark.asyncio -async def test_wrong_query_foreign_key_type(): +async def test_save_and_retrieve(): async with database: - await Author(name='Test', birth_year=1988, password='test123').save() + test_uuid = uuid.uuid4() + await Author( + name="Test", + birth_year=1988, + password="test123", + uuid_test=test_uuid, + test_float=1.2, + test_bool=True, + test_decimal=decimal.Decimal(3.5), + test_decimal2=decimal.Decimal(5.5), + test_json=dict(aa=12), + custom_backend="test12", + ).save() author = await Author.objects.get() - assert author.name == 'Test' + assert author.name == "Test" assert author.birth_year == 1988 + password = ( + "03e4a4d513e99cb3fe4ee3db282c053daa3f3572b849c3868939a306944ad5c08" + "22b50d4886e10f4cd418c3f2df3ceb02e2e7ac6e920ae0c90f2dedfc8fa16e2" + ) + assert author.password == password + assert author.uuid_test == test_uuid + assert author.uuid_test2 is None + assert author.test_datetime.date() == datetime.date.today() + assert author.test_date == datetime.date.today() + assert author.test_text == "" + assert author.test_float == 1.2 + assert author.test_float2 is None + assert author.test_bigint == 0 + assert author.test_json == {"aa": 12} + assert author.test_decimal == 3.5 + assert author.test_decimal2 == 5.5 + assert author.custom_backend == "test12" diff --git a/tests/test_or_filters.py b/tests/test_or_filters.py index 6d6f8ab..81a412c 100644 --- a/tests/test_or_filters.py +++ b/tests/test_or_filters.py @@ -57,24 +57,24 @@ async def test_or_filters(): books = ( await Book.objects.select_related("author") - .filter(ormar.or_(author__name="J.R.R. Tolkien", year__gt=1970)) - .all() + .filter(ormar.or_(author__name="J.R.R. Tolkien", year__gt=1970)) + .all() ) assert len(books) == 5 books = ( await Book.objects.select_related("author") - .filter(ormar.or_(author__name="J.R.R. Tolkien", year__lt=1995)) - .all() + .filter(ormar.or_(author__name="J.R.R. Tolkien", year__lt=1995)) + .all() ) assert len(books) == 4 assert not any([x.title == "The Tower of Fools" for x in books]) books = ( await Book.objects.select_related("author") - .filter(ormar.or_(year__gt=1960, year__lt=1940)) - .filter(author__name="J.R.R. Tolkien") - .all() + .filter(ormar.or_(year__gt=1960, year__lt=1940)) + .filter(author__name="J.R.R. Tolkien") + .all() ) assert len(books) == 2 assert books[0].title == "The Hobbit" @@ -82,13 +82,13 @@ async def test_or_filters(): books = ( await Book.objects.select_related("author") - .filter( + .filter( ormar.and_( ormar.or_(year__gt=1960, year__lt=1940), author__name="J.R.R. Tolkien", ) ) - .all() + .all() ) assert len(books) == 2 @@ -97,14 +97,14 @@ async def test_or_filters(): books = ( await Book.objects.select_related("author") - .filter( + .filter( ormar.or_( ormar.and_(year__gt=1960, author__name="J.R.R. Tolkien"), ormar.and_(year__lt=2000, author__name="Andrzej Sapkowski"), ) ) - .filter(title__startswith="The") - .all() + .filter(title__startswith="The") + .all() ) assert len(books) == 2 assert books[0].title == "The Silmarillion" @@ -112,7 +112,7 @@ async def test_or_filters(): books = ( await Book.objects.select_related("author") - .filter( + .filter( ormar.or_( ormar.and_( ormar.or_(year__gt=1960, year__lt=1940), @@ -121,7 +121,7 @@ async def test_or_filters(): ormar.and_(year__lt=2000, author__name="Andrzej Sapkowski"), ) ) - .all() + .all() ) assert len(books) == 3 assert books[0].title == "The Hobbit" @@ -130,29 +130,29 @@ async def test_or_filters(): books = ( await Book.objects.select_related("author") - .exclude( + .exclude( ormar.or_( ormar.and_(year__gt=1960, author__name="J.R.R. Tolkien"), ormar.and_(year__lt=2000, author__name="Andrzej Sapkowski"), ) ) - .filter(title__startswith="The") - .all() + .filter(title__startswith="The") + .all() ) assert len(books) == 3 assert not any([x.title in ["The Silmarillion", "The Witcher"] for x in books]) books = ( await Book.objects.select_related("author") - .filter( + .filter( ormar.or_( ormar.and_(year__gt=1960, author__name="J.R.R. Tolkien"), ormar.and_(year__lt=2000, author__name="Andrzej Sapkowski"), title__icontains="hobbit", ) ) - .filter(title__startswith="The") - .all() + .filter(title__startswith="The") + .all() ) assert len(books) == 3 assert not any( @@ -161,43 +161,43 @@ async def test_or_filters(): books = ( await Book.objects.select_related("author") - .filter(ormar.or_(year__gt=1980, year__lt=1910)) - .filter(title__startswith="The") - .limit(1) - .all() + .filter(ormar.or_(year__gt=1980, year__lt=1910)) + .filter(title__startswith="The") + .limit(1) + .all() ) assert len(books) == 1 assert books[0].title == "The Witcher" books = ( await Book.objects.select_related("author") - .filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski")) - .filter(title__startswith="The") - .limit(1) - .all() + .filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski")) + .filter(title__startswith="The") + .limit(1) + .all() ) assert len(books) == 1 assert books[0].title == "The Witcher" books = ( await Book.objects.select_related("author") - .filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski")) - .filter(title__startswith="The") - .limit(1) - .offset(1) - .all() + .filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski")) + .filter(title__startswith="The") + .limit(1) + .offset(1) + .all() ) assert len(books) == 1 assert books[0].title == "The Tower of Fools" books = ( await Book.objects.select_related("author") - .filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski")) - .filter(title__startswith="The") - .limit(1) - .offset(1) - .order_by("-id") - .all() + .filter(ormar.or_(year__gt=1980, author__name="Andrzej Sapkowski")) + .filter(title__startswith="The") + .limit(1) + .offset(1) + .order_by("-id") + .all() ) assert len(books) == 1 assert books[0].title == "The Witcher" @@ -220,19 +220,19 @@ async def test_or_filters(): books = ( await Book.objects.select_related("author") - .filter(ormar.or_(author__name="J.R.R. Tolkien")) - .all() + .filter(ormar.or_(author__name="J.R.R. Tolkien")) + .all() ) assert len(books) == 3 books = ( await Book.objects.select_related("author") - .filter( + .filter( ormar.or_( ormar.and_(author__name__icontains="tolkien"), ormar.and_(author__name__icontains="sapkowski"), ) ) - .all() + .all() ) assert len(books) == 5