add order_by support for prefetch_related

This commit is contained in:
collerek
2020-11-26 11:17:33 +01:00
parent bdea9b51ed
commit 78d1241807
5 changed files with 121 additions and 22 deletions

View File

@ -32,14 +32,48 @@ def add_relation_field_to_fields(
return fields return fields
def sort_models(models: List["Model"], orders_by: Dict) -> List["Model"]:
sort_criteria = [
(key, value) for key, value in orders_by.items() if isinstance(value, str)
]
sort_criteria = sort_criteria[::-1]
for criteria in sort_criteria:
if criteria[1] == "desc":
models.sort(key=lambda x: getattr(x, criteria[0]), reverse=True)
else:
models.sort(key=lambda x: getattr(x, criteria[0]))
return models
def set_children_on_model( # noqa: CCR001
model: "Model",
related: str,
children: Dict,
model_id: int,
models: Dict,
orders_by: Dict,
) -> None:
for key, child_models in children.items():
if key == model_id:
models_to_set = [models[child] for child in sorted(child_models)]
if models_to_set:
if orders_by and any(isinstance(x, str) for x in orders_by.values()):
models_to_set = sort_models(
models=models_to_set, orders_by=orders_by
)
for child in models_to_set:
setattr(model, related, child)
class PrefetchQuery: class PrefetchQuery:
def __init__( def __init__( # noqa: CFQ002
self, self,
model_cls: Type["Model"], model_cls: Type["Model"],
fields: Optional[Union[Dict, Set]], fields: Optional[Union[Dict, Set]],
exclude_fields: Optional[Union[Dict, Set]], exclude_fields: Optional[Union[Dict, Set]],
prefetch_related: List, prefetch_related: List,
select_related: List, select_related: List,
orders_by: List,
) -> None: ) -> None:
self.model = model_cls self.model = model_cls
@ -51,6 +85,8 @@ class PrefetchQuery:
self.already_extracted: Dict = dict() self.already_extracted: Dict = dict()
self.models: Dict = {} self.models: Dict = {}
self.select_dict = translate_list_to_dict(self._select_related) self.select_dict = translate_list_to_dict(self._select_related)
self.orders_by = orders_by or []
self.order_dict = translate_list_to_dict(self.orders_by, is_order=True)
async def prefetch_related( async def prefetch_related(
self, models: Sequence["Model"], rows: List self, models: Sequence["Model"], rows: List
@ -128,7 +164,9 @@ class PrefetchQuery:
return filter_clauses return filter_clauses
return [] return []
def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model": def _populate_nested_related(
self, model: "Model", prefetch_dict: Dict, orders_by: Dict,
) -> "Model":
related_to_extract = model.get_filtered_names_to_extract( related_to_extract = model.get_filtered_names_to_extract(
prefetch_dict=prefetch_dict prefetch_dict=prefetch_dict
@ -146,25 +184,17 @@ class PrefetchQuery:
children = self.already_extracted.get(target_model, {}).get(field_name, {}) children = self.already_extracted.get(target_model, {}).get(field_name, {})
models = self.already_extracted.get(target_model, {}).get("pk_models", {}) models = self.already_extracted.get(target_model, {}).get("pk_models", {})
self._set_children_on_model( set_children_on_model(
model=model, model=model,
related=related, related=related,
children=children, children=children,
model_id=model_id, model_id=model_id,
models=models, models=models,
orders_by=orders_by.get(related, {}),
) )
return model return model
@staticmethod
def _set_children_on_model(
model: "Model", related: str, children: Dict, model_id: int, models: Dict
) -> None:
for key, child_models in children.items():
if key == model_id:
for child in child_models:
setattr(model, related, models.get(child))
async def _prefetch_related_models( async def _prefetch_related_models(
self, models: Sequence["Model"], rows: List self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]: ) -> Sequence["Model"]:
@ -174,6 +204,7 @@ class PrefetchQuery:
target_model = self.model target_model = self.model
fields = self._columns fields = self._columns
exclude_fields = self._exclude_columns exclude_fields = self._exclude_columns
orders_by = self.order_dict
for related in prefetch_dict.keys(): for related in prefetch_dict.keys():
await self._extract_related_models( await self._extract_related_models(
related=related, related=related,
@ -182,11 +213,14 @@ class PrefetchQuery:
select_dict=select_dict.get(related, {}), select_dict=select_dict.get(related, {}),
fields=fields, fields=fields,
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
orders_by=orders_by.get(related, {}),
) )
final_models = [] final_models = []
for model in models: for model in models:
final_models.append( final_models.append(
self._populate_nested_related(model=model, prefetch_dict=prefetch_dict,) self._populate_nested_related(
model=model, prefetch_dict=prefetch_dict, orders_by=self.order_dict
)
) )
return models return models
@ -198,6 +232,7 @@ class PrefetchQuery:
select_dict: Dict, select_dict: Dict,
fields: Union[Set[Any], Dict[Any, Any], None], fields: Union[Set[Any], Dict[Any, Any], None],
exclude_fields: Union[Set[Any], Dict[Any, Any], None], exclude_fields: Union[Set[Any], Dict[Any, Any], None],
orders_by: Dict,
) -> None: ) -> None:
fields = target_model.get_included(fields, related) fields = target_model.get_included(fields, related)
@ -246,6 +281,7 @@ class PrefetchQuery:
), ),
fields=fields, fields=fields,
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
orders_by=self._get_select_related_if_apply(subrelated, orders_by),
) )
if not already_loaded: if not already_loaded:
@ -257,10 +293,13 @@ class PrefetchQuery:
fields=fields, fields=fields,
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
prefetch_dict=prefetch_dict, prefetch_dict=prefetch_dict,
orders_by=orders_by,
) )
else: else:
self._update_already_loaded_rows( self._update_already_loaded_rows(
target_field=target_field, prefetch_dict=prefetch_dict, target_field=target_field,
prefetch_dict=prefetch_dict,
orders_by=orders_by,
) )
async def _run_prefetch_query( async def _run_prefetch_query(
@ -310,12 +349,12 @@ class PrefetchQuery:
) )
def _update_already_loaded_rows( # noqa: CFQ002 def _update_already_loaded_rows( # noqa: CFQ002
self, target_field: Type["BaseField"], prefetch_dict: Dict, self, target_field: Type["BaseField"], prefetch_dict: Dict, orders_by: Dict,
) -> None: ) -> None:
target_model = target_field.to target_model = target_field.to
for instance in self.models.get(target_model.get_name(), []): for instance in self.models.get(target_model.get_name(), []):
self._populate_nested_related( self._populate_nested_related(
model=instance, prefetch_dict=prefetch_dict, model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by
) )
def _populate_rows( # noqa: CFQ002 def _populate_rows( # noqa: CFQ002
@ -327,6 +366,7 @@ class PrefetchQuery:
fields: Union[Set[Any], Dict[Any, Any], None], fields: Union[Set[Any], Dict[Any, Any], None],
exclude_fields: Union[Set[Any], Dict[Any, Any], None], exclude_fields: Union[Set[Any], Dict[Any, Any], None],
prefetch_dict: Dict, prefetch_dict: Dict,
orders_by: Dict,
) -> None: ) -> None:
target_model = target_field.to target_model = target_field.to
for row in rows: for row in rows:
@ -340,7 +380,7 @@ class PrefetchQuery:
) )
instance = target_model(**item) instance = target_model(**item)
instance = self._populate_nested_related( instance = self._populate_nested_related(
model=instance, prefetch_dict=prefetch_dict, model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by
) )
field_db_name = target_model.get_column_alias(field_name) field_db_name = target_model.get_column_alias(field_name)
models = self.already_extracted[target_model.get_name()].setdefault( models = self.already_extracted[target_model.get_name()].setdefault(

View File

@ -75,6 +75,7 @@ class QuerySet:
exclude_fields=self._exclude_columns, exclude_fields=self._exclude_columns,
prefetch_related=self._prefetch_related, prefetch_related=self._prefetch_related,
select_related=self._select_related, select_related=self._select_related,
orders_by=self.order_bys,
) )
return await query.prefetch_related(models=models, rows=rows) # type: ignore return await query.prefetch_related(models=models, rows=rows) # type: ignore

