fix some code smells
This commit is contained in:
@ -68,13 +68,15 @@ class QuerySet:
|
||||
def table(self) -> sqlalchemy.Table:
|
||||
return self.model_cls.__table__
|
||||
|
||||
def prefixed_columns(self, alias: str, table: sqlalchemy.Table) -> List[text]:
|
||||
@staticmethod
|
||||
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
|
||||
return [
|
||||
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
|
||||
for column in table.columns
|
||||
]
|
||||
|
||||
def prefixed_table_name(self, alias: str, name: str) -> text:
|
||||
@staticmethod
|
||||
def prefixed_table_name(alias: str, name: str) -> text:
|
||||
return text(f"{name} {alias}_{name}")
|
||||
|
||||
def on_clause(
|
||||
@ -91,7 +93,7 @@ class QuerySet:
|
||||
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}'
|
||||
)
|
||||
|
||||
def build_join_parameters(
|
||||
def _build_join_parameters(
|
||||
self, part: str, join_params: JoinParameters
|
||||
) -> JoinParameters:
|
||||
model_cls = join_params.model_cls.__model_fields__[part].to
|
||||
@ -137,12 +139,12 @@ class QuerySet:
|
||||
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
||||
|
||||
@staticmethod
|
||||
def field_is_a_foreign_key_and_no_circular_reference(
|
||||
def _field_is_a_foreign_key_and_no_circular_reference(
|
||||
field: BaseField, field_name: str, rel_part: str
|
||||
) -> bool:
|
||||
return isinstance(field, ForeignKey) and field_name not in rel_part
|
||||
|
||||
def field_qualifies_to_deeper_search(
|
||||
def _field_qualifies_to_deeper_search(
|
||||
self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str
|
||||
) -> bool:
|
||||
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
|
||||
@ -155,7 +157,7 @@ class QuerySet:
|
||||
or (partial_match and not already_checked)
|
||||
) or not nested
|
||||
|
||||
def extract_auto_required_relations(
|
||||
def _extract_auto_required_relations(
|
||||
self,
|
||||
join_params: JoinParameters,
|
||||
rel_part: str = "",
|
||||
@ -163,7 +165,7 @@ class QuerySet:
|
||||
parent_virtual: bool = False,
|
||||
) -> None:
|
||||
for field_name, field in join_params.prev_model.__model_fields__.items():
|
||||
if self.field_is_a_foreign_key_and_no_circular_reference(
|
||||
if self._field_is_a_foreign_key_and_no_circular_reference(
|
||||
field, field_name, rel_part
|
||||
):
|
||||
rel_part = field_name if not rel_part else rel_part + "__" + field_name
|
||||
@ -171,7 +173,7 @@ class QuerySet:
|
||||
if rel_part not in self._select_related:
|
||||
self.auto_related.append("__".join(rel_part.split("__")[:-1]))
|
||||
rel_part = ""
|
||||
elif self.field_qualifies_to_deeper_search(
|
||||
elif self._field_qualifies_to_deeper_search(
|
||||
field, parent_virtual, nested, rel_part
|
||||
):
|
||||
join_params = JoinParameters(
|
||||
@ -180,7 +182,7 @@ class QuerySet:
|
||||
join_params.from_table,
|
||||
join_params.prev_model,
|
||||
)
|
||||
self.extract_auto_required_relations(
|
||||
self._extract_auto_required_relations(
|
||||
join_params=join_params,
|
||||
rel_part=rel_part,
|
||||
nested=True,
|
||||
@ -189,6 +191,41 @@ class QuerySet:
|
||||
else:
|
||||
rel_part = ""
|
||||
|
||||
def _include_auto_related_models(self) -> None:
|
||||
if self.auto_related:
|
||||
new_joins = []
|
||||
for join in self._select_related:
|
||||
if not any([x.startswith(join) for x in self.auto_related]):
|
||||
new_joins.append(join)
|
||||
self._select_related = new_joins + self.auto_related
|
||||
|
||||
def _apply_expression_modifiers(
|
||||
self, expr: sqlalchemy.sql.select
|
||||
) -> sqlalchemy.sql.select:
|
||||
if self.filter_clauses:
|
||||
if len(self.filter_clauses) == 1:
|
||||
clause = self.filter_clauses[0]
|
||||
else:
|
||||
clause = sqlalchemy.sql.and_(*self.filter_clauses)
|
||||
expr = expr.where(clause)
|
||||
|
||||
if self.limit_count:
|
||||
expr = expr.limit(self.limit_count)
|
||||
|
||||
if self.query_offset:
|
||||
expr = expr.offset(self.query_offset)
|
||||
|
||||
for order in self.order_bys:
|
||||
expr = expr.order_by(order)
|
||||
return expr
|
||||
|
||||
def _reset_query_parameters(self) -> None:
|
||||
self.select_from = None
|
||||
self.columns = None
|
||||
self.order_bys = None
|
||||
self.auto_related = []
|
||||
self.used_aliases = []
|
||||
|
||||
def build_select_expression(self) -> sqlalchemy.sql.select:
|
||||
self.columns = list(self.table.columns)
|
||||
self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")]
|
||||
@ -207,14 +244,9 @@ class QuerySet:
|
||||
start_params = JoinParameters(
|
||||
self.model_cls, "", self.table.name, self.model_cls
|
||||
)
|
||||
self.extract_auto_required_relations(start_params)
|
||||
if self.auto_related:
|
||||
new_joins = []
|
||||
for join in self._select_related:
|
||||
if not any([x.startswith(join) for x in self.auto_related]):
|
||||
new_joins.append(join)
|
||||
self._select_related = new_joins + self.auto_related
|
||||
self._select_related.sort(key=lambda item: (-len(item), item))
|
||||
self._extract_auto_required_relations(start_params)
|
||||
self._include_auto_related_models()
|
||||
self._select_related.sort(key=lambda item: (-len(item), item))
|
||||
|
||||
for item in self._select_related:
|
||||
join_parameters = JoinParameters(
|
||||
@ -222,34 +254,15 @@ class QuerySet:
|
||||
)
|
||||
|
||||
for part in item.split("__"):
|
||||
join_parameters = self.build_join_parameters(part, join_parameters)
|
||||
join_parameters = self._build_join_parameters(part, join_parameters)
|
||||
|
||||
expr = sqlalchemy.sql.select(self.columns)
|
||||
expr = expr.select_from(self.select_from)
|
||||
|
||||
if self.filter_clauses:
|
||||
if len(self.filter_clauses) == 1:
|
||||
clause = self.filter_clauses[0]
|
||||
else:
|
||||
clause = sqlalchemy.sql.and_(*self.filter_clauses)
|
||||
expr = expr.where(clause)
|
||||
|
||||
if self.limit_count:
|
||||
expr = expr.limit(self.limit_count)
|
||||
|
||||
if self.query_offset:
|
||||
expr = expr.offset(self.query_offset)
|
||||
|
||||
for order in self.order_bys:
|
||||
expr = expr.order_by(order)
|
||||
expr = self._apply_expression_modifiers(expr)
|
||||
|
||||
# print(expr.compile(compile_kwargs={"literal_binds": True}))
|
||||
|
||||
self.select_from = None
|
||||
self.columns = None
|
||||
self.order_bys = None
|
||||
self.auto_related = []
|
||||
self.used_aliases = []
|
||||
self._reset_query_parameters()
|
||||
|
||||
return expr
|
||||
|
||||
@ -298,7 +311,6 @@ class QuerySet:
|
||||
model_cls = model_cls.__model_fields__[part].to
|
||||
previous_table = current_table
|
||||
|
||||
# print(table_prefix)
|
||||
table = model_cls.__table__
|
||||
column = model_cls.__table__.columns[field_name]
|
||||
|
||||
|
||||
@ -71,43 +71,47 @@ class RelationshipManager:
|
||||
child, parent = parent, proxy(child)
|
||||
else:
|
||||
child = proxy(child)
|
||||
parents_list = self._relations[
|
||||
parent_name.lower().title() + "_" + child_name + "s"
|
||||
].setdefault(parent_id, [])
|
||||
|
||||
parent_relation_name = parent_name.lower().title() + "_" + child_name + "s"
|
||||
parents_list = self._relations[parent_relation_name].setdefault(parent_id, [])
|
||||
self.append_related_model(parents_list, child)
|
||||
children_list = self._relations[
|
||||
child_name.lower().title() + "_" + parent_name
|
||||
].setdefault(child_id, [])
|
||||
|
||||
child_relation_name = child_name.lower().title() + "_" + parent_name
|
||||
children_list = self._relations[child_relation_name].setdefault(child_id, [])
|
||||
self.append_related_model(children_list, parent)
|
||||
|
||||
def append_related_model(
|
||||
self, relations_list: List["Model"], model: "Model"
|
||||
) -> None:
|
||||
for x in relations_list:
|
||||
@staticmethod
|
||||
def append_related_model(relations_list: List["Model"], model: "Model") -> None:
|
||||
for relation_child in relations_list:
|
||||
try:
|
||||
if x.__same__(model):
|
||||
if relation_child.__same__(model):
|
||||
return
|
||||
except ReferenceError:
|
||||
continue
|
||||
|
||||
relations_list.append(model)
|
||||
|
||||
def contains(self, relations_key: str, object: "Model") -> bool:
|
||||
def contains(self, relations_key: str, instance: "Model") -> bool:
|
||||
if relations_key in self._relations:
|
||||
return object._orm_id in self._relations[relations_key]
|
||||
return instance._orm_id in self._relations[relations_key]
|
||||
return False
|
||||
|
||||
def get(self, relations_key: str, object: "Model") -> Union["Model", List["Model"]]:
|
||||
def get(
|
||||
self, relations_key: str, instance: "Model"
|
||||
) -> Union["Model", List["Model"]]:
|
||||
if relations_key in self._relations:
|
||||
if object._orm_id in self._relations[relations_key]:
|
||||
if instance._orm_id in self._relations[relations_key]:
|
||||
if self._relations[relations_key]["type"] == "primary":
|
||||
return self._relations[relations_key][object._orm_id][0]
|
||||
return self._relations[relations_key][object._orm_id]
|
||||
return self._relations[relations_key][instance._orm_id][0]
|
||||
return self._relations[relations_key][instance._orm_id]
|
||||
|
||||
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
|
||||
for k, v in self._relations.items():
|
||||
if v["source_table"] == from_table and v["target_table"] == to_table:
|
||||
return self._relations[k]["table_alias"]
|
||||
for relation_name, relation in self._relations.items():
|
||||
if (
|
||||
relation["source_table"] == from_table
|
||||
and relation["target_table"] == to_table
|
||||
):
|
||||
return self._relations[relation_name]["table_alias"]
|
||||
return ""
|
||||
|
||||
def __str__(self) -> str: # pragma no cover
|
||||
|
||||
Reference in New Issue
Block a user