From 4e27f07a7e06ca838adddeb01fa9ce501a1dd96e Mon Sep 17 00:00:00 2001 From: collerek Date: Thu, 4 Mar 2021 13:12:07 +0100 Subject: [PATCH] som types fixes, fix for wrong prefixes in model_row for complex relations, test load_all with repeating tables, add docs --- docs/releases.md | 4 +++ ormar/fields/foreign_key.py | 4 +-- ormar/models/helpers/models.py | 7 ++-- ormar/models/mixins/relation_mixin.py | 40 +++++++++++++++++------ ormar/models/model.py | 16 +++++---- ormar/models/model_row.py | 25 +++++++++++--- ormar/queryset/queryset.py | 2 +- ormar/relations/relation.py | 2 +- tests/test_excluding_fields_in_fastapi.py | 3 +- tests/test_m2m_through_fields.py | 2 +- tests/test_more_same_table_joins.py | 14 ++++++++ 11 files changed, 90 insertions(+), 29 deletions(-) diff --git a/docs/releases.md b/docs/releases.md index ef6342a..2671b3d 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -15,6 +15,10 @@ in `ManyToMany` relations and in reverse `ForeignKey` relations. Note that update like in `QuerySet` `update` returns number of updated models and **does not update related models in place** on parent model. To get the refreshed data on parent model you need to refresh the related models (i.e. `await model_instance.related.all()`) +* Add `load_all(follow=False, exclude=None)` model method that allows to load current instance of the model + with all related models in one call. By default it loads only directly related models but setting + `follow=True` causes traversing the tree (avoiding loops). You can also pass `exclude` parameter + that works the same as `QuerySet.exclude_fields()` method. * Added possibility to add more fields on `Through` model for `ManyToMany` relationships: * name of the through model field is the lowercase name of the Through class * you can pass additional fields when calling `add(child, **kwargs)` on relation (on `QuerysetProxy`) diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 643b88d..65a9df0 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -460,7 +460,7 @@ class ForeignKeyField(BaseField): return model @classmethod - def get_relation_name(cls) -> str: + def get_relation_name(cls) -> str: # pragma: no cover """ Returns name of the relation, which can be a own name or through model names for m2m models @@ -471,7 +471,7 @@ class ForeignKeyField(BaseField): return cls.name @classmethod - def get_source_model(cls) -> Type["Model"]: + def get_source_model(cls) -> Type["Model"]: # pragma: no cover """ Returns model from which the relation comes -> either owner or through model diff --git a/ormar/models/helpers/models.py b/ormar/models/helpers/models.py index 6d67e91..e0b5d3c 100644 --- a/ormar/models/helpers/models.py +++ b/ormar/models/helpers/models.py @@ -1,3 +1,4 @@ +import collections import itertools import sqlite3 from typing import Any, Dict, List, TYPE_CHECKING, Tuple, Type @@ -123,7 +124,7 @@ def extract_annotations_and_default_vals(attrs: Dict) -> Tuple[Dict, Dict]: return attrs, model_fields -def group_related_list(list_: List) -> Dict: +def group_related_list(list_: List) -> collections.OrderedDict: """ Translates the list of related strings into a dictionary. That way nested models are grouped to traverse them in a right order @@ -152,7 +153,9 @@ def group_related_list(list_: List) -> Dict: result_dict[key] = group_related_list(new) else: result_dict.setdefault(key, []).extend(new) - return {k: v for k, v in sorted(result_dict.items(), key=lambda item: len(item[1]))} + return collections.OrderedDict( + sorted(result_dict.items(), key=lambda item: len(item[1])) + ) def meta_field_not_set(model: Type["Model"], field_name: str) -> bool: diff --git a/ormar/models/mixins/relation_mixin.py b/ormar/models/mixins/relation_mixin.py index 4fdd9e7..53fab3c 100644 --- a/ormar/models/mixins/relation_mixin.py +++ b/ormar/models/mixins/relation_mixin.py @@ -1,5 +1,13 @@ import inspect -from typing import List, Optional, Set, TYPE_CHECKING, Type, Union +from typing import ( + Callable, + List, + Optional, + Set, + TYPE_CHECKING, + Type, + Union, +) class RelationMixin: @@ -13,6 +21,7 @@ class RelationMixin: Meta: ModelMeta _related_names: Optional[Set] _related_fields: Optional[List] + get_name: Callable @classmethod def extract_db_own_fields(cls) -> Set: @@ -122,7 +131,8 @@ class RelationMixin: @classmethod def _iterate_related_models( cls, - visited: Set[Union[Type["Model"], Type["RelationMixin"]]] = None, + visited: Set[str] = None, + source_visited: Set[str] = None, source_relation: str = None, source_model: Union[Type["Model"], Type["RelationMixin"]] = None, ) -> List[str]: @@ -139,22 +149,24 @@ class RelationMixin: :return: list of relation strings to be passed to select_related :rtype: List[str] """ - visited = visited or set() - visited.add(cls) + source_visited = source_visited or set() + if not source_model: + source_visited = cls._populate_source_model_prefixes() relations = cls.extract_related_names() processed_relations = [] for relation in relations: target_model = cls.Meta.model_fields[relation].to if source_model and target_model == source_model: continue - if target_model not in visited: - visited.add(target_model) + if target_model not in source_visited or not source_model: deep_relations = target_model._iterate_related_models( - visited=visited, source_relation=relation, source_model=cls + visited=visited, + source_visited=source_visited, + source_relation=relation, + source_model=cls, ) processed_relations.extend(deep_relations) - # TODO add test for circular deps - else: # pragma: no cover + else: processed_relations.append(relation) if processed_relations: final_relations = [ @@ -163,5 +175,13 @@ class RelationMixin: ] else: final_relations = [source_relation] if source_relation else [] - return final_relations + + @classmethod + def _populate_source_model_prefixes(cls) -> Set: + relations = cls.extract_related_names() + visited = {cls} + for relation in relations: + target_model = cls.Meta.model_fields[relation].to + visited.add(target_model) + return visited diff --git a/ormar/models/model.py b/ormar/models/model.py index 26ef420..48b9f58 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -5,6 +5,7 @@ from typing import ( Set, TYPE_CHECKING, Tuple, + TypeVar, Union, ) @@ -17,6 +18,8 @@ from ormar.models.model_row import ModelRow if TYPE_CHECKING: # pragma nocover from ormar import QuerySet +T = TypeVar("T", bound="Model") + class Model(ModelRow): __abstract__ = False @@ -28,7 +31,7 @@ class Model(ModelRow): _repr = {k: getattr(self, k) for k, v in self.Meta.model_fields.items()} return f"{self.__class__.__name__}({str(_repr)})" - async def upsert(self, **kwargs: Any) -> "Model": + async def upsert(self: T, **kwargs: Any) -> T: """ Performs either a save or an update depending on the presence of the pk. If the pk field is filled it's an update, otherwise the save is performed. @@ -43,7 +46,7 @@ class Model(ModelRow): return await self.save() return await self.update(**kwargs) - async def save(self) -> "Model": + async def save(self: T) -> T: """ Performs a save of given Model instance. If primary key is already saved, db backend will throw integrity error. @@ -189,7 +192,7 @@ class Model(ModelRow): update_count += 1 return update_count, visited - async def update(self, **kwargs: Any) -> "Model": + async def update(self: T, **kwargs: Any) -> T: """ Performs update of Model instance in the database. Fields can be updated before or you can pass them as kwargs. @@ -248,7 +251,7 @@ class Model(ModelRow): await self.signals.post_delete.send(sender=self.__class__, instance=self) return result - async def load(self) -> "Model": + async def load(self: T) -> T: """ Allow to refresh existing Models fields from database. Be careful as the related models can be overwritten by pk_only models in load. @@ -270,8 +273,8 @@ class Model(ModelRow): return self async def load_all( - self, follow: bool = False, exclude: Union[List, str, Set, Dict] = None - ) -> "Model": + self: T, follow: bool = False, exclude: Union[List, str, Set, Dict] = None + ) -> T: """ Allow to refresh existing Models fields from database. Performs refresh of the related models fields. @@ -303,7 +306,6 @@ class Model(ModelRow): if follow: relations = self._iterate_related_models() queryset = self.__class__.objects - print(relations) if exclude: queryset = queryset.exclude_fields(exclude) instance = await queryset.select_related(relations).get(pk=self.pk) diff --git a/ormar/models/model_row.py b/ormar/models/model_row.py index a14d418..d9c674d 100644 --- a/ormar/models/model_row.py +++ b/ormar/models/model_row.py @@ -31,6 +31,7 @@ class ModelRow(NewBaseModel): excludable: ExcludableItems = None, current_relation_str: str = "", proxy_source_model: Optional[Type["Model"]] = None, + used_prefixes: List[str] = None, ) -> Optional["Model"]: """ Model method to convert raw sql row from database into ormar.Model instance. @@ -45,6 +46,8 @@ class ModelRow(NewBaseModel): where rows are populated in a different way as they do not have nested models in result. + :param used_prefixes: list of already extracted prefixes + :type used_prefixes: List[str] :param proxy_source_model: source model from which querysetproxy is constructed :type proxy_source_model: Optional[Type["ModelRow"]] :param excludable: structure of fields to include and exclude @@ -68,17 +71,28 @@ class ModelRow(NewBaseModel): select_related = select_related or [] related_models = related_models or [] table_prefix = "" + used_prefixes = used_prefixes if used_prefixes is not None else [] excludable = excludable or ExcludableItems() if select_related: related_models = group_related_list(select_related) if related_field: - table_prefix = cls.Meta.alias_manager.resolve_relation_alias_after_complex( - source_model=source_model, - relation_str=current_relation_str, - relation_field=related_field, + if related_field.is_multi: + previous_model = related_field.through + else: + previous_model = related_field.owner + table_prefix = cls.Meta.alias_manager.resolve_relation_alias( + from_model=previous_model, relation_name=related_field.name ) + if not table_prefix or table_prefix in used_prefixes: + manager = cls.Meta.alias_manager + table_prefix = manager.resolve_relation_alias_after_complex( + source_model=source_model, + relation_str=current_relation_str, + relation_field=related_field, + ) + used_prefixes.append(table_prefix) item = cls._populate_nested_models_from_row( item=item, @@ -89,6 +103,7 @@ class ModelRow(NewBaseModel): source_model=source_model, # type: ignore proxy_source_model=proxy_source_model, # type: ignore table_prefix=table_prefix, + used_prefixes=used_prefixes, ) item = cls.extract_prefixed_table_columns( item=item, row=row, table_prefix=table_prefix, excludable=excludable @@ -112,6 +127,7 @@ class ModelRow(NewBaseModel): related_models: Any, excludable: ExcludableItems, table_prefix: str, + used_prefixes: List[str], current_relation_str: str = None, proxy_source_model: Type["Model"] = None, ) -> dict: @@ -170,6 +186,7 @@ class ModelRow(NewBaseModel): current_relation_str=relation_str, source_model=source_model, proxy_source_model=proxy_source_model, + used_prefixes=used_prefixes, ) item[model_cls.get_column_name_from_alias(related)] = child if field.is_multi and child: diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 8adc4f3..d0679f5 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -344,7 +344,7 @@ class QuerySet: if not isinstance(related, list): related = [related] - related = list(set(list(self._select_related) + related)) + related = sorted(list(set(list(self._select_related) + related))) return self.rebuild_self(select_related=related,) def prefetch_related(self, related: Union[List, str]) -> "QuerySet": diff --git a/ormar/relations/relation.py b/ormar/relations/relation.py index 0f9be3d..bb7abd1 100644 --- a/ormar/relations/relation.py +++ b/ormar/relations/relation.py @@ -74,7 +74,7 @@ class Relation: self._owner.__dict__[self.field_name] = None elif self.related_models is not None: self.related_models._clear() - self._owner.__dict__[self.field_name] = [] + self._owner.__dict__[self.field_name] = None @property def through(self) -> Type["Model"]: diff --git a/tests/test_excluding_fields_in_fastapi.py b/tests/test_excluding_fields_in_fastapi.py index dd4dd84..1f0950f 100644 --- a/tests/test_excluding_fields_in_fastapi.py +++ b/tests/test_excluding_fields_in_fastapi.py @@ -124,7 +124,8 @@ async def create_user(user: User): @app.post("/users2/", response_model=User) async def create_user2(user: User): - return (await user.save()).dict(exclude={"password"}) + user = await user.save() + return user.dict(exclude={"password"}) @app.post("/users3/", response_model=UserBase) diff --git a/tests/test_m2m_through_fields.py b/tests/test_m2m_through_fields.py index 3f600e0..4ed7023 100644 --- a/tests/test_m2m_through_fields.py +++ b/tests/test_m2m_through_fields.py @@ -1,4 +1,4 @@ -from typing import Any, List, Sequence, cast +from typing import Any, Sequence, cast import databases import pytest diff --git a/tests/test_more_same_table_joins.py b/tests/test_more_same_table_joins.py index 9dc086e..b991d13 100644 --- a/tests/test_more_same_table_joins.py +++ b/tests/test_more_same_table_joins.py @@ -108,3 +108,17 @@ async def test_model_multiple_instances_of_same_table_in_schema(): assert len(classes[0].dict().get("students")) == 2 assert classes[0].teachers[0].category.department.name == "Law Department" assert classes[0].students[0].category.department.name == "Math Department" + + +@pytest.mark.asyncio +async def test_load_all_multiple_instances_of_same_table_in_schema(): + async with database: + await create_data() + math_class = await SchoolClass.objects.get(name="Math") + assert math_class.name == "Math" + + await math_class.load_all(follow=True) + assert math_class.students[0].name == "Jane" + assert len(math_class.dict().get("students")) == 2 + assert math_class.teachers[0].category.department.name == "Law Department" + assert math_class.students[0].category.department.name == "Math Department"