import copy import inspect import json import uuid from typing import Any, List, Type from typing import Set, Dict import pydantic import sqlalchemy from pydantic import BaseConfig, create_model from orm.exceptions import ModelDefinitionError, MultipleMatches, NoMatch from orm.fields import BaseField from orm.relations import RelationshipManager def parse_pydantic_field_from_model_fields(object_dict: dict): pydantic_fields = {field_name: ( base_field.__type__, ... if base_field.is_required else base_field.default_value ) for field_name, base_field in object_dict.items() if isinstance(base_field, BaseField)} return pydantic_fields FILTER_OPERATORS = { "exact": "__eq__", "iexact": "ilike", "contains": "like", "icontains": "ilike", "in": "in_", "gt": "__gt__", "gte": "__ge__", "lt": "__lt__", "lte": "__le__", } class QuerySet: ESCAPE_CHARACTERS = ['%', '_'] def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None, limit_count: int = None, offset: int = None): self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses self._select_related = [] if select_related is None else select_related self.limit_count = limit_count self.query_offset = offset def __get__(self, instance, owner): return self.__class__(model_cls=owner) @property def database(self): return self.model_cls.__database__ @property def table(self): return self.model_cls.__table__ def build_select_expression(self): tables = [self.table] select_from = self.table for item in self._select_related: model_cls = self.model_cls select_from = self.table for part in item.split("__"): model_cls = model_cls.__model_fields__[part].to select_from = sqlalchemy.sql.join(select_from, model_cls.__table__) tables.append(model_cls.__table__) expr = sqlalchemy.sql.select(tables) expr = expr.select_from(select_from) if self.filter_clauses: if len(self.filter_clauses) == 1: clause = self.filter_clauses[0] else: clause = sqlalchemy.sql.and_(*self.filter_clauses) expr = expr.where(clause) if self.limit_count: expr = expr.limit(self.limit_count) if self.query_offset: expr = expr.offset(self.query_offset) # print(expr.compile(compile_kwargs={"literal_binds": True})) return expr def filter(self, **kwargs): filter_clauses = self.filter_clauses select_related = list(self._select_related) if kwargs.get("pk"): pk_name = self.model_cls.__pkname__ kwargs[pk_name] = kwargs.pop("pk") for key, value in kwargs.items(): if "__" in key: parts = key.split("__") # Determine if we should treat the final part as a # filter operator or as a related field. if parts[-1] in FILTER_OPERATORS: op = parts[-1] field_name = parts[-2] related_parts = parts[:-2] else: op = "exact" field_name = parts[-1] related_parts = parts[:-1] model_cls = self.model_cls if related_parts: # Add any implied select_related related_str = "__".join(related_parts) if related_str not in select_related: select_related.append(related_str) # Walk the relationships to the actual model class # against which the comparison is being made. for part in related_parts: model_cls = model_cls.__model_fields__[part].to column = model_cls.__table__.columns[field_name] else: op = "exact" column = self.table.columns[key] # Map the operation code onto SQLAlchemy's ColumnElement # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement op_attr = FILTER_OPERATORS[op] has_escaped_character = False if op in ["contains", "icontains"]: has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS if c in value) if has_escaped_character: # enable escape modifier for char in self.ESCAPE_CHARACTERS: value = value.replace(char, f'\\{char}') value = f"%{value}%" if isinstance(value, Model): value = value.pk clause = getattr(column, op_attr)(value) clause.modifiers['escape'] = '\\' if has_escaped_character else None filter_clauses.append(clause) return self.__class__( model_cls=self.model_cls, filter_clauses=filter_clauses, select_related=select_related, limit_count=self.limit_count, offset=self.query_offset ) def select_related(self, related): if not isinstance(related, (list, tuple)): related = [related] related = list(self._select_related) + related return self.__class__( model_cls=self.model_cls, filter_clauses=self.filter_clauses, select_related=related, limit_count=self.limit_count, offset=self.query_offset ) # async def exists(self) -> bool: # expr = self.build_select_expression() # expr = sqlalchemy.exists(expr).select() # return await self.database.fetch_val(expr) def limit(self, limit_count: int): return self.__class__( model_cls=self.model_cls, filter_clauses=self.filter_clauses, select_related=self._select_related, limit_count=limit_count, offset=self.query_offset ) def offset(self, offset: int): return self.__class__( model_cls=self.model_cls, filter_clauses=self.filter_clauses, select_related=self._select_related, limit_count=self.limit_count, offset=offset ) async def get(self, **kwargs): if kwargs: return await self.filter(**kwargs).get() expr = self.build_select_expression().limit(2) rows = await self.database.fetch_all(expr) if not rows: raise NoMatch() if len(rows) > 1: raise MultipleMatches() return self.model_cls.from_row(rows[0], select_related=self._select_related) async def all(self, **kwargs): if kwargs: return await self.filter(**kwargs).all() expr = self.build_select_expression() rows = await self.database.fetch_all(expr) return [ self.model_cls.from_row(row, select_related=self._select_related) for row in rows ] async def create(self, **kwargs): new_kwargs = dict(**kwargs) # Remove primary key when None to prevent not null constraint in postgresql. pkname = self.model_cls.__pkname__ pk = self.model_cls.__model_fields__[pkname] if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement): del new_kwargs[pkname] # substitute related models with their pk for field in self.model_cls.extract_related_names(): if field in new_kwargs and new_kwargs.get(field) is not None: new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__) # Build the insert expression. expr = self.table.insert() expr = expr.values(**new_kwargs) # Execute the insert, and return a new model instance. instance = self.model_cls(**kwargs) instance.pk = await self.database.execute(expr) return instance class ModelMetaclass(type): def __new__( mcs: type, name: str, bases: Any, attrs: dict ) -> type: new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) if attrs.get("__abstract__"): return new_model tablename = attrs["__tablename__"] metadata = attrs["__metadata__"] pkname = None columns = [] model_fields = {} for field_name, field in attrs.items(): if isinstance(field, BaseField): model_fields[field_name] = field if not field.pydantic_only: if field.primary_key: pkname = field_name columns.append(field.get_column(field_name)) # sqlalchemy table creation attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) attrs['__columns__'] = columns attrs['__pkname__'] = pkname if not pkname: raise ModelDefinitionError('Table has to have a primary key.') # pydantic model creation pydantic_fields = parse_pydantic_field_from_model_fields(attrs) config = type('Config', (BaseConfig,), {'orm_mode': True}) pydantic_model = create_model(name, __config__=config, **pydantic_fields) attrs['__pydantic_fields__'] = pydantic_fields attrs['__pydantic_model__'] = pydantic_model attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__) attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__) attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__) attrs['__model_fields__'] = model_fields new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) return new_model class Model(metaclass=ModelMetaclass): __abstract__ = True objects = QuerySet() def __init__(self, **kwargs) -> None: self._orm_id = uuid.uuid4().hex self._orm_saved = False self._orm_relationship_manager = RelationshipManager(self) self._orm_observers = [] if "pk" in kwargs: kwargs[self.__pkname__] = kwargs.pop("pk") kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()} self.values = self.__pydantic_model__(**kwargs) def __setattr__(self, key: str, value: Any) -> None: if key in self.__fields__: if self.is_conversion_to_json_needed(key) and not isinstance(value, str): try: value = json.dumps(value) except TypeError: # pragma no cover pass value = self.__model_fields__[key].expand_relationship(value, self) setattr(self.values, key, value) else: super().__setattr__(key, value) def __getattribute__(self, key: str) -> Any: if key != '__fields__' and key in self.__fields__: if key in self._orm_relationship_manager: parent_item = self._orm_relationship_manager.get(key) return parent_item item = getattr(self.values, key, None) if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str): try: item = json.loads(item) except TypeError: # pragma no cover pass return item return super().__getattribute__(key) def __repr__(self): # pragma no cover return self.values.__repr__() # def attach(self, observer: 'Model'): # if all([obs._orm_id != observer._orm_id for obs in self._orm_observers]): # self._orm_observers.append(observer) # # def detach(self, observer: 'Model'): # for ind, obs in enumerate(self._orm_observers): # if obs._orm_id == observer._orm_id: # del self._orm_observers[ind] # break # def notify(self): for obs in self._orm_observers: # pragma no cover obs.orm_update(self) def orm_update(self, subject: 'Model') -> None: # pragma no cover print('should be updated here') @classmethod def from_row(cls, row, select_related: List = None) -> 'Model': item = {} select_related = select_related or [] for related in select_related: if "__" in related: first_part, remainder = related.split("__", 1) model_cls = cls.__model_fields__[first_part].to item[first_part] = model_cls.from_row(row, select_related=[remainder]) else: model_cls = cls.__model_fields__[related].to item[related] = model_cls.from_row(row) for column in cls.__table__.columns: if column.name not in item: item[column.name] = row[column] return cls(**item) def is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json @property def pk(self): return getattr(self.values, self.__pkname__) @pk.setter def pk(self, value): setattr(self.values, self.__pkname__, value) @property def pk_column(self) -> sqlalchemy.Column: return self.__table__.primary_key.columns.values()[0] def dict(self) -> Dict: return self.values.dict() def from_dict(self, value_dict: Dict) -> None: for key, value in value_dict.items(): setattr(self, key, value) def extract_own_model_fields(self) -> Dict: related_names = self.extract_related_names() self_fields = {k: v for k, v in self.dict().items() if k not in related_names} return self_fields @classmethod def extract_related_names(cls) -> Set: related_names = set() for name, field in cls.__fields__.items(): if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): related_names.add(name) # elif field.sub_fields and any( # [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]): # related_names.add(name) return related_names def extract_model_db_fields(self) -> Dict: self_fields = self.extract_own_model_fields() self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns} for field in self.extract_related_names(): if getattr(self, field) is not None: self_fields[field] = getattr(getattr(self, field), self.__model_fields__[field].to.__pkname__) return self_fields async def save(self) -> int: self_fields = self.extract_model_db_fields() if self.__model_fields__.get(self.__pkname__).autoincrement: self_fields.pop(self.__pkname__, None) expr = self.__table__.insert() expr = expr.values(**self_fields) item_id = await self.__database__.execute(expr) setattr(self, 'pk', item_id) self.notify() return item_id async def update(self, **kwargs: Any) -> int: if kwargs: new_values = {**self.dict(), **kwargs} self.from_dict(new_values) self_fields = self.extract_model_db_fields() self_fields.pop(self.__pkname__) expr = self.__table__.update().values(**self_fields).where( self.pk_column == getattr(self, self.__pkname__)) result = await self.__database__.execute(expr) self.notify() return result async def delete(self) -> int: expr = self.__table__.delete() expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) result = await self.__database__.execute(expr) self.notify() return result async def load(self) -> 'Model': expr = self.__table__.select().where(self.pk_column == self.pk) row = await self.__database__.fetch_one(expr) self.from_dict(dict(row)) self.notify() return self