liniting and applying black

This commit is contained in:
collerek
2020-08-09 07:51:06 +02:00
parent 9d9346fb13
commit 241628b1d9
9 changed files with 455 additions and 247 deletions

BIN
.coverage

Binary file not shown.

View File

@ -1,4 +1,5 @@
[flake8]
ignore = ANN101
ignore = ANN101, ANN102, W503
max-complexity = 10
max-line-length = 88
exclude = p38venv,.pytest_cache

View File

@ -1,6 +1,18 @@
from orm.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch
from orm.fields import BigInteger, Boolean, Date, DateTime, Decimal, Float, ForeignKey, Integer, JSON, String, Text, \
Time
from orm.fields import (
BigInteger,
Boolean,
Date,
DateTime,
Decimal,
Float,
ForeignKey,
Integer,
JSON,
String,
Text,
Time,
)
from orm.models import Model
__version__ = "0.0.1"
@ -21,5 +33,5 @@ __all__ = [
"ModelDefinitionError",
"ModelNotSet",
"MultipleMatches",
"NoMatch"
"NoMatch",
]

View File

@ -1,14 +1,15 @@
import datetime
import decimal
from typing import List, Optional, TYPE_CHECKING, Type, Any, Union
import sqlalchemy
from pydantic import Json, BaseModel
from pydantic.fields import ModelField
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
import orm
from orm.exceptions import ModelDefinitionError, RelationshipInstanceError
from pydantic import BaseModel, Json
from pydantic.fields import ModelField
import sqlalchemy
if TYPE_CHECKING: # pragma no cover
from orm.models import Model
@ -16,33 +17,39 @@ if TYPE_CHECKING: # pragma no cover
class BaseField:
__type__ = None
def __init__(self, *args, **kwargs) -> None:
name = kwargs.pop('name', None)
def __init__(self, *args: Any, **kwargs: Any) -> None:
name = kwargs.pop("name", None)
args = list(args)
if args:
if isinstance(args[0], str):
if name is not None:
raise ModelDefinitionError('Column name cannot be passed positionally and as a keyword.')
raise ModelDefinitionError(
"Column name cannot be passed positionally and as a keyword."
)
name = args.pop(0)
self.name = name
self.primary_key = kwargs.pop('primary_key', False)
self.autoincrement = kwargs.pop('autoincrement', self.primary_key and self.__type__ == int)
self.primary_key = kwargs.pop("primary_key", False)
self.autoincrement = kwargs.pop(
"autoincrement", self.primary_key and self.__type__ == int
)
self.nullable = kwargs.pop('nullable', not self.primary_key)
self.default = kwargs.pop('default', None)
self.server_default = kwargs.pop('server_default', None)
self.nullable = kwargs.pop("nullable", not self.primary_key)
self.default = kwargs.pop("default", None)
self.server_default = kwargs.pop("server_default", None)
self.index = kwargs.pop('index', None)
self.unique = kwargs.pop('unique', None)
self.index = kwargs.pop("index", None)
self.unique = kwargs.pop("unique", None)
self.pydantic_only = kwargs.pop('pydantic_only', False)
self.pydantic_only = kwargs.pop("pydantic_only", False)
if self.pydantic_only and self.primary_key:
raise ModelDefinitionError('Primary key column cannot be pydantic only.')
raise ModelDefinitionError("Primary key column cannot be pydantic only.")
@property
def is_required(self) -> bool:
return not self.nullable and not self.has_default and not self.is_auto_primary_key
return (
not self.nullable and not self.has_default and not self.is_auto_primary_key
)
@property
def default_value(self) -> Any:
@ -81,16 +88,19 @@ class BaseField:
def get_constraints(self) -> Optional[List]:
return []
def expand_relationship(self, value, child) -> Any:
def expand_relationship(self, value: Any, child: "Model") -> Any:
return value
class String(BaseField):
__type__ = str
def __init__(self, *args, **kwargs):
assert 'length' in kwargs, 'length is required'
self.length = kwargs.pop('length')
def __init__(self, *args: Any, **kwargs: Any) -> None:
if "length" not in kwargs:
raise ModelDefinitionError(
"Param length is required for String model field."
)
self.length = kwargs.pop("length")
super().__init__(*args, **kwargs)
def get_column_type(self) -> sqlalchemy.Column:
@ -163,27 +173,41 @@ class BigInteger(BaseField):
class Decimal(BaseField):
__type__ = decimal.Decimal
def __init__(self, *args, **kwargs):
assert 'precision' in kwargs, 'precision is required'
assert 'length' in kwargs, 'length is required'
self.length = kwargs.pop('length')
self.precision = kwargs.pop('precision')
def __init__(self, *args: Any, **kwargs: Any) -> None:
if "length" not in kwargs or "precision" not in kwargs:
raise ModelDefinitionError(
"Params length and precision are required for Decimal model field."
)
self.length = kwargs.pop("length")
self.precision = kwargs.pop("precision")
super().__init__(*args, **kwargs)
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.DECIMAL(self.length, self.precision)
def create_dummy_instance(fk: Type['Model'], pk: int = None) -> 'Model':
def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model":
init_dict = {fk.__pkname__: pk or -1}
init_dict = {**init_dict, **{k: create_dummy_instance(v.to)
for k, v in fk.__model_fields__.items()
if isinstance(v, ForeignKey) and not v.nullable and not v.virtual}}
init_dict = {
**init_dict,
**{
k: create_dummy_instance(v.to)
for k, v in fk.__model_fields__.items()
if isinstance(v, ForeignKey) and not v.nullable and not v.virtual
},
}
return fk(**init_dict)
class ForeignKey(BaseField):
def __init__(self, to, name: str = None, related_name: str = None, nullable: bool = True, virtual: bool = False):
def __init__(
self,
to: Type["Model"],
name: str = None,
related_name: str = None,
nullable: bool = True,
virtual: bool = False,
) -> None:
super().__init__(nullable=nullable, name=name)
self.virtual = virtual
self.related_name = related_name
@ -201,11 +225,16 @@ class ForeignKey(BaseField):
to_column = self.to.__model_fields__[self.to.__pkname__]
return to_column.get_column_type()
def expand_relationship(self, value, child) -> Union['Model', List['Model']]:
def expand_relationship(
self, value: Any, child: "Model"
) -> Union["Model", List["Model"]]:
if not isinstance(value, (self.to, dict, int, str, list)) or (
isinstance(value, orm.models.Model) and not isinstance(value, self.to)):
isinstance(value, orm.models.Model) and not isinstance(value, self.to)
):
raise RelationshipInstanceError(
'Relationship model can be build only from orm.Model, dict and integer or string (pk).')
"Relationship model can be build only from orm.Model, "
"dict and integer or string (pk)."
)
if isinstance(value, list) and not isinstance(value, self.to):
model = [self.expand_relationship(val, child) for val in value]
return model
@ -217,19 +246,27 @@ class ForeignKey(BaseField):
else:
model = create_dummy_instance(fk=self.to, pk=value)
child_model_name = self.related_name or child.__class__.__name__.lower() + 's'
model._orm_relationship_manager.add_relation(model.__class__.__name__.lower(),
child.__class__.__name__.lower(),
model, child, virtual=self.virtual)
child_model_name = self.related_name or child.__class__.__name__.lower() + "s"
model._orm_relationship_manager.add_relation(
model.__class__.__name__.lower(),
child.__class__.__name__.lower(),
model,
child,
virtual=self.virtual,
)
if child_model_name not in model.__fields__ \
and child.__class__.__name__.lower() not in model.__fields__:
model.__fields__[child_model_name] = ModelField(name=child_model_name,
type_=Optional[child.__pydantic_model__],
model_config=child.__pydantic_model__.__config__,
class_validators=child.__pydantic_model__.__validators__)
model.__model_fields__[child_model_name] = ForeignKey(child.__class__,
name=child_model_name,
virtual=True)
if (
child_model_name not in model.__fields__
and child.__class__.__name__.lower() not in model.__fields__
):
model.__fields__[child_model_name] = ModelField(
name=child_model_name,
type_=Optional[child.__pydantic_model__],
model_config=child.__pydantic_model__.__config__,
class_validators=child.__pydantic_model__.__validators__,
)
model.__model_fields__[child_model_name] = ForeignKey(
child.__class__, name=child_model_name, virtual=True
)
return model

