add order_by support for prefetch_related
This commit is contained in:
@ -32,14 +32,48 @@ def add_relation_field_to_fields(
|
|||||||
return fields
|
return fields
|
||||||
|
|
||||||
|
|
||||||
|
def sort_models(models: List["Model"], orders_by: Dict) -> List["Model"]:
|
||||||
|
sort_criteria = [
|
||||||
|
(key, value) for key, value in orders_by.items() if isinstance(value, str)
|
||||||
|
]
|
||||||
|
sort_criteria = sort_criteria[::-1]
|
||||||
|
for criteria in sort_criteria:
|
||||||
|
if criteria[1] == "desc":
|
||||||
|
models.sort(key=lambda x: getattr(x, criteria[0]), reverse=True)
|
||||||
|
else:
|
||||||
|
models.sort(key=lambda x: getattr(x, criteria[0]))
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
def set_children_on_model( # noqa: CCR001
|
||||||
|
model: "Model",
|
||||||
|
related: str,
|
||||||
|
children: Dict,
|
||||||
|
model_id: int,
|
||||||
|
models: Dict,
|
||||||
|
orders_by: Dict,
|
||||||
|
) -> None:
|
||||||
|
for key, child_models in children.items():
|
||||||
|
if key == model_id:
|
||||||
|
models_to_set = [models[child] for child in sorted(child_models)]
|
||||||
|
if models_to_set:
|
||||||
|
if orders_by and any(isinstance(x, str) for x in orders_by.values()):
|
||||||
|
models_to_set = sort_models(
|
||||||
|
models=models_to_set, orders_by=orders_by
|
||||||
|
)
|
||||||
|
for child in models_to_set:
|
||||||
|
setattr(model, related, child)
|
||||||
|
|
||||||
|
|
||||||
class PrefetchQuery:
|
class PrefetchQuery:
|
||||||
def __init__(
|
def __init__( # noqa: CFQ002
|
||||||
self,
|
self,
|
||||||
model_cls: Type["Model"],
|
model_cls: Type["Model"],
|
||||||
fields: Optional[Union[Dict, Set]],
|
fields: Optional[Union[Dict, Set]],
|
||||||
exclude_fields: Optional[Union[Dict, Set]],
|
exclude_fields: Optional[Union[Dict, Set]],
|
||||||
prefetch_related: List,
|
prefetch_related: List,
|
||||||
select_related: List,
|
select_related: List,
|
||||||
|
orders_by: List,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self.model = model_cls
|
self.model = model_cls
|
||||||
@ -51,6 +85,8 @@ class PrefetchQuery:
|
|||||||
self.already_extracted: Dict = dict()
|
self.already_extracted: Dict = dict()
|
||||||
self.models: Dict = {}
|
self.models: Dict = {}
|
||||||
self.select_dict = translate_list_to_dict(self._select_related)
|
self.select_dict = translate_list_to_dict(self._select_related)
|
||||||
|
self.orders_by = orders_by or []
|
||||||
|
self.order_dict = translate_list_to_dict(self.orders_by, is_order=True)
|
||||||
|
|
||||||
async def prefetch_related(
|
async def prefetch_related(
|
||||||
self, models: Sequence["Model"], rows: List
|
self, models: Sequence["Model"], rows: List
|
||||||
@ -128,7 +164,9 @@ class PrefetchQuery:
|
|||||||
return filter_clauses
|
return filter_clauses
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model":
|
def _populate_nested_related(
|
||||||
|
self, model: "Model", prefetch_dict: Dict, orders_by: Dict,
|
||||||
|
) -> "Model":
|
||||||
|
|
||||||
related_to_extract = model.get_filtered_names_to_extract(
|
related_to_extract = model.get_filtered_names_to_extract(
|
||||||
prefetch_dict=prefetch_dict
|
prefetch_dict=prefetch_dict
|
||||||
@ -146,25 +184,17 @@ class PrefetchQuery:
|
|||||||
|
|
||||||
children = self.already_extracted.get(target_model, {}).get(field_name, {})
|
children = self.already_extracted.get(target_model, {}).get(field_name, {})
|
||||||
models = self.already_extracted.get(target_model, {}).get("pk_models", {})
|
models = self.already_extracted.get(target_model, {}).get("pk_models", {})
|
||||||
self._set_children_on_model(
|
set_children_on_model(
|
||||||
model=model,
|
model=model,
|
||||||
related=related,
|
related=related,
|
||||||
children=children,
|
children=children,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
models=models,
|
models=models,
|
||||||
|
orders_by=orders_by.get(related, {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _set_children_on_model(
|
|
||||||
model: "Model", related: str, children: Dict, model_id: int, models: Dict
|
|
||||||
) -> None:
|
|
||||||
for key, child_models in children.items():
|
|
||||||
if key == model_id:
|
|
||||||
for child in child_models:
|
|
||||||
setattr(model, related, models.get(child))
|
|
||||||
|
|
||||||
async def _prefetch_related_models(
|
async def _prefetch_related_models(
|
||||||
self, models: Sequence["Model"], rows: List
|
self, models: Sequence["Model"], rows: List
|
||||||
) -> Sequence["Model"]:
|
) -> Sequence["Model"]:
|
||||||
@ -174,6 +204,7 @@ class PrefetchQuery:
|
|||||||
target_model = self.model
|
target_model = self.model
|
||||||
fields = self._columns
|
fields = self._columns
|
||||||
exclude_fields = self._exclude_columns
|
exclude_fields = self._exclude_columns
|
||||||
|
orders_by = self.order_dict
|
||||||
for related in prefetch_dict.keys():
|
for related in prefetch_dict.keys():
|
||||||
await self._extract_related_models(
|
await self._extract_related_models(
|
||||||
related=related,
|
related=related,
|
||||||
@ -182,11 +213,14 @@ class PrefetchQuery:
|
|||||||
select_dict=select_dict.get(related, {}),
|
select_dict=select_dict.get(related, {}),
|
||||||
fields=fields,
|
fields=fields,
|
||||||
exclude_fields=exclude_fields,
|
exclude_fields=exclude_fields,
|
||||||
|
orders_by=orders_by.get(related, {}),
|
||||||
)
|
)
|
||||||
final_models = []
|
final_models = []
|
||||||
for model in models:
|
for model in models:
|
||||||
final_models.append(
|
final_models.append(
|
||||||
self._populate_nested_related(model=model, prefetch_dict=prefetch_dict,)
|
self._populate_nested_related(
|
||||||
|
model=model, prefetch_dict=prefetch_dict, orders_by=self.order_dict
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return models
|
return models
|
||||||
|
|
||||||
@ -198,6 +232,7 @@ class PrefetchQuery:
|
|||||||
select_dict: Dict,
|
select_dict: Dict,
|
||||||
fields: Union[Set[Any], Dict[Any, Any], None],
|
fields: Union[Set[Any], Dict[Any, Any], None],
|
||||||
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
|
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
|
||||||
|
orders_by: Dict,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
fields = target_model.get_included(fields, related)
|
fields = target_model.get_included(fields, related)
|
||||||
@ -246,6 +281,7 @@ class PrefetchQuery:
|
|||||||
),
|
),
|
||||||
fields=fields,
|
fields=fields,
|
||||||
exclude_fields=exclude_fields,
|
exclude_fields=exclude_fields,
|
||||||
|
orders_by=self._get_select_related_if_apply(subrelated, orders_by),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not already_loaded:
|
if not already_loaded:
|
||||||
@ -257,10 +293,13 @@ class PrefetchQuery:
|
|||||||
fields=fields,
|
fields=fields,
|
||||||
exclude_fields=exclude_fields,
|
exclude_fields=exclude_fields,
|
||||||
prefetch_dict=prefetch_dict,
|
prefetch_dict=prefetch_dict,
|
||||||
|
orders_by=orders_by,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._update_already_loaded_rows(
|
self._update_already_loaded_rows(
|
||||||
target_field=target_field, prefetch_dict=prefetch_dict,
|
target_field=target_field,
|
||||||
|
prefetch_dict=prefetch_dict,
|
||||||
|
orders_by=orders_by,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _run_prefetch_query(
|
async def _run_prefetch_query(
|
||||||
@ -310,12 +349,12 @@ class PrefetchQuery:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _update_already_loaded_rows( # noqa: CFQ002
|
def _update_already_loaded_rows( # noqa: CFQ002
|
||||||
self, target_field: Type["BaseField"], prefetch_dict: Dict,
|
self, target_field: Type["BaseField"], prefetch_dict: Dict, orders_by: Dict,
|
||||||
) -> None:
|
) -> None:
|
||||||
target_model = target_field.to
|
target_model = target_field.to
|
||||||
for instance in self.models.get(target_model.get_name(), []):
|
for instance in self.models.get(target_model.get_name(), []):
|
||||||
self._populate_nested_related(
|
self._populate_nested_related(
|
||||||
model=instance, prefetch_dict=prefetch_dict,
|
model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by
|
||||||
)
|
)
|
||||||
|
|
||||||
def _populate_rows( # noqa: CFQ002
|
def _populate_rows( # noqa: CFQ002
|
||||||
@ -327,6 +366,7 @@ class PrefetchQuery:
|
|||||||
fields: Union[Set[Any], Dict[Any, Any], None],
|
fields: Union[Set[Any], Dict[Any, Any], None],
|
||||||
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
|
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
|
||||||
prefetch_dict: Dict,
|
prefetch_dict: Dict,
|
||||||
|
orders_by: Dict,
|
||||||
) -> None:
|
) -> None:
|
||||||
target_model = target_field.to
|
target_model = target_field.to
|
||||||
for row in rows:
|
for row in rows:
|
||||||
@ -340,7 +380,7 @@ class PrefetchQuery:
|
|||||||
)
|
)
|
||||||
instance = target_model(**item)
|
instance = target_model(**item)
|
||||||
instance = self._populate_nested_related(
|
instance = self._populate_nested_related(
|
||||||
model=instance, prefetch_dict=prefetch_dict,
|
model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by
|
||||||
)
|
)
|
||||||
field_db_name = target_model.get_column_alias(field_name)
|
field_db_name = target_model.get_column_alias(field_name)
|
||||||
models = self.already_extracted[target_model.get_name()].setdefault(
|
models = self.already_extracted[target_model.get_name()].setdefault(
|
||||||
|
|||||||
@ -75,6 +75,7 @@ class QuerySet:
|
|||||||
exclude_fields=self._exclude_columns,
|
exclude_fields=self._exclude_columns,
|
||||||
prefetch_related=self._prefetch_related,
|
prefetch_related=self._prefetch_related,
|
||||||
select_related=self._select_related,
|
select_related=self._select_related,
|
||||||
|
orders_by=self.order_bys,
|
||||||
)
|
)
|
||||||
return await query.prefetch_related(models=models, rows=rows) # type: ignore
|
return await query.prefetch_related(models=models, rows=rows) # type: ignore
|
||||||
|
|
||||||
|
|||||||
@ -14,18 +14,28 @@ def check_node_not_dict_or_not_last_node(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001
|
def translate_list_to_dict( # noqa: CCR001
|
||||||
|
list_to_trans: Union[List, Set], is_order: bool = False
|
||||||
|
) -> Dict:
|
||||||
new_dict: Dict = dict()
|
new_dict: Dict = dict()
|
||||||
for path in list_to_trans:
|
for path in list_to_trans:
|
||||||
current_level = new_dict
|
current_level = new_dict
|
||||||
parts = path.split("__")
|
parts = path.split("__")
|
||||||
|
def_val: Any = ...
|
||||||
|
if is_order:
|
||||||
|
if parts[0][0] == "-":
|
||||||
|
def_val = "desc"
|
||||||
|
parts[0] = parts[0][1:]
|
||||||
|
else:
|
||||||
|
def_val = "asc"
|
||||||
|
|
||||||
for part in parts:
|
for part in parts:
|
||||||
if check_node_not_dict_or_not_last_node(
|
if check_node_not_dict_or_not_last_node(
|
||||||
part=part, parts=parts, current_level=current_level
|
part=part, parts=parts, current_level=current_level
|
||||||
):
|
):
|
||||||
current_level[part] = dict()
|
current_level[part] = dict()
|
||||||
elif part not in current_level:
|
elif part not in current_level:
|
||||||
current_level[part] = ...
|
current_level[part] = def_val
|
||||||
current_level = current_level[part]
|
current_level = current_level[part]
|
||||||
return new_dict
|
return new_dict
|
||||||
|
|
||||||
|
|||||||
@ -228,6 +228,7 @@ async def test_prefetch_related_with_select_related():
|
|||||||
|
|
||||||
album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related(
|
album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related(
|
||||||
['cover_pictures', 'shops__division']).get()
|
['cover_pictures', 'shops__division']).get()
|
||||||
|
|
||||||
assert len(album.tracks) == 0
|
assert len(album.tracks) == 0
|
||||||
assert len(album.cover_pictures) == 2
|
assert len(album.cover_pictures) == 2
|
||||||
assert album.shops[0].division.name == 'Div 1'
|
assert album.shops[0].division.name == 'Div 1'
|
||||||
@ -240,14 +241,15 @@ async def test_prefetch_related_with_select_related():
|
|||||||
|
|
||||||
album = await Album.objects.select_related('tracks__tonation__rand_set').filter(
|
album = await Album.objects.select_related('tracks__tonation__rand_set').filter(
|
||||||
name='Malibu').prefetch_related(
|
name='Malibu').prefetch_related(
|
||||||
['cover_pictures', 'shops__division']).get()
|
['cover_pictures', 'shops__division']).order_by(
|
||||||
|
['-shops__name', '-cover_pictures__artist', 'shops__division__name']).get()
|
||||||
assert len(album.tracks) == 3
|
assert len(album.tracks) == 3
|
||||||
assert album.tracks[0].tonation == album.tracks[2].tonation == ton1
|
assert album.tracks[0].tonation == album.tracks[2].tonation == ton1
|
||||||
assert len(album.cover_pictures) == 2
|
assert len(album.cover_pictures) == 2
|
||||||
assert album.cover_pictures[0].artist == 'Artist 1'
|
assert album.cover_pictures[0].artist == 'Artist 2'
|
||||||
|
|
||||||
assert len(album.shops) == 2
|
assert len(album.shops) == 2
|
||||||
assert album.shops[0].name == 'Shop 1'
|
assert album.shops[0].name == 'Shop 2'
|
||||||
assert album.shops[0].division.name == 'Div 1'
|
assert album.shops[0].division.name == 'Div 1'
|
||||||
|
|
||||||
track = await Track.objects.select_related('album').prefetch_related(
|
track = await Track.objects.select_related('album').prefetch_related(
|
||||||
|
|||||||
@ -1,5 +1,11 @@
|
|||||||
|
import databases
|
||||||
|
import sqlalchemy
|
||||||
|
|
||||||
|
import ormar
|
||||||
from ormar.models.excludable import Excludable
|
from ormar.models.excludable import Excludable
|
||||||
|
from ormar.queryset.prefetch_query import sort_models
|
||||||
from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update
|
from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update
|
||||||
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
|
|
||||||
def test_empty_excludable():
|
def test_empty_excludable():
|
||||||
@ -96,3 +102,43 @@ def test_updating_dict_inc_set_with_dict_inc_set():
|
|||||||
"cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis},
|
"cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis},
|
||||||
"uu": Ellipsis,
|
"uu": Ellipsis,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||||
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
|
|
||||||
|
class SortModel(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "sorts"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: int = ormar.Integer(primary_key=True)
|
||||||
|
name: str = ormar.String(max_length=100)
|
||||||
|
sort_order: int = ormar.Integer()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sorting_models():
|
||||||
|
models = [
|
||||||
|
SortModel(id=1, name='Alice', sort_order=0),
|
||||||
|
SortModel(id=2, name='Al', sort_order=1),
|
||||||
|
SortModel(id=3, name='Zake', sort_order=1),
|
||||||
|
SortModel(id=4, name='Will', sort_order=0),
|
||||||
|
SortModel(id=5, name='Al', sort_order=2),
|
||||||
|
SortModel(id=6, name='Alice', sort_order=2)
|
||||||
|
]
|
||||||
|
orders_by = {'name': 'asc', 'none': {}, 'sort_order': 'desc'}
|
||||||
|
models = sort_models(models, orders_by)
|
||||||
|
assert models[5].name == 'Zake'
|
||||||
|
assert models[0].name == 'Al'
|
||||||
|
assert models[1].name == 'Al'
|
||||||
|
assert [model.id for model in models] == [5, 2, 6, 1, 4, 3]
|
||||||
|
|
||||||
|
orders_by = {'name': 'asc', 'none': set('aa'), 'id': 'asc'}
|
||||||
|
models = sort_models(models, orders_by)
|
||||||
|
assert [model.id for model in models] == [2, 5, 1, 6, 4, 3]
|
||||||
|
|
||||||
|
orders_by = {'sort_order': 'asc', 'none': ..., 'id': 'asc', 'uu': 2, 'aa': None}
|
||||||
|
models = sort_models(models, orders_by)
|
||||||
|
assert [model.id for model in models] == [1, 4, 2, 3, 5, 6]
|
||||||
|
|||||||
Reference in New Issue
Block a user