diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 788fe5d..198c9b6 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -32,14 +32,48 @@ def add_relation_field_to_fields( return fields +def sort_models(models: List["Model"], orders_by: Dict) -> List["Model"]: + sort_criteria = [ + (key, value) for key, value in orders_by.items() if isinstance(value, str) + ] + sort_criteria = sort_criteria[::-1] + for criteria in sort_criteria: + if criteria[1] == "desc": + models.sort(key=lambda x: getattr(x, criteria[0]), reverse=True) + else: + models.sort(key=lambda x: getattr(x, criteria[0])) + return models + + +def set_children_on_model( # noqa: CCR001 + model: "Model", + related: str, + children: Dict, + model_id: int, + models: Dict, + orders_by: Dict, +) -> None: + for key, child_models in children.items(): + if key == model_id: + models_to_set = [models[child] for child in sorted(child_models)] + if models_to_set: + if orders_by and any(isinstance(x, str) for x in orders_by.values()): + models_to_set = sort_models( + models=models_to_set, orders_by=orders_by + ) + for child in models_to_set: + setattr(model, related, child) + + class PrefetchQuery: - def __init__( + def __init__( # noqa: CFQ002 self, model_cls: Type["Model"], fields: Optional[Union[Dict, Set]], exclude_fields: Optional[Union[Dict, Set]], prefetch_related: List, select_related: List, + orders_by: List, ) -> None: self.model = model_cls @@ -51,6 +85,8 @@ class PrefetchQuery: self.already_extracted: Dict = dict() self.models: Dict = {} self.select_dict = translate_list_to_dict(self._select_related) + self.orders_by = orders_by or [] + self.order_dict = translate_list_to_dict(self.orders_by, is_order=True) async def prefetch_related( self, models: Sequence["Model"], rows: List @@ -128,7 +164,9 @@ class PrefetchQuery: return filter_clauses return [] - def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model": + def _populate_nested_related( + self, model: "Model", prefetch_dict: Dict, orders_by: Dict, + ) -> "Model": related_to_extract = model.get_filtered_names_to_extract( prefetch_dict=prefetch_dict @@ -146,25 +184,17 @@ class PrefetchQuery: children = self.already_extracted.get(target_model, {}).get(field_name, {}) models = self.already_extracted.get(target_model, {}).get("pk_models", {}) - self._set_children_on_model( + set_children_on_model( model=model, related=related, children=children, model_id=model_id, models=models, + orders_by=orders_by.get(related, {}), ) return model - @staticmethod - def _set_children_on_model( - model: "Model", related: str, children: Dict, model_id: int, models: Dict - ) -> None: - for key, child_models in children.items(): - if key == model_id: - for child in child_models: - setattr(model, related, models.get(child)) - async def _prefetch_related_models( self, models: Sequence["Model"], rows: List ) -> Sequence["Model"]: @@ -174,6 +204,7 @@ class PrefetchQuery: target_model = self.model fields = self._columns exclude_fields = self._exclude_columns + orders_by = self.order_dict for related in prefetch_dict.keys(): await self._extract_related_models( related=related, @@ -182,11 +213,14 @@ class PrefetchQuery: select_dict=select_dict.get(related, {}), fields=fields, exclude_fields=exclude_fields, + orders_by=orders_by.get(related, {}), ) final_models = [] for model in models: final_models.append( - self._populate_nested_related(model=model, prefetch_dict=prefetch_dict,) + self._populate_nested_related( + model=model, prefetch_dict=prefetch_dict, orders_by=self.order_dict + ) ) return models @@ -198,6 +232,7 @@ class PrefetchQuery: select_dict: Dict, fields: Union[Set[Any], Dict[Any, Any], None], exclude_fields: Union[Set[Any], Dict[Any, Any], None], + orders_by: Dict, ) -> None: fields = target_model.get_included(fields, related) @@ -246,6 +281,7 @@ class PrefetchQuery: ), fields=fields, exclude_fields=exclude_fields, + orders_by=self._get_select_related_if_apply(subrelated, orders_by), ) if not already_loaded: @@ -257,10 +293,13 @@ class PrefetchQuery: fields=fields, exclude_fields=exclude_fields, prefetch_dict=prefetch_dict, + orders_by=orders_by, ) else: self._update_already_loaded_rows( - target_field=target_field, prefetch_dict=prefetch_dict, + target_field=target_field, + prefetch_dict=prefetch_dict, + orders_by=orders_by, ) async def _run_prefetch_query( @@ -310,12 +349,12 @@ class PrefetchQuery: ) def _update_already_loaded_rows( # noqa: CFQ002 - self, target_field: Type["BaseField"], prefetch_dict: Dict, + self, target_field: Type["BaseField"], prefetch_dict: Dict, orders_by: Dict, ) -> None: target_model = target_field.to for instance in self.models.get(target_model.get_name(), []): self._populate_nested_related( - model=instance, prefetch_dict=prefetch_dict, + model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by ) def _populate_rows( # noqa: CFQ002 @@ -327,6 +366,7 @@ class PrefetchQuery: fields: Union[Set[Any], Dict[Any, Any], None], exclude_fields: Union[Set[Any], Dict[Any, Any], None], prefetch_dict: Dict, + orders_by: Dict, ) -> None: target_model = target_field.to for row in rows: @@ -340,7 +380,7 @@ class PrefetchQuery: ) instance = target_model(**item) instance = self._populate_nested_related( - model=instance, prefetch_dict=prefetch_dict, + model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by ) field_db_name = target_model.get_column_alias(field_name) models = self.already_extracted[target_model.get_name()].setdefault( diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index bc36965..093abf7 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -75,6 +75,7 @@ class QuerySet: exclude_fields=self._exclude_columns, prefetch_related=self._prefetch_related, select_related=self._select_related, + orders_by=self.order_bys, ) return await query.prefetch_related(models=models, rows=rows) # type: ignore diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index ed422dd..bed2e25 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -14,18 +14,28 @@ def check_node_not_dict_or_not_last_node( ) -def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001 +def translate_list_to_dict( # noqa: CCR001 + list_to_trans: Union[List, Set], is_order: bool = False +) -> Dict: new_dict: Dict = dict() for path in list_to_trans: current_level = new_dict parts = path.split("__") + def_val: Any = ... + if is_order: + if parts[0][0] == "-": + def_val = "desc" + parts[0] = parts[0][1:] + else: + def_val = "asc" + for part in parts: if check_node_not_dict_or_not_last_node( part=part, parts=parts, current_level=current_level ): current_level[part] = dict() elif part not in current_level: - current_level[part] = ... + current_level[part] = def_val current_level = current_level[part] return new_dict diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index 4d3fba9..bfc09d7 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -228,6 +228,7 @@ async def test_prefetch_related_with_select_related(): album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related( ['cover_pictures', 'shops__division']).get() + assert len(album.tracks) == 0 assert len(album.cover_pictures) == 2 assert album.shops[0].division.name == 'Div 1' @@ -240,14 +241,15 @@ async def test_prefetch_related_with_select_related(): album = await Album.objects.select_related('tracks__tonation__rand_set').filter( name='Malibu').prefetch_related( - ['cover_pictures', 'shops__division']).get() + ['cover_pictures', 'shops__division']).order_by( + ['-shops__name', '-cover_pictures__artist', 'shops__division__name']).get() assert len(album.tracks) == 3 assert album.tracks[0].tonation == album.tracks[2].tonation == ton1 assert len(album.cover_pictures) == 2 - assert album.cover_pictures[0].artist == 'Artist 1' + assert album.cover_pictures[0].artist == 'Artist 2' assert len(album.shops) == 2 - assert album.shops[0].name == 'Shop 1' + assert album.shops[0].name == 'Shop 2' assert album.shops[0].division.name == 'Div 1' track = await Track.objects.select_related('album').prefetch_related( diff --git a/tests/test_queryset_utils.py b/tests/test_queryset_utils.py index bac0a26..97695b9 100644 --- a/tests/test_queryset_utils.py +++ b/tests/test_queryset_utils.py @@ -1,5 +1,11 @@ +import databases +import sqlalchemy + +import ormar from ormar.models.excludable import Excludable +from ormar.queryset.prefetch_query import sort_models from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update +from tests.settings import DATABASE_URL def test_empty_excludable(): @@ -96,3 +102,43 @@ def test_updating_dict_inc_set_with_dict_inc_set(): "cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis}, "uu": Ellipsis, } + + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class SortModel(ormar.Model): + class Meta: + tablename = "sorts" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + sort_order: int = ormar.Integer() + + +def test_sorting_models(): + models = [ + SortModel(id=1, name='Alice', sort_order=0), + SortModel(id=2, name='Al', sort_order=1), + SortModel(id=3, name='Zake', sort_order=1), + SortModel(id=4, name='Will', sort_order=0), + SortModel(id=5, name='Al', sort_order=2), + SortModel(id=6, name='Alice', sort_order=2) + ] + orders_by = {'name': 'asc', 'none': {}, 'sort_order': 'desc'} + models = sort_models(models, orders_by) + assert models[5].name == 'Zake' + assert models[0].name == 'Al' + assert models[1].name == 'Al' + assert [model.id for model in models] == [5, 2, 6, 1, 4, 3] + + orders_by = {'name': 'asc', 'none': set('aa'), 'id': 'asc'} + models = sort_models(models, orders_by) + assert [model.id for model in models] == [2, 5, 1, 6, 4, 3] + + orders_by = {'sort_order': 'asc', 'none': ..., 'id': 'asc', 'uu': 2, 'aa': None} + models = sort_models(models, orders_by) + assert [model.id for model in models] == [1, 4, 2, 3, 5, 6]