liniting and applying black

This commit is contained in:
collerek
2020-08-09 07:51:06 +02:00
parent 9d9346fb13
commit 241628b1d9
9 changed files with 455 additions and 247 deletions

BIN
.coverage

Binary file not shown.

View File

@ -1,4 +1,5 @@
[flake8] [flake8]
ignore = ANN101 ignore = ANN101, ANN102, W503
max-complexity = 10 max-complexity = 10
max-line-length = 88
exclude = p38venv,.pytest_cache exclude = p38venv,.pytest_cache

View File

@ -1,6 +1,18 @@
from orm.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch from orm.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch
from orm.fields import BigInteger, Boolean, Date, DateTime, Decimal, Float, ForeignKey, Integer, JSON, String, Text, \ from orm.fields import (
Time BigInteger,
Boolean,
Date,
DateTime,
Decimal,
Float,
ForeignKey,
Integer,
JSON,
String,
Text,
Time,
)
from orm.models import Model from orm.models import Model
__version__ = "0.0.1" __version__ = "0.0.1"
@ -21,5 +33,5 @@ __all__ = [
"ModelDefinitionError", "ModelDefinitionError",
"ModelNotSet", "ModelNotSet",
"MultipleMatches", "MultipleMatches",
"NoMatch" "NoMatch",
] ]

View File

@ -1,14 +1,15 @@
import datetime import datetime
import decimal import decimal
from typing import List, Optional, TYPE_CHECKING, Type, Any, Union from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
import sqlalchemy
from pydantic import Json, BaseModel
from pydantic.fields import ModelField
import orm import orm
from orm.exceptions import ModelDefinitionError, RelationshipInstanceError from orm.exceptions import ModelDefinitionError, RelationshipInstanceError
from pydantic import BaseModel, Json
from pydantic.fields import ModelField
import sqlalchemy
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from orm.models import Model from orm.models import Model
@ -16,33 +17,39 @@ if TYPE_CHECKING: # pragma no cover
class BaseField: class BaseField:
__type__ = None __type__ = None
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
name = kwargs.pop('name', None) name = kwargs.pop("name", None)
args = list(args) args = list(args)
if args: if args:
if isinstance(args[0], str): if isinstance(args[0], str):
if name is not None: if name is not None:
raise ModelDefinitionError('Column name cannot be passed positionally and as a keyword.') raise ModelDefinitionError(
"Column name cannot be passed positionally and as a keyword."
)
name = args.pop(0) name = args.pop(0)
self.name = name self.name = name
self.primary_key = kwargs.pop('primary_key', False) self.primary_key = kwargs.pop("primary_key", False)
self.autoincrement = kwargs.pop('autoincrement', self.primary_key and self.__type__ == int) self.autoincrement = kwargs.pop(
"autoincrement", self.primary_key and self.__type__ == int
)
self.nullable = kwargs.pop('nullable', not self.primary_key) self.nullable = kwargs.pop("nullable", not self.primary_key)
self.default = kwargs.pop('default', None) self.default = kwargs.pop("default", None)
self.server_default = kwargs.pop('server_default', None) self.server_default = kwargs.pop("server_default", None)
self.index = kwargs.pop('index', None) self.index = kwargs.pop("index", None)
self.unique = kwargs.pop('unique', None) self.unique = kwargs.pop("unique", None)
self.pydantic_only = kwargs.pop('pydantic_only', False) self.pydantic_only = kwargs.pop("pydantic_only", False)
if self.pydantic_only and self.primary_key: if self.pydantic_only and self.primary_key:
raise ModelDefinitionError('Primary key column cannot be pydantic only.') raise ModelDefinitionError("Primary key column cannot be pydantic only.")
@property @property
def is_required(self) -> bool: def is_required(self) -> bool:
return not self.nullable and not self.has_default and not self.is_auto_primary_key return (
not self.nullable and not self.has_default and not self.is_auto_primary_key
)
@property @property
def default_value(self) -> Any: def default_value(self) -> Any:
@ -81,16 +88,19 @@ class BaseField:
def get_constraints(self) -> Optional[List]: def get_constraints(self) -> Optional[List]:
return [] return []
def expand_relationship(self, value, child) -> Any: def expand_relationship(self, value: Any, child: "Model") -> Any:
return value return value
class String(BaseField): class String(BaseField):
__type__ = str __type__ = str
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
assert 'length' in kwargs, 'length is required' if "length" not in kwargs:
self.length = kwargs.pop('length') raise ModelDefinitionError(
"Param length is required for String model field."
)
self.length = kwargs.pop("length")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def get_column_type(self) -> sqlalchemy.Column: def get_column_type(self) -> sqlalchemy.Column:
@ -163,27 +173,41 @@ class BigInteger(BaseField):
class Decimal(BaseField): class Decimal(BaseField):
__type__ = decimal.Decimal __type__ = decimal.Decimal
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
assert 'precision' in kwargs, 'precision is required' if "length" not in kwargs or "precision" not in kwargs:
assert 'length' in kwargs, 'length is required' raise ModelDefinitionError(
self.length = kwargs.pop('length') "Params length and precision are required for Decimal model field."
self.precision = kwargs.pop('precision') )
self.length = kwargs.pop("length")
self.precision = kwargs.pop("precision")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def get_column_type(self) -> sqlalchemy.Column: def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.DECIMAL(self.length, self.precision) return sqlalchemy.DECIMAL(self.length, self.precision)
def create_dummy_instance(fk: Type['Model'], pk: int = None) -> 'Model': def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model":
init_dict = {fk.__pkname__: pk or -1} init_dict = {fk.__pkname__: pk or -1}
init_dict = {**init_dict, **{k: create_dummy_instance(v.to) init_dict = {
for k, v in fk.__model_fields__.items() **init_dict,
if isinstance(v, ForeignKey) and not v.nullable and not v.virtual}} **{
k: create_dummy_instance(v.to)
for k, v in fk.__model_fields__.items()
if isinstance(v, ForeignKey) and not v.nullable and not v.virtual
},
}
return fk(**init_dict) return fk(**init_dict)
class ForeignKey(BaseField): class ForeignKey(BaseField):
def __init__(self, to, name: str = None, related_name: str = None, nullable: bool = True, virtual: bool = False): def __init__(
self,
to: Type["Model"],
name: str = None,
related_name: str = None,
nullable: bool = True,
virtual: bool = False,
) -> None:
super().__init__(nullable=nullable, name=name) super().__init__(nullable=nullable, name=name)
self.virtual = virtual self.virtual = virtual
self.related_name = related_name self.related_name = related_name
@ -201,11 +225,16 @@ class ForeignKey(BaseField):
to_column = self.to.__model_fields__[self.to.__pkname__] to_column = self.to.__model_fields__[self.to.__pkname__]
return to_column.get_column_type() return to_column.get_column_type()
def expand_relationship(self, value, child) -> Union['Model', List['Model']]: def expand_relationship(
self, value: Any, child: "Model"
) -> Union["Model", List["Model"]]:
if not isinstance(value, (self.to, dict, int, str, list)) or ( if not isinstance(value, (self.to, dict, int, str, list)) or (
isinstance(value, orm.models.Model) and not isinstance(value, self.to)): isinstance(value, orm.models.Model) and not isinstance(value, self.to)
):
raise RelationshipInstanceError( raise RelationshipInstanceError(
'Relationship model can be build only from orm.Model, dict and integer or string (pk).') "Relationship model can be build only from orm.Model, "
"dict and integer or string (pk)."
)
if isinstance(value, list) and not isinstance(value, self.to): if isinstance(value, list) and not isinstance(value, self.to):
model = [self.expand_relationship(val, child) for val in value] model = [self.expand_relationship(val, child) for val in value]
return model return model
@ -217,19 +246,27 @@ class ForeignKey(BaseField):
else: else:
model = create_dummy_instance(fk=self.to, pk=value) model = create_dummy_instance(fk=self.to, pk=value)
child_model_name = self.related_name or child.__class__.__name__.lower() + 's' child_model_name = self.related_name or child.__class__.__name__.lower() + "s"
model._orm_relationship_manager.add_relation(model.__class__.__name__.lower(), model._orm_relationship_manager.add_relation(
child.__class__.__name__.lower(), model.__class__.__name__.lower(),
model, child, virtual=self.virtual) child.__class__.__name__.lower(),
model,
child,
virtual=self.virtual,
)
if child_model_name not in model.__fields__ \ if (
and child.__class__.__name__.lower() not in model.__fields__: child_model_name not in model.__fields__
model.__fields__[child_model_name] = ModelField(name=child_model_name, and child.__class__.__name__.lower() not in model.__fields__
type_=Optional[child.__pydantic_model__], ):
model_config=child.__pydantic_model__.__config__, model.__fields__[child_model_name] = ModelField(
class_validators=child.__pydantic_model__.__validators__) name=child_model_name,
model.__model_fields__[child_model_name] = ForeignKey(child.__class__, type_=Optional[child.__pydantic_model__],
name=child_model_name, model_config=child.__pydantic_model__.__config__,
virtual=True) class_validators=child.__pydantic_model__.__validators__,
)
model.__model_fields__[child_model_name] = ForeignKey(
child.__class__, name=child_model_name, virtual=True
)
return model return model

