som types fixes, fix for wrong prefixes in model_row for complex relations, test load_all with repeating tables, add docs

This commit is contained in:
collerek
2021-03-04 13:12:07 +01:00
parent a8ae50276e
commit 4e27f07a7e
11 changed files with 90 additions and 29 deletions

View File

@ -15,6 +15,10 @@
in `ManyToMany` relations and in reverse `ForeignKey` relations. Note that update like in `QuerySet` `update` returns number of in `ManyToMany` relations and in reverse `ForeignKey` relations. Note that update like in `QuerySet` `update` returns number of
updated models and **does not update related models in place** on parent model. To get the refreshed data on parent model you need to refresh updated models and **does not update related models in place** on parent model. To get the refreshed data on parent model you need to refresh
the related models (i.e. `await model_instance.related.all()`) the related models (i.e. `await model_instance.related.all()`)
* Add `load_all(follow=False, exclude=None)` model method that allows to load current instance of the model
with all related models in one call. By default it loads only directly related models but setting
`follow=True` causes traversing the tree (avoiding loops). You can also pass `exclude` parameter
that works the same as `QuerySet.exclude_fields()` method.
* Added possibility to add more fields on `Through` model for `ManyToMany` relationships: * Added possibility to add more fields on `Through` model for `ManyToMany` relationships:
* name of the through model field is the lowercase name of the Through class * name of the through model field is the lowercase name of the Through class
* you can pass additional fields when calling `add(child, **kwargs)` on relation (on `QuerysetProxy`) * you can pass additional fields when calling `add(child, **kwargs)` on relation (on `QuerysetProxy`)

View File

@ -460,7 +460,7 @@ class ForeignKeyField(BaseField):
return model return model
@classmethod @classmethod
def get_relation_name(cls) -> str: def get_relation_name(cls) -> str: # pragma: no cover
""" """
Returns name of the relation, which can be a own name or through model Returns name of the relation, which can be a own name or through model
names for m2m models names for m2m models
@ -471,7 +471,7 @@ class ForeignKeyField(BaseField):
return cls.name return cls.name
@classmethod @classmethod
def get_source_model(cls) -> Type["Model"]: def get_source_model(cls) -> Type["Model"]: # pragma: no cover
""" """
Returns model from which the relation comes -> either owner or through model Returns model from which the relation comes -> either owner or through model

View File

@ -1,3 +1,4 @@
import collections
import itertools import itertools
import sqlite3 import sqlite3
from typing import Any, Dict, List, TYPE_CHECKING, Tuple, Type from typing import Any, Dict, List, TYPE_CHECKING, Tuple, Type
@ -123,7 +124,7 @@ def extract_annotations_and_default_vals(attrs: Dict) -> Tuple[Dict, Dict]:
return attrs, model_fields return attrs, model_fields
def group_related_list(list_: List) -> Dict: def group_related_list(list_: List) -> collections.OrderedDict:
""" """
Translates the list of related strings into a dictionary. Translates the list of related strings into a dictionary.
That way nested models are grouped to traverse them in a right order That way nested models are grouped to traverse them in a right order
@ -152,7 +153,9 @@ def group_related_list(list_: List) -> Dict:
result_dict[key] = group_related_list(new) result_dict[key] = group_related_list(new)
else: else:
result_dict.setdefault(key, []).extend(new) result_dict.setdefault(key, []).extend(new)
return {k: v for k, v in sorted(result_dict.items(), key=lambda item: len(item[1]))} return collections.OrderedDict(
sorted(result_dict.items(), key=lambda item: len(item[1]))
)
def meta_field_not_set(model: Type["Model"], field_name: str) -> bool: def meta_field_not_set(model: Type["Model"], field_name: str) -> bool:

View File

