added mypy checks and some typehint changes to conform

This commit is contained in:
collerek
2020-09-29 14:05:08 +02:00
parent 6d56ea5e30
commit 3caa87057e
23 changed files with 274 additions and 202 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type
import sqlalchemy
from sqlalchemy import text
@ -118,7 +118,7 @@ class QueryClause:
def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str]
) -> Tuple[List[str], str, "Model"]:
) -> Tuple[List[str], str, Type["Model"]]:
table_prefix = ""
model_cls = self.model_cls
@ -168,9 +168,7 @@ class QueryClause:
return clause
@staticmethod
def _escape_characters_in_clause(
op: str, value: Union[str, "Model"]
) -> Tuple[str, bool]:
def _escape_characters_in_clause(op: str, value: Any) -> Tuple[Any, bool]:
has_escaped_character = False
if op not in [

View File

@ -22,8 +22,8 @@ class SqlJoin:
self,
used_aliases: List,
select_from: sqlalchemy.sql.select,
order_bys: List,
columns: List,
order_bys: List[sqlalchemy.sql.elements.TextClause],
columns: List[sqlalchemy.Column],
) -> None:
self.used_aliases = used_aliases
self.select_from = select_from

View File

@ -1,8 +1,10 @@
from typing import Optional
import sqlalchemy
class LimitQuery:
def __init__(self, limit_count: int) -> None:
def __init__(self, limit_count: Optional[int]) -> None:
self.limit_count = limit_count
def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select:

View File

@ -1,8 +1,10 @@
from typing import Optional
import sqlalchemy
class OffsetQuery:
def __init__(self, query_offset: int) -> None:
def __init__(self, query_offset: Optional[int]) -> None:
self.query_offset = query_offset
def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select:

View File

@ -1,4 +1,4 @@
from typing import List, TYPE_CHECKING, Tuple, Type
from typing import List, TYPE_CHECKING, Tuple, Type, Optional
import sqlalchemy
from sqlalchemy import text
@ -18,8 +18,8 @@ class Query:
filter_clauses: List,
exclude_clauses: List,
select_related: List,
limit_count: int,
offset: int,
limit_count: Optional[int],
offset: Optional[int],
) -> None:
self.query_offset = offset
self.limit_count = limit_count
@ -30,11 +30,11 @@ class Query:
self.model_cls = model_cls
self.table = self.model_cls.Meta.table
self.used_aliases = []
self.used_aliases: List[str] = []
self.select_from = None
self.columns = None
self.order_bys = None
self.select_from: List[str] = []
self.columns = [sqlalchemy.Column]
self.order_bys: List[sqlalchemy.sql.elements.TextClause] = []
@property
def prefixed_pk_name(self) -> str:
@ -89,7 +89,7 @@ class Query:
return expr
def _reset_query_parameters(self) -> None:
self.select_from = None
self.columns = None
self.order_bys = None
self.select_from = []
self.columns = []
self.order_bys = []
self.used_aliases = []

View File