View File

@ -1,26 +0,0 @@
from typing import Union, Set, Dict # pragma no cover
class Excludable: # pragma no cover
@staticmethod
def get_excluded(exclude: Union[Set, Dict, None], key: str = None):
# print(f'checking excluded for {key}', exclude)
if isinstance(exclude, dict):
if isinstance(exclude.get(key, {}), dict) and '__all__' in exclude.get(key, {}).keys():
return exclude.get(key).get('__all__')
return exclude.get(key, {})
return exclude
@staticmethod
def is_excluded(exclude: Union[Set, Dict, None], key: str = None):
if exclude is None:
return False
to_exclude = Excludable.get_excluded(exclude, key)
# print(f'to exclude for current key = {key}', to_exclude)
if isinstance(to_exclude, Set):
return key in to_exclude
elif to_exclude is ...:
return True
return False

View File

@ -2,35 +2,39 @@ import copy
import inspect import inspect
import json import json
import uuid import uuid
from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar, Tuple from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
from typing import Set, Dict from typing import Callable, Dict, Set
import databases import databases
import pydantic
import sqlalchemy
from pydantic import BaseModel, BaseConfig, create_model
import orm.queryset as qry import orm.queryset as qry
from orm.exceptions import ModelDefinitionError from orm.exceptions import ModelDefinitionError
from orm.fields import BaseField, ForeignKey from orm.fields import BaseField, ForeignKey
from orm.relations import RelationshipManager from orm.relations import RelationshipManager
import pydantic
from pydantic import BaseConfig, BaseModel, create_model
import sqlalchemy
relationship_manager = RelationshipManager() relationship_manager = RelationshipManager()
def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]:
pydantic_fields = {field_name: ( pydantic_fields = {
base_field.__type__, field_name: (
... if base_field.is_required else base_field.default_value base_field.__type__,
) ... if base_field.is_required else base_field.default_value,
)
for field_name, base_field in object_dict.items() for field_name, base_field in object_dict.items()
if isinstance(base_field, BaseField)} if isinstance(base_field, BaseField)
}
return pydantic_fields return pydantic_fields
def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename: str) -> Tuple[Optional[str], def sqlalchemy_columns_from_model_fields(
List[sqlalchemy.Column], name: str, object_dict: Dict, tablename: str
Dict[str, BaseField]]: ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
pkname: Optional[str] = None pkname: Optional[str] = None
columns: List[sqlalchemy.Column] = [] columns: List[sqlalchemy.Column] = []
model_fields: Dict[str, BaseField] = {} model_fields: Dict[str, BaseField] = {}
@ -42,9 +46,16 @@ def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename
if field.primary_key: if field.primary_key:
pkname = field_name pkname = field_name
if isinstance(field, ForeignKey): if isinstance(field, ForeignKey):
reverse_name = field.related_name or field.to.__name__.lower().title() + '_' + name.lower() + 's' reverse_name = (
relation_name = name.lower().title() + '_' + field.to.__name__.lower() field.related_name
relationship_manager.add_relation_type(relation_name, reverse_name, field, tablename) or field.to.__name__.lower().title() + "_" + name.lower() + "s"
)
relation_name = (
name.lower().title() + "_" + field.to.__name__.lower()
)
relationship_manager.add_relation_type(
relation_name, reverse_name, field, tablename
)
columns.append(field.get_column(field_name)) columns.append(field.get_column(field_name))
return pkname, columns, model_fields return pkname, columns, model_fields
@ -57,9 +68,7 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
class ModelMetaclass(type): class ModelMetaclass(type):
def __new__( def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
mcs: type, name: str, bases: Any, attrs: dict
) -> type:
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
) )
@ -71,25 +80,29 @@ class ModelMetaclass(type):
metadata = attrs["__metadata__"] metadata = attrs["__metadata__"]
# sqlalchemy table creation # sqlalchemy table creation
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs, tablename) pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) name, attrs, tablename
attrs['__columns__'] = columns )
attrs['__pkname__'] = pkname attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns)
attrs["__columns__"] = columns
attrs["__pkname__"] = pkname
if not pkname: if not pkname:
raise ModelDefinitionError('Table has to have a primary key.') raise ModelDefinitionError("Table has to have a primary key.")
# pydantic model creation # pydantic model creation
pydantic_fields = parse_pydantic_field_from_model_fields(attrs) pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
pydantic_model = create_model(name, __config__=get_pydantic_base_orm_config(), **pydantic_fields) pydantic_model = create_model(
attrs['__pydantic_fields__'] = pydantic_fields name, __config__=get_pydantic_base_orm_config(), **pydantic_fields
attrs['__pydantic_model__'] = pydantic_model )
attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__) attrs["__pydantic_fields__"] = pydantic_fields
attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__) attrs["__pydantic_model__"] = pydantic_model
attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__) attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__)
attrs['__model_fields__'] = model_fields attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__)
attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__)
attrs["__model_fields__"] = model_fields
attrs['_orm_relationship_manager'] = relationship_manager attrs["_orm_relationship_manager"] = relationship_manager
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
@ -99,7 +112,8 @@ class ModelMetaclass(type):
class Model(list, metaclass=ModelMetaclass): class Model(list, metaclass=ModelMetaclass):
# Model inherits from list in order to be treated as request.Body parameter in fastapi routes, # Model inherits from list in order to be treated as
# request.Body parameter in fastapi routes,
# inheriting from pydantic.BaseModel causes metaclass conflicts # inheriting from pydantic.BaseModel causes metaclass conflicts
__abstract__ = True __abstract__ = True
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -115,17 +129,20 @@ class Model(list, metaclass=ModelMetaclass):
objects = qry.QuerySet() objects = qry.QuerySet()
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
self._orm_id: str = uuid.uuid4().hex self._orm_id: str = uuid.uuid4().hex
self._orm_saved: bool = False self._orm_saved: bool = False
self.values: Optional[BaseModel] = None self.values: Optional[BaseModel] = None
if "pk" in kwargs: if "pk" in kwargs:
kwargs[self.__pkname__] = kwargs.pop("pk") kwargs[self.__pkname__] = kwargs.pop("pk")
kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()} kwargs = {
k: self.__model_fields__[k].expand_relationship(v, self)
for k, v in kwargs.items()
}
self.values = self.__pydantic_model__(**kwargs) self.values = self.__pydantic_model__(**kwargs)
def __del__(self): def __del__(self) -> None:
self._orm_relationship_manager.deregister(self) self._orm_relationship_manager.deregister(self)
def __setattr__(self, key: str, value: Any) -> None: def __setattr__(self, key: str, value: Any) -> None:
@ -138,20 +155,24 @@ class Model(list, metaclass=ModelMetaclass):
value = self.__model_fields__[key].expand_relationship(value, self) value = self.__model_fields__[key].expand_relationship(value, self)
relation_key = self.__class__.__name__.title() + '_' + key relation_key = self.__class__.__name__.title() + "_" + key
if not self._orm_relationship_manager.contains(relation_key, self): if not self._orm_relationship_manager.contains(relation_key, self):
setattr(self.values, key, value) setattr(self.values, key, value)
else: else:
super().__setattr__(key, value) super().__setattr__(key, value)
def __getattribute__(self, key: str) -> Any: def __getattribute__(self, key: str) -> Any:
if key != '__fields__' and key in self.__fields__: if key != "__fields__" and key in self.__fields__:
relation_key = self.__class__.__name__.title() + '_' + key relation_key = self.__class__.__name__.title() + "_" + key
if self._orm_relationship_manager.contains(relation_key, self): if self._orm_relationship_manager.contains(relation_key, self):
return self._orm_relationship_manager.get(relation_key, self) return self._orm_relationship_manager.get(relation_key, self)
item = getattr(self.values, key, None) item = getattr(self.values, key, None)
if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str): if (
item is not None
and self.is_conversion_to_json_needed(key)
and isinstance(item, str)
):
try: try:
item = json.loads(item) item = json.loads(item)
except TypeError: # pragma no cover except TypeError: # pragma no cover
@ -159,30 +180,41 @@ class Model(list, metaclass=ModelMetaclass):
return item return item
return super().__getattribute__(key) return super().__getattribute__(key)
def __eq__(self, other): def __eq__(self, other: "Model") -> bool:
return self.values.dict() == other.values.dict() return self.values.dict() == other.values.dict()
def __same__(self, other): def __same__(self, other: "Model") -> bool:
assert self.__class__ == other.__class__ if self.__class__ != other.__class__:
return False
return self._orm_id == other._orm_id or ( return self._orm_id == other._orm_id or (
self.values is not None and other.values is not None and self.pk == other.pk) self.values is not None and other.values is not None and self.pk == other.pk
)
def __repr__(self): # pragma no cover def __repr__(self) -> str: # pragma no cover
return self.values.__repr__() return self.values.__repr__()
@classmethod @classmethod
def from_row(cls, row, select_related: List = None, previous_table: str = None) -> 'Model': def from_row(
cls,
row: sqlalchemy.engine.ResultProxy,
select_related: List = None,
previous_table: str = None,
) -> "Model":
item = {} item = {}
select_related = select_related or [] select_related = select_related or []
table_prefix = cls._orm_relationship_manager.resolve_relation_join(previous_table, cls.__table__.name) table_prefix = cls._orm_relationship_manager.resolve_relation_join(
previous_table, cls.__table__.name
)
previous_table = cls.__table__.name previous_table = cls.__table__.name
for related in select_related: for related in select_related:
if "__" in related: if "__" in related:
first_part, remainder = related.split("__", 1) first_part, remainder = related.split("__", 1)
model_cls = cls.__model_fields__[first_part].to model_cls = cls.__model_fields__[first_part].to
child = model_cls.from_row(row, select_related=[remainder], previous_table=previous_table) child = model_cls.from_row(
row, select_related=[remainder], previous_table=previous_table
)
item[first_part] = child item[first_part] = child
else: else:
model_cls = cls.__model_fields__[related].to model_cls = cls.__model_fields__[related].to
@ -191,7 +223,9 @@ class Model(list, metaclass=ModelMetaclass):
for column in cls.__table__.columns: for column in cls.__table__.columns:
if column.name not in item: if column.name not in item:
item[column.name] = row[f'{table_prefix + "_" if table_prefix else ""}{column.name}'] item[column.name] = row[
f'{table_prefix + "_" if table_prefix else ""}{column.name}'
]
return cls(**item) return cls(**item)
@ -200,7 +234,7 @@ class Model(list, metaclass=ModelMetaclass):
# return cls.__pydantic_model__.validate(value=value) # return cls.__pydantic_model__.validate(value=value)
@classmethod @classmethod
def __get_validators__(cls): # pragma no cover def __get_validators__(cls) -> Callable: # pragma no cover
yield cls.__pydantic_model__.validate yield cls.__pydantic_model__.validate
# @classmethod # @classmethod
@ -211,11 +245,11 @@ class Model(list, metaclass=ModelMetaclass):
return self.__model_fields__.get(column_name).__type__ == pydantic.Json return self.__model_fields__.get(column_name).__type__ == pydantic.Json
@property @property
def pk(self): def pk(self) -> str:
return getattr(self.values, self.__pkname__) return getattr(self.values, self.__pkname__)
@pk.setter @pk.setter
def pk(self, value): def pk(self, value: Any) -> None:
setattr(self.values, self.__pkname__, value) setattr(self.values, self.__pkname__, value)
@property @property
@ -229,7 +263,9 @@ class Model(list, metaclass=ModelMetaclass):
if isinstance(nested_model, list): if isinstance(nested_model, list):
dict_instance[field] = [x.dict() for x in nested_model] dict_instance[field] = [x.dict() for x in nested_model]
else: else:
dict_instance[field] = nested_model.dict() if nested_model is not None else {} dict_instance[field] = (
nested_model.dict() if nested_model is not None else {}
)
return dict_instance return dict_instance
def from_dict(self, value_dict: Dict) -> None: def from_dict(self, value_dict: Dict) -> None:
@ -245,16 +281,22 @@ class Model(list, metaclass=ModelMetaclass):
def extract_related_names(cls) -> Set: def extract_related_names(cls) -> Set:
related_names = set() related_names = set()
for name, field in cls.__fields__.items(): for name, field in cls.__fields__.items():
if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): if inspect.isclass(field.type_) and issubclass(
field.type_, pydantic.BaseModel
):
related_names.add(name) related_names.add(name)
return related_names return related_names
def extract_model_db_fields(self) -> Dict: def extract_model_db_fields(self) -> Dict:
self_fields = self.extract_own_model_fields() self_fields = self.extract_own_model_fields()
self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns} self_fields = {
k: v for k, v in self_fields.items() if k in self.__table__.columns
}
for field in self.extract_related_names(): for field in self.extract_related_names():
if getattr(self, field) is not None: if getattr(self, field) is not None:
self_fields[field] = getattr(getattr(self, field), self.__model_fields__[field].to.__pkname__) self_fields[field] = getattr(
getattr(self, field), self.__model_fields__[field].to.__pkname__
)
return self_fields return self_fields
async def save(self) -> int: async def save(self) -> int:
@ -264,7 +306,7 @@ class Model(list, metaclass=ModelMetaclass):
expr = self.__table__.insert() expr = self.__table__.insert()
expr = expr.values(**self_fields) expr = expr.values(**self_fields)
item_id = await self.__database__.execute(expr) item_id = await self.__database__.execute(expr)
setattr(self, 'pk', item_id) self.pk = item_id
return item_id return item_id
async def update(self, **kwargs: Any) -> int: async def update(self, **kwargs: Any) -> int:
@ -274,8 +316,11 @@ class Model(list, metaclass=ModelMetaclass):
self_fields = self.extract_model_db_fields() self_fields = self.extract_model_db_fields()
self_fields.pop(self.__pkname__) self_fields.pop(self.__pkname__)
expr = self.__table__.update().values(**self_fields).where( expr = (
self.pk_column == getattr(self, self.__pkname__)) self.__table__.update()
.values(**self_fields)
.where(self.pk_column == getattr(self, self.__pkname__))
)
result = await self.__database__.execute(expr) result = await self.__database__.execute(expr)
return result return result
@ -285,7 +330,7 @@ class Model(list, metaclass=ModelMetaclass):
result = await self.__database__.execute(expr) result = await self.__database__.execute(expr)
return result return result
async def load(self) -> 'Model': async def load(self) -> "Model":
expr = self.__table__.select().where(self.pk_column == self.pk) expr = self.__table__.select().where(self.pk_column == self.pk)
row = await self.__database__.fetch_one(expr) row = await self.__database__.fetch_one(expr)
self.from_dict(dict(row)) self.from_dict(dict(row))

