Refactor in join in order to make possibility for nested duplicated relations (and it was a mess :D)

This commit is contained in:
collerek
2021-01-15 17:05:23 +01:00
parent d10141ba6f
commit 0fe95b0c7b
14 changed files with 271 additions and 303 deletions

View File

@ -1,8 +1,8 @@
from collections import OrderedDict
from typing import (
Any,
Dict,
List,
NamedTuple,
Optional,
Set,
TYPE_CHECKING,
@ -14,24 +14,13 @@ from typing import (
import sqlalchemy
from sqlalchemy import text
from ormar.fields import ManyToManyField # noqa I100
from ormar.fields import BaseField, ManyToManyField # noqa I100
from ormar.relations import AliasManager
if TYPE_CHECKING: # pragma no cover
from ormar import Model
class JoinParameters(NamedTuple):
"""
Named tuple that holds set of parameters passed during join construction.
"""
prev_model: Type["Model"]
previous_alias: str
from_table: str
model_cls: Type["Model"]
class SqlJoin:
def __init__( # noqa: CFQ002
self,
@ -42,7 +31,12 @@ class SqlJoin:
exclude_fields: Optional[Union[Set, Dict]],
order_columns: Optional[List],
sorted_orders: OrderedDict,
main_model: Type["Model"],
related_models: Any = None,
own_alias: str = "",
) -> None:
self.own_alias = own_alias
self.related_models = related_models or []
self.used_aliases = used_aliases
self.select_from = select_from
self.columns = columns
@ -50,18 +44,17 @@ class SqlJoin:
self.exclude_fields = exclude_fields
self.order_columns = order_columns
self.sorted_orders = sorted_orders
self.main_model = main_model
@staticmethod
def alias_manager(model_cls: Type["Model"]) -> AliasManager:
@property
def alias_manager(self) -> AliasManager:
"""
Shortcut for ormars model AliasManager stored on Meta.
Shortcut for ormar's model AliasManager stored on Meta.
:param model_cls: ormar Model class
:type model_cls: Type[Model]
:return: alias manager from model's Meta
:rtype: AliasManager
"""
return model_cls.Meta.alias_manager
return self.main_model.Meta.alias_manager
@staticmethod
def on_clause(
@ -86,33 +79,32 @@ class SqlJoin:
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
return text(f"{left_part}={right_part}")
@staticmethod
def update_inclusions(
model_cls: Type["Model"],
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
nested_name: str,
) -> Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]:
"""
Extract nested fields and exclude_fields if applicable.
:param model_cls: ormar model class
:type model_cls: Type["Model"]
:param fields: fields to include
:type fields: Optional[Union[Set, Dict]]
:param exclude_fields: fields to exclude
:type exclude_fields: Optional[Union[Set, Dict]]
:param nested_name: name of the nested field
:type nested_name: str
:return: updated exclude and include fields from nested objects
:rtype: Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]
"""
fields = model_cls.get_included(fields, nested_name)
exclude_fields = model_cls.get_excluded(exclude_fields, nested_name)
return fields, exclude_fields
def process_deeper_join(
self, related_name: str, model_cls: Type["Model"], remainder: Any, alias: str,
) -> None:
sql_join = SqlJoin(
used_aliases=self.used_aliases,
select_from=self.select_from,
columns=self.columns,
fields=self.main_model.get_excluded(self.fields, related_name),
exclude_fields=self.main_model.get_excluded(
self.exclude_fields, related_name
),
order_columns=self.order_columns,
sorted_orders=self.sorted_orders,
main_model=model_cls,
related_models=remainder,
own_alias=alias,
)
(
self.used_aliases,
self.select_from,
self.columns,
self.sorted_orders,
) = sql_join.build_join(related_name)
def build_join( # noqa: CCR001
self, item: str, join_parameters: JoinParameters
self, related: str
) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]:
"""
Main external access point for building a join.
@ -120,59 +112,61 @@ class SqlJoin:
handles switching to through models for m2m relations, returns updated lists of
used_aliases and sort_orders.
:param item: string with join definition
:type item: str
:param join_parameters: parameters from previous/ current join
:type join_parameters: JoinParameters
:param related: string with join definition
:type related: str
:return: list of used aliases, select from, list of aliased columns, sort orders
:rtype: Tuple[List[str], Join, List[TextClause], collections.OrderedDict]
"""
fields = self.fields
exclude_fields = self.exclude_fields
target_field = self.main_model.Meta.model_fields[related]
prev_model = self.main_model
# TODO: Finish refactoring here
if issubclass(target_field, ManyToManyField):
new_part = self.process_m2m_related_name_change(
target_field=target_field, related=related
)
self._replace_many_to_many_order_by_columns(related, new_part)
for index, part in enumerate(item.split("__")):
if issubclass(
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
model_cls = target_field.through
alias = self.alias_manager.resolve_relation_alias(
from_model=prev_model, relation_name=related
)
if alias not in self.used_aliases:
self._process_join(
model_cls=model_cls,
related=related,
alias=alias,
target_field=target_field,
)
related = new_part
self.own_alias = alias
prev_model = model_cls
target_field = target_field.through.Meta.model_fields[related]
model_cls = target_field.to
alias = model_cls.Meta.alias_manager.resolve_relation_alias(
from_model=prev_model, relation_name=related
)
if alias not in self.used_aliases:
self._process_join(
model_cls=model_cls,
prev_model=prev_model,
related=related,
alias=alias,
target_field=target_field,
)
for related_name in self.related_models:
remainder = None
if (
isinstance(self.related_models, dict)
and self.related_models[related_name]
):
_fields = join_parameters.model_cls.Meta.model_fields
target_field = _fields[part]
if (
target_field.self_reference
and part == target_field.self_reference_primary
):
new_part = target_field.default_source_field_name() # type: ignore
else:
new_part = target_field.default_target_field_name() # type: ignore
self._switch_many_to_many_order_columns(part, new_part)
if index > 0: # nested joins
fields, exclude_fields = SqlJoin.update_inclusions(
model_cls=join_parameters.model_cls,
fields=fields,
exclude_fields=exclude_fields,
nested_name=part,
)
join_parameters = self._build_join_parameters(
part=part,
join_params=join_parameters,
is_multi=True,
fields=fields,
exclude_fields=exclude_fields,
)
part = new_part
if index > 0: # nested joins
fields, exclude_fields = SqlJoin.update_inclusions(
model_cls=join_parameters.model_cls,
fields=fields,
exclude_fields=exclude_fields,
nested_name=part,
)
join_parameters = self._build_join_parameters(
part=part,
join_params=join_parameters,
fields=fields,
exclude_fields=exclude_fields,
remainder = self.related_models[related_name]
self.process_deeper_join(
related_name=related_name,
model_cls=model_cls,
remainder=remainder,
alias=alias,
)
return (
@ -182,65 +176,44 @@ class SqlJoin:
self.sorted_orders,
)
def _build_join_parameters(
self,
part: str,
join_params: JoinParameters,
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
is_multi: bool = False,
) -> JoinParameters:
@staticmethod
def process_m2m_related_name_change(
target_field: Type[ManyToManyField], related: str, reverse: bool = False
) -> str:
"""
Updates used_aliases to not join multiple times to the same table.
Updates join parameters with new values.
Extracts relation name to link join through the Through model declared on
relation field.
:param part: part of the join str definition
:type part: str
:param join_params: parameters from previous/ current join
:type join_params: JoinParameters
:param fields: fields to include
:type fields: Optional[Union[Set, Dict]]
:param exclude_fields: fields to exclude
:type exclude_fields: Optional[Union[Set, Dict]]
:param is_multi: flag if the relation is m2m
:type is_multi: bool
:return: updated join parameters
:rtype: ormar.queryset.join.JoinParameters
Changes the same names in order_by queries if they are present.
:param reverse: flag if it's on_clause lookup - use reverse fields
:type reverse: bool
:param target_field: relation field
:type target_field: Type[ManyToManyField]
:param related: name of the relation
:type related: str
:return: new relation name switched to through model field
:rtype: str
"""
if is_multi:
model_cls = join_params.model_cls.Meta.model_fields[part].through
else:
model_cls = join_params.model_cls.Meta.model_fields[part].to
to_table = model_cls.Meta.table.name
alias = model_cls.Meta.alias_manager.resolve_relation_alias(
join_params.prev_model, part
is_primary_self_ref = (
target_field.self_reference
and related == target_field.self_reference_primary
)
if alias not in self.used_aliases:
self._process_join(
join_params=join_params,
is_multi=is_multi,
model_cls=model_cls,
part=part,
alias=alias,
fields=fields,
exclude_fields=exclude_fields,
)
previous_alias = alias
from_table = to_table
prev_model = model_cls
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
if (is_primary_self_ref and not reverse) or (
not is_primary_self_ref and reverse
):
new_part = target_field.default_source_field_name() # type: ignore
else:
new_part = target_field.default_target_field_name() # type: ignore
return new_part
def _process_join( # noqa: CFQ002
self,
join_params: JoinParameters,
is_multi: bool,
model_cls: Type["Model"],
part: str,
related: str,
alias: str,
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
target_field: Type[BaseField],
prev_model: Type["Model"] = None,
) -> None:
"""
Resolves to and from column names and table names.
@ -255,63 +228,53 @@ class SqlJoin:
Process order_by causes for non m2m relations.
:param join_params: parameters from previous/ current join
:type join_params: JoinParameters
:param is_multi: flag if it's m2m relation
:type is_multi: bool
:param model_cls:
:type model_cls: ormar.models.metaclass.ModelMetaclass
:param part: name of the field used in join
:type part: str
:param related: name of the field used in join
:type related: str
:param alias: alias of the current join
:type alias: str
:param fields: fields to include
:type fields: Optional[Union[Set, Dict]]
:param exclude_fields: fields to exclude
:type exclude_fields: Optional[Union[Set, Dict]]
"""
to_table = model_cls.Meta.table.name
to_key, from_key = self.get_to_and_from_keys(
join_params, is_multi, model_cls, part
)
to_key, from_key = self.get_to_and_from_keys(related, target_field)
prev_model = prev_model or self.main_model
on_clause = self.on_clause(
previous_alias=join_params.previous_alias,
previous_alias=self.own_alias,
alias=alias,
from_clause=f"{join_params.from_table}.{from_key}",
from_clause=f"{prev_model.Meta.tablename}.{from_key}",
to_clause=f"{to_table}.{to_key}",
)
target_table = self.alias_manager(model_cls).prefixed_table_name(
alias, to_table
)
target_table = self.alias_manager.prefixed_table_name(alias, to_table)
self.select_from = sqlalchemy.sql.outerjoin(
self.select_from, target_table, on_clause
)
pkname_alias = model_cls.get_column_alias(model_cls.Meta.pkname)
if not is_multi:
if not issubclass(target_field, ManyToManyField):
self.get_order_bys(
alias=alias,
to_table=to_table,
pkname_alias=pkname_alias,
part=part,
part=related,
model_cls=model_cls,
)
self_related_fields = model_cls.own_table_columns(
model=model_cls,
fields=fields,
exclude_fields=exclude_fields,
fields=self.fields,
exclude_fields=self.exclude_fields,
use_alias=True,
)
self.columns.extend(
self.alias_manager(model_cls).prefixed_columns(
self.alias_manager.prefixed_columns(
alias, model_cls.Meta.table, self_related_fields
)
)
self.used_aliases.append(alias)
def _switch_many_to_many_order_columns(self, part: str, new_part: str) -> None:
def _replace_many_to_many_order_by_columns(self, part: str, new_part: str) -> None:
"""
Substitutes the name of the relation with actual model name in m2m order bys.
@ -325,7 +288,7 @@ class SqlJoin:
x.split("__") for x in self.order_columns if "__" in x
]
for condition in split_order_columns:
if condition[-2] == part or condition[-2][1:] == part:
if self._check_if_condition_apply(condition, part):
condition[-2] = condition[-2].replace(part, new_part)
self.order_columns = [x for x in self.order_columns if "__" not in x] + [
"__".join(x) for x in split_order_columns
@ -413,51 +376,34 @@ class SqlJoin:
order = text(f"{alias}_{to_table}.{pkname_alias}")
self.sorted_orders[f"{alias}.{pkname_alias}"] = order
@staticmethod
def get_to_and_from_keys(
join_params: JoinParameters,
is_multi: bool,
model_cls: Type["Model"],
part: str,
self, related: str, target_field: Type[BaseField]
) -> Tuple[str, str]:
"""
Based on the relation type, name of the relation and previous models and parts
stored in JoinParameters it resolves the current to and from keys, which are
different for ManyToMany relation, ForeignKey and reverse part of relations.
different for ManyToMany relation, ForeignKey and reverse related of relations.
:param join_params: parameters from previous/ current join
:type join_params: JoinParameters
:param is_multi: flag if the relation is of m2m type
:type is_multi: bool
:param model_cls: ormar model class
:type model_cls: Type[Model]
:param part: name of the current relation join
:type part: str
:param target_field: relation field
:type target_field: Type[ForeignKeyField]
:param related: name of the current relation join
:type related: str
:return: to key and from key
:rtype: Tuple[str, str]
"""
if is_multi:
target_field = join_params.model_cls.Meta.model_fields[part]
if (
target_field.self_reference
and part == target_field.self_reference_primary
):
to_key = target_field.default_target_field_name() # type: ignore
else:
to_key = target_field.default_source_field_name() # type: ignore
from_key = join_params.prev_model.get_column_alias(
join_params.prev_model.Meta.pkname
if issubclass(target_field, ManyToManyField):
to_key = self.process_m2m_related_name_change(
target_field=target_field, related=related, reverse=True
)
from_key = self.main_model.get_column_alias(self.main_model.Meta.pkname)
elif join_params.prev_model.Meta.model_fields[part].virtual:
to_field = join_params.prev_model.Meta.model_fields[part].get_related_name()
to_key = model_cls.get_column_alias(to_field)
from_key = join_params.prev_model.get_column_alias(
join_params.prev_model.Meta.pkname
)
elif target_field.virtual:
to_field = target_field.get_related_name()
to_key = target_field.to.get_column_alias(to_field)
from_key = self.main_model.get_column_alias(self.main_model.Meta.pkname)
else:
to_key = model_cls.get_column_alias(model_cls.Meta.pkname)
from_key = join_params.prev_model.get_column_alias(part)
to_key = target_field.to.get_column_alias(target_field.to.Meta.pkname)
from_key = self.main_model.get_column_alias(related)
return to_key, from_key