From 241628b1d94398bc270cbe0f01eb715a7f8be6f7 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 9 Aug 2020 07:51:06 +0200 Subject: [PATCH] liniting and applying black --- .coverage | Bin 53248 -> 53248 bytes .flake8 | 3 +- orm/__init__.py | 18 ++- orm/fields.py | 133 +++++++++++------ orm/helpers.py | 26 ---- orm/models.py | 165 +++++++++++++-------- orm/queryset.py | 261 +++++++++++++++++++++++---------- orm/relations.py | 80 ++++++---- tests/test_model_definition.py | 16 ++ 9 files changed, 455 insertions(+), 247 deletions(-) delete mode 100644 orm/helpers.py diff --git a/.coverage b/.coverage index c907f3e5af7db5ef5c792c69402959e4bfe2988f..dc4c4fa39b8c2f40a9c341058db4ebd6e14f7779 100644 GIT binary patch delta 402 zcmZozz}&Ead4rZdySbIIv6YG8W<&jZ0*ro}1sufqnbcS(`O9*rWv1q&6zdgKYOzeN z_mkw!FUpNctw>HSD9Oyvn_T8UmD!(PeKTKxB>&|9adKR6<*bwI;-o}TRD(=m(PR#v z+!!amd1`#2g4|>V{-6A>`0w#w;a|mH!EeGZ$1ldu$@hfs4Bu+L$(scQ^7$sO>$4OQ z<6&Xs)Z;0e{(t_@IiK%8duC~GKlx{$UA+hgP(q8N?&zP~o%dgU{r~Ux`tQGmoY+|z zIU8Bncb)$KW6l5DZ{OW~{4Mh?~}|M@q6_UX+x&;Fg2!cZ0e`^B$)H_dL| z{`skzpZ~>c7FH!|LU{%o=w|)GskT9*|hi1 n=D(>g|9*DwzdPoP94tVULM&-%X`Xw9y90s{e delta 341 zcmV-b0jmChpaX!Q1F$MD2R1q~GdeUmvoSB;5CKxN5I`0W0xAWQP){lXCIpgC9R+4- zWo%@Vc2AW9Cl6Dz4p12nlTTkO0R@x)ULFl*X=Q9=b1ras1StbolYw3rvz1?QARev& z5BU%358w~G53vu04@(a>4=)cQ4*d@24#p0yvk?%A4wI9PIT92M1OW*u40f0QpZT2g z-{$5yli-d%e-a1;0SP7ueiide`}O{U_t*P<2{;D?0SR&k+U0+I_P5*JZLV*B@4j^p z4|?B!+iu(Vx_i9a@n?>`+kXArpZEU!xx0P$so(Bbo-+q31_S{KRtA3W<(=REbNSi4 z`PU=^(0lvn`kw7&`*#-{?CyTM_xJbT`+Xk>1q1;JG7<%5W@aM<1OW+11nw65*WK>k neg8k%?9FDg*=#nO&1S#(`Q6>W*ZVIB0|WsH5(C<^=#Lmc2E>+& diff --git a/.flake8 b/.flake8 index c2bfa9d..ec05a50 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,5 @@ [flake8] -ignore = ANN101 +ignore = ANN101, ANN102, W503 max-complexity = 10 +max-line-length = 88 exclude = p38venv,.pytest_cache diff --git a/orm/__init__.py b/orm/__init__.py index c947954..39adee8 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -1,6 +1,18 @@ from orm.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch -from orm.fields import BigInteger, Boolean, Date, DateTime, Decimal, Float, ForeignKey, Integer, JSON, String, Text, \ - Time +from orm.fields import ( + BigInteger, + Boolean, + Date, + DateTime, + Decimal, + Float, + ForeignKey, + Integer, + JSON, + String, + Text, + Time, +) from orm.models import Model __version__ = "0.0.1" @@ -21,5 +33,5 @@ __all__ = [ "ModelDefinitionError", "ModelNotSet", "MultipleMatches", - "NoMatch" + "NoMatch", ] diff --git a/orm/fields.py b/orm/fields.py index 80814f0..9c9f3f7 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -1,14 +1,15 @@ import datetime import decimal -from typing import List, Optional, TYPE_CHECKING, Type, Any, Union - -import sqlalchemy -from pydantic import Json, BaseModel -from pydantic.fields import ModelField +from typing import Any, List, Optional, TYPE_CHECKING, Type, Union import orm from orm.exceptions import ModelDefinitionError, RelationshipInstanceError +from pydantic import BaseModel, Json +from pydantic.fields import ModelField + +import sqlalchemy + if TYPE_CHECKING: # pragma no cover from orm.models import Model @@ -16,33 +17,39 @@ if TYPE_CHECKING: # pragma no cover class BaseField: __type__ = None - def __init__(self, *args, **kwargs) -> None: - name = kwargs.pop('name', None) + def __init__(self, *args: Any, **kwargs: Any) -> None: + name = kwargs.pop("name", None) args = list(args) if args: if isinstance(args[0], str): if name is not None: - raise ModelDefinitionError('Column name cannot be passed positionally and as a keyword.') + raise ModelDefinitionError( + "Column name cannot be passed positionally and as a keyword." + ) name = args.pop(0) self.name = name - self.primary_key = kwargs.pop('primary_key', False) - self.autoincrement = kwargs.pop('autoincrement', self.primary_key and self.__type__ == int) + self.primary_key = kwargs.pop("primary_key", False) + self.autoincrement = kwargs.pop( + "autoincrement", self.primary_key and self.__type__ == int + ) - self.nullable = kwargs.pop('nullable', not self.primary_key) - self.default = kwargs.pop('default', None) - self.server_default = kwargs.pop('server_default', None) + self.nullable = kwargs.pop("nullable", not self.primary_key) + self.default = kwargs.pop("default", None) + self.server_default = kwargs.pop("server_default", None) - self.index = kwargs.pop('index', None) - self.unique = kwargs.pop('unique', None) + self.index = kwargs.pop("index", None) + self.unique = kwargs.pop("unique", None) - self.pydantic_only = kwargs.pop('pydantic_only', False) + self.pydantic_only = kwargs.pop("pydantic_only", False) if self.pydantic_only and self.primary_key: - raise ModelDefinitionError('Primary key column cannot be pydantic only.') + raise ModelDefinitionError("Primary key column cannot be pydantic only.") @property def is_required(self) -> bool: - return not self.nullable and not self.has_default and not self.is_auto_primary_key + return ( + not self.nullable and not self.has_default and not self.is_auto_primary_key + ) @property def default_value(self) -> Any: @@ -81,16 +88,19 @@ class BaseField: def get_constraints(self) -> Optional[List]: return [] - def expand_relationship(self, value, child) -> Any: + def expand_relationship(self, value: Any, child: "Model") -> Any: return value class String(BaseField): __type__ = str - def __init__(self, *args, **kwargs): - assert 'length' in kwargs, 'length is required' - self.length = kwargs.pop('length') + def __init__(self, *args: Any, **kwargs: Any) -> None: + if "length" not in kwargs: + raise ModelDefinitionError( + "Param length is required for String model field." + ) + self.length = kwargs.pop("length") super().__init__(*args, **kwargs) def get_column_type(self) -> sqlalchemy.Column: @@ -163,27 +173,41 @@ class BigInteger(BaseField): class Decimal(BaseField): __type__ = decimal.Decimal - def __init__(self, *args, **kwargs): - assert 'precision' in kwargs, 'precision is required' - assert 'length' in kwargs, 'length is required' - self.length = kwargs.pop('length') - self.precision = kwargs.pop('precision') + def __init__(self, *args: Any, **kwargs: Any) -> None: + if "length" not in kwargs or "precision" not in kwargs: + raise ModelDefinitionError( + "Params length and precision are required for Decimal model field." + ) + self.length = kwargs.pop("length") + self.precision = kwargs.pop("precision") super().__init__(*args, **kwargs) def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.DECIMAL(self.length, self.precision) -def create_dummy_instance(fk: Type['Model'], pk: int = None) -> 'Model': +def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": init_dict = {fk.__pkname__: pk or -1} - init_dict = {**init_dict, **{k: create_dummy_instance(v.to) - for k, v in fk.__model_fields__.items() - if isinstance(v, ForeignKey) and not v.nullable and not v.virtual}} + init_dict = { + **init_dict, + **{ + k: create_dummy_instance(v.to) + for k, v in fk.__model_fields__.items() + if isinstance(v, ForeignKey) and not v.nullable and not v.virtual + }, + } return fk(**init_dict) class ForeignKey(BaseField): - def __init__(self, to, name: str = None, related_name: str = None, nullable: bool = True, virtual: bool = False): + def __init__( + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, + ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual self.related_name = related_name @@ -201,11 +225,16 @@ class ForeignKey(BaseField): to_column = self.to.__model_fields__[self.to.__pkname__] return to_column.get_column_type() - def expand_relationship(self, value, child) -> Union['Model', List['Model']]: + def expand_relationship( + self, value: Any, child: "Model" + ) -> Union["Model", List["Model"]]: if not isinstance(value, (self.to, dict, int, str, list)) or ( - isinstance(value, orm.models.Model) and not isinstance(value, self.to)): + isinstance(value, orm.models.Model) and not isinstance(value, self.to) + ): raise RelationshipInstanceError( - 'Relationship model can be build only from orm.Model, dict and integer or string (pk).') + "Relationship model can be build only from orm.Model, " + "dict and integer or string (pk)." + ) if isinstance(value, list) and not isinstance(value, self.to): model = [self.expand_relationship(val, child) for val in value] return model @@ -217,19 +246,27 @@ class ForeignKey(BaseField): else: model = create_dummy_instance(fk=self.to, pk=value) - child_model_name = self.related_name or child.__class__.__name__.lower() + 's' - model._orm_relationship_manager.add_relation(model.__class__.__name__.lower(), - child.__class__.__name__.lower(), - model, child, virtual=self.virtual) + child_model_name = self.related_name or child.__class__.__name__.lower() + "s" + model._orm_relationship_manager.add_relation( + model.__class__.__name__.lower(), + child.__class__.__name__.lower(), + model, + child, + virtual=self.virtual, + ) - if child_model_name not in model.__fields__ \ - and child.__class__.__name__.lower() not in model.__fields__: - model.__fields__[child_model_name] = ModelField(name=child_model_name, - type_=Optional[child.__pydantic_model__], - model_config=child.__pydantic_model__.__config__, - class_validators=child.__pydantic_model__.__validators__) - model.__model_fields__[child_model_name] = ForeignKey(child.__class__, - name=child_model_name, - virtual=True) + if ( + child_model_name not in model.__fields__ + and child.__class__.__name__.lower() not in model.__fields__ + ): + model.__fields__[child_model_name] = ModelField( + name=child_model_name, + type_=Optional[child.__pydantic_model__], + model_config=child.__pydantic_model__.__config__, + class_validators=child.__pydantic_model__.__validators__, + ) + model.__model_fields__[child_model_name] = ForeignKey( + child.__class__, name=child_model_name, virtual=True + ) return model diff --git a/orm/helpers.py b/orm/helpers.py deleted file mode 100644 index f8d7cfb..0000000 --- a/orm/helpers.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Union, Set, Dict # pragma no cover - - -class Excludable: # pragma no cover - - @staticmethod - def get_excluded(exclude: Union[Set, Dict, None], key: str = None): - # print(f'checking excluded for {key}', exclude) - if isinstance(exclude, dict): - if isinstance(exclude.get(key, {}), dict) and '__all__' in exclude.get(key, {}).keys(): - return exclude.get(key).get('__all__') - return exclude.get(key, {}) - return exclude - - @staticmethod - def is_excluded(exclude: Union[Set, Dict, None], key: str = None): - if exclude is None: - return False - to_exclude = Excludable.get_excluded(exclude, key) - # print(f'to exclude for current key = {key}', to_exclude) - - if isinstance(to_exclude, Set): - return key in to_exclude - elif to_exclude is ...: - return True - return False diff --git a/orm/models.py b/orm/models.py index b3f7d78..68e32f4 100644 --- a/orm/models.py +++ b/orm/models.py @@ -2,35 +2,39 @@ import copy import inspect import json import uuid -from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar, Tuple -from typing import Set, Dict +from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar +from typing import Callable, Dict, Set import databases -import pydantic -import sqlalchemy -from pydantic import BaseModel, BaseConfig, create_model import orm.queryset as qry from orm.exceptions import ModelDefinitionError from orm.fields import BaseField, ForeignKey from orm.relations import RelationshipManager +import pydantic +from pydantic import BaseConfig, BaseModel, create_model + +import sqlalchemy + relationship_manager = RelationshipManager() def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: - pydantic_fields = {field_name: ( - base_field.__type__, - ... if base_field.is_required else base_field.default_value - ) + pydantic_fields = { + field_name: ( + base_field.__type__, + ... if base_field.is_required else base_field.default_value, + ) for field_name, base_field in object_dict.items() - if isinstance(base_field, BaseField)} + if isinstance(base_field, BaseField) + } return pydantic_fields -def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename: str) -> Tuple[Optional[str], - List[sqlalchemy.Column], - Dict[str, BaseField]]: +def sqlalchemy_columns_from_model_fields( + name: str, object_dict: Dict, tablename: str +) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: pkname: Optional[str] = None columns: List[sqlalchemy.Column] = [] model_fields: Dict[str, BaseField] = {} @@ -42,9 +46,16 @@ def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename if field.primary_key: pkname = field_name if isinstance(field, ForeignKey): - reverse_name = field.related_name or field.to.__name__.lower().title() + '_' + name.lower() + 's' - relation_name = name.lower().title() + '_' + field.to.__name__.lower() - relationship_manager.add_relation_type(relation_name, reverse_name, field, tablename) + reverse_name = ( + field.related_name + or field.to.__name__.lower().title() + "_" + name.lower() + "s" + ) + relation_name = ( + name.lower().title() + "_" + field.to.__name__.lower() + ) + relationship_manager.add_relation_type( + relation_name, reverse_name, field, tablename + ) columns.append(field.get_column(field_name)) return pkname, columns, model_fields @@ -57,9 +68,7 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]: class ModelMetaclass(type): - def __new__( - mcs: type, name: str, bases: Any, attrs: dict - ) -> type: + def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) @@ -71,25 +80,29 @@ class ModelMetaclass(type): metadata = attrs["__metadata__"] # sqlalchemy table creation - pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs, tablename) - attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) - attrs['__columns__'] = columns - attrs['__pkname__'] = pkname + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( + name, attrs, tablename + ) + attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns) + attrs["__columns__"] = columns + attrs["__pkname__"] = pkname if not pkname: - raise ModelDefinitionError('Table has to have a primary key.') + raise ModelDefinitionError("Table has to have a primary key.") # pydantic model creation pydantic_fields = parse_pydantic_field_from_model_fields(attrs) - pydantic_model = create_model(name, __config__=get_pydantic_base_orm_config(), **pydantic_fields) - attrs['__pydantic_fields__'] = pydantic_fields - attrs['__pydantic_model__'] = pydantic_model - attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__) - attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__) - attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__) - attrs['__model_fields__'] = model_fields + pydantic_model = create_model( + name, __config__=get_pydantic_base_orm_config(), **pydantic_fields + ) + attrs["__pydantic_fields__"] = pydantic_fields + attrs["__pydantic_model__"] = pydantic_model + attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__) + attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__) + attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__) + attrs["__model_fields__"] = model_fields - attrs['_orm_relationship_manager'] = relationship_manager + attrs["_orm_relationship_manager"] = relationship_manager new_model = super().__new__( # type: ignore mcs, name, bases, attrs @@ -99,7 +112,8 @@ class ModelMetaclass(type): class Model(list, metaclass=ModelMetaclass): - # Model inherits from list in order to be treated as request.Body parameter in fastapi routes, + # Model inherits from list in order to be treated as + # request.Body parameter in fastapi routes, # inheriting from pydantic.BaseModel causes metaclass conflicts __abstract__ = True if TYPE_CHECKING: # pragma no cover @@ -115,17 +129,20 @@ class Model(list, metaclass=ModelMetaclass): objects = qry.QuerySet() - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: self._orm_id: str = uuid.uuid4().hex self._orm_saved: bool = False self.values: Optional[BaseModel] = None if "pk" in kwargs: kwargs[self.__pkname__] = kwargs.pop("pk") - kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()} + kwargs = { + k: self.__model_fields__[k].expand_relationship(v, self) + for k, v in kwargs.items() + } self.values = self.__pydantic_model__(**kwargs) - def __del__(self): + def __del__(self) -> None: self._orm_relationship_manager.deregister(self) def __setattr__(self, key: str, value: Any) -> None: @@ -138,20 +155,24 @@ class Model(list, metaclass=ModelMetaclass): value = self.__model_fields__[key].expand_relationship(value, self) - relation_key = self.__class__.__name__.title() + '_' + key + relation_key = self.__class__.__name__.title() + "_" + key if not self._orm_relationship_manager.contains(relation_key, self): setattr(self.values, key, value) else: super().__setattr__(key, value) def __getattribute__(self, key: str) -> Any: - if key != '__fields__' and key in self.__fields__: - relation_key = self.__class__.__name__.title() + '_' + key + if key != "__fields__" and key in self.__fields__: + relation_key = self.__class__.__name__.title() + "_" + key if self._orm_relationship_manager.contains(relation_key, self): return self._orm_relationship_manager.get(relation_key, self) item = getattr(self.values, key, None) - if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str): + if ( + item is not None + and self.is_conversion_to_json_needed(key) + and isinstance(item, str) + ): try: item = json.loads(item) except TypeError: # pragma no cover @@ -159,30 +180,41 @@ class Model(list, metaclass=ModelMetaclass): return item return super().__getattribute__(key) - def __eq__(self, other): + def __eq__(self, other: "Model") -> bool: return self.values.dict() == other.values.dict() - def __same__(self, other): - assert self.__class__ == other.__class__ + def __same__(self, other: "Model") -> bool: + if self.__class__ != other.__class__: + return False return self._orm_id == other._orm_id or ( - self.values is not None and other.values is not None and self.pk == other.pk) + self.values is not None and other.values is not None and self.pk == other.pk + ) - def __repr__(self): # pragma no cover + def __repr__(self) -> str: # pragma no cover return self.values.__repr__() @classmethod - def from_row(cls, row, select_related: List = None, previous_table: str = None) -> 'Model': + def from_row( + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + previous_table: str = None, + ) -> "Model": item = {} select_related = select_related or [] - table_prefix = cls._orm_relationship_manager.resolve_relation_join(previous_table, cls.__table__.name) + table_prefix = cls._orm_relationship_manager.resolve_relation_join( + previous_table, cls.__table__.name + ) previous_table = cls.__table__.name for related in select_related: if "__" in related: first_part, remainder = related.split("__", 1) model_cls = cls.__model_fields__[first_part].to - child = model_cls.from_row(row, select_related=[remainder], previous_table=previous_table) + child = model_cls.from_row( + row, select_related=[remainder], previous_table=previous_table + ) item[first_part] = child else: model_cls = cls.__model_fields__[related].to @@ -191,7 +223,9 @@ class Model(list, metaclass=ModelMetaclass): for column in cls.__table__.columns: if column.name not in item: - item[column.name] = row[f'{table_prefix + "_" if table_prefix else ""}{column.name}'] + item[column.name] = row[ + f'{table_prefix + "_" if table_prefix else ""}{column.name}' + ] return cls(**item) @@ -200,7 +234,7 @@ class Model(list, metaclass=ModelMetaclass): # return cls.__pydantic_model__.validate(value=value) @classmethod - def __get_validators__(cls): # pragma no cover + def __get_validators__(cls) -> Callable: # pragma no cover yield cls.__pydantic_model__.validate # @classmethod @@ -211,11 +245,11 @@ class Model(list, metaclass=ModelMetaclass): return self.__model_fields__.get(column_name).__type__ == pydantic.Json @property - def pk(self): + def pk(self) -> str: return getattr(self.values, self.__pkname__) @pk.setter - def pk(self, value): + def pk(self, value: Any) -> None: setattr(self.values, self.__pkname__, value) @property @@ -229,7 +263,9 @@ class Model(list, metaclass=ModelMetaclass): if isinstance(nested_model, list): dict_instance[field] = [x.dict() for x in nested_model] else: - dict_instance[field] = nested_model.dict() if nested_model is not None else {} + dict_instance[field] = ( + nested_model.dict() if nested_model is not None else {} + ) return dict_instance def from_dict(self, value_dict: Dict) -> None: @@ -245,16 +281,22 @@ class Model(list, metaclass=ModelMetaclass): def extract_related_names(cls) -> Set: related_names = set() for name, field in cls.__fields__.items(): - if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): + if inspect.isclass(field.type_) and issubclass( + field.type_, pydantic.BaseModel + ): related_names.add(name) return related_names def extract_model_db_fields(self) -> Dict: self_fields = self.extract_own_model_fields() - self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns} + self_fields = { + k: v for k, v in self_fields.items() if k in self.__table__.columns + } for field in self.extract_related_names(): if getattr(self, field) is not None: - self_fields[field] = getattr(getattr(self, field), self.__model_fields__[field].to.__pkname__) + self_fields[field] = getattr( + getattr(self, field), self.__model_fields__[field].to.__pkname__ + ) return self_fields async def save(self) -> int: @@ -264,7 +306,7 @@ class Model(list, metaclass=ModelMetaclass): expr = self.__table__.insert() expr = expr.values(**self_fields) item_id = await self.__database__.execute(expr) - setattr(self, 'pk', item_id) + self.pk = item_id return item_id async def update(self, **kwargs: Any) -> int: @@ -274,8 +316,11 @@ class Model(list, metaclass=ModelMetaclass): self_fields = self.extract_model_db_fields() self_fields.pop(self.__pkname__) - expr = self.__table__.update().values(**self_fields).where( - self.pk_column == getattr(self, self.__pkname__)) + expr = ( + self.__table__.update() + .values(**self_fields) + .where(self.pk_column == getattr(self, self.__pkname__)) + ) result = await self.__database__.execute(expr) return result @@ -285,7 +330,7 @@ class Model(list, metaclass=ModelMetaclass): result = await self.__database__.execute(expr) return result - async def load(self) -> 'Model': + async def load(self) -> "Model": expr = self.__table__.select().where(self.pk_column == self.pk) row = await self.__database__.fetch_one(expr) self.from_dict(dict(row)) diff --git a/orm/queryset.py b/orm/queryset.py index 36160e8..39a187c 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -1,11 +1,14 @@ -from typing import List, TYPE_CHECKING, Type, NamedTuple +from typing import Any, List, NamedTuple, TYPE_CHECKING, Tuple, Type, Union -import sqlalchemy -from sqlalchemy import text +import databases import orm from orm import ForeignKey -from orm.exceptions import NoMatch, MultipleMatches +from orm.exceptions import MultipleMatches, NoMatch +from orm.fields import BaseField + +import sqlalchemy +from sqlalchemy import text if TYPE_CHECKING: # pragma no cover from orm.models import Model @@ -24,17 +27,23 @@ FILTER_OPERATORS = { class JoinParameters(NamedTuple): - prev_model: Type['Model'] + prev_model: Type["Model"] previous_alias: str from_table: str - model_cls: Type['Model'] + model_cls: Type["Model"] class QuerySet: - ESCAPE_CHARACTERS = ['%', '_'] + ESCAPE_CHARACTERS = ["%", "_"] - def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None, - limit_count: int = None, offset: int = None): + def __init__( + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, + ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses self._select_related = [] if select_related is None else select_related @@ -48,47 +57,77 @@ class QuerySet: self.columns = None self.order_bys = None - def __get__(self, instance, owner): + def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": return self.__class__(model_cls=owner) @property - def database(self): + def database(self) -> databases.Database: return self.model_cls.__database__ @property - def table(self): + def table(self) -> sqlalchemy.Table: return self.model_cls.__table__ - def prefixed_columns(self, alias, table): - return [text(f'{alias}_{table.name}.{column.name} as {alias}_{column.name}') - for column in table.columns] + def prefixed_columns(self, alias: str, table: sqlalchemy.Table) -> List[text]: + return [ + text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}") + for column in table.columns + ] - def prefixed_table_name(self, alias, name): - return text(f'{name} {alias}_{name}') + def prefixed_table_name(self, alias: str, name: str) -> text: + return text(f"{name} {alias}_{name}") - def on_clause(self, from_table, to_table, previous_alias, alias, to_key, from_key): - return text(f'{alias}_{to_table}.{to_key}=' - f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}') + def on_clause( + self, + from_table: str, + to_table: str, + previous_alias: str, + alias: str, + to_key: str, + from_key: str, + ) -> text: + return text( + f"{alias}_{to_table}.{to_key}=" + f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}' + ) - def build_join_parameters(self, part, join_params: JoinParameters): + def build_join_parameters( + self, part: str, join_params: JoinParameters + ) -> JoinParameters: model_cls = join_params.model_cls.__model_fields__[part].to to_table = model_cls.__table__.name - alias = model_cls._orm_relationship_manager.resolve_relation_join(join_params.from_table, to_table) + alias = model_cls._orm_relationship_manager.resolve_relation_join( + join_params.from_table, to_table + ) if alias not in self.used_aliases: if join_params.prev_model.__model_fields__[part].virtual: - to_key = next((v for k, v in model_cls.__model_fields__.items() - if isinstance(v, ForeignKey) and v.to == join_params.prev_model), None).name + to_key = next( + ( + v + for k, v in model_cls.__model_fields__.items() + if isinstance(v, ForeignKey) and v.to == join_params.prev_model + ), + None, + ).name from_key = model_cls.__pkname__ else: to_key = model_cls.__pkname__ from_key = part - on_clause = self.on_clause(join_params.from_table, to_table, join_params.previous_alias, alias, to_key, - from_key) + on_clause = self.on_clause( + join_params.from_table, + to_table, + join_params.previous_alias, + alias, + to_key, + from_key, + ) target_table = self.prefixed_table_name(alias, to_table) - self.select_from = sqlalchemy.sql.outerjoin(self.select_from, target_table, on_clause) - self.order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}')) + self.select_from = sqlalchemy.sql.outerjoin( + self.select_from, target_table, on_clause + ) + self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}")) self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) self.used_aliases.append(alias) @@ -98,44 +137,76 @@ class QuerySet: return JoinParameters(prev_model, previous_alias, from_table, model_cls) @staticmethod - def field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part) -> bool: + def field_is_a_foreign_key_and_no_circular_reference( + field: BaseField, field_name: str, rel_part: str + ) -> bool: return isinstance(field, ForeignKey) and field_name not in rel_part - def field_qualifies_to_deeper_search(self, field, parent_virtual, nested, rel_part) -> bool: + def field_qualifies_to_deeper_search( + self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str + ) -> bool: prev_part_of_related = "__".join(rel_part.split("__")[:-1]) - partial_match = any([x.startswith(prev_part_of_related) for x in self._select_related]) + partial_match = any( + [x.startswith(prev_part_of_related) for x in self._select_related] + ) already_checked = any([x.startswith(rel_part) for x in self.auto_related]) - return ((field.virtual and parent_virtual) or (partial_match and not already_checked)) or not nested + return ( + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested - def extract_auto_required_relations(self, join_params: JoinParameters, - rel_part: str = '', nested: bool = False, parent_virtual: bool = False): + def extract_auto_required_relations( + self, + join_params: JoinParameters, + rel_part: str = "", + nested: bool = False, + parent_virtual: bool = False, + ) -> None: for field_name, field in join_params.prev_model.__model_fields__.items(): - if self.field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part): - rel_part = field_name if not rel_part else rel_part + '__' + field_name + if self.field_is_a_foreign_key_and_no_circular_reference( + field, field_name, rel_part + ): + rel_part = field_name if not rel_part else rel_part + "__" + field_name if not field.nullable: if rel_part not in self._select_related: self.auto_related.append("__".join(rel_part.split("__")[:-1])) - rel_part = '' - elif self.field_qualifies_to_deeper_search(field, parent_virtual, nested, rel_part): - join_params = JoinParameters(field.to, join_params.previous_alias, - join_params.from_table, join_params.prev_model) - self.extract_auto_required_relations(join_params=join_params, - rel_part=rel_part, nested=True, parent_virtual=field.virtual) + rel_part = "" + elif self.field_qualifies_to_deeper_search( + field, parent_virtual, nested, rel_part + ): + join_params = JoinParameters( + field.to, + join_params.previous_alias, + join_params.from_table, + join_params.prev_model, + ) + self.extract_auto_required_relations( + join_params=join_params, + rel_part=rel_part, + nested=True, + parent_virtual=field.virtual, + ) else: - rel_part = '' + rel_part = "" - def build_select_expression(self): + def build_select_expression(self) -> sqlalchemy.sql.select: self.columns = list(self.table.columns) - self.order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')] + self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] self.select_from = self.table for key in self.model_cls.__model_fields__: - if not self.model_cls.__model_fields__[key].nullable \ - and isinstance(self.model_cls.__model_fields__[key], orm.fields.ForeignKey) \ - and key not in self._select_related: + if ( + not self.model_cls.__model_fields__[key].nullable + and isinstance( + self.model_cls.__model_fields__[key], orm.fields.ForeignKey + ) + and key not in self._select_related + ): self._select_related = [key] + self._select_related - start_params = JoinParameters(self.model_cls, '', self.table.name, self.model_cls) + start_params = JoinParameters( + self.model_cls, "", self.table.name, self.model_cls + ) self.extract_auto_required_relations(start_params) if self.auto_related: new_joins = [] @@ -146,7 +217,9 @@ class QuerySet: self._select_related.sort(key=lambda item: (-len(item), item)) for item in self._select_related: - join_parameters = JoinParameters(self.model_cls, '', self.table.name, self.model_cls) + join_parameters = JoinParameters( + self.model_cls, "", self.table.name, self.model_cls + ) for part in item.split("__"): join_parameters = self.build_join_parameters(part, join_parameters) @@ -180,7 +253,7 @@ class QuerySet: return expr - def filter(self, **kwargs): + def filter(self, **kwargs: Any) -> "QuerySet": filter_clauses = self.filter_clauses select_related = list(self._select_related) @@ -189,7 +262,7 @@ class QuerySet: kwargs[pk_name] = kwargs.pop("pk") for key, value in kwargs.items(): - table_prefix = '' + table_prefix = "" if "__" in key: parts = key.split("__") @@ -215,9 +288,13 @@ class QuerySet: # against which the comparison is being made. previous_table = model_cls.__tablename__ for part in related_parts: - current_table = model_cls.__model_fields__[part].to.__tablename__ - table_prefix = model_cls._orm_relationship_manager.resolve_relation_join(previous_table, - current_table) + current_table = model_cls.__model_fields__[ + part + ].to.__tablename__ + manager = model_cls._orm_relationship_manager + table_prefix = manager.resolve_relation_join( + previous_table, current_table + ) model_cls = model_cls.__model_fields__[part].to previous_table = current_table @@ -236,25 +313,32 @@ class QuerySet: has_escaped_character = False if op in ["contains", "icontains"]: - has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS - if c in value) + has_escaped_character = any( + c for c in self.ESCAPE_CHARACTERS if c in value + ) if has_escaped_character: # enable escape modifier for char in self.ESCAPE_CHARACTERS: - value = value.replace(char, f'\\{char}') + value = value.replace(char, f"\\{char}") value = f"%{value}%" if isinstance(value, orm.Model): value = value.pk clause = getattr(column, op_attr)(value) - clause.modifiers['escape'] = '\\' if has_escaped_character else None + clause.modifiers["escape"] = "\\" if has_escaped_character else None - clause_text = str(clause.compile(dialect=self.model_cls.__database__._backend._dialect, - compile_kwargs={"literal_binds": True})) - alias = f'{table_prefix}_' if table_prefix else '' - aliased_name = f'{alias}{table.name}.{column.name}' - clause_text = clause_text.replace(f'{table.name}.{column.name}', aliased_name) + clause_text = str( + clause.compile( + dialect=self.model_cls.__database__._backend._dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + alias = f"{table_prefix}_" if table_prefix else "" + aliased_name = f"{alias}{table.name}.{column.name}" + clause_text = clause_text.replace( + f"{table.name}.{column.name}", aliased_name + ) clause = text(clause_text) filter_clauses.append(clause) @@ -264,10 +348,10 @@ class QuerySet: filter_clauses=filter_clauses, select_related=select_related, limit_count=self.limit_count, - offset=self.query_offset + offset=self.query_offset, ) - def select_related(self, related): + def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet": if not isinstance(related, (list, tuple)): related = [related] @@ -277,7 +361,7 @@ class QuerySet: filter_clauses=self.filter_clauses, select_related=related, limit_count=self.limit_count, - offset=self.query_offset + offset=self.query_offset, ) async def exists(self) -> bool: @@ -290,25 +374,25 @@ class QuerySet: expr = sqlalchemy.func.count().select().select_from(expr) return await self.database.fetch_val(expr) - def limit(self, limit_count: int): + def limit(self, limit_count: int) -> "QuerySet": return self.__class__( model_cls=self.model_cls, filter_clauses=self.filter_clauses, select_related=self._select_related, limit_count=limit_count, - offset=self.query_offset + offset=self.query_offset, ) - def offset(self, offset: int): + def offset(self, offset: int) -> "QuerySet": return self.__class__( model_cls=self.model_cls, filter_clauses=self.filter_clauses, select_related=self._select_related, limit_count=self.limit_count, - offset=offset + offset=offset, ) - async def first(self, **kwargs): + async def first(self, **kwargs: Any) -> "Model": if kwargs: return await self.filter(**kwargs).first() @@ -316,7 +400,7 @@ class QuerySet: if rows: return rows[0] - async def get(self, **kwargs): + async def get(self, **kwargs: Any) -> "Model": if kwargs: return await self.filter(**kwargs).get() @@ -329,7 +413,7 @@ class QuerySet: raise MultipleMatches() return self.model_cls.from_row(rows[0], select_related=self._select_related) - async def all(self, **kwargs): + async def all(self, **kwargs: Any) -> List["Model"]: if kwargs: return await self.filter(**kwargs).all() @@ -345,7 +429,7 @@ class QuerySet: return result_rows @classmethod - def merge_result_rows(cls, result_rows): + def merge_result_rows(cls, result_rows: List["Model"]) -> List["Model"]: merged_rows = [] for index, model in enumerate(result_rows): if index > 0 and model.pk == result_rows[index - 1].pk: @@ -355,30 +439,45 @@ class QuerySet: return merged_rows @classmethod - def merge_two_instances(cls, one: 'Model', other: 'Model'): + def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": for field in one.__model_fields__.keys(): # print(field, one.dict(), other.dict()) - if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model): + if isinstance(getattr(one, field), list) and not isinstance( + getattr(one, field), orm.models.Model + ): setattr(other, field, getattr(one, field) + getattr(other, field)) elif isinstance(getattr(one, field), orm.models.Model): if getattr(one, field).pk == getattr(other, field).pk: - setattr(other, field, cls.merge_two_instances(getattr(one, field), getattr(other, field))) + setattr( + other, + field, + cls.merge_two_instances( + getattr(one, field), getattr(other, field) + ), + ) return other - async def create(self, **kwargs): + async def create(self, **kwargs: Any) -> "Model": new_kwargs = dict(**kwargs) # Remove primary key when None to prevent not null constraint in postgresql. pkname = self.model_cls.__pkname__ pk = self.model_cls.__model_fields__[pkname] - if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement): + if ( + pkname in new_kwargs + and new_kwargs.get(pkname) is None + and (pk.nullable or pk.autoincrement) + ): del new_kwargs[pkname] # substitute related models with their pk for field in self.model_cls.extract_related_names(): if field in new_kwargs and new_kwargs.get(field) is not None: - new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__) + new_kwargs[field] = getattr( + new_kwargs.get(field), + self.model_cls.__model_fields__[field].to.__pkname__, + ) # Build the insert expression. expr = self.table.insert() diff --git a/orm/relations.py b/orm/relations.py index a9a6971..b5741e1 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -2,7 +2,7 @@ import pprint import string import uuid from random import choices -from typing import TYPE_CHECKING, List +from typing import Dict, List, TYPE_CHECKING, Union from weakref import proxy from orm.fields import ForeignKey @@ -11,40 +11,58 @@ if TYPE_CHECKING: # pragma no cover from orm.models import Model -def get_table_alias(): - return ''.join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] +def get_table_alias() -> str: + return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] -def get_relation_config(relation_type: str, table_name: str, field: ForeignKey): +def get_relation_config( + relation_type: str, table_name: str, field: ForeignKey +) -> Dict[str, str]: alias = get_table_alias() - config = {'type': relation_type, - 'table_alias': alias, - 'source_table': table_name if relation_type == 'primary' else field.to.__tablename__, - 'target_table': field.to.__tablename__ if relation_type == 'primary' else table_name - } + config = { + "type": relation_type, + "table_alias": alias, + "source_table": table_name + if relation_type == "primary" + else field.to.__tablename__, + "target_table": field.to.__tablename__ + if relation_type == "primary" + else table_name, + } return config class RelationshipManager: - - def __init__(self): + def __init__(self) -> None: self._relations = dict() - def add_relation_type(self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str): - print(relations_key, reverse_key) + def add_relation_type( + self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str + ) -> None: if relations_key not in self._relations: - self._relations[relations_key] = get_relation_config('primary', table_name, field) + self._relations[relations_key] = get_relation_config( + "primary", table_name, field + ) if reverse_key not in self._relations: - self._relations[reverse_key] = get_relation_config('reverse', table_name, field) + self._relations[reverse_key] = get_relation_config( + "reverse", table_name, field + ) - def deregister(self, model: 'Model'): + def deregister(self, model: "Model") -> None: # print(f'deregistering {model.__class__.__name__}, {model._orm_id}') for rel_type in self._relations.keys(): if model.__class__.__name__.lower() in rel_type.lower(): if model._orm_id in self._relations[rel_type]: del self._relations[rel_type][model._orm_id] - def add_relation(self, parent_name: str, child_name: str, parent: 'Model', child: 'Model', virtual: bool = False): + def add_relation( + self, + parent_name: str, + child_name: str, + parent: "Model", + child: "Model", + virtual: bool = False, + ) -> None: parent_id = parent._orm_id child_id = child._orm_id if virtual: @@ -53,12 +71,18 @@ class RelationshipManager: child, parent = parent, proxy(child) else: child = proxy(child) - parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, []) + parents_list = self._relations[ + parent_name.lower().title() + "_" + child_name + "s" + ].setdefault(parent_id, []) self.append_related_model(parents_list, child) - children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, []) + children_list = self._relations[ + child_name.lower().title() + "_" + parent_name + ].setdefault(child_id, []) self.append_related_model(children_list, parent) - def append_related_model(self, relations_list: List['Model'], model: 'Model'): + def append_related_model( + self, relations_list: List["Model"], model: "Model" + ) -> None: for x in relations_list: try: if x.__same__(model): @@ -68,26 +92,26 @@ class RelationshipManager: relations_list.append(model) - def contains(self, relations_key: str, object: 'Model'): + def contains(self, relations_key: str, object: "Model") -> bool: if relations_key in self._relations: return object._orm_id in self._relations[relations_key] return False - def get(self, relations_key: str, object: 'Model'): + def get(self, relations_key: str, object: "Model") -> Union["Model", List["Model"]]: if relations_key in self._relations: if object._orm_id in self._relations[relations_key]: - if self._relations[relations_key]['type'] == 'primary': + if self._relations[relations_key]["type"] == "primary": return self._relations[relations_key][object._orm_id][0] return self._relations[relations_key][object._orm_id] def resolve_relation_join(self, from_table: str, to_table: str) -> str: for k, v in self._relations.items(): - if v['source_table'] == from_table and v['target_table'] == to_table: - return self._relations[k]['table_alias'] - return '' + if v["source_table"] == from_table and v["target_table"] == to_table: + return self._relations[k]["table_alias"] + return "" - def __str__(self): # pragma no cover + def __str__(self) -> str: # pragma no cover return pprint.pformat(self._relations, indent=4, width=1) - def __repr__(self): # pragma no cover + def __repr__(self) -> str: # pragma no cover return self.__str__() diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index f06f141..1b17457 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -109,6 +109,22 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition(): test = fields.Integer(name='test12', primary_key=True, pydantic_only=True) +def test_decimal_error_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + __tablename__ = "example4" + __metadata__ = metadata + test = fields.Decimal(name='test12', primary_key=True) + + +def test_string_error_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + __tablename__ = "example4" + __metadata__ = metadata + test = fields.String(name='test12', primary_key=True) + + def test_json_conversion_in_model(): with pytest.raises(pydantic.ValidationError): ExampleModel(test_json=datetime.datetime.now(), test=1, test_string='test', test_bool=True)