View File

@ -14,18 +14,28 @@ def check_node_not_dict_or_not_last_node(
) )
def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001 def translate_list_to_dict( # noqa: CCR001
list_to_trans: Union[List, Set], is_order: bool = False
) -> Dict:
new_dict: Dict = dict() new_dict: Dict = dict()
for path in list_to_trans: for path in list_to_trans:
current_level = new_dict current_level = new_dict
parts = path.split("__") parts = path.split("__")
def_val: Any = ...
if is_order:
if parts[0][0] == "-":
def_val = "desc"
parts[0] = parts[0][1:]
else:
def_val = "asc"
for part in parts: for part in parts:
if check_node_not_dict_or_not_last_node( if check_node_not_dict_or_not_last_node(
part=part, parts=parts, current_level=current_level part=part, parts=parts, current_level=current_level
): ):
current_level[part] = dict() current_level[part] = dict()
elif part not in current_level: elif part not in current_level:
current_level[part] = ... current_level[part] = def_val
current_level = current_level[part] current_level = current_level[part]
return new_dict return new_dict

View File

@ -228,6 +228,7 @@ async def test_prefetch_related_with_select_related():
album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related( album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related(
['cover_pictures', 'shops__division']).get() ['cover_pictures', 'shops__division']).get()
assert len(album.tracks) == 0 assert len(album.tracks) == 0
assert len(album.cover_pictures) == 2 assert len(album.cover_pictures) == 2
assert album.shops[0].division.name == 'Div 1' assert album.shops[0].division.name == 'Div 1'
@ -240,14 +241,15 @@ async def test_prefetch_related_with_select_related():
album = await Album.objects.select_related('tracks__tonation__rand_set').filter( album = await Album.objects.select_related('tracks__tonation__rand_set').filter(
name='Malibu').prefetch_related( name='Malibu').prefetch_related(
['cover_pictures', 'shops__division']).get() ['cover_pictures', 'shops__division']).order_by(
['-shops__name', '-cover_pictures__artist', 'shops__division__name']).get()
assert len(album.tracks) == 3 assert len(album.tracks) == 3
assert album.tracks[0].tonation == album.tracks[2].tonation == ton1 assert album.tracks[0].tonation == album.tracks[2].tonation == ton1
assert len(album.cover_pictures) == 2 assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].artist == 'Artist 1' assert album.cover_pictures[0].artist == 'Artist 2'
assert len(album.shops) == 2 assert len(album.shops) == 2
assert album.shops[0].name == 'Shop 1' assert album.shops[0].name == 'Shop 2'
assert album.shops[0].division.name == 'Div 1' assert album.shops[0].division.name == 'Div 1'
track = await Track.objects.select_related('album').prefetch_related( track = await Track.objects.select_related('album').prefetch_related(

View File

@ -1,5 +1,11 @@
import databases
import sqlalchemy
import ormar
from ormar.models.excludable import Excludable from ormar.models.excludable import Excludable
from ormar.queryset.prefetch_query import sort_models
from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update
from tests.settings import DATABASE_URL
def test_empty_excludable(): def test_empty_excludable():
@ -96,3 +102,43 @@ def test_updating_dict_inc_set_with_dict_inc_set():
"cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis}, "cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis},
"uu": Ellipsis, "uu": Ellipsis,
} }
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class SortModel(ormar.Model):
class Meta:
tablename = "sorts"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
sort_order: int = ormar.Integer()
def test_sorting_models():
models = [
SortModel(id=1, name='Alice', sort_order=0),
SortModel(id=2, name='Al', sort_order=1),
SortModel(id=3, name='Zake', sort_order=1),
SortModel(id=4, name='Will', sort_order=0),
SortModel(id=5, name='Al', sort_order=2),
SortModel(id=6, name='Alice', sort_order=2)
]
orders_by = {'name': 'asc', 'none': {}, 'sort_order': 'desc'}
models = sort_models(models, orders_by)
assert models[5].name == 'Zake'
assert models[0].name == 'Al'
assert models[1].name == 'Al'
assert [model.id for model in models] == [5, 2, 6, 1, 4, 3]
orders_by = {'name': 'asc', 'none': set('aa'), 'id': 'asc'}
models = sort_models(models, orders_by)
assert [model.id for model in models] == [2, 5, 1, 6, 4, 3]
orders_by = {'sort_order': 'asc', 'none': ..., 'id': 'asc', 'uu': 2, 'aa': None}
models = sort_models(models, orders_by)
assert [model.id for model in models] == [1, 4, 2, 3, 5, 6]