@ -1,4 +1,4 @@
from typing import Any, List, Mapping, TYPE_CHECKING, Tuple, Type, Union
from typing import Any, List, Mapping, TYPE_CHECKING, Type, Union, Optional
import databases
import sqlalchemy
@ -13,17 +13,18 @@ from ormar.queryset.query import Query
if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.models.metaclass import ModelMeta
class QuerySet:
def __init__( # noqa CFQ002
self,
model_cls: Type["Model"] = None,
filter_clauses: List = None,
exclude_clauses: List = None,
select_related: List = None,
limit_count: int = None,
offset: int = None,
self,
model_cls: Type["Model"] = None,
filter_clauses: List = None,
exclude_clauses: List = None,
select_related: List = None,
limit_count: int = None,
offset: int = None,
) -> None:
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
@ -36,47 +37,60 @@ class QuerySet:
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
return self.__class__(model_cls=owner)
def _process_query_result_rows(self, rows: List[Mapping]) -> List["Model"]:
@property
def model_meta(self) -> "ModelMeta":
if not self.model_cls: # pragma nocover
raise ValueError("Model class of QuerySet is not initialized")
return self.model_cls.Meta
@property
def model(self) -> Type["Model"]:
if not self.model_cls: # pragma nocover
raise ValueError("Model class of QuerySet is not initialized")
return self.model_cls
def _process_query_result_rows(self, rows: List) -> List[Optional["Model"]]:
result_rows = [
self.model_cls.from_row(row, select_related=self._select_related)
self.model.from_row(row, select_related=self._select_related)
for row in rows
]
rows = self.model_cls.merge_instances_list(result_rows)
return rows
if result_rows:
return self.model.merge_instances_list(result_rows) # type: ignore
return result_rows
def _populate_default_values(self, new_kwargs: dict) -> dict:
for field_name, field in self.model_cls.Meta.model_fields.items():
for field_name, field in self.model_meta.model_fields.items():
if field_name not in new_kwargs and field.has_default():
new_kwargs[field_name] = field.get_default()
return new_kwargs
def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict:
pkname = self.model_cls.Meta.pkname
pk = self.model_cls.Meta.model_fields[pkname]
pkname = self.model_meta.pkname
pk = self.model_meta.model_fields[pkname]
if new_kwargs.get(pkname, ormar.Undefined) is None and (
pk.nullable or pk.autoincrement
pk.nullable or pk.autoincrement
):
del new_kwargs[pkname]
return new_kwargs
@staticmethod
def check_single_result_rows_count(rows: List["Model"]) -> None:
if not rows:
def check_single_result_rows_count(rows: List[Optional["Model"]]) -> None:
if not rows or rows[0] is None:
raise NoMatch()
if len(rows) > 1:
raise MultipleMatches()
@property
def database(self) -> databases.Database:
return self.model_cls.Meta.database
return self.model_meta.database
@property
def table(self) -> sqlalchemy.Table:
return self.model_cls.Meta.table
return self.model_meta.table
def build_select_expression(self) -> sqlalchemy.sql.select:
qry = Query(
model_cls=self.model_cls,
model_cls=self.model,
select_related=self._select_related,
filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
@ -89,7 +103,7 @@ class QuerySet:
def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003
qryclause = QueryClause(
model_cls=self.model_cls,
model_cls=self.model,
select_related=self._select_related,
filter_clauses=self.filter_clauses,
)
@ -102,7 +116,7 @@ class QuerySet:
filter_clauses = filter_clauses
return self.__class__(
model_cls=self.model_cls,
model_cls=self.model,
filter_clauses=filter_clauses,
exclude_clauses=exclude_clauses,
select_related=select_related,
@ -113,13 +127,13 @@ class QuerySet:
def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003
return self.filter(_exclude=True, **kwargs)
def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet":
if not isinstance(related, (list, tuple)):
def select_related(self, related: Union[List, str]) -> "QuerySet":
if not isinstance(related, list):
related = [related]
related = list(set(list(self._select_related) + related))
return self.__class__(
model_cls=self.model_cls,
model_cls=self.model,
filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
select_related=related,
@ -138,7 +152,7 @@ class QuerySet:
return await self.database.fetch_val(expr)
async def update(self, each: bool = False, **kwargs: Any) -> int:
self_fields = self.model_cls.extract_db_own_fields()
self_fields = self.model.extract_db_own_fields()
updates = {k: v for k, v in kwargs.items() if k in self_fields}
if not each and not self.filter_clauses:
raise QueryDefinitionError(
@ -165,7 +179,7 @@ class QuerySet:
def limit(self, limit_count: int) -> "QuerySet":
return self.__class__(
model_cls=self.model_cls,
model_cls=self.model,
filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
select_related=self._select_related,
@ -175,7 +189,7 @@ class QuerySet:
def offset(self, offset: int) -> "QuerySet":
return self.__class__(
model_cls=self.model_cls,
model_cls=self.model,
filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
select_related=self._select_related,
@ -189,7 +203,7 @@ class QuerySet:
rows = await self.limit(1).all()
self.check_single_result_rows_count(rows)
return rows[0]
return rows[0] # type: ignore
async def get(self, **kwargs: Any) -> "Model":
if kwargs:
@ -200,9 +214,9 @@ class QuerySet:
expr = expr.limit(2)
rows = await self.database.fetch_all(expr)
rows = self._process_query_result_rows(rows)
self.check_single_result_rows_count(rows)
return rows[0]
processed_rows = self._process_query_result_rows(rows)
self.check_single_result_rows_count(processed_rows)
return processed_rows[0] # type: ignore
async def get_or_create(self, **kwargs: Any) -> "Model":
try:
@ -211,7 +225,7 @@ class QuerySet:
return await self.create(**kwargs)
async def update_or_create(self, **kwargs: Any) -> "Model":
pk_name = self.model_cls.Meta.pkname
pk_name = self.model_meta.pkname
if "pk" in kwargs:
kwargs[pk_name] = kwargs.pop("pk")
if pk_name not in kwargs or kwargs.get(pk_name) is None:
@ -219,7 +233,7 @@ class QuerySet:
model = await self.get(pk=kwargs[pk_name])
return await model.update(**kwargs)
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003
if kwargs:
return await self.filter(**kwargs).all()
@ -233,20 +247,20 @@ class QuerySet:
new_kwargs = dict(**kwargs)
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs)
expr = self.table.insert()
expr = expr.values(**new_kwargs)
# Execute the insert, and return a new model instance.
instance = self.model_cls(**kwargs)
instance = self.model(**kwargs)
pk = await self.database.execute(expr)
pk_name = self.model_cls.Meta.pkname
pk_name = self.model_meta.pkname
if pk_name not in kwargs and pk_name in new_kwargs:
instance.pk = new_kwargs[self.model_cls.Meta.pkname]
if pk and isinstance(pk, self.model_cls.pk_type()):
setattr(instance, self.model_cls.Meta.pkname, pk)
instance.pk = new_kwargs[self.model_meta.pkname]
if pk and isinstance(pk, self.model.pk_type()):
setattr(instance, self.model_meta.pkname, pk)
return instance
async def bulk_create(self, objects: List["Model"]) -> None:
@ -254,7 +268,7 @@ class QuerySet:
for objt in objects:
new_kwargs = objt.dict()
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs)
ready_objects.append(new_kwargs)
@ -262,13 +276,15 @@ class QuerySet:
await self.database.execute_many(expr, ready_objects)
async def bulk_update(
self, objects: List["Model"], columns: List[str] = None
self, objects: List["Model"], columns: List[str] = None
) -> None:
ready_objects = []
pk_name = self.model_cls.Meta.pkname
pk_name = self.model_meta.pkname
if not columns:
columns = self.model_cls.extract_db_own_fields().union(
self.model_cls.extract_related_names()
columns = list(
self.model.extract_db_own_fields().union(
self.model.extract_related_names()
)
)
if pk_name not in columns:
@ -279,13 +295,13 @@ class QuerySet:
if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None:
raise QueryDefinitionError(
"You cannot update unsaved objects. "
f"{self.model_cls.__name__} has to have {pk_name} filled."
f"{self.model.__name__} has to have {pk_name} filled."
)
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
ready_objects.append(new_kwargs)
pk_column = self.model_cls.Meta.table.c.get(pk_name)
pk_column = self.model_meta.table.c.get(pk_name)
expr = self.table.update().where(pk_column == bindparam("new_" + pk_name))
expr = expr.values(
**{k: bindparam("new_" + k) for k in columns if k != pk_name}