add order_by support for prefetch_related
This commit is contained in:
@ -32,14 +32,48 @@ def add_relation_field_to_fields(
|
||||
return fields
|
||||
|
||||
|
||||
def sort_models(models: List["Model"], orders_by: Dict) -> List["Model"]:
|
||||
sort_criteria = [
|
||||
(key, value) for key, value in orders_by.items() if isinstance(value, str)
|
||||
]
|
||||
sort_criteria = sort_criteria[::-1]
|
||||
for criteria in sort_criteria:
|
||||
if criteria[1] == "desc":
|
||||
models.sort(key=lambda x: getattr(x, criteria[0]), reverse=True)
|
||||
else:
|
||||
models.sort(key=lambda x: getattr(x, criteria[0]))
|
||||
return models
|
||||
|
||||
|
||||
def set_children_on_model( # noqa: CCR001
|
||||
model: "Model",
|
||||
related: str,
|
||||
children: Dict,
|
||||
model_id: int,
|
||||
models: Dict,
|
||||
orders_by: Dict,
|
||||
) -> None:
|
||||
for key, child_models in children.items():
|
||||
if key == model_id:
|
||||
models_to_set = [models[child] for child in sorted(child_models)]
|
||||
if models_to_set:
|
||||
if orders_by and any(isinstance(x, str) for x in orders_by.values()):
|
||||
models_to_set = sort_models(
|
||||
models=models_to_set, orders_by=orders_by
|
||||
)
|
||||
for child in models_to_set:
|
||||
setattr(model, related, child)
|
||||
|
||||
|
||||
class PrefetchQuery:
|
||||
def __init__(
|
||||
def __init__( # noqa: CFQ002
|
||||
self,
|
||||
model_cls: Type["Model"],
|
||||
fields: Optional[Union[Dict, Set]],
|
||||
exclude_fields: Optional[Union[Dict, Set]],
|
||||
prefetch_related: List,
|
||||
select_related: List,
|
||||
orders_by: List,
|
||||
) -> None:
|
||||
|
||||
self.model = model_cls
|
||||
@ -51,6 +85,8 @@ class PrefetchQuery:
|
||||
self.already_extracted: Dict = dict()
|
||||
self.models: Dict = {}
|
||||
self.select_dict = translate_list_to_dict(self._select_related)
|
||||
self.orders_by = orders_by or []
|
||||
self.order_dict = translate_list_to_dict(self.orders_by, is_order=True)
|
||||
|
||||
async def prefetch_related(
|
||||
self, models: Sequence["Model"], rows: List
|
||||
@ -128,7 +164,9 @@ class PrefetchQuery:
|
||||
return filter_clauses
|
||||
return []
|
||||
|
||||
def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model":
|
||||
def _populate_nested_related(
|
||||
self, model: "Model", prefetch_dict: Dict, orders_by: Dict,
|
||||
) -> "Model":
|
||||
|
||||
related_to_extract = model.get_filtered_names_to_extract(
|
||||
prefetch_dict=prefetch_dict
|
||||
@ -146,25 +184,17 @@ class PrefetchQuery:
|
||||
|
||||
children = self.already_extracted.get(target_model, {}).get(field_name, {})
|
||||
models = self.already_extracted.get(target_model, {}).get("pk_models", {})
|
||||
self._set_children_on_model(
|
||||
set_children_on_model(
|
||||
model=model,
|
||||
related=related,
|
||||
children=children,
|
||||
model_id=model_id,
|
||||
models=models,
|
||||
orders_by=orders_by.get(related, {}),
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _set_children_on_model(
|
||||
model: "Model", related: str, children: Dict, model_id: int, models: Dict
|
||||
) -> None:
|
||||
for key, child_models in children.items():
|
||||
if key == model_id:
|
||||
for child in child_models:
|
||||
setattr(model, related, models.get(child))
|
||||
|
||||
async def _prefetch_related_models(
|
||||
self, models: Sequence["Model"], rows: List
|
||||
) -> Sequence["Model"]:
|
||||
@ -174,6 +204,7 @@ class PrefetchQuery:
|
||||
target_model = self.model
|
||||
fields = self._columns
|
||||
exclude_fields = self._exclude_columns
|
||||
orders_by = self.order_dict
|
||||
for related in prefetch_dict.keys():
|
||||
await self._extract_related_models(
|
||||
related=related,
|
||||
@ -182,11 +213,14 @@ class PrefetchQuery:
|
||||
select_dict=select_dict.get(related, {}),
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
orders_by=orders_by.get(related, {}),
|
||||
)
|
||||
final_models = []
|
||||
for model in models:
|
||||
final_models.append(
|
||||
self._populate_nested_related(model=model, prefetch_dict=prefetch_dict,)
|
||||
self._populate_nested_related(
|
||||
model=model, prefetch_dict=prefetch_dict, orders_by=self.order_dict
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
@ -198,6 +232,7 @@ class PrefetchQuery:
|
||||
select_dict: Dict,
|
||||
fields: Union[Set[Any], Dict[Any, Any], None],
|
||||
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
|
||||
orders_by: Dict,
|
||||
) -> None:
|
||||
|
||||
fields = target_model.get_included(fields, related)
|
||||
@ -246,6 +281,7 @@ class PrefetchQuery:
|
||||
),
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
orders_by=self._get_select_related_if_apply(subrelated, orders_by),
|
||||
)
|
||||
|
||||
if not already_loaded:
|
||||
@ -257,10 +293,13 @@ class PrefetchQuery:
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
prefetch_dict=prefetch_dict,
|
||||
orders_by=orders_by,
|
||||
)
|
||||
else:
|
||||
self._update_already_loaded_rows(
|
||||
target_field=target_field, prefetch_dict=prefetch_dict,
|
||||
target_field=target_field,
|
||||
prefetch_dict=prefetch_dict,
|
||||
orders_by=orders_by,
|
||||
)
|
||||
|
||||
async def _run_prefetch_query(
|
||||
@ -310,12 +349,12 @@ class PrefetchQuery:
|
||||
)
|
||||
|
||||
def _update_already_loaded_rows( # noqa: CFQ002
|
||||
self, target_field: Type["BaseField"], prefetch_dict: Dict,
|
||||
self, target_field: Type["BaseField"], prefetch_dict: Dict, orders_by: Dict,
|
||||
) -> None:
|
||||
target_model = target_field.to
|
||||
for instance in self.models.get(target_model.get_name(), []):
|
||||
self._populate_nested_related(
|
||||
model=instance, prefetch_dict=prefetch_dict,
|
||||
model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by
|
||||
)
|
||||
|
||||
def _populate_rows( # noqa: CFQ002
|
||||
@ -327,6 +366,7 @@ class PrefetchQuery:
|
||||
fields: Union[Set[Any], Dict[Any, Any], None],
|
||||
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
|
||||
prefetch_dict: Dict,
|
||||
orders_by: Dict,
|
||||
) -> None:
|
||||
target_model = target_field.to
|
||||
for row in rows:
|
||||
@ -340,7 +380,7 @@ class PrefetchQuery:
|
||||
)
|
||||
instance = target_model(**item)
|
||||
instance = self._populate_nested_related(
|
||||
model=instance, prefetch_dict=prefetch_dict,
|
||||
model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by
|
||||
)
|
||||
field_db_name = target_model.get_column_alias(field_name)
|
||||
models = self.already_extracted[target_model.get_name()].setdefault(
|
||||
|
||||
@ -75,6 +75,7 @@ class QuerySet:
|
||||
exclude_fields=self._exclude_columns,
|
||||
prefetch_related=self._prefetch_related,
|
||||
select_related=self._select_related,
|
||||
orders_by=self.order_bys,
|
||||
)
|
||||
return await query.prefetch_related(models=models, rows=rows) # type: ignore
|
||||
|
||||
|
||||
@ -14,18 +14,28 @@ def check_node_not_dict_or_not_last_node(
|
||||
)
|
||||
|
||||
|
||||
def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001
|
||||
def translate_list_to_dict( # noqa: CCR001
|
||||
list_to_trans: Union[List, Set], is_order: bool = False
|
||||
) -> Dict:
|
||||
new_dict: Dict = dict()
|
||||
for path in list_to_trans:
|
||||
current_level = new_dict
|
||||
parts = path.split("__")
|
||||
def_val: Any = ...
|
||||
if is_order:
|
||||
if parts[0][0] == "-":
|
||||
def_val = "desc"
|
||||
parts[0] = parts[0][1:]
|
||||
else:
|
||||
def_val = "asc"
|
||||
|
||||
for part in parts:
|
||||
if check_node_not_dict_or_not_last_node(
|
||||
part=part, parts=parts, current_level=current_level
|
||||
):
|
||||
current_level[part] = dict()
|
||||
elif part not in current_level:
|
||||
current_level[part] = ...
|
||||
current_level[part] = def_val
|
||||
current_level = current_level[part]
|
||||
return new_dict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user