refactori into packages
This commit is contained in:
@ -5,7 +5,7 @@ import sqlalchemy
|
|||||||
from pydantic import Json
|
from pydantic import Json
|
||||||
|
|
||||||
from orm.fields.base import BaseField # noqa I101
|
from orm.fields.base import BaseField # noqa I101
|
||||||
from orm.fields.required_decorator import RequiredParams
|
from orm.fields.decorators import RequiredParams
|
||||||
|
|
||||||
|
|
||||||
@RequiredParams("length")
|
@RequiredParams("length")
|
||||||
|
|||||||
405
orm/models.py
405
orm/models.py
@ -1,405 +0,0 @@
|
|||||||
import copy
|
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
|
|
||||||
from typing import Callable, Dict, Set
|
|
||||||
|
|
||||||
import databases
|
|
||||||
import pydantic
|
|
||||||
import sqlalchemy
|
|
||||||
from pydantic import BaseConfig, BaseModel, create_model
|
|
||||||
from pydantic.fields import ModelField
|
|
||||||
|
|
||||||
import orm.queryset as qry # noqa I100
|
|
||||||
from orm import ForeignKey
|
|
||||||
from orm.exceptions import ModelDefinitionError
|
|
||||||
from orm.fields.base import BaseField
|
|
||||||
from orm.relations import RelationshipManager
|
|
||||||
|
|
||||||
relationship_manager = RelationshipManager()
|
|
||||||
|
|
||||||
|
|
||||||
def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]:
|
|
||||||
pydantic_fields = {
|
|
||||||
field_name: (
|
|
||||||
base_field.__type__,
|
|
||||||
... if base_field.is_required else base_field.default_value,
|
|
||||||
)
|
|
||||||
for field_name, base_field in object_dict.items()
|
|
||||||
if isinstance(base_field, BaseField)
|
|
||||||
}
|
|
||||||
return pydantic_fields
|
|
||||||
|
|
||||||
|
|
||||||
def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None:
|
|
||||||
child_relation_name = field.to.get_name(title=True) + "_" + name.lower() + "s"
|
|
||||||
reverse_name = field.related_name or child_relation_name
|
|
||||||
relation_name = name.lower().title() + "_" + field.to.get_name()
|
|
||||||
relationship_manager.add_relation_type(
|
|
||||||
relation_name, reverse_name, field, table_name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def expand_reverse_relationships(model: Type["Model"]) -> None:
|
|
||||||
for model_field in model.__model_fields__.values():
|
|
||||||
if isinstance(model_field, ForeignKey):
|
|
||||||
child_model_name = model_field.related_name or model.__name__.lower() + "s"
|
|
||||||
parent_model = model_field.to
|
|
||||||
child = model
|
|
||||||
if (
|
|
||||||
child_model_name not in parent_model.__fields__
|
|
||||||
and child.get_name() not in parent_model.__fields__
|
|
||||||
):
|
|
||||||
register_reverse_model_fields(parent_model, child, child_model_name)
|
|
||||||
|
|
||||||
|
|
||||||
def register_reverse_model_fields(
|
|
||||||
model: Type["Model"], child: Type["Model"], child_model_name: str
|
|
||||||
) -> None:
|
|
||||||
model.__fields__[child_model_name] = ModelField(
|
|
||||||
name=child_model_name,
|
|
||||||
type_=Optional[child.__pydantic_model__],
|
|
||||||
model_config=child.__pydantic_model__.__config__,
|
|
||||||
class_validators=child.__pydantic_model__.__validators__,
|
|
||||||
)
|
|
||||||
model.__model_fields__[child_model_name] = ForeignKey(
|
|
||||||
child, name=child_model_name, virtual=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def sqlalchemy_columns_from_model_fields(
|
|
||||||
name: str, object_dict: Dict, table_name: str
|
|
||||||
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
|
|
||||||
pkname: Optional[str] = None
|
|
||||||
columns: List[sqlalchemy.Column] = []
|
|
||||||
model_fields: Dict[str, BaseField] = {}
|
|
||||||
|
|
||||||
for field_name, field in object_dict.items():
|
|
||||||
if isinstance(field, BaseField):
|
|
||||||
model_fields[field_name] = field
|
|
||||||
if not field.pydantic_only:
|
|
||||||
if field.primary_key:
|
|
||||||
pkname = field_name
|
|
||||||
if isinstance(field, ForeignKey):
|
|
||||||
register_relation_on_build(table_name, field, name)
|
|
||||||
columns.append(field.get_column(field_name))
|
|
||||||
return pkname, columns, model_fields
|
|
||||||
|
|
||||||
|
|
||||||
def get_pydantic_base_orm_config() -> Type[BaseConfig]:
|
|
||||||
class Config(BaseConfig):
|
|
||||||
orm_mode = True
|
|
||||||
|
|
||||||
return Config
|
|
||||||
|
|
||||||
|
|
||||||
class ModelMetaclass(type):
|
|
||||||
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
|
|
||||||
new_model = super().__new__( # type: ignore
|
|
||||||
mcs, name, bases, attrs
|
|
||||||
)
|
|
||||||
|
|
||||||
if attrs.get("__abstract__"):
|
|
||||||
return new_model
|
|
||||||
|
|
||||||
tablename = attrs["__tablename__"]
|
|
||||||
metadata = attrs["__metadata__"]
|
|
||||||
|
|
||||||
# sqlalchemy table creation
|
|
||||||
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
|
|
||||||
name, attrs, tablename
|
|
||||||
)
|
|
||||||
attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns)
|
|
||||||
attrs["__columns__"] = columns
|
|
||||||
attrs["__pkname__"] = pkname
|
|
||||||
|
|
||||||
if not pkname:
|
|
||||||
raise ModelDefinitionError("Table has to have a primary key.")
|
|
||||||
|
|
||||||
# pydantic model creation
|
|
||||||
pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
|
|
||||||
pydantic_model = create_model(
|
|
||||||
name, __config__=get_pydantic_base_orm_config(), **pydantic_fields
|
|
||||||
)
|
|
||||||
attrs["__pydantic_fields__"] = pydantic_fields
|
|
||||||
attrs["__pydantic_model__"] = pydantic_model
|
|
||||||
attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__)
|
|
||||||
attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__)
|
|
||||||
attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__)
|
|
||||||
|
|
||||||
attrs["__model_fields__"] = model_fields
|
|
||||||
attrs["_orm_relationship_manager"] = relationship_manager
|
|
||||||
|
|
||||||
new_model = super().__new__( # type: ignore
|
|
||||||
mcs, name, bases, attrs
|
|
||||||
)
|
|
||||||
|
|
||||||
expand_reverse_relationships(new_model)
|
|
||||||
|
|
||||||
return new_model
|
|
||||||
|
|
||||||
|
|
||||||
class FakePydantic(list, metaclass=ModelMetaclass):
|
|
||||||
# FakePydantic inherits from list in order to be treated as
|
|
||||||
# request.Body parameter in fastapi routes,
|
|
||||||
# inheriting from pydantic.BaseModel causes metaclass conflicts
|
|
||||||
__abstract__ = True
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
|
||||||
__model_fields__: Dict[str, TypeVar[BaseField]]
|
|
||||||
__table__: sqlalchemy.Table
|
|
||||||
__fields__: Dict[str, pydantic.fields.ModelField]
|
|
||||||
__pydantic_model__: Type[BaseModel]
|
|
||||||
__pkname__: str
|
|
||||||
__tablename__: str
|
|
||||||
__metadata__: sqlalchemy.MetaData
|
|
||||||
__database__: databases.Database
|
|
||||||
_orm_relationship_manager: RelationshipManager
|
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self._orm_id: str = uuid.uuid4().hex
|
|
||||||
self._orm_saved: bool = False
|
|
||||||
self.values: Optional[BaseModel] = None
|
|
||||||
|
|
||||||
if "pk" in kwargs:
|
|
||||||
kwargs[self.__pkname__] = kwargs.pop("pk")
|
|
||||||
kwargs = {
|
|
||||||
k: self.__model_fields__[k].expand_relationship(v, self)
|
|
||||||
for k, v in kwargs.items()
|
|
||||||
}
|
|
||||||
self.values = self.__pydantic_model__(**kwargs)
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
|
||||||
self._orm_relationship_manager.deregister(self)
|
|
||||||
|
|
||||||
def __setattr__(self, key: str, value: Any) -> None:
|
|
||||||
if key in self.__fields__:
|
|
||||||
if self._is_conversion_to_json_needed(key) and not isinstance(value, str):
|
|
||||||
try:
|
|
||||||
value = json.dumps(value)
|
|
||||||
except TypeError: # pragma no cover
|
|
||||||
pass
|
|
||||||
|
|
||||||
value = self.__model_fields__[key].expand_relationship(value, self)
|
|
||||||
|
|
||||||
relation_key = self.__class__.__name__.title() + "_" + key
|
|
||||||
if not self._orm_relationship_manager.contains(relation_key, self):
|
|
||||||
setattr(self.values, key, value)
|
|
||||||
else:
|
|
||||||
super().__setattr__(key, value)
|
|
||||||
|
|
||||||
def __getattribute__(self, key: str) -> Any:
|
|
||||||
if key != "__fields__" and key in self.__fields__:
|
|
||||||
relation_key = self.__class__.__name__.title() + "_" + key
|
|
||||||
if self._orm_relationship_manager.contains(relation_key, self):
|
|
||||||
return self._orm_relationship_manager.get(relation_key, self)
|
|
||||||
|
|
||||||
item = getattr(self.values, key, None)
|
|
||||||
if (
|
|
||||||
item is not None
|
|
||||||
and self._is_conversion_to_json_needed(key)
|
|
||||||
and isinstance(item, str)
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
item = json.loads(item)
|
|
||||||
except TypeError: # pragma no cover
|
|
||||||
pass
|
|
||||||
return item
|
|
||||||
return super().__getattribute__(key)
|
|
||||||
|
|
||||||
def __eq__(self, other: "Model") -> bool:
|
|
||||||
return self.values.dict() == other.values.dict()
|
|
||||||
|
|
||||||
def __same__(self, other: "Model") -> bool:
|
|
||||||
if self.__class__ != other.__class__: # pragma no cover
|
|
||||||
return False
|
|
||||||
return self._orm_id == other._orm_id or (
|
|
||||||
self.values is not None and other.values is not None and self.pk == other.pk
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str: # pragma no cover
|
|
||||||
return self.values.__repr__()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_validators__(cls) -> Callable: # pragma no cover
|
|
||||||
yield cls.__pydantic_model__.validate
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_name(cls, title: bool = False, lower: bool = True) -> str:
|
|
||||||
name = cls.__name__
|
|
||||||
if lower:
|
|
||||||
name = name.lower()
|
|
||||||
if title:
|
|
||||||
name = name.title()
|
|
||||||
return name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pk_column(self) -> sqlalchemy.Column:
|
|
||||||
return self.__table__.primary_key.columns.values()[0]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def pk_type(cls) -> Any:
|
|
||||||
return cls.__model_fields__[cls.__pkname__].__type__
|
|
||||||
|
|
||||||
def dict(self) -> Dict: # noqa: A003
|
|
||||||
dict_instance = self.values.dict()
|
|
||||||
for field in self._extract_related_names():
|
|
||||||
nested_model = getattr(self, field)
|
|
||||||
if isinstance(nested_model, list):
|
|
||||||
dict_instance[field] = [x.dict() for x in nested_model]
|
|
||||||
else:
|
|
||||||
dict_instance[field] = (
|
|
||||||
nested_model.dict() if nested_model is not None else {}
|
|
||||||
)
|
|
||||||
return dict_instance
|
|
||||||
|
|
||||||
def from_dict(self, value_dict: Dict) -> None:
|
|
||||||
for key, value in value_dict.items():
|
|
||||||
setattr(self, key, value)
|
|
||||||
|
|
||||||
def _is_conversion_to_json_needed(self, column_name: str) -> bool:
|
|
||||||
return self.__model_fields__.get(column_name).__type__ == pydantic.Json
|
|
||||||
|
|
||||||
def _extract_own_model_fields(self) -> Dict:
|
|
||||||
related_names = self._extract_related_names()
|
|
||||||
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
|
|
||||||
return self_fields
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _extract_related_names(cls) -> Set:
|
|
||||||
related_names = set()
|
|
||||||
for name, field in cls.__fields__.items():
|
|
||||||
if inspect.isclass(field.type_) and issubclass(
|
|
||||||
field.type_, pydantic.BaseModel
|
|
||||||
):
|
|
||||||
related_names.add(name)
|
|
||||||
return related_names
|
|
||||||
|
|
||||||
def _extract_model_db_fields(self) -> Dict:
|
|
||||||
self_fields = self._extract_own_model_fields()
|
|
||||||
self_fields = {
|
|
||||||
k: v for k, v in self_fields.items() if k in self.__table__.columns
|
|
||||||
}
|
|
||||||
for field in self._extract_related_names():
|
|
||||||
if getattr(self, field) is not None:
|
|
||||||
self_fields[field] = getattr(
|
|
||||||
getattr(self, field), self.__model_fields__[field].to.__pkname__
|
|
||||||
)
|
|
||||||
return self_fields
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
|
|
||||||
merged_rows = []
|
|
||||||
for index, model in enumerate(result_rows):
|
|
||||||
if index > 0 and model.pk == result_rows[index - 1].pk:
|
|
||||||
result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
|
|
||||||
else:
|
|
||||||
merged_rows.append(model)
|
|
||||||
return merged_rows
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
|
|
||||||
for field in one.__model_fields__.keys():
|
|
||||||
# print(field, one.dict(), other.dict())
|
|
||||||
if isinstance(getattr(one, field), list) and not isinstance(
|
|
||||||
getattr(one, field), Model
|
|
||||||
):
|
|
||||||
setattr(other, field, getattr(one, field) + getattr(other, field))
|
|
||||||
elif isinstance(getattr(one, field), Model):
|
|
||||||
if getattr(one, field).pk == getattr(other, field).pk:
|
|
||||||
setattr(
|
|
||||||
other,
|
|
||||||
field,
|
|
||||||
cls.merge_two_instances(
|
|
||||||
getattr(one, field), getattr(other, field)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return other
|
|
||||||
|
|
||||||
|
|
||||||
class Model(FakePydantic):
|
|
||||||
__abstract__ = True
|
|
||||||
|
|
||||||
objects = qry.QuerySet()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_row(
|
|
||||||
cls,
|
|
||||||
row: sqlalchemy.engine.ResultProxy,
|
|
||||||
select_related: List = None,
|
|
||||||
previous_table: str = None,
|
|
||||||
) -> "Model":
|
|
||||||
|
|
||||||
item = {}
|
|
||||||
select_related = select_related or []
|
|
||||||
|
|
||||||
table_prefix = cls._orm_relationship_manager.resolve_relation_join(
|
|
||||||
previous_table, cls.__table__.name
|
|
||||||
)
|
|
||||||
previous_table = cls.__table__.name
|
|
||||||
for related in select_related:
|
|
||||||
if "__" in related:
|
|
||||||
first_part, remainder = related.split("__", 1)
|
|
||||||
model_cls = cls.__model_fields__[first_part].to
|
|
||||||
child = model_cls.from_row(
|
|
||||||
row, select_related=[remainder], previous_table=previous_table
|
|
||||||
)
|
|
||||||
item[first_part] = child
|
|
||||||
else:
|
|
||||||
model_cls = cls.__model_fields__[related].to
|
|
||||||
child = model_cls.from_row(row, previous_table=previous_table)
|
|
||||||
item[related] = child
|
|
||||||
|
|
||||||
for column in cls.__table__.columns:
|
|
||||||
if column.name not in item:
|
|
||||||
item[column.name] = row[
|
|
||||||
f'{table_prefix + "_" if table_prefix else ""}{column.name}'
|
|
||||||
]
|
|
||||||
|
|
||||||
return cls(**item)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pk(self) -> str:
|
|
||||||
return getattr(self.values, self.__pkname__)
|
|
||||||
|
|
||||||
@pk.setter
|
|
||||||
def pk(self, value: Any) -> None:
|
|
||||||
setattr(self.values, self.__pkname__, value)
|
|
||||||
|
|
||||||
async def save(self) -> int:
|
|
||||||
self_fields = self._extract_model_db_fields()
|
|
||||||
if self.__model_fields__.get(self.__pkname__).autoincrement:
|
|
||||||
self_fields.pop(self.__pkname__, None)
|
|
||||||
expr = self.__table__.insert()
|
|
||||||
expr = expr.values(**self_fields)
|
|
||||||
item_id = await self.__database__.execute(expr)
|
|
||||||
self.pk = item_id
|
|
||||||
return item_id
|
|
||||||
|
|
||||||
async def update(self, **kwargs: Any) -> int:
|
|
||||||
if kwargs:
|
|
||||||
new_values = {**self.dict(), **kwargs}
|
|
||||||
self.from_dict(new_values)
|
|
||||||
|
|
||||||
self_fields = self._extract_model_db_fields()
|
|
||||||
self_fields.pop(self.__pkname__)
|
|
||||||
expr = (
|
|
||||||
self.__table__.update()
|
|
||||||
.values(**self_fields)
|
|
||||||
.where(self.pk_column == getattr(self, self.__pkname__))
|
|
||||||
)
|
|
||||||
result = await self.__database__.execute(expr)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def delete(self) -> int:
|
|
||||||
expr = self.__table__.delete()
|
|
||||||
expr = expr.where(self.pk_column == (getattr(self, self.__pkname__)))
|
|
||||||
result = await self.__database__.execute(expr)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def load(self) -> "Model":
|
|
||||||
expr = self.__table__.select().where(self.pk_column == self.pk)
|
|
||||||
row = await self.__database__.fetch_one(expr)
|
|
||||||
self.from_dict(dict(row))
|
|
||||||
return self
|
|
||||||
5
orm/models/__init__.py
Normal file
5
orm/models/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from orm.models.model import Model
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Model"
|
||||||
|
]
|
||||||
195
orm/models/fakepydantic.py
Normal file
195
orm/models/fakepydantic.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import TYPE_CHECKING, Dict, TypeVar, Type, Any, Optional, Callable, Set, List
|
||||||
|
|
||||||
|
import databases
|
||||||
|
import pydantic
|
||||||
|
import sqlalchemy
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import orm
|
||||||
|
from orm.fields import BaseField
|
||||||
|
from orm.models.metaclass import ModelMetaclass
|
||||||
|
from orm.relations import RelationshipManager
|
||||||
|
|
||||||
|
if TYPE_CHECKING: #pragma no cover
|
||||||
|
from orm.models.model import Model
|
||||||
|
|
||||||
|
|
||||||
|
class FakePydantic(list, metaclass=ModelMetaclass):
|
||||||
|
# FakePydantic inherits from list in order to be treated as
|
||||||
|
# request.Body parameter in fastapi routes,
|
||||||
|
# inheriting from pydantic.BaseModel causes metaclass conflicts
|
||||||
|
__abstract__ = True
|
||||||
|
if TYPE_CHECKING: # pragma no cover
|
||||||
|
__model_fields__: Dict[str, TypeVar[BaseField]]
|
||||||
|
__table__: sqlalchemy.Table
|
||||||
|
__fields__: Dict[str, pydantic.fields.ModelField]
|
||||||
|
__pydantic_model__: Type[BaseModel]
|
||||||
|
__pkname__: str
|
||||||
|
__tablename__: str
|
||||||
|
__metadata__: sqlalchemy.MetaData
|
||||||
|
__database__: databases.Database
|
||||||
|
_orm_relationship_manager: RelationshipManager
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._orm_id: str = uuid.uuid4().hex
|
||||||
|
self._orm_saved: bool = False
|
||||||
|
self.values: Optional[BaseModel] = None
|
||||||
|
|
||||||
|
if "pk" in kwargs:
|
||||||
|
kwargs[self.__pkname__] = kwargs.pop("pk")
|
||||||
|
kwargs = {
|
||||||
|
k: self.__model_fields__[k].expand_relationship(v, self)
|
||||||
|
for k, v in kwargs.items()
|
||||||
|
}
|
||||||
|
self.values = self.__pydantic_model__(**kwargs)
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
self._orm_relationship_manager.deregister(self)
|
||||||
|
|
||||||
|
def __setattr__(self, key: str, value: Any) -> None:
|
||||||
|
if key in self.__fields__:
|
||||||
|
if self._is_conversion_to_json_needed(key) and not isinstance(value, str):
|
||||||
|
try:
|
||||||
|
value = json.dumps(value)
|
||||||
|
except TypeError: # pragma no cover
|
||||||
|
pass
|
||||||
|
|
||||||
|
value = self.__model_fields__[key].expand_relationship(value, self)
|
||||||
|
|
||||||
|
relation_key = self.__class__.__name__.title() + "_" + key
|
||||||
|
if not self._orm_relationship_manager.contains(relation_key, self):
|
||||||
|
setattr(self.values, key, value)
|
||||||
|
else:
|
||||||
|
super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattribute__(self, key: str) -> Any:
|
||||||
|
if key != "__fields__" and key in self.__fields__:
|
||||||
|
relation_key = self.__class__.__name__.title() + "_" + key
|
||||||
|
if self._orm_relationship_manager.contains(relation_key, self):
|
||||||
|
return self._orm_relationship_manager.get(relation_key, self)
|
||||||
|
|
||||||
|
item = getattr(self.values, key, None)
|
||||||
|
if (
|
||||||
|
item is not None
|
||||||
|
and self._is_conversion_to_json_needed(key)
|
||||||
|
and isinstance(item, str)
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
item = json.loads(item)
|
||||||
|
except TypeError: # pragma no cover
|
||||||
|
pass
|
||||||
|
return item
|
||||||
|
return super().__getattribute__(key)
|
||||||
|
|
||||||
|
def __eq__(self, other: "Model") -> bool:
|
||||||
|
return self.values.dict() == other.values.dict()
|
||||||
|
|
||||||
|
def __same__(self, other: "Model") -> bool:
|
||||||
|
if self.__class__ != other.__class__: # pragma no cover
|
||||||
|
return False
|
||||||
|
return self._orm_id == other._orm_id or (
|
||||||
|
self.values is not None and other.values is not None and self.pk == other.pk
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str: # pragma no cover
|
||||||
|
return self.values.__repr__()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_validators__(cls) -> Callable: # pragma no cover
|
||||||
|
yield cls.__pydantic_model__.validate
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls, title: bool = False, lower: bool = True) -> str:
|
||||||
|
name = cls.__name__
|
||||||
|
if lower:
|
||||||
|
name = name.lower()
|
||||||
|
if title:
|
||||||
|
name = name.title()
|
||||||
|
return name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pk_column(self) -> sqlalchemy.Column:
|
||||||
|
return self.__table__.primary_key.columns.values()[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pk_type(cls) -> Any:
|
||||||
|
return cls.__model_fields__[cls.__pkname__].__type__
|
||||||
|
|
||||||
|
def dict(self) -> Dict: # noqa: A003
|
||||||
|
dict_instance = self.values.dict()
|
||||||
|
for field in self._extract_related_names():
|
||||||
|
nested_model = getattr(self, field)
|
||||||
|
if isinstance(nested_model, list):
|
||||||
|
dict_instance[field] = [x.dict() for x in nested_model]
|
||||||
|
else:
|
||||||
|
dict_instance[field] = (
|
||||||
|
nested_model.dict() if nested_model is not None else {}
|
||||||
|
)
|
||||||
|
return dict_instance
|
||||||
|
|
||||||
|
def from_dict(self, value_dict: Dict) -> None:
|
||||||
|
for key, value in value_dict.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def _is_conversion_to_json_needed(self, column_name: str) -> bool:
|
||||||
|
return self.__model_fields__.get(column_name).__type__ == pydantic.Json
|
||||||
|
|
||||||
|
def _extract_own_model_fields(self) -> Dict:
|
||||||
|
related_names = self._extract_related_names()
|
||||||
|
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
|
||||||
|
return self_fields
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_related_names(cls) -> Set:
|
||||||
|
related_names = set()
|
||||||
|
for name, field in cls.__fields__.items():
|
||||||
|
if inspect.isclass(field.type_) and issubclass(
|
||||||
|
field.type_, pydantic.BaseModel
|
||||||
|
):
|
||||||
|
related_names.add(name)
|
||||||
|
return related_names
|
||||||
|
|
||||||
|
def _extract_model_db_fields(self) -> Dict:
|
||||||
|
self_fields = self._extract_own_model_fields()
|
||||||
|
self_fields = {
|
||||||
|
k: v for k, v in self_fields.items() if k in self.__table__.columns
|
||||||
|
}
|
||||||
|
for field in self._extract_related_names():
|
||||||
|
if getattr(self, field) is not None:
|
||||||
|
self_fields[field] = getattr(
|
||||||
|
getattr(self, field), self.__model_fields__[field].to.__pkname__
|
||||||
|
)
|
||||||
|
return self_fields
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
|
||||||
|
merged_rows = []
|
||||||
|
for index, model in enumerate(result_rows):
|
||||||
|
if index > 0 and model.pk == result_rows[index - 1].pk:
|
||||||
|
result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
|
||||||
|
else:
|
||||||
|
merged_rows.append(model)
|
||||||
|
return merged_rows
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
|
||||||
|
for field in one.__model_fields__.keys():
|
||||||
|
# print(field, one.dict(), other.dict())
|
||||||
|
if isinstance(getattr(one, field), list) and not isinstance(
|
||||||
|
getattr(one, field), orm.Model
|
||||||
|
):
|
||||||
|
setattr(other, field, getattr(one, field) + getattr(other, field))
|
||||||
|
elif isinstance(getattr(one, field), orm.Model):
|
||||||
|
if getattr(one, field).pk == getattr(other, field).pk:
|
||||||
|
setattr(
|
||||||
|
other,
|
||||||
|
field,
|
||||||
|
cls.merge_two_instances(
|
||||||
|
getattr(one, field), getattr(other, field)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return other
|
||||||
132
orm/models/metaclass.py
Normal file
132
orm/models/metaclass.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
import copy
|
||||||
|
from typing import Dict, Tuple, Type, Optional, List, Any
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from pydantic import BaseConfig, create_model
|
||||||
|
from pydantic.fields import ModelField
|
||||||
|
|
||||||
|
from orm import ForeignKey, ModelDefinitionError
|
||||||
|
from orm.fields import BaseField
|
||||||
|
from orm.relations import RelationshipManager
|
||||||
|
|
||||||
|
relationship_manager = RelationshipManager()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]:
|
||||||
|
pydantic_fields = {
|
||||||
|
field_name: (
|
||||||
|
base_field.__type__,
|
||||||
|
... if base_field.is_required else base_field.default_value,
|
||||||
|
)
|
||||||
|
for field_name, base_field in object_dict.items()
|
||||||
|
if isinstance(base_field, BaseField)
|
||||||
|
}
|
||||||
|
return pydantic_fields
|
||||||
|
|
||||||
|
|
||||||
|
def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None:
|
||||||
|
child_relation_name = field.to.get_name(title=True) + "_" + name.lower() + "s"
|
||||||
|
reverse_name = field.related_name or child_relation_name
|
||||||
|
relation_name = name.lower().title() + "_" + field.to.get_name()
|
||||||
|
relationship_manager.add_relation_type(
|
||||||
|
relation_name, reverse_name, field, table_name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_reverse_relationships(model: Type["Model"]) -> None:
|
||||||
|
for model_field in model.__model_fields__.values():
|
||||||
|
if isinstance(model_field, ForeignKey):
|
||||||
|
child_model_name = model_field.related_name or model.get_name() + "s"
|
||||||
|
parent_model = model_field.to
|
||||||
|
child = model
|
||||||
|
if (
|
||||||
|
child_model_name not in parent_model.__fields__
|
||||||
|
and child.get_name() not in parent_model.__fields__
|
||||||
|
):
|
||||||
|
register_reverse_model_fields(parent_model, child, child_model_name)
|
||||||
|
|
||||||
|
|
||||||
|
def register_reverse_model_fields(
|
||||||
|
model: Type["Model"], child: Type["Model"], child_model_name: str
|
||||||
|
) -> None:
|
||||||
|
model.__fields__[child_model_name] = ModelField(
|
||||||
|
name=child_model_name,
|
||||||
|
type_=Optional[child.__pydantic_model__],
|
||||||
|
model_config=child.__pydantic_model__.__config__,
|
||||||
|
class_validators=child.__pydantic_model__.__validators__,
|
||||||
|
)
|
||||||
|
model.__model_fields__[child_model_name] = ForeignKey(
|
||||||
|
child, name=child_model_name, virtual=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sqlalchemy_columns_from_model_fields(
|
||||||
|
name: str, object_dict: Dict, table_name: str
|
||||||
|
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
|
||||||
|
pkname: Optional[str] = None
|
||||||
|
columns: List[sqlalchemy.Column] = []
|
||||||
|
model_fields: Dict[str, BaseField] = {}
|
||||||
|
|
||||||
|
for field_name, field in object_dict.items():
|
||||||
|
if isinstance(field, BaseField):
|
||||||
|
model_fields[field_name] = field
|
||||||
|
if not field.pydantic_only:
|
||||||
|
if field.primary_key:
|
||||||
|
pkname = field_name
|
||||||
|
if isinstance(field, ForeignKey):
|
||||||
|
register_relation_on_build(table_name, field, name)
|
||||||
|
columns.append(field.get_column(field_name))
|
||||||
|
return pkname, columns, model_fields
|
||||||
|
|
||||||
|
|
||||||
|
def get_pydantic_base_orm_config() -> Type[BaseConfig]:
|
||||||
|
class Config(BaseConfig):
|
||||||
|
orm_mode = True
|
||||||
|
|
||||||
|
return Config
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMetaclass(type):
|
||||||
|
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
|
||||||
|
new_model = super().__new__( # type: ignore
|
||||||
|
mcs, name, bases, attrs
|
||||||
|
)
|
||||||
|
|
||||||
|
if attrs.get("__abstract__"):
|
||||||
|
return new_model
|
||||||
|
|
||||||
|
tablename = attrs["__tablename__"]
|
||||||
|
metadata = attrs["__metadata__"]
|
||||||
|
|
||||||
|
# sqlalchemy table creation
|
||||||
|
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
|
||||||
|
name, attrs, tablename
|
||||||
|
)
|
||||||
|
attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns)
|
||||||
|
attrs["__columns__"] = columns
|
||||||
|
attrs["__pkname__"] = pkname
|
||||||
|
|
||||||
|
if not pkname:
|
||||||
|
raise ModelDefinitionError("Table has to have a primary key.")
|
||||||
|
|
||||||
|
# pydantic model creation
|
||||||
|
pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
|
||||||
|
pydantic_model = create_model(
|
||||||
|
name, __config__=get_pydantic_base_orm_config(), **pydantic_fields
|
||||||
|
)
|
||||||
|
attrs["__pydantic_fields__"] = pydantic_fields
|
||||||
|
attrs["__pydantic_model__"] = pydantic_model
|
||||||
|
attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__)
|
||||||
|
attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__)
|
||||||
|
attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__)
|
||||||
|
|
||||||
|
attrs["__model_fields__"] = model_fields
|
||||||
|
attrs["_orm_relationship_manager"] = relationship_manager
|
||||||
|
|
||||||
|
new_model = super().__new__( # type: ignore
|
||||||
|
mcs, name, bases, attrs
|
||||||
|
)
|
||||||
|
|
||||||
|
expand_reverse_relationships(new_model)
|
||||||
|
|
||||||
|
return new_model
|
||||||
93
orm/models/model.py
Normal file
93
orm/models/model.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
from typing import List, Any
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
|
||||||
|
import orm.queryset.queryset
|
||||||
|
from orm.models.fakepydantic import FakePydantic
|
||||||
|
|
||||||
|
|
||||||
|
class Model(FakePydantic):
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
objects = orm.queryset.queryset.QuerySet()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_row(
|
||||||
|
cls,
|
||||||
|
row: sqlalchemy.engine.ResultProxy,
|
||||||
|
select_related: List = None,
|
||||||
|
previous_table: str = None,
|
||||||
|
) -> "Model":
|
||||||
|
|
||||||
|
item = {}
|
||||||
|
select_related = select_related or []
|
||||||
|
|
||||||
|
table_prefix = cls._orm_relationship_manager.resolve_relation_join(
|
||||||
|
previous_table, cls.__table__.name
|
||||||
|
)
|
||||||
|
previous_table = cls.__table__.name
|
||||||
|
for related in select_related:
|
||||||
|
if "__" in related:
|
||||||
|
first_part, remainder = related.split("__", 1)
|
||||||
|
model_cls = cls.__model_fields__[first_part].to
|
||||||
|
child = model_cls.from_row(
|
||||||
|
row, select_related=[remainder], previous_table=previous_table
|
||||||
|
)
|
||||||
|
item[first_part] = child
|
||||||
|
else:
|
||||||
|
model_cls = cls.__model_fields__[related].to
|
||||||
|
child = model_cls.from_row(row, previous_table=previous_table)
|
||||||
|
item[related] = child
|
||||||
|
|
||||||
|
for column in cls.__table__.columns:
|
||||||
|
if column.name not in item:
|
||||||
|
item[column.name] = row[
|
||||||
|
f'{table_prefix + "_" if table_prefix else ""}{column.name}'
|
||||||
|
]
|
||||||
|
|
||||||
|
return cls(**item)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pk(self) -> str:
|
||||||
|
return getattr(self.values, self.__pkname__)
|
||||||
|
|
||||||
|
@pk.setter
|
||||||
|
def pk(self, value: Any) -> None:
|
||||||
|
setattr(self.values, self.__pkname__, value)
|
||||||
|
|
||||||
|
async def save(self) -> int:
|
||||||
|
self_fields = self._extract_model_db_fields()
|
||||||
|
if self.__model_fields__.get(self.__pkname__).autoincrement:
|
||||||
|
self_fields.pop(self.__pkname__, None)
|
||||||
|
expr = self.__table__.insert()
|
||||||
|
expr = expr.values(**self_fields)
|
||||||
|
item_id = await self.__database__.execute(expr)
|
||||||
|
self.pk = item_id
|
||||||
|
return item_id
|
||||||
|
|
||||||
|
async def update(self, **kwargs: Any) -> int:
|
||||||
|
if kwargs:
|
||||||
|
new_values = {**self.dict(), **kwargs}
|
||||||
|
self.from_dict(new_values)
|
||||||
|
|
||||||
|
self_fields = self._extract_model_db_fields()
|
||||||
|
self_fields.pop(self.__pkname__)
|
||||||
|
expr = (
|
||||||
|
self.__table__.update()
|
||||||
|
.values(**self_fields)
|
||||||
|
.where(self.pk_column == getattr(self, self.__pkname__))
|
||||||
|
)
|
||||||
|
result = await self.__database__.execute(expr)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def delete(self) -> int:
|
||||||
|
expr = self.__table__.delete()
|
||||||
|
expr = expr.where(self.pk_column == (getattr(self, self.__pkname__)))
|
||||||
|
result = await self.__database__.execute(expr)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def load(self) -> "Model":
|
||||||
|
expr = self.__table__.select().where(self.pk_column == self.pk)
|
||||||
|
row = await self.__database__.fetch_one(expr)
|
||||||
|
self.from_dict(dict(row))
|
||||||
|
return self
|
||||||
571
orm/queryset.py
571
orm/queryset.py
@ -1,571 +0,0 @@
|
|||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import databases
|
|
||||||
import sqlalchemy
|
|
||||||
from sqlalchemy import text
|
|
||||||
|
|
||||||
import orm # noqa I100
|
|
||||||
import orm.fields.foreign_key
|
|
||||||
from orm import ForeignKey
|
|
||||||
from orm.exceptions import MultipleMatches, NoMatch, QueryDefinitionError
|
|
||||||
from orm.fields.base import BaseField
|
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
|
||||||
from orm.models import Model
|
|
||||||
|
|
||||||
FILTER_OPERATORS = {
|
|
||||||
"exact": "__eq__",
|
|
||||||
"iexact": "ilike",
|
|
||||||
"contains": "like",
|
|
||||||
"icontains": "ilike",
|
|
||||||
"in": "in_",
|
|
||||||
"gt": "__gt__",
|
|
||||||
"gte": "__ge__",
|
|
||||||
"lt": "__lt__",
|
|
||||||
"lte": "__le__",
|
|
||||||
}
|
|
||||||
|
|
||||||
ESCAPE_CHARACTERS = ["%", "_"]
|
|
||||||
|
|
||||||
|
|
||||||
class JoinParameters(NamedTuple):
|
|
||||||
prev_model: Type["Model"]
|
|
||||||
previous_alias: str
|
|
||||||
from_table: str
|
|
||||||
model_cls: Type["Model"]
|
|
||||||
|
|
||||||
|
|
||||||
class Query:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_cls: Type["Model"],
|
|
||||||
filter_clauses: List,
|
|
||||||
select_related: List,
|
|
||||||
limit_count: int,
|
|
||||||
offset: int,
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
self.query_offset = offset
|
|
||||||
self.limit_count = limit_count
|
|
||||||
self._select_related = select_related
|
|
||||||
self.filter_clauses = filter_clauses
|
|
||||||
|
|
||||||
self.model_cls = model_cls
|
|
||||||
self.table = self.model_cls.__table__
|
|
||||||
|
|
||||||
self.auto_related = []
|
|
||||||
self.used_aliases = []
|
|
||||||
|
|
||||||
self.select_from = None
|
|
||||||
self.columns = None
|
|
||||||
self.order_bys = None
|
|
||||||
|
|
||||||
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
|
|
||||||
self.columns = list(self.table.columns)
|
|
||||||
self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")]
|
|
||||||
self.select_from = self.table
|
|
||||||
|
|
||||||
for key in self.model_cls.__model_fields__:
|
|
||||||
if (
|
|
||||||
not self.model_cls.__model_fields__[key].nullable
|
|
||||||
and isinstance(
|
|
||||||
self.model_cls.__model_fields__[key],
|
|
||||||
orm.fields.foreign_key.ForeignKey,
|
|
||||||
)
|
|
||||||
and key not in self._select_related
|
|
||||||
):
|
|
||||||
self._select_related = [key] + self._select_related
|
|
||||||
|
|
||||||
start_params = JoinParameters(
|
|
||||||
self.model_cls, "", self.table.name, self.model_cls
|
|
||||||
)
|
|
||||||
self._extract_auto_required_relations(prev_model=start_params.prev_model)
|
|
||||||
self._include_auto_related_models()
|
|
||||||
self._select_related.sort(key=lambda item: (-len(item), item))
|
|
||||||
|
|
||||||
for item in self._select_related:
|
|
||||||
join_parameters = JoinParameters(
|
|
||||||
self.model_cls, "", self.table.name, self.model_cls
|
|
||||||
)
|
|
||||||
|
|
||||||
for part in item.split("__"):
|
|
||||||
join_parameters = self._build_join_parameters(part, join_parameters)
|
|
||||||
|
|
||||||
expr = sqlalchemy.sql.select(self.columns)
|
|
||||||
expr = expr.select_from(self.select_from)
|
|
||||||
|
|
||||||
expr = self._apply_expression_modifiers(expr)
|
|
||||||
|
|
||||||
# print(expr.compile(compile_kwargs={"literal_binds": True}))
|
|
||||||
self._reset_query_parameters()
|
|
||||||
|
|
||||||
return expr, self._select_related
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
|
|
||||||
return [
|
|
||||||
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
|
|
||||||
for column in table.columns
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def prefixed_table_name(alias: str, name: str) -> text:
|
|
||||||
return text(f"{name} {alias}_{name}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _field_is_a_foreign_key_and_no_circular_reference(
|
|
||||||
field: BaseField, field_name: str, rel_part: str
|
|
||||||
) -> bool:
|
|
||||||
return isinstance(field, ForeignKey) and field_name not in rel_part
|
|
||||||
|
|
||||||
def _field_qualifies_to_deeper_search(
|
|
||||||
self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str
|
|
||||||
) -> bool:
|
|
||||||
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
|
|
||||||
partial_match = any(
|
|
||||||
[x.startswith(prev_part_of_related) for x in self._select_related]
|
|
||||||
)
|
|
||||||
already_checked = any([x.startswith(rel_part) for x in self.auto_related])
|
|
||||||
return (
|
|
||||||
(field.virtual and parent_virtual)
|
|
||||||
or (partial_match and not already_checked)
|
|
||||||
) or not nested
|
|
||||||
|
|
||||||
def on_clause(
|
|
||||||
self, previous_alias: str, alias: str, from_clause: str, to_clause: str,
|
|
||||||
) -> text:
|
|
||||||
left_part = f"{alias}_{to_clause}"
|
|
||||||
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
|
|
||||||
return text(f"{left_part}={right_part}")
|
|
||||||
|
|
||||||
def _build_join_parameters(
|
|
||||||
self, part: str, join_params: JoinParameters
|
|
||||||
) -> JoinParameters:
|
|
||||||
model_cls = join_params.model_cls.__model_fields__[part].to
|
|
||||||
to_table = model_cls.__table__.name
|
|
||||||
|
|
||||||
alias = model_cls._orm_relationship_manager.resolve_relation_join(
|
|
||||||
join_params.from_table, to_table
|
|
||||||
)
|
|
||||||
if alias not in self.used_aliases:
|
|
||||||
if join_params.prev_model.__model_fields__[part].virtual:
|
|
||||||
to_key = next(
|
|
||||||
(
|
|
||||||
v
|
|
||||||
for k, v in model_cls.__model_fields__.items()
|
|
||||||
if isinstance(v, ForeignKey) and v.to == join_params.prev_model
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
).name
|
|
||||||
from_key = model_cls.__pkname__
|
|
||||||
else:
|
|
||||||
to_key = model_cls.__pkname__
|
|
||||||
from_key = part
|
|
||||||
|
|
||||||
on_clause = self.on_clause(
|
|
||||||
previous_alias=join_params.previous_alias,
|
|
||||||
alias=alias,
|
|
||||||
from_clause=f"{join_params.from_table}.{from_key}",
|
|
||||||
to_clause=f"{to_table}.{to_key}",
|
|
||||||
)
|
|
||||||
target_table = self.prefixed_table_name(alias, to_table)
|
|
||||||
self.select_from = sqlalchemy.sql.outerjoin(
|
|
||||||
self.select_from, target_table, on_clause
|
|
||||||
)
|
|
||||||
self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}"))
|
|
||||||
self.columns.extend(self.prefixed_columns(alias, model_cls.__table__))
|
|
||||||
self.used_aliases.append(alias)
|
|
||||||
|
|
||||||
previous_alias = alias
|
|
||||||
from_table = to_table
|
|
||||||
prev_model = model_cls
|
|
||||||
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
|
||||||
|
|
||||||
def _extract_auto_required_relations(
|
|
||||||
self,
|
|
||||||
prev_model: Type["Model"],
|
|
||||||
rel_part: str = "",
|
|
||||||
nested: bool = False,
|
|
||||||
parent_virtual: bool = False,
|
|
||||||
) -> None:
|
|
||||||
for field_name, field in prev_model.__model_fields__.items():
|
|
||||||
if self._field_is_a_foreign_key_and_no_circular_reference(
|
|
||||||
field, field_name, rel_part
|
|
||||||
):
|
|
||||||
rel_part = field_name if not rel_part else rel_part + "__" + field_name
|
|
||||||
if not field.nullable:
|
|
||||||
if rel_part not in self._select_related:
|
|
||||||
self.auto_related.append("__".join(rel_part.split("__")[:-1]))
|
|
||||||
rel_part = ""
|
|
||||||
elif self._field_qualifies_to_deeper_search(
|
|
||||||
field, parent_virtual, nested, rel_part
|
|
||||||
):
|
|
||||||
self._extract_auto_required_relations(
|
|
||||||
prev_model=field.to,
|
|
||||||
rel_part=rel_part,
|
|
||||||
nested=True,
|
|
||||||
parent_virtual=field.virtual,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
rel_part = ""
|
|
||||||
|
|
||||||
def _include_auto_related_models(self) -> None:
|
|
||||||
if self.auto_related:
|
|
||||||
new_joins = []
|
|
||||||
for join in self._select_related:
|
|
||||||
if not any([x.startswith(join) for x in self.auto_related]):
|
|
||||||
new_joins.append(join)
|
|
||||||
self._select_related = new_joins + self.auto_related
|
|
||||||
|
|
||||||
def _apply_expression_modifiers(
|
|
||||||
self, expr: sqlalchemy.sql.select
|
|
||||||
) -> sqlalchemy.sql.select:
|
|
||||||
if self.filter_clauses:
|
|
||||||
if len(self.filter_clauses) == 1:
|
|
||||||
clause = self.filter_clauses[0]
|
|
||||||
else:
|
|
||||||
clause = sqlalchemy.sql.and_(*self.filter_clauses)
|
|
||||||
expr = expr.where(clause)
|
|
||||||
|
|
||||||
if self.limit_count:
|
|
||||||
expr = expr.limit(self.limit_count)
|
|
||||||
|
|
||||||
if self.query_offset:
|
|
||||||
expr = expr.offset(self.query_offset)
|
|
||||||
|
|
||||||
for order in self.order_bys:
|
|
||||||
expr = expr.order_by(order)
|
|
||||||
return expr
|
|
||||||
|
|
||||||
def _reset_query_parameters(self) -> None:
|
|
||||||
self.select_from = None
|
|
||||||
self.columns = None
|
|
||||||
self.order_bys = None
|
|
||||||
self.auto_related = []
|
|
||||||
self.used_aliases = []
|
|
||||||
|
|
||||||
|
|
||||||
class QueryClause:
|
|
||||||
def __init__(
|
|
||||||
self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
self._select_related = select_related
|
|
||||||
self.filter_clauses = filter_clauses
|
|
||||||
|
|
||||||
self.model_cls = model_cls
|
|
||||||
self.table = self.model_cls.__table__
|
|
||||||
|
|
||||||
def filter( # noqa: A003
|
|
||||||
self, **kwargs: Any
|
|
||||||
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
|
|
||||||
filter_clauses = self.filter_clauses
|
|
||||||
select_related = list(self._select_related)
|
|
||||||
|
|
||||||
if kwargs.get("pk"):
|
|
||||||
pk_name = self.model_cls.__pkname__
|
|
||||||
kwargs[pk_name] = kwargs.pop("pk")
|
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
table_prefix = ""
|
|
||||||
if "__" in key:
|
|
||||||
parts = key.split("__")
|
|
||||||
|
|
||||||
(
|
|
||||||
op,
|
|
||||||
field_name,
|
|
||||||
related_parts,
|
|
||||||
) = self._extract_operator_field_and_related(parts)
|
|
||||||
|
|
||||||
model_cls = self.model_cls
|
|
||||||
if related_parts:
|
|
||||||
(
|
|
||||||
select_related,
|
|
||||||
table_prefix,
|
|
||||||
model_cls,
|
|
||||||
) = self._determine_filter_target_table(
|
|
||||||
related_parts, select_related
|
|
||||||
)
|
|
||||||
|
|
||||||
table = model_cls.__table__
|
|
||||||
column = model_cls.__table__.columns[field_name]
|
|
||||||
|
|
||||||
else:
|
|
||||||
op = "exact"
|
|
||||||
column = self.table.columns[key]
|
|
||||||
table = self.table
|
|
||||||
|
|
||||||
value, has_escaped_character = self._escape_characters_in_clause(op, value)
|
|
||||||
|
|
||||||
if isinstance(value, orm.Model):
|
|
||||||
value = value.pk
|
|
||||||
|
|
||||||
op_attr = FILTER_OPERATORS[op]
|
|
||||||
clause = getattr(column, op_attr)(value)
|
|
||||||
clause = self._compile_clause(
|
|
||||||
clause,
|
|
||||||
column,
|
|
||||||
table,
|
|
||||||
table_prefix,
|
|
||||||
modifiers={"escape": "\\" if has_escaped_character else None},
|
|
||||||
)
|
|
||||||
filter_clauses.append(clause)
|
|
||||||
|
|
||||||
return filter_clauses, select_related
|
|
||||||
|
|
||||||
def _determine_filter_target_table(
|
|
||||||
self, related_parts: List[str], select_related: List[str]
|
|
||||||
) -> Tuple[List[str], str, "Model"]:
|
|
||||||
|
|
||||||
table_prefix = ""
|
|
||||||
model_cls = self.model_cls
|
|
||||||
select_related = [relation for relation in select_related]
|
|
||||||
|
|
||||||
# Add any implied select_related
|
|
||||||
related_str = "__".join(related_parts)
|
|
||||||
if related_str not in select_related:
|
|
||||||
select_related.append(related_str)
|
|
||||||
|
|
||||||
# Walk the relationships to the actual model class
|
|
||||||
# against which the comparison is being made.
|
|
||||||
previous_table = model_cls.__tablename__
|
|
||||||
for part in related_parts:
|
|
||||||
current_table = model_cls.__model_fields__[part].to.__tablename__
|
|
||||||
manager = model_cls._orm_relationship_manager
|
|
||||||
table_prefix = manager.resolve_relation_join(previous_table, current_table)
|
|
||||||
model_cls = model_cls.__model_fields__[part].to
|
|
||||||
previous_table = current_table
|
|
||||||
return select_related, table_prefix, model_cls
|
|
||||||
|
|
||||||
def _compile_clause(
|
|
||||||
self,
|
|
||||||
clause: sqlalchemy.sql.expression.BinaryExpression,
|
|
||||||
column: sqlalchemy.Column,
|
|
||||||
table: sqlalchemy.Table,
|
|
||||||
table_prefix: str,
|
|
||||||
modifiers: Dict,
|
|
||||||
) -> sqlalchemy.sql.expression.TextClause:
|
|
||||||
for modifier, modifier_value in modifiers.items():
|
|
||||||
clause.modifiers[modifier] = modifier_value
|
|
||||||
|
|
||||||
clause_text = str(
|
|
||||||
clause.compile(
|
|
||||||
dialect=self.model_cls.__database__._backend._dialect,
|
|
||||||
compile_kwargs={"literal_binds": True},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
alias = f"{table_prefix}_" if table_prefix else ""
|
|
||||||
aliased_name = f"{alias}{table.name}.{column.name}"
|
|
||||||
clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name)
|
|
||||||
clause = text(clause_text)
|
|
||||||
return clause
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _escape_characters_in_clause(
|
|
||||||
op: str, value: Union[str, "Model"]
|
|
||||||
) -> Tuple[str, bool]:
|
|
||||||
has_escaped_character = False
|
|
||||||
|
|
||||||
if op in ["contains", "icontains"]:
|
|
||||||
if isinstance(value, orm.Model):
|
|
||||||
raise QueryDefinitionError(
|
|
||||||
"You cannot use contains and icontains with instance of the Model"
|
|
||||||
)
|
|
||||||
|
|
||||||
has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value)
|
|
||||||
|
|
||||||
if has_escaped_character:
|
|
||||||
# enable escape modifier
|
|
||||||
for char in ESCAPE_CHARACTERS:
|
|
||||||
value = value.replace(char, f"\\{char}")
|
|
||||||
value = f"%{value}%"
|
|
||||||
|
|
||||||
return value, has_escaped_character
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_operator_field_and_related(
|
|
||||||
parts: List[str],
|
|
||||||
) -> Tuple[str, str, Optional[List]]:
|
|
||||||
if parts[-1] in FILTER_OPERATORS:
|
|
||||||
op = parts[-1]
|
|
||||||
field_name = parts[-2]
|
|
||||||
related_parts = parts[:-2]
|
|
||||||
else:
|
|
||||||
op = "exact"
|
|
||||||
field_name = parts[-1]
|
|
||||||
related_parts = parts[:-1]
|
|
||||||
|
|
||||||
return op, field_name, related_parts
|
|
||||||
|
|
||||||
|
|
||||||
class QuerySet:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_cls: Type["Model"] = None,
|
|
||||||
filter_clauses: List = None,
|
|
||||||
select_related: List = None,
|
|
||||||
limit_count: int = None,
|
|
||||||
offset: int = None,
|
|
||||||
) -> None:
|
|
||||||
self.model_cls = model_cls
|
|
||||||
self.filter_clauses = [] if filter_clauses is None else filter_clauses
|
|
||||||
self._select_related = [] if select_related is None else select_related
|
|
||||||
self.limit_count = limit_count
|
|
||||||
self.query_offset = offset
|
|
||||||
self.order_bys = None
|
|
||||||
|
|
||||||
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
|
|
||||||
return self.__class__(model_cls=owner)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def database(self) -> databases.Database:
|
|
||||||
return self.model_cls.__database__
|
|
||||||
|
|
||||||
@property
|
|
||||||
def table(self) -> sqlalchemy.Table:
|
|
||||||
return self.model_cls.__table__
|
|
||||||
|
|
||||||
def build_select_expression(self) -> sqlalchemy.sql.select:
|
|
||||||
qry = Query(
|
|
||||||
model_cls=self.model_cls,
|
|
||||||
select_related=self._select_related,
|
|
||||||
filter_clauses=self.filter_clauses,
|
|
||||||
offset=self.query_offset,
|
|
||||||
limit_count=self.limit_count,
|
|
||||||
)
|
|
||||||
exp, self._select_related = qry.build_select_expression()
|
|
||||||
return exp
|
|
||||||
|
|
||||||
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
|
|
||||||
qryclause = QueryClause(
|
|
||||||
model_cls=self.model_cls,
|
|
||||||
select_related=self._select_related,
|
|
||||||
filter_clauses=self.filter_clauses,
|
|
||||||
)
|
|
||||||
filter_clauses, select_related = qryclause.filter(**kwargs)
|
|
||||||
|
|
||||||
return self.__class__(
|
|
||||||
model_cls=self.model_cls,
|
|
||||||
filter_clauses=filter_clauses,
|
|
||||||
select_related=select_related,
|
|
||||||
limit_count=self.limit_count,
|
|
||||||
offset=self.query_offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet":
|
|
||||||
if not isinstance(related, (list, tuple)):
|
|
||||||
related = [related]
|
|
||||||
|
|
||||||
related = list(self._select_related) + related
|
|
||||||
return self.__class__(
|
|
||||||
model_cls=self.model_cls,
|
|
||||||
filter_clauses=self.filter_clauses,
|
|
||||||
select_related=related,
|
|
||||||
limit_count=self.limit_count,
|
|
||||||
offset=self.query_offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def exists(self) -> bool:
|
|
||||||
expr = self.build_select_expression()
|
|
||||||
expr = sqlalchemy.exists(expr).select()
|
|
||||||
return await self.database.fetch_val(expr)
|
|
||||||
|
|
||||||
async def count(self) -> int:
|
|
||||||
expr = self.build_select_expression().alias("subquery_for_count")
|
|
||||||
expr = sqlalchemy.func.count().select().select_from(expr)
|
|
||||||
return await self.database.fetch_val(expr)
|
|
||||||
|
|
||||||
def limit(self, limit_count: int) -> "QuerySet":
|
|
||||||
return self.__class__(
|
|
||||||
model_cls=self.model_cls,
|
|
||||||
filter_clauses=self.filter_clauses,
|
|
||||||
select_related=self._select_related,
|
|
||||||
limit_count=limit_count,
|
|
||||||
offset=self.query_offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
def offset(self, offset: int) -> "QuerySet":
|
|
||||||
return self.__class__(
|
|
||||||
model_cls=self.model_cls,
|
|
||||||
filter_clauses=self.filter_clauses,
|
|
||||||
select_related=self._select_related,
|
|
||||||
limit_count=self.limit_count,
|
|
||||||
offset=offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def first(self, **kwargs: Any) -> "Model":
|
|
||||||
if kwargs:
|
|
||||||
return await self.filter(**kwargs).first()
|
|
||||||
|
|
||||||
rows = await self.limit(1).all()
|
|
||||||
if rows:
|
|
||||||
return rows[0]
|
|
||||||
|
|
||||||
async def get(self, **kwargs: Any) -> "Model":
|
|
||||||
if kwargs:
|
|
||||||
return await self.filter(**kwargs).get()
|
|
||||||
|
|
||||||
expr = self.build_select_expression().limit(2)
|
|
||||||
rows = await self.database.fetch_all(expr)
|
|
||||||
|
|
||||||
if not rows:
|
|
||||||
raise NoMatch()
|
|
||||||
if len(rows) > 1:
|
|
||||||
raise MultipleMatches()
|
|
||||||
return self.model_cls.from_row(rows[0], select_related=self._select_related)
|
|
||||||
|
|
||||||
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
|
|
||||||
if kwargs:
|
|
||||||
return await self.filter(**kwargs).all()
|
|
||||||
|
|
||||||
expr = self.build_select_expression()
|
|
||||||
rows = await self.database.fetch_all(expr)
|
|
||||||
result_rows = [
|
|
||||||
self.model_cls.from_row(row, select_related=self._select_related)
|
|
||||||
for row in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
result_rows = self.model_cls.merge_instances_list(result_rows)
|
|
||||||
|
|
||||||
return result_rows
|
|
||||||
|
|
||||||
async def create(self, **kwargs: Any) -> "Model":
|
|
||||||
|
|
||||||
new_kwargs = dict(**kwargs)
|
|
||||||
|
|
||||||
# Remove primary key when None to prevent not null constraint in postgresql.
|
|
||||||
pkname = self.model_cls.__pkname__
|
|
||||||
pk = self.model_cls.__model_fields__[pkname]
|
|
||||||
if (
|
|
||||||
pkname in new_kwargs
|
|
||||||
and new_kwargs.get(pkname) is None
|
|
||||||
and (pk.nullable or pk.autoincrement)
|
|
||||||
):
|
|
||||||
del new_kwargs[pkname]
|
|
||||||
|
|
||||||
# substitute related models with their pk
|
|
||||||
for field in self.model_cls._extract_related_names():
|
|
||||||
if field in new_kwargs and new_kwargs.get(field) is not None:
|
|
||||||
new_kwargs[field] = getattr(
|
|
||||||
new_kwargs.get(field),
|
|
||||||
self.model_cls.__model_fields__[field].to.__pkname__,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build the insert expression.
|
|
||||||
expr = self.table.insert()
|
|
||||||
expr = expr.values(**new_kwargs)
|
|
||||||
|
|
||||||
# Execute the insert, and return a new model instance.
|
|
||||||
instance = self.model_cls(**kwargs)
|
|
||||||
instance.pk = await self.database.execute(expr)
|
|
||||||
return instance
|
|
||||||
0
orm/queryset/__init__.py
Normal file
0
orm/queryset/__init__.py
Normal file
176
orm/queryset/clause.py
Normal file
176
orm/queryset/clause.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
from typing import Type, List, Any, Tuple, Dict, Union, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
import orm
|
||||||
|
from orm.exceptions import QueryDefinitionError
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma no cover
|
||||||
|
from orm import Model
|
||||||
|
|
||||||
|
FILTER_OPERATORS = {
|
||||||
|
"exact": "__eq__",
|
||||||
|
"iexact": "ilike",
|
||||||
|
"contains": "like",
|
||||||
|
"icontains": "ilike",
|
||||||
|
"in": "in_",
|
||||||
|
"gt": "__gt__",
|
||||||
|
"gte": "__ge__",
|
||||||
|
"lt": "__lt__",
|
||||||
|
"lte": "__le__",
|
||||||
|
}
|
||||||
|
ESCAPE_CHARACTERS = ["%", "_"]
|
||||||
|
|
||||||
|
|
||||||
|
class QueryClause:
|
||||||
|
def __init__(
|
||||||
|
self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
self._select_related = select_related
|
||||||
|
self.filter_clauses = filter_clauses
|
||||||
|
|
||||||
|
self.model_cls = model_cls
|
||||||
|
self.table = self.model_cls.__table__
|
||||||
|
|
||||||
|
def filter( # noqa: A003
|
||||||
|
self, **kwargs: Any
|
||||||
|
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
|
||||||
|
filter_clauses = self.filter_clauses
|
||||||
|
select_related = list(self._select_related)
|
||||||
|
|
||||||
|
if kwargs.get("pk"):
|
||||||
|
pk_name = self.model_cls.__pkname__
|
||||||
|
kwargs[pk_name] = kwargs.pop("pk")
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
table_prefix = ""
|
||||||
|
if "__" in key:
|
||||||
|
parts = key.split("__")
|
||||||
|
|
||||||
|
(
|
||||||
|
op,
|
||||||
|
field_name,
|
||||||
|
related_parts,
|
||||||
|
) = self._extract_operator_field_and_related(parts)
|
||||||
|
|
||||||
|
model_cls = self.model_cls
|
||||||
|
if related_parts:
|
||||||
|
(
|
||||||
|
select_related,
|
||||||
|
table_prefix,
|
||||||
|
model_cls,
|
||||||
|
) = self._determine_filter_target_table(
|
||||||
|
related_parts, select_related
|
||||||
|
)
|
||||||
|
|
||||||
|
table = model_cls.__table__
|
||||||
|
column = model_cls.__table__.columns[field_name]
|
||||||
|
|
||||||
|
else:
|
||||||
|
op = "exact"
|
||||||
|
column = self.table.columns[key]
|
||||||
|
table = self.table
|
||||||
|
|
||||||
|
value, has_escaped_character = self._escape_characters_in_clause(op, value)
|
||||||
|
|
||||||
|
if isinstance(value, orm.Model):
|
||||||
|
value = value.pk
|
||||||
|
|
||||||
|
op_attr = FILTER_OPERATORS[op]
|
||||||
|
clause = getattr(column, op_attr)(value)
|
||||||
|
clause = self._compile_clause(
|
||||||
|
clause,
|
||||||
|
column,
|
||||||
|
table,
|
||||||
|
table_prefix,
|
||||||
|
modifiers={"escape": "\\" if has_escaped_character else None},
|
||||||
|
)
|
||||||
|
filter_clauses.append(clause)
|
||||||
|
|
||||||
|
return filter_clauses, select_related
|
||||||
|
|
||||||
|
def _determine_filter_target_table(
|
||||||
|
self, related_parts: List[str], select_related: List[str]
|
||||||
|
) -> Tuple[List[str], str, "Model"]:
|
||||||
|
|
||||||
|
table_prefix = ""
|
||||||
|
model_cls = self.model_cls
|
||||||
|
select_related = [relation for relation in select_related]
|
||||||
|
|
||||||
|
# Add any implied select_related
|
||||||
|
related_str = "__".join(related_parts)
|
||||||
|
if related_str not in select_related:
|
||||||
|
select_related.append(related_str)
|
||||||
|
|
||||||
|
# Walk the relationships to the actual model class
|
||||||
|
# against which the comparison is being made.
|
||||||
|
previous_table = model_cls.__tablename__
|
||||||
|
for part in related_parts:
|
||||||
|
current_table = model_cls.__model_fields__[part].to.__tablename__
|
||||||
|
manager = model_cls._orm_relationship_manager
|
||||||
|
table_prefix = manager.resolve_relation_join(previous_table, current_table)
|
||||||
|
model_cls = model_cls.__model_fields__[part].to
|
||||||
|
previous_table = current_table
|
||||||
|
return select_related, table_prefix, model_cls
|
||||||
|
|
||||||
|
def _compile_clause(
|
||||||
|
self,
|
||||||
|
clause: sqlalchemy.sql.expression.BinaryExpression,
|
||||||
|
column: sqlalchemy.Column,
|
||||||
|
table: sqlalchemy.Table,
|
||||||
|
table_prefix: str,
|
||||||
|
modifiers: Dict,
|
||||||
|
) -> sqlalchemy.sql.expression.TextClause:
|
||||||
|
for modifier, modifier_value in modifiers.items():
|
||||||
|
clause.modifiers[modifier] = modifier_value
|
||||||
|
|
||||||
|
clause_text = str(
|
||||||
|
clause.compile(
|
||||||
|
dialect=self.model_cls.__database__._backend._dialect,
|
||||||
|
compile_kwargs={"literal_binds": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
alias = f"{table_prefix}_" if table_prefix else ""
|
||||||
|
aliased_name = f"{alias}{table.name}.{column.name}"
|
||||||
|
clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name)
|
||||||
|
clause = text(clause_text)
|
||||||
|
return clause
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _escape_characters_in_clause(
|
||||||
|
op: str, value: Union[str, "Model"]
|
||||||
|
) -> Tuple[str, bool]:
|
||||||
|
has_escaped_character = False
|
||||||
|
|
||||||
|
if op in ["contains", "icontains"]:
|
||||||
|
if isinstance(value, orm.Model):
|
||||||
|
raise QueryDefinitionError(
|
||||||
|
"You cannot use contains and icontains with instance of the Model"
|
||||||
|
)
|
||||||
|
|
||||||
|
has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value)
|
||||||
|
|
||||||
|
if has_escaped_character:
|
||||||
|
# enable escape modifier
|
||||||
|
for char in ESCAPE_CHARACTERS:
|
||||||
|
value = value.replace(char, f"\\{char}")
|
||||||
|
value = f"%{value}%"
|
||||||
|
|
||||||
|
return value, has_escaped_character
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_operator_field_and_related(
|
||||||
|
parts: List[str],
|
||||||
|
) -> Tuple[str, str, Optional[List]]:
|
||||||
|
if parts[-1] in FILTER_OPERATORS:
|
||||||
|
op = parts[-1]
|
||||||
|
field_name = parts[-2]
|
||||||
|
related_parts = parts[:-2]
|
||||||
|
else:
|
||||||
|
op = "exact"
|
||||||
|
field_name = parts[-1]
|
||||||
|
related_parts = parts[:-1]
|
||||||
|
|
||||||
|
return op, field_name, related_parts
|
||||||
228
orm/queryset/query.py
Normal file
228
orm/queryset/query.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
from typing import NamedTuple, Type, List, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
import orm
|
||||||
|
from orm import ForeignKey
|
||||||
|
from orm.fields import BaseField
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma no cover
|
||||||
|
from orm import Model
|
||||||
|
|
||||||
|
|
||||||
|
class JoinParameters(NamedTuple):
|
||||||
|
prev_model: Type["Model"]
|
||||||
|
previous_alias: str
|
||||||
|
from_table: str
|
||||||
|
model_cls: Type["Model"]
|
||||||
|
|
||||||
|
|
||||||
|
class Query:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_cls: Type["Model"],
|
||||||
|
filter_clauses: List,
|
||||||
|
select_related: List,
|
||||||
|
limit_count: int,
|
||||||
|
offset: int,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
self.query_offset = offset
|
||||||
|
self.limit_count = limit_count
|
||||||
|
self._select_related = select_related
|
||||||
|
self.filter_clauses = filter_clauses
|
||||||
|
|
||||||
|
self.model_cls = model_cls
|
||||||
|
self.table = self.model_cls.__table__
|
||||||
|
|
||||||
|
self.auto_related = []
|
||||||
|
self.used_aliases = []
|
||||||
|
|
||||||
|
self.select_from = None
|
||||||
|
self.columns = None
|
||||||
|
self.order_bys = None
|
||||||
|
|
||||||
|
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
|
||||||
|
self.columns = list(self.table.columns)
|
||||||
|
self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")]
|
||||||
|
self.select_from = self.table
|
||||||
|
|
||||||
|
for key in self.model_cls.__model_fields__:
|
||||||
|
if (
|
||||||
|
not self.model_cls.__model_fields__[key].nullable
|
||||||
|
and isinstance(
|
||||||
|
self.model_cls.__model_fields__[key],
|
||||||
|
orm.fields.foreign_key.ForeignKey,
|
||||||
|
)
|
||||||
|
and key not in self._select_related
|
||||||
|
):
|
||||||
|
self._select_related = [key] + self._select_related
|
||||||
|
|
||||||
|
start_params = JoinParameters(
|
||||||
|
self.model_cls, "", self.table.name, self.model_cls
|
||||||
|
)
|
||||||
|
self._extract_auto_required_relations(prev_model=start_params.prev_model)
|
||||||
|
self._include_auto_related_models()
|
||||||
|
self._select_related.sort(key=lambda item: (-len(item), item))
|
||||||
|
|
||||||
|
for item in self._select_related:
|
||||||
|
join_parameters = JoinParameters(
|
||||||
|
self.model_cls, "", self.table.name, self.model_cls
|
||||||
|
)
|
||||||
|
|
||||||
|
for part in item.split("__"):
|
||||||
|
join_parameters = self._build_join_parameters(part, join_parameters)
|
||||||
|
|
||||||
|
expr = sqlalchemy.sql.select(self.columns)
|
||||||
|
expr = expr.select_from(self.select_from)
|
||||||
|
|
||||||
|
expr = self._apply_expression_modifiers(expr)
|
||||||
|
|
||||||
|
# print(expr.compile(compile_kwargs={"literal_binds": True}))
|
||||||
|
self._reset_query_parameters()
|
||||||
|
|
||||||
|
return expr, self._select_related
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
|
||||||
|
return [
|
||||||
|
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
|
||||||
|
for column in table.columns
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prefixed_table_name(alias: str, name: str) -> text:
|
||||||
|
return text(f"{name} {alias}_{name}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _field_is_a_foreign_key_and_no_circular_reference(
|
||||||
|
field: BaseField, field_name: str, rel_part: str
|
||||||
|
) -> bool:
|
||||||
|
return isinstance(field, ForeignKey) and field_name not in rel_part
|
||||||
|
|
||||||
|
def _field_qualifies_to_deeper_search(
|
||||||
|
self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str
|
||||||
|
) -> bool:
|
||||||
|
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
|
||||||
|
partial_match = any(
|
||||||
|
[x.startswith(prev_part_of_related) for x in self._select_related]
|
||||||
|
)
|
||||||
|
already_checked = any([x.startswith(rel_part) for x in self.auto_related])
|
||||||
|
return (
|
||||||
|
(field.virtual and parent_virtual)
|
||||||
|
or (partial_match and not already_checked)
|
||||||
|
) or not nested
|
||||||
|
|
||||||
|
def on_clause(
|
||||||
|
self, previous_alias: str, alias: str, from_clause: str, to_clause: str,
|
||||||
|
) -> text:
|
||||||
|
left_part = f"{alias}_{to_clause}"
|
||||||
|
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
|
||||||
|
return text(f"{left_part}={right_part}")
|
||||||
|
|
||||||
|
def _build_join_parameters(
|
||||||
|
self, part: str, join_params: JoinParameters
|
||||||
|
) -> JoinParameters:
|
||||||
|
model_cls = join_params.model_cls.__model_fields__[part].to
|
||||||
|
to_table = model_cls.__table__.name
|
||||||
|
|
||||||
|
alias = model_cls._orm_relationship_manager.resolve_relation_join(
|
||||||
|
join_params.from_table, to_table
|
||||||
|
)
|
||||||
|
if alias not in self.used_aliases:
|
||||||
|
if join_params.prev_model.__model_fields__[part].virtual:
|
||||||
|
to_key = next(
|
||||||
|
(
|
||||||
|
v
|
||||||
|
for k, v in model_cls.__model_fields__.items()
|
||||||
|
if isinstance(v, ForeignKey) and v.to == join_params.prev_model
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
).name
|
||||||
|
from_key = model_cls.__pkname__
|
||||||
|
else:
|
||||||
|
to_key = model_cls.__pkname__
|
||||||
|
from_key = part
|
||||||
|
|
||||||
|
on_clause = self.on_clause(
|
||||||
|
previous_alias=join_params.previous_alias,
|
||||||
|
alias=alias,
|
||||||
|
from_clause=f"{join_params.from_table}.{from_key}",
|
||||||
|
to_clause=f"{to_table}.{to_key}",
|
||||||
|
)
|
||||||
|
target_table = self.prefixed_table_name(alias, to_table)
|
||||||
|
self.select_from = sqlalchemy.sql.outerjoin(
|
||||||
|
self.select_from, target_table, on_clause
|
||||||
|
)
|
||||||
|
self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}"))
|
||||||
|
self.columns.extend(self.prefixed_columns(alias, model_cls.__table__))
|
||||||
|
self.used_aliases.append(alias)
|
||||||
|
|
||||||
|
previous_alias = alias
|
||||||
|
from_table = to_table
|
||||||
|
prev_model = model_cls
|
||||||
|
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
||||||
|
|
||||||
|
def _extract_auto_required_relations(
|
||||||
|
self,
|
||||||
|
prev_model: Type["Model"],
|
||||||
|
rel_part: str = "",
|
||||||
|
nested: bool = False,
|
||||||
|
parent_virtual: bool = False,
|
||||||
|
) -> None:
|
||||||
|
for field_name, field in prev_model.__model_fields__.items():
|
||||||
|
if self._field_is_a_foreign_key_and_no_circular_reference(
|
||||||
|
field, field_name, rel_part
|
||||||
|
):
|
||||||
|
rel_part = field_name if not rel_part else rel_part + "__" + field_name
|
||||||
|
if not field.nullable:
|
||||||
|
if rel_part not in self._select_related:
|
||||||
|
self.auto_related.append("__".join(rel_part.split("__")[:-1]))
|
||||||
|
rel_part = ""
|
||||||
|
elif self._field_qualifies_to_deeper_search(
|
||||||
|
field, parent_virtual, nested, rel_part
|
||||||
|
):
|
||||||
|
self._extract_auto_required_relations(
|
||||||
|
prev_model=field.to,
|
||||||
|
rel_part=rel_part,
|
||||||
|
nested=True,
|
||||||
|
parent_virtual=field.virtual,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rel_part = ""
|
||||||
|
|
||||||
|
def _include_auto_related_models(self) -> None:
|
||||||
|
if self.auto_related:
|
||||||
|
new_joins = []
|
||||||
|
for join in self._select_related:
|
||||||
|
if not any([x.startswith(join) for x in self.auto_related]):
|
||||||
|
new_joins.append(join)
|
||||||
|
self._select_related = new_joins + self.auto_related
|
||||||
|
|
||||||
|
def _apply_expression_modifiers(
|
||||||
|
self, expr: sqlalchemy.sql.select
|
||||||
|
) -> sqlalchemy.sql.select:
|
||||||
|
if self.filter_clauses:
|
||||||
|
if len(self.filter_clauses) == 1:
|
||||||
|
clause = self.filter_clauses[0]
|
||||||
|
else:
|
||||||
|
clause = sqlalchemy.sql.and_(*self.filter_clauses)
|
||||||
|
expr = expr.where(clause)
|
||||||
|
|
||||||
|
if self.limit_count:
|
||||||
|
expr = expr.limit(self.limit_count)
|
||||||
|
|
||||||
|
if self.query_offset:
|
||||||
|
expr = expr.offset(self.query_offset)
|
||||||
|
|
||||||
|
for order in self.order_bys:
|
||||||
|
expr = expr.order_by(order)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def _reset_query_parameters(self) -> None:
|
||||||
|
self.select_from = None
|
||||||
|
self.columns = None
|
||||||
|
self.order_bys = None
|
||||||
|
self.auto_related = []
|
||||||
|
self.used_aliases = []
|
||||||
175
orm/queryset/queryset.py
Normal file
175
orm/queryset/queryset.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
from typing import Type, List, Any, Union, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
|
import databases
|
||||||
|
import sqlalchemy
|
||||||
|
|
||||||
|
import orm # noqa I100
|
||||||
|
from orm import NoMatch, MultipleMatches
|
||||||
|
from orm.queryset.clause import QueryClause
|
||||||
|
from orm.queryset.query import Query
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma no cover
|
||||||
|
from orm import Model
|
||||||
|
|
||||||
|
|
||||||
|
class QuerySet:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_cls: Type["Model"] = None,
|
||||||
|
filter_clauses: List = None,
|
||||||
|
select_related: List = None,
|
||||||
|
limit_count: int = None,
|
||||||
|
offset: int = None,
|
||||||
|
) -> None:
|
||||||
|
self.model_cls = model_cls
|
||||||
|
self.filter_clauses = [] if filter_clauses is None else filter_clauses
|
||||||
|
self._select_related = [] if select_related is None else select_related
|
||||||
|
self.limit_count = limit_count
|
||||||
|
self.query_offset = offset
|
||||||
|
self.order_bys = None
|
||||||
|
|
||||||
|
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
|
||||||
|
return self.__class__(model_cls=owner)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def database(self) -> databases.Database:
|
||||||
|
return self.model_cls.__database__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def table(self) -> sqlalchemy.Table:
|
||||||
|
return self.model_cls.__table__
|
||||||
|
|
||||||
|
def build_select_expression(self) -> sqlalchemy.sql.select:
|
||||||
|
qry = Query(
|
||||||
|
model_cls=self.model_cls,
|
||||||
|
select_related=self._select_related,
|
||||||
|
filter_clauses=self.filter_clauses,
|
||||||
|
offset=self.query_offset,
|
||||||
|
limit_count=self.limit_count,
|
||||||
|
)
|
||||||
|
exp, self._select_related = qry.build_select_expression()
|
||||||
|
return exp
|
||||||
|
|
||||||
|
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
|
||||||
|
qryclause = QueryClause(
|
||||||
|
model_cls=self.model_cls,
|
||||||
|
select_related=self._select_related,
|
||||||
|
filter_clauses=self.filter_clauses,
|
||||||
|
)
|
||||||
|
filter_clauses, select_related = qryclause.filter(**kwargs)
|
||||||
|
|
||||||
|
return self.__class__(
|
||||||
|
model_cls=self.model_cls,
|
||||||
|
filter_clauses=filter_clauses,
|
||||||
|
select_related=select_related,
|
||||||
|
limit_count=self.limit_count,
|
||||||
|
offset=self.query_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet":
|
||||||
|
if not isinstance(related, (list, tuple)):
|
||||||
|
related = [related]
|
||||||
|
|
||||||
|
related = list(self._select_related) + related
|
||||||
|
return self.__class__(
|
||||||
|
model_cls=self.model_cls,
|
||||||
|
filter_clauses=self.filter_clauses,
|
||||||
|
select_related=related,
|
||||||
|
limit_count=self.limit_count,
|
||||||
|
offset=self.query_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def exists(self) -> bool:
|
||||||
|
expr = self.build_select_expression()
|
||||||
|
expr = sqlalchemy.exists(expr).select()
|
||||||
|
return await self.database.fetch_val(expr)
|
||||||
|
|
||||||
|
async def count(self) -> int:
|
||||||
|
expr = self.build_select_expression().alias("subquery_for_count")
|
||||||
|
expr = sqlalchemy.func.count().select().select_from(expr)
|
||||||
|
return await self.database.fetch_val(expr)
|
||||||
|
|
||||||
|
def limit(self, limit_count: int) -> "QuerySet":
|
||||||
|
return self.__class__(
|
||||||
|
model_cls=self.model_cls,
|
||||||
|
filter_clauses=self.filter_clauses,
|
||||||
|
select_related=self._select_related,
|
||||||
|
limit_count=limit_count,
|
||||||
|
offset=self.query_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
def offset(self, offset: int) -> "QuerySet":
|
||||||
|
return self.__class__(
|
||||||
|
model_cls=self.model_cls,
|
||||||
|
filter_clauses=self.filter_clauses,
|
||||||
|
select_related=self._select_related,
|
||||||
|
limit_count=self.limit_count,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def first(self, **kwargs: Any) -> "Model":
|
||||||
|
if kwargs:
|
||||||
|
return await self.filter(**kwargs).first()
|
||||||
|
|
||||||
|
rows = await self.limit(1).all()
|
||||||
|
if rows:
|
||||||
|
return rows[0]
|
||||||
|
|
||||||
|
async def get(self, **kwargs: Any) -> "Model":
|
||||||
|
if kwargs:
|
||||||
|
return await self.filter(**kwargs).get()
|
||||||
|
|
||||||
|
expr = self.build_select_expression().limit(2)
|
||||||
|
rows = await self.database.fetch_all(expr)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
raise NoMatch()
|
||||||
|
if len(rows) > 1:
|
||||||
|
raise MultipleMatches()
|
||||||
|
return self.model_cls.from_row(rows[0], select_related=self._select_related)
|
||||||
|
|
||||||
|
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
|
||||||
|
if kwargs:
|
||||||
|
return await self.filter(**kwargs).all()
|
||||||
|
|
||||||
|
expr = self.build_select_expression()
|
||||||
|
rows = await self.database.fetch_all(expr)
|
||||||
|
result_rows = [
|
||||||
|
self.model_cls.from_row(row, select_related=self._select_related)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
result_rows = self.model_cls.merge_instances_list(result_rows)
|
||||||
|
|
||||||
|
return result_rows
|
||||||
|
|
||||||
|
async def create(self, **kwargs: Any) -> "Model":
|
||||||
|
|
||||||
|
new_kwargs = dict(**kwargs)
|
||||||
|
|
||||||
|
# Remove primary key when None to prevent not null constraint in postgresql.
|
||||||
|
pkname = self.model_cls.__pkname__
|
||||||
|
pk = self.model_cls.__model_fields__[pkname]
|
||||||
|
if (
|
||||||
|
pkname in new_kwargs
|
||||||
|
and new_kwargs.get(pkname) is None
|
||||||
|
and (pk.nullable or pk.autoincrement)
|
||||||
|
):
|
||||||
|
del new_kwargs[pkname]
|
||||||
|
|
||||||
|
# substitute related models with their pk
|
||||||
|
for field in self.model_cls._extract_related_names():
|
||||||
|
if field in new_kwargs and new_kwargs.get(field) is not None:
|
||||||
|
new_kwargs[field] = getattr(
|
||||||
|
new_kwargs.get(field),
|
||||||
|
self.model_cls.__model_fields__[field].to.__pkname__,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build the insert expression.
|
||||||
|
expr = self.table.insert()
|
||||||
|
expr = expr.values(**new_kwargs)
|
||||||
|
|
||||||
|
# Execute the insert, and return a new model instance.
|
||||||
|
instance = self.model_cls(**kwargs)
|
||||||
|
instance.pk = await self.database.execute(expr)
|
||||||
|
return instance
|
||||||
Reference in New Issue
Block a user