Merge pull request #54 from collerek/prefetch_related

Add prefetch related QuerySet method
This commit is contained in:
collerek
2020-11-26 18:35:14 +07:00
committed by GitHub
13 changed files with 1123 additions and 30 deletions

View File

@ -5,7 +5,8 @@ name: build
on:
push:
branches: [ master ]
branches-ignore:
- 'gh-pages'
pull_request:
branches: [ master ]

View File

@ -45,7 +45,7 @@ Ormar is built with:
* [`SQLAlchemy core`][sqlalchemy-core] for query building.
* [`databases`][databases] for cross-database async support.
* [`pydantic`][pydantic] for data validation.
* typing_extensions for python 3.6 - 3.7
* `typing_extensions` for python 3.6 - 3.7
### Migrations
@ -53,7 +53,8 @@ Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to p
database migrations.
**ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.4.0`
**ormar is still under development:**
We recommend pinning any dependencies (with i.e. `ormar~=0.5.2`)
### Quick Start
@ -157,6 +158,7 @@ assert len(tracks) == 1
* `filter(**kwargs) -> QuerySet`
* `exclude(**kwargs) -> QuerySet`
* `select_related(related: Union[List, str]) -> QuerySet`
* `prefetch_related(related: Union[List, str]) -> QuerySet`
* `limit(limit_count: int) -> QuerySet`
* `offset(offset: int) -> QuerySet`
* `count() -> int`
@ -165,6 +167,7 @@ assert len(tracks) == 1
* `exclude_fields(columns: Union[List, str, set, dict]) -> QuerySet`
* `order_by(columns:Union[List, str]) -> QuerySet`
#### Relation types
* One to many - with `ForeignKey(to: Model)`

View File

@ -45,7 +45,7 @@ Ormar is built with:
* [`SQLAlchemy core`][sqlalchemy-core] for query building.
* [`databases`][databases] for cross-database async support.
* [`pydantic`][pydantic] for data validation.
* typing_extensions for python 3.6 - 3.7
* `typing_extensions` for python 3.6 - 3.7
### Migrations
@ -53,7 +53,8 @@ Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to p
database migrations.
**ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.4.0`
**ormar is still under development:**
We recommend pinning any dependencies (with i.e. `ormar~=0.5.2`)
### Quick Start
@ -157,6 +158,7 @@ assert len(tracks) == 1
* `filter(**kwargs) -> QuerySet`
* `exclude(**kwargs) -> QuerySet`
* `select_related(related: Union[List, str]) -> QuerySet`
* `prefetch_related(related: Union[List, str]) -> QuerySet`
* `limit(limit_count: int) -> QuerySet`
* `offset(offset: int) -> QuerySet`
* `count() -> int`

View File