@ -1,5 +1,13 @@
import inspect import inspect
from typing import List, Optional, Set, TYPE_CHECKING, Type, Union from typing import (
Callable,
List,
Optional,
Set,
TYPE_CHECKING,
Type,
Union,
)
class RelationMixin: class RelationMixin:
@ -13,6 +21,7 @@ class RelationMixin:
Meta: ModelMeta Meta: ModelMeta
_related_names: Optional[Set] _related_names: Optional[Set]
_related_fields: Optional[List] _related_fields: Optional[List]
get_name: Callable
@classmethod @classmethod
def extract_db_own_fields(cls) -> Set: def extract_db_own_fields(cls) -> Set:
@ -122,7 +131,8 @@ class RelationMixin:
@classmethod @classmethod
def _iterate_related_models( def _iterate_related_models(
cls, cls,
visited: Set[Union[Type["Model"], Type["RelationMixin"]]] = None, visited: Set[str] = None,
source_visited: Set[str] = None,
source_relation: str = None, source_relation: str = None,
source_model: Union[Type["Model"], Type["RelationMixin"]] = None, source_model: Union[Type["Model"], Type["RelationMixin"]] = None,
) -> List[str]: ) -> List[str]:
@ -139,22 +149,24 @@ class RelationMixin:
:return: list of relation strings to be passed to select_related :return: list of relation strings to be passed to select_related
:rtype: List[str] :rtype: List[str]
""" """
visited = visited or set() source_visited = source_visited or set()
visited.add(cls) if not source_model:
source_visited = cls._populate_source_model_prefixes()
relations = cls.extract_related_names() relations = cls.extract_related_names()
processed_relations = [] processed_relations = []
for relation in relations: for relation in relations:
target_model = cls.Meta.model_fields[relation].to target_model = cls.Meta.model_fields[relation].to
if source_model and target_model == source_model: if source_model and target_model == source_model:
continue continue
if target_model not in visited: if target_model not in source_visited or not source_model:
visited.add(target_model)
deep_relations = target_model._iterate_related_models( deep_relations = target_model._iterate_related_models(
visited=visited, source_relation=relation, source_model=cls visited=visited,
source_visited=source_visited,
source_relation=relation,
source_model=cls,
) )
processed_relations.extend(deep_relations) processed_relations.extend(deep_relations)
# TODO add test for circular deps else:
else: # pragma: no cover
processed_relations.append(relation) processed_relations.append(relation)
if processed_relations: if processed_relations:
final_relations = [ final_relations = [
@ -163,5 +175,13 @@ class RelationMixin:
] ]
else: else:
final_relations = [source_relation] if source_relation else [] final_relations = [source_relation] if source_relation else []
return final_relations return final_relations
@classmethod
def _populate_source_model_prefixes(cls) -> Set:
relations = cls.extract_related_names()
visited = {cls}
for relation in relations:
target_model = cls.Meta.model_fields[relation].to
visited.add(target_model)
return visited

View File

