diff --git a/.coverage b/.coverage index b5b4f93..2007c58 100644 Binary files a/.coverage and b/.coverage differ diff --git a/orm/fields/required_decorator.py b/orm/fields/decorators.py similarity index 100% rename from orm/fields/required_decorator.py rename to orm/fields/decorators.py diff --git a/orm/fields/model_fields.py b/orm/fields/model_fields.py index 4d841e1..f14391e 100644 --- a/orm/fields/model_fields.py +++ b/orm/fields/model_fields.py @@ -5,7 +5,7 @@ import sqlalchemy from pydantic import Json from orm.fields.base import BaseField # noqa I101 -from orm.fields.required_decorator import RequiredParams +from orm.fields.decorators import RequiredParams @RequiredParams("length") diff --git a/orm/models.py b/orm/models.py deleted file mode 100644 index 22750eb..0000000 --- a/orm/models.py +++ /dev/null @@ -1,405 +0,0 @@ -import copy -import inspect -import json -import uuid -from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar -from typing import Callable, Dict, Set - -import databases -import pydantic -import sqlalchemy -from pydantic import BaseConfig, BaseModel, create_model -from pydantic.fields import ModelField - -import orm.queryset as qry # noqa I100 -from orm import ForeignKey -from orm.exceptions import ModelDefinitionError -from orm.fields.base import BaseField -from orm.relations import RelationshipManager - -relationship_manager = RelationshipManager() - - -def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: - pydantic_fields = { - field_name: ( - base_field.__type__, - ... if base_field.is_required else base_field.default_value, - ) - for field_name, base_field in object_dict.items() - if isinstance(base_field, BaseField) - } - return pydantic_fields - - -def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: - child_relation_name = field.to.get_name(title=True) + "_" + name.lower() + "s" - reverse_name = field.related_name or child_relation_name - relation_name = name.lower().title() + "_" + field.to.get_name() - relationship_manager.add_relation_type( - relation_name, reverse_name, field, table_name - ) - - -def expand_reverse_relationships(model: Type["Model"]) -> None: - for model_field in model.__model_fields__.values(): - if isinstance(model_field, ForeignKey): - child_model_name = model_field.related_name or model.__name__.lower() + "s" - parent_model = model_field.to - child = model - if ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ - ): - register_reverse_model_fields(parent_model, child, child_model_name) - - -def register_reverse_model_fields( - model: Type["Model"], child: Type["Model"], child_model_name: str -) -> None: - model.__fields__[child_model_name] = ModelField( - name=child_model_name, - type_=Optional[child.__pydantic_model__], - model_config=child.__pydantic_model__.__config__, - class_validators=child.__pydantic_model__.__validators__, - ) - model.__model_fields__[child_model_name] = ForeignKey( - child, name=child_model_name, virtual=True - ) - - -def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str -) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: - pkname: Optional[str] = None - columns: List[sqlalchemy.Column] = [] - model_fields: Dict[str, BaseField] = {} - - for field_name, field in object_dict.items(): - if isinstance(field, BaseField): - model_fields[field_name] = field - if not field.pydantic_only: - if field.primary_key: - pkname = field_name - if isinstance(field, ForeignKey): - register_relation_on_build(table_name, field, name) - columns.append(field.get_column(field_name)) - return pkname, columns, model_fields - - -def get_pydantic_base_orm_config() -> Type[BaseConfig]: - class Config(BaseConfig): - orm_mode = True - - return Config - - -class ModelMetaclass(type): - def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: - new_model = super().__new__( # type: ignore - mcs, name, bases, attrs - ) - - if attrs.get("__abstract__"): - return new_model - - tablename = attrs["__tablename__"] - metadata = attrs["__metadata__"] - - # sqlalchemy table creation - pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( - name, attrs, tablename - ) - attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns) - attrs["__columns__"] = columns - attrs["__pkname__"] = pkname - - if not pkname: - raise ModelDefinitionError("Table has to have a primary key.") - - # pydantic model creation - pydantic_fields = parse_pydantic_field_from_model_fields(attrs) - pydantic_model = create_model( - name, __config__=get_pydantic_base_orm_config(), **pydantic_fields - ) - attrs["__pydantic_fields__"] = pydantic_fields - attrs["__pydantic_model__"] = pydantic_model - attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__) - attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__) - attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__) - - attrs["__model_fields__"] = model_fields - attrs["_orm_relationship_manager"] = relationship_manager - - new_model = super().__new__( # type: ignore - mcs, name, bases, attrs - ) - - expand_reverse_relationships(new_model) - - return new_model - - -class FakePydantic(list, metaclass=ModelMetaclass): - # FakePydantic inherits from list in order to be treated as - # request.Body parameter in fastapi routes, - # inheriting from pydantic.BaseModel causes metaclass conflicts - __abstract__ = True - if TYPE_CHECKING: # pragma no cover - __model_fields__: Dict[str, TypeVar[BaseField]] - __table__: sqlalchemy.Table - __fields__: Dict[str, pydantic.fields.ModelField] - __pydantic_model__: Type[BaseModel] - __pkname__: str - __tablename__: str - __metadata__: sqlalchemy.MetaData - __database__: databases.Database - _orm_relationship_manager: RelationshipManager - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__() - self._orm_id: str = uuid.uuid4().hex - self._orm_saved: bool = False - self.values: Optional[BaseModel] = None - - if "pk" in kwargs: - kwargs[self.__pkname__] = kwargs.pop("pk") - kwargs = { - k: self.__model_fields__[k].expand_relationship(v, self) - for k, v in kwargs.items() - } - self.values = self.__pydantic_model__(**kwargs) - - def __del__(self) -> None: - self._orm_relationship_manager.deregister(self) - - def __setattr__(self, key: str, value: Any) -> None: - if key in self.__fields__: - if self._is_conversion_to_json_needed(key) and not isinstance(value, str): - try: - value = json.dumps(value) - except TypeError: # pragma no cover - pass - - value = self.__model_fields__[key].expand_relationship(value, self) - - relation_key = self.__class__.__name__.title() + "_" + key - if not self._orm_relationship_manager.contains(relation_key, self): - setattr(self.values, key, value) - else: - super().__setattr__(key, value) - - def __getattribute__(self, key: str) -> Any: - if key != "__fields__" and key in self.__fields__: - relation_key = self.__class__.__name__.title() + "_" + key - if self._orm_relationship_manager.contains(relation_key, self): - return self._orm_relationship_manager.get(relation_key, self) - - item = getattr(self.values, key, None) - if ( - item is not None - and self._is_conversion_to_json_needed(key) - and isinstance(item, str) - ): - try: - item = json.loads(item) - except TypeError: # pragma no cover - pass - return item - return super().__getattribute__(key) - - def __eq__(self, other: "Model") -> bool: - return self.values.dict() == other.values.dict() - - def __same__(self, other: "Model") -> bool: - if self.__class__ != other.__class__: # pragma no cover - return False - return self._orm_id == other._orm_id or ( - self.values is not None and other.values is not None and self.pk == other.pk - ) - - def __repr__(self) -> str: # pragma no cover - return self.values.__repr__() - - @classmethod - def __get_validators__(cls) -> Callable: # pragma no cover - yield cls.__pydantic_model__.validate - - @classmethod - def get_name(cls, title: bool = False, lower: bool = True) -> str: - name = cls.__name__ - if lower: - name = name.lower() - if title: - name = name.title() - return name - - @property - def pk_column(self) -> sqlalchemy.Column: - return self.__table__.primary_key.columns.values()[0] - - @classmethod - def pk_type(cls) -> Any: - return cls.__model_fields__[cls.__pkname__].__type__ - - def dict(self) -> Dict: # noqa: A003 - dict_instance = self.values.dict() - for field in self._extract_related_names(): - nested_model = getattr(self, field) - if isinstance(nested_model, list): - dict_instance[field] = [x.dict() for x in nested_model] - else: - dict_instance[field] = ( - nested_model.dict() if nested_model is not None else {} - ) - return dict_instance - - def from_dict(self, value_dict: Dict) -> None: - for key, value in value_dict.items(): - setattr(self, key, value) - - def _is_conversion_to_json_needed(self, column_name: str) -> bool: - return self.__model_fields__.get(column_name).__type__ == pydantic.Json - - def _extract_own_model_fields(self) -> Dict: - related_names = self._extract_related_names() - self_fields = {k: v for k, v in self.dict().items() if k not in related_names} - return self_fields - - @classmethod - def _extract_related_names(cls) -> Set: - related_names = set() - for name, field in cls.__fields__.items(): - if inspect.isclass(field.type_) and issubclass( - field.type_, pydantic.BaseModel - ): - related_names.add(name) - return related_names - - def _extract_model_db_fields(self) -> Dict: - self_fields = self._extract_own_model_fields() - self_fields = { - k: v for k, v in self_fields.items() if k in self.__table__.columns - } - for field in self._extract_related_names(): - if getattr(self, field) is not None: - self_fields[field] = getattr( - getattr(self, field), self.__model_fields__[field].to.__pkname__ - ) - return self_fields - - @classmethod - def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: - merged_rows = [] - for index, model in enumerate(result_rows): - if index > 0 and model.pk == result_rows[index - 1].pk: - result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) - else: - merged_rows.append(model) - return merged_rows - - @classmethod - def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": - for field in one.__model_fields__.keys(): - # print(field, one.dict(), other.dict()) - if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), Model - ): - setattr(other, field, getattr(one, field) + getattr(other, field)) - elif isinstance(getattr(one, field), Model): - if getattr(one, field).pk == getattr(other, field).pk: - setattr( - other, - field, - cls.merge_two_instances( - getattr(one, field), getattr(other, field) - ), - ) - return other - - -class Model(FakePydantic): - __abstract__ = True - - objects = qry.QuerySet() - - @classmethod - def from_row( - cls, - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - previous_table: str = None, - ) -> "Model": - - item = {} - select_related = select_related or [] - - table_prefix = cls._orm_relationship_manager.resolve_relation_join( - previous_table, cls.__table__.name - ) - previous_table = cls.__table__.name - for related in select_related: - if "__" in related: - first_part, remainder = related.split("__", 1) - model_cls = cls.__model_fields__[first_part].to - child = model_cls.from_row( - row, select_related=[remainder], previous_table=previous_table - ) - item[first_part] = child - else: - model_cls = cls.__model_fields__[related].to - child = model_cls.from_row(row, previous_table=previous_table) - item[related] = child - - for column in cls.__table__.columns: - if column.name not in item: - item[column.name] = row[ - f'{table_prefix + "_" if table_prefix else ""}{column.name}' - ] - - return cls(**item) - - @property - def pk(self) -> str: - return getattr(self.values, self.__pkname__) - - @pk.setter - def pk(self, value: Any) -> None: - setattr(self.values, self.__pkname__, value) - - async def save(self) -> int: - self_fields = self._extract_model_db_fields() - if self.__model_fields__.get(self.__pkname__).autoincrement: - self_fields.pop(self.__pkname__, None) - expr = self.__table__.insert() - expr = expr.values(**self_fields) - item_id = await self.__database__.execute(expr) - self.pk = item_id - return item_id - - async def update(self, **kwargs: Any) -> int: - if kwargs: - new_values = {**self.dict(), **kwargs} - self.from_dict(new_values) - - self_fields = self._extract_model_db_fields() - self_fields.pop(self.__pkname__) - expr = ( - self.__table__.update() - .values(**self_fields) - .where(self.pk_column == getattr(self, self.__pkname__)) - ) - result = await self.__database__.execute(expr) - return result - - async def delete(self) -> int: - expr = self.__table__.delete() - expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) - result = await self.__database__.execute(expr) - return result - - async def load(self) -> "Model": - expr = self.__table__.select().where(self.pk_column == self.pk) - row = await self.__database__.fetch_one(expr) - self.from_dict(dict(row)) - return self diff --git a/orm/models/__init__.py b/orm/models/__init__.py new file mode 100644 index 0000000..00fac17 --- /dev/null +++ b/orm/models/__init__.py @@ -0,0 +1,5 @@ +from orm.models.model import Model + +__all__ = [ + "Model" +] diff --git a/orm/models/fakepydantic.py b/orm/models/fakepydantic.py new file mode 100644 index 0000000..bbced86 --- /dev/null +++ b/orm/models/fakepydantic.py @@ -0,0 +1,195 @@ +import inspect +import json +import uuid +from typing import TYPE_CHECKING, Dict, TypeVar, Type, Any, Optional, Callable, Set, List + +import databases +import pydantic +import sqlalchemy +from pydantic import BaseModel + +import orm +from orm.fields import BaseField +from orm.models.metaclass import ModelMetaclass +from orm.relations import RelationshipManager + +if TYPE_CHECKING: #pragma no cover + from orm.models.model import Model + + +class FakePydantic(list, metaclass=ModelMetaclass): + # FakePydantic inherits from list in order to be treated as + # request.Body parameter in fastapi routes, + # inheriting from pydantic.BaseModel causes metaclass conflicts + __abstract__ = True + if TYPE_CHECKING: # pragma no cover + __model_fields__: Dict[str, TypeVar[BaseField]] + __table__: sqlalchemy.Table + __fields__: Dict[str, pydantic.fields.ModelField] + __pydantic_model__: Type[BaseModel] + __pkname__: str + __tablename__: str + __metadata__: sqlalchemy.MetaData + __database__: databases.Database + _orm_relationship_manager: RelationshipManager + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + self._orm_id: str = uuid.uuid4().hex + self._orm_saved: bool = False + self.values: Optional[BaseModel] = None + + if "pk" in kwargs: + kwargs[self.__pkname__] = kwargs.pop("pk") + kwargs = { + k: self.__model_fields__[k].expand_relationship(v, self) + for k, v in kwargs.items() + } + self.values = self.__pydantic_model__(**kwargs) + + def __del__(self) -> None: + self._orm_relationship_manager.deregister(self) + + def __setattr__(self, key: str, value: Any) -> None: + if key in self.__fields__: + if self._is_conversion_to_json_needed(key) and not isinstance(value, str): + try: + value = json.dumps(value) + except TypeError: # pragma no cover + pass + + value = self.__model_fields__[key].expand_relationship(value, self) + + relation_key = self.__class__.__name__.title() + "_" + key + if not self._orm_relationship_manager.contains(relation_key, self): + setattr(self.values, key, value) + else: + super().__setattr__(key, value) + + def __getattribute__(self, key: str) -> Any: + if key != "__fields__" and key in self.__fields__: + relation_key = self.__class__.__name__.title() + "_" + key + if self._orm_relationship_manager.contains(relation_key, self): + return self._orm_relationship_manager.get(relation_key, self) + + item = getattr(self.values, key, None) + if ( + item is not None + and self._is_conversion_to_json_needed(key) + and isinstance(item, str) + ): + try: + item = json.loads(item) + except TypeError: # pragma no cover + pass + return item + return super().__getattribute__(key) + + def __eq__(self, other: "Model") -> bool: + return self.values.dict() == other.values.dict() + + def __same__(self, other: "Model") -> bool: + if self.__class__ != other.__class__: # pragma no cover + return False + return self._orm_id == other._orm_id or ( + self.values is not None and other.values is not None and self.pk == other.pk + ) + + def __repr__(self) -> str: # pragma no cover + return self.values.__repr__() + + @classmethod + def __get_validators__(cls) -> Callable: # pragma no cover + yield cls.__pydantic_model__.validate + + @classmethod + def get_name(cls, title: bool = False, lower: bool = True) -> str: + name = cls.__name__ + if lower: + name = name.lower() + if title: + name = name.title() + return name + + @property + def pk_column(self) -> sqlalchemy.Column: + return self.__table__.primary_key.columns.values()[0] + + @classmethod + def pk_type(cls) -> Any: + return cls.__model_fields__[cls.__pkname__].__type__ + + def dict(self) -> Dict: # noqa: A003 + dict_instance = self.values.dict() + for field in self._extract_related_names(): + nested_model = getattr(self, field) + if isinstance(nested_model, list): + dict_instance[field] = [x.dict() for x in nested_model] + else: + dict_instance[field] = ( + nested_model.dict() if nested_model is not None else {} + ) + return dict_instance + + def from_dict(self, value_dict: Dict) -> None: + for key, value in value_dict.items(): + setattr(self, key, value) + + def _is_conversion_to_json_needed(self, column_name: str) -> bool: + return self.__model_fields__.get(column_name).__type__ == pydantic.Json + + def _extract_own_model_fields(self) -> Dict: + related_names = self._extract_related_names() + self_fields = {k: v for k, v in self.dict().items() if k not in related_names} + return self_fields + + @classmethod + def _extract_related_names(cls) -> Set: + related_names = set() + for name, field in cls.__fields__.items(): + if inspect.isclass(field.type_) and issubclass( + field.type_, pydantic.BaseModel + ): + related_names.add(name) + return related_names + + def _extract_model_db_fields(self) -> Dict: + self_fields = self._extract_own_model_fields() + self_fields = { + k: v for k, v in self_fields.items() if k in self.__table__.columns + } + for field in self._extract_related_names(): + if getattr(self, field) is not None: + self_fields[field] = getattr( + getattr(self, field), self.__model_fields__[field].to.__pkname__ + ) + return self_fields + + @classmethod + def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: + merged_rows = [] + for index, model in enumerate(result_rows): + if index > 0 and model.pk == result_rows[index - 1].pk: + result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) + else: + merged_rows.append(model) + return merged_rows + + @classmethod + def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": + for field in one.__model_fields__.keys(): + # print(field, one.dict(), other.dict()) + if isinstance(getattr(one, field), list) and not isinstance( + getattr(one, field), orm.Model + ): + setattr(other, field, getattr(one, field) + getattr(other, field)) + elif isinstance(getattr(one, field), orm.Model): + if getattr(one, field).pk == getattr(other, field).pk: + setattr( + other, + field, + cls.merge_two_instances( + getattr(one, field), getattr(other, field) + ), + ) + return other diff --git a/orm/models/metaclass.py b/orm/models/metaclass.py new file mode 100644 index 0000000..b3546bb --- /dev/null +++ b/orm/models/metaclass.py @@ -0,0 +1,132 @@ +import copy +from typing import Dict, Tuple, Type, Optional, List, Any + +import sqlalchemy +from pydantic import BaseConfig, create_model +from pydantic.fields import ModelField + +from orm import ForeignKey, ModelDefinitionError +from orm.fields import BaseField +from orm.relations import RelationshipManager + +relationship_manager = RelationshipManager() + + +def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: + pydantic_fields = { + field_name: ( + base_field.__type__, + ... if base_field.is_required else base_field.default_value, + ) + for field_name, base_field in object_dict.items() + if isinstance(base_field, BaseField) + } + return pydantic_fields + + +def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: + child_relation_name = field.to.get_name(title=True) + "_" + name.lower() + "s" + reverse_name = field.related_name or child_relation_name + relation_name = name.lower().title() + "_" + field.to.get_name() + relationship_manager.add_relation_type( + relation_name, reverse_name, field, table_name + ) + + +def expand_reverse_relationships(model: Type["Model"]) -> None: + for model_field in model.__model_fields__.values(): + if isinstance(model_field, ForeignKey): + child_model_name = model_field.related_name or model.get_name() + "s" + parent_model = model_field.to + child = model + if ( + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ + ): + register_reverse_model_fields(parent_model, child, child_model_name) + + +def register_reverse_model_fields( + model: Type["Model"], child: Type["Model"], child_model_name: str +) -> None: + model.__fields__[child_model_name] = ModelField( + name=child_model_name, + type_=Optional[child.__pydantic_model__], + model_config=child.__pydantic_model__.__config__, + class_validators=child.__pydantic_model__.__validators__, + ) + model.__model_fields__[child_model_name] = ForeignKey( + child, name=child_model_name, virtual=True + ) + + +def sqlalchemy_columns_from_model_fields( + name: str, object_dict: Dict, table_name: str +) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: + pkname: Optional[str] = None + columns: List[sqlalchemy.Column] = [] + model_fields: Dict[str, BaseField] = {} + + for field_name, field in object_dict.items(): + if isinstance(field, BaseField): + model_fields[field_name] = field + if not field.pydantic_only: + if field.primary_key: + pkname = field_name + if isinstance(field, ForeignKey): + register_relation_on_build(table_name, field, name) + columns.append(field.get_column(field_name)) + return pkname, columns, model_fields + + +def get_pydantic_base_orm_config() -> Type[BaseConfig]: + class Config(BaseConfig): + orm_mode = True + + return Config + + +class ModelMetaclass(type): + def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: + new_model = super().__new__( # type: ignore + mcs, name, bases, attrs + ) + + if attrs.get("__abstract__"): + return new_model + + tablename = attrs["__tablename__"] + metadata = attrs["__metadata__"] + + # sqlalchemy table creation + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( + name, attrs, tablename + ) + attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns) + attrs["__columns__"] = columns + attrs["__pkname__"] = pkname + + if not pkname: + raise ModelDefinitionError("Table has to have a primary key.") + + # pydantic model creation + pydantic_fields = parse_pydantic_field_from_model_fields(attrs) + pydantic_model = create_model( + name, __config__=get_pydantic_base_orm_config(), **pydantic_fields + ) + attrs["__pydantic_fields__"] = pydantic_fields + attrs["__pydantic_model__"] = pydantic_model + attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__) + attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__) + attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__) + + attrs["__model_fields__"] = model_fields + attrs["_orm_relationship_manager"] = relationship_manager + + new_model = super().__new__( # type: ignore + mcs, name, bases, attrs + ) + + expand_reverse_relationships(new_model) + + return new_model \ No newline at end of file diff --git a/orm/models/model.py b/orm/models/model.py new file mode 100644 index 0000000..d949d95 --- /dev/null +++ b/orm/models/model.py @@ -0,0 +1,93 @@ +from typing import List, Any + +import sqlalchemy + +import orm.queryset.queryset +from orm.models.fakepydantic import FakePydantic + + +class Model(FakePydantic): + __abstract__ = True + + objects = orm.queryset.queryset.QuerySet() + + @classmethod + def from_row( + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + previous_table: str = None, + ) -> "Model": + + item = {} + select_related = select_related or [] + + table_prefix = cls._orm_relationship_manager.resolve_relation_join( + previous_table, cls.__table__.name + ) + previous_table = cls.__table__.name + for related in select_related: + if "__" in related: + first_part, remainder = related.split("__", 1) + model_cls = cls.__model_fields__[first_part].to + child = model_cls.from_row( + row, select_related=[remainder], previous_table=previous_table + ) + item[first_part] = child + else: + model_cls = cls.__model_fields__[related].to + child = model_cls.from_row(row, previous_table=previous_table) + item[related] = child + + for column in cls.__table__.columns: + if column.name not in item: + item[column.name] = row[ + f'{table_prefix + "_" if table_prefix else ""}{column.name}' + ] + + return cls(**item) + + @property + def pk(self) -> str: + return getattr(self.values, self.__pkname__) + + @pk.setter + def pk(self, value: Any) -> None: + setattr(self.values, self.__pkname__, value) + + async def save(self) -> int: + self_fields = self._extract_model_db_fields() + if self.__model_fields__.get(self.__pkname__).autoincrement: + self_fields.pop(self.__pkname__, None) + expr = self.__table__.insert() + expr = expr.values(**self_fields) + item_id = await self.__database__.execute(expr) + self.pk = item_id + return item_id + + async def update(self, **kwargs: Any) -> int: + if kwargs: + new_values = {**self.dict(), **kwargs} + self.from_dict(new_values) + + self_fields = self._extract_model_db_fields() + self_fields.pop(self.__pkname__) + expr = ( + self.__table__.update() + .values(**self_fields) + .where(self.pk_column == getattr(self, self.__pkname__)) + ) + result = await self.__database__.execute(expr) + return result + + async def delete(self) -> int: + expr = self.__table__.delete() + expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) + result = await self.__database__.execute(expr) + return result + + async def load(self) -> "Model": + expr = self.__table__.select().where(self.pk_column == self.pk) + row = await self.__database__.fetch_one(expr) + self.from_dict(dict(row)) + return self \ No newline at end of file diff --git a/orm/queryset.py b/orm/queryset.py deleted file mode 100644 index 5033f13..0000000 --- a/orm/queryset.py +++ /dev/null @@ -1,571 +0,0 @@ -from typing import ( - Any, - Dict, - List, - NamedTuple, - Optional, - TYPE_CHECKING, - Tuple, - Type, - Union, -) - -import databases -import sqlalchemy -from sqlalchemy import text - -import orm # noqa I100 -import orm.fields.foreign_key -from orm import ForeignKey -from orm.exceptions import MultipleMatches, NoMatch, QueryDefinitionError -from orm.fields.base import BaseField - -if TYPE_CHECKING: # pragma no cover - from orm.models import Model - -FILTER_OPERATORS = { - "exact": "__eq__", - "iexact": "ilike", - "contains": "like", - "icontains": "ilike", - "in": "in_", - "gt": "__gt__", - "gte": "__ge__", - "lt": "__lt__", - "lte": "__le__", -} - -ESCAPE_CHARACTERS = ["%", "_"] - - -class JoinParameters(NamedTuple): - prev_model: Type["Model"] - previous_alias: str - from_table: str - model_cls: Type["Model"] - - -class Query: - def __init__( - self, - model_cls: Type["Model"], - filter_clauses: List, - select_related: List, - limit_count: int, - offset: int, - ) -> None: - - self.query_offset = offset - self.limit_count = limit_count - self._select_related = select_related - self.filter_clauses = filter_clauses - - self.model_cls = model_cls - self.table = self.model_cls.__table__ - - self.auto_related = [] - self.used_aliases = [] - - self.select_from = None - self.columns = None - self.order_bys = None - - def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: - self.columns = list(self.table.columns) - self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] - self.select_from = self.table - - for key in self.model_cls.__model_fields__: - if ( - not self.model_cls.__model_fields__[key].nullable - and isinstance( - self.model_cls.__model_fields__[key], - orm.fields.foreign_key.ForeignKey, - ) - and key not in self._select_related - ): - self._select_related = [key] + self._select_related - - start_params = JoinParameters( - self.model_cls, "", self.table.name, self.model_cls - ) - self._extract_auto_required_relations(prev_model=start_params.prev_model) - self._include_auto_related_models() - self._select_related.sort(key=lambda item: (-len(item), item)) - - for item in self._select_related: - join_parameters = JoinParameters( - self.model_cls, "", self.table.name, self.model_cls - ) - - for part in item.split("__"): - join_parameters = self._build_join_parameters(part, join_parameters) - - expr = sqlalchemy.sql.select(self.columns) - expr = expr.select_from(self.select_from) - - expr = self._apply_expression_modifiers(expr) - - # print(expr.compile(compile_kwargs={"literal_binds": True})) - self._reset_query_parameters() - - return expr, self._select_related - - @staticmethod - def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: - return [ - text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}") - for column in table.columns - ] - - @staticmethod - def prefixed_table_name(alias: str, name: str) -> text: - return text(f"{name} {alias}_{name}") - - @staticmethod - def _field_is_a_foreign_key_and_no_circular_reference( - field: BaseField, field_name: str, rel_part: str - ) -> bool: - return isinstance(field, ForeignKey) and field_name not in rel_part - - def _field_qualifies_to_deeper_search( - self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str - ) -> bool: - prev_part_of_related = "__".join(rel_part.split("__")[:-1]) - partial_match = any( - [x.startswith(prev_part_of_related) for x in self._select_related] - ) - already_checked = any([x.startswith(rel_part) for x in self.auto_related]) - return ( - (field.virtual and parent_virtual) - or (partial_match and not already_checked) - ) or not nested - - def on_clause( - self, previous_alias: str, alias: str, from_clause: str, to_clause: str, - ) -> text: - left_part = f"{alias}_{to_clause}" - right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" - return text(f"{left_part}={right_part}") - - def _build_join_parameters( - self, part: str, join_params: JoinParameters - ) -> JoinParameters: - model_cls = join_params.model_cls.__model_fields__[part].to - to_table = model_cls.__table__.name - - alias = model_cls._orm_relationship_manager.resolve_relation_join( - join_params.from_table, to_table - ) - if alias not in self.used_aliases: - if join_params.prev_model.__model_fields__[part].virtual: - to_key = next( - ( - v - for k, v in model_cls.__model_fields__.items() - if isinstance(v, ForeignKey) and v.to == join_params.prev_model - ), - None, - ).name - from_key = model_cls.__pkname__ - else: - to_key = model_cls.__pkname__ - from_key = part - - on_clause = self.on_clause( - previous_alias=join_params.previous_alias, - alias=alias, - from_clause=f"{join_params.from_table}.{from_key}", - to_clause=f"{to_table}.{to_key}", - ) - target_table = self.prefixed_table_name(alias, to_table) - self.select_from = sqlalchemy.sql.outerjoin( - self.select_from, target_table, on_clause - ) - self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}")) - self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) - self.used_aliases.append(alias) - - previous_alias = alias - from_table = to_table - prev_model = model_cls - return JoinParameters(prev_model, previous_alias, from_table, model_cls) - - def _extract_auto_required_relations( - self, - prev_model: Type["Model"], - rel_part: str = "", - nested: bool = False, - parent_virtual: bool = False, - ) -> None: - for field_name, field in prev_model.__model_fields__.items(): - if self._field_is_a_foreign_key_and_no_circular_reference( - field, field_name, rel_part - ): - rel_part = field_name if not rel_part else rel_part + "__" + field_name - if not field.nullable: - if rel_part not in self._select_related: - self.auto_related.append("__".join(rel_part.split("__")[:-1])) - rel_part = "" - elif self._field_qualifies_to_deeper_search( - field, parent_virtual, nested, rel_part - ): - self._extract_auto_required_relations( - prev_model=field.to, - rel_part=rel_part, - nested=True, - parent_virtual=field.virtual, - ) - else: - rel_part = "" - - def _include_auto_related_models(self) -> None: - if self.auto_related: - new_joins = [] - for join in self._select_related: - if not any([x.startswith(join) for x in self.auto_related]): - new_joins.append(join) - self._select_related = new_joins + self.auto_related - - def _apply_expression_modifiers( - self, expr: sqlalchemy.sql.select - ) -> sqlalchemy.sql.select: - if self.filter_clauses: - if len(self.filter_clauses) == 1: - clause = self.filter_clauses[0] - else: - clause = sqlalchemy.sql.and_(*self.filter_clauses) - expr = expr.where(clause) - - if self.limit_count: - expr = expr.limit(self.limit_count) - - if self.query_offset: - expr = expr.offset(self.query_offset) - - for order in self.order_bys: - expr = expr.order_by(order) - return expr - - def _reset_query_parameters(self) -> None: - self.select_from = None - self.columns = None - self.order_bys = None - self.auto_related = [] - self.used_aliases = [] - - -class QueryClause: - def __init__( - self, model_cls: Type["Model"], filter_clauses: List, select_related: List, - ) -> None: - - self._select_related = select_related - self.filter_clauses = filter_clauses - - self.model_cls = model_cls - self.table = self.model_cls.__table__ - - def filter( # noqa: A003 - self, **kwargs: Any - ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: - filter_clauses = self.filter_clauses - select_related = list(self._select_related) - - if kwargs.get("pk"): - pk_name = self.model_cls.__pkname__ - kwargs[pk_name] = kwargs.pop("pk") - - for key, value in kwargs.items(): - table_prefix = "" - if "__" in key: - parts = key.split("__") - - ( - op, - field_name, - related_parts, - ) = self._extract_operator_field_and_related(parts) - - model_cls = self.model_cls - if related_parts: - ( - select_related, - table_prefix, - model_cls, - ) = self._determine_filter_target_table( - related_parts, select_related - ) - - table = model_cls.__table__ - column = model_cls.__table__.columns[field_name] - - else: - op = "exact" - column = self.table.columns[key] - table = self.table - - value, has_escaped_character = self._escape_characters_in_clause(op, value) - - if isinstance(value, orm.Model): - value = value.pk - - op_attr = FILTER_OPERATORS[op] - clause = getattr(column, op_attr)(value) - clause = self._compile_clause( - clause, - column, - table, - table_prefix, - modifiers={"escape": "\\" if has_escaped_character else None}, - ) - filter_clauses.append(clause) - - return filter_clauses, select_related - - def _determine_filter_target_table( - self, related_parts: List[str], select_related: List[str] - ) -> Tuple[List[str], str, "Model"]: - - table_prefix = "" - model_cls = self.model_cls - select_related = [relation for relation in select_related] - - # Add any implied select_related - related_str = "__".join(related_parts) - if related_str not in select_related: - select_related.append(related_str) - - # Walk the relationships to the actual model class - # against which the comparison is being made. - previous_table = model_cls.__tablename__ - for part in related_parts: - current_table = model_cls.__model_fields__[part].to.__tablename__ - manager = model_cls._orm_relationship_manager - table_prefix = manager.resolve_relation_join(previous_table, current_table) - model_cls = model_cls.__model_fields__[part].to - previous_table = current_table - return select_related, table_prefix, model_cls - - def _compile_clause( - self, - clause: sqlalchemy.sql.expression.BinaryExpression, - column: sqlalchemy.Column, - table: sqlalchemy.Table, - table_prefix: str, - modifiers: Dict, - ) -> sqlalchemy.sql.expression.TextClause: - for modifier, modifier_value in modifiers.items(): - clause.modifiers[modifier] = modifier_value - - clause_text = str( - clause.compile( - dialect=self.model_cls.__database__._backend._dialect, - compile_kwargs={"literal_binds": True}, - ) - ) - alias = f"{table_prefix}_" if table_prefix else "" - aliased_name = f"{alias}{table.name}.{column.name}" - clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name) - clause = text(clause_text) - return clause - - @staticmethod - def _escape_characters_in_clause( - op: str, value: Union[str, "Model"] - ) -> Tuple[str, bool]: - has_escaped_character = False - - if op in ["contains", "icontains"]: - if isinstance(value, orm.Model): - raise QueryDefinitionError( - "You cannot use contains and icontains with instance of the Model" - ) - - has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value) - - if has_escaped_character: - # enable escape modifier - for char in ESCAPE_CHARACTERS: - value = value.replace(char, f"\\{char}") - value = f"%{value}%" - - return value, has_escaped_character - - @staticmethod - def _extract_operator_field_and_related( - parts: List[str], - ) -> Tuple[str, str, Optional[List]]: - if parts[-1] in FILTER_OPERATORS: - op = parts[-1] - field_name = parts[-2] - related_parts = parts[:-2] - else: - op = "exact" - field_name = parts[-1] - related_parts = parts[:-1] - - return op, field_name, related_parts - - -class QuerySet: - def __init__( - self, - model_cls: Type["Model"] = None, - filter_clauses: List = None, - select_related: List = None, - limit_count: int = None, - offset: int = None, - ) -> None: - self.model_cls = model_cls - self.filter_clauses = [] if filter_clauses is None else filter_clauses - self._select_related = [] if select_related is None else select_related - self.limit_count = limit_count - self.query_offset = offset - self.order_bys = None - - def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": - return self.__class__(model_cls=owner) - - @property - def database(self) -> databases.Database: - return self.model_cls.__database__ - - @property - def table(self) -> sqlalchemy.Table: - return self.model_cls.__table__ - - def build_select_expression(self) -> sqlalchemy.sql.select: - qry = Query( - model_cls=self.model_cls, - select_related=self._select_related, - filter_clauses=self.filter_clauses, - offset=self.query_offset, - limit_count=self.limit_count, - ) - exp, self._select_related = qry.build_select_expression() - return exp - - def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 - qryclause = QueryClause( - model_cls=self.model_cls, - select_related=self._select_related, - filter_clauses=self.filter_clauses, - ) - filter_clauses, select_related = qryclause.filter(**kwargs) - - return self.__class__( - model_cls=self.model_cls, - filter_clauses=filter_clauses, - select_related=select_related, - limit_count=self.limit_count, - offset=self.query_offset, - ) - - def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet": - if not isinstance(related, (list, tuple)): - related = [related] - - related = list(self._select_related) + related - return self.__class__( - model_cls=self.model_cls, - filter_clauses=self.filter_clauses, - select_related=related, - limit_count=self.limit_count, - offset=self.query_offset, - ) - - async def exists(self) -> bool: - expr = self.build_select_expression() - expr = sqlalchemy.exists(expr).select() - return await self.database.fetch_val(expr) - - async def count(self) -> int: - expr = self.build_select_expression().alias("subquery_for_count") - expr = sqlalchemy.func.count().select().select_from(expr) - return await self.database.fetch_val(expr) - - def limit(self, limit_count: int) -> "QuerySet": - return self.__class__( - model_cls=self.model_cls, - filter_clauses=self.filter_clauses, - select_related=self._select_related, - limit_count=limit_count, - offset=self.query_offset, - ) - - def offset(self, offset: int) -> "QuerySet": - return self.__class__( - model_cls=self.model_cls, - filter_clauses=self.filter_clauses, - select_related=self._select_related, - limit_count=self.limit_count, - offset=offset, - ) - - async def first(self, **kwargs: Any) -> "Model": - if kwargs: - return await self.filter(**kwargs).first() - - rows = await self.limit(1).all() - if rows: - return rows[0] - - async def get(self, **kwargs: Any) -> "Model": - if kwargs: - return await self.filter(**kwargs).get() - - expr = self.build_select_expression().limit(2) - rows = await self.database.fetch_all(expr) - - if not rows: - raise NoMatch() - if len(rows) > 1: - raise MultipleMatches() - return self.model_cls.from_row(rows[0], select_related=self._select_related) - - async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 - if kwargs: - return await self.filter(**kwargs).all() - - expr = self.build_select_expression() - rows = await self.database.fetch_all(expr) - result_rows = [ - self.model_cls.from_row(row, select_related=self._select_related) - for row in rows - ] - - result_rows = self.model_cls.merge_instances_list(result_rows) - - return result_rows - - async def create(self, **kwargs: Any) -> "Model": - - new_kwargs = dict(**kwargs) - - # Remove primary key when None to prevent not null constraint in postgresql. - pkname = self.model_cls.__pkname__ - pk = self.model_cls.__model_fields__[pkname] - if ( - pkname in new_kwargs - and new_kwargs.get(pkname) is None - and (pk.nullable or pk.autoincrement) - ): - del new_kwargs[pkname] - - # substitute related models with their pk - for field in self.model_cls._extract_related_names(): - if field in new_kwargs and new_kwargs.get(field) is not None: - new_kwargs[field] = getattr( - new_kwargs.get(field), - self.model_cls.__model_fields__[field].to.__pkname__, - ) - - # Build the insert expression. - expr = self.table.insert() - expr = expr.values(**new_kwargs) - - # Execute the insert, and return a new model instance. - instance = self.model_cls(**kwargs) - instance.pk = await self.database.execute(expr) - return instance diff --git a/orm/queryset/__init__.py b/orm/queryset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/orm/queryset/clause.py b/orm/queryset/clause.py new file mode 100644 index 0000000..0cc2309 --- /dev/null +++ b/orm/queryset/clause.py @@ -0,0 +1,176 @@ +from typing import Type, List, Any, Tuple, Dict, Union, Optional, TYPE_CHECKING + +import sqlalchemy +from sqlalchemy import text + +import orm +from orm.exceptions import QueryDefinitionError + +if TYPE_CHECKING: # pragma no cover + from orm import Model + +FILTER_OPERATORS = { + "exact": "__eq__", + "iexact": "ilike", + "contains": "like", + "icontains": "ilike", + "in": "in_", + "gt": "__gt__", + "gte": "__ge__", + "lt": "__lt__", + "lte": "__le__", +} +ESCAPE_CHARACTERS = ["%", "_"] + + +class QueryClause: + def __init__( + self, model_cls: Type["Model"], filter_clauses: List, select_related: List, + ) -> None: + + self._select_related = select_related + self.filter_clauses = filter_clauses + + self.model_cls = model_cls + self.table = self.model_cls.__table__ + + def filter( # noqa: A003 + self, **kwargs: Any + ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: + filter_clauses = self.filter_clauses + select_related = list(self._select_related) + + if kwargs.get("pk"): + pk_name = self.model_cls.__pkname__ + kwargs[pk_name] = kwargs.pop("pk") + + for key, value in kwargs.items(): + table_prefix = "" + if "__" in key: + parts = key.split("__") + + ( + op, + field_name, + related_parts, + ) = self._extract_operator_field_and_related(parts) + + model_cls = self.model_cls + if related_parts: + ( + select_related, + table_prefix, + model_cls, + ) = self._determine_filter_target_table( + related_parts, select_related + ) + + table = model_cls.__table__ + column = model_cls.__table__.columns[field_name] + + else: + op = "exact" + column = self.table.columns[key] + table = self.table + + value, has_escaped_character = self._escape_characters_in_clause(op, value) + + if isinstance(value, orm.Model): + value = value.pk + + op_attr = FILTER_OPERATORS[op] + clause = getattr(column, op_attr)(value) + clause = self._compile_clause( + clause, + column, + table, + table_prefix, + modifiers={"escape": "\\" if has_escaped_character else None}, + ) + filter_clauses.append(clause) + + return filter_clauses, select_related + + def _determine_filter_target_table( + self, related_parts: List[str], select_related: List[str] + ) -> Tuple[List[str], str, "Model"]: + + table_prefix = "" + model_cls = self.model_cls + select_related = [relation for relation in select_related] + + # Add any implied select_related + related_str = "__".join(related_parts) + if related_str not in select_related: + select_related.append(related_str) + + # Walk the relationships to the actual model class + # against which the comparison is being made. + previous_table = model_cls.__tablename__ + for part in related_parts: + current_table = model_cls.__model_fields__[part].to.__tablename__ + manager = model_cls._orm_relationship_manager + table_prefix = manager.resolve_relation_join(previous_table, current_table) + model_cls = model_cls.__model_fields__[part].to + previous_table = current_table + return select_related, table_prefix, model_cls + + def _compile_clause( + self, + clause: sqlalchemy.sql.expression.BinaryExpression, + column: sqlalchemy.Column, + table: sqlalchemy.Table, + table_prefix: str, + modifiers: Dict, + ) -> sqlalchemy.sql.expression.TextClause: + for modifier, modifier_value in modifiers.items(): + clause.modifiers[modifier] = modifier_value + + clause_text = str( + clause.compile( + dialect=self.model_cls.__database__._backend._dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + alias = f"{table_prefix}_" if table_prefix else "" + aliased_name = f"{alias}{table.name}.{column.name}" + clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name) + clause = text(clause_text) + return clause + + @staticmethod + def _escape_characters_in_clause( + op: str, value: Union[str, "Model"] + ) -> Tuple[str, bool]: + has_escaped_character = False + + if op in ["contains", "icontains"]: + if isinstance(value, orm.Model): + raise QueryDefinitionError( + "You cannot use contains and icontains with instance of the Model" + ) + + has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value) + + if has_escaped_character: + # enable escape modifier + for char in ESCAPE_CHARACTERS: + value = value.replace(char, f"\\{char}") + value = f"%{value}%" + + return value, has_escaped_character + + @staticmethod + def _extract_operator_field_and_related( + parts: List[str], + ) -> Tuple[str, str, Optional[List]]: + if parts[-1] in FILTER_OPERATORS: + op = parts[-1] + field_name = parts[-2] + related_parts = parts[:-2] + else: + op = "exact" + field_name = parts[-1] + related_parts = parts[:-1] + + return op, field_name, related_parts diff --git a/orm/queryset/query.py b/orm/queryset/query.py new file mode 100644 index 0000000..22b2db1 --- /dev/null +++ b/orm/queryset/query.py @@ -0,0 +1,228 @@ +from typing import NamedTuple, Type, List, Tuple, TYPE_CHECKING + +import sqlalchemy +from sqlalchemy import text + +import orm +from orm import ForeignKey +from orm.fields import BaseField + +if TYPE_CHECKING: # pragma no cover + from orm import Model + + +class JoinParameters(NamedTuple): + prev_model: Type["Model"] + previous_alias: str + from_table: str + model_cls: Type["Model"] + + +class Query: + def __init__( + self, + model_cls: Type["Model"], + filter_clauses: List, + select_related: List, + limit_count: int, + offset: int, + ) -> None: + + self.query_offset = offset + self.limit_count = limit_count + self._select_related = select_related + self.filter_clauses = filter_clauses + + self.model_cls = model_cls + self.table = self.model_cls.__table__ + + self.auto_related = [] + self.used_aliases = [] + + self.select_from = None + self.columns = None + self.order_bys = None + + def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: + self.columns = list(self.table.columns) + self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] + self.select_from = self.table + + for key in self.model_cls.__model_fields__: + if ( + not self.model_cls.__model_fields__[key].nullable + and isinstance( + self.model_cls.__model_fields__[key], + orm.fields.foreign_key.ForeignKey, + ) + and key not in self._select_related + ): + self._select_related = [key] + self._select_related + + start_params = JoinParameters( + self.model_cls, "", self.table.name, self.model_cls + ) + self._extract_auto_required_relations(prev_model=start_params.prev_model) + self._include_auto_related_models() + self._select_related.sort(key=lambda item: (-len(item), item)) + + for item in self._select_related: + join_parameters = JoinParameters( + self.model_cls, "", self.table.name, self.model_cls + ) + + for part in item.split("__"): + join_parameters = self._build_join_parameters(part, join_parameters) + + expr = sqlalchemy.sql.select(self.columns) + expr = expr.select_from(self.select_from) + + expr = self._apply_expression_modifiers(expr) + + # print(expr.compile(compile_kwargs={"literal_binds": True})) + self._reset_query_parameters() + + return expr, self._select_related + + @staticmethod + def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: + return [ + text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}") + for column in table.columns + ] + + @staticmethod + def prefixed_table_name(alias: str, name: str) -> text: + return text(f"{name} {alias}_{name}") + + @staticmethod + def _field_is_a_foreign_key_and_no_circular_reference( + field: BaseField, field_name: str, rel_part: str + ) -> bool: + return isinstance(field, ForeignKey) and field_name not in rel_part + + def _field_qualifies_to_deeper_search( + self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str + ) -> bool: + prev_part_of_related = "__".join(rel_part.split("__")[:-1]) + partial_match = any( + [x.startswith(prev_part_of_related) for x in self._select_related] + ) + already_checked = any([x.startswith(rel_part) for x in self.auto_related]) + return ( + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested + + def on_clause( + self, previous_alias: str, alias: str, from_clause: str, to_clause: str, + ) -> text: + left_part = f"{alias}_{to_clause}" + right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" + return text(f"{left_part}={right_part}") + + def _build_join_parameters( + self, part: str, join_params: JoinParameters + ) -> JoinParameters: + model_cls = join_params.model_cls.__model_fields__[part].to + to_table = model_cls.__table__.name + + alias = model_cls._orm_relationship_manager.resolve_relation_join( + join_params.from_table, to_table + ) + if alias not in self.used_aliases: + if join_params.prev_model.__model_fields__[part].virtual: + to_key = next( + ( + v + for k, v in model_cls.__model_fields__.items() + if isinstance(v, ForeignKey) and v.to == join_params.prev_model + ), + None, + ).name + from_key = model_cls.__pkname__ + else: + to_key = model_cls.__pkname__ + from_key = part + + on_clause = self.on_clause( + previous_alias=join_params.previous_alias, + alias=alias, + from_clause=f"{join_params.from_table}.{from_key}", + to_clause=f"{to_table}.{to_key}", + ) + target_table = self.prefixed_table_name(alias, to_table) + self.select_from = sqlalchemy.sql.outerjoin( + self.select_from, target_table, on_clause + ) + self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}")) + self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) + self.used_aliases.append(alias) + + previous_alias = alias + from_table = to_table + prev_model = model_cls + return JoinParameters(prev_model, previous_alias, from_table, model_cls) + + def _extract_auto_required_relations( + self, + prev_model: Type["Model"], + rel_part: str = "", + nested: bool = False, + parent_virtual: bool = False, + ) -> None: + for field_name, field in prev_model.__model_fields__.items(): + if self._field_is_a_foreign_key_and_no_circular_reference( + field, field_name, rel_part + ): + rel_part = field_name if not rel_part else rel_part + "__" + field_name + if not field.nullable: + if rel_part not in self._select_related: + self.auto_related.append("__".join(rel_part.split("__")[:-1])) + rel_part = "" + elif self._field_qualifies_to_deeper_search( + field, parent_virtual, nested, rel_part + ): + self._extract_auto_required_relations( + prev_model=field.to, + rel_part=rel_part, + nested=True, + parent_virtual=field.virtual, + ) + else: + rel_part = "" + + def _include_auto_related_models(self) -> None: + if self.auto_related: + new_joins = [] + for join in self._select_related: + if not any([x.startswith(join) for x in self.auto_related]): + new_joins.append(join) + self._select_related = new_joins + self.auto_related + + def _apply_expression_modifiers( + self, expr: sqlalchemy.sql.select + ) -> sqlalchemy.sql.select: + if self.filter_clauses: + if len(self.filter_clauses) == 1: + clause = self.filter_clauses[0] + else: + clause = sqlalchemy.sql.and_(*self.filter_clauses) + expr = expr.where(clause) + + if self.limit_count: + expr = expr.limit(self.limit_count) + + if self.query_offset: + expr = expr.offset(self.query_offset) + + for order in self.order_bys: + expr = expr.order_by(order) + return expr + + def _reset_query_parameters(self) -> None: + self.select_from = None + self.columns = None + self.order_bys = None + self.auto_related = [] + self.used_aliases = [] diff --git a/orm/queryset/queryset.py b/orm/queryset/queryset.py new file mode 100644 index 0000000..a61a5fa --- /dev/null +++ b/orm/queryset/queryset.py @@ -0,0 +1,175 @@ +from typing import Type, List, Any, Union, Tuple, TYPE_CHECKING + +import databases +import sqlalchemy + +import orm # noqa I100 +from orm import NoMatch, MultipleMatches +from orm.queryset.clause import QueryClause +from orm.queryset.query import Query + +if TYPE_CHECKING: # pragma no cover + from orm import Model + + +class QuerySet: + def __init__( + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, + ) -> None: + self.model_cls = model_cls + self.filter_clauses = [] if filter_clauses is None else filter_clauses + self._select_related = [] if select_related is None else select_related + self.limit_count = limit_count + self.query_offset = offset + self.order_bys = None + + def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": + return self.__class__(model_cls=owner) + + @property + def database(self) -> databases.Database: + return self.model_cls.__database__ + + @property + def table(self) -> sqlalchemy.Table: + return self.model_cls.__table__ + + def build_select_expression(self) -> sqlalchemy.sql.select: + qry = Query( + model_cls=self.model_cls, + select_related=self._select_related, + filter_clauses=self.filter_clauses, + offset=self.query_offset, + limit_count=self.limit_count, + ) + exp, self._select_related = qry.build_select_expression() + return exp + + def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 + qryclause = QueryClause( + model_cls=self.model_cls, + select_related=self._select_related, + filter_clauses=self.filter_clauses, + ) + filter_clauses, select_related = qryclause.filter(**kwargs) + + return self.__class__( + model_cls=self.model_cls, + filter_clauses=filter_clauses, + select_related=select_related, + limit_count=self.limit_count, + offset=self.query_offset, + ) + + def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet": + if not isinstance(related, (list, tuple)): + related = [related] + + related = list(self._select_related) + related + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=related, + limit_count=self.limit_count, + offset=self.query_offset, + ) + + async def exists(self) -> bool: + expr = self.build_select_expression() + expr = sqlalchemy.exists(expr).select() + return await self.database.fetch_val(expr) + + async def count(self) -> int: + expr = self.build_select_expression().alias("subquery_for_count") + expr = sqlalchemy.func.count().select().select_from(expr) + return await self.database.fetch_val(expr) + + def limit(self, limit_count: int) -> "QuerySet": + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + limit_count=limit_count, + offset=self.query_offset, + ) + + def offset(self, offset: int) -> "QuerySet": + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + limit_count=self.limit_count, + offset=offset, + ) + + async def first(self, **kwargs: Any) -> "Model": + if kwargs: + return await self.filter(**kwargs).first() + + rows = await self.limit(1).all() + if rows: + return rows[0] + + async def get(self, **kwargs: Any) -> "Model": + if kwargs: + return await self.filter(**kwargs).get() + + expr = self.build_select_expression().limit(2) + rows = await self.database.fetch_all(expr) + + if not rows: + raise NoMatch() + if len(rows) > 1: + raise MultipleMatches() + return self.model_cls.from_row(rows[0], select_related=self._select_related) + + async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 + if kwargs: + return await self.filter(**kwargs).all() + + expr = self.build_select_expression() + rows = await self.database.fetch_all(expr) + result_rows = [ + self.model_cls.from_row(row, select_related=self._select_related) + for row in rows + ] + + result_rows = self.model_cls.merge_instances_list(result_rows) + + return result_rows + + async def create(self, **kwargs: Any) -> "Model": + + new_kwargs = dict(**kwargs) + + # Remove primary key when None to prevent not null constraint in postgresql. + pkname = self.model_cls.__pkname__ + pk = self.model_cls.__model_fields__[pkname] + if ( + pkname in new_kwargs + and new_kwargs.get(pkname) is None + and (pk.nullable or pk.autoincrement) + ): + del new_kwargs[pkname] + + # substitute related models with their pk + for field in self.model_cls._extract_related_names(): + if field in new_kwargs and new_kwargs.get(field) is not None: + new_kwargs[field] = getattr( + new_kwargs.get(field), + self.model_cls.__model_fields__[field].to.__pkname__, + ) + + # Build the insert expression. + expr = self.table.insert() + expr = expr.values(**new_kwargs) + + # Execute the insert, and return a new model instance. + instance = self.model_cls(**kwargs) + instance.pk = await self.database.execute(expr) + return instance