@ -253,11 +253,25 @@ notes = await Track.objects.exclude(position_gt=3).all()
`select_related(related: Union[List, str]) -> QuerySet`
Allows to prefetch related models.
Allows to prefetch related models during the same query.
**With `select_related` always only one query is run against the database**, meaning that one
(sometimes complicated) join is generated and later nested models are processed in python.
To fetch related model use `ForeignKey` names.
To chain related `Models` relation use double underscore.
To chain related `Models` relation use double underscores between names.
!!!note
If you are coming from `django` note that `ormar` `select_related` differs -> in `django` you can `select_related`
only singe relation types, while in `ormar` you can select related across `ForeignKey` relation,
reverse side of `ForeignKey` (so virtual auto generated keys) and `ManyToMany` fields (so all relations as of current version).
!!!tip
To control which model fields to select use `fields()` and `exclude_fields()` `QuerySet` methods.
!!!tip
To control order of models (both main or nested) use `order_by()` method.
```python
album = await Album.objects.select_related("tracks").all()
@ -286,6 +300,150 @@ Exactly the same behavior is for Many2Many fields, where you put the names of Ma
Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()`
### prefetch_related
`prefetch_related(related: Union[List, str]) -> QuerySet`
Allows to prefetch related models during query - but opposite to `select_related` each
subsequent model is fetched in a separate database query.
**With `prefetch_related` always one query per Model is run against the database**,
meaning that you will have multiple queries executed one after another.
To fetch related model use `ForeignKey` names.
To chain related `Models` relation use double underscores between names.
!!!tip
To control which model fields to select use `fields()` and `exclude_fields()` `QuerySet` methods.
!!!tip
To control order of models (both main or nested) use `order_by()` method.
```python
album = await Album.objects.prefetch_related("tracks").all()
# will return album will all columns tracks
```
You can provide a string or a list of strings
```python
classes = await SchoolClass.objects.prefetch_related(
["teachers__category", "students"]).all()
# will return classes with teachers and teachers categories
# as well as classes students
```
Exactly the same behavior is for Many2Many fields, where you put the names of Many2Many fields and the final `Models` are fetched for you.
!!!warning
If you set `ForeignKey` field as not nullable (so required) during
all queries the not nullable `Models` will be auto prefetched, even if you do not include them in select_related.
!!!note
All methods that do not return the rows explicitly returns a QueySet instance so you can chain them together
So operations like `filter()`, `select_related()`, `limit()` and `offset()` etc. can be chained.
Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()`
### select_related vs prefetch_related
Which should you use -> `select_related` or `prefetch_related`?
Well, it really depends on your data. The best answer is try yourself and see which one performs faster/better in your system constraints.
What to keep in mind:
#### Performance
**Number of queries**:
`select_related` always executes one query against the database, while `prefetch_related` executes multiple queries.
Usually the query (I/O) operation is the slowest one but it does not have to be.
**Number of rows**:
Imagine that you have 10 000 object in one table A and each of those objects have 3 children in table B,
and subsequently each object in table B has 2 children in table C. Something like this:
```
Model C
/
Model B - Model C
/
Model A - Model B - Model C
\ \
\ Model C
\
Model B - Model C
\
Model C
```
That means that `select_related` will always return 60 000 rows (10 000 * 3 * 2) later compacted to 10 000 models.
How many rows will return `prefetch_related`?
Well, that depends, if each of models B and C is unique it will return 10 000 rows in first query, 30 000 rows
(each of 3 children of A in table B are unique) in second query and 60 000 rows (each of 2 children of model B
in table C are unique) in 3rd query.
In this case `select_related` seems like a better choice, not only it will run one query comparing to 3 of
`prefetch_related` but will also return 60 000 rows comparing to 100 000 of `prefetch_related` (10+30+60k).
But what if each Model A has exactly the same 3 models B and each models C has exactly same models C? `select_related`
will still return 60 000 rows, while `prefetch_related` will return 10 000 for model A, 3 rows for model B and 2 rows for Model C.
So in total 10 006 rows. Now depending on the structure of models (i.e. if it has long Text() fields etc.) `prefetch_related`
might be faster despite it needs to perform three separate queries instead of one.
#### Memory
`ormar` is a mini ORM meaning that it does not keep a registry of already loaded models.
That means that in `select_related` example above you will always have 10 000 Models A, 30 000 Models B
(even if the unique number of rows in db is 3 - processing of `select_related` spawns **new** child models for each parent model).
And 60 000 Models C.
If the same Model B is shared by rows 1, 10, 100 etc. and you update one of those, the rest of rows
that share the same child will **not** be updated on the spot.
If you persist your changes into the database the change **will be available only after reload
(either each child separately or the whole query again)**.
That means that `select_related` will use more memory as each child is instantiated as a new object - obviously using it's own space.
!!!note
This might change in future versions if we decide to introduce caching.
!!!warning
By default all children (or event the same models loaded 2+ times) are completely independent, distinct python objects, despite that they represent the same row in db.
They will evaluate to True when compared, so in example above:
```python
# will return True if child1 of both rows is the same child db row
row1.child1 == row100.child1
# same here:
model1 = await Model.get(pk=1)
model2 = await Model.get(pk=1) # same pk = same row in db
# will return `True`
model1 == model2
```
but
```python
# will return False (note that id is a python `builtin` function not ormar one).
id(row1.child1) == (ro100.child1)
# from above - will also return False
id(model1) == id(model2)
```
On the contrary - with `prefetch_related` each unique distinct child model is instantiated
only once and the same child models is shared across all parent models.
That means that in `prefetch_related` example above if there are 3 distinct models in table B and 2 in table C,
there will be only 5 children nested models shared between all model A instances. That also means that if you update
any attribute it will be updated on all parents as they share the same child object.
### limit
@ -352,6 +510,10 @@ has_sample = await Book.objects.filter(title='Sample').exists()
With `fields()` you can select subset of model columns to limit the data load.
!!!note
Note that `fields()` and `exclude_fields()` works both for main models (on normal queries like `get`, `all` etc.)
as well as `select_related` and `prefetch_related` models (with nested notation).
Given a sample data like following:
```python
@ -433,6 +595,10 @@ It's the opposite of `fields()` method so check documentation above to see what
Especially check above how you can pass also nested dictionaries and sets as a mask to exclude fields from whole hierarchy.
!!!note
Note that `fields()` and `exclude_fields()` works both for main models (on normal queries like `get`, `all` etc.)
as well as `select_related` and `prefetch_related` models (with nested notation).
Below you can find few simple examples:
```python hl_lines="47 48 60 61 67"

View File

@ -1,3 +1,13 @@
# 0.5.2
* Added `prefetch_related` method to load subsequent models in separate queries.
* Update docs
# 0.5.1
* Switched to github actions instead of travis
* Update badges in the docs
# 0.5.0
* Added save status -> you can check if model is saved with `ModelInstance.saved` property

View File

@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType()
__version__ = "0.5.1"
__version__ = "0.5.2"
__all__ = [
"Integer",
"BigInteger",

View File

@ -1,12 +1,15 @@
import inspect
from collections import OrderedDict
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
Union,
@ -38,6 +41,8 @@ class ModelTableProxy:
Meta: ModelMeta
_related_names: Set
_related_names_hash: Union[str, bytes]
pk: Any
get_name: Callable
def dict(self): # noqa A003
raise NotImplementedError # pragma no cover
@ -47,6 +52,64 @@ class ModelTableProxy:
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
return self_fields
@classmethod
def get_related_field_name(cls, target_field: Type["BaseField"]) -> str:
if issubclass(target_field, ormar.fields.ManyToManyField):
return cls.resolve_relation_name(target_field.through, cls)
if target_field.virtual:
return cls.resolve_relation_name(target_field.to, cls)
return target_field.to.Meta.pkname
@staticmethod
def get_clause_target_and_filter_column_name(
parent_model: Type["Model"], target_model: Type["Model"], reverse: bool
) -> Tuple[Type["Model"], str]:
if reverse:
field = target_model.resolve_relation_field(target_model, parent_model)
if issubclass(field, ormar.fields.ManyToManyField):
sub_field = target_model.resolve_relation_field(
field.through, parent_model
)
return field.through, sub_field.get_alias()
return target_model, field.get_alias()
target_field = target_model.get_column_alias(target_model.Meta.pkname)
return target_model, target_field
@staticmethod
def get_column_name_for_id_extraction(
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool,
use_raw: bool,
) -> str:
if reverse:
column_name = parent_model.Meta.pkname
return (
parent_model.get_column_alias(column_name) if use_raw else column_name
)
column = target_model.resolve_relation_field(parent_model, target_model)
return column.get_alias() if use_raw else column.name
@classmethod
def get_filtered_names_to_extract(cls, prefetch_dict: Dict) -> List:
related_to_extract = []
if prefetch_dict and prefetch_dict is not Ellipsis:
related_to_extract = [
related
for related in cls.extract_related_names()
if related in prefetch_dict
]
return related_to_extract
def get_relation_model_id(self, target_field: Type["BaseField"]) -> Optional[int]:
if target_field.virtual or issubclass(
target_field, ormar.fields.ManyToManyField
):
return self.pk
related_name = self.resolve_relation_name(self, target_field.to)
related_model = getattr(self, related_name)
return None if not related_model else related_model.pk
@classmethod
def extract_db_own_fields(cls) -> Set:
related_names = cls.extract_related_names()
@ -155,8 +218,18 @@ class ModelTableProxy:
@staticmethod
def resolve_relation_name( # noqa CCR001
item: Union["NewBaseModel", Type["NewBaseModel"]],
related: Union["NewBaseModel", Type["NewBaseModel"]],
item: Union[
"NewBaseModel",
Type["NewBaseModel"],
"ModelTableProxy",
Type["ModelTableProxy"],
],
related: Union[
"NewBaseModel",
Type["NewBaseModel"],
"ModelTableProxy",
Type["ModelTableProxy"],
],
) -> str:
for name, field in item.Meta.model_fields.items():
if issubclass(field, ForeignKeyField):
@ -236,9 +309,7 @@ class ModelTableProxy:
@staticmethod
def _populate_pk_column(
model: Type["Model"],
columns: List[str],
use_alias: bool = False,
model: Type["Model"], columns: List[str], use_alias: bool = False,
) -> List[str]:
pk_alias = (
model.get_column_alias(model.Meta.pkname)

View File

@ -0,0 +1,394 @@
from typing import (
Any,
Dict,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
Tuple,
Type,
Union,
)
import ormar
from ormar.fields import BaseField, ManyToManyField
from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query
from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict
if TYPE_CHECKING: # pragma: no cover
from ormar import Model
def add_relation_field_to_fields(
fields: Union[Set[Any], Dict[Any, Any], None], related_field_name: str
) -> Union[Set[Any], Dict[Any, Any], None]:
if fields and related_field_name not in fields:
if isinstance(fields, dict):
fields[related_field_name] = ...
elif isinstance(fields, set):
fields.add(related_field_name)
return fields
def sort_models(models: List["Model"], orders_by: Dict) -> List["Model"]:
sort_criteria = [
(key, value) for key, value in orders_by.items() if isinstance(value, str)
]
sort_criteria = sort_criteria[::-1]
for criteria in sort_criteria:
key, value = criteria
if value == "desc":
models.sort(key=lambda x: getattr(x, key), reverse=True)
else:
models.sort(key=lambda x: getattr(x, key))
return models
def set_children_on_model( # noqa: CCR001
model: "Model",
related: str,
children: Dict,
model_id: int,
models: Dict,
orders_by: Dict,
) -> None:
for key, child_models in children.items():
if key == model_id:
models_to_set = [models[child] for child in sorted(child_models)]
if models_to_set:
if orders_by and any(isinstance(x, str) for x in orders_by.values()):
models_to_set = sort_models(
models=models_to_set, orders_by=orders_by
)
for child in models_to_set:
setattr(model, related, child)
class PrefetchQuery:
def __init__( # noqa: CFQ002
self,
model_cls: Type["Model"],
fields: Optional[Union[Dict, Set]],
exclude_fields: Optional[Union[Dict, Set]],
prefetch_related: List,
select_related: List,
orders_by: List,
) -> None:
self.model = model_cls
self.database = self.model.Meta.database
self._prefetch_related = prefetch_related
self._select_related = select_related
self._exclude_columns = exclude_fields
self._columns = fields
self.already_extracted: Dict = dict()
self.models: Dict = {}
self.select_dict = translate_list_to_dict(self._select_related)
self.orders_by = orders_by or []
self.order_dict = translate_list_to_dict(self.orders_by, is_order=True)
async def prefetch_related(
self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]:
self.models = extract_models_to_dict_of_lists(
model_type=self.model, models=models, select_dict=self.select_dict
)
self.models[self.model.get_name()] = models
return await self._prefetch_related_models(models=models, rows=rows)
def _extract_ids_from_raw_data(
self, parent_model: Type["Model"], column_name: str
) -> Set:
list_of_ids = set()
current_data = self.already_extracted.get(parent_model.get_name(), {})
table_prefix = current_data.get("prefix", "")
column_name = (f"{table_prefix}_" if table_prefix else "") + column_name
for row in current_data.get("raw", []):
if row[column_name]:
list_of_ids.add(row[column_name])
return list_of_ids
def _extract_ids_from_preloaded_models(
self, parent_model: Type["Model"], column_name: str
) -> Set:
list_of_ids = set()
for model in self.models.get(parent_model.get_name(), []):
child = getattr(model, column_name)
if isinstance(child, ormar.Model):
list_of_ids.add(child.pk)
else:
list_of_ids.add(child)
return list_of_ids
def _extract_required_ids(
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
) -> Set:
use_raw = parent_model.get_name() not in self.models
column_name = parent_model.get_column_name_for_id_extraction(
parent_model=parent_model,
target_model=target_model,
reverse=reverse,
use_raw=use_raw,
)
if use_raw:
return self._extract_ids_from_raw_data(
parent_model=parent_model, column_name=column_name
)
return self._extract_ids_from_preloaded_models(
parent_model=parent_model, column_name=column_name
)
def _get_filter_for_prefetch(
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
) -> List:
ids = self._extract_required_ids(
parent_model=parent_model, target_model=target_model, reverse=reverse,
)
if ids:
(
clause_target,
filter_column,
) = parent_model.get_clause_target_and_filter_column_name(
parent_model=parent_model, target_model=target_model, reverse=reverse
)
qryclause = QueryClause(
model_cls=clause_target, select_related=[], filter_clauses=[],
)
kwargs = {f"{filter_column}__in": ids}
filter_clauses, _ = qryclause.filter(**kwargs)
return filter_clauses
return []
def _populate_nested_related(
self, model: "Model", prefetch_dict: Dict, orders_by: Dict,
) -> "Model":
related_to_extract = model.get_filtered_names_to_extract(
prefetch_dict=prefetch_dict
)
for related in related_to_extract:
target_field = model.Meta.model_fields[related]
target_model = target_field.to.get_name()
model_id = model.get_relation_model_id(target_field=target_field)
if model_id is None: # pragma: no cover
continue
field_name = model.get_related_field_name(target_field=target_field)
children = self.already_extracted.get(target_model, {}).get(field_name, {})
models = self.already_extracted.get(target_model, {}).get("pk_models", {})
set_children_on_model(
model=model,
related=related,
children=children,
model_id=model_id,
models=models,
orders_by=orders_by.get(related, {}),
)
return model
async def _prefetch_related_models(
self, models: Sequence["Model"], rows: List
) -> Sequence["Model"]:
self.already_extracted = {self.model.get_name(): {"raw": rows}}
select_dict = translate_list_to_dict(self._select_related)
prefetch_dict = translate_list_to_dict(self._prefetch_related)
target_model = self.model
fields = self._columns
exclude_fields = self._exclude_columns
orders_by = self.order_dict
for related in prefetch_dict.keys():
await self._extract_related_models(
related=related,
target_model=target_model,
prefetch_dict=prefetch_dict.get(related, {}),
select_dict=select_dict.get(related, {}),
fields=fields,
exclude_fields=exclude_fields,
orders_by=orders_by.get(related, {}),
)
final_models = []
for model in models:
final_models.append(
self._populate_nested_related(
model=model, prefetch_dict=prefetch_dict, orders_by=self.order_dict
)
)
return models
async def _extract_related_models( # noqa: CFQ002, CCR001
self,
related: str,
target_model: Type["Model"],
prefetch_dict: Dict,
select_dict: Dict,
fields: Union[Set[Any], Dict[Any, Any], None],
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
orders_by: Dict,
) -> None:
fields = target_model.get_included(fields, related)
exclude_fields = target_model.get_excluded(exclude_fields, related)
target_field = target_model.Meta.model_fields[related]
reverse = False
if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True
parent_model = target_model
filter_clauses = self._get_filter_for_prefetch(
parent_model=parent_model, target_model=target_field.to, reverse=reverse,
)
if not filter_clauses: # related field is empty
return
already_loaded = select_dict is Ellipsis or related in select_dict
if not already_loaded:
# If not already loaded with select_related
related_field_name = parent_model.get_related_field_name(
target_field=target_field
)
fields = add_relation_field_to_fields(
fields=fields, related_field_name=related_field_name
)
table_prefix, rows = await self._run_prefetch_query(
target_field=target_field,
fields=fields,
exclude_fields=exclude_fields,
filter_clauses=filter_clauses,
)
else:
rows = []
table_prefix = ""
if prefetch_dict and prefetch_dict is not Ellipsis:
for subrelated in prefetch_dict.keys():
await self._extract_related_models(
related=subrelated,
target_model=target_field.to,
prefetch_dict=prefetch_dict.get(subrelated, {}),
select_dict=self._get_select_related_if_apply(
subrelated, select_dict
),
fields=fields,
exclude_fields=exclude_fields,
orders_by=self._get_select_related_if_apply(subrelated, orders_by),
)
if not already_loaded:
self._populate_rows(
rows=rows,
parent_model=parent_model,
target_field=target_field,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields,
prefetch_dict=prefetch_dict,
orders_by=orders_by,
)
else:
self._update_already_loaded_rows(
target_field=target_field,
prefetch_dict=prefetch_dict,
orders_by=orders_by,
)
async def _run_prefetch_query(
self,
target_field: Type["BaseField"],
fields: Union[Set[Any], Dict[Any, Any], None],
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
filter_clauses: List,
) -> Tuple[str, List]:
target_model = target_field.to
target_name = target_model.get_name()
select_related = []
query_target = target_model
table_prefix = ""
if issubclass(target_field, ManyToManyField):
query_target = target_field.through
select_related = [target_name]
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join(
from_table=query_target.Meta.tablename,
to_table=target_field.to.Meta.tablename,
)
self.already_extracted.setdefault(target_name, {})["prefix"] = table_prefix
qry = Query(
model_cls=query_target,
select_related=select_related,
filter_clauses=filter_clauses,
exclude_clauses=[],
offset=None,
limit_count=None,
fields=fields,
exclude_fields=exclude_fields,
order_bys=None,
)
expr = qry.build_select_expression()
# print(expr.compile(compile_kwargs={"literal_binds": True}))
rows = await self.database.fetch_all(expr)
self.already_extracted.setdefault(target_name, {}).update({"raw": rows})
return table_prefix, rows
@staticmethod
def _get_select_related_if_apply(related: str, select_dict: Dict) -> Dict:
return (
select_dict.get(related, {})
if (select_dict and select_dict is not Ellipsis and related in select_dict)
else {}
)
def _update_already_loaded_rows( # noqa: CFQ002
self, target_field: Type["BaseField"], prefetch_dict: Dict, orders_by: Dict,
) -> None:
target_model = target_field.to
for instance in self.models.get(target_model.get_name(), []):
self._populate_nested_related(
model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by
)
def _populate_rows( # noqa: CFQ002
self,
rows: List,
target_field: Type["BaseField"],
parent_model: Type["Model"],
table_prefix: str,
fields: Union[Set[Any], Dict[Any, Any], None],
exclude_fields: Union[Set[Any], Dict[Any, Any], None],
prefetch_dict: Dict,
orders_by: Dict,
) -> None:
target_model = target_field.to
for row in rows:
field_name = parent_model.get_related_field_name(target_field=target_field)
item = target_model.extract_prefixed_table_columns(
item={},
row=row,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields,
)
instance = target_model(**item)
instance = self._populate_nested_related(
model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by
)
field_db_name = target_model.get_column_alias(field_name)
models = self.already_extracted[target_model.get_name()].setdefault(
"pk_models", {}
)
if instance.pk not in models:
models[instance.pk] = instance
self.already_extracted[target_model.get_name()].setdefault(
field_name, dict()
).setdefault(row[field_db_name], set()).add(instance.pk)

View File

@ -9,6 +9,7 @@ from ormar import MultipleMatches, NoMatch
from ormar.exceptions import QueryDefinitionError
from ormar.queryset import FilterQuery
from ormar.queryset.clause import QueryClause
from ormar.queryset.prefetch_query import PrefetchQuery
from ormar.queryset.query import Query
from ormar.queryset.utils import update, update_dict_from_list
@ -30,11 +31,13 @@ class QuerySet:
columns: Dict = None,
exclude_columns: Dict = None,
order_bys: List = None,
prefetch_related: List = None,
) -> None:
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
self.exclude_clauses = [] if exclude_clauses is None else exclude_clauses
self._select_related = [] if select_related is None else select_related
self._prefetch_related = [] if prefetch_related is None else prefetch_related
self.limit_count = limit_count
self.query_offset = offset
self._columns = columns or {}
@ -48,8 +51,7 @@ class QuerySet:
) -> "QuerySet":
if issubclass(owner, ormar.Model):
return self.__class__(model_cls=owner)
else: # pragma nocover
return self.__class__()
return self.__class__() # pragma: no cover
@property
def model_meta(self) -> "ModelMeta":
@ -63,6 +65,19 @@ class QuerySet:
raise ValueError("Model class of QuerySet is not initialized")
return self.model_cls
async def _prefetch_related_models(
self, models: Sequence[Optional["Model"]], rows: List
) -> Sequence[Optional["Model"]]:
query = PrefetchQuery(
model_cls=self.model,
fields=self._columns,
exclude_fields=self._exclude_columns,
prefetch_related=self._prefetch_related,
select_related=self._select_related,
orders_by=self.order_bys,
)
return await query.prefetch_related(models=models, rows=rows) # type: ignore
def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]:
result_rows = [
self.model.from_row(
@ -148,6 +163,7 @@ class QuerySet:
columns=self._columns,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
)
def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003
@ -168,6 +184,25 @@ class QuerySet:
columns=self._columns,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
)
def prefetch_related(self, related: Union[List, str]) -> "QuerySet":
if not isinstance(related, list):
related = [related]
related = list(set(list(self._prefetch_related) + related))
return self.__class__(
model_cls=self.model,
filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
select_related=self._select_related,
limit_count=self.limit_count,
offset=self.query_offset,
columns=self._columns,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys,
prefetch_related=related,
)
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
@ -190,6 +225,7 @@ class QuerySet:
columns=self._columns,
exclude_columns=current_excluded,
order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
)
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
@ -212,6 +248,7 @@ class QuerySet:
columns=current_included,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
)
def order_by(self, columns: Union[List, str]) -> "QuerySet":
@ -229,6 +266,7 @@ class QuerySet:
columns=self._columns,
exclude_columns=self._exclude_columns,
order_bys=order_bys,
prefetch_related=self._prefetch_related,
)
async def exists(self) -> bool:
@ -279,6 +317,7 @@ class QuerySet:
columns=self._columns,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
)
def offset(self, offset: int) -> "QuerySet":
@ -292,6 +331,7 @@ class QuerySet:
columns=self._columns,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys,
prefetch_related=self._prefetch_related,
)
async def first(self, **kwargs: Any) -> "Model":
@ -312,6 +352,8 @@ class QuerySet:
rows = await self.database.fetch_all(expr)
processed_rows = self._process_query_result_rows(rows)
if self._prefetch_related and processed_rows:
processed_rows = await self._prefetch_related_models(processed_rows, rows)
self.check_single_result_rows_count(processed_rows)
return processed_rows[0] # type: ignore
@ -337,6 +379,8 @@ class QuerySet:
expr = self.build_select_expression()
rows = await self.database.fetch_all(expr)
result_rows = self._process_query_result_rows(rows)
if self._prefetch_related and result_rows:
result_rows = await self._prefetch_related_models(result_rows, rows)
return result_rows

View File

@ -1,6 +1,9 @@
import collections.abc
import copy
from typing import Any, Dict, List, Set, Union
from typing import Any, Dict, List, Sequence, Set, TYPE_CHECKING, Type, Union
if TYPE_CHECKING: # pragma no cover
from ormar import Model
def check_node_not_dict_or_not_last_node(
@ -11,18 +14,28 @@ def check_node_not_dict_or_not_last_node(
)
def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001
def translate_list_to_dict( # noqa: CCR001
list_to_trans: Union[List, Set], is_order: bool = False
) -> Dict:
new_dict: Dict = dict()
for path in list_to_trans:
current_level = new_dict
parts = path.split("__")
def_val: Any = ...
if is_order:
if parts[0][0] == "-":
def_val = "desc"
parts[0] = parts[0][1:]
else:
def_val = "asc"
for part in parts:
if check_node_not_dict_or_not_last_node(
part=part, parts=parts, current_level=current_level
):
current_level[part] = dict()
elif part not in current_level:
current_level[part] = ...
current_level[part] = def_val
current_level = current_level[part]
return new_dict
@ -55,3 +68,39 @@ def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) ->
dict_to_update = translate_list_to_dict(list_to_update)
update(updated_dict, dict_to_update)
return updated_dict
def extract_nested_models( # noqa: CCR001
model: "Model", model_type: Type["Model"], select_dict: Dict, extracted: Dict
) -> None:
follow = [rel for rel in model_type.extract_related_names() if rel in select_dict]
for related in follow:
child = getattr(model, related)
if child:
target_model = model_type.Meta.model_fields[related].to
if isinstance(child, list):
extracted.setdefault(target_model.get_name(), []).extend(child)
if select_dict[related] is not Ellipsis:
for sub_child in child:
extract_nested_models(
sub_child, target_model, select_dict[related], extracted,
)
else:
extracted.setdefault(target_model.get_name(), []).append(child)
if select_dict[related] is not Ellipsis:
extract_nested_models(
child, target_model, select_dict[related], extracted,
)
def extract_models_to_dict_of_lists(
model_type: Type["Model"],
models: Sequence["Model"],
select_dict: Dict,
extracted: Dict = None,
) -> Dict:
if not extracted:
extracted = dict()
for model in models:
extract_nested_models(model, model_type, select_dict, extracted)
return extracted

View File

@ -57,10 +57,6 @@ class Organisation(ormar.Model):
ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"])
class Organization(object):
pass
class Team(ormar.Model):
class Meta:
tablename = "teams"

View File

@ -0,0 +1,311 @@
from typing import List, Optional
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class RandomSet(ormar.Model):
class Meta:
tablename = "randoms"
metadata = metadata
database = database
id: int = ormar.Integer(name='random_id', primary_key=True)
name: str = ormar.String(max_length=100)
class Tonation(ormar.Model):
class Meta:
tablename = "tonations"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(name='tonation_name', max_length=100)
rand_set: Optional[RandomSet] = ormar.ForeignKey(RandomSet)
class Division(ormar.Model):
class Meta:
tablename = "divisions"
metadata = metadata
database = database
id: int = ormar.Integer(name='division_id', primary_key=True)
name: str = ormar.String(max_length=100, nullable=True)
class Shop(ormar.Model):
class Meta:
tablename = "shops"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=True)
division: Optional[Division] = ormar.ForeignKey(Division)
class AlbumShops(ormar.Model):
class Meta:
tablename = "albums_x_shops"
metadata = metadata
database = database
class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=True)
shops: List[Shop] = ormar.ManyToMany(to=Shop, through=AlbumShops)
class Track(ormar.Model):
class Meta:
tablename = "tracks"
metadata = metadata
database = database
id: int = ormar.Integer(name='track_id', primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album)
title: str = ormar.String(max_length=100)
position: int = ormar.Integer()
tonation: Optional[Tonation] = ormar.ForeignKey(Tonation, name='tonation_id')
class Cover(ormar.Model):
class Meta:
tablename = "covers"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures", name='album_id')
title: str = ormar.String(max_length=100)
artist: str = ormar.String(max_length=200, nullable=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_prefetch_related():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Malibu")
await album.save()
ton1 = await Tonation.objects.create(name='B-mol')
await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1)
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1)
await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
fantasies = Album(name="Fantasies")
await fantasies.save()
await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1)
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
await Track.objects.create(album=fantasies, title="Satellite Mind", position=3)
await Cover.objects.create(title='Cover3', album=fantasies, artist='Artist 3')
await Cover.objects.create(title='Cover4', album=fantasies, artist='Artist 4')
album = await Album.objects.filter(name='Malibu').prefetch_related(
['tracks__tonation', 'cover_pictures']).get()
assert len(album.tracks) == 3
assert album.tracks[0].title == 'The Bird'
assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].title == 'Cover1'
assert album.tracks[0].tonation.name == album.tracks[2].tonation.name == 'B-mol'
albums = await Album.objects.prefetch_related('tracks').all()
assert len(albums[0].tracks) == 3
assert len(albums[1].tracks) == 3
assert albums[0].tracks[0].title == "The Bird"
assert albums[1].tracks[0].title == "Help I'm Alive"
track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird")
assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1'
track = await Track.objects.prefetch_related(["album__cover_pictures"]).exclude_fields(
'album__cover_pictures__artist').get(title="The Bird")
assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist is None
tracks = await Track.objects.prefetch_related("album").all()
assert len(tracks) == 6
@pytest.mark.asyncio
async def test_prefetch_related_with_many_to_many():
async with database:
async with database.transaction(force_rollback=True):
div = await Division.objects.create(name='Div 1')
shop1 = await Shop.objects.create(name='Shop 1', division=div)
shop2 = await Shop.objects.create(name='Shop 2', division=div)
album = Album(name="Malibu")
await album.save()
await album.shops.add(shop1)
await album.shops.add(shop2)
await Track.objects.create(album=album, title="The Bird", position=1)
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2)
await Track.objects.create(album=album, title="The Waters", position=3)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops__division"]).get(
title="The Bird")
assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1'
assert len(track.album.shops) == 2
assert track.album.shops[0].name == 'Shop 1'
assert track.album.shops[0].division.name == 'Div 1'
album2 = Album(name="Malibu 2")
await album2.save()
await album2.shops.add(shop1)
await album2.shops.add(shop2)
await Track.objects.create(album=album2, title="The Bird 2", position=1)
tracks = await Track.objects.prefetch_related(["album__shops"]).all()
assert tracks[0].album.name == 'Malibu'
assert tracks[0].album.shops[0].name == "Shop 1"
assert tracks[3].album.name == 'Malibu 2'
assert tracks[3].album.shops[0].name == "Shop 1"
assert tracks[0].album.shops[0] == tracks[3].album.shops[0]
assert id(tracks[0].album.shops[0]) == id(tracks[3].album.shops[0])
tracks[0].album.shops[0].name = 'Dummy'
assert tracks[0].album.shops[0].name == tracks[3].album.shops[0].name
@pytest.mark.asyncio
async def test_prefetch_related_empty():
async with database:
async with database.transaction(force_rollback=True):
await Track.objects.create(title="The Bird", position=1)
track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird")
assert track.title == 'The Bird'
assert track.album is None
@pytest.mark.asyncio
async def test_prefetch_related_with_select_related():
async with database:
async with database.transaction(force_rollback=True):
div = await Division.objects.create(name='Div 1')
shop1 = await Shop.objects.create(name='Shop 1', division=div)
shop2 = await Shop.objects.create(name='Shop 2', division=div)
album = Album(name="Malibu")
await album.save()
await album.shops.add(shop1)
await album.shops.add(shop2)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related(
['cover_pictures', 'shops__division']).get()
assert len(album.tracks) == 0
assert len(album.cover_pictures) == 2
assert album.shops[0].division.name == 'Div 1'
rand_set = await RandomSet.objects.create(name='Rand 1')
ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set)
await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1)
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1)
await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1)
album = await Album.objects.select_related('tracks__tonation__rand_set').filter(
name='Malibu').prefetch_related(
['cover_pictures', 'shops__division']).order_by(
['-shops__name', '-cover_pictures__artist', 'shops__division__name']).get()
assert len(album.tracks) == 3
assert album.tracks[0].tonation == album.tracks[2].tonation == ton1
assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].artist == 'Artist 2'
assert len(album.shops) == 2
assert album.shops[0].name == 'Shop 2'
assert album.shops[0].division.name == 'Div 1'
track = await Track.objects.select_related('album').prefetch_related(
["album__cover_pictures", "album__shops__division"]).get(
title="The Bird")
assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1'
assert len(track.album.shops) == 2
assert track.album.shops[0].name == 'Shop 1'
assert track.album.shops[0].division.name == 'Div 1'
@pytest.mark.asyncio
async def test_prefetch_related_with_select_related_and_fields():
async with database:
async with database.transaction(force_rollback=True):
div = await Division.objects.create(name='Div 1')
shop1 = await Shop.objects.create(name='Shop 1', division=div)
shop2 = await Shop.objects.create(name='Shop 2', division=div)
album = Album(name="Malibu")
await album.save()
await album.shops.add(shop1)
await album.shops.add(shop2)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
rand_set = await RandomSet.objects.create(name='Rand 1')
ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set)
await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1)
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1)
await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1)
album = await Album.objects.select_related('tracks__tonation__rand_set').filter(
name='Malibu').prefetch_related(
['cover_pictures', 'shops__division']).exclude_fields({'shops': {'division': {'name'}}}).get()
assert len(album.tracks) == 3
assert album.tracks[0].tonation == album.tracks[2].tonation == ton1
assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].artist == 'Artist 1'
assert len(album.shops) == 2
assert album.shops[0].name == 'Shop 1'
assert album.shops[0].division.name is None
album = await Album.objects.select_related('tracks').filter(
name='Malibu').prefetch_related(
['cover_pictures', 'shops__division']).fields(
{'name': ..., 'shops': {'division'}, 'cover_pictures': {'id': ..., 'title': ...}}
).exclude_fields({'shops': {'division': {'name'}}}).get()
assert len(album.tracks) == 3
assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].artist is None
assert album.cover_pictures[0].title is not None
assert len(album.shops) == 2
assert album.shops[0].name is None
assert album.shops[0].division is not None
assert album.shops[0].division.name is None

View File

@ -1,5 +1,11 @@
import databases
import sqlalchemy
import ormar
from ormar.models.excludable import Excludable
from ormar.queryset.prefetch_query import sort_models
from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update
from tests.settings import DATABASE_URL
def test_empty_excludable():
@ -96,3 +102,43 @@ def test_updating_dict_inc_set_with_dict_inc_set():
"cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis},
"uu": Ellipsis,
}
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class SortModel(ormar.Model):
class Meta:
tablename = "sorts"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
sort_order: int = ormar.Integer()
def test_sorting_models():
models = [
SortModel(id=1, name='Alice', sort_order=0),
SortModel(id=2, name='Al', sort_order=1),
SortModel(id=3, name='Zake', sort_order=1),
SortModel(id=4, name='Will', sort_order=0),
SortModel(id=5, name='Al', sort_order=2),
SortModel(id=6, name='Alice', sort_order=2)
]
orders_by = {'name': 'asc', 'none': {}, 'sort_order': 'desc'}
models = sort_models(models, orders_by)
assert models[5].name == 'Zake'
assert models[0].name == 'Al'
assert models[1].name == 'Al'
assert [model.id for model in models] == [5, 2, 6, 1, 4, 3]
orders_by = {'name': 'asc', 'none': set('aa'), 'id': 'asc'}
models = sort_models(models, orders_by)
assert [model.id for model in models] == [2, 5, 1, 6, 4, 3]
orders_by = {'sort_order': 'asc', 'none': ..., 'id': 'asc', 'uu': 2, 'aa': None}
models = sort_models(models, orders_by)
assert [model.id for model in models] == [1, 4, 2, 3, 5, 6]