View File

@ -1,26 +0,0 @@
from typing import Union, Set, Dict # pragma no cover
class Excludable: # pragma no cover
@staticmethod
def get_excluded(exclude: Union[Set, Dict, None], key: str = None):
# print(f'checking excluded for {key}', exclude)
if isinstance(exclude, dict):
if isinstance(exclude.get(key, {}), dict) and '__all__' in exclude.get(key, {}).keys():
return exclude.get(key).get('__all__')
return exclude.get(key, {})
return exclude
@staticmethod
def is_excluded(exclude: Union[Set, Dict, None], key: str = None):
if exclude is None:
return False
to_exclude = Excludable.get_excluded(exclude, key)
# print(f'to exclude for current key = {key}', to_exclude)
if isinstance(to_exclude, Set):
return key in to_exclude
elif to_exclude is ...:
return True
return False

View File

@ -2,35 +2,39 @@ import copy
import inspect
import json
import uuid
from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar, Tuple
from typing import Set, Dict
from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
from typing import Callable, Dict, Set
import databases
import pydantic
import sqlalchemy
from pydantic import BaseModel, BaseConfig, create_model
import orm.queryset as qry
from orm.exceptions import ModelDefinitionError
from orm.fields import BaseField, ForeignKey
from orm.relations import RelationshipManager
import pydantic
from pydantic import BaseConfig, BaseModel, create_model
import sqlalchemy
relationship_manager = RelationshipManager()
def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]:
pydantic_fields = {field_name: (
base_field.__type__,
... if base_field.is_required else base_field.default_value
)
pydantic_fields = {
field_name: (
base_field.__type__,
... if base_field.is_required else base_field.default_value,
)
for field_name, base_field in object_dict.items()
if isinstance(base_field, BaseField)}
if isinstance(base_field, BaseField)
}
return pydantic_fields
def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename: str) -> Tuple[Optional[str],
List[sqlalchemy.Column],
Dict[str, BaseField]]:
def sqlalchemy_columns_from_model_fields(
name: str, object_dict: Dict, tablename: str
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
pkname: Optional[str] = None
columns: List[sqlalchemy.Column] = []
model_fields: Dict[str, BaseField] = {}
@ -42,9 +46,16 @@ def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename
if field.primary_key:
pkname = field_name
if isinstance(field, ForeignKey):
reverse_name = field.related_name or field.to.__name__.lower().title() + '_' + name.lower() + 's'
relation_name = name.lower().title() + '_' + field.to.__name__.lower()
relationship_manager.add_relation_type(relation_name, reverse_name, field, tablename)
reverse_name = (
field.related_name
or field.to.__name__.lower().title() + "_" + name.lower() + "s"
)
relation_name = (
name.lower().title() + "_" + field.to.__name__.lower()
)
relationship_manager.add_relation_type(
relation_name, reverse_name, field, tablename
)
columns.append(field.get_column(field_name))
return pkname, columns, model_fields
@ -57,9 +68,7 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
class ModelMetaclass(type):
def __new__(
mcs: type, name: str, bases: Any, attrs: dict
) -> type:
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs
)
@ -71,25 +80,29 @@ class ModelMetaclass(type):
metadata = attrs["__metadata__"]
# sqlalchemy table creation
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs, tablename)
attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns)
attrs['__columns__'] = columns
attrs['__pkname__'] = pkname
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
name, attrs, tablename
)
attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns)
attrs["__columns__"] = columns
attrs["__pkname__"] = pkname
if not pkname:
raise ModelDefinitionError('Table has to have a primary key.')
raise ModelDefinitionError("Table has to have a primary key.")
# pydantic model creation
pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
pydantic_model = create_model(name, __config__=get_pydantic_base_orm_config(), **pydantic_fields)
attrs['__pydantic_fields__'] = pydantic_fields
attrs['__pydantic_model__'] = pydantic_model
attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__)
attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__)
attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__)
attrs['__model_fields__'] = model_fields
pydantic_model = create_model(
name, __config__=get_pydantic_base_orm_config(), **pydantic_fields
)
attrs["__pydantic_fields__"] = pydantic_fields
attrs["__pydantic_model__"] = pydantic_model
attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__)
attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__)
attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__)
attrs["__model_fields__"] = model_fields
attrs['_orm_relationship_manager'] = relationship_manager
attrs["_orm_relationship_manager"] = relationship_manager
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs
@ -99,7 +112,8 @@ class ModelMetaclass(type):
class Model(list, metaclass=ModelMetaclass):
# Model inherits from list in order to be treated as request.Body parameter in fastapi routes,
# Model inherits from list in order to be treated as
# request.Body parameter in fastapi routes,
# inheriting from pydantic.BaseModel causes metaclass conflicts
__abstract__ = True
if TYPE_CHECKING: # pragma no cover
@ -115,17 +129,20 @@ class Model(list, metaclass=ModelMetaclass):
objects = qry.QuerySet()
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
self._orm_id: str = uuid.uuid4().hex
self._orm_saved: bool = False
self.values: Optional[BaseModel] = None
if "pk" in kwargs:
kwargs[self.__pkname__] = kwargs.pop("pk")
kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()}
kwargs = {
k: self.__model_fields__[k].expand_relationship(v, self)
for k, v in kwargs.items()
}
self.values = self.__pydantic_model__(**kwargs)
def __del__(self):
def __del__(self) -> None:
self._orm_relationship_manager.deregister(self)
def __setattr__(self, key: str, value: Any) -> None:
@ -138,20 +155,24 @@ class Model(list, metaclass=ModelMetaclass):
value = self.__model_fields__[key].expand_relationship(value, self)
relation_key = self.__class__.__name__.title() + '_' + key
relation_key = self.__class__.__name__.title() + "_" + key
if not self._orm_relationship_manager.contains(relation_key, self):
setattr(self.values, key, value)
else:
super().__setattr__(key, value)
def __getattribute__(self, key: str) -> Any:
if key != '__fields__' and key in self.__fields__:
relation_key = self.__class__.__name__.title() + '_' + key
if key != "__fields__" and key in self.__fields__:
relation_key = self.__class__.__name__.title() + "_" + key
if self._orm_relationship_manager.contains(relation_key, self):
return self._orm_relationship_manager.get(relation_key, self)
item = getattr(self.values, key, None)
if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str):
if (
item is not None
and self.is_conversion_to_json_needed(key)
and isinstance(item, str)
):
try:
item = json.loads(item)
except TypeError: # pragma no cover
@ -159,30 +180,41 @@ class Model(list, metaclass=ModelMetaclass):
return item
return super().__getattribute__(key)
def __eq__(self, other):
def __eq__(self, other: "Model") -> bool:
return self.values.dict() == other.values.dict()
def __same__(self, other):
assert self.__class__ == other.__class__
def __same__(self, other: "Model") -> bool:
if self.__class__ != other.__class__:
return False
return self._orm_id == other._orm_id or (
self.values is not None and other.values is not None and self.pk == other.pk)
self.values is not None and other.values is not None and self.pk == other.pk
)
def __repr__(self): # pragma no cover
def __repr__(self) -> str: # pragma no cover
return self.values.__repr__()
@classmethod
def from_row(cls, row, select_related: List = None, previous_table: str = None) -> 'Model':
def from_row(
cls,
row: sqlalchemy.engine.ResultProxy,
select_related: List = None,
previous_table: str = None,
) -> "Model":
item = {}
select_related = select_related or []
table_prefix = cls._orm_relationship_manager.resolve_relation_join(previous_table, cls.__table__.name)
table_prefix = cls._orm_relationship_manager.resolve_relation_join(
previous_table, cls.__table__.name
)
previous_table = cls.__table__.name
for related in select_related:
if "__" in related:
first_part, remainder = related.split("__", 1)
model_cls = cls.__model_fields__[first_part].to
child = model_cls.from_row(row, select_related=[remainder], previous_table=previous_table)
child = model_cls.from_row(
row, select_related=[remainder], previous_table=previous_table
)
item[first_part] = child
else:
model_cls = cls.__model_fields__[related].to
@ -191,7 +223,9 @@ class Model(list, metaclass=ModelMetaclass):
for column in cls.__table__.columns:
if column.name not in item:
item[column.name] = row[f'{table_prefix + "_" if table_prefix else ""}{column.name}']
item[column.name] = row[
f'{table_prefix + "_" if table_prefix else ""}{column.name}'
]
return cls(**item)
@ -200,7 +234,7 @@ class Model(list, metaclass=ModelMetaclass):
# return cls.__pydantic_model__.validate(value=value)
@classmethod
def __get_validators__(cls): # pragma no cover
def __get_validators__(cls) -> Callable: # pragma no cover
yield cls.__pydantic_model__.validate
# @classmethod
@ -211,11 +245,11 @@ class Model(list, metaclass=ModelMetaclass):
return self.__model_fields__.get(column_name).__type__ == pydantic.Json
@property
def pk(self):
def pk(self) -> str:
return getattr(self.values, self.__pkname__)
@pk.setter
def pk(self, value):
def pk(self, value: Any) -> None:
setattr(self.values, self.__pkname__, value)
@property
@ -229,7 +263,9 @@ class Model(list, metaclass=ModelMetaclass):
if isinstance(nested_model, list):
dict_instance[field] = [x.dict() for x in nested_model]
else:
dict_instance[field] = nested_model.dict() if nested_model is not None else {}
dict_instance[field] = (
nested_model.dict() if nested_model is not None else {}
)
return dict_instance
def from_dict(self, value_dict: Dict) -> None:
@ -245,16 +281,22 @@ class Model(list, metaclass=ModelMetaclass):
def extract_related_names(cls) -> Set:
related_names = set()
for name, field in cls.__fields__.items():
if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel):
if inspect.isclass(field.type_) and issubclass(
field.type_, pydantic.BaseModel
):
related_names.add(name)
return related_names
def extract_model_db_fields(self) -> Dict:
self_fields = self.extract_own_model_fields()
self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns}
self_fields = {
k: v for k, v in self_fields.items() if k in self.__table__.columns
}
for field in self.extract_related_names():
if getattr(self, field) is not None:
self_fields[field] = getattr(getattr(self, field), self.__model_fields__[field].to.__pkname__)
self_fields[field] = getattr(
getattr(self, field), self.__model_fields__[field].to.__pkname__
)
return self_fields
async def save(self) -> int:
@ -264,7 +306,7 @@ class Model(list, metaclass=ModelMetaclass):
expr = self.__table__.insert()
expr = expr.values(**self_fields)
item_id = await self.__database__.execute(expr)
setattr(self, 'pk', item_id)
self.pk = item_id
return item_id
async def update(self, **kwargs: Any) -> int:
@ -274,8 +316,11 @@ class Model(list, metaclass=ModelMetaclass):
self_fields = self.extract_model_db_fields()
self_fields.pop(self.__pkname__)
expr = self.__table__.update().values(**self_fields).where(
self.pk_column == getattr(self, self.__pkname__))
expr = (
self.__table__.update()
.values(**self_fields)
.where(self.pk_column == getattr(self, self.__pkname__))
)
result = await self.__database__.execute(expr)
return result
@ -285,7 +330,7 @@ class Model(list, metaclass=ModelMetaclass):
result = await self.__database__.execute(expr)
return result
async def load(self) -> 'Model':
async def load(self) -> "Model":
expr = self.__table__.select().where(self.pk_column == self.pk)
row = await self.__database__.fetch_one(expr)
self.from_dict(dict(row))

