add order_by support for prefetch_related
This commit is contained in:
@ -32,14 +32,48 @@ def add_relation_field_to_fields(
|
||||
return fields
|
||||
|
||||
|
||||
def sort_models(models: List["Model"], orders_by: Dict) -> List["Model"]:
|
||||
sort_criteria = [
|
||||
(key, value) for key, value in orders_by.items() if isinstance(value, str)
|
||||
]
|
||||
sort_criteria = sort_criteria[::-1]
|
||||
for criteria in sort_criteria:
|
||||
if criteria[1] == "desc":
|
||||
models.sort(key=lambda x: getattr(x, criteria[0]), reverse=True)
|
||||
else:
|
||||
models.sort(key=lambda x: getattr(x, criteria[0]))
|
||||
return models
|
||||
|
||||
|
||||
def set_children_on_model( # noqa: CCR001
|
||||
model: "Model",
|
||||
related: str,
|
||||
children: Dict,
|
||||
model_id: int,
|
||||
models: Dict,
|
||||
orders_by: Dict,
|
||||
) -> None:
|
||||
for key, child_models in children.items():
|
||||
if key == model_id:
|
||||
models_to_set = [models[child] for child in sorted(child_models)]
|
||||
if models_to_set:
|
||||
if orders_by and any(isinstance(x, str) for x in orders_by.values()):
|
||||
models_to_set = sort_models(
|
||||
models=models_to_set, orders_by=orders_by
|
||||
)
|
||||
for child in models_to_set:
|
||||
setattr(model, related, child)
|
||||
|
||||
|
||||
class PrefetchQuery:
|
||||
def __init__(
|
||||
def __init__( # noqa: CFQ002
|
||||
self,
|
||||
model_cls: Type["Model"],
|
||||
fields: Optional[Union[Dict, Set]],
|
||||
exclude_fields: Optional[Union[Dict, Set]],
|
||||
prefetch_related: List,
|
||||
select_related: List,
|
||||
orders_by: List,
|
||||
) -> None:
|
||||
|
||||
self.model = model_cls
|
||||
@ -51,6 +85,8 @@ class PrefetchQuery:
|
||||
self.already_extracted: Dict = dict()
|
||||
self.models: Dict = {}
|
||||
self.select_dict = translate_list_to_dict(self._select_related)
|
||||
self.orders_by = orders_by or []
|
||||
self.order_dict = translate_list_to_dict(self.orders_by, is_order=True)
|
||||
|
||||
async def prefetch_related(
|
||||
self, models: Sequence["Model"], rows: List
|
||||
@ -128,7 +164,9 @@ class PrefetchQuery:
|
||||
return filter_clauses
|
||||
return []
|
||||
|
||||
def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model":
|
||||
def _populate_nested_related(
|
||||
self, model: "Model", prefetch_dict: Dict, orders_by: Dict,
|
||||
) -> "Model":
|
||||
|
||||
related_to_extract = model.get_filtered_names_to_extract(
|
||||
prefetch_dict=prefetch_dict
|
||||
@ -146,25 +184,17 @@ class PrefetchQuery:
|
||||
|
||||
children = self.already_extracted.get(target_model, {}).get(field_name, {})
|
||||
models = self.already_extracted.get(target_model, {}).get("pk_models", {})
|
||||
self._set_children_on_model(
|
||||
set_children_on_model(
|
||||
model=model,
|
||||
related=related,
|
||||
children=children,
|
||||
model_id=model_id,
|
||||
models=models,
|
||||
orders_by=orders_by.get(related, {}),
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _set_children_on_model(
|
||||
model: "Model", related: str, children: Dict, model_id: int, models: Dict
|
||||
) -> None:
|
||||
for key, child_models in children.items():
|
||||
if key == model_id:
|
||||
for child in child_models:
|
||||
setattr(model, related, models.get(child))
|
||||
|
||||
async def _prefetch_related_models(
|
||||
self, models: Sequence["Model"], rows: List
|
||||
) -> Sequence["Model"]:
|
||||
@ -174,6 +204,7 @@ class PrefetchQuery:
|
||||
target_model = self.model
|
||||
fields = self._columns
|
||||
exclude_fields = self._exclude_columns
|
||||
orders_by = self.order_dict
|
||||
for related in prefetch_dict.keys():
|
||||
await self._extract_related_models(
|
||||
related=related,
|
||||
@ -182,11 +213,14 @@ class PrefetchQuery:
|
||||
select_dict=select_dict.get(related, {}),
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
orders_by=orders_by.get(related, {}),
|
||||
)
|
||||
final_models = []
|
||||
for model in models:
|
||||
final_models.append(
|
||||
self._populate_nested_related(model=model, prefetch_dict=prefetch_dict,)
|
||||
self._populate_nested_related(
|
||||
model=model, prefetch_dict=prefetch_dict, orders_by=self.order_dict
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
@ -198,6 +232,7 @@ class PrefetchQuery:
|
||||
select_dict: Dict,
|
||||
fields: Union[Set[Any], Dict[Any, Any], None],
|
||||
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
|
||||
orders_by: Dict,
|
||||
) -> None:
|
||||
|
||||
fields = target_model.get_included(fields, related)
|
||||
@ -246,6 +281,7 @@ class PrefetchQuery:
|
||||
),
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
orders_by=self._get_select_related_if_apply(subrelated, orders_by),
|
||||
)
|
||||
|
||||
if not already_loaded:
|
||||
@ -257,10 +293,13 @@ class PrefetchQuery:
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
prefetch_dict=prefetch_dict,
|
||||
orders_by=orders_by,
|
||||
)
|
||||
else:
|
||||
self._update_already_loaded_rows(
|
||||
target_field=target_field, prefetch_dict=prefetch_dict,
|
||||
target_field=target_field,
|
||||
prefetch_dict=prefetch_dict,
|
||||
orders_by=orders_by,
|
||||
)
|
||||
|
||||
async def _run_prefetch_query(
|
||||
@ -310,12 +349,12 @@ class PrefetchQuery:
|
||||
)
|
||||
|
||||
def _update_already_loaded_rows( # noqa: CFQ002
|
||||
self, target_field: Type["BaseField"], prefetch_dict: Dict,
|
||||
self, target_field: Type["BaseField"], prefetch_dict: Dict, orders_by: Dict,
|
||||
) -> None:
|
||||
target_model = target_field.to
|
||||
for instance in self.models.get(target_model.get_name(), []):
|
||||
self._populate_nested_related(
|
||||
model=instance, prefetch_dict=prefetch_dict,
|
||||
model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by
|
||||
)
|
||||
|
||||
def _populate_rows( # noqa: CFQ002
|
||||
@ -327,6 +366,7 @@ class PrefetchQuery:
|
||||
fields: Union[Set[Any], Dict[Any, Any], None],
|
||||
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
|
||||
prefetch_dict: Dict,
|
||||
orders_by: Dict,
|
||||
) -> None:
|
||||
target_model = target_field.to
|
||||
for row in rows:
|
||||
@ -340,7 +380,7 @@ class PrefetchQuery:
|
||||
)
|
||||
instance = target_model(**item)
|
||||
instance = self._populate_nested_related(
|
||||
model=instance, prefetch_dict=prefetch_dict,
|
||||
model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by
|
||||
)
|
||||
field_db_name = target_model.get_column_alias(field_name)
|
||||
models = self.already_extracted[target_model.get_name()].setdefault(
|
||||
|
||||
@ -75,6 +75,7 @@ class QuerySet:
|
||||
exclude_fields=self._exclude_columns,
|
||||
prefetch_related=self._prefetch_related,
|
||||
select_related=self._select_related,
|
||||
orders_by=self.order_bys,
|
||||
)
|
||||
return await query.prefetch_related(models=models, rows=rows) # type: ignore
|
||||
|
||||
|
||||
@ -14,18 +14,28 @@ def check_node_not_dict_or_not_last_node(
|
||||
)
|
||||
|
||||
|
||||
def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001
|
||||
def translate_list_to_dict( # noqa: CCR001
|
||||
list_to_trans: Union[List, Set], is_order: bool = False
|
||||
) -> Dict:
|
||||
new_dict: Dict = dict()
|
||||
for path in list_to_trans:
|
||||
current_level = new_dict
|
||||
parts = path.split("__")
|
||||
def_val: Any = ...
|
||||
if is_order:
|
||||
if parts[0][0] == "-":
|
||||
def_val = "desc"
|
||||
parts[0] = parts[0][1:]
|
||||
else:
|
||||
def_val = "asc"
|
||||
|
||||
for part in parts:
|
||||
if check_node_not_dict_or_not_last_node(
|
||||
part=part, parts=parts, current_level=current_level
|
||||
):
|
||||
current_level[part] = dict()
|
||||
elif part not in current_level:
|
||||
current_level[part] = ...
|
||||
current_level[part] = def_val
|
||||
current_level = current_level[part]
|
||||
return new_dict
|
||||
|
||||
|
||||
@ -228,6 +228,7 @@ async def test_prefetch_related_with_select_related():
|
||||
|
||||
album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related(
|
||||
['cover_pictures', 'shops__division']).get()
|
||||
|
||||
assert len(album.tracks) == 0
|
||||
assert len(album.cover_pictures) == 2
|
||||
assert album.shops[0].division.name == 'Div 1'
|
||||
@ -240,14 +241,15 @@ async def test_prefetch_related_with_select_related():
|
||||
|
||||
album = await Album.objects.select_related('tracks__tonation__rand_set').filter(
|
||||
name='Malibu').prefetch_related(
|
||||
['cover_pictures', 'shops__division']).get()
|
||||
['cover_pictures', 'shops__division']).order_by(
|
||||
['-shops__name', '-cover_pictures__artist', 'shops__division__name']).get()
|
||||
assert len(album.tracks) == 3
|
||||
assert album.tracks[0].tonation == album.tracks[2].tonation == ton1
|
||||
assert len(album.cover_pictures) == 2
|
||||
assert album.cover_pictures[0].artist == 'Artist 1'
|
||||
assert album.cover_pictures[0].artist == 'Artist 2'
|
||||
|
||||
assert len(album.shops) == 2
|
||||
assert album.shops[0].name == 'Shop 1'
|
||||
assert album.shops[0].name == 'Shop 2'
|
||||
assert album.shops[0].division.name == 'Div 1'
|
||||
|
||||
track = await Track.objects.select_related('album').prefetch_related(
|
||||
|
||||
@ -1,5 +1,11 @@
|
||||
import databases
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from ormar.models.excludable import Excludable
|
||||
from ormar.queryset.prefetch_query import sort_models
|
||||
from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
|
||||
def test_empty_excludable():
|
||||
@ -96,3 +102,43 @@ def test_updating_dict_inc_set_with_dict_inc_set():
|
||||
"cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis},
|
||||
"uu": Ellipsis,
|
||||
}
|
||||
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
class SortModel(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "sorts"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100)
|
||||
sort_order: int = ormar.Integer()
|
||||
|
||||
|
||||
def test_sorting_models():
|
||||
models = [
|
||||
SortModel(id=1, name='Alice', sort_order=0),
|
||||
SortModel(id=2, name='Al', sort_order=1),
|
||||
SortModel(id=3, name='Zake', sort_order=1),
|
||||
SortModel(id=4, name='Will', sort_order=0),
|
||||
SortModel(id=5, name='Al', sort_order=2),
|
||||
SortModel(id=6, name='Alice', sort_order=2)
|
||||
]
|
||||
orders_by = {'name': 'asc', 'none': {}, 'sort_order': 'desc'}
|
||||
models = sort_models(models, orders_by)
|
||||
assert models[5].name == 'Zake'
|
||||
assert models[0].name == 'Al'
|
||||
assert models[1].name == 'Al'
|
||||
assert [model.id for model in models] == [5, 2, 6, 1, 4, 3]
|
||||
|
||||
orders_by = {'name': 'asc', 'none': set('aa'), 'id': 'asc'}
|
||||
models = sort_models(models, orders_by)
|
||||
assert [model.id for model in models] == [2, 5, 1, 6, 4, 3]
|
||||
|
||||
orders_by = {'sort_order': 'asc', 'none': ..., 'id': 'asc', 'uu': 2, 'aa': None}
|
||||
models = sort_models(models, orders_by)
|
||||
assert [model.id for model in models] == [1, 4, 2, 3, 5, 6]
|
||||
|
||||
Reference in New Issue
Block a user