some refactors and cleanup
This commit is contained in:
@ -1,12 +1,15 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -38,6 +41,8 @@ class ModelTableProxy:
|
|||||||
Meta: ModelMeta
|
Meta: ModelMeta
|
||||||
_related_names: Set
|
_related_names: Set
|
||||||
_related_names_hash: Union[str, bytes]
|
_related_names_hash: Union[str, bytes]
|
||||||
|
pk: Any
|
||||||
|
get_name: Callable
|
||||||
|
|
||||||
def dict(self): # noqa A003
|
def dict(self): # noqa A003
|
||||||
raise NotImplementedError # pragma no cover
|
raise NotImplementedError # pragma no cover
|
||||||
@ -47,6 +52,66 @@ class ModelTableProxy:
|
|||||||
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
|
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
|
||||||
return self_fields
|
return self_fields
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_related_field_name(cls, target_field: Type["BaseField"]) -> str:
|
||||||
|
if issubclass(target_field, ormar.fields.ManyToManyField):
|
||||||
|
return cls.resolve_relation_name(target_field.through, cls)
|
||||||
|
if target_field.virtual:
|
||||||
|
return cls.resolve_relation_name(target_field.to, cls)
|
||||||
|
return target_field.to.Meta.pkname
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_clause_target_and_filter_column_name(
|
||||||
|
parent_model: Type["Model"], target_model: Type["Model"], reverse: bool
|
||||||
|
) -> Tuple[Type["Model"], str]:
|
||||||
|
if reverse:
|
||||||
|
field = target_model.resolve_relation_field(target_model, parent_model)
|
||||||
|
if issubclass(field, ormar.fields.ManyToManyField):
|
||||||
|
sub_field = target_model.resolve_relation_field(
|
||||||
|
field.through, parent_model
|
||||||
|
)
|
||||||
|
return field.through, sub_field.get_alias()
|
||||||
|
else:
|
||||||
|
return target_model, field.get_alias()
|
||||||
|
target_field = target_model.get_column_alias(target_model.Meta.pkname)
|
||||||
|
return target_model, target_field
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_column_name_for_id_extraction(
|
||||||
|
parent_model: Type["Model"],
|
||||||
|
target_model: Type["Model"],
|
||||||
|
reverse: bool,
|
||||||
|
use_raw: bool,
|
||||||
|
) -> str:
|
||||||
|
if reverse:
|
||||||
|
column_name = parent_model.Meta.pkname
|
||||||
|
return (
|
||||||
|
parent_model.get_column_alias(column_name) if use_raw else column_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
column = target_model.resolve_relation_field(parent_model, target_model)
|
||||||
|
return column.get_alias() if use_raw else column.name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_filtered_names_to_extract(cls, prefetch_dict: Dict) -> List:
|
||||||
|
related_to_extract = []
|
||||||
|
if prefetch_dict and prefetch_dict is not Ellipsis:
|
||||||
|
related_to_extract = [
|
||||||
|
related
|
||||||
|
for related in cls.extract_related_names()
|
||||||
|
if related in prefetch_dict
|
||||||
|
]
|
||||||
|
return related_to_extract
|
||||||
|
|
||||||
|
def get_relation_model_id(self, target_field: Type["BaseField"]) -> Optional[int]:
|
||||||
|
if target_field.virtual or issubclass(
|
||||||
|
target_field, ormar.fields.ManyToManyField
|
||||||
|
):
|
||||||
|
return self.pk
|
||||||
|
related_name = self.resolve_relation_name(self, target_field.to)
|
||||||
|
related_model = getattr(self, related_name)
|
||||||
|
return None if not related_model else related_model.pk
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_db_own_fields(cls) -> Set:
|
def extract_db_own_fields(cls) -> Set:
|
||||||
related_names = cls.extract_related_names()
|
related_names = cls.extract_related_names()
|
||||||
@ -155,8 +220,18 @@ class ModelTableProxy:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resolve_relation_name( # noqa CCR001
|
def resolve_relation_name( # noqa CCR001
|
||||||
item: Union["NewBaseModel", Type["NewBaseModel"]],
|
item: Union[
|
||||||
related: Union["NewBaseModel", Type["NewBaseModel"]],
|
"NewBaseModel",
|
||||||
|
Type["NewBaseModel"],
|
||||||
|
"ModelTableProxy",
|
||||||
|
Type["ModelTableProxy"],
|
||||||
|
],
|
||||||
|
related: Union[
|
||||||
|
"NewBaseModel",
|
||||||
|
Type["NewBaseModel"],
|
||||||
|
"ModelTableProxy",
|
||||||
|
Type["ModelTableProxy"],
|
||||||
|
],
|
||||||
) -> str:
|
) -> str:
|
||||||
for name, field in item.Meta.model_fields.items():
|
for name, field in item.Meta.model_fields.items():
|
||||||
if issubclass(field, ForeignKeyField):
|
if issubclass(field, ForeignKeyField):
|
||||||
|
|||||||
@ -21,6 +21,17 @@ if TYPE_CHECKING: # pragma: no cover
|
|||||||
from ormar import Model
|
from ormar import Model
|
||||||
|
|
||||||
|
|
||||||
|
def add_relation_field_to_fields(
|
||||||
|
fields: Union[Set[Any], Dict[Any, Any], None], related_field_name: str
|
||||||
|
) -> Union[Set[Any], Dict[Any, Any], None]:
|
||||||
|
if fields and related_field_name not in fields:
|
||||||
|
if isinstance(fields, dict):
|
||||||
|
fields[related_field_name] = ...
|
||||||
|
elif isinstance(fields, set):
|
||||||
|
fields.add(related_field_name)
|
||||||
|
return fields
|
||||||
|
|
||||||
|
|
||||||
class PrefetchQuery:
|
class PrefetchQuery:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -50,22 +61,6 @@ class PrefetchQuery:
|
|||||||
self.models[self.model.get_name()] = models
|
self.models[self.model.get_name()] = models
|
||||||
return await self._prefetch_related_models(models=models, rows=rows)
|
return await self._prefetch_related_models(models=models, rows=rows)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_column_name_for_id_extraction(
|
|
||||||
parent_model: Type["Model"],
|
|
||||||
target_model: Type["Model"],
|
|
||||||
reverse: bool,
|
|
||||||
use_raw: bool,
|
|
||||||
) -> str:
|
|
||||||
if reverse:
|
|
||||||
column_name = parent_model.Meta.pkname
|
|
||||||
return (
|
|
||||||
parent_model.get_column_alias(column_name) if use_raw else column_name
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
column = target_model.resolve_relation_field(parent_model, target_model)
|
|
||||||
return column.get_alias() if use_raw else column.name
|
|
||||||
|
|
||||||
def _extract_ids_from_raw_data(
|
def _extract_ids_from_raw_data(
|
||||||
self, parent_model: Type["Model"], column_name: str
|
self, parent_model: Type["Model"], column_name: str
|
||||||
) -> Set:
|
) -> Set:
|
||||||
@ -96,7 +91,7 @@ class PrefetchQuery:
|
|||||||
|
|
||||||
use_raw = parent_model.get_name() not in self.models
|
use_raw = parent_model.get_name() not in self.models
|
||||||
|
|
||||||
column_name = self._get_column_name_for_id_extraction(
|
column_name = parent_model.get_column_name_for_id_extraction(
|
||||||
parent_model=parent_model,
|
parent_model=parent_model,
|
||||||
target_model=target_model,
|
target_model=target_model,
|
||||||
reverse=reverse,
|
reverse=reverse,
|
||||||
@ -112,22 +107,6 @@ class PrefetchQuery:
|
|||||||
parent_model=parent_model, column_name=column_name
|
parent_model=parent_model, column_name=column_name
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_clause_target_and_filter_column_name(
|
|
||||||
parent_model: Type["Model"], target_model: Type["Model"], reverse: bool
|
|
||||||
) -> Tuple[Type["Model"], str]:
|
|
||||||
if reverse:
|
|
||||||
field = target_model.resolve_relation_field(target_model, parent_model)
|
|
||||||
if issubclass(field, ManyToManyField):
|
|
||||||
sub_field = target_model.resolve_relation_field(
|
|
||||||
field.through, parent_model
|
|
||||||
)
|
|
||||||
return field.through, sub_field.get_alias()
|
|
||||||
else:
|
|
||||||
return target_model, field.get_alias()
|
|
||||||
target_field = target_model.get_column_alias(target_model.Meta.pkname)
|
|
||||||
return target_model, target_field
|
|
||||||
|
|
||||||
def _get_filter_for_prefetch(
|
def _get_filter_for_prefetch(
|
||||||
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
|
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
|
||||||
) -> List:
|
) -> List:
|
||||||
@ -138,7 +117,7 @@ class PrefetchQuery:
|
|||||||
(
|
(
|
||||||
clause_target,
|
clause_target,
|
||||||
filter_column,
|
filter_column,
|
||||||
) = self._get_clause_target_and_filter_column_name(
|
) = parent_model.get_clause_target_and_filter_column_name(
|
||||||
parent_model=parent_model, target_model=target_model, reverse=reverse
|
parent_model=parent_model, target_model=target_model, reverse=reverse
|
||||||
)
|
)
|
||||||
qryclause = QueryClause(
|
qryclause = QueryClause(
|
||||||
@ -149,52 +128,21 @@ class PrefetchQuery:
|
|||||||
return filter_clauses
|
return filter_clauses
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_model_id(target_field: Type["BaseField"], model: "Model") -> Optional[int]:
|
|
||||||
if target_field.virtual or issubclass(target_field, ManyToManyField):
|
|
||||||
return model.pk
|
|
||||||
related_name = model.resolve_relation_name(model, target_field.to)
|
|
||||||
related_model = getattr(model, related_name)
|
|
||||||
return None if not related_model else related_model.pk
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_related_field_name(
|
|
||||||
target_field: Type["BaseField"], model: Union["Model", Type["Model"]]
|
|
||||||
) -> str:
|
|
||||||
if issubclass(target_field, ManyToManyField):
|
|
||||||
return model.resolve_relation_name(target_field.through, model)
|
|
||||||
if target_field.virtual:
|
|
||||||
return model.resolve_relation_name(target_field.to, model)
|
|
||||||
return target_field.to.Meta.pkname
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_names_to_extract(prefetch_dict: Dict, model: "Model") -> List:
|
|
||||||
related_to_extract = []
|
|
||||||
if prefetch_dict and prefetch_dict is not Ellipsis:
|
|
||||||
related_to_extract = [
|
|
||||||
related
|
|
||||||
for related in model.extract_related_names()
|
|
||||||
if related in prefetch_dict
|
|
||||||
]
|
|
||||||
return related_to_extract
|
|
||||||
|
|
||||||
def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model":
|
def _populate_nested_related(self, model: "Model", prefetch_dict: Dict) -> "Model":
|
||||||
|
|
||||||
related_to_extract = self._get_names_to_extract(
|
related_to_extract = model.get_filtered_names_to_extract(
|
||||||
prefetch_dict=prefetch_dict, model=model
|
prefetch_dict=prefetch_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
for related in related_to_extract:
|
for related in related_to_extract:
|
||||||
target_field = model.Meta.model_fields[related]
|
target_field = model.Meta.model_fields[related]
|
||||||
target_model = target_field.to.get_name()
|
target_model = target_field.to.get_name()
|
||||||
model_id = self._get_model_id(target_field=target_field, model=model)
|
model_id = model.get_relation_model_id(target_field=target_field)
|
||||||
|
|
||||||
if model_id is None: # pragma: no cover
|
if model_id is None: # pragma: no cover
|
||||||
continue
|
continue
|
||||||
|
|
||||||
field_name = self._get_related_field_name(
|
field_name = model.get_related_field_name(target_field=target_field)
|
||||||
target_field=target_field, model=model
|
|
||||||
)
|
|
||||||
|
|
||||||
children = self.already_extracted.get(target_model, {}).get(field_name, {})
|
children = self.already_extracted.get(target_model, {}).get(field_name, {})
|
||||||
self._set_children_on_model(
|
self._set_children_on_model(
|
||||||
@ -266,6 +214,12 @@ class PrefetchQuery:
|
|||||||
|
|
||||||
if not already_loaded:
|
if not already_loaded:
|
||||||
# If not already loaded with select_related
|
# If not already loaded with select_related
|
||||||
|
related_field_name = parent_model.get_related_field_name(
|
||||||
|
target_field=target_field
|
||||||
|
)
|
||||||
|
fields = add_relation_field_to_fields(
|
||||||
|
fields=fields, related_field_name=related_field_name
|
||||||
|
)
|
||||||
table_prefix, rows = await self._run_prefetch_query(
|
table_prefix, rows = await self._run_prefetch_query(
|
||||||
target_field=target_field,
|
target_field=target_field,
|
||||||
fields=fields,
|
fields=fields,
|
||||||
@ -371,9 +325,7 @@ class PrefetchQuery:
|
|||||||
) -> None:
|
) -> None:
|
||||||
target_model = target_field.to
|
target_model = target_field.to
|
||||||
for row in rows:
|
for row in rows:
|
||||||
field_name = self._get_related_field_name(
|
field_name = parent_model.get_related_field_name(target_field=target_field)
|
||||||
target_field=target_field, model=parent_model
|
|
||||||
)
|
|
||||||
item = target_model.extract_prefixed_table_columns(
|
item = target_model.extract_prefixed_table_columns(
|
||||||
item={},
|
item={},
|
||||||
row=row,
|
row=row,
|
||||||
|
|||||||
@ -39,7 +39,7 @@ class Division(ormar.Model):
|
|||||||
database = database
|
database = database
|
||||||
|
|
||||||
id: int = ormar.Integer(name='division_id', primary_key=True)
|
id: int = ormar.Integer(name='division_id', primary_key=True)
|
||||||
name: str = ormar.String(max_length=100)
|
name: str = ormar.String(max_length=100, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
class Shop(ormar.Model):
|
class Shop(ormar.Model):
|
||||||
@ -49,7 +49,7 @@ class Shop(ormar.Model):
|
|||||||
database = database
|
database = database
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
id: int = ormar.Integer(primary_key=True)
|
||||||
name: str = ormar.String(max_length=100)
|
name: str = ormar.String(max_length=100, nullable=True)
|
||||||
division: Optional[Division] = ormar.ForeignKey(Division)
|
division: Optional[Division] = ormar.ForeignKey(Division)
|
||||||
|
|
||||||
|
|
||||||
@ -67,7 +67,7 @@ class Album(ormar.Model):
|
|||||||
database = database
|
database = database
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
id: int = ormar.Integer(primary_key=True)
|
||||||
name: str = ormar.String(max_length=100)
|
name: str = ormar.String(max_length=100, nullable=True)
|
||||||
shops: List[Shop] = ormar.ManyToMany(to=Shop, through=AlbumShops)
|
shops: List[Shop] = ormar.ManyToMany(to=Shop, through=AlbumShops)
|
||||||
|
|
||||||
|
|
||||||
@ -243,3 +243,50 @@ async def test_prefetch_related_with_select_related():
|
|||||||
assert len(track.album.shops) == 2
|
assert len(track.album.shops) == 2
|
||||||
assert track.album.shops[0].name == 'Shop 1'
|
assert track.album.shops[0].name == 'Shop 1'
|
||||||
assert track.album.shops[0].division.name == 'Div 1'
|
assert track.album.shops[0].division.name == 'Div 1'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prefetch_related_with_select_related_and_fields():
|
||||||
|
async with database:
|
||||||
|
async with database.transaction(force_rollback=True):
|
||||||
|
div = await Division.objects.create(name='Div 1')
|
||||||
|
shop1 = await Shop.objects.create(name='Shop 1', division=div)
|
||||||
|
shop2 = await Shop.objects.create(name='Shop 2', division=div)
|
||||||
|
album = Album(name="Malibu")
|
||||||
|
await album.save()
|
||||||
|
await album.shops.add(shop1)
|
||||||
|
await album.shops.add(shop2)
|
||||||
|
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1')
|
||||||
|
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2')
|
||||||
|
rand_set = await RandomSet.objects.create(name='Rand 1')
|
||||||
|
ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set)
|
||||||
|
await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1)
|
||||||
|
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1)
|
||||||
|
await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1)
|
||||||
|
|
||||||
|
album = await Album.objects.select_related('tracks__tonation__rand_set').filter(
|
||||||
|
name='Malibu').prefetch_related(
|
||||||
|
['cover_pictures', 'shops__division']).exclude_fields({'shops': {'division': {'name'}}}).get()
|
||||||
|
assert len(album.tracks) == 3
|
||||||
|
assert album.tracks[0].tonation == album.tracks[2].tonation == ton1
|
||||||
|
assert len(album.cover_pictures) == 2
|
||||||
|
assert album.cover_pictures[0].artist == 'Artist 1'
|
||||||
|
|
||||||
|
assert len(album.shops) == 2
|
||||||
|
assert album.shops[0].name == 'Shop 1'
|
||||||
|
assert album.shops[0].division.name is None
|
||||||
|
|
||||||
|
album = await Album.objects.select_related('tracks').filter(
|
||||||
|
name='Malibu').prefetch_related(
|
||||||
|
['cover_pictures', 'shops__division']).fields(
|
||||||
|
{'name': ..., 'shops': {'division'}, 'cover_pictures': {'id': ..., 'title': ...}}
|
||||||
|
).exclude_fields({'shops': {'division': {'name'}}}).get()
|
||||||
|
assert len(album.tracks) == 3
|
||||||
|
assert len(album.cover_pictures) == 2
|
||||||
|
assert album.cover_pictures[0].artist is None
|
||||||
|
assert album.cover_pictures[0].title is not None
|
||||||
|
|
||||||
|
assert len(album.shops) == 2
|
||||||
|
assert album.shops[0].name is None
|
||||||
|
assert album.shops[0].division is not None
|
||||||
|
assert album.shops[0].division.name is None
|
||||||
|
|||||||
Reference in New Issue
Block a user