Merge pull request #54 from collerek/prefetch_related

Add prefetch related QuerySet method
This commit is contained in:
collerek
2020-11-26 18:35:14 +07:00
committed by GitHub
13 changed files with 1123 additions and 30 deletions

View File

@ -5,7 +5,8 @@ name: build
on: on:
push: push:
branches: [ master ] branches-ignore:
- 'gh-pages'
pull_request: pull_request:
branches: [ master ] branches: [ master ]

View File

@ -45,7 +45,7 @@ Ormar is built with:
* [`SQLAlchemy core`][sqlalchemy-core] for query building. * [`SQLAlchemy core`][sqlalchemy-core] for query building.
* [`databases`][databases] for cross-database async support. * [`databases`][databases] for cross-database async support.
* [`pydantic`][pydantic] for data validation. * [`pydantic`][pydantic] for data validation.
* typing_extensions for python 3.6 - 3.7 * `typing_extensions` for python 3.6 - 3.7
### Migrations ### Migrations
@ -53,7 +53,8 @@ Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to p
database migrations. database migrations.
**ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.4.0` **ormar is still under development:**
We recommend pinning any dependencies (with i.e. `ormar~=0.5.2`)
### Quick Start ### Quick Start
@ -157,6 +158,7 @@ assert len(tracks) == 1
* `filter(**kwargs) -> QuerySet` * `filter(**kwargs) -> QuerySet`
* `exclude(**kwargs) -> QuerySet` * `exclude(**kwargs) -> QuerySet`
* `select_related(related: Union[List, str]) -> QuerySet` * `select_related(related: Union[List, str]) -> QuerySet`
* `prefetch_related(related: Union[List, str]) -> QuerySet`
* `limit(limit_count: int) -> QuerySet` * `limit(limit_count: int) -> QuerySet`
* `offset(offset: int) -> QuerySet` * `offset(offset: int) -> QuerySet`
* `count() -> int` * `count() -> int`
@ -165,6 +167,7 @@ assert len(tracks) == 1
* `exclude_fields(columns: Union[List, str, set, dict]) -> QuerySet` * `exclude_fields(columns: Union[List, str, set, dict]) -> QuerySet`
* `order_by(columns:Union[List, str]) -> QuerySet` * `order_by(columns:Union[List, str]) -> QuerySet`
#### Relation types #### Relation types
* One to many - with `ForeignKey(to: Model)` * One to many - with `ForeignKey(to: Model)`

View File

@ -45,7 +45,7 @@ Ormar is built with:
* [`SQLAlchemy core`][sqlalchemy-core] for query building. * [`SQLAlchemy core`][sqlalchemy-core] for query building.
* [`databases`][databases] for cross-database async support. * [`databases`][databases] for cross-database async support.
* [`pydantic`][pydantic] for data validation. * [`pydantic`][pydantic] for data validation.
* typing_extensions for python 3.6 - 3.7 * `typing_extensions` for python 3.6 - 3.7
### Migrations ### Migrations
@ -53,7 +53,8 @@ Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to p
database migrations. database migrations.
**ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.4.0` **ormar is still under development:**
We recommend pinning any dependencies (with i.e. `ormar~=0.5.2`)
### Quick Start ### Quick Start
@ -157,6 +158,7 @@ assert len(tracks) == 1
* `filter(**kwargs) -> QuerySet` * `filter(**kwargs) -> QuerySet`
* `exclude(**kwargs) -> QuerySet` * `exclude(**kwargs) -> QuerySet`
* `select_related(related: Union[List, str]) -> QuerySet` * `select_related(related: Union[List, str]) -> QuerySet`
* `prefetch_related(related: Union[List, str]) -> QuerySet`
* `limit(limit_count: int) -> QuerySet` * `limit(limit_count: int) -> QuerySet`
* `offset(offset: int) -> QuerySet` * `offset(offset: int) -> QuerySet`
* `count() -> int` * `count() -> int`

View File

