diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml index 7bd02a9..a953a77 100644 --- a/.github/workflows/test-package.yml +++ b/.github/workflows/test-package.yml @@ -5,7 +5,8 @@ name: build on: push: - branches: [ master ] + branches-ignore: + - 'gh-pages' pull_request: branches: [ master ] diff --git a/README.md b/README.md index 25fd50f..b4b1890 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Ormar is built with: * [`SQLAlchemy core`][sqlalchemy-core] for query building. * [`databases`][databases] for cross-database async support. * [`pydantic`][pydantic] for data validation. - * typing_extensions for python 3.6 - 3.7 + * `typing_extensions` for python 3.6 - 3.7 ### Migrations @@ -53,7 +53,8 @@ Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to p database migrations. -**ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.4.0` +**ormar is still under development:** +We recommend pinning any dependencies (with i.e. `ormar~=0.5.2`) ### Quick Start @@ -157,6 +158,7 @@ assert len(tracks) == 1 * `filter(**kwargs) -> QuerySet` * `exclude(**kwargs) -> QuerySet` * `select_related(related: Union[List, str]) -> QuerySet` +* `prefetch_related(related: Union[List, str]) -> QuerySet` * `limit(limit_count: int) -> QuerySet` * `offset(offset: int) -> QuerySet` * `count() -> int` @@ -165,6 +167,7 @@ assert len(tracks) == 1 * `exclude_fields(columns: Union[List, str, set, dict]) -> QuerySet` * `order_by(columns:Union[List, str]) -> QuerySet` + #### Relation types * One to many - with `ForeignKey(to: Model)` diff --git a/docs/index.md b/docs/index.md index f69de18..b4b1890 100644 --- a/docs/index.md +++ b/docs/index.md @@ -45,7 +45,7 @@ Ormar is built with: * [`SQLAlchemy core`][sqlalchemy-core] for query building. * [`databases`][databases] for cross-database async support. * [`pydantic`][pydantic] for data validation. - * typing_extensions for python 3.6 - 3.7 + * `typing_extensions` for python 3.6 - 3.7 ### Migrations @@ -53,7 +53,8 @@ Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to p database migrations. -**ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.4.0` +**ormar is still under development:** +We recommend pinning any dependencies (with i.e. `ormar~=0.5.2`) ### Quick Start @@ -157,6 +158,7 @@ assert len(tracks) == 1 * `filter(**kwargs) -> QuerySet` * `exclude(**kwargs) -> QuerySet` * `select_related(related: Union[List, str]) -> QuerySet` +* `prefetch_related(related: Union[List, str]) -> QuerySet` * `limit(limit_count: int) -> QuerySet` * `offset(offset: int) -> QuerySet` * `count() -> int` diff --git a/docs/queries.md b/docs/queries.md index f7f3d7b..c848bc1 100644 --- a/docs/queries.md +++ b/docs/queries.md @@ -253,11 +253,25 @@ notes = await Track.objects.exclude(position_gt=3).all() `select_related(related: Union[List, str]) -> QuerySet` -Allows to prefetch related models. +Allows to prefetch related models during the same query. + +**With `select_related` always only one query is run against the database**, meaning that one +(sometimes complicated) join is generated and later nested models are processed in python. To fetch related model use `ForeignKey` names. -To chain related `Models` relation use double underscore. +To chain related `Models` relation use double underscores between names. + +!!!note + If you are coming from `django` note that `ormar` `select_related` differs -> in `django` you can `select_related` + only singe relation types, while in `ormar` you can select related across `ForeignKey` relation, + reverse side of `ForeignKey` (so virtual auto generated keys) and `ManyToMany` fields (so all relations as of current version). + +!!!tip + To control which model fields to select use `fields()` and `exclude_fields()` `QuerySet` methods. + +!!!tip + To control order of models (both main or nested) use `order_by()` method. ```python album = await Album.objects.select_related("tracks").all() @@ -286,6 +300,150 @@ Exactly the same behavior is for Many2Many fields, where you put the names of Ma Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` +### prefetch_related + +`prefetch_related(related: Union[List, str]) -> QuerySet` + +Allows to prefetch related models during query - but opposite to `select_related` each +subsequent model is fetched in a separate database query. + +**With `prefetch_related` always one query per Model is run against the database**, +meaning that you will have multiple queries executed one after another. + +To fetch related model use `ForeignKey` names. + +To chain related `Models` relation use double underscores between names. + +!!!tip + To control which model fields to select use `fields()` and `exclude_fields()` `QuerySet` methods. + +!!!tip + To control order of models (both main or nested) use `order_by()` method. + +```python +album = await Album.objects.prefetch_related("tracks").all() +# will return album will all columns tracks +``` + +You can provide a string or a list of strings + +```python +classes = await SchoolClass.objects.prefetch_related( +["teachers__category", "students"]).all() +# will return classes with teachers and teachers categories +# as well as classes students +``` + +Exactly the same behavior is for Many2Many fields, where you put the names of Many2Many fields and the final `Models` are fetched for you. + +!!!warning + If you set `ForeignKey` field as not nullable (so required) during + all queries the not nullable `Models` will be auto prefetched, even if you do not include them in select_related. + +!!!note + All methods that do not return the rows explicitly returns a QueySet instance so you can chain them together + + So operations like `filter()`, `select_related()`, `limit()` and `offset()` etc. can be chained. + + Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` + +### select_related vs prefetch_related + +Which should you use -> `select_related` or `prefetch_related`? + +Well, it really depends on your data. The best answer is try yourself and see which one performs faster/better in your system constraints. + +What to keep in mind: + +#### Performance + +**Number of queries**: +`select_related` always executes one query against the database, while `prefetch_related` executes multiple queries. +Usually the query (I/O) operation is the slowest one but it does not have to be. + +**Number of rows**: +Imagine that you have 10 000 object in one table A and each of those objects have 3 children in table B, +and subsequently each object in table B has 2 children in table C. Something like this: + +``` + Model C + / + Model B - Model C + / +Model A - Model B - Model C + \ \ + \ Model C + \ + Model B - Model C + \ + Model C +``` + +That means that `select_related` will always return 60 000 rows (10 000 * 3 * 2) later compacted to 10 000 models. + +How many rows will return `prefetch_related`? + +Well, that depends, if each of models B and C is unique it will return 10 000 rows in first query, 30 000 rows +(each of 3 children of A in table B are unique) in second query and 60 000 rows (each of 2 children of model B +in table C are unique) in 3rd query. + +In this case `select_related` seems like a better choice, not only it will run one query comparing to 3 of +`prefetch_related` but will also return 60 000 rows comparing to 100 000 of `prefetch_related` (10+30+60k). + +But what if each Model A has exactly the same 3 models B and each models C has exactly same models C? `select_related` +will still return 60 000 rows, while `prefetch_related` will return 10 000 for model A, 3 rows for model B and 2 rows for Model C. +So in total 10 006 rows. Now depending on the structure of models (i.e. if it has long Text() fields etc.) `prefetch_related` +might be faster despite it needs to perform three separate queries instead of one. + +#### Memory + +`ormar` is a mini ORM meaning that it does not keep a registry of already loaded models. + +That means that in `select_related` example above you will always have 10 000 Models A, 30 000 Models B +(even if the unique number of rows in db is 3 - processing of `select_related` spawns **new** child models for each parent model). +And 60 000 Models C. + +If the same Model B is shared by rows 1, 10, 100 etc. and you update one of those, the rest of rows +that share the same child will **not** be updated on the spot. +If you persist your changes into the database the change **will be available only after reload +(either each child separately or the whole query again)**. +That means that `select_related` will use more memory as each child is instantiated as a new object - obviously using it's own space. + +!!!note + This might change in future versions if we decide to introduce caching. + +!!!warning + By default all children (or event the same models loaded 2+ times) are completely independent, distinct python objects, despite that they represent the same row in db. + + They will evaluate to True when compared, so in example above: + + ```python + # will return True if child1 of both rows is the same child db row + row1.child1 == row100.child1 + + # same here: + model1 = await Model.get(pk=1) + model2 = await Model.get(pk=1) # same pk = same row in db + # will return `True` + model1 == model2 + ``` + + but + + ```python + # will return False (note that id is a python `builtin` function not ormar one). + id(row1.child1) == (ro100.child1) + + # from above - will also return False + id(model1) == id(model2) + ``` + + +On the contrary - with `prefetch_related` each unique distinct child model is instantiated +only once and the same child models is shared across all parent models. +That means that in `prefetch_related` example above if there are 3 distinct models in table B and 2 in table C, +there will be only 5 children nested models shared between all model A instances. That also means that if you update +any attribute it will be updated on all parents as they share the same child object. ### limit @@ -352,6 +510,10 @@ has_sample = await Book.objects.filter(title='Sample').exists() With `fields()` you can select subset of model columns to limit the data load. +!!!note + Note that `fields()` and `exclude_fields()` works both for main models (on normal queries like `get`, `all` etc.) + as well as `select_related` and `prefetch_related` models (with nested notation). + Given a sample data like following: ```python @@ -433,6 +595,10 @@ It's the opposite of `fields()` method so check documentation above to see what Especially check above how you can pass also nested dictionaries and sets as a mask to exclude fields from whole hierarchy. +!!!note + Note that `fields()` and `exclude_fields()` works both for main models (on normal queries like `get`, `all` etc.) + as well as `select_related` and `prefetch_related` models (with nested notation). + Below you can find few simple examples: ```python hl_lines="47 48 60 61 67" diff --git a/docs/releases.md b/docs/releases.md index e636ceb..cdbd6b0 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,13 @@ +# 0.5.2 + +* Added `prefetch_related` method to load subsequent models in separate queries. +* Update docs + +# 0.5.1 + +* Switched to github actions instead of travis +* Update badges in the docs + # 0.5.0 * Added save status -> you can check if model is saved with `ModelInstance.saved` property diff --git a/ormar/__init__.py b/ormar/__init__.py index 87f1b5b..cd3712a 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.5.1" +__version__ = "0.5.2" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 1ae257e..06b5060 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -1,12 +1,15 @@ import inspect from collections import OrderedDict from typing import ( + Any, + Callable, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, + Tuple, Type, TypeVar, Union, @@ -38,6 +41,8 @@ class ModelTableProxy: Meta: ModelMeta _related_names: Set _related_names_hash: Union[str, bytes] + pk: Any + get_name: Callable def dict(self): # noqa A003 raise NotImplementedError # pragma no cover @@ -47,6 +52,64 @@ class ModelTableProxy: self_fields = {k: v for k, v in self.dict().items() if k not in related_names} return self_fields + @classmethod + def get_related_field_name(cls, target_field: Type["BaseField"]) -> str: + if issubclass(target_field, ormar.fields.ManyToManyField): + return cls.resolve_relation_name(target_field.through, cls) + if target_field.virtual: + return cls.resolve_relation_name(target_field.to, cls) + return target_field.to.Meta.pkname + + @staticmethod + def get_clause_target_and_filter_column_name( + parent_model: Type["Model"], target_model: Type["Model"], reverse: bool + ) -> Tuple[Type["Model"], str]: + if reverse: + field = target_model.resolve_relation_field(target_model, parent_model) + if issubclass(field, ormar.fields.ManyToManyField): + sub_field = target_model.resolve_relation_field( + field.through, parent_model + ) + return field.through, sub_field.get_alias() + return target_model, field.get_alias() + target_field = target_model.get_column_alias(target_model.Meta.pkname) + return target_model, target_field + + @staticmethod + def get_column_name_for_id_extraction( + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool, + use_raw: bool, + ) -> str: + if reverse: + column_name = parent_model.Meta.pkname + return ( + parent_model.get_column_alias(column_name) if use_raw else column_name + ) + column = target_model.resolve_relation_field(parent_model, target_model) + return column.get_alias() if use_raw else column.name + + @classmethod + def get_filtered_names_to_extract(cls, prefetch_dict: Dict) -> List: + related_to_extract = [] + if prefetch_dict and prefetch_dict is not Ellipsis: + related_to_extract = [ + related + for related in cls.extract_related_names() + if related in prefetch_dict + ] + return related_to_extract + + def get_relation_model_id(self, target_field: Type["BaseField"]) -> Optional[int]: + if target_field.virtual or issubclass( + target_field, ormar.fields.ManyToManyField + ): + return self.pk + related_name = self.resolve_relation_name(self, target_field.to) + related_model = getattr(self, related_name) + return None if not related_model else related_model.pk + @classmethod def extract_db_own_fields(cls) -> Set: related_names = cls.extract_related_names() @@ -155,8 +218,18 @@ class ModelTableProxy: @staticmethod def resolve_relation_name( # noqa CCR001 - item: Union["NewBaseModel", Type["NewBaseModel"]], - related: Union["NewBaseModel", Type["NewBaseModel"]], + item: Union[ + "NewBaseModel", + Type["NewBaseModel"], + "ModelTableProxy", + Type["ModelTableProxy"], + ], + related: Union[ + "NewBaseModel", + Type["NewBaseModel"], + "ModelTableProxy", + Type["ModelTableProxy"], + ], ) -> str: for name, field in item.Meta.model_fields.items(): if issubclass(field, ForeignKeyField): @@ -236,9 +309,7 @@ class ModelTableProxy: @staticmethod def _populate_pk_column( - model: Type["Model"], - columns: List[str], - use_alias: bool = False, + model: Type["Model"], columns: List[str], use_alias: bool = False, ) -> List[str]: pk_alias = ( model.get_column_alias(model.Meta.pkname) diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py new file mode 100644 index 0000000..13ad785 --- /dev/null +++ b/ormar/queryset/prefetch_query.py @@ -0,0 +1,394 @@ +from typing import ( + Any, + Dict, + List, + Optional, + Sequence, + Set, + TYPE_CHECKING, + Tuple, + Type, + Union, +) + +import ormar +from ormar.fields import BaseField, ManyToManyField +from ormar.queryset.clause import QueryClause +from ormar.queryset.query import Query +from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict + +if TYPE_CHECKING: # pragma: no cover + from ormar import Model + + +def add_relation_field_to_fields( + fields: Union[Set[Any], Dict[Any, Any], None], related_field_name: str +) -> Union[Set[Any], Dict[Any, Any], None]: + if fields and related_field_name not in fields: + if isinstance(fields, dict): + fields[related_field_name] = ... + elif isinstance(fields, set): + fields.add(related_field_name) + return fields + + +def sort_models(models: List["Model"], orders_by: Dict) -> List["Model"]: + sort_criteria = [ + (key, value) for key, value in orders_by.items() if isinstance(value, str) + ] + sort_criteria = sort_criteria[::-1] + for criteria in sort_criteria: + key, value = criteria + if value == "desc": + models.sort(key=lambda x: getattr(x, key), reverse=True) + else: + models.sort(key=lambda x: getattr(x, key)) + return models + + +def set_children_on_model( # noqa: CCR001 + model: "Model", + related: str, + children: Dict, + model_id: int, + models: Dict, + orders_by: Dict, +) -> None: + for key, child_models in children.items(): + if key == model_id: + models_to_set = [models[child] for child in sorted(child_models)] + if models_to_set: + if orders_by and any(isinstance(x, str) for x in orders_by.values()): + models_to_set = sort_models( + models=models_to_set, orders_by=orders_by + ) + for child in models_to_set: + setattr(model, related, child) + + +class PrefetchQuery: + def __init__( # noqa: CFQ002 + self, + model_cls: Type["Model"], + fields: Optional[Union[Dict, Set]], + exclude_fields: Optional[Union[Dict, Set]], + prefetch_related: List, + select_related: List, + orders_by: List, + ) -> None: + + self.model = model_cls + self.database = self.model.Meta.database + self._prefetch_related = prefetch_related + self._select_related = select_related + self._exclude_columns = exclude_fields + self._columns = fields + self.already_extracted: Dict = dict() + self.models: Dict = {} + self.select_dict = translate_list_to_dict(self._select_related) + self.orders_by = orders_by or [] + self.order_dict = translate_list_to_dict(self.orders_by, is_order=True) + + async def prefetch_related( + self, models: Sequence["Model"], rows: List + ) -> Sequence["Model"]: + self.models = extract_models_to_dict_of_lists( + model_type=self.model, models=models, select_dict=self.select_dict + ) + self.models[self.model.get_name()] = models + return await self._prefetch_related_models(models=models, rows=rows) + + def _extract_ids_from_raw_data( + self, parent_model: Type["Model"], column_name: str + ) -> Set: + list_of_ids = set() + current_data = self.already_extracted.get(parent_model.get_name(), {}) + table_prefix = current_data.get("prefix", "") + column_name = (f"{table_prefix}_" if table_prefix else "") + column_name + for row in current_data.get("raw", []): + if row[column_name]: + list_of_ids.add(row[column_name]) + return list_of_ids + + def _extract_ids_from_preloaded_models( + self, parent_model: Type["Model"], column_name: str + ) -> Set: + list_of_ids = set() + for model in self.models.get(parent_model.get_name(), []): + child = getattr(model, column_name) + if isinstance(child, ormar.Model): + list_of_ids.add(child.pk) + else: + list_of_ids.add(child) + return list_of_ids + + def _extract_required_ids( + self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, + ) -> Set: + + use_raw = parent_model.get_name() not in self.models + + column_name = parent_model.get_column_name_for_id_extraction( + parent_model=parent_model, + target_model=target_model, + reverse=reverse, + use_raw=use_raw, + ) + + if use_raw: + return self._extract_ids_from_raw_data( + parent_model=parent_model, column_name=column_name + ) + + return self._extract_ids_from_preloaded_models( + parent_model=parent_model, column_name=column_name + ) + + def _get_filter_for_prefetch( + self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, + ) -> List: + ids = self._extract_required_ids( + parent_model=parent_model, target_model=target_model, reverse=reverse, + ) + if ids: + ( + clause_target, + filter_column, + ) = parent_model.get_clause_target_and_filter_column_name( + parent_model=parent_model, target_model=target_model, reverse=reverse + ) + qryclause = QueryClause( + model_cls=clause_target, select_related=[], filter_clauses=[], + ) + kwargs = {f"{filter_column}__in": ids} + filter_clauses, _ = qryclause.filter(**kwargs) + return filter_clauses + return [] + + def _populate_nested_related( + self, model: "Model", prefetch_dict: Dict, orders_by: Dict, + ) -> "Model": + + related_to_extract = model.get_filtered_names_to_extract( + prefetch_dict=prefetch_dict + ) + + for related in related_to_extract: + target_field = model.Meta.model_fields[related] + target_model = target_field.to.get_name() + model_id = model.get_relation_model_id(target_field=target_field) + + if model_id is None: # pragma: no cover + continue + + field_name = model.get_related_field_name(target_field=target_field) + + children = self.already_extracted.get(target_model, {}).get(field_name, {}) + models = self.already_extracted.get(target_model, {}).get("pk_models", {}) + set_children_on_model( + model=model, + related=related, + children=children, + model_id=model_id, + models=models, + orders_by=orders_by.get(related, {}), + ) + + return model + + async def _prefetch_related_models( + self, models: Sequence["Model"], rows: List + ) -> Sequence["Model"]: + self.already_extracted = {self.model.get_name(): {"raw": rows}} + select_dict = translate_list_to_dict(self._select_related) + prefetch_dict = translate_list_to_dict(self._prefetch_related) + target_model = self.model + fields = self._columns + exclude_fields = self._exclude_columns + orders_by = self.order_dict + for related in prefetch_dict.keys(): + await self._extract_related_models( + related=related, + target_model=target_model, + prefetch_dict=prefetch_dict.get(related, {}), + select_dict=select_dict.get(related, {}), + fields=fields, + exclude_fields=exclude_fields, + orders_by=orders_by.get(related, {}), + ) + final_models = [] + for model in models: + final_models.append( + self._populate_nested_related( + model=model, prefetch_dict=prefetch_dict, orders_by=self.order_dict + ) + ) + return models + + async def _extract_related_models( # noqa: CFQ002, CCR001 + self, + related: str, + target_model: Type["Model"], + prefetch_dict: Dict, + select_dict: Dict, + fields: Union[Set[Any], Dict[Any, Any], None], + exclude_fields: Union[Set[Any], Dict[Any, Any], None], + orders_by: Dict, + ) -> None: + + fields = target_model.get_included(fields, related) + exclude_fields = target_model.get_excluded(exclude_fields, related) + target_field = target_model.Meta.model_fields[related] + reverse = False + if target_field.virtual or issubclass(target_field, ManyToManyField): + reverse = True + + parent_model = target_model + + filter_clauses = self._get_filter_for_prefetch( + parent_model=parent_model, target_model=target_field.to, reverse=reverse, + ) + if not filter_clauses: # related field is empty + return + + already_loaded = select_dict is Ellipsis or related in select_dict + + if not already_loaded: + # If not already loaded with select_related + related_field_name = parent_model.get_related_field_name( + target_field=target_field + ) + fields = add_relation_field_to_fields( + fields=fields, related_field_name=related_field_name + ) + table_prefix, rows = await self._run_prefetch_query( + target_field=target_field, + fields=fields, + exclude_fields=exclude_fields, + filter_clauses=filter_clauses, + ) + else: + rows = [] + table_prefix = "" + + if prefetch_dict and prefetch_dict is not Ellipsis: + for subrelated in prefetch_dict.keys(): + await self._extract_related_models( + related=subrelated, + target_model=target_field.to, + prefetch_dict=prefetch_dict.get(subrelated, {}), + select_dict=self._get_select_related_if_apply( + subrelated, select_dict + ), + fields=fields, + exclude_fields=exclude_fields, + orders_by=self._get_select_related_if_apply(subrelated, orders_by), + ) + + if not already_loaded: + self._populate_rows( + rows=rows, + parent_model=parent_model, + target_field=target_field, + table_prefix=table_prefix, + fields=fields, + exclude_fields=exclude_fields, + prefetch_dict=prefetch_dict, + orders_by=orders_by, + ) + else: + self._update_already_loaded_rows( + target_field=target_field, + prefetch_dict=prefetch_dict, + orders_by=orders_by, + ) + + async def _run_prefetch_query( + self, + target_field: Type["BaseField"], + fields: Union[Set[Any], Dict[Any, Any], None], + exclude_fields: Union[Set[Any], Dict[Any, Any], None], + filter_clauses: List, + ) -> Tuple[str, List]: + target_model = target_field.to + target_name = target_model.get_name() + select_related = [] + query_target = target_model + table_prefix = "" + if issubclass(target_field, ManyToManyField): + query_target = target_field.through + select_related = [target_name] + table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( + from_table=query_target.Meta.tablename, + to_table=target_field.to.Meta.tablename, + ) + self.already_extracted.setdefault(target_name, {})["prefix"] = table_prefix + + qry = Query( + model_cls=query_target, + select_related=select_related, + filter_clauses=filter_clauses, + exclude_clauses=[], + offset=None, + limit_count=None, + fields=fields, + exclude_fields=exclude_fields, + order_bys=None, + ) + expr = qry.build_select_expression() + # print(expr.compile(compile_kwargs={"literal_binds": True})) + rows = await self.database.fetch_all(expr) + self.already_extracted.setdefault(target_name, {}).update({"raw": rows}) + return table_prefix, rows + + @staticmethod + def _get_select_related_if_apply(related: str, select_dict: Dict) -> Dict: + return ( + select_dict.get(related, {}) + if (select_dict and select_dict is not Ellipsis and related in select_dict) + else {} + ) + + def _update_already_loaded_rows( # noqa: CFQ002 + self, target_field: Type["BaseField"], prefetch_dict: Dict, orders_by: Dict, + ) -> None: + target_model = target_field.to + for instance in self.models.get(target_model.get_name(), []): + self._populate_nested_related( + model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by + ) + + def _populate_rows( # noqa: CFQ002 + self, + rows: List, + target_field: Type["BaseField"], + parent_model: Type["Model"], + table_prefix: str, + fields: Union[Set[Any], Dict[Any, Any], None], + exclude_fields: Union[Set[Any], Dict[Any, Any], None], + prefetch_dict: Dict, + orders_by: Dict, + ) -> None: + target_model = target_field.to + for row in rows: + field_name = parent_model.get_related_field_name(target_field=target_field) + item = target_model.extract_prefixed_table_columns( + item={}, + row=row, + table_prefix=table_prefix, + fields=fields, + exclude_fields=exclude_fields, + ) + instance = target_model(**item) + instance = self._populate_nested_related( + model=instance, prefetch_dict=prefetch_dict, orders_by=orders_by + ) + field_db_name = target_model.get_column_alias(field_name) + models = self.already_extracted[target_model.get_name()].setdefault( + "pk_models", {} + ) + if instance.pk not in models: + models[instance.pk] = instance + self.already_extracted[target_model.get_name()].setdefault( + field_name, dict() + ).setdefault(row[field_db_name], set()).add(instance.pk) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index b8b4aa3..02a0566 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -9,6 +9,7 @@ from ormar import MultipleMatches, NoMatch from ormar.exceptions import QueryDefinitionError from ormar.queryset import FilterQuery from ormar.queryset.clause import QueryClause +from ormar.queryset.prefetch_query import PrefetchQuery from ormar.queryset.query import Query from ormar.queryset.utils import update, update_dict_from_list @@ -30,11 +31,13 @@ class QuerySet: columns: Dict = None, exclude_columns: Dict = None, order_bys: List = None, + prefetch_related: List = None, ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses self.exclude_clauses = [] if exclude_clauses is None else exclude_clauses self._select_related = [] if select_related is None else select_related + self._prefetch_related = [] if prefetch_related is None else prefetch_related self.limit_count = limit_count self.query_offset = offset self._columns = columns or {} @@ -48,8 +51,7 @@ class QuerySet: ) -> "QuerySet": if issubclass(owner, ormar.Model): return self.__class__(model_cls=owner) - else: # pragma nocover - return self.__class__() + return self.__class__() # pragma: no cover @property def model_meta(self) -> "ModelMeta": @@ -63,6 +65,19 @@ class QuerySet: raise ValueError("Model class of QuerySet is not initialized") return self.model_cls + async def _prefetch_related_models( + self, models: Sequence[Optional["Model"]], rows: List + ) -> Sequence[Optional["Model"]]: + query = PrefetchQuery( + model_cls=self.model, + fields=self._columns, + exclude_fields=self._exclude_columns, + prefetch_related=self._prefetch_related, + select_related=self._select_related, + orders_by=self.order_bys, + ) + return await query.prefetch_related(models=models, rows=rows) # type: ignore + def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: result_rows = [ self.model.from_row( @@ -148,6 +163,7 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003 @@ -168,6 +184,25 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, + ) + + def prefetch_related(self, related: Union[List, str]) -> "QuerySet": + if not isinstance(related, list): + related = [related] + + related = list(set(list(self._prefetch_related) + related)) + return self.__class__( + model_cls=self.model, + filter_clauses=self.filter_clauses, + exclude_clauses=self.exclude_clauses, + select_related=self._select_related, + limit_count=self.limit_count, + offset=self.query_offset, + columns=self._columns, + exclude_columns=self._exclude_columns, + order_bys=self.order_bys, + prefetch_related=related, ) def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": @@ -190,6 +225,7 @@ class QuerySet: columns=self._columns, exclude_columns=current_excluded, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": @@ -212,6 +248,7 @@ class QuerySet: columns=current_included, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) def order_by(self, columns: Union[List, str]) -> "QuerySet": @@ -229,6 +266,7 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=order_bys, + prefetch_related=self._prefetch_related, ) async def exists(self) -> bool: @@ -279,6 +317,7 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) def offset(self, offset: int) -> "QuerySet": @@ -292,6 +331,7 @@ class QuerySet: columns=self._columns, exclude_columns=self._exclude_columns, order_bys=self.order_bys, + prefetch_related=self._prefetch_related, ) async def first(self, **kwargs: Any) -> "Model": @@ -312,6 +352,8 @@ class QuerySet: rows = await self.database.fetch_all(expr) processed_rows = self._process_query_result_rows(rows) + if self._prefetch_related and processed_rows: + processed_rows = await self._prefetch_related_models(processed_rows, rows) self.check_single_result_rows_count(processed_rows) return processed_rows[0] # type: ignore @@ -337,6 +379,8 @@ class QuerySet: expr = self.build_select_expression() rows = await self.database.fetch_all(expr) result_rows = self._process_query_result_rows(rows) + if self._prefetch_related and result_rows: + result_rows = await self._prefetch_related_models(result_rows, rows) return result_rows diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index c3c8fa9..bed2e25 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -1,6 +1,9 @@ import collections.abc import copy -from typing import Any, Dict, List, Set, Union +from typing import Any, Dict, List, Sequence, Set, TYPE_CHECKING, Type, Union + +if TYPE_CHECKING: # pragma no cover + from ormar import Model def check_node_not_dict_or_not_last_node( @@ -11,18 +14,28 @@ def check_node_not_dict_or_not_last_node( ) -def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001 +def translate_list_to_dict( # noqa: CCR001 + list_to_trans: Union[List, Set], is_order: bool = False +) -> Dict: new_dict: Dict = dict() for path in list_to_trans: current_level = new_dict parts = path.split("__") + def_val: Any = ... + if is_order: + if parts[0][0] == "-": + def_val = "desc" + parts[0] = parts[0][1:] + else: + def_val = "asc" + for part in parts: if check_node_not_dict_or_not_last_node( part=part, parts=parts, current_level=current_level ): current_level[part] = dict() elif part not in current_level: - current_level[part] = ... + current_level[part] = def_val current_level = current_level[part] return new_dict @@ -55,3 +68,39 @@ def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> dict_to_update = translate_list_to_dict(list_to_update) update(updated_dict, dict_to_update) return updated_dict + + +def extract_nested_models( # noqa: CCR001 + model: "Model", model_type: Type["Model"], select_dict: Dict, extracted: Dict +) -> None: + follow = [rel for rel in model_type.extract_related_names() if rel in select_dict] + for related in follow: + child = getattr(model, related) + if child: + target_model = model_type.Meta.model_fields[related].to + if isinstance(child, list): + extracted.setdefault(target_model.get_name(), []).extend(child) + if select_dict[related] is not Ellipsis: + for sub_child in child: + extract_nested_models( + sub_child, target_model, select_dict[related], extracted, + ) + else: + extracted.setdefault(target_model.get_name(), []).append(child) + if select_dict[related] is not Ellipsis: + extract_nested_models( + child, target_model, select_dict[related], extracted, + ) + + +def extract_models_to_dict_of_lists( + model_type: Type["Model"], + models: Sequence["Model"], + select_dict: Dict, + extracted: Dict = None, +) -> Dict: + if not extracted: + extracted = dict() + for model in models: + extract_nested_models(model, model_type, select_dict, extracted) + return extracted diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 4dafef7..34fe974 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -57,10 +57,6 @@ class Organisation(ormar.Model): ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"]) -class Organization(object): - pass - - class Team(ormar.Model): class Meta: tablename = "teams" @@ -241,8 +237,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -250,8 +246,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -296,8 +292,8 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: @@ -329,8 +325,8 @@ async def test_pk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1 diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py new file mode 100644 index 0000000..bfc09d7 --- /dev/null +++ b/tests/test_prefetch_related.py @@ -0,0 +1,311 @@ +from typing import List, Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class RandomSet(ormar.Model): + class Meta: + tablename = "randoms" + metadata = metadata + database = database + + id: int = ormar.Integer(name='random_id', primary_key=True) + name: str = ormar.String(max_length=100) + + +class Tonation(ormar.Model): + class Meta: + tablename = "tonations" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(name='tonation_name', max_length=100) + rand_set: Optional[RandomSet] = ormar.ForeignKey(RandomSet) + + +class Division(ormar.Model): + class Meta: + tablename = "divisions" + metadata = metadata + database = database + + id: int = ormar.Integer(name='division_id', primary_key=True) + name: str = ormar.String(max_length=100, nullable=True) + + +class Shop(ormar.Model): + class Meta: + tablename = "shops" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=True) + division: Optional[Division] = ormar.ForeignKey(Division) + + +class AlbumShops(ormar.Model): + class Meta: + tablename = "albums_x_shops" + metadata = metadata + database = database + + +class Album(ormar.Model): + class Meta: + tablename = "albums" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=True) + shops: List[Shop] = ormar.ManyToMany(to=Shop, through=AlbumShops) + + +class Track(ormar.Model): + class Meta: + tablename = "tracks" + metadata = metadata + database = database + + id: int = ormar.Integer(name='track_id', primary_key=True) + album: Optional[Album] = ormar.ForeignKey(Album) + title: str = ormar.String(max_length=100) + position: int = ormar.Integer() + tonation: Optional[Tonation] = ormar.ForeignKey(Tonation, name='tonation_id') + + +class Cover(ormar.Model): + class Meta: + tablename = "covers" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures", name='album_id') + title: str = ormar.String(max_length=100) + artist: str = ormar.String(max_length=200, nullable=True) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_prefetch_related(): + async with database: + async with database.transaction(force_rollback=True): + album = Album(name="Malibu") + await album.save() + ton1 = await Tonation.objects.create(name='B-mol') + await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) + await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) + await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) + await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') + await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + + fantasies = Album(name="Fantasies") + await fantasies.save() + await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) + await Cover.objects.create(title='Cover3', album=fantasies, artist='Artist 3') + await Cover.objects.create(title='Cover4', album=fantasies, artist='Artist 4') + + album = await Album.objects.filter(name='Malibu').prefetch_related( + ['tracks__tonation', 'cover_pictures']).get() + assert len(album.tracks) == 3 + assert album.tracks[0].title == 'The Bird' + assert len(album.cover_pictures) == 2 + assert album.cover_pictures[0].title == 'Cover1' + assert album.tracks[0].tonation.name == album.tracks[2].tonation.name == 'B-mol' + + albums = await Album.objects.prefetch_related('tracks').all() + assert len(albums[0].tracks) == 3 + assert len(albums[1].tracks) == 3 + assert albums[0].tracks[0].title == "The Bird" + assert albums[1].tracks[0].title == "Help I'm Alive" + + track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird") + assert track.album.name == "Malibu" + assert len(track.album.cover_pictures) == 2 + assert track.album.cover_pictures[0].artist == 'Artist 1' + + track = await Track.objects.prefetch_related(["album__cover_pictures"]).exclude_fields( + 'album__cover_pictures__artist').get(title="The Bird") + assert track.album.name == "Malibu" + assert len(track.album.cover_pictures) == 2 + assert track.album.cover_pictures[0].artist is None + + tracks = await Track.objects.prefetch_related("album").all() + assert len(tracks) == 6 + + +@pytest.mark.asyncio +async def test_prefetch_related_with_many_to_many(): + async with database: + async with database.transaction(force_rollback=True): + div = await Division.objects.create(name='Div 1') + shop1 = await Shop.objects.create(name='Shop 1', division=div) + shop2 = await Shop.objects.create(name='Shop 2', division=div) + album = Album(name="Malibu") + await album.save() + await album.shops.add(shop1) + await album.shops.add(shop2) + + await Track.objects.create(album=album, title="The Bird", position=1) + await Track.objects.create(album=album, title="Heart don't stand a chance", position=2) + await Track.objects.create(album=album, title="The Waters", position=3) + await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') + await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + + track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops__division"]).get( + title="The Bird") + assert track.album.name == "Malibu" + assert len(track.album.cover_pictures) == 2 + assert track.album.cover_pictures[0].artist == 'Artist 1' + + assert len(track.album.shops) == 2 + assert track.album.shops[0].name == 'Shop 1' + assert track.album.shops[0].division.name == 'Div 1' + + album2 = Album(name="Malibu 2") + await album2.save() + await album2.shops.add(shop1) + await album2.shops.add(shop2) + await Track.objects.create(album=album2, title="The Bird 2", position=1) + + tracks = await Track.objects.prefetch_related(["album__shops"]).all() + assert tracks[0].album.name == 'Malibu' + assert tracks[0].album.shops[0].name == "Shop 1" + assert tracks[3].album.name == 'Malibu 2' + assert tracks[3].album.shops[0].name == "Shop 1" + + assert tracks[0].album.shops[0] == tracks[3].album.shops[0] + assert id(tracks[0].album.shops[0]) == id(tracks[3].album.shops[0]) + tracks[0].album.shops[0].name = 'Dummy' + assert tracks[0].album.shops[0].name == tracks[3].album.shops[0].name + + +@pytest.mark.asyncio +async def test_prefetch_related_empty(): + async with database: + async with database.transaction(force_rollback=True): + await Track.objects.create(title="The Bird", position=1) + track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird") + assert track.title == 'The Bird' + assert track.album is None + + +@pytest.mark.asyncio +async def test_prefetch_related_with_select_related(): + async with database: + async with database.transaction(force_rollback=True): + div = await Division.objects.create(name='Div 1') + shop1 = await Shop.objects.create(name='Shop 1', division=div) + shop2 = await Shop.objects.create(name='Shop 2', division=div) + album = Album(name="Malibu") + await album.save() + await album.shops.add(shop1) + await album.shops.add(shop2) + + await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') + await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + + album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related( + ['cover_pictures', 'shops__division']).get() + + assert len(album.tracks) == 0 + assert len(album.cover_pictures) == 2 + assert album.shops[0].division.name == 'Div 1' + + rand_set = await RandomSet.objects.create(name='Rand 1') + ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set) + await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) + await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) + await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) + + album = await Album.objects.select_related('tracks__tonation__rand_set').filter( + name='Malibu').prefetch_related( + ['cover_pictures', 'shops__division']).order_by( + ['-shops__name', '-cover_pictures__artist', 'shops__division__name']).get() + assert len(album.tracks) == 3 + assert album.tracks[0].tonation == album.tracks[2].tonation == ton1 + assert len(album.cover_pictures) == 2 + assert album.cover_pictures[0].artist == 'Artist 2' + + assert len(album.shops) == 2 + assert album.shops[0].name == 'Shop 2' + assert album.shops[0].division.name == 'Div 1' + + track = await Track.objects.select_related('album').prefetch_related( + ["album__cover_pictures", "album__shops__division"]).get( + title="The Bird") + assert track.album.name == "Malibu" + assert len(track.album.cover_pictures) == 2 + assert track.album.cover_pictures[0].artist == 'Artist 1' + + assert len(track.album.shops) == 2 + assert track.album.shops[0].name == 'Shop 1' + assert track.album.shops[0].division.name == 'Div 1' + + +@pytest.mark.asyncio +async def test_prefetch_related_with_select_related_and_fields(): + async with database: + async with database.transaction(force_rollback=True): + div = await Division.objects.create(name='Div 1') + shop1 = await Shop.objects.create(name='Shop 1', division=div) + shop2 = await Shop.objects.create(name='Shop 2', division=div) + album = Album(name="Malibu") + await album.save() + await album.shops.add(shop1) + await album.shops.add(shop2) + await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') + await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + rand_set = await RandomSet.objects.create(name='Rand 1') + ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set) + await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) + await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) + await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) + + album = await Album.objects.select_related('tracks__tonation__rand_set').filter( + name='Malibu').prefetch_related( + ['cover_pictures', 'shops__division']).exclude_fields({'shops': {'division': {'name'}}}).get() + assert len(album.tracks) == 3 + assert album.tracks[0].tonation == album.tracks[2].tonation == ton1 + assert len(album.cover_pictures) == 2 + assert album.cover_pictures[0].artist == 'Artist 1' + + assert len(album.shops) == 2 + assert album.shops[0].name == 'Shop 1' + assert album.shops[0].division.name is None + + album = await Album.objects.select_related('tracks').filter( + name='Malibu').prefetch_related( + ['cover_pictures', 'shops__division']).fields( + {'name': ..., 'shops': {'division'}, 'cover_pictures': {'id': ..., 'title': ...}} + ).exclude_fields({'shops': {'division': {'name'}}}).get() + assert len(album.tracks) == 3 + assert len(album.cover_pictures) == 2 + assert album.cover_pictures[0].artist is None + assert album.cover_pictures[0].title is not None + + assert len(album.shops) == 2 + assert album.shops[0].name is None + assert album.shops[0].division is not None + assert album.shops[0].division.name is None diff --git a/tests/test_queryset_utils.py b/tests/test_queryset_utils.py index bac0a26..97695b9 100644 --- a/tests/test_queryset_utils.py +++ b/tests/test_queryset_utils.py @@ -1,5 +1,11 @@ +import databases +import sqlalchemy + +import ormar from ormar.models.excludable import Excludable +from ormar.queryset.prefetch_query import sort_models from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update +from tests.settings import DATABASE_URL def test_empty_excludable(): @@ -96,3 +102,43 @@ def test_updating_dict_inc_set_with_dict_inc_set(): "cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis}, "uu": Ellipsis, } + + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class SortModel(ormar.Model): + class Meta: + tablename = "sorts" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + sort_order: int = ormar.Integer() + + +def test_sorting_models(): + models = [ + SortModel(id=1, name='Alice', sort_order=0), + SortModel(id=2, name='Al', sort_order=1), + SortModel(id=3, name='Zake', sort_order=1), + SortModel(id=4, name='Will', sort_order=0), + SortModel(id=5, name='Al', sort_order=2), + SortModel(id=6, name='Alice', sort_order=2) + ] + orders_by = {'name': 'asc', 'none': {}, 'sort_order': 'desc'} + models = sort_models(models, orders_by) + assert models[5].name == 'Zake' + assert models[0].name == 'Al' + assert models[1].name == 'Al' + assert [model.id for model in models] == [5, 2, 6, 1, 4, 3] + + orders_by = {'name': 'asc', 'none': set('aa'), 'id': 'asc'} + models = sort_models(models, orders_by) + assert [model.id for model in models] == [2, 5, 1, 6, 4, 3] + + orders_by = {'sort_order': 'asc', 'none': ..., 'id': 'asc', 'uu': 2, 'aa': None} + models = sort_models(models, orders_by) + assert [model.id for model in models] == [1, 4, 2, 3, 5, 6]