liniting and applying black
This commit is contained in:
165
orm/models.py
165
orm/models.py
@ -2,35 +2,39 @@ import copy
|
||||
import inspect
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar, Tuple
|
||||
from typing import Set, Dict
|
||||
from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
|
||||
from typing import Callable, Dict, Set
|
||||
|
||||
import databases
|
||||
import pydantic
|
||||
import sqlalchemy
|
||||
from pydantic import BaseModel, BaseConfig, create_model
|
||||
|
||||
import orm.queryset as qry
|
||||
from orm.exceptions import ModelDefinitionError
|
||||
from orm.fields import BaseField, ForeignKey
|
||||
from orm.relations import RelationshipManager
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseConfig, BaseModel, create_model
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
relationship_manager = RelationshipManager()
|
||||
|
||||
|
||||
def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]:
|
||||
pydantic_fields = {field_name: (
|
||||
base_field.__type__,
|
||||
... if base_field.is_required else base_field.default_value
|
||||
)
|
||||
pydantic_fields = {
|
||||
field_name: (
|
||||
base_field.__type__,
|
||||
... if base_field.is_required else base_field.default_value,
|
||||
)
|
||||
for field_name, base_field in object_dict.items()
|
||||
if isinstance(base_field, BaseField)}
|
||||
if isinstance(base_field, BaseField)
|
||||
}
|
||||
return pydantic_fields
|
||||
|
||||
|
||||
def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename: str) -> Tuple[Optional[str],
|
||||
List[sqlalchemy.Column],
|
||||
Dict[str, BaseField]]:
|
||||
def sqlalchemy_columns_from_model_fields(
|
||||
name: str, object_dict: Dict, tablename: str
|
||||
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
|
||||
pkname: Optional[str] = None
|
||||
columns: List[sqlalchemy.Column] = []
|
||||
model_fields: Dict[str, BaseField] = {}
|
||||
@ -42,9 +46,16 @@ def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename
|
||||
if field.primary_key:
|
||||
pkname = field_name
|
||||
if isinstance(field, ForeignKey):
|
||||
reverse_name = field.related_name or field.to.__name__.lower().title() + '_' + name.lower() + 's'
|
||||
relation_name = name.lower().title() + '_' + field.to.__name__.lower()
|
||||
relationship_manager.add_relation_type(relation_name, reverse_name, field, tablename)
|
||||
reverse_name = (
|
||||
field.related_name
|
||||
or field.to.__name__.lower().title() + "_" + name.lower() + "s"
|
||||
)
|
||||
relation_name = (
|
||||
name.lower().title() + "_" + field.to.__name__.lower()
|
||||
)
|
||||
relationship_manager.add_relation_type(
|
||||
relation_name, reverse_name, field, tablename
|
||||
)
|
||||
columns.append(field.get_column(field_name))
|
||||
return pkname, columns, model_fields
|
||||
|
||||
@ -57,9 +68,7 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
|
||||
|
||||
|
||||
class ModelMetaclass(type):
|
||||
def __new__(
|
||||
mcs: type, name: str, bases: Any, attrs: dict
|
||||
) -> type:
|
||||
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
|
||||
new_model = super().__new__( # type: ignore
|
||||
mcs, name, bases, attrs
|
||||
)
|
||||
@ -71,25 +80,29 @@ class ModelMetaclass(type):
|
||||
metadata = attrs["__metadata__"]
|
||||
|
||||
# sqlalchemy table creation
|
||||
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs, tablename)
|
||||
attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns)
|
||||
attrs['__columns__'] = columns
|
||||
attrs['__pkname__'] = pkname
|
||||
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
|
||||
name, attrs, tablename
|
||||
)
|
||||
attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns)
|
||||
attrs["__columns__"] = columns
|
||||
attrs["__pkname__"] = pkname
|
||||
|
||||
if not pkname:
|
||||
raise ModelDefinitionError('Table has to have a primary key.')
|
||||
raise ModelDefinitionError("Table has to have a primary key.")
|
||||
|
||||
# pydantic model creation
|
||||
pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
|
||||
pydantic_model = create_model(name, __config__=get_pydantic_base_orm_config(), **pydantic_fields)
|
||||
attrs['__pydantic_fields__'] = pydantic_fields
|
||||
attrs['__pydantic_model__'] = pydantic_model
|
||||
attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__)
|
||||
attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__)
|
||||
attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__)
|
||||
attrs['__model_fields__'] = model_fields
|
||||
pydantic_model = create_model(
|
||||
name, __config__=get_pydantic_base_orm_config(), **pydantic_fields
|
||||
)
|
||||
attrs["__pydantic_fields__"] = pydantic_fields
|
||||
attrs["__pydantic_model__"] = pydantic_model
|
||||
attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__)
|
||||
attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__)
|
||||
attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__)
|
||||
attrs["__model_fields__"] = model_fields
|
||||
|
||||
attrs['_orm_relationship_manager'] = relationship_manager
|
||||
attrs["_orm_relationship_manager"] = relationship_manager
|
||||
|
||||
new_model = super().__new__( # type: ignore
|
||||
mcs, name, bases, attrs
|
||||
@ -99,7 +112,8 @@ class ModelMetaclass(type):
|
||||
|
||||
|
||||
class Model(list, metaclass=ModelMetaclass):
|
||||
# Model inherits from list in order to be treated as request.Body parameter in fastapi routes,
|
||||
# Model inherits from list in order to be treated as
|
||||
# request.Body parameter in fastapi routes,
|
||||
# inheriting from pydantic.BaseModel causes metaclass conflicts
|
||||
__abstract__ = True
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
@ -115,17 +129,20 @@ class Model(list, metaclass=ModelMetaclass):
|
||||
|
||||
objects = qry.QuerySet()
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._orm_id: str = uuid.uuid4().hex
|
||||
self._orm_saved: bool = False
|
||||
self.values: Optional[BaseModel] = None
|
||||
|
||||
if "pk" in kwargs:
|
||||
kwargs[self.__pkname__] = kwargs.pop("pk")
|
||||
kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()}
|
||||
kwargs = {
|
||||
k: self.__model_fields__[k].expand_relationship(v, self)
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
self.values = self.__pydantic_model__(**kwargs)
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
self._orm_relationship_manager.deregister(self)
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
@ -138,20 +155,24 @@ class Model(list, metaclass=ModelMetaclass):
|
||||
|
||||
value = self.__model_fields__[key].expand_relationship(value, self)
|
||||
|
||||
relation_key = self.__class__.__name__.title() + '_' + key
|
||||
relation_key = self.__class__.__name__.title() + "_" + key
|
||||
if not self._orm_relationship_manager.contains(relation_key, self):
|
||||
setattr(self.values, key, value)
|
||||
else:
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __getattribute__(self, key: str) -> Any:
|
||||
if key != '__fields__' and key in self.__fields__:
|
||||
relation_key = self.__class__.__name__.title() + '_' + key
|
||||
if key != "__fields__" and key in self.__fields__:
|
||||
relation_key = self.__class__.__name__.title() + "_" + key
|
||||
if self._orm_relationship_manager.contains(relation_key, self):
|
||||
return self._orm_relationship_manager.get(relation_key, self)
|
||||
|
||||
item = getattr(self.values, key, None)
|
||||
if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str):
|
||||
if (
|
||||
item is not None
|
||||
and self.is_conversion_to_json_needed(key)
|
||||
and isinstance(item, str)
|
||||
):
|
||||
try:
|
||||
item = json.loads(item)
|
||||
except TypeError: # pragma no cover
|
||||
@ -159,30 +180,41 @@ class Model(list, metaclass=ModelMetaclass):
|
||||
return item
|
||||
return super().__getattribute__(key)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: "Model") -> bool:
|
||||
return self.values.dict() == other.values.dict()
|
||||
|
||||
def __same__(self, other):
|
||||
assert self.__class__ == other.__class__
|
||||
def __same__(self, other: "Model") -> bool:
|
||||
if self.__class__ != other.__class__:
|
||||
return False
|
||||
return self._orm_id == other._orm_id or (
|
||||
self.values is not None and other.values is not None and self.pk == other.pk)
|
||||
self.values is not None and other.values is not None and self.pk == other.pk
|
||||
)
|
||||
|
||||
def __repr__(self): # pragma no cover
|
||||
def __repr__(self) -> str: # pragma no cover
|
||||
return self.values.__repr__()
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row, select_related: List = None, previous_table: str = None) -> 'Model':
|
||||
def from_row(
|
||||
cls,
|
||||
row: sqlalchemy.engine.ResultProxy,
|
||||
select_related: List = None,
|
||||
previous_table: str = None,
|
||||
) -> "Model":
|
||||
|
||||
item = {}
|
||||
select_related = select_related or []
|
||||
|
||||
table_prefix = cls._orm_relationship_manager.resolve_relation_join(previous_table, cls.__table__.name)
|
||||
table_prefix = cls._orm_relationship_manager.resolve_relation_join(
|
||||
previous_table, cls.__table__.name
|
||||
)
|
||||
previous_table = cls.__table__.name
|
||||
for related in select_related:
|
||||
if "__" in related:
|
||||
first_part, remainder = related.split("__", 1)
|
||||
model_cls = cls.__model_fields__[first_part].to
|
||||
child = model_cls.from_row(row, select_related=[remainder], previous_table=previous_table)
|
||||
child = model_cls.from_row(
|
||||
row, select_related=[remainder], previous_table=previous_table
|
||||
)
|
||||
item[first_part] = child
|
||||
else:
|
||||
model_cls = cls.__model_fields__[related].to
|
||||
@ -191,7 +223,9 @@ class Model(list, metaclass=ModelMetaclass):
|
||||
|
||||
for column in cls.__table__.columns:
|
||||
if column.name not in item:
|
||||
item[column.name] = row[f'{table_prefix + "_" if table_prefix else ""}{column.name}']
|
||||
item[column.name] = row[
|
||||
f'{table_prefix + "_" if table_prefix else ""}{column.name}'
|
||||
]
|
||||
|
||||
return cls(**item)
|
||||
|
||||
@ -200,7 +234,7 @@ class Model(list, metaclass=ModelMetaclass):
|
||||
# return cls.__pydantic_model__.validate(value=value)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls): # pragma no cover
|
||||
def __get_validators__(cls) -> Callable: # pragma no cover
|
||||
yield cls.__pydantic_model__.validate
|
||||
|
||||
# @classmethod
|
||||
@ -211,11 +245,11 @@ class Model(list, metaclass=ModelMetaclass):
|
||||
return self.__model_fields__.get(column_name).__type__ == pydantic.Json
|
||||
|
||||
@property
|
||||
def pk(self):
|
||||
def pk(self) -> str:
|
||||
return getattr(self.values, self.__pkname__)
|
||||
|
||||
@pk.setter
|
||||
def pk(self, value):
|
||||
def pk(self, value: Any) -> None:
|
||||
setattr(self.values, self.__pkname__, value)
|
||||
|
||||
@property
|
||||
@ -229,7 +263,9 @@ class Model(list, metaclass=ModelMetaclass):
|
||||
if isinstance(nested_model, list):
|
||||
dict_instance[field] = [x.dict() for x in nested_model]
|
||||
else:
|
||||
dict_instance[field] = nested_model.dict() if nested_model is not None else {}
|
||||
dict_instance[field] = (
|
||||
nested_model.dict() if nested_model is not None else {}
|
||||
)
|
||||
return dict_instance
|
||||
|
||||
def from_dict(self, value_dict: Dict) -> None:
|
||||
@ -245,16 +281,22 @@ class Model(list, metaclass=ModelMetaclass):
|
||||
def extract_related_names(cls) -> Set:
|
||||
related_names = set()
|
||||
for name, field in cls.__fields__.items():
|
||||
if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel):
|
||||
if inspect.isclass(field.type_) and issubclass(
|
||||
field.type_, pydantic.BaseModel
|
||||
):
|
||||
related_names.add(name)
|
||||
return related_names
|
||||
|
||||
def extract_model_db_fields(self) -> Dict:
|
||||
self_fields = self.extract_own_model_fields()
|
||||
self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns}
|
||||
self_fields = {
|
||||
k: v for k, v in self_fields.items() if k in self.__table__.columns
|
||||
}
|
||||
for field in self.extract_related_names():
|
||||
if getattr(self, field) is not None:
|
||||
self_fields[field] = getattr(getattr(self, field), self.__model_fields__[field].to.__pkname__)
|
||||
self_fields[field] = getattr(
|
||||
getattr(self, field), self.__model_fields__[field].to.__pkname__
|
||||
)
|
||||
return self_fields
|
||||
|
||||
async def save(self) -> int:
|
||||
@ -264,7 +306,7 @@ class Model(list, metaclass=ModelMetaclass):
|
||||
expr = self.__table__.insert()
|
||||
expr = expr.values(**self_fields)
|
||||
item_id = await self.__database__.execute(expr)
|
||||
setattr(self, 'pk', item_id)
|
||||
self.pk = item_id
|
||||
return item_id
|
||||
|
||||
async def update(self, **kwargs: Any) -> int:
|
||||
@ -274,8 +316,11 @@ class Model(list, metaclass=ModelMetaclass):
|
||||
|
||||
self_fields = self.extract_model_db_fields()
|
||||
self_fields.pop(self.__pkname__)
|
||||
expr = self.__table__.update().values(**self_fields).where(
|
||||
self.pk_column == getattr(self, self.__pkname__))
|
||||
expr = (
|
||||
self.__table__.update()
|
||||
.values(**self_fields)
|
||||
.where(self.pk_column == getattr(self, self.__pkname__))
|
||||
)
|
||||
result = await self.__database__.execute(expr)
|
||||
return result
|
||||
|
||||
@ -285,7 +330,7 @@ class Model(list, metaclass=ModelMetaclass):
|
||||
result = await self.__database__.execute(expr)
|
||||
return result
|
||||
|
||||
async def load(self) -> 'Model':
|
||||
async def load(self) -> "Model":
|
||||
expr = self.__table__.select().where(self.pk_column == self.pk)
|
||||
row = await self.__database__.fetch_one(expr)
|
||||
self.from_dict(dict(row))
|
||||
|
||||
Reference in New Issue
Block a user