View File

@ -1,11 +1,14 @@
from typing import List, TYPE_CHECKING, Type, NamedTuple from typing import Any, List, NamedTuple, TYPE_CHECKING, Tuple, Type, Union
import sqlalchemy import databases
from sqlalchemy import text
import orm import orm
from orm import ForeignKey from orm import ForeignKey
from orm.exceptions import NoMatch, MultipleMatches from orm.exceptions import MultipleMatches, NoMatch
from orm.fields import BaseField
import sqlalchemy
from sqlalchemy import text
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from orm.models import Model from orm.models import Model
@ -24,17 +27,23 @@ FILTER_OPERATORS = {
class JoinParameters(NamedTuple): class JoinParameters(NamedTuple):
prev_model: Type['Model'] prev_model: Type["Model"]
previous_alias: str previous_alias: str
from_table: str from_table: str
model_cls: Type['Model'] model_cls: Type["Model"]
class QuerySet: class QuerySet:
ESCAPE_CHARACTERS = ['%', '_'] ESCAPE_CHARACTERS = ["%", "_"]
def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None, def __init__(
limit_count: int = None, offset: int = None): self,
model_cls: Type["Model"] = None,
filter_clauses: List = None,
select_related: List = None,
limit_count: int = None,
offset: int = None,
) -> None:
self.model_cls = model_cls self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses self.filter_clauses = [] if filter_clauses is None else filter_clauses
self._select_related = [] if select_related is None else select_related self._select_related = [] if select_related is None else select_related
@ -48,47 +57,77 @@ class QuerySet:
self.columns = None self.columns = None
self.order_bys = None self.order_bys = None
def __get__(self, instance, owner): def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
return self.__class__(model_cls=owner) return self.__class__(model_cls=owner)
@property @property
def database(self): def database(self) -> databases.Database:
return self.model_cls.__database__ return self.model_cls.__database__
@property @property
def table(self): def table(self) -> sqlalchemy.Table:
return self.model_cls.__table__ return self.model_cls.__table__
def prefixed_columns(self, alias, table): def prefixed_columns(self, alias: str, table: sqlalchemy.Table) -> List[text]:
return [text(f'{alias}_{table.name}.{column.name} as {alias}_{column.name}') return [
for column in table.columns] text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
for column in table.columns
]
def prefixed_table_name(self, alias, name): def prefixed_table_name(self, alias: str, name: str) -> text:
return text(f'{name} {alias}_{name}') return text(f"{name} {alias}_{name}")
def on_clause(self, from_table, to_table, previous_alias, alias, to_key, from_key): def on_clause(
return text(f'{alias}_{to_table}.{to_key}=' self,
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}') from_table: str,
to_table: str,
previous_alias: str,
alias: str,
to_key: str,
from_key: str,
) -> text:
return text(
f"{alias}_{to_table}.{to_key}="
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}'
)
def build_join_parameters(self, part, join_params: JoinParameters): def build_join_parameters(
self, part: str, join_params: JoinParameters
) -> JoinParameters:
model_cls = join_params.model_cls.__model_fields__[part].to model_cls = join_params.model_cls.__model_fields__[part].to
to_table = model_cls.__table__.name to_table = model_cls.__table__.name
alias = model_cls._orm_relationship_manager.resolve_relation_join(join_params.from_table, to_table) alias = model_cls._orm_relationship_manager.resolve_relation_join(
join_params.from_table, to_table
)
if alias not in self.used_aliases: if alias not in self.used_aliases:
if join_params.prev_model.__model_fields__[part].virtual: if join_params.prev_model.__model_fields__[part].virtual:
to_key = next((v for k, v in model_cls.__model_fields__.items() to_key = next(
if isinstance(v, ForeignKey) and v.to == join_params.prev_model), None).name (
v
for k, v in model_cls.__model_fields__.items()
if isinstance(v, ForeignKey) and v.to == join_params.prev_model
),
None,
).name
from_key = model_cls.__pkname__ from_key = model_cls.__pkname__
else: else:
to_key = model_cls.__pkname__ to_key = model_cls.__pkname__
from_key = part from_key = part
on_clause = self.on_clause(join_params.from_table, to_table, join_params.previous_alias, alias, to_key, on_clause = self.on_clause(
from_key) join_params.from_table,
to_table,
join_params.previous_alias,
alias,
to_key,
from_key,
)
target_table = self.prefixed_table_name(alias, to_table) target_table = self.prefixed_table_name(alias, to_table)
self.select_from = sqlalchemy.sql.outerjoin(self.select_from, target_table, on_clause) self.select_from = sqlalchemy.sql.outerjoin(
self.order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}')) self.select_from, target_table, on_clause
)
self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}"))
self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) self.columns.extend(self.prefixed_columns(alias, model_cls.__table__))
self.used_aliases.append(alias) self.used_aliases.append(alias)
@ -98,44 +137,76 @@ class QuerySet:
return JoinParameters(prev_model, previous_alias, from_table, model_cls) return JoinParameters(prev_model, previous_alias, from_table, model_cls)
@staticmethod @staticmethod
def field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part) -> bool: def field_is_a_foreign_key_and_no_circular_reference(
field: BaseField, field_name: str, rel_part: str
) -> bool:
return isinstance(field, ForeignKey) and field_name not in rel_part return isinstance(field, ForeignKey) and field_name not in rel_part
def field_qualifies_to_deeper_search(self, field, parent_virtual, nested, rel_part) -> bool: def field_qualifies_to_deeper_search(
self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str
) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1]) prev_part_of_related = "__".join(rel_part.split("__")[:-1])
partial_match = any([x.startswith(prev_part_of_related) for x in self._select_related]) partial_match = any(
[x.startswith(prev_part_of_related) for x in self._select_related]
)
already_checked = any([x.startswith(rel_part) for x in self.auto_related]) already_checked = any([x.startswith(rel_part) for x in self.auto_related])
return ((field.virtual and parent_virtual) or (partial_match and not already_checked)) or not nested return (
(field.virtual and parent_virtual)
or (partial_match and not already_checked)
) or not nested
def extract_auto_required_relations(self, join_params: JoinParameters, def extract_auto_required_relations(
rel_part: str = '', nested: bool = False, parent_virtual: bool = False): self,
join_params: JoinParameters,
rel_part: str = "",
nested: bool = False,
parent_virtual: bool = False,
) -> None:
for field_name, field in join_params.prev_model.__model_fields__.items(): for field_name, field in join_params.prev_model.__model_fields__.items():
if self.field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part): if self.field_is_a_foreign_key_and_no_circular_reference(
rel_part = field_name if not rel_part else rel_part + '__' + field_name field, field_name, rel_part
):
rel_part = field_name if not rel_part else rel_part + "__" + field_name
if not field.nullable: if not field.nullable:
if rel_part not in self._select_related: if rel_part not in self._select_related:
self.auto_related.append("__".join(rel_part.split("__")[:-1])) self.auto_related.append("__".join(rel_part.split("__")[:-1]))
rel_part = '' rel_part = ""
elif self.field_qualifies_to_deeper_search(field, parent_virtual, nested, rel_part): elif self.field_qualifies_to_deeper_search(
join_params = JoinParameters(field.to, join_params.previous_alias, field, parent_virtual, nested, rel_part
join_params.from_table, join_params.prev_model) ):
self.extract_auto_required_relations(join_params=join_params, join_params = JoinParameters(
rel_part=rel_part, nested=True, parent_virtual=field.virtual) field.to,
join_params.previous_alias,
join_params.from_table,
join_params.prev_model,
)
self.extract_auto_required_relations(
join_params=join_params,
rel_part=rel_part,
nested=True,
parent_virtual=field.virtual,
)
else: else:
rel_part = '' rel_part = ""
def build_select_expression(self): def build_select_expression(self) -> sqlalchemy.sql.select:
self.columns = list(self.table.columns) self.columns = list(self.table.columns)
self.order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')] self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")]
self.select_from = self.table self.select_from = self.table
for key in self.model_cls.__model_fields__: for key in self.model_cls.__model_fields__:
if not self.model_cls.__model_fields__[key].nullable \ if (
and isinstance(self.model_cls.__model_fields__[key], orm.fields.ForeignKey) \ not self.model_cls.__model_fields__[key].nullable
and key not in self._select_related: and isinstance(
self.model_cls.__model_fields__[key], orm.fields.ForeignKey
)
and key not in self._select_related
):
self._select_related = [key] + self._select_related self._select_related = [key] + self._select_related
start_params = JoinParameters(self.model_cls, '', self.table.name, self.model_cls) start_params = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls
)
self.extract_auto_required_relations(start_params) self.extract_auto_required_relations(start_params)
if self.auto_related: if self.auto_related:
new_joins = [] new_joins = []
@ -146,7 +217,9 @@ class QuerySet:
self._select_related.sort(key=lambda item: (-len(item), item)) self._select_related.sort(key=lambda item: (-len(item), item))
for item in self._select_related: for item in self._select_related:
join_parameters = JoinParameters(self.model_cls, '', self.table.name, self.model_cls) join_parameters = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls
)
for part in item.split("__"): for part in item.split("__"):
join_parameters = self.build_join_parameters(part, join_parameters) join_parameters = self.build_join_parameters(part, join_parameters)
@ -180,7 +253,7 @@ class QuerySet:
return expr return expr
def filter(self, **kwargs): def filter(self, **kwargs: Any) -> "QuerySet":
filter_clauses = self.filter_clauses filter_clauses = self.filter_clauses
select_related = list(self._select_related) select_related = list(self._select_related)
@ -189,7 +262,7 @@ class QuerySet:
kwargs[pk_name] = kwargs.pop("pk") kwargs[pk_name] = kwargs.pop("pk")
for key, value in kwargs.items(): for key, value in kwargs.items():
table_prefix = '' table_prefix = ""
if "__" in key: if "__" in key:
parts = key.split("__") parts = key.split("__")
@ -215,9 +288,13 @@ class QuerySet:
# against which the comparison is being made. # against which the comparison is being made.
previous_table = model_cls.__tablename__ previous_table = model_cls.__tablename__
for part in related_parts: for part in related_parts:
current_table = model_cls.__model_fields__[part].to.__tablename__ current_table = model_cls.__model_fields__[
table_prefix = model_cls._orm_relationship_manager.resolve_relation_join(previous_table, part
current_table) ].to.__tablename__
manager = model_cls._orm_relationship_manager
table_prefix = manager.resolve_relation_join(
previous_table, current_table
)
model_cls = model_cls.__model_fields__[part].to model_cls = model_cls.__model_fields__[part].to
previous_table = current_table previous_table = current_table
@ -236,25 +313,32 @@ class QuerySet:
has_escaped_character = False has_escaped_character = False
if op in ["contains", "icontains"]: if op in ["contains", "icontains"]:
has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS has_escaped_character = any(
if c in value) c for c in self.ESCAPE_CHARACTERS if c in value
)
if has_escaped_character: if has_escaped_character:
# enable escape modifier # enable escape modifier
for char in self.ESCAPE_CHARACTERS: for char in self.ESCAPE_CHARACTERS:
value = value.replace(char, f'\\{char}') value = value.replace(char, f"\\{char}")
value = f"%{value}%" value = f"%{value}%"
if isinstance(value, orm.Model): if isinstance(value, orm.Model):
value = value.pk value = value.pk
clause = getattr(column, op_attr)(value) clause = getattr(column, op_attr)(value)
clause.modifiers['escape'] = '\\' if has_escaped_character else None clause.modifiers["escape"] = "\\" if has_escaped_character else None
clause_text = str(clause.compile(dialect=self.model_cls.__database__._backend._dialect, clause_text = str(
compile_kwargs={"literal_binds": True})) clause.compile(
alias = f'{table_prefix}_' if table_prefix else '' dialect=self.model_cls.__database__._backend._dialect,
aliased_name = f'{alias}{table.name}.{column.name}' compile_kwargs={"literal_binds": True},
clause_text = clause_text.replace(f'{table.name}.{column.name}', aliased_name) )
)
alias = f"{table_prefix}_" if table_prefix else ""
aliased_name = f"{alias}{table.name}.{column.name}"
clause_text = clause_text.replace(
f"{table.name}.{column.name}", aliased_name
)
clause = text(clause_text) clause = text(clause_text)
filter_clauses.append(clause) filter_clauses.append(clause)
@ -264,10 +348,10 @@ class QuerySet:
filter_clauses=filter_clauses, filter_clauses=filter_clauses,
select_related=select_related, select_related=select_related,
limit_count=self.limit_count, limit_count=self.limit_count,
offset=self.query_offset offset=self.query_offset,
) )
def select_related(self, related): def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet":
if not isinstance(related, (list, tuple)): if not isinstance(related, (list, tuple)):
related = [related] related = [related]
@ -277,7 +361,7 @@ class QuerySet:
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
select_related=related, select_related=related,
limit_count=self.limit_count, limit_count=self.limit_count,
offset=self.query_offset offset=self.query_offset,
) )
async def exists(self) -> bool: async def exists(self) -> bool:
@ -290,25 +374,25 @@ class QuerySet:
expr = sqlalchemy.func.count().select().select_from(expr) expr = sqlalchemy.func.count().select().select_from(expr)
return await self.database.fetch_val(expr) return await self.database.fetch_val(expr)
def limit(self, limit_count: int): def limit(self, limit_count: int) -> "QuerySet":
return self.__class__( return self.__class__(
model_cls=self.model_cls, model_cls=self.model_cls,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
select_related=self._select_related, select_related=self._select_related,
limit_count=limit_count, limit_count=limit_count,
offset=self.query_offset offset=self.query_offset,
) )
def offset(self, offset: int): def offset(self, offset: int) -> "QuerySet":
return self.__class__( return self.__class__(
model_cls=self.model_cls, model_cls=self.model_cls,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
select_related=self._select_related, select_related=self._select_related,
limit_count=self.limit_count, limit_count=self.limit_count,
offset=offset offset=offset,
) )
async def first(self, **kwargs): async def first(self, **kwargs: Any) -> "Model":
if kwargs: if kwargs:
return await self.filter(**kwargs).first() return await self.filter(**kwargs).first()
@ -316,7 +400,7 @@ class QuerySet:
if rows: if rows:
return rows[0] return rows[0]
async def get(self, **kwargs): async def get(self, **kwargs: Any) -> "Model":
if kwargs: if kwargs:
return await self.filter(**kwargs).get() return await self.filter(**kwargs).get()
@ -329,7 +413,7 @@ class QuerySet:
raise MultipleMatches() raise MultipleMatches()
return self.model_cls.from_row(rows[0], select_related=self._select_related) return self.model_cls.from_row(rows[0], select_related=self._select_related)
async def all(self, **kwargs): async def all(self, **kwargs: Any) -> List["Model"]:
if kwargs: if kwargs:
return await self.filter(**kwargs).all() return await self.filter(**kwargs).all()
@ -345,7 +429,7 @@ class QuerySet:
return result_rows return result_rows
@classmethod @classmethod
def merge_result_rows(cls, result_rows): def merge_result_rows(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows = [] merged_rows = []
for index, model in enumerate(result_rows): for index, model in enumerate(result_rows):
if index > 0 and model.pk == result_rows[index - 1].pk: if index > 0 and model.pk == result_rows[index - 1].pk:
@ -355,30 +439,45 @@ class QuerySet:
return merged_rows return merged_rows
@classmethod @classmethod
def merge_two_instances(cls, one: 'Model', other: 'Model'): def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
for field in one.__model_fields__.keys(): for field in one.__model_fields__.keys():
# print(field, one.dict(), other.dict()) # print(field, one.dict(), other.dict())
if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model): if isinstance(getattr(one, field), list) and not isinstance(
getattr(one, field), orm.models.Model
):
setattr(other, field, getattr(one, field) + getattr(other, field)) setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), orm.models.Model): elif isinstance(getattr(one, field), orm.models.Model):
if getattr(one, field).pk == getattr(other, field).pk: if getattr(one, field).pk == getattr(other, field).pk:
setattr(other, field, cls.merge_two_instances(getattr(one, field), getattr(other, field))) setattr(
other,
field,
cls.merge_two_instances(
getattr(one, field), getattr(other, field)
),
)
return other return other
async def create(self, **kwargs): async def create(self, **kwargs: Any) -> "Model":
new_kwargs = dict(**kwargs) new_kwargs = dict(**kwargs)
# Remove primary key when None to prevent not null constraint in postgresql. # Remove primary key when None to prevent not null constraint in postgresql.
pkname = self.model_cls.__pkname__ pkname = self.model_cls.__pkname__
pk = self.model_cls.__model_fields__[pkname] pk = self.model_cls.__model_fields__[pkname]
if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement): if (
pkname in new_kwargs
and new_kwargs.get(pkname) is None
and (pk.nullable or pk.autoincrement)
):
del new_kwargs[pkname] del new_kwargs[pkname]
# substitute related models with their pk # substitute related models with their pk
for field in self.model_cls.extract_related_names(): for field in self.model_cls.extract_related_names():
if field in new_kwargs and new_kwargs.get(field) is not None: if field in new_kwargs and new_kwargs.get(field) is not None:
new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__) new_kwargs[field] = getattr(
new_kwargs.get(field),
self.model_cls.__model_fields__[field].to.__pkname__,
)
# Build the insert expression. # Build the insert expression.
expr = self.table.insert() expr = self.table.insert()

