Files
ormar/orm/models/fakepydantic.py
2020-08-12 10:33:37 +02:00

210 lines
7.2 KiB
Python

import inspect
import json
import uuid
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
TYPE_CHECKING,
Type,
TypeVar,
Union,
)
import databases
import pydantic
import sqlalchemy
from pydantic import BaseModel
import orm # noqa I100
from orm.fields import BaseField
from orm.models.metaclass import ModelMetaclass
from orm.relations import RelationshipManager
if TYPE_CHECKING: # pragma no cover
from orm.models.model import Model
class FakePydantic(list, metaclass=ModelMetaclass):
# FakePydantic inherits from list in order to be treated as
# request.Body parameter in fastapi routes,
# inheriting from pydantic.BaseModel causes metaclass conflicts
__abstract__ = True
if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, TypeVar[BaseField]]
__table__: sqlalchemy.Table
__fields__: Dict[str, pydantic.fields.ModelField]
__pydantic_model__: Type[BaseModel]
__pkname__: str
__tablename__: str
__metadata__: sqlalchemy.MetaData
__database__: databases.Database
_orm_relationship_manager: RelationshipManager
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
self._orm_id: str = uuid.uuid4().hex
self._orm_saved: bool = False
self.values: Optional[BaseModel] = None
if "pk" in kwargs:
kwargs[self.__pkname__] = kwargs.pop("pk")
kwargs = {
k: self.__model_fields__[k].expand_relationship(v, self)
for k, v in kwargs.items()
}
self.values = self.__pydantic_model__(**kwargs)
def __del__(self) -> None:
self._orm_relationship_manager.deregister(self)
def __setattr__(self, key: str, value: Any) -> None:
if key in self.__fields__:
value = self._convert_json(key, value, op="dumps")
value = self.__model_fields__[key].expand_relationship(value, self)
relation_key = self.get_name(title=True) + "_" + key
if not self._orm_relationship_manager.contains(relation_key, self):
setattr(self.values, key, value)
else:
super().__setattr__(key, value)
def __getattribute__(self, key: str) -> Any:
if key != "__fields__" and key in self.__fields__:
relation_key = self.get_name(title=True) + "_" + key
if self._orm_relationship_manager.contains(relation_key, self):
return self._orm_relationship_manager.get(relation_key, self)
item = getattr(self.values, key, None)
item = self._convert_json(key, item, op="loads")
return item
return super().__getattribute__(key)
def __eq__(self, other: "Model") -> bool:
return self.values.dict() == other.values.dict()
def __same__(self, other: "Model") -> bool:
if self.__class__ != other.__class__: # pragma no cover
return False
return self._orm_id == other._orm_id or (
self.values is not None and other.values is not None and self.pk == other.pk
)
def __repr__(self) -> str: # pragma no cover
return self.values.__repr__()
@classmethod
def __get_validators__(cls) -> Callable: # pragma no cover
yield cls.__pydantic_model__.validate
@classmethod
def get_name(cls, title: bool = False, lower: bool = True) -> str:
name = cls.__name__
if lower:
name = name.lower()
if title:
name = name.title()
return name
@property
def pk_column(self) -> sqlalchemy.Column:
return self.__table__.primary_key.columns.values()[0]
@classmethod
def pk_type(cls) -> Any:
return cls.__model_fields__[cls.__pkname__].__type__
def dict(self) -> Dict: # noqa: A003
dict_instance = self.values.dict()
for field in self._extract_related_names():
nested_model = getattr(self, field)
if isinstance(nested_model, list):
dict_instance[field] = [x.dict() for x in nested_model]
else:
dict_instance[field] = (
nested_model.dict() if nested_model is not None else {}
)
return dict_instance
def from_dict(self, value_dict: Dict) -> None:
for key, value in value_dict.items():
setattr(self, key, value)
def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]:
if not self._is_conversion_to_json_needed(column_name):
return value
condition = (
isinstance(value, str) if op == "loads" else not isinstance(value, str)
)
operand = json.loads if op == "loads" else json.dumps
if condition:
try:
return operand(value)
except TypeError: # pragma no cover
pass
return value
def _is_conversion_to_json_needed(self, column_name: str) -> bool:
return self.__model_fields__.get(column_name).__type__ == pydantic.Json
def _extract_own_model_fields(self) -> Dict:
related_names = self._extract_related_names()
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
return self_fields
@classmethod
def _extract_related_names(cls) -> Set:
related_names = set()
for name, field in cls.__fields__.items():
if inspect.isclass(field.type_) and issubclass(
field.type_, pydantic.BaseModel
):
related_names.add(name)
return related_names
def _extract_model_db_fields(self) -> Dict:
self_fields = self._extract_own_model_fields()
self_fields = {
k: v for k, v in self_fields.items() if k in self.__table__.columns
}
for field in self._extract_related_names():
if getattr(self, field) is not None:
self_fields[field] = getattr(
getattr(self, field), self.__model_fields__[field].to.__pkname__
)
return self_fields
@classmethod
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows = []
for index, model in enumerate(result_rows):
if index > 0 and model.pk == result_rows[index - 1].pk:
result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
else:
merged_rows.append(model)
return merged_rows
@classmethod
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
for field in one.__model_fields__.keys():
if isinstance(getattr(one, field), list) and not isinstance(
getattr(one, field), orm.Model
):
setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), orm.Model):
if getattr(one, field).pk == getattr(other, field).pk:
setattr(
other,
field,
cls.merge_two_instances(
getattr(one, field), getattr(other, field)
),
)
return other