View File

@ -1,11 +1,14 @@
from typing import List, TYPE_CHECKING, Type, NamedTuple
from typing import Any, List, NamedTuple, TYPE_CHECKING, Tuple, Type, Union
import sqlalchemy
from sqlalchemy import text
import databases
import orm
from orm import ForeignKey
from orm.exceptions import NoMatch, MultipleMatches
from orm.exceptions import MultipleMatches, NoMatch
from orm.fields import BaseField
import sqlalchemy
from sqlalchemy import text
if TYPE_CHECKING: # pragma no cover
from orm.models import Model
@ -24,17 +27,23 @@ FILTER_OPERATORS = {
class JoinParameters(NamedTuple):
prev_model: Type['Model']
prev_model: Type["Model"]
previous_alias: str
from_table: str
model_cls: Type['Model']
model_cls: Type["Model"]
class QuerySet:
ESCAPE_CHARACTERS = ['%', '_']
ESCAPE_CHARACTERS = ["%", "_"]
def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None,
limit_count: int = None, offset: int = None):
def __init__(
self,
model_cls: Type["Model"] = None,
filter_clauses: List = None,
select_related: List = None,
limit_count: int = None,
offset: int = None,
) -> None:
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
self._select_related = [] if select_related is None else select_related
@ -48,47 +57,77 @@ class QuerySet:
self.columns = None
self.order_bys = None
def __get__(self, instance, owner):
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
return self.__class__(model_cls=owner)
@property
def database(self):
def database(self) -> databases.Database:
return self.model_cls.__database__
@property
def table(self):
def table(self) -> sqlalchemy.Table:
return self.model_cls.__table__
def prefixed_columns(self, alias, table):
return [text(f'{alias}_{table.name}.{column.name} as {alias}_{column.name}')
for column in table.columns]
def prefixed_columns(self, alias: str, table: sqlalchemy.Table) -> List[text]:
return [
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
for column in table.columns
]
def prefixed_table_name(self, alias, name):
return text(f'{name} {alias}_{name}')
def prefixed_table_name(self, alias: str, name: str) -> text:
return text(f"{name} {alias}_{name}")
def on_clause(self, from_table, to_table, previous_alias, alias, to_key, from_key):
return text(f'{alias}_{to_table}.{to_key}='
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}')
def on_clause(
self,
from_table: str,
to_table: str,
previous_alias: str,
alias: str,
to_key: str,
from_key: str,
) -> text:
return text(
f"{alias}_{to_table}.{to_key}="
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}'
)
def build_join_parameters(self, part, join_params: JoinParameters):
def build_join_parameters(
self, part: str, join_params: JoinParameters
) -> JoinParameters:
model_cls = join_params.model_cls.__model_fields__[part].to
to_table = model_cls.__table__.name
alias = model_cls._orm_relationship_manager.resolve_relation_join(join_params.from_table, to_table)
alias = model_cls._orm_relationship_manager.resolve_relation_join(
join_params.from_table, to_table
)
if alias not in self.used_aliases:
if join_params.prev_model.__model_fields__[part].virtual:
to_key = next((v for k, v in model_cls.__model_fields__.items()
if isinstance(v, ForeignKey) and v.to == join_params.prev_model), None).name
to_key = next(
(
v
for k, v in model_cls.__model_fields__.items()
if isinstance(v, ForeignKey) and v.to == join_params.prev_model
),
None,
).name
from_key = model_cls.__pkname__
else:
to_key = model_cls.__pkname__
from_key = part
on_clause = self.on_clause(join_params.from_table, to_table, join_params.previous_alias, alias, to_key,
from_key)
on_clause = self.on_clause(
join_params.from_table,
to_table,
join_params.previous_alias,
alias,
to_key,
from_key,
)
target_table = self.prefixed_table_name(alias, to_table)
self.select_from = sqlalchemy.sql.outerjoin(self.select_from, target_table, on_clause)
self.order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}'))
self.select_from = sqlalchemy.sql.outerjoin(
self.select_from, target_table, on_clause
)
self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}"))
self.columns.extend(self.prefixed_columns(alias, model_cls.__table__))
self.used_aliases.append(alias)
@ -98,44 +137,76 @@ class QuerySet:
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
@staticmethod
def field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part) -> bool:
def field_is_a_foreign_key_and_no_circular_reference(
field: BaseField, field_name: str, rel_part: str
) -> bool:
return isinstance(field, ForeignKey) and field_name not in rel_part
def field_qualifies_to_deeper_search(self, field, parent_virtual, nested, rel_part) -> bool:
def field_qualifies_to_deeper_search(
self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str
) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
partial_match = any([x.startswith(prev_part_of_related) for x in self._select_related])
partial_match = any(
[x.startswith(prev_part_of_related) for x in self._select_related]
)
already_checked = any([x.startswith(rel_part) for x in self.auto_related])
return ((field.virtual and parent_virtual) or (partial_match and not already_checked)) or not nested
return (
(field.virtual and parent_virtual)
or (partial_match and not already_checked)
) or not nested
def extract_auto_required_relations(self, join_params: JoinParameters,
rel_part: str = '', nested: bool = False, parent_virtual: bool = False):
def extract_auto_required_relations(
self,
join_params: JoinParameters,
rel_part: str = "",
nested: bool = False,
parent_virtual: bool = False,
) -> None:
for field_name, field in join_params.prev_model.__model_fields__.items():
if self.field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part):
rel_part = field_name if not rel_part else rel_part + '__' + field_name
if self.field_is_a_foreign_key_and_no_circular_reference(
field, field_name, rel_part
):
rel_part = field_name if not rel_part else rel_part + "__" + field_name
if not field.nullable:
if rel_part not in self._select_related:
self.auto_related.append("__".join(rel_part.split("__")[:-1]))
rel_part = ''
elif self.field_qualifies_to_deeper_search(field, parent_virtual, nested, rel_part):
join_params = JoinParameters(field.to, join_params.previous_alias,
join_params.from_table, join_params.prev_model)
self.extract_auto_required_relations(join_params=join_params,
rel_part=rel_part, nested=True, parent_virtual=field.virtual)
rel_part = ""
elif self.field_qualifies_to_deeper_search(
field, parent_virtual, nested, rel_part
):
join_params = JoinParameters(
field.to,
join_params.previous_alias,
join_params.from_table,
join_params.prev_model,
)
self.extract_auto_required_relations(
join_params=join_params,
rel_part=rel_part,
nested=True,
parent_virtual=field.virtual,
)
else:
rel_part = ''
rel_part = ""
def build_select_expression(self):
def build_select_expression(self) -> sqlalchemy.sql.select:
self.columns = list(self.table.columns)
self.order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')]
self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")]
self.select_from = self.table
for key in self.model_cls.__model_fields__:
if not self.model_cls.__model_fields__[key].nullable \
and isinstance(self.model_cls.__model_fields__[key], orm.fields.ForeignKey) \
and key not in self._select_related:
if (
not self.model_cls.__model_fields__[key].nullable
and isinstance(
self.model_cls.__model_fields__[key], orm.fields.ForeignKey
)
and key not in self._select_related
):
self._select_related = [key] + self._select_related
start_params = JoinParameters(self.model_cls, '', self.table.name, self.model_cls)
start_params = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls
)
self.extract_auto_required_relations(start_params)
if self.auto_related:
new_joins = []
@ -146,7 +217,9 @@ class QuerySet:
self._select_related.sort(key=lambda item: (-len(item), item))
for item in self._select_related:
join_parameters = JoinParameters(self.model_cls, '', self.table.name, self.model_cls)
join_parameters = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls
)
for part in item.split("__"):
join_parameters = self.build_join_parameters(part, join_parameters)
@ -180,7 +253,7 @@ class QuerySet:
return expr
def filter(self, **kwargs):
def filter(self, **kwargs: Any) -> "QuerySet":
filter_clauses = self.filter_clauses
select_related = list(self._select_related)
@ -189,7 +262,7 @@ class QuerySet:
kwargs[pk_name] = kwargs.pop("pk")
for key, value in kwargs.items():
table_prefix = ''
table_prefix = ""
if "__" in key:
parts = key.split("__")
@ -215,9 +288,13 @@ class QuerySet:
# against which the comparison is being made.
previous_table = model_cls.__tablename__
for part in related_parts:
current_table = model_cls.__model_fields__[part].to.__tablename__
table_prefix = model_cls._orm_relationship_manager.resolve_relation_join(previous_table,
current_table)
current_table = model_cls.__model_fields__[
part
].to.__tablename__
manager = model_cls._orm_relationship_manager
table_prefix = manager.resolve_relation_join(
previous_table, current_table
)
model_cls = model_cls.__model_fields__[part].to
previous_table = current_table
@ -236,25 +313,32 @@ class QuerySet:
has_escaped_character = False
if op in ["contains", "icontains"]:
has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS
if c in value)
has_escaped_character = any(
c for c in self.ESCAPE_CHARACTERS if c in value
)
if has_escaped_character:
# enable escape modifier
for char in self.ESCAPE_CHARACTERS:
value = value.replace(char, f'\\{char}')
value = value.replace(char, f"\\{char}")
value = f"%{value}%"
if isinstance(value, orm.Model):
value = value.pk
clause = getattr(column, op_attr)(value)
clause.modifiers['escape'] = '\\' if has_escaped_character else None
clause.modifiers["escape"] = "\\" if has_escaped_character else None
clause_text = str(clause.compile(dialect=self.model_cls.__database__._backend._dialect,
compile_kwargs={"literal_binds": True}))
alias = f'{table_prefix}_' if table_prefix else ''
aliased_name = f'{alias}{table.name}.{column.name}'
clause_text = clause_text.replace(f'{table.name}.{column.name}', aliased_name)
clause_text = str(
clause.compile(
dialect=self.model_cls.__database__._backend._dialect,
compile_kwargs={"literal_binds": True},
)
)
alias = f"{table_prefix}_" if table_prefix else ""
aliased_name = f"{alias}{table.name}.{column.name}"
clause_text = clause_text.replace(
f"{table.name}.{column.name}", aliased_name
)
clause = text(clause_text)
filter_clauses.append(clause)
@ -264,10 +348,10 @@ class QuerySet:
filter_clauses=filter_clauses,
select_related=select_related,
limit_count=self.limit_count,
offset=self.query_offset
offset=self.query_offset,
)
def select_related(self, related):
def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet":
if not isinstance(related, (list, tuple)):
related = [related]
@ -277,7 +361,7 @@ class QuerySet:
filter_clauses=self.filter_clauses,
select_related=related,
limit_count=self.limit_count,
offset=self.query_offset
offset=self.query_offset,
)
async def exists(self) -> bool:
@ -290,25 +374,25 @@ class QuerySet:
expr = sqlalchemy.func.count().select().select_from(expr)
return await self.database.fetch_val(expr)
def limit(self, limit_count: int):
def limit(self, limit_count: int) -> "QuerySet":
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=self._select_related,
limit_count=limit_count,
offset=self.query_offset
offset=self.query_offset,
)
def offset(self, offset: int):
def offset(self, offset: int) -> "QuerySet":
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=self._select_related,
limit_count=self.limit_count,
offset=offset
offset=offset,
)
async def first(self, **kwargs):
async def first(self, **kwargs: Any) -> "Model":
if kwargs:
return await self.filter(**kwargs).first()
@ -316,7 +400,7 @@ class QuerySet:
if rows:
return rows[0]
async def get(self, **kwargs):
async def get(self, **kwargs: Any) -> "Model":
if kwargs:
return await self.filter(**kwargs).get()
@ -329,7 +413,7 @@ class QuerySet:
raise MultipleMatches()
return self.model_cls.from_row(rows[0], select_related=self._select_related)
async def all(self, **kwargs):
async def all(self, **kwargs: Any) -> List["Model"]:
if kwargs:
return await self.filter(**kwargs).all()
@ -345,7 +429,7 @@ class QuerySet:
return result_rows
@classmethod
def merge_result_rows(cls, result_rows):
def merge_result_rows(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows = []
for index, model in enumerate(result_rows):
if index > 0 and model.pk == result_rows[index - 1].pk:
@ -355,30 +439,45 @@ class QuerySet:
return merged_rows
@classmethod
def merge_two_instances(cls, one: 'Model', other: 'Model'):
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
for field in one.__model_fields__.keys():
# print(field, one.dict(), other.dict())
if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model):
if isinstance(getattr(one, field), list) and not isinstance(
getattr(one, field), orm.models.Model
):
setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), orm.models.Model):
if getattr(one, field).pk == getattr(other, field).pk:
setattr(other, field, cls.merge_two_instances(getattr(one, field), getattr(other, field)))
setattr(
other,
field,
cls.merge_two_instances(
getattr(one, field), getattr(other, field)
),
)
return other
async def create(self, **kwargs):
async def create(self, **kwargs: Any) -> "Model":
new_kwargs = dict(**kwargs)
# Remove primary key when None to prevent not null constraint in postgresql.
pkname = self.model_cls.__pkname__
pk = self.model_cls.__model_fields__[pkname]
if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement):
if (
pkname in new_kwargs
and new_kwargs.get(pkname) is None
and (pk.nullable or pk.autoincrement)
):
del new_kwargs[pkname]
# substitute related models with their pk
for field in self.model_cls.extract_related_names():
if field in new_kwargs and new_kwargs.get(field) is not None:
new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__)
new_kwargs[field] = getattr(
new_kwargs.get(field),
self.model_cls.__model_fields__[field].to.__pkname__,
)
# Build the insert expression.
expr = self.table.insert()

