diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 243e66e..9f1930c 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -150,30 +150,15 @@ class PrefetchQuery: return [] @staticmethod - def _get_model_id_and_field_name( - target_field: Type["BaseField"], model: "Model" - ) -> Tuple[bool, Optional[str], Optional[int]]: - if target_field.virtual: - is_multi = False - field_name = model.resolve_relation_name(target_field.to, model) - model_id = model.pk - elif issubclass(target_field, ManyToManyField): - is_multi = True - field_name = model.resolve_relation_name(target_field.through, model) - model_id = model.pk - else: - is_multi = False - related_name = model.resolve_relation_name(model, target_field.to) - related_model = getattr(model, related_name) - if not related_model: - return is_multi, None, None - model_id = related_model.pk - field_name = target_field.to.Meta.pkname - - return is_multi, field_name, model_id + def _get_model_id(target_field: Type["BaseField"], model: "Model") -> Optional[int]: + if target_field.virtual or issubclass(target_field, ManyToManyField): + return model.pk + related_name = model.resolve_relation_name(model, target_field.to) + related_model = getattr(model, related_name) + return None if not related_model else related_model.pk @staticmethod - def _get_group_field_name( + def _get_related_field_name( target_field: Type["BaseField"], model: Union["Model", Type["Model"]] ) -> str: if issubclass(target_field, ManyToManyField): @@ -202,13 +187,15 @@ class PrefetchQuery: for related in related_to_extract: target_field = model.Meta.model_fields[related] target_model = target_field.to.get_name() - is_multi, field_name, model_id = self._get_model_id_and_field_name( + model_id = self._get_model_id(target_field=target_field, model=model) + + if model_id is None: # pragma: no cover + continue + + field_name = self._get_related_field_name( target_field=target_field, model=model ) - if field_name is None or model_id is None: # pragma: no cover - continue - children = self.already_extracted.get(target_model, {}).get(field_name, {}) self._set_children_on_model( model=model, related=related, children=children, model_id=model_id @@ -384,7 +371,7 @@ class PrefetchQuery: ) -> None: target_model = target_field.to for row in rows: - field_name = self._get_group_field_name( + field_name = self._get_related_field_name( target_field=target_field, model=parent_model ) item = target_model.extract_prefixed_table_columns(