From 22c4a0619c8ad834b81d56553b10b0185754cb7b Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 9 Aug 2020 08:59:36 +0200 Subject: [PATCH] fix some code smells --- .coverage | Bin 53248 -> 53248 bytes README.md | 50 +++++++++--------- orm/queryset.py | 92 +++++++++++++++++++-------------- orm/relations.py | 44 +++++++++------- requirements.txt | 7 ++- tests/test_same_table_joins.py | 49 +++++++----------- 6 files changed, 127 insertions(+), 115 deletions(-) diff --git a/.coverage b/.coverage index 6b4bedb1fcbc2347fae84cb3ec780f161553b1dd..0de5a489869598a0ebded3ec7d3ded637646d989 100644 GIT binary patch delta 170 zcmV;b09F5hpaX!Q1F$MD2RJ%4IXW;fvoSB^P!BN=8V?c=1rFd2$PT9tk+TsHZ4MR` z4g>)SDh_V*=AAe1dC$-P=LgSs{*$SWb6yn)1OW*w2!6%91j!^2M*#6$_}Xxk+TsHZ4MX~ z4g>)SE)H(<=AAe1dC$*(2hV@+{O sqlalchemy.Table: return self.model_cls.__table__ - def prefixed_columns(self, alias: str, table: sqlalchemy.Table) -> List[text]: + @staticmethod + def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: return [ text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}") for column in table.columns ] - def prefixed_table_name(self, alias: str, name: str) -> text: + @staticmethod + def prefixed_table_name(alias: str, name: str) -> text: return text(f"{name} {alias}_{name}") def on_clause( @@ -91,7 +93,7 @@ class QuerySet: f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}' ) - def build_join_parameters( + def _build_join_parameters( self, part: str, join_params: JoinParameters ) -> JoinParameters: model_cls = join_params.model_cls.__model_fields__[part].to @@ -137,12 +139,12 @@ class QuerySet: return JoinParameters(prev_model, previous_alias, from_table, model_cls) @staticmethod - def field_is_a_foreign_key_and_no_circular_reference( + def _field_is_a_foreign_key_and_no_circular_reference( field: BaseField, field_name: str, rel_part: str ) -> bool: return isinstance(field, ForeignKey) and field_name not in rel_part - def field_qualifies_to_deeper_search( + def _field_qualifies_to_deeper_search( self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str ) -> bool: prev_part_of_related = "__".join(rel_part.split("__")[:-1]) @@ -155,7 +157,7 @@ class QuerySet: or (partial_match and not already_checked) ) or not nested - def extract_auto_required_relations( + def _extract_auto_required_relations( self, join_params: JoinParameters, rel_part: str = "", @@ -163,7 +165,7 @@ class QuerySet: parent_virtual: bool = False, ) -> None: for field_name, field in join_params.prev_model.__model_fields__.items(): - if self.field_is_a_foreign_key_and_no_circular_reference( + if self._field_is_a_foreign_key_and_no_circular_reference( field, field_name, rel_part ): rel_part = field_name if not rel_part else rel_part + "__" + field_name @@ -171,7 +173,7 @@ class QuerySet: if rel_part not in self._select_related: self.auto_related.append("__".join(rel_part.split("__")[:-1])) rel_part = "" - elif self.field_qualifies_to_deeper_search( + elif self._field_qualifies_to_deeper_search( field, parent_virtual, nested, rel_part ): join_params = JoinParameters( @@ -180,7 +182,7 @@ class QuerySet: join_params.from_table, join_params.prev_model, ) - self.extract_auto_required_relations( + self._extract_auto_required_relations( join_params=join_params, rel_part=rel_part, nested=True, @@ -189,6 +191,41 @@ class QuerySet: else: rel_part = "" + def _include_auto_related_models(self) -> None: + if self.auto_related: + new_joins = [] + for join in self._select_related: + if not any([x.startswith(join) for x in self.auto_related]): + new_joins.append(join) + self._select_related = new_joins + self.auto_related + + def _apply_expression_modifiers( + self, expr: sqlalchemy.sql.select + ) -> sqlalchemy.sql.select: + if self.filter_clauses: + if len(self.filter_clauses) == 1: + clause = self.filter_clauses[0] + else: + clause = sqlalchemy.sql.and_(*self.filter_clauses) + expr = expr.where(clause) + + if self.limit_count: + expr = expr.limit(self.limit_count) + + if self.query_offset: + expr = expr.offset(self.query_offset) + + for order in self.order_bys: + expr = expr.order_by(order) + return expr + + def _reset_query_parameters(self) -> None: + self.select_from = None + self.columns = None + self.order_bys = None + self.auto_related = [] + self.used_aliases = [] + def build_select_expression(self) -> sqlalchemy.sql.select: self.columns = list(self.table.columns) self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] @@ -207,14 +244,9 @@ class QuerySet: start_params = JoinParameters( self.model_cls, "", self.table.name, self.model_cls ) - self.extract_auto_required_relations(start_params) - if self.auto_related: - new_joins = [] - for join in self._select_related: - if not any([x.startswith(join) for x in self.auto_related]): - new_joins.append(join) - self._select_related = new_joins + self.auto_related - self._select_related.sort(key=lambda item: (-len(item), item)) + self._extract_auto_required_relations(start_params) + self._include_auto_related_models() + self._select_related.sort(key=lambda item: (-len(item), item)) for item in self._select_related: join_parameters = JoinParameters( @@ -222,34 +254,15 @@ class QuerySet: ) for part in item.split("__"): - join_parameters = self.build_join_parameters(part, join_parameters) + join_parameters = self._build_join_parameters(part, join_parameters) expr = sqlalchemy.sql.select(self.columns) expr = expr.select_from(self.select_from) - if self.filter_clauses: - if len(self.filter_clauses) == 1: - clause = self.filter_clauses[0] - else: - clause = sqlalchemy.sql.and_(*self.filter_clauses) - expr = expr.where(clause) - - if self.limit_count: - expr = expr.limit(self.limit_count) - - if self.query_offset: - expr = expr.offset(self.query_offset) - - for order in self.order_bys: - expr = expr.order_by(order) + expr = self._apply_expression_modifiers(expr) # print(expr.compile(compile_kwargs={"literal_binds": True})) - - self.select_from = None - self.columns = None - self.order_bys = None - self.auto_related = [] - self.used_aliases = [] + self._reset_query_parameters() return expr @@ -298,7 +311,6 @@ class QuerySet: model_cls = model_cls.__model_fields__[part].to previous_table = current_table - # print(table_prefix) table = model_cls.__table__ column = model_cls.__table__.columns[field_name] diff --git a/orm/relations.py b/orm/relations.py index b5741e1..3232c8c 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -71,43 +71,47 @@ class RelationshipManager: child, parent = parent, proxy(child) else: child = proxy(child) - parents_list = self._relations[ - parent_name.lower().title() + "_" + child_name + "s" - ].setdefault(parent_id, []) + + parent_relation_name = parent_name.lower().title() + "_" + child_name + "s" + parents_list = self._relations[parent_relation_name].setdefault(parent_id, []) self.append_related_model(parents_list, child) - children_list = self._relations[ - child_name.lower().title() + "_" + parent_name - ].setdefault(child_id, []) + + child_relation_name = child_name.lower().title() + "_" + parent_name + children_list = self._relations[child_relation_name].setdefault(child_id, []) self.append_related_model(children_list, parent) - def append_related_model( - self, relations_list: List["Model"], model: "Model" - ) -> None: - for x in relations_list: + @staticmethod + def append_related_model(relations_list: List["Model"], model: "Model") -> None: + for relation_child in relations_list: try: - if x.__same__(model): + if relation_child.__same__(model): return except ReferenceError: continue relations_list.append(model) - def contains(self, relations_key: str, object: "Model") -> bool: + def contains(self, relations_key: str, instance: "Model") -> bool: if relations_key in self._relations: - return object._orm_id in self._relations[relations_key] + return instance._orm_id in self._relations[relations_key] return False - def get(self, relations_key: str, object: "Model") -> Union["Model", List["Model"]]: + def get( + self, relations_key: str, instance: "Model" + ) -> Union["Model", List["Model"]]: if relations_key in self._relations: - if object._orm_id in self._relations[relations_key]: + if instance._orm_id in self._relations[relations_key]: if self._relations[relations_key]["type"] == "primary": - return self._relations[relations_key][object._orm_id][0] - return self._relations[relations_key][object._orm_id] + return self._relations[relations_key][instance._orm_id][0] + return self._relations[relations_key][instance._orm_id] def resolve_relation_join(self, from_table: str, to_table: str) -> str: - for k, v in self._relations.items(): - if v["source_table"] == from_table and v["target_table"] == to_table: - return self._relations[k]["table_alias"] + for relation_name, relation in self._relations.items(): + if ( + relation["source_table"] == from_table + and relation["target_table"] == to_table + ): + return self._relations[relation_name]["table_alias"] return "" def __str__(self) -> str: # pragma no cover diff --git a/requirements.txt b/requirements.txt index fde4a73..807e704 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,9 @@ flake8-black flake8-bugbear flake8-import-order flake8-bandit -flake8-annotations \ No newline at end of file +flake8-annotations +flake8-builtins +flake8-variables-names +flake8-cognitive-complexity +flake8-functions +flake8-expression-complexity \ No newline at end of file diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index ca4e7b7..ea85409 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -67,16 +67,24 @@ def create_test_database(): metadata.drop_all(engine) -@pytest.mark.asyncio -async def test_model_multiple_instances_of_same_table_in_schema(): - async with database: - department = await Department.objects.create(id=1, name='Math Department') - class1 = await SchoolClass.objects.create(name="Math", department=department) - category = await Category.objects.create(name="Foreign") - category2 = await Category.objects.create(name="Domestic") - await Student.objects.create(name="Jane", category=category, schoolclass=class1) - await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) +@pytest.fixture() +async def init_relation(): + department = await Department.objects.create(id=1, name='Math Department') + class1 = await SchoolClass.objects.create(name="Math", department=department) + category = await Category.objects.create(name="Foreign") + category2 = await Category.objects.create(name="Domestic") + await Student.objects.create(name="Jane", category=category, schoolclass=class1) + await Student.objects.create(name="Jack", category=category2, schoolclass=class1) + await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) + yield + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + +@pytest.mark.asyncio +async def test_model_multiple_instances_of_same_table_in_schema(init_relation): + async with database: classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() assert classes[0].name == 'Math' assert classes[0].students[0].name == 'Jane' @@ -92,18 +100,9 @@ async def test_model_multiple_instances_of_same_table_in_schema(): @pytest.mark.asyncio -async def test_right_tables_join(): +async def test_right_tables_join(init_relation): async with database: - department = await Department.objects.create(id=1, name='Math Department') - class1 = await SchoolClass.objects.create(name="Math", department=department) - category = await Category.objects.create(name="Foreign") - category2 = await Category.objects.create(name="Domestic") - await Student.objects.create(name="Jane", category=category, schoolclass=class1) - await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) - classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() - assert classes[0].name == 'Math' - assert classes[0].students[0].name == 'Jane' assert classes[0].teachers[0].category.name == 'Domestic' assert classes[0].students[0].category.name is None @@ -112,17 +111,9 @@ async def test_right_tables_join(): @pytest.mark.asyncio -async def test_multiple_reverse_related_objects(): +async def test_multiple_reverse_related_objects(init_relation): async with database: - department = await Department.objects.create(id=1, name='Math Department') - class1 = await SchoolClass.objects.create(name="Math", department=department) - category = await Category.objects.create(name="Foreign") - category2 = await Category.objects.create(name="Domestic") - await Student.objects.create(name="Jane", category=category, schoolclass=class1) - await Student.objects.create(name="Jack", category=category, schoolclass=class1) - await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) - classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() assert classes[0].name == 'Math' - assert classes[0].students[0].name == 'Jane' + assert classes[0].students[1].name == 'Jack' assert classes[0].teachers[0].category.name == 'Domestic'