liniting and applying black

This commit is contained in:
collerek
2020-08-09 07:51:06 +02:00
parent 9d9346fb13
commit 241628b1d9
9 changed files with 455 additions and 247 deletions

View File

@ -2,35 +2,39 @@ import copy
import inspect
import json
import uuid
from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar, Tuple
from typing import Set, Dict
from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
from typing import Callable, Dict, Set
import databases
import pydantic
import sqlalchemy
from pydantic import BaseModel, BaseConfig, create_model
import orm.queryset as qry
from orm.exceptions import ModelDefinitionError
from orm.fields import BaseField, ForeignKey
from orm.relations import RelationshipManager
import pydantic
from pydantic import BaseConfig, BaseModel, create_model
import sqlalchemy
relationship_manager = RelationshipManager()
def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]:
pydantic_fields = {field_name: (
base_field.__type__,
... if base_field.is_required else base_field.default_value
)
pydantic_fields = {
field_name: (
base_field.__type__,
... if base_field.is_required else base_field.default_value,
)
for field_name, base_field in object_dict.items()
if isinstance(base_field, BaseField)}
if isinstance(base_field, BaseField)
}
return pydantic_fields
def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename: str) -> Tuple[Optional[str],
List[sqlalchemy.Column],
Dict[str, BaseField]]:
def sqlalchemy_columns_from_model_fields(
name: str, object_dict: Dict, tablename: str
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
pkname: Optional[str] = None
columns: List[sqlalchemy.Column] = []
model_fields: Dict[str, BaseField] = {}
@ -42,9 +46,16 @@ def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename
if field.primary_key:
pkname = field_name
if isinstance(field, ForeignKey):
reverse_name = field.related_name or field.to.__name__.lower().title() + '_' + name.lower() + 's'
relation_name = name.lower().title() + '_' + field.to.__name__.lower()
relationship_manager.add_relation_type(relation_name, reverse_name, field, tablename)
reverse_name = (
field.related_name
or field.to.__name__.lower().title() + "_" + name.lower() + "s"
)
relation_name = (
name.lower().title() + "_" + field.to.__name__.lower()
)
relationship_manager.add_relation_type(
relation_name, reverse_name, field, tablename
)
columns.append(field.get_column(field_name))
return pkname, columns, model_fields
@ -57,9 +68,7 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
class ModelMetaclass(type):
def __new__(
mcs: type, name: str, bases: Any, attrs: dict
) -> type:
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs
)
@ -71,25 +80,29 @@ class ModelMetaclass(type):
metadata = attrs["__metadata__"]
# sqlalchemy table creation
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs, tablename)
attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns)
attrs['__columns__'] = columns
attrs['__pkname__'] = pkname
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
name, attrs, tablename
)
attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns)
attrs["__columns__"] = columns
attrs["__pkname__"] = pkname
if not pkname:
raise ModelDefinitionError('Table has to have a primary key.')
raise ModelDefinitionError("Table has to have a primary key.")
# pydantic model creation
pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
pydantic_model = create_model(name, __config__=get_pydantic_base_orm_config(), **pydantic_fields)
attrs['__pydantic_fields__'] = pydantic_fields
attrs['__pydantic_model__'] = pydantic_model
attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__)
attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__)
attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__)
attrs['__model_fields__'] = model_fields
pydantic_model = create_model(
name, __config__=get_pydantic_base_orm_config(), **pydantic_fields
)
attrs["__pydantic_fields__"] = pydantic_fields
attrs["__pydantic_model__"] = pydantic_model
attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__)
attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__)
attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__)
attrs["__model_fields__"] = model_fields
attrs['_orm_relationship_manager'] = relationship_manager
attrs["_orm_relationship_manager"] = relationship_manager
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs
@ -99,7 +112,8 @@ class ModelMetaclass(type):
class Model(list, metaclass=ModelMetaclass):
# Model inherits from list in order to be treated as request.Body parameter in fastapi routes,
# Model inherits from list in order to be treated as
# request.Body parameter in fastapi routes,
# inheriting from pydantic.BaseModel causes metaclass conflicts
__abstract__ = True
if TYPE_CHECKING: # pragma no cover
@ -115,17 +129,20 @@ class Model(list, metaclass=ModelMetaclass):
objects = qry.QuerySet()
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
self._orm_id: str = uuid.uuid4().hex
self._orm_saved: bool = False
self.values: Optional[BaseModel] = None
if "pk" in kwargs:
kwargs[self.__pkname__] = kwargs.pop("pk")
kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()}
kwargs = {
k: self.__model_fields__[k].expand_relationship(v, self)
for k, v in kwargs.items()
}
self.values = self.__pydantic_model__(**kwargs)
def __del__(self):
def __del__(self) -> None:
self._orm_relationship_manager.deregister(self)
def __setattr__(self, key: str, value: Any) -> None:
@ -138,20 +155,24 @@ class Model(list, metaclass=ModelMetaclass):
value = self.__model_fields__[key].expand_relationship(value, self)
relation_key = self.__class__.__name__.title() + '_' + key
relation_key = self.__class__.__name__.title() + "_" + key
if not self._orm_relationship_manager.contains(relation_key, self):
setattr(self.values, key, value)
else:
super().__setattr__(key, value)
def __getattribute__(self, key: str) -> Any:
if key != '__fields__' and key in self.__fields__:
relation_key = self.__class__.__name__.title() + '_' + key
if key != "__fields__" and key in self.__fields__:
relation_key = self.__class__.__name__.title() + "_" + key
if self._orm_relationship_manager.contains(relation_key, self):
return self._orm_relationship_manager.get(relation_key, self)
item = getattr(self.values, key, None)
if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str):
if (
item is not None
and self.is_conversion_to_json_needed(key)
and isinstance(item, str)
):
try:
item = json.loads(item)
except TypeError: # pragma no cover
@ -159,30 +180,41 @@ class Model(list, metaclass=ModelMetaclass):
return item
return super().__getattribute__(key)
def __eq__(self, other):
def __eq__(self, other: "Model") -> bool:
return self.values.dict() == other.values.dict()
def __same__(self, other):
assert self.__class__ == other.__class__
def __same__(self, other: "Model") -> bool:
if self.__class__ != other.__class__:
return False
return self._orm_id == other._orm_id or (
self.values is not None and other.values is not None and self.pk == other.pk)
self.values is not None and other.values is not None and self.pk == other.pk
)
def __repr__(self): # pragma no cover
def __repr__(self) -> str: # pragma no cover
return self.values.__repr__()
@classmethod
def from_row(cls, row, select_related: List = None, previous_table: str = None) -> 'Model':
def from_row(
cls,
row: sqlalchemy.engine.ResultProxy,
select_related: List = None,
previous_table: str = None,
) -> "Model":
item = {}
select_related = select_related or []
table_prefix = cls._orm_relationship_manager.resolve_relation_join(previous_table, cls.__table__.name)
table_prefix = cls._orm_relationship_manager.resolve_relation_join(
previous_table, cls.__table__.name
)
previous_table = cls.__table__.name
for related in select_related:
if "__" in related:
first_part, remainder = related.split("__", 1)
model_cls = cls.__model_fields__[first_part].to
child = model_cls.from_row(row, select_related=[remainder], previous_table=previous_table)
child = model_cls.from_row(
row, select_related=[remainder], previous_table=previous_table
)
item[first_part] = child
else:
model_cls = cls.__model_fields__[related].to
@ -191,7 +223,9 @@ class Model(list, metaclass=ModelMetaclass):
for column in cls.__table__.columns:
if column.name not in item:
item[column.name] = row[f'{table_prefix + "_" if table_prefix else ""}{column.name}']
item[column.name] = row[
f'{table_prefix + "_" if table_prefix else ""}{column.name}'
]
return cls(**item)
@ -200,7 +234,7 @@ class Model(list, metaclass=ModelMetaclass):
# return cls.__pydantic_model__.validate(value=value)
@classmethod
def __get_validators__(cls): # pragma no cover
def __get_validators__(cls) -> Callable: # pragma no cover
yield cls.__pydantic_model__.validate
# @classmethod
@ -211,11 +245,11 @@ class Model(list, metaclass=ModelMetaclass):
return self.__model_fields__.get(column_name).__type__ == pydantic.Json
@property
def pk(self):
def pk(self) -> str:
return getattr(self.values, self.__pkname__)
@pk.setter
def pk(self, value):
def pk(self, value: Any) -> None:
setattr(self.values, self.__pkname__, value)
@property
@ -229,7 +263,9 @@ class Model(list, metaclass=ModelMetaclass):
if isinstance(nested_model, list):
dict_instance[field] = [x.dict() for x in nested_model]
else:
dict_instance[field] = nested_model.dict() if nested_model is not None else {}
dict_instance[field] = (
nested_model.dict() if nested_model is not None else {}
)
return dict_instance
def from_dict(self, value_dict: Dict) -> None:
@ -245,16 +281,22 @@ class Model(list, metaclass=ModelMetaclass):
def extract_related_names(cls) -> Set:
related_names = set()
for name, field in cls.__fields__.items():
if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel):
if inspect.isclass(field.type_) and issubclass(
field.type_, pydantic.BaseModel
):
related_names.add(name)
return related_names
def extract_model_db_fields(self) -> Dict:
self_fields = self.extract_own_model_fields()
self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns}
self_fields = {
k: v for k, v in self_fields.items() if k in self.__table__.columns
}
for field in self.extract_related_names():
if getattr(self, field) is not None:
self_fields[field] = getattr(getattr(self, field), self.__model_fields__[field].to.__pkname__)
self_fields[field] = getattr(
getattr(self, field), self.__model_fields__[field].to.__pkname__
)
return self_fields
async def save(self) -> int:
@ -264,7 +306,7 @@ class Model(list, metaclass=ModelMetaclass):
expr = self.__table__.insert()
expr = expr.values(**self_fields)
item_id = await self.__database__.execute(expr)
setattr(self, 'pk', item_id)
self.pk = item_id
return item_id
async def update(self, **kwargs: Any) -> int:
@ -274,8 +316,11 @@ class Model(list, metaclass=ModelMetaclass):
self_fields = self.extract_model_db_fields()
self_fields.pop(self.__pkname__)
expr = self.__table__.update().values(**self_fields).where(
self.pk_column == getattr(self, self.__pkname__))
expr = (
self.__table__.update()
.values(**self_fields)
.where(self.pk_column == getattr(self, self.__pkname__))
)
result = await self.__database__.execute(expr)
return result
@ -285,7 +330,7 @@ class Model(list, metaclass=ModelMetaclass):
result = await self.__database__.execute(expr)
return result
async def load(self) -> 'Model':
async def load(self) -> "Model":
expr = self.__table__.select().where(self.pk_column == self.pk)
row = await self.__database__.fetch_one(expr)
self.from_dict(dict(row))