allow passing a dict and set to fields and exclude_fields, store it as dict
This commit is contained in:
@ -1,5 +1,15 @@
|
||||
from collections import OrderedDict
|
||||
from typing import List, NamedTuple, Optional, TYPE_CHECKING, Tuple, Type
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import text
|
||||
@ -24,8 +34,8 @@ class SqlJoin:
|
||||
used_aliases: List,
|
||||
select_from: sqlalchemy.sql.select,
|
||||
columns: List[sqlalchemy.Column],
|
||||
fields: List,
|
||||
exclude_fields: List,
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
order_columns: Optional[List],
|
||||
sorted_orders: OrderedDict,
|
||||
) -> None:
|
||||
@ -49,21 +59,60 @@ class SqlJoin:
|
||||
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
|
||||
return text(f"{left_part}={right_part}")
|
||||
|
||||
def build_join(
|
||||
@staticmethod
|
||||
def update_inclusions(
|
||||
model_cls: Type["Model"],
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
nested_name: str,
|
||||
) -> Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]:
|
||||
fields = model_cls.get_included(fields, nested_name)
|
||||
exclude_fields = model_cls.get_included(exclude_fields, nested_name)
|
||||
return fields, exclude_fields
|
||||
|
||||
def build_join( # noqa: CCR001
|
||||
self, item: str, join_parameters: JoinParameters
|
||||
) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]:
|
||||
for part in item.split("__"):
|
||||
|
||||
fields = self.fields
|
||||
exclude_fields = self.exclude_fields
|
||||
|
||||
for index, part in enumerate(item.split("__")):
|
||||
if issubclass(
|
||||
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
|
||||
):
|
||||
_fields = join_parameters.model_cls.Meta.model_fields
|
||||
new_part = _fields[part].to.get_name()
|
||||
self._switch_many_to_many_order_columns(part, new_part)
|
||||
if index > 0: # nested joins
|
||||
fields, exclude_fields = SqlJoin.update_inclusions(
|
||||
model_cls=join_parameters.model_cls,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
nested_name=part,
|
||||
)
|
||||
|
||||
join_parameters = self._build_join_parameters(
|
||||
part, join_parameters, is_multi=True
|
||||
part=part,
|
||||
join_params=join_parameters,
|
||||
is_multi=True,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
)
|
||||
part = new_part
|
||||
join_parameters = self._build_join_parameters(part, join_parameters)
|
||||
if index > 0: # nested joins
|
||||
fields, exclude_fields = SqlJoin.update_inclusions(
|
||||
model_cls=join_parameters.model_cls,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
nested_name=part,
|
||||
)
|
||||
join_parameters = self._build_join_parameters(
|
||||
part=part,
|
||||
join_params=join_parameters,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
)
|
||||
|
||||
return (
|
||||
self.used_aliases,
|
||||
@ -73,7 +122,12 @@ class SqlJoin:
|
||||
)
|
||||
|
||||
def _build_join_parameters(
|
||||
self, part: str, join_params: JoinParameters, is_multi: bool = False
|
||||
self,
|
||||
part: str,
|
||||
join_params: JoinParameters,
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
is_multi: bool = False,
|
||||
) -> JoinParameters:
|
||||
if is_multi:
|
||||
model_cls = join_params.model_cls.Meta.model_fields[part].through
|
||||
@ -85,20 +139,30 @@ class SqlJoin:
|
||||
join_params.from_table, to_table
|
||||
)
|
||||
if alias not in self.used_aliases:
|
||||
self._process_join(join_params, is_multi, model_cls, part, alias)
|
||||
self._process_join(
|
||||
join_params=join_params,
|
||||
is_multi=is_multi,
|
||||
model_cls=model_cls,
|
||||
part=part,
|
||||
alias=alias,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
)
|
||||
|
||||
previous_alias = alias
|
||||
from_table = to_table
|
||||
prev_model = model_cls
|
||||
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
||||
|
||||
def _process_join(
|
||||
def _process_join( # noqa: CFQ002
|
||||
self,
|
||||
join_params: JoinParameters,
|
||||
is_multi: bool,
|
||||
model_cls: Type["Model"],
|
||||
part: str,
|
||||
alias: str,
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
) -> None:
|
||||
to_table = model_cls.Meta.table.name
|
||||
to_key, from_key = self.get_to_and_from_keys(
|
||||
@ -129,7 +193,10 @@ class SqlJoin:
|
||||
)
|
||||
|
||||
self_related_fields = model_cls.own_table_columns(
|
||||
model_cls, self.fields, self.exclude_fields, nested=True,
|
||||
model=model_cls,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
use_alias=True,
|
||||
)
|
||||
self.columns.extend(
|
||||
self.relation_manager(model_cls).prefixed_columns(
|
||||
|
||||
Reference in New Issue
Block a user