From 8f179f763f36bc4ed6d3b72a005242afab58d2de Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 7 Aug 2020 19:34:17 +0200 Subject: [PATCH] add preloading of not nullable relations (and all chain inbetween) --- .coverage | Bin 53248 -> 53248 bytes orm/fields.py | 8 +- orm/queryset.py | 152 ++++++++++++++++++++++++--------- orm/relations.py | 12 +-- tests/test_same_table_joins.py | 9 +- 5 files changed, 128 insertions(+), 53 deletions(-) diff --git a/.coverage b/.coverage index 807412ba39ac227fd469563a159bb1f53c3717b6..1cc206f5c0dcc201cc45878a1905994243a1dd21 100644 GIT binary patch delta 165 zcmV;W09yZmpaX!Q1F$MD2RS-3GdeUmvoSB#P#&%T5BU%358w~G53vu04@(a>4=)cQ z4*d@24#p0yvk?%A4j2{=1OW*y4sP@2oj31!&(D7c&wuay?*Ws~j(tM){(|?{`+W&G z2Lu5LatGSwe|+}0+ud!hZ-4K;bq^1E-+$X~+xWVByxZ|-j=kHye(ukEfBxLvzWdZ~ T_bbnt1Cxc1S^+P!;Ey0c=M+v4 delta 155 zcmV;M0A&AwpaX!Q1F$MD2Q@k}Fgi6dvoSB#P#(Ad5BU%358w~G53vu04^$684>%7i z4+akK4$ls_vk?%K4i^;;1OW*w4sP@2oj31!&(D9y`R|?oJ(I?ceL4O9j`!F5eF-lI z1OW+P2io#K_HTE$>AtVkcDvoS?e6hz$KN^jZu|PVKkxndb9ejhQ@{PMJZBD*g^yYR J3$x&lAV3tJNV@<4 diff --git a/orm/fields.py b/orm/fields.py index 69ba85f..aa40930 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -185,8 +185,8 @@ def create_dummy_instance(fk: Type['Model'], pk: int = None): class ForeignKey(BaseField): - def __init__(self, to, related_name: str = None, nullable: bool = True, virtual: bool = False): - super().__init__(nullable=nullable) + def __init__(self, to, name: str = None, related_name: str = None, nullable: bool = True, virtual: bool = False): + super().__init__(nullable=nullable, name=name) self.virtual = virtual self.related_name = related_name self.to = to @@ -230,6 +230,8 @@ class ForeignKey(BaseField): type_=Optional[child.__pydantic_model__], model_config=child.__pydantic_model__.__config__, class_validators=child.__pydantic_model__.__validators__) - model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True) + model.__model_fields__[child_model_name] = ForeignKey(child.__class__, + name=child_model_name, + virtual=True) return model diff --git a/orm/queryset.py b/orm/queryset.py index 2eba123..95ad8f6 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -1,9 +1,10 @@ -from typing import List, TYPE_CHECKING, Type +from typing import List, TYPE_CHECKING, Type, NamedTuple import sqlalchemy from sqlalchemy import text import orm +from orm import ForeignKey from orm.exceptions import NoMatch, MultipleMatches if TYPE_CHECKING: # pragma no cover @@ -22,6 +23,13 @@ FILTER_OPERATORS = { } +class JoinParameters(NamedTuple): + prev_model: Type['Model'] + previous_alias: str + from_table: str + model_cls: Type['Model'] + + class QuerySet: ESCAPE_CHARACTERS = ['%', '_'] @@ -32,7 +40,13 @@ class QuerySet: self._select_related = [] if select_related is None else select_related self.limit_count = limit_count self.query_offset = offset - self.aliases_dict = dict() + + self.auto_related = [] + self.used_aliases = [] + + self.select_from = None + self.columns = None + self.order_bys = None def __get__(self, instance, owner): return self.__class__(model_cls=owner) @@ -56,52 +70,101 @@ class QuerySet: return text(f'{alias}_{to_table}.{to_key}=' f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}') + def build_join_parameters(self, part, join_params: JoinParameters): + model_cls = join_params.model_cls.__model_fields__[part].to + to_table = model_cls.__table__.name + + alias = model_cls._orm_relationship_manager.resolve_relation_join(join_params.from_table, to_table) + # print(f'resolving tables alias from {join_params.from_table}, to: {to_table} -> {alias}') + if alias not in self.used_aliases: + if join_params.prev_model.__model_fields__[part].virtual: + to_key = next((v for k, v in model_cls.__model_fields__.items() + if isinstance(v, ForeignKey) and v.to == join_params.prev_model), None).name + from_key = model_cls.__pkname__ + else: + to_key = model_cls.__pkname__ + from_key = part + + on_clause = self.on_clause(join_params.from_table, to_table, join_params.previous_alias, alias, to_key, + from_key) + target_table = self.prefixed_table_name(alias, to_table) + self.select_from = sqlalchemy.sql.outerjoin(self.select_from, target_table, on_clause) + self.order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}')) + self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) + self.used_aliases.append(alias) + + previous_alias = alias + from_table = to_table + prev_model = model_cls + return JoinParameters(prev_model, previous_alias, from_table, model_cls) + + @staticmethod + def field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part) -> bool: + return isinstance(field, ForeignKey) and field_name not in rel_part + + def field_qualifies_to_deeper_search(self, field, parent_virtual, nested, rel_part) -> bool: + prev_part_of_related = "__".join(rel_part.split("__")[:-1]) + partial_match = any([x.startswith(prev_part_of_related) for x in self._select_related]) + already_checked = any([x.startswith(rel_part) for x in self.auto_related]) + return ((field.virtual and parent_virtual) or (partial_match and not already_checked)) or not nested + + def extract_auto_required_relations(self, join_params: JoinParameters, + rel_part: str = '', nested: bool = False, parent_virtual: bool = False): + # print(f'checking model {join_params.prev_model}, {rel_part}') + for field_name, field in join_params.prev_model.__model_fields__.items(): + # print(f'checking_field {field_name}') + if self.field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part): + rel_part = field_name if not rel_part else rel_part + '__' + field_name + if not field.nullable: + # print(f'field {field_name} is not nullable, appending to auto, curr rel: {rel_part}') + if rel_part not in self._select_related: + self.auto_related.append("__".join(rel_part.split("__")[:-1])) + rel_part = '' + elif self.field_qualifies_to_deeper_search(field, parent_virtual, nested, rel_part): + # print( + # f'field {field_name} is nullable, going down, curr rel: ' + # f'{rel_part}, nested:{nested}, virtual:{field.virtual}, virtual_par:{parent_virtual}, ' + # f'injoin: {"__".join(rel_part.split("__")[:-1]) in self._select_related}') + join_params = JoinParameters(field.to, join_params.previous_alias, + join_params.from_table, join_params.prev_model) + self.extract_auto_required_relations(join_params=join_params, + rel_part=rel_part, nested=True, parent_virtual=field.virtual) + else: + # print( + # f'field {field_name} is out, going down, curr rel: ' + # f'{rel_part}, nested:{nested}, virtual:{field.virtual}, virtual_par:{parent_virtual}, ' + # f'injoin: {"__".join(rel_part.split("__")[:-1]) in self._select_related}') + rel_part = '' + def build_select_expression(self): - # tables = [self.table] - columns = list(self.table.columns) - order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')] - select_from = self.table + self.columns = list(self.table.columns) + self.order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')] + self.select_from = self.table for key in self.model_cls.__model_fields__: if not self.model_cls.__model_fields__[key].nullable \ and isinstance(self.model_cls.__model_fields__[key], orm.fields.ForeignKey) \ and key not in self._select_related: - self._select_related.append(key) + self._select_related = [key] + self._select_related + + start_params = JoinParameters(self.model_cls, '', self.table.name, self.model_cls) + self.extract_auto_required_relations(start_params) + if self.auto_related: + new_joins = [] + for join in self._select_related: + if not any([x.startswith(join) for x in self.auto_related]): + new_joins.append(join) + self._select_related = new_joins + self.auto_related + self._select_related.sort(key=lambda item: (-len(item), item)) for item in self._select_related: - previous_alias = '' - from_table = self.table.name - prev_model = self.model_cls - model_cls = self.model_cls + join_parameters = JoinParameters(self.model_cls, '', self.table.name, self.model_cls) for part in item.split("__"): + join_parameters = self.build_join_parameters(part, join_parameters) - model_cls = model_cls.__model_fields__[part].to - to_table = model_cls.__table__.name - - alias = model_cls._orm_relationship_manager.resolve_relation_join(from_table, to_table) - - if prev_model.__model_fields__[part].virtual: - # TODO: change the key lookup - to_key = prev_model.__name__.lower() - from_key = model_cls.__pkname__ - else: - to_key = model_cls.__pkname__ - from_key = part - - on_clause = self.on_clause(from_table, to_table, previous_alias, alias, to_key, from_key) - target_table = self.prefixed_table_name(alias, to_table) - select_from = sqlalchemy.sql.outerjoin(select_from, target_table, on_clause) - # tables.append(model_cls.__table__) - order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}')) - columns.extend(self.prefixed_columns(alias, model_cls.__table__)) - - previous_alias = alias - from_table = to_table - prev_model = model_cls - - expr = sqlalchemy.sql.select(columns) - expr = expr.select_from(select_from) + expr = sqlalchemy.sql.select(self.columns) + expr = expr.select_from(self.select_from) if self.filter_clauses: if len(self.filter_clauses) == 1: @@ -116,10 +179,17 @@ class QuerySet: if self.query_offset: expr = expr.offset(self.query_offset) - for order in order_bys: + for order in self.order_bys: expr = expr.order_by(order) - print(expr.compile(compile_kwargs={"literal_binds": True})) + # print(expr.compile(compile_kwargs={"literal_binds": True})) + + self.select_from = None + self.columns = None + self.order_bys = None + self.auto_related = [] + self.used_aliases = [] + return expr def filter(self, **kwargs): @@ -163,7 +233,7 @@ class QuerySet: model_cls = model_cls.__model_fields__[part].to previous_table = current_table - print(table_prefix) + # print(table_prefix) table = model_cls.__table__ column = model_cls.__table__.columns[field_name] @@ -298,7 +368,7 @@ class QuerySet: @classmethod def merge_two_instances(cls, one: 'Model', other: 'Model'): for field in one.__model_fields__.keys(): - print(field, one.dict(), other.dict()) + # print(field, one.dict(), other.dict()) if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model): setattr(other, field, getattr(one, field) + getattr(other, field)) elif isinstance(getattr(one, field), orm.models.Model): diff --git a/orm/relations.py b/orm/relations.py index 0888284..a1a6625 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -40,7 +40,7 @@ class RelationshipManager: self._relations[reverse_key] = get_relation_config('reverse', table_name, field) def deregister(self, model: 'Model'): - print(f'deregistering {model.__class__.__name__}, {model._orm_id}') + # print(f'deregistering {model.__class__.__name__}, {model._orm_id}') for rel_type in self._relations.keys(): if model.__class__.__name__.lower() in rel_type.lower(): if model._orm_id in self._relations[rel_type]: @@ -55,11 +55,11 @@ class RelationshipManager: child, parent = parent, proxy(child) else: child = proxy(child) - print( - f'setting up relationship, {parent_id}, {child_id}, ' - f'{parent.__class__.__name__}, {child.__class__.__name__}, ' - f'{parent.pk if parent.values is not None else None}, ' - f'{child.pk if child.values is not None else None}') + # print( + # f'setting up relationship, {parent_id}, {child_id}, ' + # f'{parent.__class__.__name__}, {child.__class__.__name__}, ' + # f'{parent.pk if parent.values is not None else None}, ' + # f'{child.pk if child.values is not None else None}') parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, []) self.append_related_model(parents_list, child) children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, []) diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index f98e01b..ca4e7b7 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -77,15 +77,18 @@ async def test_model_multiple_instances_of_same_table_in_schema(): await Student.objects.create(name="Jane", category=category, schoolclass=class1) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) - classes = await SchoolClass.objects.select_related(['teachers', 'students']).all() + classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() assert classes[0].name == 'Math' assert classes[0].students[0].name == 'Jane' # related fields of main model are only populated by pk + # unless there is a required foreign key somewhere along the way + # since department is required for schoolclass it was pre loaded (again) # but you can load them anytime - assert classes[0].students[0].schoolclass.name is None - await classes[0].students[0].schoolclass.load() assert classes[0].students[0].schoolclass.name == 'Math' + assert classes[0].students[0].schoolclass.department.name is None + await classes[0].students[0].schoolclass.department.load() + assert classes[0].students[0].schoolclass.department.name == 'Math Department' @pytest.mark.asyncio