@ -5,6 +5,7 @@ from typing import (
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
Tuple, Tuple,
TypeVar,
Union, Union,
) )
@ -17,6 +18,8 @@ from ormar.models.model_row import ModelRow
if TYPE_CHECKING: # pragma nocover if TYPE_CHECKING: # pragma nocover
from ormar import QuerySet from ormar import QuerySet
T = TypeVar("T", bound="Model")
class Model(ModelRow): class Model(ModelRow):
__abstract__ = False __abstract__ = False
@ -28,7 +31,7 @@ class Model(ModelRow):
_repr = {k: getattr(self, k) for k, v in self.Meta.model_fields.items()} _repr = {k: getattr(self, k) for k, v in self.Meta.model_fields.items()}
return f"{self.__class__.__name__}({str(_repr)})" return f"{self.__class__.__name__}({str(_repr)})"
async def upsert(self, **kwargs: Any) -> "Model": async def upsert(self: T, **kwargs: Any) -> T:
""" """
Performs either a save or an update depending on the presence of the pk. Performs either a save or an update depending on the presence of the pk.
If the pk field is filled it's an update, otherwise the save is performed. If the pk field is filled it's an update, otherwise the save is performed.
@ -43,7 +46,7 @@ class Model(ModelRow):
return await self.save() return await self.save()
return await self.update(**kwargs) return await self.update(**kwargs)
async def save(self) -> "Model": async def save(self: T) -> T:
""" """
Performs a save of given Model instance. Performs a save of given Model instance.
If primary key is already saved, db backend will throw integrity error. If primary key is already saved, db backend will throw integrity error.
@ -189,7 +192,7 @@ class Model(ModelRow):
update_count += 1 update_count += 1
return update_count, visited return update_count, visited
async def update(self, **kwargs: Any) -> "Model": async def update(self: T, **kwargs: Any) -> T:
""" """
Performs update of Model instance in the database. Performs update of Model instance in the database.
Fields can be updated before or you can pass them as kwargs. Fields can be updated before or you can pass them as kwargs.
@ -248,7 +251,7 @@ class Model(ModelRow):
await self.signals.post_delete.send(sender=self.__class__, instance=self) await self.signals.post_delete.send(sender=self.__class__, instance=self)
return result return result
async def load(self) -> "Model": async def load(self: T) -> T:
""" """
Allow to refresh existing Models fields from database. Allow to refresh existing Models fields from database.
Be careful as the related models can be overwritten by pk_only models in load. Be careful as the related models can be overwritten by pk_only models in load.
@ -270,8 +273,8 @@ class Model(ModelRow):
return self return self
async def load_all( async def load_all(
self, follow: bool = False, exclude: Union[List, str, Set, Dict] = None self: T, follow: bool = False, exclude: Union[List, str, Set, Dict] = None
) -> "Model": ) -> T:
""" """
Allow to refresh existing Models fields from database. Allow to refresh existing Models fields from database.
Performs refresh of the related models fields. Performs refresh of the related models fields.
@ -303,7 +306,6 @@ class Model(ModelRow):
if follow: if follow:
relations = self._iterate_related_models() relations = self._iterate_related_models()
queryset = self.__class__.objects queryset = self.__class__.objects
print(relations)
if exclude: if exclude:
queryset = queryset.exclude_fields(exclude) queryset = queryset.exclude_fields(exclude)
instance = await queryset.select_related(relations).get(pk=self.pk) instance = await queryset.select_related(relations).get(pk=self.pk)

View File

@ -31,6 +31,7 @@ class ModelRow(NewBaseModel):
excludable: ExcludableItems = None, excludable: ExcludableItems = None,
current_relation_str: str = "", current_relation_str: str = "",
proxy_source_model: Optional[Type["Model"]] = None, proxy_source_model: Optional[Type["Model"]] = None,
used_prefixes: List[str] = None,
) -> Optional["Model"]: ) -> Optional["Model"]:
""" """
Model method to convert raw sql row from database into ormar.Model instance. Model method to convert raw sql row from database into ormar.Model instance.
@ -45,6 +46,8 @@ class ModelRow(NewBaseModel):
where rows are populated in a different way as they do not have where rows are populated in a different way as they do not have
nested models in result. nested models in result.
:param used_prefixes: list of already extracted prefixes
:type used_prefixes: List[str]
:param proxy_source_model: source model from which querysetproxy is constructed :param proxy_source_model: source model from which querysetproxy is constructed
:type proxy_source_model: Optional[Type["ModelRow"]] :type proxy_source_model: Optional[Type["ModelRow"]]
:param excludable: structure of fields to include and exclude :param excludable: structure of fields to include and exclude
@ -68,17 +71,28 @@ class ModelRow(NewBaseModel):
select_related = select_related or [] select_related = select_related or []
related_models = related_models or [] related_models = related_models or []
table_prefix = "" table_prefix = ""
used_prefixes = used_prefixes if used_prefixes is not None else []
excludable = excludable or ExcludableItems() excludable = excludable or ExcludableItems()
if select_related: if select_related:
related_models = group_related_list(select_related) related_models = group_related_list(select_related)
if related_field: if related_field:
table_prefix = cls.Meta.alias_manager.resolve_relation_alias_after_complex( if related_field.is_multi:
previous_model = related_field.through
else:
previous_model = related_field.owner
table_prefix = cls.Meta.alias_manager.resolve_relation_alias(
from_model=previous_model, relation_name=related_field.name
)
if not table_prefix or table_prefix in used_prefixes:
manager = cls.Meta.alias_manager
table_prefix = manager.resolve_relation_alias_after_complex(
source_model=source_model, source_model=source_model,
relation_str=current_relation_str, relation_str=current_relation_str,
relation_field=related_field, relation_field=related_field,
) )
used_prefixes.append(table_prefix)
item = cls._populate_nested_models_from_row( item = cls._populate_nested_models_from_row(
item=item, item=item,
@ -89,6 +103,7 @@ class ModelRow(NewBaseModel):
source_model=source_model, # type: ignore source_model=source_model, # type: ignore
proxy_source_model=proxy_source_model, # type: ignore proxy_source_model=proxy_source_model, # type: ignore
table_prefix=table_prefix, table_prefix=table_prefix,
used_prefixes=used_prefixes,
) )
item = cls.extract_prefixed_table_columns( item = cls.extract_prefixed_table_columns(
item=item, row=row, table_prefix=table_prefix, excludable=excludable item=item, row=row, table_prefix=table_prefix, excludable=excludable
@ -112,6 +127,7 @@ class ModelRow(NewBaseModel):
related_models: Any, related_models: Any,
excludable: ExcludableItems, excludable: ExcludableItems,
table_prefix: str, table_prefix: str,
used_prefixes: List[str],
current_relation_str: str = None, current_relation_str: str = None,
proxy_source_model: Type["Model"] = None, proxy_source_model: Type["Model"] = None,
) -> dict: ) -> dict:
@ -170,6 +186,7 @@ class ModelRow(NewBaseModel):
current_relation_str=relation_str, current_relation_str=relation_str,
source_model=source_model, source_model=source_model,
proxy_source_model=proxy_source_model, proxy_source_model=proxy_source_model,
used_prefixes=used_prefixes,
) )
item[model_cls.get_column_name_from_alias(related)] = child item[model_cls.get_column_name_from_alias(related)] = child
if field.is_multi and child: if field.is_multi and child:

View File

@ -344,7 +344,7 @@ class QuerySet:
if not isinstance(related, list): if not isinstance(related, list):
related = [related] related = [related]
related = list(set(list(self._select_related) + related)) related = sorted(list(set(list(self._select_related) + related)))
return self.rebuild_self(select_related=related,) return self.rebuild_self(select_related=related,)
def prefetch_related(self, related: Union[List, str]) -> "QuerySet": def prefetch_related(self, related: Union[List, str]) -> "QuerySet":

View File

@ -74,7 +74,7 @@ class Relation:
self._owner.__dict__[self.field_name] = None self._owner.__dict__[self.field_name] = None
elif self.related_models is not None: elif self.related_models is not None:
self.related_models._clear() self.related_models._clear()
self._owner.__dict__[self.field_name] = [] self._owner.__dict__[self.field_name] = None
@property @property
def through(self) -> Type["Model"]: def through(self) -> Type["Model"]:

View File

@ -124,7 +124,8 @@ async def create_user(user: User):
@app.post("/users2/", response_model=User) @app.post("/users2/", response_model=User)
async def create_user2(user: User): async def create_user2(user: User):
return (await user.save()).dict(exclude={"password"}) user = await user.save()
return user.dict(exclude={"password"})
@app.post("/users3/", response_model=UserBase) @app.post("/users3/", response_model=UserBase)

View File

@ -1,4 +1,4 @@
from typing import Any, List, Sequence, cast from typing import Any, Sequence, cast
import databases import databases
import pytest import pytest

View File

@ -108,3 +108,17 @@ async def test_model_multiple_instances_of_same_table_in_schema():
assert len(classes[0].dict().get("students")) == 2 assert len(classes[0].dict().get("students")) == 2
assert classes[0].teachers[0].category.department.name == "Law Department" assert classes[0].teachers[0].category.department.name == "Law Department"
assert classes[0].students[0].category.department.name == "Math Department" assert classes[0].students[0].category.department.name == "Math Department"
@pytest.mark.asyncio
async def test_load_all_multiple_instances_of_same_table_in_schema():
async with database:
await create_data()
math_class = await SchoolClass.objects.get(name="Math")
assert math_class.name == "Math"
await math_class.load_all(follow=True)
assert math_class.students[0].name == "Jane"
assert len(math_class.dict().get("students")) == 2
assert math_class.teachers[0].category.department.name == "Law Department"
assert math_class.students[0].category.department.name == "Math Department"