@ -253,11 +253,25 @@ notes = await Track.objects.exclude(position_gt=3).all()
`select_related(related: Union[List, str]) -> QuerySet` `select_related(related: Union[List, str]) -> QuerySet`
Allows to prefetch related models. Allows to prefetch related models during the same query.
**With `select_related` always only one query is run against the database**, meaning that one
(sometimes complicated) join is generated and later nested models are processed in python.
To fetch related model use `ForeignKey` names. To fetch related model use `ForeignKey` names.
To chain related `Models` relation use double underscore. To chain related `Models` relation use double underscores between names.
!!!note
If you are coming from `django` note that `ormar` `select_related` differs -> in `django` you can `select_related`
only singe relation types, while in `ormar` you can select related across `ForeignKey` relation,
reverse side of `ForeignKey` (so virtual auto generated keys) and `ManyToMany` fields (so all relations as of current version).
!!!tip
To control which model fields to select use `fields()` and `exclude_fields()` `QuerySet` methods.
!!!tip
To control order of models (both main or nested) use `order_by()` method.
```python ```python
album = await Album.objects.select_related("tracks").all() album = await Album.objects.select_related("tracks").all()
@ -286,6 +300,150 @@ Exactly the same behavior is for Many2Many fields, where you put the names of Ma
Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()`
### prefetch_related
`prefetch_related(related: Union[List, str]) -> QuerySet`
Allows to prefetch related models during query - but opposite to `select_related` each
subsequent model is fetched in a separate database query.
**With `prefetch_related` always one query per Model is run against the database**,
meaning that you will have multiple queries executed one after another.
To fetch related model use `ForeignKey` names.
To chain related `Models` relation use double underscores between names.
!!!tip
To control which model fields to select use `fields()` and `exclude_fields()` `QuerySet` methods.
!!!tip
To control order of models (both main or nested) use `order_by()` method.
```python
album = await Album.objects.prefetch_related("tracks").all()
# will return album will all columns tracks
```
You can provide a string or a list of strings
```python
classes = await SchoolClass.objects.prefetch_related(
["teachers__category", "students"]).all()
# will return classes with teachers and teachers categories
# as well as classes students
```
Exactly the same behavior is for Many2Many fields, where you put the names of Many2Many fields and the final `Models` are fetched for you.
!!!warning
If you set `ForeignKey` field as not nullable (so required) during
all queries the not nullable `Models` will be auto prefetched, even if you do not include them in select_related.
!!!note
All methods that do not return the rows explicitly returns a QueySet instance so you can chain them together
So operations like `filter()`, `select_related()`, `limit()` and `offset()` etc. can be chained.
Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()`
### select_related vs prefetch_related
Which should you use -> `select_related` or `prefetch_related`?
Well, it really depends on your data. The best answer is try yourself and see which one performs faster/better in your system constraints.
What to keep in mind:
#### Performance
**Number of queries**:
`select_related` always executes one query against the database, while `prefetch_related` executes multiple queries.
Usually the query (I/O) operation is the slowest one but it does not have to be.
**Number of rows**:
Imagine that you have 10 000 object in one table A and each of those objects have 3 children in table B,
and subsequently each object in table B has 2 children in table C. Something like this:
```
Model C
/
Model B - Model C
/
Model A - Model B - Model C
\ \
\ Model C
\
Model B - Model C
\
Model C
```
That means that `select_related` will always return 60 000 rows (10 000 * 3 * 2) later compacted to 10 000 models.
How many rows will return `prefetch_related`?
Well, that depends, if each of models B and C is unique it will return 10 000 rows in first query, 30 000 rows
(each of 3 children of A in table B are unique) in second query and 60 000 rows (each of 2 children of model B
in table C are unique) in 3rd query.
In this case `select_related` seems like a better choice, not only it will run one query comparing to 3 of
`prefetch_related` but will also return 60 000 rows comparing to 100 000 of `prefetch_related` (10+30+60k).
But what if each Model A has exactly the same 3 models B and each models C has exactly same models C? `select_related`
will still return 60 000 rows, while `prefetch_related` will return 10 000 for model A, 3 rows for model B and 2 rows for Model C.
So in total 10 006 rows. Now depending on the structure of models (i.e. if it has long Text() fields etc.) `prefetch_related`
might be faster despite it needs to perform three separate queries instead of one.
#### Memory
`ormar` is a mini ORM meaning that it does not keep a registry of already loaded models.
That means that in `select_related` example above you will always have 10 000 Models A, 30 000 Models B
(even if the unique number of rows in db is 3 - processing of `select_related` spawns **new** child models for each parent model).
And 60 000 Models C.
If the same Model B is shared by rows 1, 10, 100 etc. and you update one of those, the rest of rows
that share the same child will **not** be updated on the spot.
If you persist your changes into the database the change **will be available only after reload
(either each child separately or the whole query again)**.
That means that `select_related` will use more memory as each child is instantiated as a new object - obviously using it's own space.
!!!note
This might change in future versions if we decide to introduce caching.
!!!warning
By default all children (or event the same models loaded 2+ times) are completely independent, distinct python objects, despite that they represent the same row in db.
They will evaluate to True when compared, so in example above:
```python
# will return True if child1 of both rows is the same child db row
row1.child1 == row100.child1
# same here:
model1 = await Model.get(pk=1)
model2 = await Model.get(pk=1) # same pk = same row in db
# will return `True`
model1 == model2
```
but
```python
# will return False (note that id is a python `builtin` function not ormar one).
id(row1.child1) == (ro100.child1)
# from above - will also return False
id(model1) == id(model2)
```
On the contrary - with `prefetch_related` each unique distinct child model is instantiated
only once and the same child models is shared across all parent models.
That means that in `prefetch_related` example above if there are 3 distinct models in table B and 2 in table C,
there will be only 5 children nested models shared between all model A instances. That also means that if you update
any attribute it will be updated on all parents as they share the same child object.
### limit ### limit
@ -352,6 +510,10 @@ has_sample = await Book.objects.filter(title='Sample').exists()
With `fields()` you can select subset of model columns to limit the data load. With `fields()` you can select subset of model columns to limit the data load.
!!!note
Note that `fields()` and `exclude_fields()` works both for main models (on normal queries like `get`, `all` etc.)
as well as `select_related` and `prefetch_related` models (with nested notation).
Given a sample data like following: Given a sample data like following:
```python ```python
@ -433,6 +595,10 @@ It's the opposite of `fields()` method so check documentation above to see what
Especially check above how you can pass also nested dictionaries and sets as a mask to exclude fields from whole hierarchy. Especially check above how you can pass also nested dictionaries and sets as a mask to exclude fields from whole hierarchy.
!!!note
Note that `fields()` and `exclude_fields()` works both for main models (on normal queries like `get`, `all` etc.)
as well as `select_related` and `prefetch_related` models (with nested notation).
Below you can find few simple examples: Below you can find few simple examples:
```python hl_lines="47 48 60 61 67" ```python hl_lines="47 48 60 61 67"

View File

@ -1,3 +1,13 @@
# 0.5.2
* Added `prefetch_related` method to load subsequent models in separate queries.
* Update docs
# 0.5.1
* Switched to github actions instead of travis
* Update badges in the docs
# 0.5.0 # 0.5.0
* Added save status -> you can check if model is saved with `ModelInstance.saved` property * Added save status -> you can check if model is saved with `ModelInstance.saved` property

View File

