liniting and applying black

This commit is contained in:
collerek
2020-08-09 07:51:06 +02:00
parent 9d9346fb13
commit 241628b1d9
9 changed files with 455 additions and 247 deletions

View File

@ -1,11 +1,14 @@
from typing import List, TYPE_CHECKING, Type, NamedTuple
from typing import Any, List, NamedTuple, TYPE_CHECKING, Tuple, Type, Union
import sqlalchemy
from sqlalchemy import text
import databases
import orm
from orm import ForeignKey
from orm.exceptions import NoMatch, MultipleMatches
from orm.exceptions import MultipleMatches, NoMatch
from orm.fields import BaseField
import sqlalchemy
from sqlalchemy import text
if TYPE_CHECKING: # pragma no cover
from orm.models import Model
@ -24,17 +27,23 @@ FILTER_OPERATORS = {
class JoinParameters(NamedTuple):
prev_model: Type['Model']
prev_model: Type["Model"]
previous_alias: str
from_table: str
model_cls: Type['Model']
model_cls: Type["Model"]
class QuerySet:
ESCAPE_CHARACTERS = ['%', '_']
ESCAPE_CHARACTERS = ["%", "_"]
def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None,
limit_count: int = None, offset: int = None):
def __init__(
self,
model_cls: Type["Model"] = None,
filter_clauses: List = None,
select_related: List = None,
limit_count: int = None,
offset: int = None,
) -> None:
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
self._select_related = [] if select_related is None else select_related
@ -48,47 +57,77 @@ class QuerySet:
self.columns = None
self.order_bys = None
def __get__(self, instance, owner):
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
return self.__class__(model_cls=owner)
@property
def database(self):
def database(self) -> databases.Database:
return self.model_cls.__database__
@property
def table(self):
def table(self) -> sqlalchemy.Table:
return self.model_cls.__table__
def prefixed_columns(self, alias, table):
return [text(f'{alias}_{table.name}.{column.name} as {alias}_{column.name}')
for column in table.columns]
def prefixed_columns(self, alias: str, table: sqlalchemy.Table) -> List[text]:
return [
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
for column in table.columns
]
def prefixed_table_name(self, alias, name):
return text(f'{name} {alias}_{name}')
def prefixed_table_name(self, alias: str, name: str) -> text:
return text(f"{name} {alias}_{name}")
def on_clause(self, from_table, to_table, previous_alias, alias, to_key, from_key):
return text(f'{alias}_{to_table}.{to_key}='
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}')
def on_clause(
self,
from_table: str,
to_table: str,
previous_alias: str,
alias: str,
to_key: str,
from_key: str,
) -> text:
return text(
f"{alias}_{to_table}.{to_key}="
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}'
)
def build_join_parameters(self, part, join_params: JoinParameters):
def build_join_parameters(
self, part: str, join_params: JoinParameters
) -> JoinParameters:
model_cls = join_params.model_cls.__model_fields__[part].to
to_table = model_cls.__table__.name
alias = model_cls._orm_relationship_manager.resolve_relation_join(join_params.from_table, to_table)
alias = model_cls._orm_relationship_manager.resolve_relation_join(
join_params.from_table, to_table
)
if alias not in self.used_aliases:
if join_params.prev_model.__model_fields__[part].virtual:
to_key = next((v for k, v in model_cls.__model_fields__.items()
if isinstance(v, ForeignKey) and v.to == join_params.prev_model), None).name
to_key = next(
(
v
for k, v in model_cls.__model_fields__.items()
if isinstance(v, ForeignKey) and v.to == join_params.prev_model
),
None,
).name
from_key = model_cls.__pkname__
else:
to_key = model_cls.__pkname__
from_key = part
on_clause = self.on_clause(join_params.from_table, to_table, join_params.previous_alias, alias, to_key,
from_key)
on_clause = self.on_clause(
join_params.from_table,
to_table,
join_params.previous_alias,
alias,
to_key,
from_key,
)
target_table = self.prefixed_table_name(alias, to_table)
self.select_from = sqlalchemy.sql.outerjoin(self.select_from, target_table, on_clause)
self.order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}'))
self.select_from = sqlalchemy.sql.outerjoin(
self.select_from, target_table, on_clause
)
self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}"))
self.columns.extend(self.prefixed_columns(alias, model_cls.__table__))
self.used_aliases.append(alias)
@ -98,44 +137,76 @@ class QuerySet:
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
@staticmethod
def field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part) -> bool:
def field_is_a_foreign_key_and_no_circular_reference(
field: BaseField, field_name: str, rel_part: str
) -> bool:
return isinstance(field, ForeignKey) and field_name not in rel_part
def field_qualifies_to_deeper_search(self, field, parent_virtual, nested, rel_part) -> bool:
def field_qualifies_to_deeper_search(
self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str
) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
partial_match = any([x.startswith(prev_part_of_related) for x in self._select_related])
partial_match = any(
[x.startswith(prev_part_of_related) for x in self._select_related]
)
already_checked = any([x.startswith(rel_part) for x in self.auto_related])
return ((field.virtual and parent_virtual) or (partial_match and not already_checked)) or not nested
return (
(field.virtual and parent_virtual)
or (partial_match and not already_checked)
) or not nested
def extract_auto_required_relations(self, join_params: JoinParameters,
rel_part: str = '', nested: bool = False, parent_virtual: bool = False):
def extract_auto_required_relations(
self,
join_params: JoinParameters,
rel_part: str = "",
nested: bool = False,
parent_virtual: bool = False,
) -> None:
for field_name, field in join_params.prev_model.__model_fields__.items():
if self.field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part):
rel_part = field_name if not rel_part else rel_part + '__' + field_name
if self.field_is_a_foreign_key_and_no_circular_reference(
field, field_name, rel_part
):
rel_part = field_name if not rel_part else rel_part + "__" + field_name
if not field.nullable:
if rel_part not in self._select_related:
self.auto_related.append("__".join(rel_part.split("__")[:-1]))
rel_part = ''
elif self.field_qualifies_to_deeper_search(field, parent_virtual, nested, rel_part):
join_params = JoinParameters(field.to, join_params.previous_alias,
join_params.from_table, join_params.prev_model)
self.extract_auto_required_relations(join_params=join_params,
rel_part=rel_part, nested=True, parent_virtual=field.virtual)
rel_part = ""
elif self.field_qualifies_to_deeper_search(
field, parent_virtual, nested, rel_part
):
join_params = JoinParameters(
field.to,
join_params.previous_alias,
join_params.from_table,
join_params.prev_model,
)
self.extract_auto_required_relations(
join_params=join_params,
rel_part=rel_part,
nested=True,
parent_virtual=field.virtual,
)
else:
rel_part = ''
rel_part = ""
def build_select_expression(self):
def build_select_expression(self) -> sqlalchemy.sql.select:
self.columns = list(self.table.columns)
self.order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')]
self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")]
self.select_from = self.table
for key in self.model_cls.__model_fields__:
if not self.model_cls.__model_fields__[key].nullable \
and isinstance(self.model_cls.__model_fields__[key], orm.fields.ForeignKey) \
and key not in self._select_related:
if (
not self.model_cls.__model_fields__[key].nullable
and isinstance(
self.model_cls.__model_fields__[key], orm.fields.ForeignKey
)
and key not in self._select_related
):
self._select_related = [key] + self._select_related
start_params = JoinParameters(self.model_cls, '', self.table.name, self.model_cls)
start_params = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls
)
self.extract_auto_required_relations(start_params)
if self.auto_related:
new_joins = []
@ -146,7 +217,9 @@ class QuerySet:
self._select_related.sort(key=lambda item: (-len(item), item))
for item in self._select_related:
join_parameters = JoinParameters(self.model_cls, '', self.table.name, self.model_cls)
join_parameters = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls
)
for part in item.split("__"):
join_parameters = self.build_join_parameters(part, join_parameters)
@ -180,7 +253,7 @@ class QuerySet:
return expr
def filter(self, **kwargs):
def filter(self, **kwargs: Any) -> "QuerySet":
filter_clauses = self.filter_clauses
select_related = list(self._select_related)
@ -189,7 +262,7 @@ class QuerySet:
kwargs[pk_name] = kwargs.pop("pk")
for key, value in kwargs.items():
table_prefix = ''
table_prefix = ""
if "__" in key:
parts = key.split("__")
@ -215,9 +288,13 @@ class QuerySet:
# against which the comparison is being made.
previous_table = model_cls.__tablename__
for part in related_parts:
current_table = model_cls.__model_fields__[part].to.__tablename__
table_prefix = model_cls._orm_relationship_manager.resolve_relation_join(previous_table,
current_table)
current_table = model_cls.__model_fields__[
part
].to.__tablename__
manager = model_cls._orm_relationship_manager
table_prefix = manager.resolve_relation_join(
previous_table, current_table
)
model_cls = model_cls.__model_fields__[part].to
previous_table = current_table
@ -236,25 +313,32 @@ class QuerySet:
has_escaped_character = False
if op in ["contains", "icontains"]:
has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS
if c in value)
has_escaped_character = any(
c for c in self.ESCAPE_CHARACTERS if c in value
)
if has_escaped_character:
# enable escape modifier
for char in self.ESCAPE_CHARACTERS:
value = value.replace(char, f'\\{char}')
value = value.replace(char, f"\\{char}")
value = f"%{value}%"
if isinstance(value, orm.Model):
value = value.pk
clause = getattr(column, op_attr)(value)
clause.modifiers['escape'] = '\\' if has_escaped_character else None
clause.modifiers["escape"] = "\\" if has_escaped_character else None
clause_text = str(clause.compile(dialect=self.model_cls.__database__._backend._dialect,
compile_kwargs={"literal_binds": True}))
alias = f'{table_prefix}_' if table_prefix else ''
aliased_name = f'{alias}{table.name}.{column.name}'
clause_text = clause_text.replace(f'{table.name}.{column.name}', aliased_name)
clause_text = str(
clause.compile(
dialect=self.model_cls.__database__._backend._dialect,
compile_kwargs={"literal_binds": True},
)
)
alias = f"{table_prefix}_" if table_prefix else ""
aliased_name = f"{alias}{table.name}.{column.name}"
clause_text = clause_text.replace(
f"{table.name}.{column.name}", aliased_name
)
clause = text(clause_text)
filter_clauses.append(clause)
@ -264,10 +348,10 @@ class QuerySet:
filter_clauses=filter_clauses,
select_related=select_related,
limit_count=self.limit_count,
offset=self.query_offset
offset=self.query_offset,
)
def select_related(self, related):
def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet":
if not isinstance(related, (list, tuple)):
related = [related]
@ -277,7 +361,7 @@ class QuerySet:
filter_clauses=self.filter_clauses,
select_related=related,
limit_count=self.limit_count,
offset=self.query_offset
offset=self.query_offset,
)
async def exists(self) -> bool:
@ -290,25 +374,25 @@ class QuerySet:
expr = sqlalchemy.func.count().select().select_from(expr)
return await self.database.fetch_val(expr)
def limit(self, limit_count: int):
def limit(self, limit_count: int) -> "QuerySet":
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=self._select_related,
limit_count=limit_count,
offset=self.query_offset
offset=self.query_offset,
)
def offset(self, offset: int):
def offset(self, offset: int) -> "QuerySet":
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=self._select_related,
limit_count=self.limit_count,
offset=offset
offset=offset,
)
async def first(self, **kwargs):
async def first(self, **kwargs: Any) -> "Model":
if kwargs:
return await self.filter(**kwargs).first()
@ -316,7 +400,7 @@ class QuerySet:
if rows:
return rows[0]
async def get(self, **kwargs):
async def get(self, **kwargs: Any) -> "Model":
if kwargs:
return await self.filter(**kwargs).get()
@ -329,7 +413,7 @@ class QuerySet:
raise MultipleMatches()
return self.model_cls.from_row(rows[0], select_related=self._select_related)
async def all(self, **kwargs):
async def all(self, **kwargs: Any) -> List["Model"]:
if kwargs:
return await self.filter(**kwargs).all()
@ -345,7 +429,7 @@ class QuerySet:
return result_rows
@classmethod
def merge_result_rows(cls, result_rows):
def merge_result_rows(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows = []
for index, model in enumerate(result_rows):
if index > 0 and model.pk == result_rows[index - 1].pk:
@ -355,30 +439,45 @@ class QuerySet:
return merged_rows
@classmethod
def merge_two_instances(cls, one: 'Model', other: 'Model'):
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
for field in one.__model_fields__.keys():
# print(field, one.dict(), other.dict())
if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model):
if isinstance(getattr(one, field), list) and not isinstance(
getattr(one, field), orm.models.Model
):
setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), orm.models.Model):
if getattr(one, field).pk == getattr(other, field).pk:
setattr(other, field, cls.merge_two_instances(getattr(one, field), getattr(other, field)))
setattr(
other,
field,
cls.merge_two_instances(
getattr(one, field), getattr(other, field)
),
)
return other
async def create(self, **kwargs):
async def create(self, **kwargs: Any) -> "Model":
new_kwargs = dict(**kwargs)
# Remove primary key when None to prevent not null constraint in postgresql.
pkname = self.model_cls.__pkname__
pk = self.model_cls.__model_fields__[pkname]
if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement):
if (
pkname in new_kwargs
and new_kwargs.get(pkname) is None
and (pk.nullable or pk.autoincrement)
):
del new_kwargs[pkname]
# substitute related models with their pk
for field in self.model_cls.extract_related_names():
if field in new_kwargs and new_kwargs.get(field) is not None:
new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__)
new_kwargs[field] = getattr(
new_kwargs.get(field),
self.model_cls.__model_fields__[field].to.__pkname__,
)
# Build the insert expression.
expr = self.table.insert()