View File

@ -2,7 +2,7 @@ import pprint
import string
import uuid
from random import choices
from typing import TYPE_CHECKING, List
from typing import Dict, List, TYPE_CHECKING, Union
from weakref import proxy
from orm.fields import ForeignKey
@ -11,40 +11,58 @@ if TYPE_CHECKING: # pragma no cover
from orm.models import Model
def get_table_alias():
return ''.join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
def get_table_alias() -> str:
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
def get_relation_config(relation_type: str, table_name: str, field: ForeignKey):
def get_relation_config(
relation_type: str, table_name: str, field: ForeignKey
) -> Dict[str, str]:
alias = get_table_alias()
config = {'type': relation_type,
'table_alias': alias,
'source_table': table_name if relation_type == 'primary' else field.to.__tablename__,
'target_table': field.to.__tablename__ if relation_type == 'primary' else table_name
}
config = {
"type": relation_type,
"table_alias": alias,
"source_table": table_name
if relation_type == "primary"
else field.to.__tablename__,
"target_table": field.to.__tablename__
if relation_type == "primary"
else table_name,
}
return config
class RelationshipManager:
def __init__(self):
def __init__(self) -> None:
self._relations = dict()
def add_relation_type(self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str):
print(relations_key, reverse_key)
def add_relation_type(
self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str
) -> None:
if relations_key not in self._relations:
self._relations[relations_key] = get_relation_config('primary', table_name, field)
self._relations[relations_key] = get_relation_config(
"primary", table_name, field
)
if reverse_key not in self._relations:
self._relations[reverse_key] = get_relation_config('reverse', table_name, field)
self._relations[reverse_key] = get_relation_config(
"reverse", table_name, field
)
def deregister(self, model: 'Model'):
def deregister(self, model: "Model") -> None:
# print(f'deregistering {model.__class__.__name__}, {model._orm_id}')
for rel_type in self._relations.keys():
if model.__class__.__name__.lower() in rel_type.lower():
if model._orm_id in self._relations[rel_type]:
del self._relations[rel_type][model._orm_id]
def add_relation(self, parent_name: str, child_name: str, parent: 'Model', child: 'Model', virtual: bool = False):
def add_relation(
self,
parent_name: str,
child_name: str,
parent: "Model",
child: "Model",
virtual: bool = False,
) -> None:
parent_id = parent._orm_id
child_id = child._orm_id
if virtual:
@ -53,12 +71,18 @@ class RelationshipManager:
child, parent = parent, proxy(child)
else:
child = proxy(child)
parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, [])
parents_list = self._relations[
parent_name.lower().title() + "_" + child_name + "s"
].setdefault(parent_id, [])
self.append_related_model(parents_list, child)
children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, [])
children_list = self._relations[
child_name.lower().title() + "_" + parent_name
].setdefault(child_id, [])
self.append_related_model(children_list, parent)
def append_related_model(self, relations_list: List['Model'], model: 'Model'):
def append_related_model(
self, relations_list: List["Model"], model: "Model"
) -> None:
for x in relations_list:
try:
if x.__same__(model):
@ -68,26 +92,26 @@ class RelationshipManager:
relations_list.append(model)
def contains(self, relations_key: str, object: 'Model'):
def contains(self, relations_key: str, object: "Model") -> bool:
if relations_key in self._relations:
return object._orm_id in self._relations[relations_key]
return False
def get(self, relations_key: str, object: 'Model'):
def get(self, relations_key: str, object: "Model") -> Union["Model", List["Model"]]:
if relations_key in self._relations:
if object._orm_id in self._relations[relations_key]:
if self._relations[relations_key]['type'] == 'primary':
if self._relations[relations_key]["type"] == "primary":
return self._relations[relations_key][object._orm_id][0]
return self._relations[relations_key][object._orm_id]
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
for k, v in self._relations.items():
if v['source_table'] == from_table and v['target_table'] == to_table:
return self._relations[k]['table_alias']
return ''
if v["source_table"] == from_table and v["target_table"] == to_table:
return self._relations[k]["table_alias"]
return ""
def __str__(self): # pragma no cover
def __str__(self) -> str: # pragma no cover
return pprint.pformat(self._relations, indent=4, width=1)
def __repr__(self): # pragma no cover
def __repr__(self) -> str: # pragma no cover
return self.__str__()

View File

@ -109,6 +109,22 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition():
test = fields.Integer(name='test12', primary_key=True, pydantic_only=True)
def test_decimal_error_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
__tablename__ = "example4"
__metadata__ = metadata
test = fields.Decimal(name='test12', primary_key=True)
def test_string_error_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
__tablename__ = "example4"
__metadata__ = metadata
test = fields.String(name='test12', primary_key=True)
def test_json_conversion_in_model():
with pytest.raises(pydantic.ValidationError):
ExampleModel(test_json=datetime.datetime.now(), test=1, test_string='test', test_bool=True)