@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType() Undefined = UndefinedType()
__version__ = "0.5.1" __version__ = "0.5.2"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",

View File

@ -1,12 +1,15 @@
import inspect import inspect
from collections import OrderedDict from collections import OrderedDict
from typing import ( from typing import (
Any,
Callable,
Dict, Dict,
List, List,
Optional, Optional,
Sequence, Sequence,
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
Tuple,
Type, Type,
TypeVar, TypeVar,
Union, Union,
@ -38,6 +41,8 @@ class ModelTableProxy:
Meta: ModelMeta Meta: ModelMeta
_related_names: Set _related_names: Set
_related_names_hash: Union[str, bytes] _related_names_hash: Union[str, bytes]
pk: Any
get_name: Callable
def dict(self): # noqa A003 def dict(self): # noqa A003
raise NotImplementedError # pragma no cover raise NotImplementedError # pragma no cover
@ -47,6 +52,64 @@ class ModelTableProxy:
self_fields = {k: v for k, v in self.dict().items() if k not in related_names} self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
return self_fields return self_fields
@classmethod
def get_related_field_name(cls, target_field: Type["BaseField"]) -> str:
if issubclass(target_field, ormar.fields.ManyToManyField):
return cls.resolve_relation_name(target_field.through, cls)
if target_field.virtual:
return cls.resolve_relation_name(target_field.to, cls)
return target_field.to.Meta.pkname
@staticmethod
def get_clause_target_and_filter_column_name(
parent_model: Type["Model"], target_model: Type["Model"], reverse: bool
) -> Tuple[Type["Model"], str]:
if reverse:
field = target_model.resolve_relation_field(target_model, parent_model)
if issubclass(field, ormar.fields.ManyToManyField):
sub_field = target_model.resolve_relation_field(
field.through, parent_model
)
return field.through, sub_field.get_alias()
return target_model, field.get_alias()
target_field = target_model.get_column_alias(target_model.Meta.pkname)
return target_model, target_field
@staticmethod
def get_column_name_for_id_extraction(
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool,
use_raw: bool,
) -> str:
if reverse:
column_name = parent_model.Meta.pkname
return (
parent_model.get_column_alias(column_name) if use_raw else column_name
)
column = target_model.resolve_relation_field(parent_model, target_model)
return column.get_alias() if use_raw else column.name
@classmethod
def get_filtered_names_to_extract(cls, prefetch_dict: Dict) -> List:
related_to_extract = []
if prefetch_dict and prefetch_dict is not Ellipsis:
related_to_extract = [
related
for related in cls.extract_related_names()
if related in prefetch_dict
]
return related_to_extract
def get_relation_model_id(self, target_field: Type["BaseField"]) -> Optional[int]:
if target_field.virtual or issubclass(
target_field, ormar.fields.ManyToManyField
):
return self.pk
related_name = self.resolve_relation_name(self, target_field.to)
related_model = getattr(self, related_name)
return None if not related_model else related_model.pk
@classmethod @classmethod
def extract_db_own_fields(cls) -> Set: def extract_db_own_fields(cls) -> Set:
related_names = cls.extract_related_names() related_names = cls.extract_related_names()
@ -155,8 +218,18 @@ class ModelTableProxy:
@staticmethod @staticmethod
def resolve_relation_name( # noqa CCR001 def resolve_relation_name( # noqa CCR001
item: Union["NewBaseModel", Type["NewBaseModel"]], item: Union[
related: Union["NewBaseModel", Type["NewBaseModel"]], "NewBaseModel",
Type["NewBaseModel"],
"ModelTableProxy",
Type["ModelTableProxy"],
],
related: Union[
"NewBaseModel",
Type["NewBaseModel"],
"ModelTableProxy",
Type["ModelTableProxy"],
],
) -> str: ) -> str:
for name, field in item.Meta.model_fields.items(): for name, field in item.Meta.model_fields.items():
if issubclass(field, ForeignKeyField): if issubclass(field, ForeignKeyField):
@ -236,9 +309,7 @@ class ModelTableProxy:
@staticmethod @staticmethod
def _populate_pk_column( def _populate_pk_column(
model: Type["Model"], model: Type["Model"], columns: List[str], use_alias: bool = False,
columns: List[str],
use_alias: bool = False,
) -> List[str]: ) -> List[str]:
pk_alias = ( pk_alias = (
model.get_column_alias(model.Meta.pkname) model.get_column_alias(model.Meta.pkname)

View File

@ -0,0 +1,394 @@
from typing import (
Any,
Dict,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
Tuple,
Type,
Union,
)
import ormar
from ormar.fields import BaseField, ManyToManyField
from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query
from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict
if TYPE_CHECKING: # pragma: no cover
from ormar import Model
def add_relation_field_to_fields(
fields: Union[Set[Any], Dict[Any, Any], None], related_field_name: str
) -> Union[Set[Any], Dict[Any, Any], None]:
if fields and related_field_name not in fields:
if isinstance(fields, dict):
fields[related_field_name] = ...
elif isinstance(fields, set):
fields.add(related_field_name)
return fields
def sort_models(models: List["Model"], orders_by: Dict) -> List["Model"]:
sort_criteria = [
(key, value) for key, value in orders_by.items() if isinstance(value, str)
]
sort_criteria = sort_criteria[::-1]
for criteria in sort_criteria:
key, value = criteria
if value == "desc":
models.sort(key=lambda x: getattr(x, key), reverse=True)
else:
models.sort(key=lambda x: getattr(x, key))
return models
def set_children_on_model( # noqa: CCR001
model: "Model",
related: str,
children: Dict,
model_id: int,
models: Dict,
orders_by: Dict,
) -> None:
for key, child_models in children.items():
if key == model_id:
models_to_set = [models[child] for child in sorted(child_models)]
if models_to_set:
if orders_by and any(isinstance(x, str) for x in orders_by.values()):
models_to_set = sort_models(
models=models_to_set, orders_by=orders_by
)
for child in models_to_set:
setattr(model, related, child)
class PrefetchQuery:
def __init__( # noqa: CFQ002
self,
model_cls: Type["Model"],
fields: Optional[Union[Dict, Set]],
exclude_fields: Optional[Union[Dict, Set]],
prefetch_related: List,
select_related: List,
orders_by: List,
) -> None:
self.model = model_cls
self.database = self.model.Meta.database
self._prefetch_related = prefetch_related
self._select_related = select_related
self._exclude_columns = exclude_fields
self._columns = fields
self.already_extracted: Dict = dict()
self.models: Dict = {}
self.select_dict = translate_list_to_dict(self._select_related)
self.orders_by = orders_by or []
self.order_dict = translate_list_to_dict(self.orders_by, is_order=True)
async def prefetch_related(
self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]:
self.models = extract_models_to_dict_of_lists(
model_type=self.model, models=models, select_dict=self.select_dict
)
self.models[self.model.get_name()] = models
return await self._prefetch_related_models(models=models, rows=rows)
def _extract_ids_from_raw_data(
self, parent_model: Type["Model"], column_name: str
) -> Set:
list_of_ids = set()
current_data = self.already_extracted.get(parent_model.get_name(), {})
table_prefix = current_data.get("prefix", "")
column_name = (f"{table_prefix}_" if table_prefix else "") + column_name
for row in current_data.get("raw", []):
if row[column_name]:
list_of_ids.add(row[column_name])
return list_of_ids
def _extract_ids_from_preloaded_models(
self, parent_model: Type["Model"], column_name: str
) -> Set:
list_of_ids = set()
for model in self.models.get(parent_model.get_name(), []):
child = getattr(model, column_name)
if isinstance(child, ormar.Model):
list_of_ids.add(child.pk)
else:
list_of_ids.add(child)
return list_of_ids
def _extract_required_ids(
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
) -> Set:
use_raw = parent_model.get_name() not in self.models
column_name = parent_model.get_column_name_for_id_extraction(
parent_model=parent_model,
target_model=target_model,
reverse=reverse,
use_raw=use_raw,
)
if use_raw:
return self._extract_ids_from_raw_data(
parent_model=parent_model, column_name=column_name
)
return self._extract_ids_from_preloaded_models(
parent_model=parent_model, column_name=column_name
)
def _get_filter_for_prefetch(
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
) -> List:
ids = self._extract_required_ids(
parent_model=parent_model, target_model=target_model, reverse=reverse,
)
if ids:
(
clause_target,
filter_column,
) = parent_model.get_clause_target_and_filter_column_name(
parent_model=parent_model, target_model=target_model, reverse=reverse
)
qryclause = QueryClause(
model_cls=clause_target, select_related=[], filter_clauses=[],
)
kwargs = {f"{filter_column}__in": ids}
filter_clauses, _ = qryclause.filter(**kwargs)
return filter_clauses
return []
def _populate_nested_related(
self, model: "Model", prefetch_dict: Dict, orders_by: Dict,
) -> "Model":
related_to_extract = model.get_filtered_names_to_extract(
prefetch_dict=prefetch_dict
)
for related in related_to_extract:
target_field = model.Meta.model_fields[related]
target_model = target_field.to.get_name()
model_id = model.get_relation_model_id(target_field=target_field)
if model_id is None: # pragma: no cover
continue
field_name = model.get_related_field_name(target_field=target_field)
children = self.already_extracted.get(target_model, {}).get(field_name, {})
models = self.already_extracted.get(target_model, {}).get("pk_models", {})
set_children_on_model(
model=model,
related=related,
children=children,
model_id=model_id,
models=models,
orders_by=orders_by.get(related, {}),
)
return model
async def _prefetch_related_models(
self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]:
self.already_extracted = {self.model.get_name(): {"raw": rows}}
select_dict = translate_list_to_dict(self._select_related)
prefetch_dict = translate_list_to_dict(self._prefetch_related)
target_model = self.model
fields = self._columns
exclude_fields = self._exclude_columns
orders_by = self.order_dict
for related in prefetch_dict.keys():
await self._extract_related_models(
related=related,
target_model=target_model,
prefetch_dict=prefetch_dict.get(related, {}),
select_dict=select_dict.get(related, {}),
fields=fields,
exclude_fields=exclude_fields,
orders_by=orders_by.get(related, {}),
)
final_models = []
for model in models:
final_models.append(
self._populate_nested_related(
model=model, prefetch_dict=prefetch_dict, orders_by=self.order_dict
)
)
return models
async def _extract_related_models( # noqa: CFQ002, CCR001
self,
related: str,
target_model: Type["Model"],
prefetch_dict: Dict,
select_dict: Dict,
fields: Union[Set[Any], Dict[Any, Any], None],
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
orders_by: Dict,
) -> None:
fields = target_model.get_included(fields, related)
exclude_fields = target_model.get_excluded(exclude_fields, related)
target_field = target_model.Meta.model_fields[related]
reverse = False
if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True
parent_model = target_model
filter_clauses = self._get_filter_for_prefetch(
parent_model=parent_model, target_model=target_field.to, reverse=reverse,
)
if not filter_clauses: # related field is empty
return
already_loaded = select_dict is Ellipsis or related in select_dict
if not already_loaded:
# If not already loaded with select_related
related_field_name = parent_model.get_related_field_name(
target_field=target_field
)
fields = add_relation_field_to_fields(
fields=fields, related_field_name=related_field_name
)
table_prefix, rows = await self._run_prefetch_query(
target_field=target_field,
fields=fields,
exclude_fields=exclude_fields,
filter_clauses=filter_clauses,
)
else:
rows = []
table_prefix = ""
if prefetch_dict and prefetch_dict is not Ellipsis:
for subrelated in prefetch_dict.keys():
await self._extract_related_models(
related=subrelated,
target_model=target_field.to,
prefetch_dict=prefetch_dict.get(subrelated, {}),
select_dict=self._get_select_related_if_apply(
subrelated, select_dict
),
fields=fields,
exclude_fields=exclude_fields,
orders_by=self._get_select_related_if_apply(subrelated, orders_by),
)
if not already_loaded:
self._populate_rows(
rows=rows,
parent_model=parent_model,
target_field=target_field,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields,
prefetch_dict=prefetch_dict,
orders_by=orders_by,
)
else:
self._update_already_loaded_rows(
target_field=target_field,
prefetch_dict=prefetch_dict,
orders_by=orders_by,
)
async def _run_prefetch_query(
self,
target_field: Type["BaseField"],
fields: Union[Set[Any], Dict[Any, Any], None],
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
filter_clauses: List,
) -> Tuple[str, List]:
target_model = target_field.to
target_name = target_model.get_name()
select_related = []
query_target = target_model
table_prefix = ""
if issubclass(target_field, ManyToManyField):
query_target = target_field.through
select_related = [target_name]
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=query_target.Meta.tablename,
to_table=target_field.to.Meta.tablename,
)
self.already_extracted.setdefault(target_name, {})["prefix"] = table_prefix
qry = Query(
model_cls=query_target,
select_related=select_related,
filter_clauses=filter_clauses,
exclude_clauses=[],
offset=None,
limit_count=None,
fields=fields,
exclude_fields=exclude_fields,
order_bys=None,
)
expr = qry.build_select_expression()
# print(expr.compile(compile_kwargs={"literal_binds": True}))
rows = await self.database.fetch_all(expr)
self.already_extracted.setdefault(target_name, {}).update({"raw": rows})
return table_prefix, rows
@staticmethod
def _get_select_related_if_apply(related: str, select_dict: Dict) -> Dict:
return (
select_dict.get(related, {})
if (select_dict and select_dict is not Ellipsis and related in select_dict)
else {}
)
def _update_already_loaded_rows( # noqa: CFQ002
self, target_field: Type["BaseField"], prefetch_dict: Dict, orders_by: Dict,
) -> None:
target_model = target_field.to
for instance in self.models.get(target_model.get_name(), []):
self._populate_nested_related(
model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by
)
def _populate_rows( # noqa: CFQ002
self,
rows: List,
target_field: Type["BaseField"],
parent_model: Type["Model"],
table_prefix: str,
fields: Union[Set[Any], Dict[Any, Any], None],
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
prefetch_dict: Dict,
orders_by: Dict,
) -> None:
target_model = target_field.to
for row in rows:
field_name = parent_model.get_related_field_name(target_field=target_field)
item = target_model.extract_prefixed_table_columns(
item={},
row=row,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields,
)
instance = target_model(**item)
instance = self._populate_nested_related(
model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by
)
field_db_name = target_model.get_column_alias(field_name)
models = self.already_extracted[target_model.get_name()].setdefault(
"pk_models", {}
)
if instance.pk not in models:
models[instance.pk] = instance
self.already_extracted[target_model.get_name()].setdefault(
field_name, dict()
).setdefault(row[field_db_name], set()).add(instance.pk)

View File

@ -9,6 +9,7 @@ from ormar import MultipleMatches, NoMatch
from ormar.exceptions import QueryDefinitionError from ormar.exceptions import QueryDefinitionError
from ormar.queryset import FilterQuery from ormar.queryset import FilterQuery
from ormar.queryset.clause import QueryClause from ormar.queryset.clause import QueryClause
from ormar.queryset.prefetch_query import PrefetchQuery
from ormar.queryset.query import Query from ormar.queryset.query import Query
from ormar.queryset.utils import update, update_dict_from_list from ormar.queryset.utils import update, update_dict_from_list
@ -30,11 +31,13 @@ class QuerySet:
columns: Dict = None, columns: Dict = None,
exclude_columns: Dict = None, exclude_columns: Dict = None,
order_bys: List = None, order_bys: List = None,
prefetch_related: List = None,
) -> None: ) -> None:
self.model_cls = model_cls self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses self.filter_clauses = [] if filter_clauses is None else filter_clauses
self.exclude_clauses = [] if exclude_clauses is None else exclude_clauses self.exclude_clauses = [] if exclude_clauses is None else exclude_clauses
self._select_related = [] if select_related is None else select_related self._select_related = [] if select_related is None else select_related
self._prefetch_related = [] if prefetch_related is None else prefetch_related
self.limit_count = limit_count self.limit_count = limit_count
self.query_offset = offset self.query_offset = offset
self._columns = columns or {} self._columns = columns or {}
@ -48,8 +51,7 @@ class QuerySet:
) -> "QuerySet": ) -> "QuerySet":
if issubclass(owner, ormar.Model): if issubclass(owner, ormar.Model):
return self.__class__(model_cls=owner) return self.__class__(model_cls=owner)
else: # pragma nocover return self.__class__() # pragma: no cover
return self.__class__()
@property @property
def model_meta(self) -> "ModelMeta": def model_meta(self) -> "ModelMeta":
@ -63,6 +65,19 @@ class QuerySet:
raise ValueError("Model class of QuerySet is not initialized") raise ValueError("Model class of QuerySet is not initialized")
return self.model_cls return self.model_cls
async def _prefetch_related_models(
self, models: Sequence[Optional["Model"]], rows: List
) -> Sequence[Optional["Model"]]:
query = PrefetchQuery(
model_cls=self.model,
fields=self._columns,
exclude_fields=self._exclude_columns,
prefetch_related=self._prefetch_related,
select_related=self._select_related,
orders_by=self.order_bys,
)
return await query.prefetch_related(models=models, rows=rows) # type: ignore
def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]:
result_rows = [ result_rows = [
self.model.from_row( self.model.from_row(
@ -148,6 +163,7 @@ class QuerySet:
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
) )
def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003 def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003
@ -168,6 +184,25 @@ class QuerySet:
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
)
def prefetch_related(self, related: Union[List, str]) -> "QuerySet":
if not isinstance(related, list):
related = [related]
related = list(set(list(self._prefetch_related) + related))
return self.__class__(
model_cls=self.model,
filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
select_related=self._select_related,
limit_count=self.limit_count,
offset=self.query_offset,
columns=self._columns,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys,
prefetch_related=related,
) )
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
@ -190,6 +225,7 @@ class QuerySet:
columns=self._columns, columns=self._columns,
exclude_columns=current_excluded, exclude_columns=current_excluded,
order_bys=self.order_bys, order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
) )
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
@ -212,6 +248,7 @@ class QuerySet:
columns=current_included, columns=current_included,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
) )
def order_by(self, columns: Union[List, str]) -> "QuerySet": def order_by(self, columns: Union[List, str]) -> "QuerySet":
@ -229,6 +266,7 @@ class QuerySet:
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=order_bys, order_bys=order_bys,
prefetch_related=self._prefetch_related,
) )
async def exists(self) -> bool: async def exists(self) -> bool:
@ -279,6 +317,7 @@ class QuerySet:
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
) )
def offset(self, offset: int) -> "QuerySet": def offset(self, offset: int) -> "QuerySet":
@ -292,6 +331,7 @@ class QuerySet:
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
) )
async def first(self, **kwargs: Any) -> "Model": async def first(self, **kwargs: Any) -> "Model":
@ -312,6 +352,8 @@ class QuerySet:
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
processed_rows = self._process_query_result_rows(rows) processed_rows = self._process_query_result_rows(rows)
if self._prefetch_related and processed_rows:
processed_rows = await self._prefetch_related_models(processed_rows, rows)
self.check_single_result_rows_count(processed_rows) self.check_single_result_rows_count(processed_rows)
return processed_rows[0] # type: ignore return processed_rows[0] # type: ignore
@ -337,6 +379,8 @@ class QuerySet:
expr = self.build_select_expression() expr = self.build_select_expression()
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
result_rows = self._process_query_result_rows(rows) result_rows = self._process_query_result_rows(rows)
if self._prefetch_related and result_rows:
result_rows = await self._prefetch_related_models(result_rows, rows)
return result_rows return result_rows

