added mypy checks and some typehint changes to conform

This commit is contained in:
collerek
2020-09-29 14:05:08 +02:00
parent 6d56ea5e30
commit 3caa87057e
23 changed files with 274 additions and 202 deletions

BIN
.coverage

Binary file not shown.

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
p38venv
.idea
.pytest_cache
.mypy_cache
*.pyc
*.log
test.db

5
mypy.ini Normal file
View File

@ -0,0 +1,5 @@
[mypy]
python_version = 3.8
[mypy-sqlalchemy.*]
ignore_missing_imports = True

View File

@ -1,7 +1,8 @@
from typing import Any, List, Optional, TYPE_CHECKING, Union
from typing import Any, List, Optional, TYPE_CHECKING, Union, Type
import sqlalchemy
from pydantic import Field, typing
from pydantic.fields import FieldInfo
from ormar import ModelDefinitionError # noqa I101
@ -15,6 +16,7 @@ class BaseField:
column_type: sqlalchemy.Column
constraints: List = []
name: str
primary_key: bool
autoincrement: bool
@ -24,12 +26,14 @@ class BaseField:
pydantic_only: bool
virtual: bool = False
choices: typing.Sequence
to: Type["Model"]
through: Type["Model"]
default: Any
server_default: Any
@classmethod
def default_value(cls) -> Optional[Field]:
def default_value(cls) -> Optional[FieldInfo]:
if cls.is_auto_primary_key():
return Field(default=None)
if cls.has_default():

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Type, Union
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union, Generator
import sqlalchemy
@ -7,7 +7,7 @@ from ormar.exceptions import RelationshipInstanceError
from ormar.fields.base import BaseField
if TYPE_CHECKING: # pragma no cover
from ormar.models import Model
from ormar.models import Model, NewBaseModel
def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
@ -23,16 +23,16 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
def ForeignKey( # noqa CFQ002
to: Type["Model"],
*,
name: str = None,
unique: bool = False,
nullable: bool = True,
related_name: str = None,
virtual: bool = False,
onupdate: str = None,
ondelete: str = None,
) -> Type[object]:
to: Type["Model"],
*,
name: str = None,
unique: bool = False,
nullable: bool = True,
related_name: str = None,
virtual: bool = False,
onupdate: str = None,
ondelete: str = None,
) -> Type["ForeignKeyField"]:
fk_string = to.Meta.tablename + "." + to.Meta.pkname
to_field = to.__fields__[to.Meta.pkname]
namespace = dict(
@ -65,7 +65,7 @@ class ForeignKeyField(BaseField):
virtual: bool
@classmethod
def __get_validators__(cls) -> Callable:
def __get_validators__(cls) -> Generator:
yield cls.validate
@classmethod
@ -74,13 +74,13 @@ class ForeignKeyField(BaseField):
@classmethod
def _extract_model_from_sequence(
cls, value: List, child: "Model", to_register: bool
) -> Union["Model", List["Model"]]:
return [cls.expand_relationship(val, child, to_register) for val in value]
cls, value: List, child: "Model", to_register: bool
) -> List["Model"]:
return [cls.expand_relationship(val, child, to_register) for val in value] # type: ignore
@classmethod
def _register_existing_model(
cls, value: "Model", child: "Model", to_register: bool
cls, value: "Model", child: "Model", to_register: bool
) -> "Model":
if to_register:
cls.register_relation(value, child)
@ -88,7 +88,7 @@ class ForeignKeyField(BaseField):
@classmethod
def _construct_model_from_dict(
cls, value: dict, child: "Model", to_register: bool
cls, value: dict, child: "Model", to_register: bool
) -> "Model":
if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname:
value["__pk_only__"] = True
@ -99,7 +99,7 @@ class ForeignKeyField(BaseField):
@classmethod
def _construct_model_from_pk(
cls, value: Any, child: "Model", to_register: bool
cls, value: Any, child: "Model", to_register: bool
) -> "Model":
if not isinstance(value, cls.to.pk_type()):
raise RelationshipInstanceError(
@ -120,7 +120,7 @@ class ForeignKeyField(BaseField):
@classmethod
def expand_relationship(
cls, value: Any, child: "Model", to_register: bool = True
cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True
) -> Optional[Union["Model", List["Model"]]]:
if value is None:
return None if not cls.virtual else []
@ -131,7 +131,7 @@ class ForeignKeyField(BaseField):
"list": cls._extract_model_from_sequence,
}
model = constructors.get(
model = constructors.get( # type: ignore
value.__class__.__name__, cls._construct_model_from_pk
)(value, child, to_register)
return model

View File

@ -15,7 +15,7 @@ def ManyToMany(
unique: bool = False,
related_name: str = None,
virtual: bool = False,
) -> Type[object]:
) -> Type["ManyToManyField"]:
to_field = to.__fields__[to.Meta.pkname]
namespace = dict(
to=to,

View File

@ -18,10 +18,10 @@ def is_field_nullable(
class ModelFieldFactory:
_bases = BaseField
_type = None
_bases: Any = BaseField
_type: Any = None
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]:
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: # type: ignore
cls.validate(**kwargs)
default = kwargs.pop("default", None)
@ -58,7 +58,7 @@ class String(ModelFieldFactory):
_bases = (pydantic.ConstrainedStr, BaseField)
_type = str
def __new__( # noqa CFQ002
def __new__( # type: ignore # noqa CFQ002
cls,
*,
allow_blank: bool = False,
@ -68,7 +68,7 @@ class String(ModelFieldFactory):
curtail_length: int = None,
regex: str = None,
**kwargs: Any
) -> Type[str]:
) -> Type[BaseField]: # type: ignore
kwargs = {
**kwargs,
**{
@ -96,14 +96,14 @@ class Integer(ModelFieldFactory):
_bases = (pydantic.ConstrainedInt, BaseField)
_type = int
def __new__(
def __new__( # type: ignore
cls,
*,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
**kwargs: Any
) -> Type[int]:
) -> Type[BaseField]:
autoincrement = kwargs.pop("autoincrement", None)
autoincrement = (
autoincrement
@ -131,9 +131,9 @@ class Text(ModelFieldFactory):
_bases = (pydantic.ConstrainedStr, BaseField)
_type = str
def __new__(
def __new__( # type: ignore
cls, *, allow_blank: bool = False, strip_whitespace: bool = False, **kwargs: Any
) -> Type[str]:
) -> Type[BaseField]:
kwargs = {
**kwargs,
**{
@ -153,14 +153,14 @@ class Float(ModelFieldFactory):
_bases = (pydantic.ConstrainedFloat, BaseField)
_type = float
def __new__(
def __new__( # type: ignore
cls,
*,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
**kwargs: Any
) -> Type[int]:
) -> Type[BaseField]:
kwargs = {
**kwargs,
**{
@ -236,7 +236,7 @@ class Decimal(ModelFieldFactory):
_bases = (pydantic.ConstrainedDecimal, BaseField)
_type = decimal.Decimal
def __new__( # noqa CFQ002
def __new__( # type: ignore # noqa CFQ002
cls,
*,
minimum: float = None,
@ -247,7 +247,7 @@ class Decimal(ModelFieldFactory):
max_digits: int = None,
decimal_places: int = None,
**kwargs: Any
) -> Type[decimal.Decimal]:
) -> Type[BaseField]:
kwargs = {
**kwargs,
**{

View File

@ -28,15 +28,19 @@ class ModelMeta:
database: databases.Database
columns: List[sqlalchemy.Column]
pkname: str
model_fields: Dict[str, Union[BaseField, ForeignKey]]
model_fields: Dict[
str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]]
]
alias_manager: AliasManager
def register_relation_on_build(table_name: str, field: ForeignKey) -> None:
def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None:
alias_manager.add_relation_type(field.to.Meta.tablename, table_name)
def register_many_to_many_relation_on_build(table_name: str, field: ManyToMany) -> None:
def register_many_to_many_relation_on_build(
table_name: str, field: Type[ManyToManyField]
) -> None:
alias_manager.add_relation_type(field.through.Meta.tablename, table_name)
alias_manager.add_relation_type(
field.through.Meta.tablename, field.to.Meta.tablename
@ -106,7 +110,7 @@ def create_pydantic_field(
) -> None:
model_field.through.__fields__[field_name] = ModelField(
name=field_name,
type_=Optional[model],
type_=model,
model_config=model.__config__,
required=False,
class_validators={},
@ -130,7 +134,7 @@ def create_and_append_m2m_fk(
def check_pk_column_validity(
field_name: str, field: BaseField, pkname: str
field_name: str, field: BaseField, pkname: Optional[str]
) -> Optional[str]:
if pkname is not None:
raise ModelDefinitionError("Only one primary key column is allowed.")
@ -218,6 +222,7 @@ def populate_meta_tablename_columns_and_pk(
) -> Type["Model"]:
tablename = name.lower() + "s"
new_model.Meta.tablename = new_model.Meta.tablename or tablename
pkname: Optional[str]
if hasattr(new_model.Meta, "columns"):
columns = new_model.Meta.table.columns
@ -226,12 +231,13 @@ def populate_meta_tablename_columns_and_pk(
pkname, columns = sqlalchemy_columns_from_model_fields(
new_model.Meta.model_fields, new_model.Meta.tablename
)
if pkname is None:
raise ModelDefinitionError("Table has to have a primary key.")
new_model.Meta.columns = columns
new_model.Meta.pkname = pkname
if not new_model.Meta.pkname:
raise ModelDefinitionError("Table has to have a primary key.")
return new_model
@ -253,8 +259,8 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
return Config
def check_if_field_has_choices(field: BaseField) -> bool:
return hasattr(field, "choices") and field.choices
def check_if_field_has_choices(field: Type[BaseField]) -> bool:
return hasattr(field, "choices") and bool(field.choices)
def model_initialized_and_has_model_fields(model: Type["Model"]) -> bool:
@ -287,7 +293,7 @@ def populate_choices_validators( # noqa CCR001
class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
def __new__(mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict) -> "ModelMetaclass": # type: ignore
attrs["Config"] = get_pydantic_base_orm_config()
attrs["__name__"] = name
attrs = extract_annotations_and_default_vals(attrs, bases)
@ -306,7 +312,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
field_name = new_model.Meta.pkname
field = Integer(name=field_name, primary_key=True)
attrs["__annotations__"][field_name] = field
populate_default_pydantic_field_value(field, field_name, attrs)
populate_default_pydantic_field_value(field, field_name, attrs) # type: ignore
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs

View File

@ -1,5 +1,5 @@
import itertools
from typing import Any, List, Tuple, Union
from typing import Any, List, Dict, Optional
import sqlalchemy
from databases.backends.postgres import Record
@ -9,8 +9,8 @@ from ormar.fields.many_to_many import ManyToManyField
from ormar.models import NewBaseModel # noqa I100
def group_related_list(list_: List) -> dict:
test_dict = dict()
def group_related_list(list_: List) -> Dict:
test_dict: Dict[str, Any] = dict()
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
for key, group in grouped:
group_list = list(group)
@ -29,14 +29,14 @@ class Model(NewBaseModel):
@classmethod
def from_row(
cls,
row: sqlalchemy.engine.ResultProxy,
select_related: List = None,
related_models: Any = None,
previous_table: str = None,
) -> Union["Model", Tuple["Model", dict]]:
cls,
row: sqlalchemy.engine.ResultProxy,
select_related: List = None,
related_models: Any = None,
previous_table: str = None,
) -> Optional["Model"]:
item = {}
item: Dict[str, Any] = {}
select_related = select_related or []
related_models = related_models or []
if select_related:
@ -44,17 +44,20 @@ class Model(NewBaseModel):
# breakpoint()
if (
previous_table
and previous_table in cls.Meta.model_fields
and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField)
previous_table
and previous_table in cls.Meta.model_fields
and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField)
):
previous_table = cls.Meta.model_fields[
previous_table
].through.Meta.tablename
table_prefix = cls.Meta.alias_manager.resolve_relation_join(
previous_table, cls.Meta.table.name
)
if previous_table:
table_prefix = cls.Meta.alias_manager.resolve_relation_join(
previous_table, cls.Meta.table.name
)
else:
table_prefix = ''
previous_table = cls.Meta.table.name
item = cls.populate_nested_models_from_row(
@ -67,11 +70,11 @@ class Model(NewBaseModel):
@classmethod
def populate_nested_models_from_row(
cls,
item: dict,
row: sqlalchemy.engine.ResultProxy,
related_models: Any,
previous_table: sqlalchemy.Table,
cls,
item: dict,
row: sqlalchemy.engine.ResultProxy,
related_models: Any,
previous_table: sqlalchemy.Table,
) -> dict:
for related in related_models:
if isinstance(related_models, dict) and related_models[related]:
@ -90,7 +93,7 @@ class Model(NewBaseModel):
@classmethod
def extract_prefixed_table_columns( # noqa CCR001
cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str
cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str
) -> dict:
for column in cls.Meta.table.columns:
if column.name not in item:
@ -106,7 +109,7 @@ class Model(NewBaseModel):
async def save(self) -> "Model":
self_fields = self._extract_model_db_fields()
if not self.pk and self.Meta.model_fields.get(self.Meta.pkname).autoincrement:
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
self_fields.pop(self.Meta.pkname, None)
self_fields = self.objects._populate_default_values(self_fields)
expr = self.Meta.table.insert()
@ -138,5 +141,7 @@ class Model(NewBaseModel):
async def load(self) -> "Model":
expr = self.Meta.table.select().where(self.pk_column == self.pk)
row = await self.Meta.database.fetch_one(expr)
if not row: # pragma nocover
raise ValueError('Instance was deleted from database and cannot be refreshed')
self.from_dict(dict(row))
return self

View File

@ -1,5 +1,5 @@
import inspect
from typing import List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union
from typing import List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union, Dict
import ormar
from ormar.exceptions import RelationshipInstanceError
@ -9,6 +9,7 @@ from ormar.models.metaclass import ModelMeta
if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.models import NewBaseModel
Field = TypeVar("Field", bound=BaseField)
@ -17,10 +18,10 @@ class ModelTableProxy:
if TYPE_CHECKING: # pragma no cover
Meta: ModelMeta
def dict(): # noqa A003
def dict(self): # noqa A003
raise NotImplementedError # pragma no cover
def _extract_own_model_fields(self) -> dict:
def _extract_own_model_fields(self) -> Dict:
related_names = self.extract_related_names()
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
return self_fields
@ -34,7 +35,7 @@ class ModelTableProxy:
return self_fields
@classmethod
def substitute_models_with_pks(cls, model_dict: dict) -> dict:
def substitute_models_with_pks(cls, model_dict: Dict) -> Dict:
for field in cls.extract_related_names():
field_value = model_dict.get(field, None)
if field_value is not None:
@ -80,7 +81,7 @@ class ModelTableProxy:
related_names.add(name)
return related_names
def _extract_model_db_fields(self) -> dict:
def _extract_model_db_fields(self) -> Dict:
self_fields = self._extract_own_model_fields()
self_fields = {
k: v for k, v in self_fields.items() if k in self.Meta.table.columns
@ -92,7 +93,9 @@ class ModelTableProxy:
return self_fields
@staticmethod
def resolve_relation_name(item: "Model", related: "Model") -> Optional[str]:
def resolve_relation_name(
item: Union["NewBaseModel", Type["NewBaseModel"]], related: Union["NewBaseModel", Type["NewBaseModel"]]
) -> str:
for name, field in item.Meta.model_fields.items():
if issubclass(field, ForeignKeyField):
# fastapi is creating clones of response model
@ -100,11 +103,14 @@ class ModelTableProxy:
# so we need to compare Meta too as this one is copied as is
if field.to == related.__class__ or field.to.Meta == related.Meta:
return name
raise ValueError(
f"No relation between {item.get_name()} and {related.get_name()}"
) # pragma nocover
@staticmethod
def resolve_relation_field(
item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]]
) -> Type[Field]:
) -> Union[Type[BaseField], Type[ForeignKeyField]]:
name = ModelTableProxy.resolve_relation_name(item, related)
to_field = item.Meta.model_fields.get(name)
if not to_field: # pragma no cover
@ -116,7 +122,7 @@ class ModelTableProxy:
@classmethod
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows = []
merged_rows: List["Model"] = []
for index, model in enumerate(result_rows):
if index > 0 and model.pk == merged_rows[-1].pk:
merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])

View File

@ -3,13 +3,13 @@ import uuid
from typing import (
AbstractSet,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
TYPE_CHECKING,
Type,
TypeVar,
Union,
)
@ -39,7 +39,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
__slots__ = ("_orm_id", "_orm_saved", "_orm")
if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, TypeVar[BaseField]]
__model_fields__: Dict[str, Type[BaseField]]
__table__: sqlalchemy.Table
__fields__: Dict[str, pydantic.fields.ModelField]
__pydantic_model__: Type[BaseModel]
@ -84,7 +84,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
for k, v in kwargs.items()
}
values, fields_set, validation_error = pydantic.validate_model(self, kwargs)
values, fields_set, validation_error = pydantic.validate_model(self, kwargs) # type: ignore
if validation_error and not pk_only:
raise validation_error
@ -134,13 +134,14 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
) -> Optional[Union["Model", List["Model"]]]:
if item in self._orm:
return self._orm.get(item)
return None
def __eq__(self, other: "Model") -> bool:
def __eq__(self, other: object) -> bool:
if isinstance(other, NewBaseModel):
return self.__same__(other)
return super().__eq__(other) # pragma no cover
def __same__(self, other: "Model") -> bool:
def __same__(self, other: "NewBaseModel") -> bool:
return (
self._orm_id == other._orm_id
or self.dict() == other.dict()
@ -205,19 +206,19 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
dict_instance[field] = None
return dict_instance
def from_dict(self, value_dict: Dict) -> "Model":
def from_dict(self, value_dict: Dict) -> "NewBaseModel":
for key, value in value_dict.items():
setattr(self, key, value)
return self
def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]:
def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, Dict]:
if not self._is_conversion_to_json_needed(column_name):
return value
condition = (
isinstance(value, str) if op == "loads" else not isinstance(value, str)
)
operand = json.loads if op == "loads" else json.dumps
operand: Callable[[Any], Any] = json.loads if op == "loads" else json.dumps
if condition:
try:
@ -227,4 +228,4 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
return value
def _is_conversion_to_json_needed(self, column_name: str) -> bool:
return self.Meta.model_fields.get(column_name).__type__ == pydantic.Json
return self.Meta.model_fields[column_name].__type__ == pydantic.Json

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type
import sqlalchemy
from sqlalchemy import text
@ -118,7 +118,7 @@ class QueryClause:
def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str]
) -> Tuple[List[str], str, "Model"]:
) -> Tuple[List[str], str, Type["Model"]]:
table_prefix = ""
model_cls = self.model_cls
@ -168,9 +168,7 @@ class QueryClause:
return clause
@staticmethod
def _escape_characters_in_clause(
op: str, value: Union[str, "Model"]
) -> Tuple[str, bool]:
def _escape_characters_in_clause(op: str, value: Any) -> Tuple[Any, bool]:
has_escaped_character = False
if op not in [

View File

@ -22,8 +22,8 @@ class SqlJoin:
self,
used_aliases: List,
select_from: sqlalchemy.sql.select,
order_bys: List,
columns: List,
order_bys: List[sqlalchemy.sql.elements.TextClause],
columns: List[sqlalchemy.Column],
) -> None:
self.used_aliases = used_aliases
self.select_from = select_from

