refactors in fields

This commit is contained in:
collerek
2020-08-09 12:04:44 +02:00
parent d9755234c1
commit 3f2568b27e
6 changed files with 140 additions and 114 deletions

BIN
.coverage

Binary file not shown.

View File

@ -1,6 +1,6 @@
import datetime
import decimal
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, Union
import orm
from orm.exceptions import ModelDefinitionError, RelationshipInstanceError
@ -29,6 +29,9 @@ class BaseField:
name = args.pop(0)
self.name = name
self._populate_from_kwargs(kwargs)
def _populate_from_kwargs(self, kwargs: Dict) -> None:
self.primary_key = kwargs.pop("primary_key", False)
self.autoincrement = kwargs.pop(
"autoincrement", self.primary_key and self.__type__ == int
@ -79,7 +82,7 @@ class BaseField:
index=self.index,
unique=self.unique,
default=self.default,
server_default=self.server_default
server_default=self.server_default,
)
def get_column_type(self) -> sqlalchemy.types.TypeEngine:
@ -228,13 +231,13 @@ class ForeignKey(BaseField):
def expand_relationship(
self, value: Any, child: "Model"
) -> Union["Model", List["Model"]]:
if not isinstance(value, (self.to, dict, int, str, list)) or (
isinstance(value, orm.models.Model) and not isinstance(value, self.to)
):
if isinstance(value, orm.models.Model) and not isinstance(value, self.to):
raise RelationshipInstanceError(
"Relationship model can be build only from orm.Model, "
"dict and integer or string (pk)."
f"Relationship error - expecting: {self.to.__name__}, "
f"but {value.__class__.__name__} encountered."
)
if isinstance(value, list) and not isinstance(value, self.to):
model = [self.expand_relationship(val, child) for val in value]
return model
@ -244,9 +247,19 @@ class ForeignKey(BaseField):
elif isinstance(value, dict):
model = self.to(**value)
else:
if not isinstance(value, self.to.pk_type()):
raise RelationshipInstanceError(
f"Relationship error - ForeignKey {self.to.__name__} is of type {self.to.pk_type()} "
f"of type {self.__type__} while {type(value)} passed as a parameter."
)
model = create_dummy_instance(fk=self.to, pk=value)
child_model_name = self.related_name or child.__class__.__name__.lower() + "s"
self.add_to_relationship_registry(model, child)
return model
def add_to_relationship_registry(self, model: "Model", child: "Model") -> None:
child_model_name = self.related_name or child.get_name() + "s"
model._orm_relationship_manager.add_relation(
model.__class__.__name__.lower(),
child.__class__.__name__.lower(),
@ -257,8 +270,14 @@ class ForeignKey(BaseField):
if (
child_model_name not in model.__fields__
and child.__class__.__name__.lower() not in model.__fields__
and child.get_name() not in model.__fields__
):
self.register_reverse_model_fields(model, child, child_model_name)
@staticmethod
def register_reverse_model_fields(
model: "Model", child: "Model", child_model_name: str
) -> None:
model.__fields__[child_model_name] = ModelField(
name=child_model_name,
type_=Optional[child.__pydantic_model__],
@ -268,5 +287,3 @@ class ForeignKey(BaseField):
model.__model_fields__[child_model_name] = ForeignKey(
child.__class__, name=child_model_name, virtual=True
)
return model

View File

@ -32,8 +32,17 @@ def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple
return pydantic_fields
def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None:
child_relation_name = field.to.get_name(title=True) + "_" + name.lower() + "s"
reverse_name = field.related_name or child_relation_name
relation_name = name.lower().title() + "_" + field.to.get_name()
relationship_manager.add_relation_type(
relation_name, reverse_name, field, table_name
)
def sqlalchemy_columns_from_model_fields(
name: str, object_dict: Dict, tablename: str
name: str, object_dict: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
pkname: Optional[str] = None
columns: List[sqlalchemy.Column] = []
@ -46,14 +55,7 @@ def sqlalchemy_columns_from_model_fields(
if field.primary_key:
pkname = field_name
if isinstance(field, ForeignKey):
child_relation_name = (
field.to.get_name(title=True) + "_" + name.lower() + "s"
)
reverse_name = field.related_name or child_relation_name
relation_name = name.lower().title() + "_" + field.to.get_name()
relationship_manager.add_relation_type(
relation_name, reverse_name, field, tablename
)
register_relation_on_build(table_name, field, name)
columns.append(field.get_column(field_name))
return pkname, columns, model_fields
@ -109,8 +111,8 @@ class ModelMetaclass(type):
return new_model
class Model(list, metaclass=ModelMetaclass):
# Model inherits from list in order to be treated as
class FakePydantic(list, metaclass=ModelMetaclass):
# FakePydantic inherits from list in order to be treated as
# request.Body parameter in fastapi routes,
# inheriting from pydantic.BaseModel causes metaclass conflicts
__abstract__ = True
@ -125,9 +127,8 @@ class Model(list, metaclass=ModelMetaclass):
__database__: databases.Database
_orm_relationship_manager: RelationshipManager
objects = qry.QuerySet()
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
self._orm_id: str = uuid.uuid4().hex
self._orm_saved: bool = False
self.values: Optional[BaseModel] = None
@ -145,7 +146,7 @@ class Model(list, metaclass=ModelMetaclass):
def __setattr__(self, key: str, value: Any) -> None:
if key in self.__fields__:
if self.is_conversion_to_json_needed(key) and not isinstance(value, str):
if self._is_conversion_to_json_needed(key) and not isinstance(value, str):
try:
value = json.dumps(value)
except TypeError: # pragma no cover
@ -168,7 +169,7 @@ class Model(list, metaclass=ModelMetaclass):
item = getattr(self.values, key, None)
if (
item is not None
and self.is_conversion_to_json_needed(key)
and self._is_conversion_to_json_needed(key)
and isinstance(item, str)
):
try:
@ -191,6 +192,79 @@ class Model(list, metaclass=ModelMetaclass):
def __repr__(self) -> str: # pragma no cover
return self.values.__repr__()
@classmethod
def __get_validators__(cls) -> Callable: # pragma no cover
yield cls.__pydantic_model__.validate
@classmethod
def get_name(cls, title: bool = False, lower: bool = True) -> str:
name = cls.__name__
if lower:
name = name.lower()
if title:
name = name.title()
return name
@property
def pk_column(self) -> sqlalchemy.Column:
return self.__table__.primary_key.columns.values()[0]
@classmethod
def pk_type(cls):
return cls.__model_fields__[cls.__pkname__].__type__
def dict(self) -> Dict: # noqa: A003
dict_instance = self.values.dict()
for field in self._extract_related_names():
nested_model = getattr(self, field)
if isinstance(nested_model, list):
dict_instance[field] = [x.dict() for x in nested_model]
else:
dict_instance[field] = (
nested_model.dict() if nested_model is not None else {}
)
return dict_instance
def from_dict(self, value_dict: Dict) -> None:
for key, value in value_dict.items():
setattr(self, key, value)
def _is_conversion_to_json_needed(self, column_name: str) -> bool:
return self.__model_fields__.get(column_name).__type__ == pydantic.Json
def _extract_own_model_fields(self) -> Dict:
related_names = self._extract_related_names()
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
return self_fields
@classmethod
def _extract_related_names(cls) -> Set:
related_names = set()
for name, field in cls.__fields__.items():
if inspect.isclass(field.type_) and issubclass(
field.type_, pydantic.BaseModel
):
related_names.add(name)
return related_names
def _extract_model_db_fields(self) -> Dict:
self_fields = self._extract_own_model_fields()
self_fields = {
k: v for k, v in self_fields.items() if k in self.__table__.columns
}
for field in self._extract_related_names():
if getattr(self, field) is not None:
self_fields[field] = getattr(
getattr(self, field), self.__model_fields__[field].to.__pkname__
)
return self_fields
class Model(FakePydantic):
__abstract__ = True
objects = qry.QuerySet()
@classmethod
def from_row(
cls,
@ -227,30 +301,6 @@ class Model(list, metaclass=ModelMetaclass):
return cls(**item)
# @classmethod
# def validate(cls, value: Any) -> 'BaseModel': # pragma no cover
# return cls.__pydantic_model__.validate(value=value)
@classmethod
def __get_validators__(cls) -> Callable: # pragma no cover
yield cls.__pydantic_model__.validate
# @classmethod
# def schema(cls, by_alias: bool = True): # pragma no cover
# return cls.__pydantic_model__.schema(by_alias=by_alias)
@classmethod
def get_name(cls, title: bool = False, lower: bool = True) -> str:
name = cls.__name__
if lower:
name = name.lower()
if title:
name = name.title()
return name
def is_conversion_to_json_needed(self, column_name: str) -> bool:
return self.__model_fields__.get(column_name).__type__ == pydantic.Json
@property
def pk(self) -> str:
return getattr(self.values, self.__pkname__)
@ -259,55 +309,8 @@ class Model(list, metaclass=ModelMetaclass):
def pk(self, value: Any) -> None:
setattr(self.values, self.__pkname__, value)
@property
def pk_column(self) -> sqlalchemy.Column:
return self.__table__.primary_key.columns.values()[0]
def dict(self) -> Dict: # noqa: A003
dict_instance = self.values.dict()
for field in self.extract_related_names():
nested_model = getattr(self, field)
if isinstance(nested_model, list):
dict_instance[field] = [x.dict() for x in nested_model]
else:
dict_instance[field] = (
nested_model.dict() if nested_model is not None else {}
)
return dict_instance
def from_dict(self, value_dict: Dict) -> None:
for key, value in value_dict.items():
setattr(self, key, value)
def extract_own_model_fields(self) -> Dict:
related_names = self.extract_related_names()
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
return self_fields
@classmethod
def extract_related_names(cls) -> Set:
related_names = set()
for name, field in cls.__fields__.items():
if inspect.isclass(field.type_) and issubclass(
field.type_, pydantic.BaseModel
):
related_names.add(name)
return related_names
def extract_model_db_fields(self) -> Dict:
self_fields = self.extract_own_model_fields()
self_fields = {
k: v for k, v in self_fields.items() if k in self.__table__.columns
}
for field in self.extract_related_names():
if getattr(self, field) is not None:
self_fields[field] = getattr(
getattr(self, field), self.__model_fields__[field].to.__pkname__
)
return self_fields
async def save(self) -> int:
self_fields = self.extract_model_db_fields()
self_fields = self._extract_model_db_fields()
if self.__model_fields__.get(self.__pkname__).autoincrement:
self_fields.pop(self.__pkname__, None)
expr = self.__table__.insert()
@ -321,7 +324,7 @@ class Model(list, metaclass=ModelMetaclass):
new_values = {**self.dict(), **kwargs}
self.from_dict(new_values)
self_fields = self.extract_model_db_fields()
self_fields = self._extract_model_db_fields()
self_fields.pop(self.__pkname__)
expr = (
self.__table__.update()

View File

@ -527,7 +527,7 @@ class QuerySet:
del new_kwargs[pkname]
# substitute related models with their pk
for field in self.model_cls.extract_related_names():
for field in self.model_cls._extract_related_names():
if field in new_kwargs and new_kwargs.get(field) is not None:
new_kwargs[field] = getattr(
new_kwargs.get(field),

View File

@ -8,7 +8,7 @@ from weakref import proxy
from orm.fields import ForeignKey
if TYPE_CHECKING: # pragma no cover
from orm.models import Model
from orm.models import FakePydantic, Model
def get_table_alias() -> str:
@ -48,7 +48,7 @@ class RelationshipManager:
"reverse", table_name, field
)
def deregister(self, model: "Model") -> None:
def deregister(self, model: "FakePydantic") -> None:
# print(f'deregistering {model.__class__.__name__}, {model._orm_id}')
for rel_type in self._relations.keys():
if model.__class__.__name__.lower() in rel_type.lower():
@ -59,8 +59,8 @@ class RelationshipManager:
self,
parent_name: str,
child_name: str,
parent: "Model",
child: "Model",
parent: "FakePydantic",
child: "FakePydantic",
virtual: bool = False,
) -> None:
parent_id = parent._orm_id
@ -91,13 +91,13 @@ class RelationshipManager:
relations_list.append(model)
def contains(self, relations_key: str, instance: "Model") -> bool:
def contains(self, relations_key: str, instance: "FakePydantic") -> bool:
if relations_key in self._relations:
return instance._orm_id in self._relations[relations_key]
return False
def get(
self, relations_key: str, instance: "Model"
self, relations_key: str, instance: "FakePydantic"
) -> Union["Model", List["Model"]]:
if relations_key in self._relations:
if instance._orm_id in self._relations[relations_key]:

View File

@ -67,6 +67,12 @@ def create_test_database():
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_wrong_query_foreign_key_type():
with pytest.raises(RelationshipInstanceError):
Track(title="The Error", album="wrong_pk_type")
@pytest.mark.asyncio
async def test_model_crud():
async with database: