cleanup by related field

This commit is contained in:
collerek
2020-11-25 16:56:54 +01:00
parent f2fe41d38a
commit e0223f8a22

View File

@ -10,7 +10,6 @@ from typing import (
Union, Union,
) )
import ormar
from ormar.fields import BaseField, ManyToManyField from ormar.fields import BaseField, ManyToManyField
from ormar.queryset.clause import QueryClause from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query from ormar.queryset.query import Query
@ -44,10 +43,9 @@ class PrefetchQuery:
target_model: Type["Model"], target_model: Type["Model"],
reverse: bool, reverse: bool,
) -> Set: ) -> Set:
raw_rows = already_extracted.get(parent_model.get_name(), {}).get("raw", []) current_data = already_extracted.get(parent_model.get_name(), {})
table_prefix = already_extracted.get(parent_model.get_name(), {}).get( raw_rows = current_data.get("raw", [])
"prefix", "" table_prefix = current_data.get("prefix", "")
)
if reverse: if reverse:
column_name = parent_model.get_column_alias(parent_model.Meta.pkname) column_name = parent_model.get_column_alias(parent_model.Meta.pkname)
else: else:
@ -105,23 +103,33 @@ class PrefetchQuery:
target_field: Type["BaseField"], model: "Model" target_field: Type["BaseField"], model: "Model"
) -> Tuple[bool, Optional[str], Optional[int]]: ) -> Tuple[bool, Optional[str], Optional[int]]:
if target_field.virtual: if target_field.virtual:
reverse = True is_multi = False
field_name = model.resolve_relation_name(target_field.to, model) field_name = model.resolve_relation_name(target_field.to, model)
model_id = model.pk model_id = model.pk
elif issubclass(target_field, ManyToManyField): elif issubclass(target_field, ManyToManyField):
reverse = True is_multi = True
field_name = model.resolve_relation_name(target_field.through, model) field_name = model.resolve_relation_name(target_field.through, model)
model_id = model.pk model_id = model.pk
else: else:
reverse = False is_multi = False
related_name = model.resolve_relation_name(model, target_field.to) related_name = model.resolve_relation_name(model, target_field.to)
related_model = getattr(model, related_name) related_model = getattr(model, related_name)
if not related_model: if not related_model:
return reverse, None, None return is_multi, None, None
model_id = related_model.pk model_id = related_model.pk
field_name = target_field.to.Meta.pkname field_name = target_field.to.Meta.pkname
return reverse, field_name, model_id return is_multi, field_name, model_id
@staticmethod
def _get_group_field_name(
target_field: Type["BaseField"], model: Type["Model"]
) -> str:
if issubclass(target_field, ManyToManyField):
return model.resolve_relation_name(target_field.through, model)
if target_field.virtual:
return model.resolve_relation_name(target_field.to, model)
return target_field.to.Meta.pkname
@staticmethod @staticmethod
def _get_names_to_extract(prefetch_dict: Dict, model: "Model") -> List: def _get_names_to_extract(prefetch_dict: Dict, model: "Model") -> List:
@ -146,40 +154,17 @@ class PrefetchQuery:
for related in related_to_extract: for related in related_to_extract:
target_field = model.Meta.model_fields[related] target_field = model.Meta.model_fields[related]
target_model = target_field.to.get_name() target_model = target_field.to.get_name()
reverse, field_name, model_id = PrefetchQuery._get_model_id_and_field_name( is_multi, field_name, model_id = PrefetchQuery._get_model_id_and_field_name(
target_field=target_field, model=model target_field=target_field, model=model
) )
if not field_name:
continue
if ( children = already_extracted.get(target_model, {}).get(field_name, {})
target_model in already_extracted for key, child_models in children.items():
and already_extracted[target_model]["models"] if key == model_id:
): for child in child_models:
for key, child_model in already_extracted[target_model][ setattr(model, related, child)
"models"
].items():
if issubclass(target_field, ManyToManyField):
ind = next(
i
if key == x[target_field.to.get_column_alias(field_name)]
else -1
for i, x in enumerate(
already_extracted[target_model]["raw"]
)
)
raw_data = already_extracted[target_model]["raw"][ind]
if (
raw_data
and field_name in raw_data
and raw_data[field_name] == model_id
):
setattr(model, related, child_model)
elif isinstance(getattr(child_model, field_name), ormar.Model):
if getattr(child_model, field_name).pk == model_id:
setattr(model, related, child_model)
elif getattr(child_model, field_name) == model_id:
setattr(model, related, child_model)
return model return model
@ -203,7 +188,7 @@ class PrefetchQuery:
fields = self._columns fields = self._columns
exclude_fields = self._exclude_columns exclude_fields = self._exclude_columns
for related in prefetch_dict.keys(): for related in prefetch_dict.keys():
await self._extract_related_models( subrelated = await self._extract_related_models(
related=related, related=related,
target_model=target_model, target_model=target_model,
prefetch_dict=prefetch_dict.get(related), prefetch_dict=prefetch_dict.get(related),
@ -212,6 +197,7 @@ class PrefetchQuery:
fields=fields, fields=fields,
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
) )
print(related, subrelated)
final_models = [] final_models = []
for model in models: for model in models:
final_models.append( final_models.append(
@ -288,7 +274,7 @@ class PrefetchQuery:
if prefetch_dict and prefetch_dict is not Ellipsis: if prefetch_dict and prefetch_dict is not Ellipsis:
for subrelated in prefetch_dict.keys(): for subrelated in prefetch_dict.keys():
await self._extract_related_models( submodels = await self._extract_related_models(
related=subrelated, related=subrelated,
target_model=target_model, target_model=target_model,
prefetch_dict=prefetch_dict.get(subrelated), prefetch_dict=prefetch_dict.get(subrelated),
@ -299,8 +285,13 @@ class PrefetchQuery:
fields=fields, fields=fields,
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
) )
print(subrelated, submodels)
for row in rows: for row in rows:
field_name = PrefetchQuery._get_group_field_name(
target_field=target_field, model=parent_model
)
print("TEST", field_name, target_model, row[field_name])
item = target_model.extract_prefixed_table_columns( item = target_model.extract_prefixed_table_columns(
item={}, item={},
row=row, row=row,
@ -314,4 +305,9 @@ class PrefetchQuery:
already_extracted=already_extracted, already_extracted=already_extracted,
prefetch_dict=prefetch_dict, prefetch_dict=prefetch_dict,
) )
already_extracted[target_model.get_name()].setdefault(
field_name, dict()
).setdefault(row[field_name], []).append(instance)
already_extracted[target_model.get_name()]["models"][instance.pk] = instance already_extracted[target_model.get_name()]["models"][instance.pk] = instance
return already_extracted[target_model.get_name()]["models"]