From 62475a1949ba8aebf634961943116f6aabf1129c Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 7 Aug 2020 13:20:16 +0200 Subject: [PATCH] change queryset to work with column and table aliases --- .coverage | Bin 53248 -> 53248 bytes orm/fields.py | 8 ++- orm/models.py | 42 +++++++++++--- orm/queryset.py | 101 ++++++++++++++++++++++++++++++--- orm/relations.py | 30 +++++++++- tests/test_same_table_joins.py | 38 ++++++++++++- 6 files changed, 194 insertions(+), 25 deletions(-) diff --git a/.coverage b/.coverage index ad89c9e7f4286dd936db7dc6b3d5e305f71bfba6..955ea24ea0b5b017d6968d3bd41402d31daa70d6 100644 GIT binary patch delta 280 zcmV+z0q6dJpaX!Q1F!}l3I_lWLJu(y6%Plq5fJwdlMpW=ATcvKF)=zcIS&E@8y9wE za&u{KZZ$44E-`^S0R$a_Ym-thSOhI_WMZ?WFJVv~yZ{gR59$x#54{hw4~Gv|4?_<- z4=xW04)YGu4!pAw5SR`Y6b=Ld2`Ub5^X8p5?|IM1e*qW(fB}=@j&Wxa2m}EMCJ25N z^Go~v9*+0c`+W&62Lu5LUI*IpKlX2TxB0$S+wXSUw!6o>9e?N8yY1`e{=E0+&)x02 zPklCk@|-y;1_S{KRtA3e^3L!7x%_P2{A-c{y|<69@7Z3qe|N#b?(UPSj~g3h|GL}V eyYK%eo4wg=Hk-|6v)Sx7zq|YKdcU*ak03y~Tzbp^ delta 255 zcmV)74`dHh4@(a|4;c>w z4&n~avk?%h4iga$1OW*o4sP@2&3oSS@$WeQ9+T3JaaRur1OW*k2!0jwOZ$C)_kIZ^ z2Lu5LN(XBBAN$+g{q45x?s&Ij8GE<;xj*my`Ez&s?o*yK2PXyu0SQnBe)sau@Bg{{ zY~K8vWT5Zuv%Y7`f4ksdcay4*8ym8J-R None: self._orm_id: str = uuid.uuid4().hex self._orm_saved: bool = False - self._orm_relationship_manager: RelationshipManager = relationship_manager self.values: Optional[BaseModel] = None if "pk" in kwargs: @@ -129,8 +131,12 @@ class Model(list, metaclass=ModelMetaclass): value = json.dumps(value) except TypeError: # pragma no cover pass + value = self.__model_fields__[key].expand_relationship(value, self) - setattr(self.values, key, value) + + relation_key = self.__class__.__name__.title() + '_' + key + if not self._orm_relationship_manager.contains(relation_key, self): + setattr(self.values, key, value) else: super().__setattr__(key, value) @@ -152,25 +158,36 @@ class Model(list, metaclass=ModelMetaclass): def __eq__(self, other): return self.values.dict() == other.values.dict() + def __same__(self, other): + assert self.__class__ == other.__class__ + return self._orm_id == other._orm_id or ( + self.values is not None and other.values is not None and self.pk == other.pk) + def __repr__(self): # pragma no cover return self.values.__repr__() @classmethod - def from_row(cls, row, select_related: List = None) -> 'Model': + def from_row(cls, row, select_related: List = None, previous_table: str = None) -> 'Model': + item = {} select_related = select_related or [] + + table_prefix = cls._orm_relationship_manager.resolve_relation_join(previous_table, cls.__table__.name) + previous_table = cls.__table__.name for related in select_related: if "__" in related: first_part, remainder = related.split("__", 1) model_cls = cls.__model_fields__[first_part].to - item[first_part] = model_cls.from_row(row, select_related=[remainder]) + child = model_cls.from_row(row, select_related=[remainder], previous_table=previous_table) + item[first_part] = child else: model_cls = cls.__model_fields__[related].to - item[related] = model_cls.from_row(row) + child = model_cls.from_row(row, previous_table=previous_table) + item[related] = child for column in cls.__table__.columns: if column.name not in item: - item[column.name] = row[column] + item[column.name] = row[f'{table_prefix + "_" if table_prefix else ""}{column.name}'] return cls(**item) @@ -202,7 +219,14 @@ class Model(list, metaclass=ModelMetaclass): return self.__table__.primary_key.columns.values()[0] def dict(self) -> Dict: - return self.values.dict() + dict_instance = self.values.dict() + for field in self.extract_related_names(): + nested_model = getattr(self, field) + if isinstance(nested_model, list): + dict_instance[field] = [x.dict() for x in nested_model] + else: + dict_instance[field] = nested_model.dict() if nested_model is not None else {} + return dict_instance def from_dict(self, value_dict: Dict) -> None: for key, value in value_dict.items(): diff --git a/orm/queryset.py b/orm/queryset.py index d0d6b2c..1f01d7d 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -1,6 +1,7 @@ -from typing import List, TYPE_CHECKING +from typing import List, TYPE_CHECKING, Type import sqlalchemy +from sqlalchemy import text import orm from orm.exceptions import NoMatch, MultipleMatches @@ -24,13 +25,14 @@ FILTER_OPERATORS = { class QuerySet: ESCAPE_CHARACTERS = ['%', '_'] - def __init__(self, model_cls: 'Model' = None, filter_clauses: List = None, select_related: List = None, + def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None, limit_count: int = None, offset: int = None): self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses self._select_related = [] if select_related is None else select_related self.limit_count = limit_count self.query_offset = offset + self.aliases_dict = dict() def __get__(self, instance, owner): return self.__class__(model_cls=owner) @@ -43,19 +45,56 @@ class QuerySet: def table(self): return self.model_cls.__table__ + def prefixed_columns(self, alias, table): + return [text(f'{alias}_{table.name}.{column.name} as {alias}_{column.name}') + for column in table.columns] + + def prefixed_table_name(self, alias, name): + return text(f'{name} {alias}_{name}') + + def on_clause(self, from_table, to_table, previous_alias, alias, to_key, from_key): + return text(f'{alias}_{to_table}.{to_key}=' + f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}') + def build_select_expression(self): tables = [self.table] + columns = list(self.table.columns) + order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')] select_from = self.table for item in self._select_related: + previous_alias = '' + from_table = self.table.name + prev_model = self.model_cls model_cls = self.model_cls - select_from = self.table - for part in item.split("__"): - model_cls = model_cls.__model_fields__[part].to - select_from = sqlalchemy.sql.join(select_from, model_cls.__table__) - tables.append(model_cls.__table__) - expr = sqlalchemy.sql.select(tables) + for part in item.split("__"): + + model_cls = model_cls.__model_fields__[part].to + to_table = model_cls.__table__.name + + alias = model_cls._orm_relationship_manager.resolve_relation_join(from_table, to_table) + + if prev_model.__model_fields__[part].virtual: + # TODO: change the key lookup + to_key = prev_model.__name__.lower() + from_key = model_cls.__pkname__ + else: + to_key = model_cls.__pkname__ + from_key = part + + on_clause = self.on_clause(from_table, to_table, previous_alias, alias, to_key, from_key) + target_table = self.prefixed_table_name(alias, to_table) + select_from = sqlalchemy.sql.outerjoin(select_from, target_table, on_clause) + tables.append(model_cls.__table__) + order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}')) + columns.extend(self.prefixed_columns(alias, model_cls.__table__)) + + previous_alias = alias + from_table = to_table + prev_model = model_cls + + expr = sqlalchemy.sql.select(columns) expr = expr.select_from(select_from) if self.filter_clauses: @@ -71,6 +110,9 @@ class QuerySet: if self.query_offset: expr = expr.offset(self.query_offset) + for order in order_bys: + expr = expr.order_by(order) + print(expr.compile(compile_kwargs={"literal_binds": True})) return expr @@ -83,6 +125,7 @@ class QuerySet: kwargs[pk_name] = kwargs.pop("pk") for key, value in kwargs.items(): + table_prefix = '' if "__" in key: parts = key.split("__") @@ -106,14 +149,22 @@ class QuerySet: # Walk the relationships to the actual model class # against which the comparison is being made. + previous_table = model_cls.__tablename__ for part in related_parts: + current_table = model_cls.__model_fields__[part].to.__tablename__ + table_prefix = model_cls._orm_relationship_manager.resolve_relation_join(previous_table, + current_table) model_cls = model_cls.__model_fields__[part].to + previous_table = current_table + print(table_prefix) + table = model_cls.__table__ column = model_cls.__table__.columns[field_name] else: op = "exact" column = self.table.columns[key] + table = self.table # Map the operation code onto SQLAlchemy's ColumnElement # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement @@ -134,6 +185,13 @@ class QuerySet: clause = getattr(column, op_attr)(value) clause.modifiers['escape'] = '\\' if has_escaped_character else None + + clause_text = str(clause.compile(compile_kwargs={"literal_binds": True})) + alias = f'{table_prefix}_' if table_prefix else '' + aliased_name = f'{alias}{table.name}.{column.name}' + clause_text = clause_text.replace(f'{table.name}.{column.name}', aliased_name) + clause = text(clause_text) + filter_clauses.append(clause) return self.__class__( @@ -212,11 +270,36 @@ class QuerySet: expr = self.build_select_expression() rows = await self.database.fetch_all(expr) - return [ + result_rows = [ self.model_cls.from_row(row, select_related=self._select_related) for row in rows ] + result_rows = self.merge_result_rows(result_rows) + + return result_rows + + @classmethod + def merge_result_rows(cls, result_rows): + merged_rows = [] + for index, model in enumerate(result_rows): + if index > 0 and model.pk == result_rows[index - 1].pk: + result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) + else: + merged_rows.append(model) + return merged_rows + + @classmethod + def merge_two_instances(cls, one: 'Model', other: 'Model'): + for field in one.__model_fields__.keys(): + print(field, one.dict(), other.dict()) + if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model): + setattr(other, field, getattr(one, field) + getattr(other, field)) + elif isinstance(getattr(one, field), orm.models.Model): + if getattr(one, field).pk == getattr(other, field).pk: + setattr(other, field, cls.merge_two_instances(getattr(one, field), getattr(other, field))) + return other + async def create(self, **kwargs): new_kwargs = dict(**kwargs) diff --git a/orm/relations.py b/orm/relations.py index 31bb520..0888284 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -2,7 +2,7 @@ import pprint import string import uuid from random import choices -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List from weakref import proxy from sqlalchemy import text @@ -40,6 +40,7 @@ class RelationshipManager: self._relations[reverse_key] = get_relation_config('reverse', table_name, field) def deregister(self, model: 'Model'): + print(f'deregistering {model.__class__.__name__}, {model._orm_id}') for rel_type in self._relations.keys(): if model.__class__.__name__.lower() in rel_type.lower(): if model._orm_id in self._relations[rel_type]: @@ -54,8 +55,25 @@ class RelationshipManager: child, parent = parent, proxy(child) else: child = proxy(child) - self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append(child) - self._relations[child_name.title() + '_' + parent_name].setdefault(child_id, []).append(parent) + print( + f'setting up relationship, {parent_id}, {child_id}, ' + f'{parent.__class__.__name__}, {child.__class__.__name__}, ' + f'{parent.pk if parent.values is not None else None}, ' + f'{child.pk if child.values is not None else None}') + parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, []) + self.append_related_model(parents_list, child) + children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, []) + self.append_related_model(children_list, parent) + + def append_related_model(self, relations_list: List['Model'], model: 'Model'): + for x in relations_list: + try: + if x.__same__(model): + return + except ReferenceError: + continue + + relations_list.append(model) def contains(self, relations_key: str, object: 'Model'): if relations_key in self._relations: @@ -69,6 +87,12 @@ class RelationshipManager: return self._relations[relations_key][object._orm_id][0] return self._relations[relations_key][object._orm_id] + def resolve_relation_join(self, from_table: str, to_table: str) -> str: + for k, v in self._relations.items(): + if v['source_table'] == from_table and v['target_table'] == to_table: + return self._relations[k]['table_alias'] + return '' + def __str__(self): # pragma no cover return pprint.pformat(self._relations, indent=4, width=1) diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 155e6a4..30b34d5 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -9,6 +9,15 @@ database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() +class Department(orm.Model): + __tablename__ = "departments" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + class SchoolClass(orm.Model): __tablename__ = "schoolclasses" __metadata__ = metadata @@ -16,6 +25,7 @@ class SchoolClass(orm.Model): id = orm.Integer(primary_key=True) name = orm.String(length=100) + department = orm.ForeignKey(Department) class Category(orm.Model): @@ -60,7 +70,8 @@ def create_test_database(): @pytest.mark.asyncio async def test_model_multiple_instances_of_same_table_in_schema(): async with database: - class1 = await SchoolClass.objects.create(name="Math") + department = await Department.objects.create(name='Math Department') + class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") await Student.objects.create(name="Jane", category=category, schoolclass=class1) @@ -80,7 +91,8 @@ async def test_model_multiple_instances_of_same_table_in_schema(): @pytest.mark.asyncio async def test_right_tables_join(): async with database: - class1 = await SchoolClass.objects.create(name="Math") + department = await Department.objects.create(name='Math Department') + class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") await Student.objects.create(name="Jane", category=category, schoolclass=class1) @@ -89,5 +101,25 @@ async def test_right_tables_join(): classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() assert classes[0].name == 'Math' assert classes[0].students[0].name == 'Jane' - breakpoint() + assert classes[0].teachers[0].category.name == 'Domestic' + + assert classes[0].students[0].category.name is None + await classes[0].students[0].category.load() + assert classes[0].students[0].category.name == 'Foreign' + + +@pytest.mark.asyncio +async def test_multiple_reverse_related_objects(): + async with database: + department = await Department.objects.create(name='Math Department') + class1 = await SchoolClass.objects.create(name="Math", department=department) + category = await Category.objects.create(name="Foreign") + category2 = await Category.objects.create(name="Domestic") + await Student.objects.create(name="Jane", category=category, schoolclass=class1) + await Student.objects.create(name="Jack", category=category, schoolclass=class1) + await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) + + classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() + assert classes[0].name == 'Math' + assert classes[0].students[0].name == 'Jane' assert classes[0].teachers[0].category.name == 'Domestic'