View File

@ -1,6 +1,9 @@
import collections.abc import collections.abc
import copy import copy
from typing import Any, Dict, List, Set, Union from typing import Any, Dict, List, Sequence, Set, TYPE_CHECKING, Type, Union
if TYPE_CHECKING: # pragma no cover
from ormar import Model
def check_node_not_dict_or_not_last_node( def check_node_not_dict_or_not_last_node(
@ -11,18 +14,28 @@ def check_node_not_dict_or_not_last_node(
) )
def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001 def translate_list_to_dict( # noqa: CCR001
list_to_trans: Union[List, Set], is_order: bool = False
) -> Dict:
new_dict: Dict = dict() new_dict: Dict = dict()
for path in list_to_trans: for path in list_to_trans:
current_level = new_dict current_level = new_dict
parts = path.split("__") parts = path.split("__")
def_val: Any = ...
if is_order:
if parts[0][0] == "-":
def_val = "desc"
parts[0] = parts[0][1:]
else:
def_val = "asc"
for part in parts: for part in parts:
if check_node_not_dict_or_not_last_node( if check_node_not_dict_or_not_last_node(
part=part, parts=parts, current_level=current_level part=part, parts=parts, current_level=current_level
): ):
current_level[part] = dict() current_level[part] = dict()
elif part not in current_level: elif part not in current_level:
current_level[part] = ... current_level[part] = def_val
current_level = current_level[part] current_level = current_level[part]
return new_dict return new_dict
@ -55,3 +68,39 @@ def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) ->
dict_to_update = translate_list_to_dict(list_to_update) dict_to_update = translate_list_to_dict(list_to_update)
update(updated_dict, dict_to_update) update(updated_dict, dict_to_update)
return updated_dict return updated_dict
def extract_nested_models( # noqa: CCR001
model: "Model", model_type: Type["Model"], select_dict: Dict, extracted: Dict
) -> None:
follow = [rel for rel in model_type.extract_related_names() if rel in select_dict]
for related in follow:
child = getattr(model, related)
if child:
target_model = model_type.Meta.model_fields[related].to
if isinstance(child, list):
extracted.setdefault(target_model.get_name(), []).extend(child)
if select_dict[related] is not Ellipsis:
for sub_child in child:
extract_nested_models(
sub_child, target_model, select_dict[related], extracted,
)
else:
extracted.setdefault(target_model.get_name(), []).append(child)
if select_dict[related] is not Ellipsis:
extract_nested_models(
child, target_model, select_dict[related], extracted,
)
def extract_models_to_dict_of_lists(
model_type: Type["Model"],
models: Sequence["Model"],
select_dict: Dict,
extracted: Dict = None,
) -> Dict:
if not extracted:
extracted = dict()
for model in models:
extract_nested_models(model, model_type, select_dict, extracted)
return extracted

