liniting and applying black
This commit is contained in:
3
.flake8
3
.flake8
@ -1,4 +1,5 @@
|
|||||||
[flake8]
|
[flake8]
|
||||||
ignore = ANN101
|
ignore = ANN101, ANN102, W503
|
||||||
max-complexity = 10
|
max-complexity = 10
|
||||||
|
max-line-length = 88
|
||||||
exclude = p38venv,.pytest_cache
|
exclude = p38venv,.pytest_cache
|
||||||
|
|||||||
@ -1,6 +1,18 @@
|
|||||||
from orm.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch
|
from orm.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch
|
||||||
from orm.fields import BigInteger, Boolean, Date, DateTime, Decimal, Float, ForeignKey, Integer, JSON, String, Text, \
|
from orm.fields import (
|
||||||
Time
|
BigInteger,
|
||||||
|
Boolean,
|
||||||
|
Date,
|
||||||
|
DateTime,
|
||||||
|
Decimal,
|
||||||
|
Float,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
JSON,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
Time,
|
||||||
|
)
|
||||||
from orm.models import Model
|
from orm.models import Model
|
||||||
|
|
||||||
__version__ = "0.0.1"
|
__version__ = "0.0.1"
|
||||||
@ -21,5 +33,5 @@ __all__ = [
|
|||||||
"ModelDefinitionError",
|
"ModelDefinitionError",
|
||||||
"ModelNotSet",
|
"ModelNotSet",
|
||||||
"MultipleMatches",
|
"MultipleMatches",
|
||||||
"NoMatch"
|
"NoMatch",
|
||||||
]
|
]
|
||||||
|
|||||||
125
orm/fields.py
125
orm/fields.py
@ -1,14 +1,15 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import decimal
|
import decimal
|
||||||
from typing import List, Optional, TYPE_CHECKING, Type, Any, Union
|
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
|
||||||
|
|
||||||
import sqlalchemy
|
|
||||||
from pydantic import Json, BaseModel
|
|
||||||
from pydantic.fields import ModelField
|
|
||||||
|
|
||||||
import orm
|
import orm
|
||||||
from orm.exceptions import ModelDefinitionError, RelationshipInstanceError
|
from orm.exceptions import ModelDefinitionError, RelationshipInstanceError
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Json
|
||||||
|
from pydantic.fields import ModelField
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from orm.models import Model
|
from orm.models import Model
|
||||||
|
|
||||||
@ -16,33 +17,39 @@ if TYPE_CHECKING: # pragma no cover
|
|||||||
class BaseField:
|
class BaseField:
|
||||||
__type__ = None
|
__type__ = None
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
name = kwargs.pop('name', None)
|
name = kwargs.pop("name", None)
|
||||||
args = list(args)
|
args = list(args)
|
||||||
if args:
|
if args:
|
||||||
if isinstance(args[0], str):
|
if isinstance(args[0], str):
|
||||||
if name is not None:
|
if name is not None:
|
||||||
raise ModelDefinitionError('Column name cannot be passed positionally and as a keyword.')
|
raise ModelDefinitionError(
|
||||||
|
"Column name cannot be passed positionally and as a keyword."
|
||||||
|
)
|
||||||
name = args.pop(0)
|
name = args.pop(0)
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.primary_key = kwargs.pop('primary_key', False)
|
self.primary_key = kwargs.pop("primary_key", False)
|
||||||
self.autoincrement = kwargs.pop('autoincrement', self.primary_key and self.__type__ == int)
|
self.autoincrement = kwargs.pop(
|
||||||
|
"autoincrement", self.primary_key and self.__type__ == int
|
||||||
|
)
|
||||||
|
|
||||||
self.nullable = kwargs.pop('nullable', not self.primary_key)
|
self.nullable = kwargs.pop("nullable", not self.primary_key)
|
||||||
self.default = kwargs.pop('default', None)
|
self.default = kwargs.pop("default", None)
|
||||||
self.server_default = kwargs.pop('server_default', None)
|
self.server_default = kwargs.pop("server_default", None)
|
||||||
|
|
||||||
self.index = kwargs.pop('index', None)
|
self.index = kwargs.pop("index", None)
|
||||||
self.unique = kwargs.pop('unique', None)
|
self.unique = kwargs.pop("unique", None)
|
||||||
|
|
||||||
self.pydantic_only = kwargs.pop('pydantic_only', False)
|
self.pydantic_only = kwargs.pop("pydantic_only", False)
|
||||||
if self.pydantic_only and self.primary_key:
|
if self.pydantic_only and self.primary_key:
|
||||||
raise ModelDefinitionError('Primary key column cannot be pydantic only.')
|
raise ModelDefinitionError("Primary key column cannot be pydantic only.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_required(self) -> bool:
|
def is_required(self) -> bool:
|
||||||
return not self.nullable and not self.has_default and not self.is_auto_primary_key
|
return (
|
||||||
|
not self.nullable and not self.has_default and not self.is_auto_primary_key
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_value(self) -> Any:
|
def default_value(self) -> Any:
|
||||||
@ -81,16 +88,19 @@ class BaseField:
|
|||||||
def get_constraints(self) -> Optional[List]:
|
def get_constraints(self) -> Optional[List]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def expand_relationship(self, value, child) -> Any:
|
def expand_relationship(self, value: Any, child: "Model") -> Any:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
class String(BaseField):
|
class String(BaseField):
|
||||||
__type__ = str
|
__type__ = str
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
assert 'length' in kwargs, 'length is required'
|
if "length" not in kwargs:
|
||||||
self.length = kwargs.pop('length')
|
raise ModelDefinitionError(
|
||||||
|
"Param length is required for String model field."
|
||||||
|
)
|
||||||
|
self.length = kwargs.pop("length")
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def get_column_type(self) -> sqlalchemy.Column:
|
def get_column_type(self) -> sqlalchemy.Column:
|
||||||
@ -163,27 +173,41 @@ class BigInteger(BaseField):
|
|||||||
class Decimal(BaseField):
|
class Decimal(BaseField):
|
||||||
__type__ = decimal.Decimal
|
__type__ = decimal.Decimal
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
assert 'precision' in kwargs, 'precision is required'
|
if "length" not in kwargs or "precision" not in kwargs:
|
||||||
assert 'length' in kwargs, 'length is required'
|
raise ModelDefinitionError(
|
||||||
self.length = kwargs.pop('length')
|
"Params length and precision are required for Decimal model field."
|
||||||
self.precision = kwargs.pop('precision')
|
)
|
||||||
|
self.length = kwargs.pop("length")
|
||||||
|
self.precision = kwargs.pop("precision")
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def get_column_type(self) -> sqlalchemy.Column:
|
def get_column_type(self) -> sqlalchemy.Column:
|
||||||
return sqlalchemy.DECIMAL(self.length, self.precision)
|
return sqlalchemy.DECIMAL(self.length, self.precision)
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_instance(fk: Type['Model'], pk: int = None) -> 'Model':
|
def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model":
|
||||||
init_dict = {fk.__pkname__: pk or -1}
|
init_dict = {fk.__pkname__: pk or -1}
|
||||||
init_dict = {**init_dict, **{k: create_dummy_instance(v.to)
|
init_dict = {
|
||||||
|
**init_dict,
|
||||||
|
**{
|
||||||
|
k: create_dummy_instance(v.to)
|
||||||
for k, v in fk.__model_fields__.items()
|
for k, v in fk.__model_fields__.items()
|
||||||
if isinstance(v, ForeignKey) and not v.nullable and not v.virtual}}
|
if isinstance(v, ForeignKey) and not v.nullable and not v.virtual
|
||||||
|
},
|
||||||
|
}
|
||||||
return fk(**init_dict)
|
return fk(**init_dict)
|
||||||
|
|
||||||
|
|
||||||
class ForeignKey(BaseField):
|
class ForeignKey(BaseField):
|
||||||
def __init__(self, to, name: str = None, related_name: str = None, nullable: bool = True, virtual: bool = False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
to: Type["Model"],
|
||||||
|
name: str = None,
|
||||||
|
related_name: str = None,
|
||||||
|
nullable: bool = True,
|
||||||
|
virtual: bool = False,
|
||||||
|
) -> None:
|
||||||
super().__init__(nullable=nullable, name=name)
|
super().__init__(nullable=nullable, name=name)
|
||||||
self.virtual = virtual
|
self.virtual = virtual
|
||||||
self.related_name = related_name
|
self.related_name = related_name
|
||||||
@ -201,11 +225,16 @@ class ForeignKey(BaseField):
|
|||||||
to_column = self.to.__model_fields__[self.to.__pkname__]
|
to_column = self.to.__model_fields__[self.to.__pkname__]
|
||||||
return to_column.get_column_type()
|
return to_column.get_column_type()
|
||||||
|
|
||||||
def expand_relationship(self, value, child) -> Union['Model', List['Model']]:
|
def expand_relationship(
|
||||||
|
self, value: Any, child: "Model"
|
||||||
|
) -> Union["Model", List["Model"]]:
|
||||||
if not isinstance(value, (self.to, dict, int, str, list)) or (
|
if not isinstance(value, (self.to, dict, int, str, list)) or (
|
||||||
isinstance(value, orm.models.Model) and not isinstance(value, self.to)):
|
isinstance(value, orm.models.Model) and not isinstance(value, self.to)
|
||||||
|
):
|
||||||
raise RelationshipInstanceError(
|
raise RelationshipInstanceError(
|
||||||
'Relationship model can be build only from orm.Model, dict and integer or string (pk).')
|
"Relationship model can be build only from orm.Model, "
|
||||||
|
"dict and integer or string (pk)."
|
||||||
|
)
|
||||||
if isinstance(value, list) and not isinstance(value, self.to):
|
if isinstance(value, list) and not isinstance(value, self.to):
|
||||||
model = [self.expand_relationship(val, child) for val in value]
|
model = [self.expand_relationship(val, child) for val in value]
|
||||||
return model
|
return model
|
||||||
@ -217,19 +246,27 @@ class ForeignKey(BaseField):
|
|||||||
else:
|
else:
|
||||||
model = create_dummy_instance(fk=self.to, pk=value)
|
model = create_dummy_instance(fk=self.to, pk=value)
|
||||||
|
|
||||||
child_model_name = self.related_name or child.__class__.__name__.lower() + 's'
|
child_model_name = self.related_name or child.__class__.__name__.lower() + "s"
|
||||||
model._orm_relationship_manager.add_relation(model.__class__.__name__.lower(),
|
model._orm_relationship_manager.add_relation(
|
||||||
|
model.__class__.__name__.lower(),
|
||||||
child.__class__.__name__.lower(),
|
child.__class__.__name__.lower(),
|
||||||
model, child, virtual=self.virtual)
|
model,
|
||||||
|
child,
|
||||||
|
virtual=self.virtual,
|
||||||
|
)
|
||||||
|
|
||||||
if child_model_name not in model.__fields__ \
|
if (
|
||||||
and child.__class__.__name__.lower() not in model.__fields__:
|
child_model_name not in model.__fields__
|
||||||
model.__fields__[child_model_name] = ModelField(name=child_model_name,
|
and child.__class__.__name__.lower() not in model.__fields__
|
||||||
|
):
|
||||||
|
model.__fields__[child_model_name] = ModelField(
|
||||||
|
name=child_model_name,
|
||||||
type_=Optional[child.__pydantic_model__],
|
type_=Optional[child.__pydantic_model__],
|
||||||
model_config=child.__pydantic_model__.__config__,
|
model_config=child.__pydantic_model__.__config__,
|
||||||
class_validators=child.__pydantic_model__.__validators__)
|
class_validators=child.__pydantic_model__.__validators__,
|
||||||
model.__model_fields__[child_model_name] = ForeignKey(child.__class__,
|
)
|
||||||
name=child_model_name,
|
model.__model_fields__[child_model_name] = ForeignKey(
|
||||||
virtual=True)
|
child.__class__, name=child_model_name, virtual=True
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
@ -1,26 +0,0 @@
|
|||||||
from typing import Union, Set, Dict # pragma no cover
|
|
||||||
|
|
||||||
|
|
||||||
class Excludable: # pragma no cover
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_excluded(exclude: Union[Set, Dict, None], key: str = None):
|
|
||||||
# print(f'checking excluded for {key}', exclude)
|
|
||||||
if isinstance(exclude, dict):
|
|
||||||
if isinstance(exclude.get(key, {}), dict) and '__all__' in exclude.get(key, {}).keys():
|
|
||||||
return exclude.get(key).get('__all__')
|
|
||||||
return exclude.get(key, {})
|
|
||||||
return exclude
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_excluded(exclude: Union[Set, Dict, None], key: str = None):
|
|
||||||
if exclude is None:
|
|
||||||
return False
|
|
||||||
to_exclude = Excludable.get_excluded(exclude, key)
|
|
||||||
# print(f'to exclude for current key = {key}', to_exclude)
|
|
||||||
|
|
||||||
if isinstance(to_exclude, Set):
|
|
||||||
return key in to_exclude
|
|
||||||
elif to_exclude is ...:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
161
orm/models.py
161
orm/models.py
@ -2,35 +2,39 @@ import copy
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar, Tuple
|
from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
|
||||||
from typing import Set, Dict
|
from typing import Callable, Dict, Set
|
||||||
|
|
||||||
import databases
|
import databases
|
||||||
import pydantic
|
|
||||||
import sqlalchemy
|
|
||||||
from pydantic import BaseModel, BaseConfig, create_model
|
|
||||||
|
|
||||||
import orm.queryset as qry
|
import orm.queryset as qry
|
||||||
from orm.exceptions import ModelDefinitionError
|
from orm.exceptions import ModelDefinitionError
|
||||||
from orm.fields import BaseField, ForeignKey
|
from orm.fields import BaseField, ForeignKey
|
||||||
from orm.relations import RelationshipManager
|
from orm.relations import RelationshipManager
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
from pydantic import BaseConfig, BaseModel, create_model
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
|
||||||
relationship_manager = RelationshipManager()
|
relationship_manager = RelationshipManager()
|
||||||
|
|
||||||
|
|
||||||
def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]:
|
def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]:
|
||||||
pydantic_fields = {field_name: (
|
pydantic_fields = {
|
||||||
|
field_name: (
|
||||||
base_field.__type__,
|
base_field.__type__,
|
||||||
... if base_field.is_required else base_field.default_value
|
... if base_field.is_required else base_field.default_value,
|
||||||
)
|
)
|
||||||
for field_name, base_field in object_dict.items()
|
for field_name, base_field in object_dict.items()
|
||||||
if isinstance(base_field, BaseField)}
|
if isinstance(base_field, BaseField)
|
||||||
|
}
|
||||||
return pydantic_fields
|
return pydantic_fields
|
||||||
|
|
||||||
|
|
||||||
def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename: str) -> Tuple[Optional[str],
|
def sqlalchemy_columns_from_model_fields(
|
||||||
List[sqlalchemy.Column],
|
name: str, object_dict: Dict, tablename: str
|
||||||
Dict[str, BaseField]]:
|
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
|
||||||
pkname: Optional[str] = None
|
pkname: Optional[str] = None
|
||||||
columns: List[sqlalchemy.Column] = []
|
columns: List[sqlalchemy.Column] = []
|
||||||
model_fields: Dict[str, BaseField] = {}
|
model_fields: Dict[str, BaseField] = {}
|
||||||
@ -42,9 +46,16 @@ def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename
|
|||||||
if field.primary_key:
|
if field.primary_key:
|
||||||
pkname = field_name
|
pkname = field_name
|
||||||
if isinstance(field, ForeignKey):
|
if isinstance(field, ForeignKey):
|
||||||
reverse_name = field.related_name or field.to.__name__.lower().title() + '_' + name.lower() + 's'
|
reverse_name = (
|
||||||
relation_name = name.lower().title() + '_' + field.to.__name__.lower()
|
field.related_name
|
||||||
relationship_manager.add_relation_type(relation_name, reverse_name, field, tablename)
|
or field.to.__name__.lower().title() + "_" + name.lower() + "s"
|
||||||
|
)
|
||||||
|
relation_name = (
|
||||||
|
name.lower().title() + "_" + field.to.__name__.lower()
|
||||||
|
)
|
||||||
|
relationship_manager.add_relation_type(
|
||||||
|
relation_name, reverse_name, field, tablename
|
||||||
|
)
|
||||||
columns.append(field.get_column(field_name))
|
columns.append(field.get_column(field_name))
|
||||||
return pkname, columns, model_fields
|
return pkname, columns, model_fields
|
||||||
|
|
||||||
@ -57,9 +68,7 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
|
|||||||
|
|
||||||
|
|
||||||
class ModelMetaclass(type):
|
class ModelMetaclass(type):
|
||||||
def __new__(
|
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
|
||||||
mcs: type, name: str, bases: Any, attrs: dict
|
|
||||||
) -> type:
|
|
||||||
new_model = super().__new__( # type: ignore
|
new_model = super().__new__( # type: ignore
|
||||||
mcs, name, bases, attrs
|
mcs, name, bases, attrs
|
||||||
)
|
)
|
||||||
@ -71,25 +80,29 @@ class ModelMetaclass(type):
|
|||||||
metadata = attrs["__metadata__"]
|
metadata = attrs["__metadata__"]
|
||||||
|
|
||||||
# sqlalchemy table creation
|
# sqlalchemy table creation
|
||||||
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs, tablename)
|
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
|
||||||
attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns)
|
name, attrs, tablename
|
||||||
attrs['__columns__'] = columns
|
)
|
||||||
attrs['__pkname__'] = pkname
|
attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns)
|
||||||
|
attrs["__columns__"] = columns
|
||||||
|
attrs["__pkname__"] = pkname
|
||||||
|
|
||||||
if not pkname:
|
if not pkname:
|
||||||
raise ModelDefinitionError('Table has to have a primary key.')
|
raise ModelDefinitionError("Table has to have a primary key.")
|
||||||
|
|
||||||
# pydantic model creation
|
# pydantic model creation
|
||||||
pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
|
pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
|
||||||
pydantic_model = create_model(name, __config__=get_pydantic_base_orm_config(), **pydantic_fields)
|
pydantic_model = create_model(
|
||||||
attrs['__pydantic_fields__'] = pydantic_fields
|
name, __config__=get_pydantic_base_orm_config(), **pydantic_fields
|
||||||
attrs['__pydantic_model__'] = pydantic_model
|
)
|
||||||
attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__)
|
attrs["__pydantic_fields__"] = pydantic_fields
|
||||||
attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__)
|
attrs["__pydantic_model__"] = pydantic_model
|
||||||
attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__)
|
attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__)
|
||||||
attrs['__model_fields__'] = model_fields
|
attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__)
|
||||||
|
attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__)
|
||||||
|
attrs["__model_fields__"] = model_fields
|
||||||
|
|
||||||
attrs['_orm_relationship_manager'] = relationship_manager
|
attrs["_orm_relationship_manager"] = relationship_manager
|
||||||
|
|
||||||
new_model = super().__new__( # type: ignore
|
new_model = super().__new__( # type: ignore
|
||||||
mcs, name, bases, attrs
|
mcs, name, bases, attrs
|
||||||
@ -99,7 +112,8 @@ class ModelMetaclass(type):
|
|||||||
|
|
||||||
|
|
||||||
class Model(list, metaclass=ModelMetaclass):
|
class Model(list, metaclass=ModelMetaclass):
|
||||||
# Model inherits from list in order to be treated as request.Body parameter in fastapi routes,
|
# Model inherits from list in order to be treated as
|
||||||
|
# request.Body parameter in fastapi routes,
|
||||||
# inheriting from pydantic.BaseModel causes metaclass conflicts
|
# inheriting from pydantic.BaseModel causes metaclass conflicts
|
||||||
__abstract__ = True
|
__abstract__ = True
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
@ -115,17 +129,20 @@ class Model(list, metaclass=ModelMetaclass):
|
|||||||
|
|
||||||
objects = qry.QuerySet()
|
objects = qry.QuerySet()
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
self._orm_id: str = uuid.uuid4().hex
|
self._orm_id: str = uuid.uuid4().hex
|
||||||
self._orm_saved: bool = False
|
self._orm_saved: bool = False
|
||||||
self.values: Optional[BaseModel] = None
|
self.values: Optional[BaseModel] = None
|
||||||
|
|
||||||
if "pk" in kwargs:
|
if "pk" in kwargs:
|
||||||
kwargs[self.__pkname__] = kwargs.pop("pk")
|
kwargs[self.__pkname__] = kwargs.pop("pk")
|
||||||
kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()}
|
kwargs = {
|
||||||
|
k: self.__model_fields__[k].expand_relationship(v, self)
|
||||||
|
for k, v in kwargs.items()
|
||||||
|
}
|
||||||
self.values = self.__pydantic_model__(**kwargs)
|
self.values = self.__pydantic_model__(**kwargs)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self) -> None:
|
||||||
self._orm_relationship_manager.deregister(self)
|
self._orm_relationship_manager.deregister(self)
|
||||||
|
|
||||||
def __setattr__(self, key: str, value: Any) -> None:
|
def __setattr__(self, key: str, value: Any) -> None:
|
||||||
@ -138,20 +155,24 @@ class Model(list, metaclass=ModelMetaclass):
|
|||||||
|
|
||||||
value = self.__model_fields__[key].expand_relationship(value, self)
|
value = self.__model_fields__[key].expand_relationship(value, self)
|
||||||
|
|
||||||
relation_key = self.__class__.__name__.title() + '_' + key
|
relation_key = self.__class__.__name__.title() + "_" + key
|
||||||
if not self._orm_relationship_manager.contains(relation_key, self):
|
if not self._orm_relationship_manager.contains(relation_key, self):
|
||||||
setattr(self.values, key, value)
|
setattr(self.values, key, value)
|
||||||
else:
|
else:
|
||||||
super().__setattr__(key, value)
|
super().__setattr__(key, value)
|
||||||
|
|
||||||
def __getattribute__(self, key: str) -> Any:
|
def __getattribute__(self, key: str) -> Any:
|
||||||
if key != '__fields__' and key in self.__fields__:
|
if key != "__fields__" and key in self.__fields__:
|
||||||
relation_key = self.__class__.__name__.title() + '_' + key
|
relation_key = self.__class__.__name__.title() + "_" + key
|
||||||
if self._orm_relationship_manager.contains(relation_key, self):
|
if self._orm_relationship_manager.contains(relation_key, self):
|
||||||
return self._orm_relationship_manager.get(relation_key, self)
|
return self._orm_relationship_manager.get(relation_key, self)
|
||||||
|
|
||||||
item = getattr(self.values, key, None)
|
item = getattr(self.values, key, None)
|
||||||
if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str):
|
if (
|
||||||
|
item is not None
|
||||||
|
and self.is_conversion_to_json_needed(key)
|
||||||
|
and isinstance(item, str)
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
item = json.loads(item)
|
item = json.loads(item)
|
||||||
except TypeError: # pragma no cover
|
except TypeError: # pragma no cover
|
||||||
@ -159,30 +180,41 @@ class Model(list, metaclass=ModelMetaclass):
|
|||||||
return item
|
return item
|
||||||
return super().__getattribute__(key)
|
return super().__getattribute__(key)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other: "Model") -> bool:
|
||||||
return self.values.dict() == other.values.dict()
|
return self.values.dict() == other.values.dict()
|
||||||
|
|
||||||
def __same__(self, other):
|
def __same__(self, other: "Model") -> bool:
|
||||||
assert self.__class__ == other.__class__
|
if self.__class__ != other.__class__:
|
||||||
|
return False
|
||||||
return self._orm_id == other._orm_id or (
|
return self._orm_id == other._orm_id or (
|
||||||
self.values is not None and other.values is not None and self.pk == other.pk)
|
self.values is not None and other.values is not None and self.pk == other.pk
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self): # pragma no cover
|
def __repr__(self) -> str: # pragma no cover
|
||||||
return self.values.__repr__()
|
return self.values.__repr__()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_row(cls, row, select_related: List = None, previous_table: str = None) -> 'Model':
|
def from_row(
|
||||||
|
cls,
|
||||||
|
row: sqlalchemy.engine.ResultProxy,
|
||||||
|
select_related: List = None,
|
||||||
|
previous_table: str = None,
|
||||||
|
) -> "Model":
|
||||||
|
|
||||||
item = {}
|
item = {}
|
||||||
select_related = select_related or []
|
select_related = select_related or []
|
||||||
|
|
||||||
table_prefix = cls._orm_relationship_manager.resolve_relation_join(previous_table, cls.__table__.name)
|
table_prefix = cls._orm_relationship_manager.resolve_relation_join(
|
||||||
|
previous_table, cls.__table__.name
|
||||||
|
)
|
||||||
previous_table = cls.__table__.name
|
previous_table = cls.__table__.name
|
||||||
for related in select_related:
|
for related in select_related:
|
||||||
if "__" in related:
|
if "__" in related:
|
||||||
first_part, remainder = related.split("__", 1)
|
first_part, remainder = related.split("__", 1)
|
||||||
model_cls = cls.__model_fields__[first_part].to
|
model_cls = cls.__model_fields__[first_part].to
|
||||||
child = model_cls.from_row(row, select_related=[remainder], previous_table=previous_table)
|
child = model_cls.from_row(
|
||||||
|
row, select_related=[remainder], previous_table=previous_table
|
||||||
|
)
|
||||||
item[first_part] = child
|
item[first_part] = child
|
||||||
else:
|
else:
|
||||||
model_cls = cls.__model_fields__[related].to
|
model_cls = cls.__model_fields__[related].to
|
||||||
@ -191,7 +223,9 @@ class Model(list, metaclass=ModelMetaclass):
|
|||||||
|
|
||||||
for column in cls.__table__.columns:
|
for column in cls.__table__.columns:
|
||||||
if column.name not in item:
|
if column.name not in item:
|
||||||
item[column.name] = row[f'{table_prefix + "_" if table_prefix else ""}{column.name}']
|
item[column.name] = row[
|
||||||
|
f'{table_prefix + "_" if table_prefix else ""}{column.name}'
|
||||||
|
]
|
||||||
|
|
||||||
return cls(**item)
|
return cls(**item)
|
||||||
|
|
||||||
@ -200,7 +234,7 @@ class Model(list, metaclass=ModelMetaclass):
|
|||||||
# return cls.__pydantic_model__.validate(value=value)
|
# return cls.__pydantic_model__.validate(value=value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_validators__(cls): # pragma no cover
|
def __get_validators__(cls) -> Callable: # pragma no cover
|
||||||
yield cls.__pydantic_model__.validate
|
yield cls.__pydantic_model__.validate
|
||||||
|
|
||||||
# @classmethod
|
# @classmethod
|
||||||
@ -211,11 +245,11 @@ class Model(list, metaclass=ModelMetaclass):
|
|||||||
return self.__model_fields__.get(column_name).__type__ == pydantic.Json
|
return self.__model_fields__.get(column_name).__type__ == pydantic.Json
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pk(self):
|
def pk(self) -> str:
|
||||||
return getattr(self.values, self.__pkname__)
|
return getattr(self.values, self.__pkname__)
|
||||||
|
|
||||||
@pk.setter
|
@pk.setter
|
||||||
def pk(self, value):
|
def pk(self, value: Any) -> None:
|
||||||
setattr(self.values, self.__pkname__, value)
|
setattr(self.values, self.__pkname__, value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -229,7 +263,9 @@ class Model(list, metaclass=ModelMetaclass):
|
|||||||
if isinstance(nested_model, list):
|
if isinstance(nested_model, list):
|
||||||
dict_instance[field] = [x.dict() for x in nested_model]
|
dict_instance[field] = [x.dict() for x in nested_model]
|
||||||
else:
|
else:
|
||||||
dict_instance[field] = nested_model.dict() if nested_model is not None else {}
|
dict_instance[field] = (
|
||||||
|
nested_model.dict() if nested_model is not None else {}
|
||||||
|
)
|
||||||
return dict_instance
|
return dict_instance
|
||||||
|
|
||||||
def from_dict(self, value_dict: Dict) -> None:
|
def from_dict(self, value_dict: Dict) -> None:
|
||||||
@ -245,16 +281,22 @@ class Model(list, metaclass=ModelMetaclass):
|
|||||||
def extract_related_names(cls) -> Set:
|
def extract_related_names(cls) -> Set:
|
||||||
related_names = set()
|
related_names = set()
|
||||||
for name, field in cls.__fields__.items():
|
for name, field in cls.__fields__.items():
|
||||||
if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel):
|
if inspect.isclass(field.type_) and issubclass(
|
||||||
|
field.type_, pydantic.BaseModel
|
||||||
|
):
|
||||||
related_names.add(name)
|
related_names.add(name)
|
||||||
return related_names
|
return related_names
|
||||||
|
|
||||||
def extract_model_db_fields(self) -> Dict:
|
def extract_model_db_fields(self) -> Dict:
|
||||||
self_fields = self.extract_own_model_fields()
|
self_fields = self.extract_own_model_fields()
|
||||||
self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns}
|
self_fields = {
|
||||||
|
k: v for k, v in self_fields.items() if k in self.__table__.columns
|
||||||
|
}
|
||||||
for field in self.extract_related_names():
|
for field in self.extract_related_names():
|
||||||
if getattr(self, field) is not None:
|
if getattr(self, field) is not None:
|
||||||
self_fields[field] = getattr(getattr(self, field), self.__model_fields__[field].to.__pkname__)
|
self_fields[field] = getattr(
|
||||||
|
getattr(self, field), self.__model_fields__[field].to.__pkname__
|
||||||
|
)
|
||||||
return self_fields
|
return self_fields
|
||||||
|
|
||||||
async def save(self) -> int:
|
async def save(self) -> int:
|
||||||
@ -264,7 +306,7 @@ class Model(list, metaclass=ModelMetaclass):
|
|||||||
expr = self.__table__.insert()
|
expr = self.__table__.insert()
|
||||||
expr = expr.values(**self_fields)
|
expr = expr.values(**self_fields)
|
||||||
item_id = await self.__database__.execute(expr)
|
item_id = await self.__database__.execute(expr)
|
||||||
setattr(self, 'pk', item_id)
|
self.pk = item_id
|
||||||
return item_id
|
return item_id
|
||||||
|
|
||||||
async def update(self, **kwargs: Any) -> int:
|
async def update(self, **kwargs: Any) -> int:
|
||||||
@ -274,8 +316,11 @@ class Model(list, metaclass=ModelMetaclass):
|
|||||||
|
|
||||||
self_fields = self.extract_model_db_fields()
|
self_fields = self.extract_model_db_fields()
|
||||||
self_fields.pop(self.__pkname__)
|
self_fields.pop(self.__pkname__)
|
||||||
expr = self.__table__.update().values(**self_fields).where(
|
expr = (
|
||||||
self.pk_column == getattr(self, self.__pkname__))
|
self.__table__.update()
|
||||||
|
.values(**self_fields)
|
||||||
|
.where(self.pk_column == getattr(self, self.__pkname__))
|
||||||
|
)
|
||||||
result = await self.__database__.execute(expr)
|
result = await self.__database__.execute(expr)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -285,7 +330,7 @@ class Model(list, metaclass=ModelMetaclass):
|
|||||||
result = await self.__database__.execute(expr)
|
result = await self.__database__.execute(expr)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def load(self) -> 'Model':
|
async def load(self) -> "Model":
|
||||||
expr = self.__table__.select().where(self.pk_column == self.pk)
|
expr = self.__table__.select().where(self.pk_column == self.pk)
|
||||||
row = await self.__database__.fetch_one(expr)
|
row = await self.__database__.fetch_one(expr)
|
||||||
self.from_dict(dict(row))
|
self.from_dict(dict(row))
|
||||||
|
|||||||
261
orm/queryset.py
261
orm/queryset.py
@ -1,11 +1,14 @@
|
|||||||
from typing import List, TYPE_CHECKING, Type, NamedTuple
|
from typing import Any, List, NamedTuple, TYPE_CHECKING, Tuple, Type, Union
|
||||||
|
|
||||||
import sqlalchemy
|
import databases
|
||||||
from sqlalchemy import text
|
|
||||||
|
|
||||||
import orm
|
import orm
|
||||||
from orm import ForeignKey
|
from orm import ForeignKey
|
||||||
from orm.exceptions import NoMatch, MultipleMatches
|
from orm.exceptions import MultipleMatches, NoMatch
|
||||||
|
from orm.fields import BaseField
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from orm.models import Model
|
from orm.models import Model
|
||||||
@ -24,17 +27,23 @@ FILTER_OPERATORS = {
|
|||||||
|
|
||||||
|
|
||||||
class JoinParameters(NamedTuple):
|
class JoinParameters(NamedTuple):
|
||||||
prev_model: Type['Model']
|
prev_model: Type["Model"]
|
||||||
previous_alias: str
|
previous_alias: str
|
||||||
from_table: str
|
from_table: str
|
||||||
model_cls: Type['Model']
|
model_cls: Type["Model"]
|
||||||
|
|
||||||
|
|
||||||
class QuerySet:
|
class QuerySet:
|
||||||
ESCAPE_CHARACTERS = ['%', '_']
|
ESCAPE_CHARACTERS = ["%", "_"]
|
||||||
|
|
||||||
def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None,
|
def __init__(
|
||||||
limit_count: int = None, offset: int = None):
|
self,
|
||||||
|
model_cls: Type["Model"] = None,
|
||||||
|
filter_clauses: List = None,
|
||||||
|
select_related: List = None,
|
||||||
|
limit_count: int = None,
|
||||||
|
offset: int = None,
|
||||||
|
) -> None:
|
||||||
self.model_cls = model_cls
|
self.model_cls = model_cls
|
||||||
self.filter_clauses = [] if filter_clauses is None else filter_clauses
|
self.filter_clauses = [] if filter_clauses is None else filter_clauses
|
||||||
self._select_related = [] if select_related is None else select_related
|
self._select_related = [] if select_related is None else select_related
|
||||||
@ -48,47 +57,77 @@ class QuerySet:
|
|||||||
self.columns = None
|
self.columns = None
|
||||||
self.order_bys = None
|
self.order_bys = None
|
||||||
|
|
||||||
def __get__(self, instance, owner):
|
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
|
||||||
return self.__class__(model_cls=owner)
|
return self.__class__(model_cls=owner)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def database(self):
|
def database(self) -> databases.Database:
|
||||||
return self.model_cls.__database__
|
return self.model_cls.__database__
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def table(self):
|
def table(self) -> sqlalchemy.Table:
|
||||||
return self.model_cls.__table__
|
return self.model_cls.__table__
|
||||||
|
|
||||||
def prefixed_columns(self, alias, table):
|
def prefixed_columns(self, alias: str, table: sqlalchemy.Table) -> List[text]:
|
||||||
return [text(f'{alias}_{table.name}.{column.name} as {alias}_{column.name}')
|
return [
|
||||||
for column in table.columns]
|
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
|
||||||
|
for column in table.columns
|
||||||
|
]
|
||||||
|
|
||||||
def prefixed_table_name(self, alias, name):
|
def prefixed_table_name(self, alias: str, name: str) -> text:
|
||||||
return text(f'{name} {alias}_{name}')
|
return text(f"{name} {alias}_{name}")
|
||||||
|
|
||||||
def on_clause(self, from_table, to_table, previous_alias, alias, to_key, from_key):
|
def on_clause(
|
||||||
return text(f'{alias}_{to_table}.{to_key}='
|
self,
|
||||||
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}')
|
from_table: str,
|
||||||
|
to_table: str,
|
||||||
|
previous_alias: str,
|
||||||
|
alias: str,
|
||||||
|
to_key: str,
|
||||||
|
from_key: str,
|
||||||
|
) -> text:
|
||||||
|
return text(
|
||||||
|
f"{alias}_{to_table}.{to_key}="
|
||||||
|
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}'
|
||||||
|
)
|
||||||
|
|
||||||
def build_join_parameters(self, part, join_params: JoinParameters):
|
def build_join_parameters(
|
||||||
|
self, part: str, join_params: JoinParameters
|
||||||
|
) -> JoinParameters:
|
||||||
model_cls = join_params.model_cls.__model_fields__[part].to
|
model_cls = join_params.model_cls.__model_fields__[part].to
|
||||||
to_table = model_cls.__table__.name
|
to_table = model_cls.__table__.name
|
||||||
|
|
||||||
alias = model_cls._orm_relationship_manager.resolve_relation_join(join_params.from_table, to_table)
|
alias = model_cls._orm_relationship_manager.resolve_relation_join(
|
||||||
|
join_params.from_table, to_table
|
||||||
|
)
|
||||||
if alias not in self.used_aliases:
|
if alias not in self.used_aliases:
|
||||||
if join_params.prev_model.__model_fields__[part].virtual:
|
if join_params.prev_model.__model_fields__[part].virtual:
|
||||||
to_key = next((v for k, v in model_cls.__model_fields__.items()
|
to_key = next(
|
||||||
if isinstance(v, ForeignKey) and v.to == join_params.prev_model), None).name
|
(
|
||||||
|
v
|
||||||
|
for k, v in model_cls.__model_fields__.items()
|
||||||
|
if isinstance(v, ForeignKey) and v.to == join_params.prev_model
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
).name
|
||||||
from_key = model_cls.__pkname__
|
from_key = model_cls.__pkname__
|
||||||
else:
|
else:
|
||||||
to_key = model_cls.__pkname__
|
to_key = model_cls.__pkname__
|
||||||
from_key = part
|
from_key = part
|
||||||
|
|
||||||
on_clause = self.on_clause(join_params.from_table, to_table, join_params.previous_alias, alias, to_key,
|
on_clause = self.on_clause(
|
||||||
from_key)
|
join_params.from_table,
|
||||||
|
to_table,
|
||||||
|
join_params.previous_alias,
|
||||||
|
alias,
|
||||||
|
to_key,
|
||||||
|
from_key,
|
||||||
|
)
|
||||||
target_table = self.prefixed_table_name(alias, to_table)
|
target_table = self.prefixed_table_name(alias, to_table)
|
||||||
self.select_from = sqlalchemy.sql.outerjoin(self.select_from, target_table, on_clause)
|
self.select_from = sqlalchemy.sql.outerjoin(
|
||||||
self.order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}'))
|
self.select_from, target_table, on_clause
|
||||||
|
)
|
||||||
|
self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}"))
|
||||||
self.columns.extend(self.prefixed_columns(alias, model_cls.__table__))
|
self.columns.extend(self.prefixed_columns(alias, model_cls.__table__))
|
||||||
self.used_aliases.append(alias)
|
self.used_aliases.append(alias)
|
||||||
|
|
||||||
@ -98,44 +137,76 @@ class QuerySet:
|
|||||||
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part) -> bool:
|
def field_is_a_foreign_key_and_no_circular_reference(
|
||||||
|
field: BaseField, field_name: str, rel_part: str
|
||||||
|
) -> bool:
|
||||||
return isinstance(field, ForeignKey) and field_name not in rel_part
|
return isinstance(field, ForeignKey) and field_name not in rel_part
|
||||||
|
|
||||||
def field_qualifies_to_deeper_search(self, field, parent_virtual, nested, rel_part) -> bool:
|
def field_qualifies_to_deeper_search(
|
||||||
|
self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str
|
||||||
|
) -> bool:
|
||||||
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
|
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
|
||||||
partial_match = any([x.startswith(prev_part_of_related) for x in self._select_related])
|
partial_match = any(
|
||||||
|
[x.startswith(prev_part_of_related) for x in self._select_related]
|
||||||
|
)
|
||||||
already_checked = any([x.startswith(rel_part) for x in self.auto_related])
|
already_checked = any([x.startswith(rel_part) for x in self.auto_related])
|
||||||
return ((field.virtual and parent_virtual) or (partial_match and not already_checked)) or not nested
|
return (
|
||||||
|
(field.virtual and parent_virtual)
|
||||||
|
or (partial_match and not already_checked)
|
||||||
|
) or not nested
|
||||||
|
|
||||||
def extract_auto_required_relations(self, join_params: JoinParameters,
|
def extract_auto_required_relations(
|
||||||
rel_part: str = '', nested: bool = False, parent_virtual: bool = False):
|
self,
|
||||||
|
join_params: JoinParameters,
|
||||||
|
rel_part: str = "",
|
||||||
|
nested: bool = False,
|
||||||
|
parent_virtual: bool = False,
|
||||||
|
) -> None:
|
||||||
for field_name, field in join_params.prev_model.__model_fields__.items():
|
for field_name, field in join_params.prev_model.__model_fields__.items():
|
||||||
if self.field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part):
|
if self.field_is_a_foreign_key_and_no_circular_reference(
|
||||||
rel_part = field_name if not rel_part else rel_part + '__' + field_name
|
field, field_name, rel_part
|
||||||
|
):
|
||||||
|
rel_part = field_name if not rel_part else rel_part + "__" + field_name
|
||||||
if not field.nullable:
|
if not field.nullable:
|
||||||
if rel_part not in self._select_related:
|
if rel_part not in self._select_related:
|
||||||
self.auto_related.append("__".join(rel_part.split("__")[:-1]))
|
self.auto_related.append("__".join(rel_part.split("__")[:-1]))
|
||||||
rel_part = ''
|
rel_part = ""
|
||||||
elif self.field_qualifies_to_deeper_search(field, parent_virtual, nested, rel_part):
|
elif self.field_qualifies_to_deeper_search(
|
||||||
join_params = JoinParameters(field.to, join_params.previous_alias,
|
field, parent_virtual, nested, rel_part
|
||||||
join_params.from_table, join_params.prev_model)
|
):
|
||||||
self.extract_auto_required_relations(join_params=join_params,
|
join_params = JoinParameters(
|
||||||
rel_part=rel_part, nested=True, parent_virtual=field.virtual)
|
field.to,
|
||||||
|
join_params.previous_alias,
|
||||||
|
join_params.from_table,
|
||||||
|
join_params.prev_model,
|
||||||
|
)
|
||||||
|
self.extract_auto_required_relations(
|
||||||
|
join_params=join_params,
|
||||||
|
rel_part=rel_part,
|
||||||
|
nested=True,
|
||||||
|
parent_virtual=field.virtual,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
rel_part = ''
|
rel_part = ""
|
||||||
|
|
||||||
def build_select_expression(self):
|
def build_select_expression(self) -> sqlalchemy.sql.select:
|
||||||
self.columns = list(self.table.columns)
|
self.columns = list(self.table.columns)
|
||||||
self.order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')]
|
self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")]
|
||||||
self.select_from = self.table
|
self.select_from = self.table
|
||||||
|
|
||||||
for key in self.model_cls.__model_fields__:
|
for key in self.model_cls.__model_fields__:
|
||||||
if not self.model_cls.__model_fields__[key].nullable \
|
if (
|
||||||
and isinstance(self.model_cls.__model_fields__[key], orm.fields.ForeignKey) \
|
not self.model_cls.__model_fields__[key].nullable
|
||||||
and key not in self._select_related:
|
and isinstance(
|
||||||
|
self.model_cls.__model_fields__[key], orm.fields.ForeignKey
|
||||||
|
)
|
||||||
|
and key not in self._select_related
|
||||||
|
):
|
||||||
self._select_related = [key] + self._select_related
|
self._select_related = [key] + self._select_related
|
||||||
|
|
||||||
start_params = JoinParameters(self.model_cls, '', self.table.name, self.model_cls)
|
start_params = JoinParameters(
|
||||||
|
self.model_cls, "", self.table.name, self.model_cls
|
||||||
|
)
|
||||||
self.extract_auto_required_relations(start_params)
|
self.extract_auto_required_relations(start_params)
|
||||||
if self.auto_related:
|
if self.auto_related:
|
||||||
new_joins = []
|
new_joins = []
|
||||||
@ -146,7 +217,9 @@ class QuerySet:
|
|||||||
self._select_related.sort(key=lambda item: (-len(item), item))
|
self._select_related.sort(key=lambda item: (-len(item), item))
|
||||||
|
|
||||||
for item in self._select_related:
|
for item in self._select_related:
|
||||||
join_parameters = JoinParameters(self.model_cls, '', self.table.name, self.model_cls)
|
join_parameters = JoinParameters(
|
||||||
|
self.model_cls, "", self.table.name, self.model_cls
|
||||||
|
)
|
||||||
|
|
||||||
for part in item.split("__"):
|
for part in item.split("__"):
|
||||||
join_parameters = self.build_join_parameters(part, join_parameters)
|
join_parameters = self.build_join_parameters(part, join_parameters)
|
||||||
@ -180,7 +253,7 @@ class QuerySet:
|
|||||||
|
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
def filter(self, **kwargs):
|
def filter(self, **kwargs: Any) -> "QuerySet":
|
||||||
filter_clauses = self.filter_clauses
|
filter_clauses = self.filter_clauses
|
||||||
select_related = list(self._select_related)
|
select_related = list(self._select_related)
|
||||||
|
|
||||||
@ -189,7 +262,7 @@ class QuerySet:
|
|||||||
kwargs[pk_name] = kwargs.pop("pk")
|
kwargs[pk_name] = kwargs.pop("pk")
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
table_prefix = ''
|
table_prefix = ""
|
||||||
if "__" in key:
|
if "__" in key:
|
||||||
parts = key.split("__")
|
parts = key.split("__")
|
||||||
|
|
||||||
@ -215,9 +288,13 @@ class QuerySet:
|
|||||||
# against which the comparison is being made.
|
# against which the comparison is being made.
|
||||||
previous_table = model_cls.__tablename__
|
previous_table = model_cls.__tablename__
|
||||||
for part in related_parts:
|
for part in related_parts:
|
||||||
current_table = model_cls.__model_fields__[part].to.__tablename__
|
current_table = model_cls.__model_fields__[
|
||||||
table_prefix = model_cls._orm_relationship_manager.resolve_relation_join(previous_table,
|
part
|
||||||
current_table)
|
].to.__tablename__
|
||||||
|
manager = model_cls._orm_relationship_manager
|
||||||
|
table_prefix = manager.resolve_relation_join(
|
||||||
|
previous_table, current_table
|
||||||
|
)
|
||||||
model_cls = model_cls.__model_fields__[part].to
|
model_cls = model_cls.__model_fields__[part].to
|
||||||
previous_table = current_table
|
previous_table = current_table
|
||||||
|
|
||||||
@ -236,25 +313,32 @@ class QuerySet:
|
|||||||
has_escaped_character = False
|
has_escaped_character = False
|
||||||
|
|
||||||
if op in ["contains", "icontains"]:
|
if op in ["contains", "icontains"]:
|
||||||
has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS
|
has_escaped_character = any(
|
||||||
if c in value)
|
c for c in self.ESCAPE_CHARACTERS if c in value
|
||||||
|
)
|
||||||
if has_escaped_character:
|
if has_escaped_character:
|
||||||
# enable escape modifier
|
# enable escape modifier
|
||||||
for char in self.ESCAPE_CHARACTERS:
|
for char in self.ESCAPE_CHARACTERS:
|
||||||
value = value.replace(char, f'\\{char}')
|
value = value.replace(char, f"\\{char}")
|
||||||
value = f"%{value}%"
|
value = f"%{value}%"
|
||||||
|
|
||||||
if isinstance(value, orm.Model):
|
if isinstance(value, orm.Model):
|
||||||
value = value.pk
|
value = value.pk
|
||||||
|
|
||||||
clause = getattr(column, op_attr)(value)
|
clause = getattr(column, op_attr)(value)
|
||||||
clause.modifiers['escape'] = '\\' if has_escaped_character else None
|
clause.modifiers["escape"] = "\\" if has_escaped_character else None
|
||||||
|
|
||||||
clause_text = str(clause.compile(dialect=self.model_cls.__database__._backend._dialect,
|
clause_text = str(
|
||||||
compile_kwargs={"literal_binds": True}))
|
clause.compile(
|
||||||
alias = f'{table_prefix}_' if table_prefix else ''
|
dialect=self.model_cls.__database__._backend._dialect,
|
||||||
aliased_name = f'{alias}{table.name}.{column.name}'
|
compile_kwargs={"literal_binds": True},
|
||||||
clause_text = clause_text.replace(f'{table.name}.{column.name}', aliased_name)
|
)
|
||||||
|
)
|
||||||
|
alias = f"{table_prefix}_" if table_prefix else ""
|
||||||
|
aliased_name = f"{alias}{table.name}.{column.name}"
|
||||||
|
clause_text = clause_text.replace(
|
||||||
|
f"{table.name}.{column.name}", aliased_name
|
||||||
|
)
|
||||||
clause = text(clause_text)
|
clause = text(clause_text)
|
||||||
|
|
||||||
filter_clauses.append(clause)
|
filter_clauses.append(clause)
|
||||||
@ -264,10 +348,10 @@ class QuerySet:
|
|||||||
filter_clauses=filter_clauses,
|
filter_clauses=filter_clauses,
|
||||||
select_related=select_related,
|
select_related=select_related,
|
||||||
limit_count=self.limit_count,
|
limit_count=self.limit_count,
|
||||||
offset=self.query_offset
|
offset=self.query_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
def select_related(self, related):
|
def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet":
|
||||||
if not isinstance(related, (list, tuple)):
|
if not isinstance(related, (list, tuple)):
|
||||||
related = [related]
|
related = [related]
|
||||||
|
|
||||||
@ -277,7 +361,7 @@ class QuerySet:
|
|||||||
filter_clauses=self.filter_clauses,
|
filter_clauses=self.filter_clauses,
|
||||||
select_related=related,
|
select_related=related,
|
||||||
limit_count=self.limit_count,
|
limit_count=self.limit_count,
|
||||||
offset=self.query_offset
|
offset=self.query_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def exists(self) -> bool:
|
async def exists(self) -> bool:
|
||||||
@ -290,25 +374,25 @@ class QuerySet:
|
|||||||
expr = sqlalchemy.func.count().select().select_from(expr)
|
expr = sqlalchemy.func.count().select().select_from(expr)
|
||||||
return await self.database.fetch_val(expr)
|
return await self.database.fetch_val(expr)
|
||||||
|
|
||||||
def limit(self, limit_count: int):
|
def limit(self, limit_count: int) -> "QuerySet":
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
model_cls=self.model_cls,
|
model_cls=self.model_cls,
|
||||||
filter_clauses=self.filter_clauses,
|
filter_clauses=self.filter_clauses,
|
||||||
select_related=self._select_related,
|
select_related=self._select_related,
|
||||||
limit_count=limit_count,
|
limit_count=limit_count,
|
||||||
offset=self.query_offset
|
offset=self.query_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
def offset(self, offset: int):
|
def offset(self, offset: int) -> "QuerySet":
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
model_cls=self.model_cls,
|
model_cls=self.model_cls,
|
||||||
filter_clauses=self.filter_clauses,
|
filter_clauses=self.filter_clauses,
|
||||||
select_related=self._select_related,
|
select_related=self._select_related,
|
||||||
limit_count=self.limit_count,
|
limit_count=self.limit_count,
|
||||||
offset=offset
|
offset=offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def first(self, **kwargs):
|
async def first(self, **kwargs: Any) -> "Model":
|
||||||
if kwargs:
|
if kwargs:
|
||||||
return await self.filter(**kwargs).first()
|
return await self.filter(**kwargs).first()
|
||||||
|
|
||||||
@ -316,7 +400,7 @@ class QuerySet:
|
|||||||
if rows:
|
if rows:
|
||||||
return rows[0]
|
return rows[0]
|
||||||
|
|
||||||
async def get(self, **kwargs):
|
async def get(self, **kwargs: Any) -> "Model":
|
||||||
if kwargs:
|
if kwargs:
|
||||||
return await self.filter(**kwargs).get()
|
return await self.filter(**kwargs).get()
|
||||||
|
|
||||||
@ -329,7 +413,7 @@ class QuerySet:
|
|||||||
raise MultipleMatches()
|
raise MultipleMatches()
|
||||||
return self.model_cls.from_row(rows[0], select_related=self._select_related)
|
return self.model_cls.from_row(rows[0], select_related=self._select_related)
|
||||||
|
|
||||||
async def all(self, **kwargs):
|
async def all(self, **kwargs: Any) -> List["Model"]:
|
||||||
if kwargs:
|
if kwargs:
|
||||||
return await self.filter(**kwargs).all()
|
return await self.filter(**kwargs).all()
|
||||||
|
|
||||||
@ -345,7 +429,7 @@ class QuerySet:
|
|||||||
return result_rows
|
return result_rows
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def merge_result_rows(cls, result_rows):
|
def merge_result_rows(cls, result_rows: List["Model"]) -> List["Model"]:
|
||||||
merged_rows = []
|
merged_rows = []
|
||||||
for index, model in enumerate(result_rows):
|
for index, model in enumerate(result_rows):
|
||||||
if index > 0 and model.pk == result_rows[index - 1].pk:
|
if index > 0 and model.pk == result_rows[index - 1].pk:
|
||||||
@ -355,30 +439,45 @@ class QuerySet:
|
|||||||
return merged_rows
|
return merged_rows
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def merge_two_instances(cls, one: 'Model', other: 'Model'):
|
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
|
||||||
for field in one.__model_fields__.keys():
|
for field in one.__model_fields__.keys():
|
||||||
# print(field, one.dict(), other.dict())
|
# print(field, one.dict(), other.dict())
|
||||||
if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model):
|
if isinstance(getattr(one, field), list) and not isinstance(
|
||||||
|
getattr(one, field), orm.models.Model
|
||||||
|
):
|
||||||
setattr(other, field, getattr(one, field) + getattr(other, field))
|
setattr(other, field, getattr(one, field) + getattr(other, field))
|
||||||
elif isinstance(getattr(one, field), orm.models.Model):
|
elif isinstance(getattr(one, field), orm.models.Model):
|
||||||
if getattr(one, field).pk == getattr(other, field).pk:
|
if getattr(one, field).pk == getattr(other, field).pk:
|
||||||
setattr(other, field, cls.merge_two_instances(getattr(one, field), getattr(other, field)))
|
setattr(
|
||||||
|
other,
|
||||||
|
field,
|
||||||
|
cls.merge_two_instances(
|
||||||
|
getattr(one, field), getattr(other, field)
|
||||||
|
),
|
||||||
|
)
|
||||||
return other
|
return other
|
||||||
|
|
||||||
async def create(self, **kwargs):
|
async def create(self, **kwargs: Any) -> "Model":
|
||||||
|
|
||||||
new_kwargs = dict(**kwargs)
|
new_kwargs = dict(**kwargs)
|
||||||
|
|
||||||
# Remove primary key when None to prevent not null constraint in postgresql.
|
# Remove primary key when None to prevent not null constraint in postgresql.
|
||||||
pkname = self.model_cls.__pkname__
|
pkname = self.model_cls.__pkname__
|
||||||
pk = self.model_cls.__model_fields__[pkname]
|
pk = self.model_cls.__model_fields__[pkname]
|
||||||
if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement):
|
if (
|
||||||
|
pkname in new_kwargs
|
||||||
|
and new_kwargs.get(pkname) is None
|
||||||
|
and (pk.nullable or pk.autoincrement)
|
||||||
|
):
|
||||||
del new_kwargs[pkname]
|
del new_kwargs[pkname]
|
||||||
|
|
||||||
# substitute related models with their pk
|
# substitute related models with their pk
|
||||||
for field in self.model_cls.extract_related_names():
|
for field in self.model_cls.extract_related_names():
|
||||||
if field in new_kwargs and new_kwargs.get(field) is not None:
|
if field in new_kwargs and new_kwargs.get(field) is not None:
|
||||||
new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__)
|
new_kwargs[field] = getattr(
|
||||||
|
new_kwargs.get(field),
|
||||||
|
self.model_cls.__model_fields__[field].to.__pkname__,
|
||||||
|
)
|
||||||
|
|
||||||
# Build the insert expression.
|
# Build the insert expression.
|
||||||
expr = self.table.insert()
|
expr = self.table.insert()
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import pprint
|
|||||||
import string
|
import string
|
||||||
import uuid
|
import uuid
|
||||||
from random import choices
|
from random import choices
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import Dict, List, TYPE_CHECKING, Union
|
||||||
from weakref import proxy
|
from weakref import proxy
|
||||||
|
|
||||||
from orm.fields import ForeignKey
|
from orm.fields import ForeignKey
|
||||||
@ -11,40 +11,58 @@ if TYPE_CHECKING: # pragma no cover
|
|||||||
from orm.models import Model
|
from orm.models import Model
|
||||||
|
|
||||||
|
|
||||||
def get_table_alias():
|
def get_table_alias() -> str:
|
||||||
return ''.join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
|
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
|
||||||
|
|
||||||
|
|
||||||
def get_relation_config(relation_type: str, table_name: str, field: ForeignKey):
|
def get_relation_config(
|
||||||
|
relation_type: str, table_name: str, field: ForeignKey
|
||||||
|
) -> Dict[str, str]:
|
||||||
alias = get_table_alias()
|
alias = get_table_alias()
|
||||||
config = {'type': relation_type,
|
config = {
|
||||||
'table_alias': alias,
|
"type": relation_type,
|
||||||
'source_table': table_name if relation_type == 'primary' else field.to.__tablename__,
|
"table_alias": alias,
|
||||||
'target_table': field.to.__tablename__ if relation_type == 'primary' else table_name
|
"source_table": table_name
|
||||||
|
if relation_type == "primary"
|
||||||
|
else field.to.__tablename__,
|
||||||
|
"target_table": field.to.__tablename__
|
||||||
|
if relation_type == "primary"
|
||||||
|
else table_name,
|
||||||
}
|
}
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
class RelationshipManager:
|
class RelationshipManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
def __init__(self):
|
|
||||||
self._relations = dict()
|
self._relations = dict()
|
||||||
|
|
||||||
def add_relation_type(self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str):
|
def add_relation_type(
|
||||||
print(relations_key, reverse_key)
|
self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str
|
||||||
|
) -> None:
|
||||||
if relations_key not in self._relations:
|
if relations_key not in self._relations:
|
||||||
self._relations[relations_key] = get_relation_config('primary', table_name, field)
|
self._relations[relations_key] = get_relation_config(
|
||||||
|
"primary", table_name, field
|
||||||
|
)
|
||||||
if reverse_key not in self._relations:
|
if reverse_key not in self._relations:
|
||||||
self._relations[reverse_key] = get_relation_config('reverse', table_name, field)
|
self._relations[reverse_key] = get_relation_config(
|
||||||
|
"reverse", table_name, field
|
||||||
|
)
|
||||||
|
|
||||||
def deregister(self, model: 'Model'):
|
def deregister(self, model: "Model") -> None:
|
||||||
# print(f'deregistering {model.__class__.__name__}, {model._orm_id}')
|
# print(f'deregistering {model.__class__.__name__}, {model._orm_id}')
|
||||||
for rel_type in self._relations.keys():
|
for rel_type in self._relations.keys():
|
||||||
if model.__class__.__name__.lower() in rel_type.lower():
|
if model.__class__.__name__.lower() in rel_type.lower():
|
||||||
if model._orm_id in self._relations[rel_type]:
|
if model._orm_id in self._relations[rel_type]:
|
||||||
del self._relations[rel_type][model._orm_id]
|
del self._relations[rel_type][model._orm_id]
|
||||||
|
|
||||||
def add_relation(self, parent_name: str, child_name: str, parent: 'Model', child: 'Model', virtual: bool = False):
|
def add_relation(
|
||||||
|
self,
|
||||||
|
parent_name: str,
|
||||||
|
child_name: str,
|
||||||
|
parent: "Model",
|
||||||
|
child: "Model",
|
||||||
|
virtual: bool = False,
|
||||||
|
) -> None:
|
||||||
parent_id = parent._orm_id
|
parent_id = parent._orm_id
|
||||||
child_id = child._orm_id
|
child_id = child._orm_id
|
||||||
if virtual:
|
if virtual:
|
||||||
@ -53,12 +71,18 @@ class RelationshipManager:
|
|||||||
child, parent = parent, proxy(child)
|
child, parent = parent, proxy(child)
|
||||||
else:
|
else:
|
||||||
child = proxy(child)
|
child = proxy(child)
|
||||||
parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, [])
|
parents_list = self._relations[
|
||||||
|
parent_name.lower().title() + "_" + child_name + "s"
|
||||||
|
].setdefault(parent_id, [])
|
||||||
self.append_related_model(parents_list, child)
|
self.append_related_model(parents_list, child)
|
||||||
children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, [])
|
children_list = self._relations[
|
||||||
|
child_name.lower().title() + "_" + parent_name
|
||||||
|
].setdefault(child_id, [])
|
||||||
self.append_related_model(children_list, parent)
|
self.append_related_model(children_list, parent)
|
||||||
|
|
||||||
def append_related_model(self, relations_list: List['Model'], model: 'Model'):
|
def append_related_model(
|
||||||
|
self, relations_list: List["Model"], model: "Model"
|
||||||
|
) -> None:
|
||||||
for x in relations_list:
|
for x in relations_list:
|
||||||
try:
|
try:
|
||||||
if x.__same__(model):
|
if x.__same__(model):
|
||||||
@ -68,26 +92,26 @@ class RelationshipManager:
|
|||||||
|
|
||||||
relations_list.append(model)
|
relations_list.append(model)
|
||||||
|
|
||||||
def contains(self, relations_key: str, object: 'Model'):
|
def contains(self, relations_key: str, object: "Model") -> bool:
|
||||||
if relations_key in self._relations:
|
if relations_key in self._relations:
|
||||||
return object._orm_id in self._relations[relations_key]
|
return object._orm_id in self._relations[relations_key]
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get(self, relations_key: str, object: 'Model'):
|
def get(self, relations_key: str, object: "Model") -> Union["Model", List["Model"]]:
|
||||||
if relations_key in self._relations:
|
if relations_key in self._relations:
|
||||||
if object._orm_id in self._relations[relations_key]:
|
if object._orm_id in self._relations[relations_key]:
|
||||||
if self._relations[relations_key]['type'] == 'primary':
|
if self._relations[relations_key]["type"] == "primary":
|
||||||
return self._relations[relations_key][object._orm_id][0]
|
return self._relations[relations_key][object._orm_id][0]
|
||||||
return self._relations[relations_key][object._orm_id]
|
return self._relations[relations_key][object._orm_id]
|
||||||
|
|
||||||
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
|
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
|
||||||
for k, v in self._relations.items():
|
for k, v in self._relations.items():
|
||||||
if v['source_table'] == from_table and v['target_table'] == to_table:
|
if v["source_table"] == from_table and v["target_table"] == to_table:
|
||||||
return self._relations[k]['table_alias']
|
return self._relations[k]["table_alias"]
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
def __str__(self): # pragma no cover
|
def __str__(self) -> str: # pragma no cover
|
||||||
return pprint.pformat(self._relations, indent=4, width=1)
|
return pprint.pformat(self._relations, indent=4, width=1)
|
||||||
|
|
||||||
def __repr__(self): # pragma no cover
|
def __repr__(self) -> str: # pragma no cover
|
||||||
return self.__str__()
|
return self.__str__()
|
||||||
|
|||||||
@ -109,6 +109,22 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition():
|
|||||||
test = fields.Integer(name='test12', primary_key=True, pydantic_only=True)
|
test = fields.Integer(name='test12', primary_key=True, pydantic_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_decimal_error_in_model_definition():
|
||||||
|
with pytest.raises(ModelDefinitionError):
|
||||||
|
class ExampleModel2(Model):
|
||||||
|
__tablename__ = "example4"
|
||||||
|
__metadata__ = metadata
|
||||||
|
test = fields.Decimal(name='test12', primary_key=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_error_in_model_definition():
|
||||||
|
with pytest.raises(ModelDefinitionError):
|
||||||
|
class ExampleModel2(Model):
|
||||||
|
__tablename__ = "example4"
|
||||||
|
__metadata__ = metadata
|
||||||
|
test = fields.String(name='test12', primary_key=True)
|
||||||
|
|
||||||
|
|
||||||
def test_json_conversion_in_model():
|
def test_json_conversion_in_model():
|
||||||
with pytest.raises(pydantic.ValidationError):
|
with pytest.raises(pydantic.ValidationError):
|
||||||
ExampleModel(test_json=datetime.datetime.now(), test=1, test_string='test', test_bool=True)
|
ExampleModel(test_json=datetime.datetime.now(), test=1, test_string='test', test_bool=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user