import copy import inspect import json import uuid from abc import ABCMeta from typing import Any, List, Type from typing import Set, Dict import pydantic import sqlalchemy from pydantic import BaseConfig, create_model from orm.exceptions import ModelDefinitionError, MultipleMatches, NoMatch from orm.fields import BaseField from orm.relations import RelationshipManager def parse_pydantic_field_from_model_fields(object_dict: dict): pydantic_fields = {field_name: ( base_field.__type__, ... if base_field.is_required else base_field.default_value ) for field_name, base_field in object_dict.items() if isinstance(base_field, BaseField)} return pydantic_fields FILTER_OPERATORS = { "exact": "__eq__", "iexact": "ilike", "contains": "like", "icontains": "ilike", "in": "in_", "gt": "__gt__", "gte": "__ge__", "lt": "__lt__", "lte": "__le__", } class QuerySet: ESCAPE_CHARACTERS = ['%', '_'] def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None, limit_count: int = None, offset: int = None): self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses self._select_related = [] if select_related is None else select_related self.limit_count = limit_count self.query_offset = offset def __get__(self, instance, owner): return self.__class__(model_cls=owner) @property def database(self): return self.model_cls.__database__ @property def table(self): return self.model_cls.__table__ def build_select_expression(self): tables = [self.table] select_from = self.table for item in self._select_related: model_cls = self.model_cls select_from = self.table for part in item.split("__"): model_cls = model_cls.__model_fields__[part].to select_from = sqlalchemy.sql.join(select_from, model_cls.__table__) tables.append(model_cls.__table__) expr = sqlalchemy.sql.select(tables) expr = expr.select_from(select_from) if self.filter_clauses: if len(self.filter_clauses) == 1: clause = self.filter_clauses[0] else: clause = sqlalchemy.sql.and_(*self.filter_clauses) expr = expr.where(clause) if self.limit_count: expr = expr.limit(self.limit_count) if self.query_offset: expr = expr.offset(self.query_offset) # print(expr.compile(compile_kwargs={"literal_binds": True})) return expr def filter(self, **kwargs): filter_clauses = self.filter_clauses select_related = list(self._select_related) if kwargs.get("pk"): pk_name = self.model_cls.__pkname__ kwargs[pk_name] = kwargs.pop("pk") for key, value in kwargs.items(): if "__" in key: parts = key.split("__") # Determine if we should treat the final part as a # filter operator or as a related field. if parts[-1] in FILTER_OPERATORS: op = parts[-1] field_name = parts[-2] related_parts = parts[:-2] else: op = "exact" field_name = parts[-1] related_parts = parts[:-1] model_cls = self.model_cls if related_parts: # Add any implied select_related related_str = "__".join(related_parts) if related_str not in select_related: select_related.append(related_str) # Walk the relationships to the actual model class # against which the comparison is being made. for part in related_parts: model_cls = model_cls.__model_fields__[part].to column = model_cls.__table__.columns[field_name] else: op = "exact" column = self.table.columns[key] # Map the operation code onto SQLAlchemy's ColumnElement # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement op_attr = FILTER_OPERATORS[op] has_escaped_character = False if op in ["contains", "icontains"]: has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS if c in value) if has_escaped_character: # enable escape modifier for char in self.ESCAPE_CHARACTERS: value = value.replace(char, f'\\{char}') value = f"%{value}%" if isinstance(value, Model): value = value.pk clause = getattr(column, op_attr)(value) clause.modifiers['escape'] = '\\' if has_escaped_character else None filter_clauses.append(clause) return self.__class__( model_cls=self.model_cls, filter_clauses=filter_clauses, select_related=select_related, limit_count=self.limit_count, offset=self.query_offset ) def select_related(self, related): if not isinstance(related, (list, tuple)): related = [related] related = list(self._select_related) + related return self.__class__( model_cls=self.model_cls, filter_clauses=self.filter_clauses, select_related=related, limit_count=self.limit_count, offset=self.query_offset ) async def exists(self) -> bool: expr = self.build_select_expression() expr = sqlalchemy.exists(expr).select() return await self.database.fetch_val(expr) async def count(self) -> int: expr = self.build_select_expression().alias("subquery_for_count") expr = sqlalchemy.func.count().select().select_from(expr) return await self.database.fetch_val(expr) def limit(self, limit_count: int): return self.__class__( model_cls=self.model_cls, filter_clauses=self.filter_clauses, select_related=self._select_related, limit_count=limit_count, offset=self.query_offset ) def offset(self, offset: int): return self.__class__( model_cls=self.model_cls, filter_clauses=self.filter_clauses, select_related=self._select_related, limit_count=self.limit_count, offset=offset ) async def first(self, **kwargs): if kwargs: return await self.filter(**kwargs).first() rows = await self.limit(1).all() if rows: return rows[0] async def get(self, **kwargs): if kwargs: return await self.filter(**kwargs).get() expr = self.build_select_expression().limit(2) rows = await self.database.fetch_all(expr) if not rows: raise NoMatch() if len(rows) > 1: raise MultipleMatches() return self.model_cls.from_row(rows[0], select_related=self._select_related) async def all(self, **kwargs): if kwargs: return await self.filter(**kwargs).all() expr = self.build_select_expression() rows = await self.database.fetch_all(expr) return [ self.model_cls.from_row(row, select_related=self._select_related) for row in rows ] async def create(self, **kwargs): new_kwargs = dict(**kwargs) # Remove primary key when None to prevent not null constraint in postgresql. pkname = self.model_cls.__pkname__ pk = self.model_cls.__model_fields__[pkname] if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement): del new_kwargs[pkname] # substitute related models with their pk for field in self.model_cls.extract_related_names(): if field in new_kwargs and new_kwargs.get(field) is not None: new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__) # Build the insert expression. expr = self.table.insert() expr = expr.values(**new_kwargs) # Execute the insert, and return a new model instance. instance = self.model_cls(**kwargs) instance.pk = await self.database.execute(expr) return instance class ModelMetaclass(type): def __new__( mcs: type, name: str, bases: Any, attrs: dict ) -> type: new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) if attrs.get("__abstract__"): return new_model tablename = attrs["__tablename__"] metadata = attrs["__metadata__"] pkname = None columns = [] model_fields = {} for field_name, field in attrs.items(): if isinstance(field, BaseField): model_fields[field_name] = field if not field.pydantic_only: if field.primary_key: pkname = field_name columns.append(field.get_column(field_name)) # sqlalchemy table creation attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) attrs['__columns__'] = columns attrs['__pkname__'] = pkname if not pkname: raise ModelDefinitionError('Table has to have a primary key.') # pydantic model creation pydantic_fields = parse_pydantic_field_from_model_fields(attrs) config = type('Config', (BaseConfig,), {'orm_mode': True}) pydantic_model = create_model(name, __config__=config, **pydantic_fields) attrs['__pydantic_fields__'] = pydantic_fields attrs['__pydantic_model__'] = pydantic_model attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__) attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__) attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__) attrs['__model_fields__'] = model_fields new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) return new_model class Model(tuple, metaclass=ModelMetaclass): __abstract__ = True objects = QuerySet() def __init__(self, **kwargs) -> None: self._orm_id = uuid.uuid4().hex self._orm_saved = False self._orm_relationship_manager = RelationshipManager(self) self._orm_observers = [] if "pk" in kwargs: kwargs[self.__pkname__] = kwargs.pop("pk") kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()} self.values = self.__pydantic_model__(**kwargs) def __setattr__(self, key: str, value: Any) -> None: if key in self.__fields__: if self.is_conversion_to_json_needed(key) and not isinstance(value, str): try: value = json.dumps(value) except TypeError: # pragma no cover pass value = self.__model_fields__[key].expand_relationship(value, self) setattr(self.values, key, value) else: super().__setattr__(key, value) def __getattribute__(self, key: str) -> Any: if key != '__fields__' and key in self.__fields__: if key in self._orm_relationship_manager: parent_item = self._orm_relationship_manager.get(key) return parent_item item = getattr(self.values, key, None) if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str): try: item = json.loads(item) except TypeError: # pragma no cover pass return item return super().__getattribute__(key) def __eq__(self, other): return self.values.dict() == other.values.dict() def __repr__(self): # pragma no cover return self.values.__repr__() # def attach(self, observer: 'Model'): # if all([obs._orm_id != observer._orm_id for obs in self._orm_observers]): # self._orm_observers.append(observer) # # def detach(self, observer: 'Model'): # for ind, obs in enumerate(self._orm_observers): # if obs._orm_id == observer._orm_id: # del self._orm_observers[ind] # break # def notify(self): for obs in self._orm_observers: # pragma no cover obs.orm_update(self) def orm_update(self, subject: 'Model') -> None: # pragma no cover print('should be updated here') @classmethod def from_row(cls, row, select_related: List = None) -> 'Model': item = {} select_related = select_related or [] for related in select_related: if "__" in related: first_part, remainder = related.split("__", 1) model_cls = cls.__model_fields__[first_part].to item[first_part] = model_cls.from_row(row, select_related=[remainder]) else: model_cls = cls.__model_fields__[related].to item[related] = model_cls.from_row(row) for column in cls.__table__.columns: if column.name not in item: item[column.name] = row[column] return cls(**item) @classmethod def validate(cls: Type['Model'], value: Any) -> 'Model': # pragma no cover return cls.__pydantic_model__.validate(cls.__pydantic_model__.__class__, value) @classmethod def __get_validators__(cls): # pragma no cover yield cls.__pydantic_model__.validate @classmethod def schema(cls, by_alias: bool = True): # pragma no cover return cls.__pydantic_model__.schame(cls.__pydantic_model__, by_alias=by_alias) def is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json @property def pk(self): return getattr(self.values, self.__pkname__) @pk.setter def pk(self, value): setattr(self.values, self.__pkname__, value) @property def pk_column(self) -> sqlalchemy.Column: return self.__table__.primary_key.columns.values()[0] def dict(self) -> Dict: return self.values.dict() def from_dict(self, value_dict: Dict) -> None: for key, value in value_dict.items(): setattr(self, key, value) def extract_own_model_fields(self) -> Dict: related_names = self.extract_related_names() self_fields = {k: v for k, v in self.dict().items() if k not in related_names} return self_fields @classmethod def extract_related_names(cls) -> Set: related_names = set() for name, field in cls.__fields__.items(): if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): related_names.add(name) # elif field.sub_fields and any( # [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]): # related_names.add(name) return related_names def extract_model_db_fields(self) -> Dict: self_fields = self.extract_own_model_fields() self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns} for field in self.extract_related_names(): if getattr(self, field) is not None: self_fields[field] = getattr(getattr(self, field), self.__model_fields__[field].to.__pkname__) return self_fields async def save(self) -> int: self_fields = self.extract_model_db_fields() if self.__model_fields__.get(self.__pkname__).autoincrement: self_fields.pop(self.__pkname__, None) expr = self.__table__.insert() expr = expr.values(**self_fields) item_id = await self.__database__.execute(expr) setattr(self, 'pk', item_id) self.notify() return item_id async def update(self, **kwargs: Any) -> int: if kwargs: new_values = {**self.dict(), **kwargs} self.from_dict(new_values) self_fields = self.extract_model_db_fields() self_fields.pop(self.__pkname__) expr = self.__table__.update().values(**self_fields).where( self.pk_column == getattr(self, self.__pkname__)) result = await self.__database__.execute(expr) self.notify() return result async def delete(self) -> int: expr = self.__table__.delete() expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) result = await self.__database__.execute(expr) self.notify() return result async def load(self) -> 'Model': expr = self.__table__.select().where(self.pk_column == self.pk) row = await self.__database__.fetch_one(expr) self.from_dict(dict(row)) self.notify() return self