View File

@ -57,10 +57,6 @@ class Organisation(ormar.Model):
ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"]) ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"])
class Organization(object):
pass
class Team(ormar.Model): class Team(ormar.Model):
class Meta: class Meta:
tablename = "teams" tablename = "teams"
@ -241,8 +237,8 @@ async def test_fk_filter():
tracks = ( tracks = (
await Track.objects.select_related("album") await Track.objects.select_related("album")
.filter(album__name="Fantasies") .filter(album__name="Fantasies")
.all() .all()
) )
assert len(tracks) == 3 assert len(tracks) == 3
for track in tracks: for track in tracks:
@ -250,8 +246,8 @@ async def test_fk_filter():
tracks = ( tracks = (
await Track.objects.select_related("album") await Track.objects.select_related("album")
.filter(album__name__icontains="fan") .filter(album__name__icontains="fan")
.all() .all()
) )
assert len(tracks) == 3 assert len(tracks) == 3
for track in tracks: for track in tracks:
@ -296,8 +292,8 @@ async def test_multiple_fk():
members = ( members = (
await Member.objects.select_related("team__org") await Member.objects.select_related("team__org")
.filter(team__org__ident="ACME Ltd") .filter(team__org__ident="ACME Ltd")
.all() .all()
) )
assert len(members) == 4 assert len(members) == 4
for member in members: for member in members:
@ -329,8 +325,8 @@ async def test_pk_filter():
tracks = ( tracks = (
await Track.objects.select_related("album") await Track.objects.select_related("album")
.filter(position=2, album__name="Test") .filter(position=2, album__name="Test")
.all() .all()
) )
assert len(tracks) == 1 assert len(tracks) == 1