View File

@ -1,8 +1,10 @@
from typing import Optional
import sqlalchemy
class LimitQuery:
def __init__(self, limit_count: int) -> None:
def __init__(self, limit_count: Optional[int]) -> None:
self.limit_count = limit_count
def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select:

View File

@ -1,8 +1,10 @@
from typing import Optional
import sqlalchemy
class OffsetQuery:
def __init__(self, query_offset: int) -> None:
def __init__(self, query_offset: Optional[int]) -> None:
self.query_offset = query_offset
def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select:

View File

@ -1,4 +1,4 @@
from typing import List, TYPE_CHECKING, Tuple, Type
from typing import List, TYPE_CHECKING, Tuple, Type, Optional
import sqlalchemy
from sqlalchemy import text
@ -18,8 +18,8 @@ class Query:
filter_clauses: List,
exclude_clauses: List,
select_related: List,
limit_count: int,
offset: int,
limit_count: Optional[int],
offset: Optional[int],
) -> None:
self.query_offset = offset
self.limit_count = limit_count
@ -30,11 +30,11 @@ class Query:
self.model_cls = model_cls
self.table = self.model_cls.Meta.table
self.used_aliases = []
self.used_aliases: List[str] = []
self.select_from = None
self.columns = None
self.order_bys = None
self.select_from: List[str] = []
self.columns = [sqlalchemy.Column]
self.order_bys: List[sqlalchemy.sql.elements.TextClause] = []
@property
def prefixed_pk_name(self) -> str:
@ -89,7 +89,7 @@ class Query:
return expr
def _reset_query_parameters(self) -> None:
self.select_from = None
self.columns = None
self.order_bys = None
self.select_from = []
self.columns = []
self.order_bys = []
self.used_aliases = []

