Refactor in join in order to make possibility for nested duplicated relations (and it was a mess :D)
This commit is contained in:
@ -1,8 +1,8 @@
|
||||
from collections import OrderedDict
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
@ -14,24 +14,13 @@ from typing import (
|
||||
import sqlalchemy
|
||||
from sqlalchemy import text
|
||||
|
||||
from ormar.fields import ManyToManyField # noqa I100
|
||||
from ormar.fields import BaseField, ManyToManyField # noqa I100
|
||||
from ormar.relations import AliasManager
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
|
||||
|
||||
class JoinParameters(NamedTuple):
|
||||
"""
|
||||
Named tuple that holds set of parameters passed during join construction.
|
||||
"""
|
||||
|
||||
prev_model: Type["Model"]
|
||||
previous_alias: str
|
||||
from_table: str
|
||||
model_cls: Type["Model"]
|
||||
|
||||
|
||||
class SqlJoin:
|
||||
def __init__( # noqa: CFQ002
|
||||
self,
|
||||
@ -42,7 +31,12 @@ class SqlJoin:
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
order_columns: Optional[List],
|
||||
sorted_orders: OrderedDict,
|
||||
main_model: Type["Model"],
|
||||
related_models: Any = None,
|
||||
own_alias: str = "",
|
||||
) -> None:
|
||||
self.own_alias = own_alias
|
||||
self.related_models = related_models or []
|
||||
self.used_aliases = used_aliases
|
||||
self.select_from = select_from
|
||||
self.columns = columns
|
||||
@ -50,18 +44,17 @@ class SqlJoin:
|
||||
self.exclude_fields = exclude_fields
|
||||
self.order_columns = order_columns
|
||||
self.sorted_orders = sorted_orders
|
||||
self.main_model = main_model
|
||||
|
||||
@staticmethod
|
||||
def alias_manager(model_cls: Type["Model"]) -> AliasManager:
|
||||
@property
|
||||
def alias_manager(self) -> AliasManager:
|
||||
"""
|
||||
Shortcut for ormars model AliasManager stored on Meta.
|
||||
Shortcut for ormar's model AliasManager stored on Meta.
|
||||
|
||||
:param model_cls: ormar Model class
|
||||
:type model_cls: Type[Model]
|
||||
:return: alias manager from model's Meta
|
||||
:rtype: AliasManager
|
||||
"""
|
||||
return model_cls.Meta.alias_manager
|
||||
return self.main_model.Meta.alias_manager
|
||||
|
||||
@staticmethod
|
||||
def on_clause(
|
||||
@ -86,33 +79,32 @@ class SqlJoin:
|
||||
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
|
||||
return text(f"{left_part}={right_part}")
|
||||
|
||||
@staticmethod
|
||||
def update_inclusions(
|
||||
model_cls: Type["Model"],
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
nested_name: str,
|
||||
) -> Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]:
|
||||
"""
|
||||
Extract nested fields and exclude_fields if applicable.
|
||||
|
||||
:param model_cls: ormar model class
|
||||
:type model_cls: Type["Model"]
|
||||
:param fields: fields to include
|
||||
:type fields: Optional[Union[Set, Dict]]
|
||||
:param exclude_fields: fields to exclude
|
||||
:type exclude_fields: Optional[Union[Set, Dict]]
|
||||
:param nested_name: name of the nested field
|
||||
:type nested_name: str
|
||||
:return: updated exclude and include fields from nested objects
|
||||
:rtype: Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]
|
||||
"""
|
||||
fields = model_cls.get_included(fields, nested_name)
|
||||
exclude_fields = model_cls.get_excluded(exclude_fields, nested_name)
|
||||
return fields, exclude_fields
|
||||
def process_deeper_join(
|
||||
self, related_name: str, model_cls: Type["Model"], remainder: Any, alias: str,
|
||||
) -> None:
|
||||
sql_join = SqlJoin(
|
||||
used_aliases=self.used_aliases,
|
||||
select_from=self.select_from,
|
||||
columns=self.columns,
|
||||
fields=self.main_model.get_excluded(self.fields, related_name),
|
||||
exclude_fields=self.main_model.get_excluded(
|
||||
self.exclude_fields, related_name
|
||||
),
|
||||
order_columns=self.order_columns,
|
||||
sorted_orders=self.sorted_orders,
|
||||
main_model=model_cls,
|
||||
related_models=remainder,
|
||||
own_alias=alias,
|
||||
)
|
||||
(
|
||||
self.used_aliases,
|
||||
self.select_from,
|
||||
self.columns,
|
||||
self.sorted_orders,
|
||||
) = sql_join.build_join(related_name)
|
||||
|
||||
def build_join( # noqa: CCR001
|
||||
self, item: str, join_parameters: JoinParameters
|
||||
self, related: str
|
||||
) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]:
|
||||
"""
|
||||
Main external access point for building a join.
|
||||
@ -120,59 +112,61 @@ class SqlJoin:
|
||||
handles switching to through models for m2m relations, returns updated lists of
|
||||
used_aliases and sort_orders.
|
||||
|
||||
:param item: string with join definition
|
||||
:type item: str
|
||||
:param join_parameters: parameters from previous/ current join
|
||||
:type join_parameters: JoinParameters
|
||||
:param related: string with join definition
|
||||
:type related: str
|
||||
:return: list of used aliases, select from, list of aliased columns, sort orders
|
||||
:rtype: Tuple[List[str], Join, List[TextClause], collections.OrderedDict]
|
||||
"""
|
||||
fields = self.fields
|
||||
exclude_fields = self.exclude_fields
|
||||
target_field = self.main_model.Meta.model_fields[related]
|
||||
prev_model = self.main_model
|
||||
# TODO: Finish refactoring here
|
||||
if issubclass(target_field, ManyToManyField):
|
||||
new_part = self.process_m2m_related_name_change(
|
||||
target_field=target_field, related=related
|
||||
)
|
||||
self._replace_many_to_many_order_by_columns(related, new_part)
|
||||
|
||||
for index, part in enumerate(item.split("__")):
|
||||
if issubclass(
|
||||
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
|
||||
model_cls = target_field.through
|
||||
alias = self.alias_manager.resolve_relation_alias(
|
||||
from_model=prev_model, relation_name=related
|
||||
)
|
||||
if alias not in self.used_aliases:
|
||||
self._process_join(
|
||||
model_cls=model_cls,
|
||||
related=related,
|
||||
alias=alias,
|
||||
target_field=target_field,
|
||||
)
|
||||
related = new_part
|
||||
self.own_alias = alias
|
||||
prev_model = model_cls
|
||||
target_field = target_field.through.Meta.model_fields[related]
|
||||
|
||||
model_cls = target_field.to
|
||||
alias = model_cls.Meta.alias_manager.resolve_relation_alias(
|
||||
from_model=prev_model, relation_name=related
|
||||
)
|
||||
if alias not in self.used_aliases:
|
||||
self._process_join(
|
||||
model_cls=model_cls,
|
||||
prev_model=prev_model,
|
||||
related=related,
|
||||
alias=alias,
|
||||
target_field=target_field,
|
||||
)
|
||||
|
||||
for related_name in self.related_models:
|
||||
remainder = None
|
||||
if (
|
||||
isinstance(self.related_models, dict)
|
||||
and self.related_models[related_name]
|
||||
):
|
||||
_fields = join_parameters.model_cls.Meta.model_fields
|
||||
target_field = _fields[part]
|
||||
if (
|
||||
target_field.self_reference
|
||||
and part == target_field.self_reference_primary
|
||||
):
|
||||
new_part = target_field.default_source_field_name() # type: ignore
|
||||
else:
|
||||
new_part = target_field.default_target_field_name() # type: ignore
|
||||
self._switch_many_to_many_order_columns(part, new_part)
|
||||
if index > 0: # nested joins
|
||||
fields, exclude_fields = SqlJoin.update_inclusions(
|
||||
model_cls=join_parameters.model_cls,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
nested_name=part,
|
||||
)
|
||||
|
||||
join_parameters = self._build_join_parameters(
|
||||
part=part,
|
||||
join_params=join_parameters,
|
||||
is_multi=True,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
)
|
||||
part = new_part
|
||||
|
||||
if index > 0: # nested joins
|
||||
fields, exclude_fields = SqlJoin.update_inclusions(
|
||||
model_cls=join_parameters.model_cls,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
nested_name=part,
|
||||
)
|
||||
join_parameters = self._build_join_parameters(
|
||||
part=part,
|
||||
join_params=join_parameters,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
remainder = self.related_models[related_name]
|
||||
self.process_deeper_join(
|
||||
related_name=related_name,
|
||||
model_cls=model_cls,
|
||||
remainder=remainder,
|
||||
alias=alias,
|
||||
)
|
||||
|
||||
return (
|
||||
@ -182,65 +176,44 @@ class SqlJoin:
|
||||
self.sorted_orders,
|
||||
)
|
||||
|
||||
def _build_join_parameters(
|
||||
self,
|
||||
part: str,
|
||||
join_params: JoinParameters,
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
is_multi: bool = False,
|
||||
) -> JoinParameters:
|
||||
@staticmethod
|
||||
def process_m2m_related_name_change(
|
||||
target_field: Type[ManyToManyField], related: str, reverse: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Updates used_aliases to not join multiple times to the same table.
|
||||
Updates join parameters with new values.
|
||||
Extracts relation name to link join through the Through model declared on
|
||||
relation field.
|
||||
|
||||
:param part: part of the join str definition
|
||||
:type part: str
|
||||
:param join_params: parameters from previous/ current join
|
||||
:type join_params: JoinParameters
|
||||
:param fields: fields to include
|
||||
:type fields: Optional[Union[Set, Dict]]
|
||||
:param exclude_fields: fields to exclude
|
||||
:type exclude_fields: Optional[Union[Set, Dict]]
|
||||
:param is_multi: flag if the relation is m2m
|
||||
:type is_multi: bool
|
||||
:return: updated join parameters
|
||||
:rtype: ormar.queryset.join.JoinParameters
|
||||
Changes the same names in order_by queries if they are present.
|
||||
|
||||
:param reverse: flag if it's on_clause lookup - use reverse fields
|
||||
:type reverse: bool
|
||||
:param target_field: relation field
|
||||
:type target_field: Type[ManyToManyField]
|
||||
:param related: name of the relation
|
||||
:type related: str
|
||||
:return: new relation name switched to through model field
|
||||
:rtype: str
|
||||
"""
|
||||
if is_multi:
|
||||
model_cls = join_params.model_cls.Meta.model_fields[part].through
|
||||
else:
|
||||
model_cls = join_params.model_cls.Meta.model_fields[part].to
|
||||
to_table = model_cls.Meta.table.name
|
||||
|
||||
alias = model_cls.Meta.alias_manager.resolve_relation_alias(
|
||||
join_params.prev_model, part
|
||||
is_primary_self_ref = (
|
||||
target_field.self_reference
|
||||
and related == target_field.self_reference_primary
|
||||
)
|
||||
if alias not in self.used_aliases:
|
||||
self._process_join(
|
||||
join_params=join_params,
|
||||
is_multi=is_multi,
|
||||
model_cls=model_cls,
|
||||
part=part,
|
||||
alias=alias,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
)
|
||||
|
||||
previous_alias = alias
|
||||
from_table = to_table
|
||||
prev_model = model_cls
|
||||
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
||||
if (is_primary_self_ref and not reverse) or (
|
||||
not is_primary_self_ref and reverse
|
||||
):
|
||||
new_part = target_field.default_source_field_name() # type: ignore
|
||||
else:
|
||||
new_part = target_field.default_target_field_name() # type: ignore
|
||||
return new_part
|
||||
|
||||
def _process_join( # noqa: CFQ002
|
||||
self,
|
||||
join_params: JoinParameters,
|
||||
is_multi: bool,
|
||||
model_cls: Type["Model"],
|
||||
part: str,
|
||||
related: str,
|
||||
alias: str,
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
target_field: Type[BaseField],
|
||||
prev_model: Type["Model"] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Resolves to and from column names and table names.
|
||||
@ -255,63 +228,53 @@ class SqlJoin:
|
||||
|
||||
Process order_by causes for non m2m relations.
|
||||
|
||||
:param join_params: parameters from previous/ current join
|
||||
:type join_params: JoinParameters
|
||||
:param is_multi: flag if it's m2m relation
|
||||
:type is_multi: bool
|
||||
:param model_cls:
|
||||
:type model_cls: ormar.models.metaclass.ModelMetaclass
|
||||
:param part: name of the field used in join
|
||||
:type part: str
|
||||
:param related: name of the field used in join
|
||||
:type related: str
|
||||
:param alias: alias of the current join
|
||||
:type alias: str
|
||||
:param fields: fields to include
|
||||
:type fields: Optional[Union[Set, Dict]]
|
||||
:param exclude_fields: fields to exclude
|
||||
:type exclude_fields: Optional[Union[Set, Dict]]
|
||||
"""
|
||||
to_table = model_cls.Meta.table.name
|
||||
to_key, from_key = self.get_to_and_from_keys(
|
||||
join_params, is_multi, model_cls, part
|
||||
)
|
||||
to_key, from_key = self.get_to_and_from_keys(related, target_field)
|
||||
|
||||
prev_model = prev_model or self.main_model
|
||||
|
||||
on_clause = self.on_clause(
|
||||
previous_alias=join_params.previous_alias,
|
||||
previous_alias=self.own_alias,
|
||||
alias=alias,
|
||||
from_clause=f"{join_params.from_table}.{from_key}",
|
||||
from_clause=f"{prev_model.Meta.tablename}.{from_key}",
|
||||
to_clause=f"{to_table}.{to_key}",
|
||||
)
|
||||
target_table = self.alias_manager(model_cls).prefixed_table_name(
|
||||
alias, to_table
|
||||
)
|
||||
target_table = self.alias_manager.prefixed_table_name(alias, to_table)
|
||||
self.select_from = sqlalchemy.sql.outerjoin(
|
||||
self.select_from, target_table, on_clause
|
||||
)
|
||||
|
||||
pkname_alias = model_cls.get_column_alias(model_cls.Meta.pkname)
|
||||
if not is_multi:
|
||||
if not issubclass(target_field, ManyToManyField):
|
||||
self.get_order_bys(
|
||||
alias=alias,
|
||||
to_table=to_table,
|
||||
pkname_alias=pkname_alias,
|
||||
part=part,
|
||||
part=related,
|
||||
model_cls=model_cls,
|
||||
)
|
||||
|
||||
self_related_fields = model_cls.own_table_columns(
|
||||
model=model_cls,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
fields=self.fields,
|
||||
exclude_fields=self.exclude_fields,
|
||||
use_alias=True,
|
||||
)
|
||||
self.columns.extend(
|
||||
self.alias_manager(model_cls).prefixed_columns(
|
||||
self.alias_manager.prefixed_columns(
|
||||
alias, model_cls.Meta.table, self_related_fields
|
||||
)
|
||||
)
|
||||
self.used_aliases.append(alias)
|
||||
|
||||
def _switch_many_to_many_order_columns(self, part: str, new_part: str) -> None:
|
||||
def _replace_many_to_many_order_by_columns(self, part: str, new_part: str) -> None:
|
||||
"""
|
||||
Substitutes the name of the relation with actual model name in m2m order bys.
|
||||
|
||||
@ -325,7 +288,7 @@ class SqlJoin:
|
||||
x.split("__") for x in self.order_columns if "__" in x
|
||||
]
|
||||
for condition in split_order_columns:
|
||||
if condition[-2] == part or condition[-2][1:] == part:
|
||||
if self._check_if_condition_apply(condition, part):
|
||||
condition[-2] = condition[-2].replace(part, new_part)
|
||||
self.order_columns = [x for x in self.order_columns if "__" not in x] + [
|
||||
"__".join(x) for x in split_order_columns
|
||||
@ -413,51 +376,34 @@ class SqlJoin:
|
||||
order = text(f"{alias}_{to_table}.{pkname_alias}")
|
||||
self.sorted_orders[f"{alias}.{pkname_alias}"] = order
|
||||
|
||||
@staticmethod
|
||||
def get_to_and_from_keys(
|
||||
join_params: JoinParameters,
|
||||
is_multi: bool,
|
||||
model_cls: Type["Model"],
|
||||
part: str,
|
||||
self, related: str, target_field: Type[BaseField]
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Based on the relation type, name of the relation and previous models and parts
|
||||
stored in JoinParameters it resolves the current to and from keys, which are
|
||||
different for ManyToMany relation, ForeignKey and reverse part of relations.
|
||||
different for ManyToMany relation, ForeignKey and reverse related of relations.
|
||||
|
||||
:param join_params: parameters from previous/ current join
|
||||
:type join_params: JoinParameters
|
||||
:param is_multi: flag if the relation is of m2m type
|
||||
:type is_multi: bool
|
||||
:param model_cls: ormar model class
|
||||
:type model_cls: Type[Model]
|
||||
:param part: name of the current relation join
|
||||
:type part: str
|
||||
:param target_field: relation field
|
||||
:type target_field: Type[ForeignKeyField]
|
||||
:param related: name of the current relation join
|
||||
:type related: str
|
||||
:return: to key and from key
|
||||
:rtype: Tuple[str, str]
|
||||
"""
|
||||
if is_multi:
|
||||
target_field = join_params.model_cls.Meta.model_fields[part]
|
||||
if (
|
||||
target_field.self_reference
|
||||
and part == target_field.self_reference_primary
|
||||
):
|
||||
to_key = target_field.default_target_field_name() # type: ignore
|
||||
else:
|
||||
to_key = target_field.default_source_field_name() # type: ignore
|
||||
from_key = join_params.prev_model.get_column_alias(
|
||||
join_params.prev_model.Meta.pkname
|
||||
if issubclass(target_field, ManyToManyField):
|
||||
to_key = self.process_m2m_related_name_change(
|
||||
target_field=target_field, related=related, reverse=True
|
||||
)
|
||||
from_key = self.main_model.get_column_alias(self.main_model.Meta.pkname)
|
||||
|
||||
elif join_params.prev_model.Meta.model_fields[part].virtual:
|
||||
to_field = join_params.prev_model.Meta.model_fields[part].get_related_name()
|
||||
to_key = model_cls.get_column_alias(to_field)
|
||||
from_key = join_params.prev_model.get_column_alias(
|
||||
join_params.prev_model.Meta.pkname
|
||||
)
|
||||
elif target_field.virtual:
|
||||
to_field = target_field.get_related_name()
|
||||
to_key = target_field.to.get_column_alias(to_field)
|
||||
from_key = self.main_model.get_column_alias(self.main_model.Meta.pkname)
|
||||
|
||||
else:
|
||||
to_key = model_cls.get_column_alias(model_cls.Meta.pkname)
|
||||
from_key = join_params.prev_model.get_column_alias(part)
|
||||
to_key = target_field.to.get_column_alias(target_field.to.Meta.pkname)
|
||||
from_key = self.main_model.get_column_alias(related)
|
||||
|
||||
return to_key, from_key
|
||||
|
||||
Reference in New Issue
Block a user