View File

@ -0,0 +1,311 @@
from typing import List, Optional
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class RandomSet(ormar.Model):
class Meta:
tablename = "randoms"
metadata = metadata
database = database
id: int = ormar.Integer(name='random_id', primary_key=True)
name: str = ormar.String(max_length=100)
class Tonation(ormar.Model):
class Meta:
tablename = "tonations"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(name='tonation_name', max_length=100)
rand_set: Optional[RandomSet] = ormar.ForeignKey(RandomSet)
class Division(ormar.Model):
class Meta:
tablename = "divisions"
metadata = metadata
database = database
id: int = ormar.Integer(name='division_id', primary_key=True)
name: str = ormar.String(max_length=100, nullable=True)
class Shop(ormar.Model):
class Meta:
tablename = "shops"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=True)
division: Optional[Division] = ormar.ForeignKey(Division)
class AlbumShops(ormar.Model):
class Meta:
tablename = "albums_x_shops"
metadata = metadata
database = database
class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=True)
shops: List[Shop] = ormar.ManyToMany(to=Shop, through=AlbumShops)
class Track(ormar.Model):
class Meta:
tablename = "tracks"
metadata = metadata
database = database
id: int = ormar.Integer(name='track_id', primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album)
title: str = ormar.String(max_length=100)
position: int = ormar.Integer()
tonation: Optional[Tonation] = ormar.ForeignKey(Tonation, name='tonation_id')
class Cover(ormar.Model):
class Meta:
tablename = "covers"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures", name='album_id')
title: str = ormar.String(max_length=100)
artist: str = ormar.String(max_length=200, nullable=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_prefetch_related():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Malibu")
await album.save()
ton1 = await Tonation.objects.create(name='B-mol')
await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1)
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1)
await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
fantasies = Album(name="Fantasies")
await fantasies.save()
await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1)
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
await Track.objects.create(album=fantasies, title="Satellite Mind", position=3)
await Cover.objects.create(title='Cover3', album=fantasies, artist='Artist 3')
await Cover.objects.create(title='Cover4', album=fantasies, artist='Artist 4')
album = await Album.objects.filter(name='Malibu').prefetch_related(
['tracks__tonation', 'cover_pictures']).get()
assert len(album.tracks) == 3
assert album.tracks[0].title == 'The Bird'
assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].title == 'Cover1'
assert album.tracks[0].tonation.name == album.tracks[2].tonation.name == 'B-mol'
albums = await Album.objects.prefetch_related('tracks').all()
assert len(albums[0].tracks) == 3
assert len(albums[1].tracks) == 3
assert albums[0].tracks[0].title == "The Bird"
assert albums[1].tracks[0].title == "Help I'm Alive"
track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird")
assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1'
track = await Track.objects.prefetch_related(["album__cover_pictures"]).exclude_fields(
'album__cover_pictures__artist').get(title="The Bird")
assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist is None
tracks = await Track.objects.prefetch_related("album").all()
assert len(tracks) == 6
@pytest.mark.asyncio
async def test_prefetch_related_with_many_to_many():
async with database:
async with database.transaction(force_rollback=True):
div = await Division.objects.create(name='Div 1')
shop1 = await Shop.objects.create(name='Shop 1', division=div)
shop2 = await Shop.objects.create(name='Shop 2', division=div)
album = Album(name="Malibu")
await album.save()
await album.shops.add(shop1)
await album.shops.add(shop2)
await Track.objects.create(album=album, title="The Bird", position=1)
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2)
await Track.objects.create(album=album, title="The Waters", position=3)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops__division"]).get(
title="The Bird")
assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1'
assert len(track.album.shops) == 2
assert track.album.shops[0].name == 'Shop 1'
assert track.album.shops[0].division.name == 'Div 1'
album2 = Album(name="Malibu 2")
await album2.save()
await album2.shops.add(shop1)
await album2.shops.add(shop2)
await Track.objects.create(album=album2, title="The Bird 2", position=1)
tracks = await Track.objects.prefetch_related(["album__shops"]).all()
assert tracks[0].album.name == 'Malibu'
assert tracks[0].album.shops[0].name == "Shop 1"
assert tracks[3].album.name == 'Malibu 2'
assert tracks[3].album.shops[0].name == "Shop 1"
assert tracks[0].album.shops[0] == tracks[3].album.shops[0]
assert id(tracks[0].album.shops[0]) == id(tracks[3].album.shops[0])
tracks[0].album.shops[0].name = 'Dummy'
assert tracks[0].album.shops[0].name == tracks[3].album.shops[0].name
@pytest.mark.asyncio
async def test_prefetch_related_empty():
async with database:
async with database.transaction(force_rollback=True):
await Track.objects.create(title="The Bird", position=1)
track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird")
assert track.title == 'The Bird'
assert track.album is None
@pytest.mark.asyncio
async def test_prefetch_related_with_select_related():
async with database:
async with database.transaction(force_rollback=True):
div = await Division.objects.create(name='Div 1')
shop1 = await Shop.objects.create(name='Shop 1', division=div)
shop2 = await Shop.objects.create(name='Shop 2', division=div)
album = Album(name="Malibu")
await album.save()
await album.shops.add(shop1)
await album.shops.add(shop2)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related(
['cover_pictures', 'shops__division']).get()
assert len(album.tracks) == 0
assert len(album.cover_pictures) == 2
assert album.shops[0].division.name == 'Div 1'
rand_set = await RandomSet.objects.create(name='Rand 1')
ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set)
await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1)
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1)
await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1)
album = await Album.objects.select_related('tracks__tonation__rand_set').filter(
name='Malibu').prefetch_related(
['cover_pictures', 'shops__division']).order_by(
['-shops__name', '-cover_pictures__artist', 'shops__division__name']).get()
assert len(album.tracks) == 3
assert album.tracks[0].tonation == album.tracks[2].tonation == ton1
assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].artist == 'Artist 2'
assert len(album.shops) == 2
assert album.shops[0].name == 'Shop 2'
assert album.shops[0].division.name == 'Div 1'
track = await Track.objects.select_related('album').prefetch_related(
["album__cover_pictures", "album__shops__division"]).get(
title="The Bird")
assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1'
assert len(track.album.shops) == 2
assert track.album.shops[0].name == 'Shop 1'
assert track.album.shops[0].division.name == 'Div 1'
@pytest.mark.asyncio
async def test_prefetch_related_with_select_related_and_fields():
async with database:
async with database.transaction(force_rollback=True):
div = await Division.objects.create(name='Div 1')
shop1 = await Shop.objects.create(name='Shop 1', division=div)
shop2 = await Shop.objects.create(name='Shop 2', division=div)
album = Album(name="Malibu")
await album.save()
await album.shops.add(shop1)
await album.shops.add(shop2)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
rand_set = await RandomSet.objects.create(name='Rand 1')
ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set)
await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1)
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1)
await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1)
album = await Album.objects.select_related('tracks__tonation__rand_set').filter(
name='Malibu').prefetch_related(
['cover_pictures', 'shops__division']).exclude_fields({'shops': {'division': {'name'}}}).get()
assert len(album.tracks) == 3
assert album.tracks[0].tonation == album.tracks[2].tonation == ton1
assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].artist == 'Artist 1'
assert len(album.shops) == 2
assert album.shops[0].name == 'Shop 1'
assert album.shops[0].division.name is None
album = await Album.objects.select_related('tracks').filter(
name='Malibu').prefetch_related(
['cover_pictures', 'shops__division']).fields(
{'name': ..., 'shops': {'division'}, 'cover_pictures': {'id': ..., 'title': ...}}
).exclude_fields({'shops': {'division': {'name'}}}).get()
assert len(album.tracks) == 3
assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].artist is None
assert album.cover_pictures[0].title is not None
assert len(album.shops) == 2
assert album.shops[0].name is None
assert album.shops[0].division is not None
assert album.shops[0].division.name is None

