Refactor in join in order to make possibility for nested duplicated relations (and it was a mess :D)

This commit is contained in:
collerek
2021-01-15 17:05:23 +01:00
parent d10141ba6f
commit 0fe95b0c7b
14 changed files with 271 additions and 303 deletions

View File

@ -150,11 +150,11 @@ Process order_by causes for non m2m relations.
- `fields (Optional[Union[Set, Dict]])`: fields to include - `fields (Optional[Union[Set, Dict]])`: fields to include
- `exclude_fields (Optional[Union[Set, Dict]])`: fields to exclude - `exclude_fields (Optional[Union[Set, Dict]])`: fields to exclude
<a name="queryset.join.SqlJoin._switch_many_to_many_order_columns"></a> <a name="queryset.join.SqlJoin._replace_many_to_many_order_by_columns"></a>
#### \_switch\_many\_to\_many\_order\_columns #### \_switch\_many\_to\_many\_order\_columns
```python ```python
| _switch_many_to_many_order_columns(part: str, new_part: str) -> None | _replace_many_to_many_order_by_columns(part: str, new_part: str) -> None
``` ```
Substitutes the name of the relation with actual model name in m2m order bys. Substitutes the name of the relation with actual model name in m2m order bys.

View File

@ -76,7 +76,7 @@ Since it can be a function you can set `default=datetime.datetime.now` and get c
response with `include`/`exclude` and `response_model_include`/`response_model_exclude` accordingly. response with `include`/`exclude` and `response_model_include`/`response_model_exclude` accordingly.
```python ```python
# <==part of code removed for clarity==> # <==related of code removed for clarity==>
class User(ormar.Model): class User(ormar.Model):
class Meta: class Meta:
tablename: str = "users2" tablename: str = "users2"
@ -93,14 +93,14 @@ class User(ormar.Model):
pydantic_only=True, default=datetime.datetime.now pydantic_only=True, default=datetime.datetime.now
) )
# <==part of code removed for clarity==> # <==related of code removed for clarity==>
app =FastAPI() app =FastAPI()
@app.post("/users/") @app.post("/users/")
async def create_user(user: User): async def create_user(user: User):
return await user.save() return await user.save()
# <==part of code removed for clarity==> # <==related of code removed for clarity==>
def test_excluding_fields_in_endpoints(): def test_excluding_fields_in_endpoints():
client = TestClient(app) client = TestClient(app)
@ -127,7 +127,7 @@ def test_excluding_fields_in_endpoints():
assert response.json().get("timestamp") == str(timestamp).replace(" ", "T") assert response.json().get("timestamp") == str(timestamp).replace(" ", "T")
# <==part of code removed for clarity==> # <==related of code removed for clarity==>
``` ```
#### Property fields #### Property fields
@ -190,7 +190,7 @@ in the response from `fastapi` and `dict()` and `json()` methods. You cannot pas
``` ```
```python ```python
# <==part of code removed for clarity==> # <==related of code removed for clarity==>
def gen_pass(): # note: NOT production ready def gen_pass(): # note: NOT production ready
choices = string.ascii_letters + string.digits + "!@#$%^&*()" choices = string.ascii_letters + string.digits + "!@#$%^&*()"
return "".join(random.choice(choices) for _ in range(20)) return "".join(random.choice(choices) for _ in range(20))
@ -215,7 +215,7 @@ class RandomModel(ormar.Model):
def full_name(self) -> str: def full_name(self) -> str:
return " ".join([self.first_name, self.last_name]) return " ".join([self.first_name, self.last_name])
# <==part of code removed for clarity==> # <==related of code removed for clarity==>
app =FastAPI() app =FastAPI()
# explicitly exclude property_field in this endpoint # explicitly exclude property_field in this endpoint
@ -223,7 +223,7 @@ app =FastAPI()
async def create_user(user: RandomModel): async def create_user(user: RandomModel):
return await user.save() return await user.save()
# <==part of code removed for clarity==> # <==related of code removed for clarity==>
def test_excluding_property_field_in_endpoints2(): def test_excluding_property_field_in_endpoints2():
client = TestClient(app) client = TestClient(app)
@ -241,7 +241,7 @@ def test_excluding_property_field_in_endpoints2():
# despite being decorated with property_field if you explictly exclude it it will be gone # despite being decorated with property_field if you explictly exclude it it will be gone
assert response.json().get("full_name") is None assert response.json().get("full_name") is None
# <==part of code removed for clarity==> # <==related of code removed for clarity==>
``` ```
#### Fields names vs Column names #### Fields names vs Column names

View File

@ -1,4 +1,5 @@
from typing import Dict, List, Optional, TYPE_CHECKING, Tuple, Type import itertools
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type
import ormar # noqa: I100 import ormar # noqa: I100
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
@ -109,3 +110,32 @@ def validate_related_names_in_relations( # noqa CCR001
f"\nTip: provide different related_name for FK and/or M2M fields" f"\nTip: provide different related_name for FK and/or M2M fields"
) )
previous_related_names.append(field.related_name) previous_related_names.append(field.related_name)
def group_related_list(list_: List) -> Dict:
"""
Translates the list of related strings into a dictionary.
That way nested models are grouped to traverse them in a right order
and to avoid repetition.
Sample: ["people__houses", "people__cars__models", "people__cars__colors"]
will become:
{'people': {'houses': [], 'cars': ['models', 'colors']}}
:param list_: list of related models used in select related
:type list_: List[str]
:return: list converted to dictionary to avoid repetition and group nested models
:rtype: Dict[str, List]
"""
test_dict: Dict[str, Any] = dict()
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
for key, group in grouped:
group_list = list(group)
new = [
"__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1
]
if any("__" in x for x in new):
test_dict[key] = group_related_list(new)
else:
test_dict[key] = new
return test_dict

View File

@ -221,7 +221,7 @@ def update_attrs_and_fields(
:param attrs: new namespace for class being constructed :param attrs: new namespace for class being constructed
:type attrs: Dict :type attrs: Dict
:param new_attrs: part of the namespace extracted from parent class :param new_attrs: related of the namespace extracted from parent class
:type new_attrs: Dict :type new_attrs: Dict
:param model_fields: ormar fields in defined in current class :param model_fields: ormar fields in defined in current class
:type model_fields: Dict[str, BaseField] :type model_fields: Dict[str, BaseField]

View File

@ -1,4 +1,3 @@
import itertools
from typing import ( from typing import (
Any, Any,
Dict, Dict,
@ -18,38 +17,9 @@ import ormar.queryset # noqa I100
from ormar.exceptions import ModelPersistenceError, NoMatch from ormar.exceptions import ModelPersistenceError, NoMatch
from ormar.fields.many_to_many import ManyToManyField from ormar.fields.many_to_many import ManyToManyField
from ormar.models import NewBaseModel # noqa I100 from ormar.models import NewBaseModel # noqa I100
from ormar.models.helpers.models import group_related_list
from ormar.models.metaclass import ModelMeta from ormar.models.metaclass import ModelMeta
def group_related_list(list_: List) -> Dict:
"""
Translates the list of related strings into a dictionary.
That way nested models are grouped to traverse them in a right order
and to avoid repetition.
Sample: ["people__houses", "people__cars__models", "people__cars__colors"]
will become:
{'people': {'houses': [], 'cars': ['models', 'colors']}}
:param list_: list of related models used in select related
:type list_: List[str]
:return: list converted to dictionary to avoid repetition and group nested models
:rtype: Dict[str, List]
"""
test_dict: Dict[str, Any] = dict()
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
for key, group in grouped:
group_list = list(group)
new = [
"__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1
]
if any("__" in x for x in new):
test_dict[key] = group_related_list(new)
else:
test_dict[key] = new
return test_dict
if TYPE_CHECKING: # pragma nocover if TYPE_CHECKING: # pragma nocover
from ormar import QuerySet from ormar import QuerySet
@ -73,9 +43,11 @@ class Model(NewBaseModel):
select_related: List = None, select_related: List = None,
related_models: Any = None, related_models: Any = None,
previous_model: Type[T] = None, previous_model: Type[T] = None,
source_model: Type[T] = None,
related_name: str = None, related_name: str = None,
fields: Optional[Union[Dict, Set]] = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: Optional[Union[Dict, Set]] = None, exclude_fields: Optional[Union[Dict, Set]] = None,
current_relation_str: str = None,
) -> Optional[T]: ) -> Optional[T]:
""" """
Model method to convert raw sql row from database into ormar.Model instance. Model method to convert raw sql row from database into ormar.Model instance.
@ -112,7 +84,10 @@ class Model(NewBaseModel):
item: Dict[str, Any] = {} item: Dict[str, Any] = {}
select_related = select_related or [] select_related = select_related or []
related_models = related_models or [] related_models = related_models or []
table_prefix = ""
if select_related: if select_related:
source_model = cls
related_models = group_related_list(select_related) related_models = group_related_list(select_related)
rel_name2 = related_name rel_name2 = related_name
@ -135,11 +110,15 @@ class Model(NewBaseModel):
previous_model = through_field.through # type: ignore previous_model = through_field.through # type: ignore
if previous_model and rel_name2: if previous_model and rel_name2:
table_prefix = cls.Meta.alias_manager.resolve_relation_alias( # TODO finish duplicated nested relation or remove this
previous_model, rel_name2 if current_relation_str and "__" in current_relation_str and source_model:
) table_prefix = cls.Meta.alias_manager.resolve_relation_alias(
else: from_model=source_model, relation_name=current_relation_str
table_prefix = "" )
if not table_prefix:
table_prefix = cls.Meta.alias_manager.resolve_relation_alias(
from_model=previous_model, relation_name=rel_name2
)
item = cls.populate_nested_models_from_row( item = cls.populate_nested_models_from_row(
item=item, item=item,
@ -147,6 +126,8 @@ class Model(NewBaseModel):
related_models=related_models, related_models=related_models,
fields=fields, fields=fields,
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
current_relation_str=current_relation_str,
source_model=source_model,
) )
item = cls.extract_prefixed_table_columns( item = cls.extract_prefixed_table_columns(
item=item, item=item,
@ -163,8 +144,6 @@ class Model(NewBaseModel):
) )
instance = cls(**item) instance = cls(**item)
instance.set_save_status(True) instance.set_save_status(True)
else:
instance = None
return instance return instance
@classmethod @classmethod
@ -175,6 +154,8 @@ class Model(NewBaseModel):
related_models: Any, related_models: Any,
fields: Optional[Union[Dict, Set]] = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: Optional[Union[Dict, Set]] = None, exclude_fields: Optional[Union[Dict, Set]] = None,
current_relation_str: str = None,
source_model: Type[T] = None,
) -> dict: ) -> dict:
""" """
Traverses structure of related models and populates the nested models Traverses structure of related models and populates the nested models
@ -202,35 +183,31 @@ class Model(NewBaseModel):
and values are database values and values are database values
:rtype: Dict :rtype: Dict
""" """
for related in related_models: for related in related_models:
relation_str = (
"__".join([current_relation_str, related])
if current_relation_str
else related
)
fields = cls.get_included(fields, related)
exclude_fields = cls.get_excluded(exclude_fields, related)
model_cls = cls.Meta.model_fields[related].to
remainder = None
if isinstance(related_models, dict) and related_models[related]: if isinstance(related_models, dict) and related_models[related]:
first_part, remainder = related, related_models[related] remainder = related_models[related]
model_cls = cls.Meta.model_fields[first_part].to child = model_cls.from_row(
row,
fields = cls.get_included(fields, first_part) related_models=remainder,
exclude_fields = cls.get_excluded(exclude_fields, first_part) previous_model=cls,
related_name=related,
child = model_cls.from_row( fields=fields,
row, exclude_fields=exclude_fields,
related_models=remainder, current_relation_str=relation_str,
previous_model=cls, source_model=source_model,
related_name=related, )
fields=fields, item[model_cls.get_column_name_from_alias(related)] = child
exclude_fields=exclude_fields,
)
item[model_cls.get_column_name_from_alias(first_part)] = child
else:
model_cls = cls.Meta.model_fields[related].to
fields = cls.get_included(fields, related)
exclude_fields = cls.get_excluded(exclude_fields, related)
child = model_cls.from_row(
row,
previous_model=cls,
related_name=related,
fields=fields,
exclude_fields=exclude_fields,
)
item[model_cls.get_column_name_from_alias(related)] = child
return item return item
@ -251,7 +228,7 @@ class Model(NewBaseModel):
All joined tables have prefixes to allow duplicate column names, All joined tables have prefixes to allow duplicate column names,
as well as duplicated joins to the same table from multiple different tables. as well as duplicated joins to the same table from multiple different tables.
Extracted fields populates the item dict later used to construct a Model. Extracted fields populates the related dict later used to construct a Model.
Used in Model.from_row and PrefetchQuery._populate_rows methods. Used in Model.from_row and PrefetchQuery._populate_rows methods.

