refactors in join to register complex aliases on duplicate, to do is doing the same in filter clauses

This commit is contained in:
collerek
2021-01-17 12:29:21 +01:00
parent 28cc847b57
commit d6e2c85b79
9 changed files with 137 additions and 74 deletions

View File

@ -24,18 +24,20 @@ if TYPE_CHECKING: # pragma no cover
class SqlJoin:
def __init__( # noqa: CFQ002
self,
used_aliases: List,
select_from: sqlalchemy.sql.select,
columns: List[sqlalchemy.Column],
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
order_columns: Optional[List],
sorted_orders: OrderedDict,
main_model: Type["Model"],
relation_name: str,
related_models: Any = None,
own_alias: str = "",
self,
used_aliases: List,
select_from: sqlalchemy.sql.select,
columns: List[sqlalchemy.Column],
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
order_columns: Optional[List],
sorted_orders: OrderedDict,
main_model: Type["Model"],
relation_name: str,
relation_str: str,
related_models: Any = None,
own_alias: str = "",
source_model: Type["Model"] = None,
) -> None:
self.relation_name = relation_name
self.related_models = related_models or []
@ -53,6 +55,9 @@ class SqlJoin:
self._next_model: Optional[Type["Model"]] = None
self._next_alias: Optional[str] = None
self.relation_str = relation_str
self.source_model = source_model
@property
def next_model(self) -> Type["Model"]:
if not self._next_model: # pragma: nocover
@ -85,7 +90,8 @@ class SqlJoin:
"""
return self.main_model.Meta.alias_manager
def on_clause(self, previous_alias: str, from_clause: str, to_clause: str,) -> text:
def on_clause(self, previous_alias: str, from_clause: str,
to_clause: str, ) -> text:
"""
Receives aliases and names of both ends of the join and combines them
into one text clause used in joins.
@ -117,11 +123,7 @@ class SqlJoin:
self.process_m2m_through_table()
self.next_model = self.target_field.to
self.next_alias = self.alias_manager.resolve_relation_alias(
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
self._forward_join()
self._process_following_joins()
@ -132,6 +134,23 @@ class SqlJoin:
self.sorted_orders,
)
def _forward_join(self):
self.next_alias = self.alias_manager.resolve_relation_alias(
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
else:
if '__' in self.relation_str:
relation_key = f'{self.source_model.get_name()}_{self.relation_str}'
if relation_key not in self.alias_manager:
print(f'registering {relation_key}')
self.next_alias = self.alias_manager.add_alias(
alias_key=relation_key)
else:
self.next_alias = self.alias_manager[relation_key]
self._process_join()
def _process_following_joins(self) -> None:
"""
Iterates through nested models to create subsequent joins.
@ -139,8 +158,8 @@ class SqlJoin:
for related_name in self.related_models:
remainder = None
if (
isinstance(self.related_models, dict)
and self.related_models[related_name]
isinstance(self.related_models, dict)
and self.related_models[related_name]
):
remainder = self.related_models[related_name]
self._process_deeper_join(related_name=related_name, remainder=remainder)
@ -175,7 +194,9 @@ class SqlJoin:
main_model=self.next_model,
relation_name=related_name,
related_models=remainder,
relation_str='__'.join([self.relation_str, related_name]),
own_alias=self.next_alias,
source_model=self.source_model or self.main_model
)
(
self.used_aliases,
@ -203,11 +224,8 @@ class SqlJoin:
self._replace_many_to_many_order_by_columns(self.relation_name, new_part)
self.next_model = self.target_field.through
self.next_alias = self.alias_manager.resolve_relation_alias(
from_model=self.target_field.owner, relation_name=self.relation_name
)
if self.next_alias not in self.used_aliases:
self._process_join()
self._forward_join()
self.relation_name = new_part
self.own_alias = self.next_alias
self.target_field = self.next_model.Meta.model_fields[self.relation_name]
@ -226,18 +244,18 @@ class SqlJoin:
"""
target_field = self.target_field
is_primary_self_ref = (
target_field.self_reference
and self.relation_name == target_field.self_reference_primary
target_field.self_reference
and self.relation_name == target_field.self_reference_primary
)
if (is_primary_self_ref and not reverse) or (
not is_primary_self_ref and reverse
not is_primary_self_ref and reverse
):
new_part = target_field.default_source_field_name() # type: ignore
else:
new_part = target_field.default_target_field_name() # type: ignore
return new_part
def _process_join(self,) -> None: # noqa: CFQ002
def _process_join(self, ) -> None: # noqa: CFQ002
"""
Resolves to and from column names and table names.
@ -317,10 +335,10 @@ class SqlJoin:
:rtype: bool
"""
return len(condition) >= 2 and (
condition[-2] == part or condition[-2][1:] == part
condition[-2] == part or condition[-2][1:] == part
)
def set_aliased_order_by(self, condition: List[str], to_table: str,) -> None:
def set_aliased_order_by(self, condition: List[str], to_table: str, ) -> None:
"""
Substitute hyphens ('-') with descending order.
Construct actual sqlalchemy text clause using aliased table and column name.
@ -335,7 +353,7 @@ class SqlJoin:
order = text(f"{self.next_alias}_{to_table}.{column_alias} {direction}")
self.sorted_orders["__".join(condition)] = order
def get_order_bys(self, to_table: str, pkname_alias: str,) -> None: # noqa: CCR001
def get_order_bys(self, to_table: str, pkname_alias: str, ) -> None: # noqa: CCR001
"""
Triggers construction of order bys if they are given.
Otherwise by default each table is sorted by a primary key column asc.