View File

@ -1,5 +1,11 @@
import databases
import sqlalchemy
import ormar
from ormar.models.excludable import Excludable from ormar.models.excludable import Excludable
from ormar.queryset.prefetch_query import sort_models
from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update
from tests.settings import DATABASE_URL
def test_empty_excludable(): def test_empty_excludable():
@ -96,3 +102,43 @@ def test_updating_dict_inc_set_with_dict_inc_set():
"cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis}, "cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis},
"uu": Ellipsis, "uu": Ellipsis,
} }
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class SortModel(ormar.Model):
class Meta:
tablename = "sorts"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
sort_order: int = ormar.Integer()
def test_sorting_models():
models = [
SortModel(id=1, name='Alice', sort_order=0),
SortModel(id=2, name='Al', sort_order=1),
SortModel(id=3, name='Zake', sort_order=1),
SortModel(id=4, name='Will', sort_order=0),
SortModel(id=5, name='Al', sort_order=2),
SortModel(id=6, name='Alice', sort_order=2)
]
orders_by = {'name': 'asc', 'none': {}, 'sort_order': 'desc'}
models = sort_models(models, orders_by)
assert models[5].name == 'Zake'
assert models[0].name == 'Al'
assert models[1].name == 'Al'
assert [model.id for model in models] == [5, 2, 6, 1, 4, 3]
orders_by = {'name': 'asc', 'none': set('aa'), 'id': 'asc'}
models = sort_models(models, orders_by)
assert [model.id for model in models] == [2, 5, 1, 6, 4, 3]
orders_by = {'sort_order': 'asc', 'none': ..., 'id': 'asc', 'uu': 2, 'aa': None}
models = sort_models(models, orders_by)
assert [model.id for model in models] == [1, 4, 2, 3, 5, 6]