View File

@ -194,7 +194,9 @@ class QueryClause:
previous_model = through_field.through previous_model = through_field.through
part2 = through_field.default_target_field_name() # type: ignore part2 = through_field.default_target_field_name() # type: ignore
manager = model_cls.Meta.alias_manager manager = model_cls.Meta.alias_manager
table_prefix = manager.resolve_relation_alias(previous_model, part2) table_prefix = manager.resolve_relation_alias(
from_model=previous_model, relation_name=part2
)
model_cls = model_cls.Meta.model_fields[part].to model_cls = model_cls.Meta.model_fields[part].to
previous_model = model_cls previous_model = model_cls
return select_related, table_prefix, model_cls return select_related, table_prefix, model_cls

View File

@ -1,8 +1,8 @@
from collections import OrderedDict from collections import OrderedDict
from typing import ( from typing import (
Any,
Dict, Dict,
List, List,
NamedTuple,
Optional, Optional,
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
@ -14,24 +14,13 @@ from typing import (
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
from ormar.fields import ManyToManyField # noqa I100 from ormar.fields import BaseField, ManyToManyField # noqa I100
from ormar.relations import AliasManager from ormar.relations import AliasManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
class JoinParameters(NamedTuple):
"""
Named tuple that holds set of parameters passed during join construction.
"""
prev_model: Type["Model"]
previous_alias: str
from_table: str
model_cls: Type["Model"]
class SqlJoin: class SqlJoin:
def __init__( # noqa: CFQ002 def __init__( # noqa: CFQ002
self, self,
@ -42,7 +31,12 @@ class SqlJoin:
exclude_fields: Optional[Union[Set, Dict]], exclude_fields: Optional[Union[Set, Dict]],
order_columns: Optional[List], order_columns: Optional[List],
sorted_orders: OrderedDict, sorted_orders: OrderedDict,
main_model: Type["Model"],
related_models: Any = None,
own_alias: str = "",
) -> None: ) -> None:
self.own_alias = own_alias
self.related_models = related_models or []
self.used_aliases = used_aliases self.used_aliases = used_aliases
self.select_from = select_from self.select_from = select_from
self.columns = columns self.columns = columns
@ -50,18 +44,17 @@ class SqlJoin:
self.exclude_fields = exclude_fields self.exclude_fields = exclude_fields
self.order_columns = order_columns self.order_columns = order_columns
self.sorted_orders = sorted_orders self.sorted_orders = sorted_orders
self.main_model = main_model
@staticmethod @property
def alias_manager(model_cls: Type["Model"]) -> AliasManager: def alias_manager(self) -> AliasManager:
""" """
Shortcut for ormars model AliasManager stored on Meta. Shortcut for ormar's model AliasManager stored on Meta.
:param model_cls: ormar Model class
:type model_cls: Type[Model]
:return: alias manager from model's Meta :return: alias manager from model's Meta
:rtype: AliasManager :rtype: AliasManager
""" """
return model_cls.Meta.alias_manager return self.main_model.Meta.alias_manager
@staticmethod @staticmethod
def on_clause( def on_clause(
@ -86,33 +79,32 @@ class SqlJoin:
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
return text(f"{left_part}={right_part}") return text(f"{left_part}={right_part}")
@staticmethod def process_deeper_join(
def update_inclusions( self, related_name: str, model_cls: Type["Model"], remainder: Any, alias: str,
model_cls: Type["Model"], ) -> None:
fields: Optional[Union[Set, Dict]], sql_join = SqlJoin(
exclude_fields: Optional[Union[Set, Dict]], used_aliases=self.used_aliases,
nested_name: str, select_from=self.select_from,
) -> Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]: columns=self.columns,
""" fields=self.main_model.get_excluded(self.fields, related_name),
Extract nested fields and exclude_fields if applicable. exclude_fields=self.main_model.get_excluded(
self.exclude_fields, related_name
:param model_cls: ormar model class ),
:type model_cls: Type["Model"] order_columns=self.order_columns,
:param fields: fields to include sorted_orders=self.sorted_orders,
:type fields: Optional[Union[Set, Dict]] main_model=model_cls,
:param exclude_fields: fields to exclude related_models=remainder,
:type exclude_fields: Optional[Union[Set, Dict]] own_alias=alias,
:param nested_name: name of the nested field )
:type nested_name: str (
:return: updated exclude and include fields from nested objects self.used_aliases,
:rtype: Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]] self.select_from,
""" self.columns,
fields = model_cls.get_included(fields, nested_name) self.sorted_orders,
exclude_fields = model_cls.get_excluded(exclude_fields, nested_name) ) = sql_join.build_join(related_name)
return fields, exclude_fields
def build_join( # noqa: CCR001 def build_join( # noqa: CCR001
self, item: str, join_parameters: JoinParameters self, related: str
) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]: ) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]:
""" """
Main external access point for building a join. Main external access point for building a join.
@ -120,59 +112,61 @@ class SqlJoin:
handles switching to through models for m2m relations, returns updated lists of handles switching to through models for m2m relations, returns updated lists of
used_aliases and sort_orders. used_aliases and sort_orders.
:param item: string with join definition :param related: string with join definition
:type item: str :type related: str
:param join_parameters: parameters from previous/ current join
:type join_parameters: JoinParameters
:return: list of used aliases, select from, list of aliased columns, sort orders :return: list of used aliases, select from, list of aliased columns, sort orders
:rtype: Tuple[List[str], Join, List[TextClause], collections.OrderedDict] :rtype: Tuple[List[str], Join, List[TextClause], collections.OrderedDict]
""" """
fields = self.fields target_field = self.main_model.Meta.model_fields[related]
exclude_fields = self.exclude_fields prev_model = self.main_model
# TODO: Finish refactoring here
if issubclass(target_field, ManyToManyField):
new_part = self.process_m2m_related_name_change(
target_field=target_field, related=related
)
self._replace_many_to_many_order_by_columns(related, new_part)
for index, part in enumerate(item.split("__")): model_cls = target_field.through
if issubclass( alias = self.alias_manager.resolve_relation_alias(
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField from_model=prev_model, relation_name=related
)
if alias not in self.used_aliases:
self._process_join(
model_cls=model_cls,
related=related,
alias=alias,
target_field=target_field,
)
related = new_part
self.own_alias = alias
prev_model = model_cls
target_field = target_field.through.Meta.model_fields[related]
model_cls = target_field.to
alias = model_cls.Meta.alias_manager.resolve_relation_alias(
from_model=prev_model, relation_name=related
)
if alias not in self.used_aliases:
self._process_join(
model_cls=model_cls,
prev_model=prev_model,
related=related,
alias=alias,
target_field=target_field,
)
for related_name in self.related_models:
remainder = None
if (
isinstance(self.related_models, dict)
and self.related_models[related_name]
): ):
_fields = join_parameters.model_cls.Meta.model_fields remainder = self.related_models[related_name]
target_field = _fields[part] self.process_deeper_join(
if ( related_name=related_name,
target_field.self_reference model_cls=model_cls,
and part == target_field.self_reference_primary remainder=remainder,
): alias=alias,
new_part = target_field.default_source_field_name() # type: ignore
else:
new_part = target_field.default_target_field_name() # type: ignore
self._switch_many_to_many_order_columns(part, new_part)
if index > 0: # nested joins
fields, exclude_fields = SqlJoin.update_inclusions(
model_cls=join_parameters.model_cls,
fields=fields,
exclude_fields=exclude_fields,
nested_name=part,
)
join_parameters = self._build_join_parameters(
part=part,
join_params=join_parameters,
is_multi=True,
fields=fields,
exclude_fields=exclude_fields,
)
part = new_part
if index > 0: # nested joins
fields, exclude_fields = SqlJoin.update_inclusions(
model_cls=join_parameters.model_cls,
fields=fields,
exclude_fields=exclude_fields,
nested_name=part,
)
join_parameters = self._build_join_parameters(
part=part,
join_params=join_parameters,
fields=fields,
exclude_fields=exclude_fields,
) )
return ( return (
@ -182,65 +176,44 @@ class SqlJoin:
self.sorted_orders, self.sorted_orders,
) )
def _build_join_parameters( @staticmethod
self, def process_m2m_related_name_change(
part: str, target_field: Type[ManyToManyField], related: str, reverse: bool = False
join_params: JoinParameters, ) -> str:
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
is_multi: bool = False,
) -> JoinParameters:
""" """
Updates used_aliases to not join multiple times to the same table. Extracts relation name to link join through the Through model declared on
Updates join parameters with new values. relation field.
:param part: part of the join str definition Changes the same names in order_by queries if they are present.
:type part: str
:param join_params: parameters from previous/ current join :param reverse: flag if it's on_clause lookup - use reverse fields
:type join_params: JoinParameters :type reverse: bool
:param fields: fields to include :param target_field: relation field
:type fields: Optional[Union[Set, Dict]] :type target_field: Type[ManyToManyField]
:param exclude_fields: fields to exclude :param related: name of the relation
:type exclude_fields: Optional[Union[Set, Dict]] :type related: str
:param is_multi: flag if the relation is m2m :return: new relation name switched to through model field
:type is_multi: bool :rtype: str
:return: updated join parameters
:rtype: ormar.queryset.join.JoinParameters
""" """
if is_multi: is_primary_self_ref = (
model_cls = join_params.model_cls.Meta.model_fields[part].through target_field.self_reference
else: and related == target_field.self_reference_primary
model_cls = join_params.model_cls.Meta.model_fields[part].to
to_table = model_cls.Meta.table.name
alias = model_cls.Meta.alias_manager.resolve_relation_alias(
join_params.prev_model, part
) )
if alias not in self.used_aliases: if (is_primary_self_ref and not reverse) or (
self._process_join( not is_primary_self_ref and reverse
join_params=join_params, ):
is_multi=is_multi, new_part = target_field.default_source_field_name() # type: ignore
model_cls=model_cls, else:
part=part, new_part = target_field.default_target_field_name() # type: ignore
alias=alias, return new_part
fields=fields,
exclude_fields=exclude_fields,
)
previous_alias = alias
from_table = to_table
prev_model = model_cls
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
def _process_join( # noqa: CFQ002 def _process_join( # noqa: CFQ002
self, self,
join_params: JoinParameters,
is_multi: bool,
model_cls: Type["Model"], model_cls: Type["Model"],
part: str, related: str,
alias: str, alias: str,
fields: Optional[Union[Set, Dict]], target_field: Type[BaseField],
exclude_fields: Optional[Union[Set, Dict]], prev_model: Type["Model"] = None,
) -> None: ) -> None:
""" """
Resolves to and from column names and table names. Resolves to and from column names and table names.
@ -255,63 +228,53 @@ class SqlJoin:
Process order_by causes for non m2m relations. Process order_by causes for non m2m relations.
:param join_params: parameters from previous/ current join
:type join_params: JoinParameters
:param is_multi: flag if it's m2m relation
:type is_multi: bool
:param model_cls: :param model_cls:
:type model_cls: ormar.models.metaclass.ModelMetaclass :type model_cls: ormar.models.metaclass.ModelMetaclass
:param part: name of the field used in join :param related: name of the field used in join
:type part: str :type related: str
:param alias: alias of the current join :param alias: alias of the current join
:type alias: str :type alias: str
:param fields: fields to include
:type fields: Optional[Union[Set, Dict]]
:param exclude_fields: fields to exclude
:type exclude_fields: Optional[Union[Set, Dict]]
""" """
to_table = model_cls.Meta.table.name to_table = model_cls.Meta.table.name
to_key, from_key = self.get_to_and_from_keys( to_key, from_key = self.get_to_and_from_keys(related, target_field)
join_params, is_multi, model_cls, part
) prev_model = prev_model or self.main_model
on_clause = self.on_clause( on_clause = self.on_clause(
previous_alias=join_params.previous_alias, previous_alias=self.own_alias,
alias=alias, alias=alias,
from_clause=f"{join_params.from_table}.{from_key}", from_clause=f"{prev_model.Meta.tablename}.{from_key}",
to_clause=f"{to_table}.{to_key}", to_clause=f"{to_table}.{to_key}",
) )
target_table = self.alias_manager(model_cls).prefixed_table_name( target_table = self.alias_manager.prefixed_table_name(alias, to_table)
alias, to_table
)
self.select_from = sqlalchemy.sql.outerjoin( self.select_from = sqlalchemy.sql.outerjoin(
self.select_from, target_table, on_clause self.select_from, target_table, on_clause
) )
pkname_alias = model_cls.get_column_alias(model_cls.Meta.pkname) pkname_alias = model_cls.get_column_alias(model_cls.Meta.pkname)
if not is_multi: if not issubclass(target_field, ManyToManyField):
self.get_order_bys( self.get_order_bys(
alias=alias, alias=alias,
to_table=to_table, to_table=to_table,
pkname_alias=pkname_alias, pkname_alias=pkname_alias,
part=part, part=related,
model_cls=model_cls, model_cls=model_cls,
) )
self_related_fields = model_cls.own_table_columns( self_related_fields = model_cls.own_table_columns(
model=model_cls, model=model_cls,
fields=fields, fields=self.fields,
exclude_fields=exclude_fields, exclude_fields=self.exclude_fields,
use_alias=True, use_alias=True,
) )
self.columns.extend( self.columns.extend(
self.alias_manager(model_cls).prefixed_columns( self.alias_manager.prefixed_columns(
alias, model_cls.Meta.table, self_related_fields alias, model_cls.Meta.table, self_related_fields
) )
) )
self.used_aliases.append(alias) self.used_aliases.append(alias)
def _switch_many_to_many_order_columns(self, part: str, new_part: str) -> None: def _replace_many_to_many_order_by_columns(self, part: str, new_part: str) -> None:
""" """
Substitutes the name of the relation with actual model name in m2m order bys. Substitutes the name of the relation with actual model name in m2m order bys.
@ -325,7 +288,7 @@ class SqlJoin:
x.split("__") for x in self.order_columns if "__" in x x.split("__") for x in self.order_columns if "__" in x
] ]
for condition in split_order_columns: for condition in split_order_columns:
if condition[-2] == part or condition[-2][1:] == part: if self._check_if_condition_apply(condition, part):
condition[-2] = condition[-2].replace(part, new_part) condition[-2] = condition[-2].replace(part, new_part)
self.order_columns = [x for x in self.order_columns if "__" not in x] + [ self.order_columns = [x for x in self.order_columns if "__" not in x] + [
"__".join(x) for x in split_order_columns "__".join(x) for x in split_order_columns
@ -413,51 +376,34 @@ class SqlJoin:
order = text(f"{alias}_{to_table}.{pkname_alias}") order = text(f"{alias}_{to_table}.{pkname_alias}")
self.sorted_orders[f"{alias}.{pkname_alias}"] = order self.sorted_orders[f"{alias}.{pkname_alias}"] = order
@staticmethod
def get_to_and_from_keys( def get_to_and_from_keys(
join_params: JoinParameters, self, related: str, target_field: Type[BaseField]
is_multi: bool,
model_cls: Type["Model"],
part: str,
) -> Tuple[str, str]: ) -> Tuple[str, str]:
""" """
Based on the relation type, name of the relation and previous models and parts Based on the relation type, name of the relation and previous models and parts
stored in JoinParameters it resolves the current to and from keys, which are stored in JoinParameters it resolves the current to and from keys, which are
different for ManyToMany relation, ForeignKey and reverse part of relations. different for ManyToMany relation, ForeignKey and reverse related of relations.
:param join_params: parameters from previous/ current join :param target_field: relation field
:type join_params: JoinParameters :type target_field: Type[ForeignKeyField]
:param is_multi: flag if the relation is of m2m type :param related: name of the current relation join
:type is_multi: bool :type related: str
:param model_cls: ormar model class
:type model_cls: Type[Model]
:param part: name of the current relation join
:type part: str
:return: to key and from key :return: to key and from key
:rtype: Tuple[str, str] :rtype: Tuple[str, str]
""" """
if is_multi: if issubclass(target_field, ManyToManyField):
target_field = join_params.model_cls.Meta.model_fields[part] to_key = self.process_m2m_related_name_change(
if ( target_field=target_field, related=related, reverse=True
target_field.self_reference
and part == target_field.self_reference_primary
):
to_key = target_field.default_target_field_name() # type: ignore
else:
to_key = target_field.default_source_field_name() # type: ignore
from_key = join_params.prev_model.get_column_alias(
join_params.prev_model.Meta.pkname
) )
from_key = self.main_model.get_column_alias(self.main_model.Meta.pkname)
elif join_params.prev_model.Meta.model_fields[part].virtual: elif target_field.virtual:
to_field = join_params.prev_model.Meta.model_fields[part].get_related_name() to_field = target_field.get_related_name()
to_key = model_cls.get_column_alias(to_field) to_key = target_field.to.get_column_alias(to_field)
from_key = join_params.prev_model.get_column_alias( from_key = self.main_model.get_column_alias(self.main_model.Meta.pkname)
join_params.prev_model.Meta.pkname
)
else: else:
to_key = model_cls.get_column_alias(model_cls.Meta.pkname) to_key = target_field.to.get_column_alias(target_field.to.Meta.pkname)
from_key = join_params.prev_model.get_column_alias(part) from_key = self.main_model.get_column_alias(related)
return to_key, from_key return to_key, from_key

View File

@ -526,7 +526,7 @@ class PrefetchQuery:
query_target = target_field.through query_target = target_field.through
select_related = [target_name] select_related = [target_name]
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_alias( table_prefix = target_field.to.Meta.alias_manager.resolve_relation_alias(
query_target, target_name from_model=query_target, relation_name=target_name
) )
self.already_extracted.setdefault(target_name, {})["prefix"] = table_prefix self.already_extracted.setdefault(target_name, {})["prefix"] = table_prefix
@ -551,14 +551,14 @@ class PrefetchQuery:
@staticmethod @staticmethod
def _get_select_related_if_apply(related: str, select_dict: Dict) -> Dict: def _get_select_related_if_apply(related: str, select_dict: Dict) -> Dict:
""" """
Extract nested part of select_related dictionary to extract models nested Extract nested related of select_related dictionary to extract models nested
deeper on related model and already loaded in select related query. deeper on related model and already loaded in select related query.
:param related: name of the relation :param related: name of the relation
:type related: str :type related: str
:param select_dict: dictionary of select related models in main query :param select_dict: dictionary of select related models in main query
:type select_dict: Dict :type select_dict: Dict
:return: dictionary with nested part of select related :return: dictionary with nested related of select related
:rtype: Dict :rtype: Dict
""" """
return ( return (

View File

@ -6,8 +6,9 @@ import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
import ormar # noqa I100 import ormar # noqa I100
from ormar.models.helpers.models import group_related_list
from ormar.queryset import FilterQuery, LimitQuery, OffsetQuery, OrderQuery from ormar.queryset import FilterQuery, LimitQuery, OffsetQuery, OrderQuery
from ormar.queryset.join import JoinParameters, SqlJoin from ormar.queryset.join import SqlJoin
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -140,14 +141,16 @@ class Query:
else: else:
self.select_from = self.table self.select_from = self.table
# TODO: Refactor to convert to nested dict like in from_row in model
self._select_related.sort(key=lambda item: (item, -len(item))) self._select_related.sort(key=lambda item: (item, -len(item)))
related_models = group_related_list(self._select_related)
for item in self._select_related: for related in related_models:
join_parameters = JoinParameters( fields = self.model_cls.get_included(self.fields, related)
self.model_cls, "", self.table.name, self.model_cls exclude_fields = self.model_cls.get_excluded(self.exclude_fields, related)
) remainder = None
fields = self.model_cls.get_included(self.fields, item) if isinstance(related_models, dict) and related_models[related]:
exclude_fields = self.model_cls.get_excluded(self.exclude_fields, item) remainder = related_models[related]
sql_join = SqlJoin( sql_join = SqlJoin(
used_aliases=self.used_aliases, used_aliases=self.used_aliases,
select_from=self.select_from, select_from=self.select_from,
@ -156,6 +159,8 @@ class Query:
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
order_columns=self.order_columns, order_columns=self.order_columns,
sorted_orders=self.sorted_orders, sorted_orders=self.sorted_orders,
main_model=self.model_cls,
related_models=remainder,
) )
( (
@ -163,14 +168,14 @@ class Query:
self.select_from, self.select_from,
self.columns, self.columns,
self.sorted_orders, self.sorted_orders,
) = sql_join.build_join(item, join_parameters) ) = sql_join.build_join(related)
expr = sqlalchemy.sql.select(self.columns) expr = sqlalchemy.sql.select(self.columns)
expr = expr.select_from(self.select_from) expr = expr.select_from(self.select_from)
expr = self._apply_expression_modifiers(expr) expr = self._apply_expression_modifiers(expr)
# print(expr.compile(compile_kwargs={"literal_binds": True})) # print("\n", expr.compile(compile_kwargs={"literal_binds": True}))
self._reset_query_parameters() self._reset_query_parameters()
return expr return expr

View File

@ -113,6 +113,19 @@ class AliasManager:
if child_key not in self._aliases_new: if child_key not in self._aliases_new:
self._aliases_new[child_key] = get_table_alias() self._aliases_new[child_key] = get_table_alias()
def add_alias(self, alias_key: str) -> str:
"""
Adds alias to the dictionary of aliases under given key.
:param alias_key: key of relation to generate alias for
:type alias_key: str
:return: generated alias
:rtype: str
"""
alias = get_table_alias()
self._aliases_new[alias_key] = alias
return alias
def resolve_relation_alias( def resolve_relation_alias(
self, from_model: Type["Model"], relation_name: str self, from_model: Type["Model"], relation_name: str
) -> str: ) -> str:

View File

@ -127,7 +127,7 @@ class RelationProxy(list):
self, item: "Model", keep_reversed: bool = True self, item: "Model", keep_reversed: bool = True
) -> None: ) -> None:
""" """
Removes the item from relation with parent. Removes the related from relation with parent.
Through models are automatically deleted for m2m relations. Through models are automatically deleted for m2m relations.

View File

@ -190,7 +190,7 @@ async def test_m2m_self_forwardref_relation(cleanup):
# await steve.friends.add(billy) # await steve.friends.add(billy)
billy_check = await Child.objects.select_related( billy_check = await Child.objects.select_related(
["friends", "favourite_game", "least_favourite_game",] ["friends", "favourite_game", "least_favourite_game"]
).get(name="Billy") ).get(name="Billy")
assert len(billy_check.friends) == 2 assert len(billy_check.friends) == 2
assert billy_check.friends[0].name == "Kate" assert billy_check.friends[0].name == "Kate"
@ -200,5 +200,6 @@ async def test_m2m_self_forwardref_relation(cleanup):
kate_check = await Child.objects.select_related(["also_friends",]).get( kate_check = await Child.objects.select_related(["also_friends",]).get(
name="Kate" name="Kate"
) )
assert len(kate_check.also_friends) == 1 assert len(kate_check.also_friends) == 1
assert kate_check.also_friends[0].name == "Billy" assert kate_check.also_friends[0].name == "Billy"

View File

@ -280,7 +280,7 @@ async def test_sort_order_on_many_to_many():
assert users[1].cars[3].name == "Buggy" assert users[1].cars[3].name == "Buggy"
users = ( users = (
await User.objects.select_related(["cars", "cars__factory"]) await User.objects.select_related(["cars__factory"])
.order_by(["-cars__factory__name", "cars__name"]) .order_by(["-cars__factory__name", "cars__name"])
.all() .all()
) )

View File

@ -116,9 +116,7 @@ async def test_selecting_subset():
) )
all_cars = ( all_cars = (
await Car.objects.select_related( await Car.objects.select_related(["manufacturer__hq__nicks"])
["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
)
.fields( .fields(
[ [
"id", "id",
@ -132,9 +130,7 @@ async def test_selecting_subset():
) )
all_cars2 = ( all_cars2 = (
await Car.objects.select_related( await Car.objects.select_related(["manufacturer__hq__nicks"])
["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
)
.fields( .fields(
{ {
"id": ..., "id": ...,
@ -149,9 +145,7 @@ async def test_selecting_subset():
) )
all_cars3 = ( all_cars3 = (
await Car.objects.select_related( await Car.objects.select_related(["manufacturer__hq__nicks"])
["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
)
.fields( .fields(
{ {
"id": ..., "id": ...,