some refactors and cleanup
This commit is contained in:
@ -21,6 +21,17 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
from ormar import Model
|
||||
|
||||
|
||||
def add_relation_field_to_fields(
|
||||
fields: Union[Set[Any], Dict[Any, Any], None], related_field_name: str
|
||||
) -> Union[Set[Any], Dict[Any, Any], None]:
|
||||
if fields and related_field_name not in fields:
|
||||
if isinstance(fields, dict):
|
||||
fields[related_field_name] = ...
|
||||
elif isinstance(fields, set):
|
||||
fields.add(related_field_name)
|
||||
return fields
|
||||
|
||||
|
||||
class PrefetchQuery:
|
||||
def __init__(
|
||||
self,
|
||||
@ -50,22 +61,6 @@ class PrefetchQuery:
|
||||
self.models[self.model.get_name()] = models
|
||||
return await self._prefetch_related_models(models=models, rows=rows)
|
||||
|
||||
@staticmethod
|
||||
def _get_column_name_for_id_extraction(
|
||||
parent_model: Type["Model"],
|
||||
target_model: Type["Model"],
|
||||
reverse: bool,
|
||||
use_raw: bool,
|
||||
) -> str:
|
||||
if reverse:
|
||||
column_name = parent_model.Meta.pkname
|
||||
return (
|
||||
parent_model.get_column_alias(column_name) if use_raw else column_name
|
||||
)
|
||||
else:
|
||||
column = target_model.resolve_relation_field(parent_model, target_model)
|
||||
return column.get_alias() if use_raw else column.name
|
||||
|
||||
def _extract_ids_from_raw_data(
|
||||
self, parent_model: Type["Model"], column_name: str
|
||||
) -> Set:
|
||||
@ -96,7 +91,7 @@ class PrefetchQuery:
|
||||
|
||||
use_raw = parent_model.get_name() not in self.models
|
||||
|
||||
column_name = self._get_column_name_for_id_extraction(
|
||||
column_name = parent_model.get_column_name_for_id_extraction(
|
||||
parent_model=parent_model,
|
||||
target_model=target_model,
|
||||
reverse=reverse,
|
||||
@ -112,22 +107,6 @@ class PrefetchQuery:
|
||||
parent_model=parent_model, column_name=column_name
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_clause_target_and_filter_column_name(
|
||||
parent_model: Type["Model"], target_model: Type["Model"], reverse: bool
|
||||
) -> Tuple[Type["Model"], str]:
|
||||
if reverse:
|
||||
field = target_model.resolve_relation_field(target_model, parent_model)
|
||||
if issubclass(field, ManyToManyField):
|
||||
sub_field = target_model.resolve_relation_field(
|
||||
field.through, parent_model
|
||||
)
|
||||
return field.through, sub_field.get_alias()
|
||||
else:
|
||||
return target_model, field.get_alias()
|
||||
target_field = target_model.get_column_alias(target_model.Meta.pkname)
|
||||
return target_model, target_field
|
||||
|
||||
def _get_filter_for_prefetch(
|
||||
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
|
||||
) -> List:
|
||||
@ -138,7 +117,7 @@ class PrefetchQuery:
|
||||
(
|
||||
clause_target,
|
||||
filter_column,
|
||||
) = self._get_clause_target_and_filter_column_name(
|
||||
) = parent_model.get_clause_target_and_filter_column_name(
|
||||
parent_model=parent_model, target_model=target_model, reverse=reverse
|
||||
)
|
||||
qryclause = QueryClause(
|
||||
@ -149,52 +128,21 @@ class PrefetchQuery:
|
||||
return filter_clauses
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _get_model_id(target_field: Type["BaseField"], model: "Model") -> Optional[int]:
|
||||
if target_field.virtual or issubclass(target_field, ManyToManyField):
|
||||
return model.pk
|
||||
related_name = model.resolve_relation_name(model, target_field.to)
|
||||
related_model = getattr(model, related_name)
|
||||
return None if not related_model else related_model.pk
|
||||
|
||||
@staticmethod
|
||||
def _get_related_field_name(
|
||||
target_field: Type["BaseField"], model: Union["Model", Type["Model"]]
|
||||
) -> str:
|
||||
if issubclass(target_field, ManyToManyField):
|
||||
return model.resolve_relation_name(target_field.through, model)
|
||||
if target_field.virtual:
|
||||
return model.resolve_relation_name(target_field.to, model)
|
||||
return target_field.to.Meta.pkname
|
||||
|
||||
@staticmethod
|
||||
def _get_names_to_extract(prefetch_dict: Dict, model: "Model") -> List:
|
||||
related_to_extract = []
|
||||
if prefetch_dict and prefetch_dict is not Ellipsis:
|
||||
related_to_extract = [
|
||||
related
|
||||
for related in model.extract_related_names()
|
||||
if related in prefetch_dict
|
||||
]
|
||||
return related_to_extract
|
||||
|
||||
def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model":
|
||||
|
||||
related_to_extract = self._get_names_to_extract(
|
||||
prefetch_dict=prefetch_dict, model=model
|
||||
related_to_extract = model.get_filtered_names_to_extract(
|
||||
prefetch_dict=prefetch_dict
|
||||
)
|
||||
|
||||
for related in related_to_extract:
|
||||
target_field = model.Meta.model_fields[related]
|
||||
target_model = target_field.to.get_name()
|
||||
model_id = self._get_model_id(target_field=target_field, model=model)
|
||||
model_id = model.get_relation_model_id(target_field=target_field)
|
||||
|
||||
if model_id is None: # pragma: no cover
|
||||
continue
|
||||
|
||||
field_name = self._get_related_field_name(
|
||||
target_field=target_field, model=model
|
||||
)
|
||||
field_name = model.get_related_field_name(target_field=target_field)
|
||||
|
||||
children = self.already_extracted.get(target_model, {}).get(field_name, {})
|
||||
self._set_children_on_model(
|
||||
@ -266,6 +214,12 @@ class PrefetchQuery:
|
||||
|
||||
if not already_loaded:
|
||||
# If not already loaded with select_related
|
||||
related_field_name = parent_model.get_related_field_name(
|
||||
target_field=target_field
|
||||
)
|
||||
fields = add_relation_field_to_fields(
|
||||
fields=fields, related_field_name=related_field_name
|
||||
)
|
||||
table_prefix, rows = await self._run_prefetch_query(
|
||||
target_field=target_field,
|
||||
fields=fields,
|
||||
@ -371,9 +325,7 @@ class PrefetchQuery:
|
||||
) -> None:
|
||||
target_model = target_field.to
|
||||
for row in rows:
|
||||
field_name = self._get_related_field_name(
|
||||
target_field=target_field, model=parent_model
|
||||
)
|
||||
field_name = parent_model.get_related_field_name(target_field=target_field)
|
||||
item = target_model.extract_prefixed_table_columns(
|
||||
item={},
|
||||
row=row,
|
||||
|
||||
Reference in New Issue
Block a user