diff --git a/.coverage b/.coverage index ad89c9e..955ea24 100644 Binary files a/.coverage and b/.coverage differ diff --git a/orm/fields.py b/orm/fields.py index 6f9c221..2723b82 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -2,6 +2,7 @@ import datetime import decimal from typing import Optional, List +import orm import sqlalchemy from pydantic import Json from pydantic.fields import ModelField @@ -192,9 +193,14 @@ class ForeignKey(BaseField): return to_column.get_column_type() def expand_relationship(self, value, child): - if not isinstance(value, (self.to, dict, int, str)): + if not isinstance(value, (self.to, dict, int, str, list)) or ( + isinstance(value, orm.models.Model) and not isinstance(value, self.to)): raise RelationshipInstanceError( 'Relationship model can be build only from orm.Model, dict and integer or string (pk).') + if isinstance(value, list) and not isinstance(value, self.to): + model = [self.expand_relationship(val, child) for val in value] + return model + if isinstance(value, self.to): model = value elif isinstance(value, dict): diff --git a/orm/models.py b/orm/models.py index 508ba00..0d2ff0e 100644 --- a/orm/models.py +++ b/orm/models.py @@ -41,8 +41,8 @@ def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename if field.primary_key: pkname = field_name if isinstance(field, ForeignKey): - reverse_name = field.related_name or field.to.__name__.title() + '_' + name.lower() + 's' - relation_name = name + '_' + field.to.__name__.lower() + reverse_name = field.related_name or field.to.__name__.lower().title() + '_' + name.lower() + 's' + relation_name = name.lower().title() + '_' + field.to.__name__.lower() relationship_manager.add_relation_type(relation_name, reverse_name, field, tablename) columns.append(field.get_column(field_name)) return pkname, columns, model_fields @@ -88,6 +88,8 @@ class ModelMetaclass(type): attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__) attrs['__model_fields__'] = model_fields + attrs['_orm_relationship_manager'] = relationship_manager + new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) @@ -105,13 +107,13 @@ class Model(list, metaclass=ModelMetaclass): __fields__: Dict[str, pydantic.fields.ModelField] __pydantic_model__: Type[BaseModel] __pkname__: str + _orm_relationship_manager: RelationshipManager objects = qry.QuerySet() def __init__(self, *args, **kwargs) -> None: self._orm_id: str = uuid.uuid4().hex self._orm_saved: bool = False - self._orm_relationship_manager: RelationshipManager = relationship_manager self.values: Optional[BaseModel] = None if "pk" in kwargs: @@ -129,8 +131,12 @@ class Model(list, metaclass=ModelMetaclass): value = json.dumps(value) except TypeError: # pragma no cover pass + value = self.__model_fields__[key].expand_relationship(value, self) - setattr(self.values, key, value) + + relation_key = self.__class__.__name__.title() + '_' + key + if not self._orm_relationship_manager.contains(relation_key, self): + setattr(self.values, key, value) else: super().__setattr__(key, value) @@ -152,25 +158,36 @@ class Model(list, metaclass=ModelMetaclass): def __eq__(self, other): return self.values.dict() == other.values.dict() + def __same__(self, other): + assert self.__class__ == other.__class__ + return self._orm_id == other._orm_id or ( + self.values is not None and other.values is not None and self.pk == other.pk) + def __repr__(self): # pragma no cover return self.values.__repr__() @classmethod - def from_row(cls, row, select_related: List = None) -> 'Model': + def from_row(cls, row, select_related: List = None, previous_table: str = None) -> 'Model': + item = {} select_related = select_related or [] + + table_prefix = cls._orm_relationship_manager.resolve_relation_join(previous_table, cls.__table__.name) + previous_table = cls.__table__.name for related in select_related: if "__" in related: first_part, remainder = related.split("__", 1) model_cls = cls.__model_fields__[first_part].to - item[first_part] = model_cls.from_row(row, select_related=[remainder]) + child = model_cls.from_row(row, select_related=[remainder], previous_table=previous_table) + item[first_part] = child else: model_cls = cls.__model_fields__[related].to - item[related] = model_cls.from_row(row) + child = model_cls.from_row(row, previous_table=previous_table) + item[related] = child for column in cls.__table__.columns: if column.name not in item: - item[column.name] = row[column] + item[column.name] = row[f'{table_prefix + "_" if table_prefix else ""}{column.name}'] return cls(**item) @@ -202,7 +219,14 @@ class Model(list, metaclass=ModelMetaclass): return self.__table__.primary_key.columns.values()[0] def dict(self) -> Dict: - return self.values.dict() + dict_instance = self.values.dict() + for field in self.extract_related_names(): + nested_model = getattr(self, field) + if isinstance(nested_model, list): + dict_instance[field] = [x.dict() for x in nested_model] + else: + dict_instance[field] = nested_model.dict() if nested_model is not None else {} + return dict_instance def from_dict(self, value_dict: Dict) -> None: for key, value in value_dict.items(): diff --git a/orm/queryset.py b/orm/queryset.py index d0d6b2c..1f01d7d 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -1,6 +1,7 @@ -from typing import List, TYPE_CHECKING +from typing import List, TYPE_CHECKING, Type import sqlalchemy +from sqlalchemy import text import orm from orm.exceptions import NoMatch, MultipleMatches @@ -24,13 +25,14 @@ FILTER_OPERATORS = { class QuerySet: ESCAPE_CHARACTERS = ['%', '_'] - def __init__(self, model_cls: 'Model' = None, filter_clauses: List = None, select_related: List = None, + def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None, limit_count: int = None, offset: int = None): self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses self._select_related = [] if select_related is None else select_related self.limit_count = limit_count self.query_offset = offset + self.aliases_dict = dict() def __get__(self, instance, owner): return self.__class__(model_cls=owner) @@ -43,19 +45,56 @@ class QuerySet: def table(self): return self.model_cls.__table__ + def prefixed_columns(self, alias, table): + return [text(f'{alias}_{table.name}.{column.name} as {alias}_{column.name}') + for column in table.columns] + + def prefixed_table_name(self, alias, name): + return text(f'{name} {alias}_{name}') + + def on_clause(self, from_table, to_table, previous_alias, alias, to_key, from_key): + return text(f'{alias}_{to_table}.{to_key}=' + f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}') + def build_select_expression(self): tables = [self.table] + columns = list(self.table.columns) + order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')] select_from = self.table for item in self._select_related: + previous_alias = '' + from_table = self.table.name + prev_model = self.model_cls model_cls = self.model_cls - select_from = self.table - for part in item.split("__"): - model_cls = model_cls.__model_fields__[part].to - select_from = sqlalchemy.sql.join(select_from, model_cls.__table__) - tables.append(model_cls.__table__) - expr = sqlalchemy.sql.select(tables) + for part in item.split("__"): + + model_cls = model_cls.__model_fields__[part].to + to_table = model_cls.__table__.name + + alias = model_cls._orm_relationship_manager.resolve_relation_join(from_table, to_table) + + if prev_model.__model_fields__[part].virtual: + # TODO: change the key lookup + to_key = prev_model.__name__.lower() + from_key = model_cls.__pkname__ + else: + to_key = model_cls.__pkname__ + from_key = part + + on_clause = self.on_clause(from_table, to_table, previous_alias, alias, to_key, from_key) + target_table = self.prefixed_table_name(alias, to_table) + select_from = sqlalchemy.sql.outerjoin(select_from, target_table, on_clause) + tables.append(model_cls.__table__) + order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}')) + columns.extend(self.prefixed_columns(alias, model_cls.__table__)) + + previous_alias = alias + from_table = to_table + prev_model = model_cls + + expr = sqlalchemy.sql.select(columns) expr = expr.select_from(select_from) if self.filter_clauses: @@ -71,6 +110,9 @@ class QuerySet: if self.query_offset: expr = expr.offset(self.query_offset) + for order in order_bys: + expr = expr.order_by(order) + print(expr.compile(compile_kwargs={"literal_binds": True})) return expr @@ -83,6 +125,7 @@ class QuerySet: kwargs[pk_name] = kwargs.pop("pk") for key, value in kwargs.items(): + table_prefix = '' if "__" in key: parts = key.split("__") @@ -106,14 +149,22 @@ class QuerySet: # Walk the relationships to the actual model class # against which the comparison is being made. + previous_table = model_cls.__tablename__ for part in related_parts: + current_table = model_cls.__model_fields__[part].to.__tablename__ + table_prefix = model_cls._orm_relationship_manager.resolve_relation_join(previous_table, + current_table) model_cls = model_cls.__model_fields__[part].to + previous_table = current_table + print(table_prefix) + table = model_cls.__table__ column = model_cls.__table__.columns[field_name] else: op = "exact" column = self.table.columns[key] + table = self.table # Map the operation code onto SQLAlchemy's ColumnElement # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement @@ -134,6 +185,13 @@ class QuerySet: clause = getattr(column, op_attr)(value) clause.modifiers['escape'] = '\\' if has_escaped_character else None + + clause_text = str(clause.compile(compile_kwargs={"literal_binds": True})) + alias = f'{table_prefix}_' if table_prefix else '' + aliased_name = f'{alias}{table.name}.{column.name}' + clause_text = clause_text.replace(f'{table.name}.{column.name}', aliased_name) + clause = text(clause_text) + filter_clauses.append(clause) return self.__class__( @@ -212,11 +270,36 @@ class QuerySet: expr = self.build_select_expression() rows = await self.database.fetch_all(expr) - return [ + result_rows = [ self.model_cls.from_row(row, select_related=self._select_related) for row in rows ] + result_rows = self.merge_result_rows(result_rows) + + return result_rows + + @classmethod + def merge_result_rows(cls, result_rows): + merged_rows = [] + for index, model in enumerate(result_rows): + if index > 0 and model.pk == result_rows[index - 1].pk: + result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) + else: + merged_rows.append(model) + return merged_rows + + @classmethod + def merge_two_instances(cls, one: 'Model', other: 'Model'): + for field in one.__model_fields__.keys(): + print(field, one.dict(), other.dict()) + if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model): + setattr(other, field, getattr(one, field) + getattr(other, field)) + elif isinstance(getattr(one, field), orm.models.Model): + if getattr(one, field).pk == getattr(other, field).pk: + setattr(other, field, cls.merge_two_instances(getattr(one, field), getattr(other, field))) + return other + async def create(self, **kwargs): new_kwargs = dict(**kwargs) diff --git a/orm/relations.py b/orm/relations.py index 31bb520..0888284 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -2,7 +2,7 @@ import pprint import string import uuid from random import choices -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List from weakref import proxy from sqlalchemy import text @@ -40,6 +40,7 @@ class RelationshipManager: self._relations[reverse_key] = get_relation_config('reverse', table_name, field) def deregister(self, model: 'Model'): + print(f'deregistering {model.__class__.__name__}, {model._orm_id}') for rel_type in self._relations.keys(): if model.__class__.__name__.lower() in rel_type.lower(): if model._orm_id in self._relations[rel_type]: @@ -54,8 +55,25 @@ class RelationshipManager: child, parent = parent, proxy(child) else: child = proxy(child) - self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append(child) - self._relations[child_name.title() + '_' + parent_name].setdefault(child_id, []).append(parent) + print( + f'setting up relationship, {parent_id}, {child_id}, ' + f'{parent.__class__.__name__}, {child.__class__.__name__}, ' + f'{parent.pk if parent.values is not None else None}, ' + f'{child.pk if child.values is not None else None}') + parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, []) + self.append_related_model(parents_list, child) + children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, []) + self.append_related_model(children_list, parent) + + def append_related_model(self, relations_list: List['Model'], model: 'Model'): + for x in relations_list: + try: + if x.__same__(model): + return + except ReferenceError: + continue + + relations_list.append(model) def contains(self, relations_key: str, object: 'Model'): if relations_key in self._relations: @@ -69,6 +87,12 @@ class RelationshipManager: return self._relations[relations_key][object._orm_id][0] return self._relations[relations_key][object._orm_id] + def resolve_relation_join(self, from_table: str, to_table: str) -> str: + for k, v in self._relations.items(): + if v['source_table'] == from_table and v['target_table'] == to_table: + return self._relations[k]['table_alias'] + return '' + def __str__(self): # pragma no cover return pprint.pformat(self._relations, indent=4, width=1) diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 155e6a4..30b34d5 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -9,6 +9,15 @@ database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() +class Department(orm.Model): + __tablename__ = "departments" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + class SchoolClass(orm.Model): __tablename__ = "schoolclasses" __metadata__ = metadata @@ -16,6 +25,7 @@ class SchoolClass(orm.Model): id = orm.Integer(primary_key=True) name = orm.String(length=100) + department = orm.ForeignKey(Department) class Category(orm.Model): @@ -60,7 +70,8 @@ def create_test_database(): @pytest.mark.asyncio async def test_model_multiple_instances_of_same_table_in_schema(): async with database: - class1 = await SchoolClass.objects.create(name="Math") + department = await Department.objects.create(name='Math Department') + class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") await Student.objects.create(name="Jane", category=category, schoolclass=class1) @@ -80,7 +91,8 @@ async def test_model_multiple_instances_of_same_table_in_schema(): @pytest.mark.asyncio async def test_right_tables_join(): async with database: - class1 = await SchoolClass.objects.create(name="Math") + department = await Department.objects.create(name='Math Department') + class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") await Student.objects.create(name="Jane", category=category, schoolclass=class1) @@ -89,5 +101,25 @@ async def test_right_tables_join(): classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() assert classes[0].name == 'Math' assert classes[0].students[0].name == 'Jane' - breakpoint() + assert classes[0].teachers[0].category.name == 'Domestic' + + assert classes[0].students[0].category.name is None + await classes[0].students[0].category.load() + assert classes[0].students[0].category.name == 'Foreign' + + +@pytest.mark.asyncio +async def test_multiple_reverse_related_objects(): + async with database: + department = await Department.objects.create(name='Math Department') + class1 = await SchoolClass.objects.create(name="Math", department=department) + category = await Category.objects.create(name="Foreign") + category2 = await Category.objects.create(name="Domestic") + await Student.objects.create(name="Jane", category=category, schoolclass=class1) + await Student.objects.create(name="Jack", category=category, schoolclass=class1) + await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) + + classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() + assert classes[0].name == 'Math' + assert classes[0].students[0].name == 'Jane' assert classes[0].teachers[0].category.name == 'Domestic'