diff --git a/.gitignore b/.gitignore index fc07f13..6c5114b 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ dist site profile.py *.db +*.db-journal diff --git a/ormar/models/model.py b/ormar/models/model.py index c20368d..1535113 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -22,7 +22,7 @@ class Model(ModelRow): __abstract__ = False if TYPE_CHECKING: # pragma nocover Meta: ModelMeta - objects: "QuerySet" + objects: "QuerySet[Model]" def __repr__(self) -> str: # pragma nocover _repr = {k: getattr(self, k) for k, v in self.Meta.model_fields.items()} diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index cc20807..d400c0f 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -310,7 +310,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass :rtype: Optional[Union[Model, List[Model]]] """ if item in self._orm: - return self._orm.get(item) + return self._orm.get(item) # type: ignore return None # pragma no cover def __eq__(self, other: object) -> bool: diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index e52ae4a..b5a3c5b 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -16,6 +16,7 @@ class Prefix: table_prefix: str model_cls: Type["Model"] relation_str: str + is_through: bool @property def alias_key(self) -> str: diff --git a/ormar/queryset/filter_action.py b/ormar/queryset/filter_action.py index 4f26864..d2d8e45 100644 --- a/ormar/queryset/filter_action.py +++ b/ormar/queryset/filter_action.py @@ -53,6 +53,7 @@ class FilterAction: self.table_prefix = "" self.source_model = model_cls self.target_model = model_cls + self.is_through = False self._determine_filter_target_table() self._escape_characters_in_clause() @@ -100,6 +101,7 @@ class FilterAction: self.table_prefix, self.target_model, self.related_str, + self.is_through, ) = get_relationship_alias_model_and_str(self.source_model, self.related_parts) def _escape_characters_in_clause(self) -> None: diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index e90e49d..4626fbd 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -290,9 +290,8 @@ class SqlJoin: self.get_order_bys( to_table=to_table, pkname_alias=pkname_alias, ) - else: - self.select_through_model_fields() + # TODO: fix fields and exclusions for through model? self_related_fields = self.next_model.own_table_columns( model=self.next_model, fields=self.fields, @@ -306,24 +305,6 @@ class SqlJoin: ) self.used_aliases.append(self.next_alias) - def select_through_model_fields(self) -> None: - # TODO: add docstring - next_alias = self.alias_manager.resolve_relation_alias( - from_model=self.target_field.owner, relation_name=self.relation_name - ) - # TODO: fix fields and exclusions - self_related_fields = self.target_field.through.own_table_columns( - model=self.target_field.through, - fields=None, - exclude_fields=self.target_field.through.extract_related_names(), - use_alias=True, - ) - self.columns.extend( - self.alias_manager.prefixed_columns( - next_alias, self.target_field.through.Meta.table, self_related_fields - ) - ) - def _replace_many_to_many_order_by_columns(self, part: str, new_part: str) -> None: """ Substitutes the name of the relation with actual model name in m2m order bys. diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index f1cbf43..6b98028 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -12,7 +12,6 @@ from typing import ( Union, ) - if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -218,7 +217,7 @@ def extract_models_to_dict_of_lists( def get_relationship_alias_model_and_str( source_model: Type["Model"], related_parts: List -) -> Tuple[str, Type["Model"], str]: +) -> Tuple[str, Type["Model"], str, bool]: """ Walks the relation to retrieve the actual model on which the clause should be constructed, extracts alias based on last relation leading to target model. @@ -230,11 +229,19 @@ def get_relationship_alias_model_and_str( :rtype: Tuple[str, Type["Model"], str] """ table_prefix = "" + is_through = False model_cls = source_model previous_model = model_cls manager = model_cls.Meta.alias_manager - for relation in related_parts: + for relation in related_parts[:]: related_field = model_cls.Meta.model_fields[relation] + if related_field.is_through: + is_through = True + related_parts = [ + x.replace(relation, related_field.related_name) if x == relation else x + for x in related_parts + ] + relation = related_field.related_name if related_field.is_multi: previous_model = related_field.through relation = related_field.default_target_field_name() # type: ignore @@ -245,4 +252,4 @@ def get_relationship_alias_model_and_str( previous_model = model_cls relation_str = "__".join(related_parts) - return table_prefix, model_cls, relation_str + return table_prefix, model_cls, relation_str, is_through diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index 031684b..85b7832 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -139,7 +139,7 @@ class QuerysetProxy(Generic[T]): :param child: child model instance :type child: Model """ - queryset = ormar.QuerySet(model_cls=self.relation.through) + queryset = ormar.QuerySet(model_cls=self.relation.through) # type: ignore owner_column = self.related_field.default_target_field_name() # type: ignore child_column = self.related_field.default_source_field_name() # type: ignore kwargs = {owner_column: self._owner, child_column: child} @@ -187,10 +187,10 @@ class QuerysetProxy(Generic[T]): :rtype: int """ if self.type_ == ormar.RelationType.MULTIPLE: - queryset = ormar.QuerySet(model_cls=self.relation.through) + queryset = ormar.QuerySet(model_cls=self.relation.through) # type: ignore owner_column = self._owner.get_name() else: - queryset = ormar.QuerySet(model_cls=self.relation.to) + queryset = ormar.QuerySet(model_cls=self.relation.to) # type: ignore owner_column = self.related_field.name kwargs = {owner_column: self._owner} self._clean_items_on_load() diff --git a/test.db-journal b/test.db-journal deleted file mode 100644 index f553864..0000000 Binary files a/test.db-journal and /dev/null differ diff --git a/tests/test_m2m_through_fields.py b/tests/test_m2m_through_fields.py index 279d1a8..9d7f38a 100644 --- a/tests/test_m2m_through_fields.py +++ b/tests/test_m2m_through_fields.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, TYPE_CHECKING import databases import pytest @@ -31,6 +31,7 @@ class PostCategory(ormar.Model): id: int = ormar.Integer(primary_key=True) sort_order: int = ormar.Integer(nullable=True) + param_name: str = ormar.String(default="Name", max_length=200) class Post(ormar.Model): @@ -109,10 +110,6 @@ async def test_setting_additional_fields_on_through_model_in_create(): assert postcat.sort_order == 2 -def process_post(post: Post): - pass - - @pytest.mark.asyncio async def test_getting_additional_fields_from_queryset() -> Any: async with database: @@ -132,9 +129,35 @@ async def test_getting_additional_fields_from_queryset() -> Any: categories__name="Test category2" ) assert post2.categories[0].postcategory.sort_order == 2 - process_post(post2) + # if TYPE_CHECKING: + # reveal_type(post2) +@pytest.mark.asyncio +async def test_filtering_by_through_model() -> Any: + async with database: + post = await Post(title="Test post").save() + await post.categories.create( + name="Test category1", + postcategory={"sort_order": 1, "param_name": "volume"}, + ) + await post.categories.create( + name="Test category2", postcategory={"sort_order": 2, "param_name": "area"} + ) + + post2 = ( + await Post.objects.filter(postcategory__sort_order__gt=1) + .select_related("categories") + .get() + ) + assert len(post2.categories) == 1 + assert post2.categories[0].postcategory.sort_order == 2 + + post3 = await Post.objects.filter( + categories__postcategory__param_name="volume").get() + assert len(post3.categories) == 1 + assert post3.categories[0].postcategory.param_name == "volume" + # TODO: check/ modify following # add to fields with class lower name (V) @@ -143,9 +166,9 @@ async def test_getting_additional_fields_from_queryset() -> Any: # creating in queryset proxy (dict with through name and kwargs) (V) # loading the data into model instance of though model (V) <- fix fields ane exclude # accessing from instance (V) <- no both sides only nested one is relevant, fix one side +# filtering in filter (through name normally) (V) < - table prefix from normal relation, check if is_through needed # updating in query -# sorting in filter (special __through__ notation?) # ordering by in order_by # modifying from instance (both sides?) # including/excluding in fields?