View File

@ -2,7 +2,7 @@ import pprint
import string import string
import uuid import uuid
from random import choices from random import choices
from typing import TYPE_CHECKING, List from typing import Dict, List, TYPE_CHECKING, Union
from weakref import proxy from weakref import proxy
from orm.fields import ForeignKey from orm.fields import ForeignKey
@ -11,40 +11,58 @@ if TYPE_CHECKING: # pragma no cover
from orm.models import Model from orm.models import Model
def get_table_alias(): def get_table_alias() -> str:
return ''.join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
def get_relation_config(relation_type: str, table_name: str, field: ForeignKey): def get_relation_config(
relation_type: str, table_name: str, field: ForeignKey
) -> Dict[str, str]:
alias = get_table_alias() alias = get_table_alias()
config = {'type': relation_type, config = {
'table_alias': alias, "type": relation_type,
'source_table': table_name if relation_type == 'primary' else field.to.__tablename__, "table_alias": alias,
'target_table': field.to.__tablename__ if relation_type == 'primary' else table_name "source_table": table_name
} if relation_type == "primary"
else field.to.__tablename__,
"target_table": field.to.__tablename__
if relation_type == "primary"
else table_name,
}
return config return config
class RelationshipManager: class RelationshipManager:
def __init__(self) -> None:
def __init__(self):
self._relations = dict() self._relations = dict()
def add_relation_type(self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str): def add_relation_type(
print(relations_key, reverse_key) self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str
) -> None:
if relations_key not in self._relations: if relations_key not in self._relations:
self._relations[relations_key] = get_relation_config('primary', table_name, field) self._relations[relations_key] = get_relation_config(
"primary", table_name, field
)
if reverse_key not in self._relations: if reverse_key not in self._relations:
self._relations[reverse_key] = get_relation_config('reverse', table_name, field) self._relations[reverse_key] = get_relation_config(
"reverse", table_name, field
)
def deregister(self, model: 'Model'): def deregister(self, model: "Model") -> None:
# print(f'deregistering {model.__class__.__name__}, {model._orm_id}') # print(f'deregistering {model.__class__.__name__}, {model._orm_id}')
for rel_type in self._relations.keys(): for rel_type in self._relations.keys():
if model.__class__.__name__.lower() in rel_type.lower(): if model.__class__.__name__.lower() in rel_type.lower():
if model._orm_id in self._relations[rel_type]: if model._orm_id in self._relations[rel_type]:
del self._relations[rel_type][model._orm_id] del self._relations[rel_type][model._orm_id]
def add_relation(self, parent_name: str, child_name: str, parent: 'Model', child: 'Model', virtual: bool = False): def add_relation(
self,
parent_name: str,
child_name: str,
parent: "Model",
child: "Model",
virtual: bool = False,
) -> None:
parent_id = parent._orm_id parent_id = parent._orm_id
child_id = child._orm_id child_id = child._orm_id
if virtual: if virtual:
@ -53,12 +71,18 @@ class RelationshipManager:
child, parent = parent, proxy(child) child, parent = parent, proxy(child)
else: else:
child = proxy(child) child = proxy(child)
parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, []) parents_list = self._relations[
parent_name.lower().title() + "_" + child_name + "s"
].setdefault(parent_id, [])
self.append_related_model(parents_list, child) self.append_related_model(parents_list, child)
children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, []) children_list = self._relations[
child_name.lower().title() + "_" + parent_name
].setdefault(child_id, [])
self.append_related_model(children_list, parent) self.append_related_model(children_list, parent)
def append_related_model(self, relations_list: List['Model'], model: 'Model'): def append_related_model(
self, relations_list: List["Model"], model: "Model"
) -> None:
for x in relations_list: for x in relations_list:
try: try:
if x.__same__(model): if x.__same__(model):
@ -68,26 +92,26 @@ class RelationshipManager:
relations_list.append(model) relations_list.append(model)
def contains(self, relations_key: str, object: 'Model'): def contains(self, relations_key: str, object: "Model") -> bool:
if relations_key in self._relations: if relations_key in self._relations:
return object._orm_id in self._relations[relations_key] return object._orm_id in self._relations[relations_key]
return False return False
def get(self, relations_key: str, object: 'Model'): def get(self, relations_key: str, object: "Model") -> Union["Model", List["Model"]]:
if relations_key in self._relations: if relations_key in self._relations:
if object._orm_id in self._relations[relations_key]: if object._orm_id in self._relations[relations_key]:
if self._relations[relations_key]['type'] == 'primary': if self._relations[relations_key]["type"] == "primary":
return self._relations[relations_key][object._orm_id][0] return self._relations[relations_key][object._orm_id][0]
return self._relations[relations_key][object._orm_id] return self._relations[relations_key][object._orm_id]
def resolve_relation_join(self, from_table: str, to_table: str) -> str: def resolve_relation_join(self, from_table: str, to_table: str) -> str:
for k, v in self._relations.items(): for k, v in self._relations.items():
if v['source_table'] == from_table and v['target_table'] == to_table: if v["source_table"] == from_table and v["target_table"] == to_table:
return self._relations[k]['table_alias'] return self._relations[k]["table_alias"]
return '' return ""
def __str__(self): # pragma no cover def __str__(self) -> str: # pragma no cover
return pprint.pformat(self._relations, indent=4, width=1) return pprint.pformat(self._relations, indent=4, width=1)
def __repr__(self): # pragma no cover def __repr__(self) -> str: # pragma no cover
return self.__str__() return self.__str__()

View File

@ -109,6 +109,22 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition():
test = fields.Integer(name='test12', primary_key=True, pydantic_only=True) test = fields.Integer(name='test12', primary_key=True, pydantic_only=True)
def test_decimal_error_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
__tablename__ = "example4"
__metadata__ = metadata
test = fields.Decimal(name='test12', primary_key=True)
def test_string_error_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
__tablename__ = "example4"
__metadata__ = metadata
test = fields.String(name='test12', primary_key=True)
def test_json_conversion_in_model(): def test_json_conversion_in_model():
with pytest.raises(pydantic.ValidationError): with pytest.raises(pydantic.ValidationError):
ExampleModel(test_json=datetime.datetime.now(), test=1, test_string='test', test_bool=True) ExampleModel(test_json=datetime.datetime.now(), test=1, test_string='test', test_bool=True)