View File

@ -1,4 +1,4 @@
from typing import Any, List, Mapping, TYPE_CHECKING, Tuple, Type, Union
from typing import Any, List, Mapping, TYPE_CHECKING, Type, Union, Optional
import databases
import sqlalchemy
@ -13,17 +13,18 @@ from ormar.queryset.query import Query
if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.models.metaclass import ModelMeta
class QuerySet:
def __init__( # noqa CFQ002
self,
model_cls: Type["Model"] = None,
filter_clauses: List = None,
exclude_clauses: List = None,
select_related: List = None,
limit_count: int = None,
offset: int = None,
self,
model_cls: Type["Model"] = None,
filter_clauses: List = None,
exclude_clauses: List = None,
select_related: List = None,
limit_count: int = None,
offset: int = None,
) -> None:
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
@ -36,47 +37,60 @@ class QuerySet:
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
return self.__class__(model_cls=owner)
def _process_query_result_rows(self, rows: List[Mapping]) -> List["Model"]:
@property
def model_meta(self) -> "ModelMeta":
if not self.model_cls: # pragma nocover
raise ValueError("Model class of QuerySet is not initialized")
return self.model_cls.Meta
@property
def model(self) -> Type["Model"]:
if not self.model_cls: # pragma nocover
raise ValueError("Model class of QuerySet is not initialized")
return self.model_cls
def _process_query_result_rows(self, rows: List) -> List[Optional["Model"]]:
result_rows = [
self.model_cls.from_row(row, select_related=self._select_related)
self.model.from_row(row, select_related=self._select_related)
for row in rows
]
rows = self.model_cls.merge_instances_list(result_rows)
return rows
if result_rows:
return self.model.merge_instances_list(result_rows) # type: ignore
return result_rows
def _populate_default_values(self, new_kwargs: dict) -> dict:
for field_name, field in self.model_cls.Meta.model_fields.items():
for field_name, field in self.model_meta.model_fields.items():
if field_name not in new_kwargs and field.has_default():
new_kwargs[field_name] = field.get_default()
return new_kwargs
def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict:
pkname = self.model_cls.Meta.pkname
pk = self.model_cls.Meta.model_fields[pkname]
pkname = self.model_meta.pkname
pk = self.model_meta.model_fields[pkname]
if new_kwargs.get(pkname, ormar.Undefined) is None and (
pk.nullable or pk.autoincrement
pk.nullable or pk.autoincrement
):
del new_kwargs[pkname]
return new_kwargs
@staticmethod
def check_single_result_rows_count(rows: List["Model"]) -> None:
if not rows:
def check_single_result_rows_count(rows: List[Optional["Model"]]) -> None:
if not rows or rows[0] is None:
raise NoMatch()
if len(rows) > 1:
raise MultipleMatches()
@property
def database(self) -> databases.Database:
return self.model_cls.Meta.database
return self.model_meta.database
@property
def table(self) -> sqlalchemy.Table:
return self.model_cls.Meta.table
return self.model_meta.table
def build_select_expression(self) -> sqlalchemy.sql.select:
qry = Query(
model_cls=self.model_cls,
model_cls=self.model,
select_related=self._select_related,
filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
@ -89,7 +103,7 @@ class QuerySet:
def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003
qryclause = QueryClause(
model_cls=self.model_cls,
model_cls=self.model,
select_related=self._select_related,
filter_clauses=self.filter_clauses,
)
@ -102,7 +116,7 @@ class QuerySet:
filter_clauses = filter_clauses
return self.__class__(
model_cls=self.model_cls,
model_cls=self.model,
filter_clauses=filter_clauses,
exclude_clauses=exclude_clauses,
select_related=select_related,
@ -113,13 +127,13 @@ class QuerySet:
def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003
return self.filter(_exclude=True, **kwargs)
def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet":
if not isinstance(related, (list, tuple)):
def select_related(self, related: Union[List, str]) -> "QuerySet":
if not isinstance(related, list):
related = [related]
related = list(set(list(self._select_related) + related))
return self.__class__(
model_cls=self.model_cls,
model_cls=self.model,
filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
select_related=related,
@ -138,7 +152,7 @@ class QuerySet:
return await self.database.fetch_val(expr)
async def update(self, each: bool = False, **kwargs: Any) -> int:
self_fields = self.model_cls.extract_db_own_fields()
self_fields = self.model.extract_db_own_fields()
updates = {k: v for k, v in kwargs.items() if k in self_fields}
if not each and not self.filter_clauses:
raise QueryDefinitionError(
@ -165,7 +179,7 @@ class QuerySet:
def limit(self, limit_count: int) -> "QuerySet":
return self.__class__(
model_cls=self.model_cls,
model_cls=self.model,
filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
select_related=self._select_related,
@ -175,7 +189,7 @@ class QuerySet:
def offset(self, offset: int) -> "QuerySet":
return self.__class__(
model_cls=self.model_cls,
model_cls=self.model,
filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
select_related=self._select_related,
@ -189,7 +203,7 @@ class QuerySet:
rows = await self.limit(1).all()
self.check_single_result_rows_count(rows)
return rows[0]
return rows[0] # type: ignore
async def get(self, **kwargs: Any) -> "Model":
if kwargs:
@ -200,9 +214,9 @@ class QuerySet:
expr = expr.limit(2)
rows = await self.database.fetch_all(expr)
rows = self._process_query_result_rows(rows)
self.check_single_result_rows_count(rows)
return rows[0]
processed_rows = self._process_query_result_rows(rows)
self.check_single_result_rows_count(processed_rows)
return processed_rows[0] # type: ignore
async def get_or_create(self, **kwargs: Any) -> "Model":
try:
@ -211,7 +225,7 @@ class QuerySet:
return await self.create(**kwargs)
async def update_or_create(self, **kwargs: Any) -> "Model":
pk_name = self.model_cls.Meta.pkname
pk_name = self.model_meta.pkname
if "pk" in kwargs:
kwargs[pk_name] = kwargs.pop("pk")
if pk_name not in kwargs or kwargs.get(pk_name) is None:
@ -219,7 +233,7 @@ class QuerySet:
model = await self.get(pk=kwargs[pk_name])
return await model.update(**kwargs)
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003
if kwargs:
return await self.filter(**kwargs).all()
@ -233,20 +247,20 @@ class QuerySet:
new_kwargs = dict(**kwargs)
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs)
expr = self.table.insert()
expr = expr.values(**new_kwargs)
# Execute the insert, and return a new model instance.
instance = self.model_cls(**kwargs)
instance = self.model(**kwargs)
pk = await self.database.execute(expr)
pk_name = self.model_cls.Meta.pkname
pk_name = self.model_meta.pkname
if pk_name not in kwargs and pk_name in new_kwargs:
instance.pk = new_kwargs[self.model_cls.Meta.pkname]
if pk and isinstance(pk, self.model_cls.pk_type()):
setattr(instance, self.model_cls.Meta.pkname, pk)
instance.pk = new_kwargs[self.model_meta.pkname]
if pk and isinstance(pk, self.model.pk_type()):
setattr(instance, self.model_meta.pkname, pk)
return instance
async def bulk_create(self, objects: List["Model"]) -> None:
@ -254,7 +268,7 @@ class QuerySet:
for objt in objects:
new_kwargs = objt.dict()
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs)
ready_objects.append(new_kwargs)
@ -262,13 +276,15 @@ class QuerySet:
await self.database.execute_many(expr, ready_objects)
async def bulk_update(
self, objects: List["Model"], columns: List[str] = None
self, objects: List["Model"], columns: List[str] = None
) -> None:
ready_objects = []
pk_name = self.model_cls.Meta.pkname
pk_name = self.model_meta.pkname
if not columns:
columns = self.model_cls.extract_db_own_fields().union(
self.model_cls.extract_related_names()
columns = list(
self.model.extract_db_own_fields().union(
self.model.extract_related_names()
)
)
if pk_name not in columns:
@ -279,13 +295,13 @@ class QuerySet:
if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None:
raise QueryDefinitionError(
"You cannot update unsaved objects. "
f"{self.model_cls.__name__} has to have {pk_name} filled."
f"{self.model.__name__} has to have {pk_name} filled."
)
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
ready_objects.append(new_kwargs)
pk_column = self.model_cls.Meta.table.c.get(pk_name)
pk_column = self.model_meta.table.c.get(pk_name)
expr = self.table.update().where(pk_column == bindparam("new_" + pk_name))
expr = expr.values(
**{k: bindparam("new_" + k) for k in columns if k != pk_name}

View File

@ -1,7 +1,7 @@
import string
import uuid
from random import choices
from typing import List
from typing import List, Dict
import sqlalchemy
from sqlalchemy import text
@ -14,7 +14,7 @@ def get_table_alias() -> str:
class AliasManager:
def __init__(self) -> None:
self._aliases = dict()
self._aliases: Dict[str, str] = dict()
@staticmethod
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:

View File

@ -13,8 +13,8 @@ class QuerysetProxy:
relation: "Relation"
def __init__(self, relation: "Relation") -> None:
self.relation = relation
self.queryset = None
self.relation: Relation = relation
self.queryset: "QuerySet"
def _assign_child_to_parent(self, child: "Model") -> None:
owner = self.relation._owner

View File

@ -9,6 +9,7 @@ from ormar.relations.relation_proxy import RelationProxy
if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.relations import RelationsManager
from ormar.models import NewBaseModel
class RelationType(Enum):
@ -19,24 +20,26 @@ class RelationType(Enum):
class Relation:
def __init__(
self,
manager: "RelationsManager",
type_: RelationType,
to: Type["Model"],
through: Type["Model"] = None,
self,
manager: "RelationsManager",
type_: RelationType,
to: Type["Model"],
through: Type["Model"] = None,
) -> None:
self.manager = manager
self._owner = manager.owner
self._type = type_
self.to = to
self.through = through
self.related_models = (
self._owner: "Model" = manager.owner
self._type: RelationType = type_
self.to: Type["Model"] = to
self.through: Optional[Type["Model"]] = through
self.related_models: Optional[Union[RelationProxy, "Model"]] = (
RelationProxy(relation=self)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None
)
def _find_existing(self, child: "Model") -> Optional[int]:
if not isinstance(self.related_models, RelationProxy): # pragma nocover
raise ValueError("Cannot find existing models in parent relation type")
for ind, relation_child in enumerate(self.related_models[:]):
try:
if relation_child == child:
@ -52,7 +55,7 @@ class Relation:
self._owner.__dict__[relation_name] = child
else:
if self._find_existing(child) is None:
self.related_models.append(child)
self.related_models.append(child) # type: ignore
rel = self._owner.__dict__.get(relation_name, [])
rel = rel or []
if not isinstance(rel, list):
@ -60,19 +63,19 @@ class Relation:
rel.append(child)
self._owner.__dict__[relation_name] = rel
def remove(self, child: "Model") -> None:
def remove(self, child: Union["NewBaseModel", Type["NewBaseModel"]]) -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY:
if self.related_models.__same__(child):
if self.related_models == child:
self.related_models = None
del self._owner.__dict__[relation_name]
else:
position = self._find_existing(child)
if position is not None:
self.related_models.pop(position)
self.related_models.pop(position) # type: ignore
del self._owner.__dict__[relation_name][position]
def get(self) -> Union[List["Model"], "Model"]:
def get(self) -> Optional[Union[List["Model"], "Model"]]:
return self.related_models
def __repr__(self) -> str: # pragma no cover

View File

@ -1,6 +1,7 @@
from typing import List, Optional, TYPE_CHECKING, Type, Union
from typing import List, Optional, TYPE_CHECKING, Type, Union, Dict
from weakref import proxy
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.fields.many_to_many import ManyToManyField
from ormar.relations.relation import Relation, RelationType
@ -11,25 +12,28 @@ from ormar.relations.utils import (
if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.models import NewBaseModel
class RelationsManager:
def __init__(
self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None
self,
related_fields: List[Type[ForeignKeyField]] = None,
owner: "NewBaseModel" = None,
) -> None:
self.owner = proxy(owner)
self._related_fields = related_fields or []
self._related_names = [field.name for field in self._related_fields]
self._relations = dict()
self._relations: Dict[str, Relation] = dict()
for field in self._related_fields:
self._add_relation(field)
def _get_relation_type(self, field: Type[ForeignKeyField]) -> RelationType:
def _get_relation_type(self, field: Type[BaseField]) -> RelationType:
if issubclass(field, ManyToManyField):
return RelationType.MULTIPLE
return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE
def _add_relation(self, field: Type[ForeignKeyField]) -> None:
def _add_relation(self, field: Type[BaseField]) -> None:
self._relations[field.name] = Relation(
manager=self,
type_=self._get_relation_type(field),
@ -44,15 +48,17 @@ class RelationsManager:
relation = self._relations.get(name, None)
if relation is not None:
return relation.get()
return None # pragma nocover
def _get(self, name: str) -> Optional[Relation]:
relation = self._relations.get(name, None)
if relation is not None:
return relation
return None
@staticmethod
def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None:
to_field = child.resolve_relation_field(child, parent)
to_field: Type[BaseField] = child.resolve_relation_field(child, parent)
(parent, child, child_name, to_name,) = get_relations_sides_and_names(
to_field, parent, child, child_name, virtual
@ -61,18 +67,22 @@ class RelationsManager:
parent_relation = parent._orm._get(child_name)
if not parent_relation:
parent_relation = register_missing_relation(parent, child, child_name)
parent_relation.add(child)
child._orm._get(to_name).add(parent)
parent_relation.add(child) # type: ignore
def remove(self, name: str, child: "Model") -> None:
child_relation = child._orm._get(to_name)
if child_relation:
child_relation.add(parent)
def remove(self, name: str, child: Union["NewBaseModel", Type["NewBaseModel"]]) -> None:
relation = self._get(name)
relation.remove(child)
if relation:
relation.remove(child)
@staticmethod
def remove_parent(item: "Model", name: Union[str, "Model"]) -> None:
def remove_parent(item: Union["NewBaseModel", Type["NewBaseModel"]], name: "Model") -> None:
related_model = name
name = item.resolve_relation_name(item, related_model)
if name in item._orm:
rel_name = item.resolve_relation_name(item, related_model)
if rel_name in item._orm:
relation_name = item.resolve_relation_name(related_model, item)
item._orm.remove(name, related_model)
item._orm.remove(rel_name, related_model)
related_model._orm.remove(relation_name, item)

View File

@ -13,22 +13,30 @@ if TYPE_CHECKING: # pragma no cover
class RelationProxy(list):
def __init__(self, relation: "Relation") -> None:
super(RelationProxy, self).__init__()
self.relation = relation
self._owner = self.relation.manager.owner
self.relation: Relation = relation
self._owner: "Model" = self.relation.manager.owner
self.queryset_proxy = QuerysetProxy(relation=self.relation)
def __getattribute__(self, item: str) -> Any:
if item in ["count", "clear"]:
if not self.queryset_proxy.queryset:
self.queryset_proxy.queryset = self._set_queryset()
self._initialize_queryset()
return getattr(self.queryset_proxy, item)
return super().__getattribute__(item)
def __getattr__(self, item: str) -> Any:
if not self.queryset_proxy.queryset:
self.queryset_proxy.queryset = self._set_queryset()
self._initialize_queryset()
return getattr(self.queryset_proxy, item)
def _initialize_queryset(self) -> None:
if not self._check_if_queryset_is_initialized():
self.queryset_proxy.queryset = self._set_queryset()
def _check_if_queryset_is_initialized(self) -> bool:
return (
hasattr(self.queryset_proxy, "queryset")
and self.queryset_proxy.queryset is not None
)
def _set_queryset(self) -> "QuerySet":
owner_table = self.relation._owner.Meta.tablename
pkname = self.relation._owner.Meta.pkname
@ -45,10 +53,15 @@ class RelationProxy(list):
)
return queryset
async def remove(self, item: "Model") -> None:
async def remove(self, item: "Model") -> None: # type: ignore
super().remove(item)
rel_name = item.resolve_relation_name(item, self._owner)
item._orm._get(rel_name).remove(self._owner)
relation = item._orm._get(rel_name)
if relation is None: # pragma nocover
raise ValueError(
f"{self._owner.get_name()} does not have relation {rel_name}"
)
relation.remove(self._owner)
if self.relation._type == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.delete_through_instance(item)

View File

@ -1,8 +1,8 @@
from typing import TYPE_CHECKING, Tuple, Type
from typing import TYPE_CHECKING, Tuple, Type, Optional
from weakref import proxy
import ormar
from ormar.fields.foreign_key import ForeignKeyField
from ormar.fields import BaseField
from ormar.fields.many_to_many import ManyToManyField
from ormar.relations import Relation
@ -12,7 +12,7 @@ if TYPE_CHECKING: # pragma no cover
def register_missing_relation(
parent: "Model", child: "Model", child_name: str
) -> Relation:
) -> Optional[Relation]:
ormar.models.expand_reverse_relationships(child.__class__)
name = parent.resolve_relation_name(parent, child)
field = parent.Meta.model_fields[name]
@ -22,7 +22,7 @@ def register_missing_relation(
def get_relations_sides_and_names(
to_field: Type[ForeignKeyField],
to_field: Type[BaseField],
parent: "Model",
child: "Model",
child_name: str,