change queryset to work with column and table aliases

This commit is contained in:
collerek
2020-08-07 13:20:16 +02:00
parent 6efb56a2a0
commit 62475a1949
6 changed files with 194 additions and 25 deletions

BIN
.coverage

Binary file not shown.

View File

@ -2,6 +2,7 @@ import datetime
import decimal import decimal
from typing import Optional, List from typing import Optional, List
import orm
import sqlalchemy import sqlalchemy
from pydantic import Json from pydantic import Json
from pydantic.fields import ModelField from pydantic.fields import ModelField
@ -192,9 +193,14 @@ class ForeignKey(BaseField):
return to_column.get_column_type() return to_column.get_column_type()
def expand_relationship(self, value, child): def expand_relationship(self, value, child):
if not isinstance(value, (self.to, dict, int, str)): if not isinstance(value, (self.to, dict, int, str, list)) or (
isinstance(value, orm.models.Model) and not isinstance(value, self.to)):
raise RelationshipInstanceError( raise RelationshipInstanceError(
'Relationship model can be build only from orm.Model, dict and integer or string (pk).') 'Relationship model can be build only from orm.Model, dict and integer or string (pk).')
if isinstance(value, list) and not isinstance(value, self.to):
model = [self.expand_relationship(val, child) for val in value]
return model
if isinstance(value, self.to): if isinstance(value, self.to):
model = value model = value
elif isinstance(value, dict): elif isinstance(value, dict):

View File

@ -41,8 +41,8 @@ def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename
if field.primary_key: if field.primary_key:
pkname = field_name pkname = field_name
if isinstance(field, ForeignKey): if isinstance(field, ForeignKey):
reverse_name = field.related_name or field.to.__name__.title() + '_' + name.lower() + 's' reverse_name = field.related_name or field.to.__name__.lower().title() + '_' + name.lower() + 's'
relation_name = name + '_' + field.to.__name__.lower() relation_name = name.lower().title() + '_' + field.to.__name__.lower()
relationship_manager.add_relation_type(relation_name, reverse_name, field, tablename) relationship_manager.add_relation_type(relation_name, reverse_name, field, tablename)
columns.append(field.get_column(field_name)) columns.append(field.get_column(field_name))
return pkname, columns, model_fields return pkname, columns, model_fields
@ -88,6 +88,8 @@ class ModelMetaclass(type):
attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__) attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__)
attrs['__model_fields__'] = model_fields attrs['__model_fields__'] = model_fields
attrs['_orm_relationship_manager'] = relationship_manager
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
) )
@ -105,13 +107,13 @@ class Model(list, metaclass=ModelMetaclass):
__fields__: Dict[str, pydantic.fields.ModelField] __fields__: Dict[str, pydantic.fields.ModelField]
__pydantic_model__: Type[BaseModel] __pydantic_model__: Type[BaseModel]
__pkname__: str __pkname__: str
_orm_relationship_manager: RelationshipManager
objects = qry.QuerySet() objects = qry.QuerySet()
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
self._orm_id: str = uuid.uuid4().hex self._orm_id: str = uuid.uuid4().hex
self._orm_saved: bool = False self._orm_saved: bool = False
self._orm_relationship_manager: RelationshipManager = relationship_manager
self.values: Optional[BaseModel] = None self.values: Optional[BaseModel] = None
if "pk" in kwargs: if "pk" in kwargs:
@ -129,7 +131,11 @@ class Model(list, metaclass=ModelMetaclass):
value = json.dumps(value) value = json.dumps(value)
except TypeError: # pragma no cover except TypeError: # pragma no cover
pass pass
value = self.__model_fields__[key].expand_relationship(value, self) value = self.__model_fields__[key].expand_relationship(value, self)
relation_key = self.__class__.__name__.title() + '_' + key
if not self._orm_relationship_manager.contains(relation_key, self):
setattr(self.values, key, value) setattr(self.values, key, value)
else: else:
super().__setattr__(key, value) super().__setattr__(key, value)
@ -152,25 +158,36 @@ class Model(list, metaclass=ModelMetaclass):
def __eq__(self, other): def __eq__(self, other):
return self.values.dict() == other.values.dict() return self.values.dict() == other.values.dict()
def __same__(self, other):
assert self.__class__ == other.__class__
return self._orm_id == other._orm_id or (
self.values is not None and other.values is not None and self.pk == other.pk)
def __repr__(self): # pragma no cover def __repr__(self): # pragma no cover
return self.values.__repr__() return self.values.__repr__()
@classmethod @classmethod
def from_row(cls, row, select_related: List = None) -> 'Model': def from_row(cls, row, select_related: List = None, previous_table: str = None) -> 'Model':
item = {} item = {}
select_related = select_related or [] select_related = select_related or []
table_prefix = cls._orm_relationship_manager.resolve_relation_join(previous_table, cls.__table__.name)
previous_table = cls.__table__.name
for related in select_related: for related in select_related:
if "__" in related: if "__" in related:
first_part, remainder = related.split("__", 1) first_part, remainder = related.split("__", 1)
model_cls = cls.__model_fields__[first_part].to model_cls = cls.__model_fields__[first_part].to
item[first_part] = model_cls.from_row(row, select_related=[remainder]) child = model_cls.from_row(row, select_related=[remainder], previous_table=previous_table)
item[first_part] = child
else: else:
model_cls = cls.__model_fields__[related].to model_cls = cls.__model_fields__[related].to
item[related] = model_cls.from_row(row) child = model_cls.from_row(row, previous_table=previous_table)
item[related] = child
for column in cls.__table__.columns: for column in cls.__table__.columns:
if column.name not in item: if column.name not in item:
item[column.name] = row[column] item[column.name] = row[f'{table_prefix + "_" if table_prefix else ""}{column.name}']
return cls(**item) return cls(**item)
@ -202,7 +219,14 @@ class Model(list, metaclass=ModelMetaclass):
return self.__table__.primary_key.columns.values()[0] return self.__table__.primary_key.columns.values()[0]
def dict(self) -> Dict: def dict(self) -> Dict:
return self.values.dict() dict_instance = self.values.dict()
for field in self.extract_related_names():
nested_model = getattr(self, field)
if isinstance(nested_model, list):
dict_instance[field] = [x.dict() for x in nested_model]
else:
dict_instance[field] = nested_model.dict() if nested_model is not None else {}
return dict_instance
def from_dict(self, value_dict: Dict) -> None: def from_dict(self, value_dict: Dict) -> None:
for key, value in value_dict.items(): for key, value in value_dict.items():

View File

@ -1,6 +1,7 @@
from typing import List, TYPE_CHECKING from typing import List, TYPE_CHECKING, Type
import sqlalchemy import sqlalchemy
from sqlalchemy import text
import orm import orm
from orm.exceptions import NoMatch, MultipleMatches from orm.exceptions import NoMatch, MultipleMatches
@ -24,13 +25,14 @@ FILTER_OPERATORS = {
class QuerySet: class QuerySet:
ESCAPE_CHARACTERS = ['%', '_'] ESCAPE_CHARACTERS = ['%', '_']
def __init__(self, model_cls: 'Model' = None, filter_clauses: List = None, select_related: List = None, def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None,
limit_count: int = None, offset: int = None): limit_count: int = None, offset: int = None):
self.model_cls = model_cls self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses self.filter_clauses = [] if filter_clauses is None else filter_clauses
self._select_related = [] if select_related is None else select_related self._select_related = [] if select_related is None else select_related
self.limit_count = limit_count self.limit_count = limit_count
self.query_offset = offset self.query_offset = offset
self.aliases_dict = dict()
def __get__(self, instance, owner): def __get__(self, instance, owner):
return self.__class__(model_cls=owner) return self.__class__(model_cls=owner)
@ -43,19 +45,56 @@ class QuerySet:
def table(self): def table(self):
return self.model_cls.__table__ return self.model_cls.__table__
def prefixed_columns(self, alias, table):
return [text(f'{alias}_{table.name}.{column.name} as {alias}_{column.name}')
for column in table.columns]
def prefixed_table_name(self, alias, name):
return text(f'{name} {alias}_{name}')
def on_clause(self, from_table, to_table, previous_alias, alias, to_key, from_key):
return text(f'{alias}_{to_table}.{to_key}='
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}')
def build_select_expression(self): def build_select_expression(self):
tables = [self.table] tables = [self.table]
columns = list(self.table.columns)
order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')]
select_from = self.table select_from = self.table
for item in self._select_related: for item in self._select_related:
previous_alias = ''
from_table = self.table.name
prev_model = self.model_cls
model_cls = self.model_cls model_cls = self.model_cls
select_from = self.table
for part in item.split("__"):
model_cls = model_cls.__model_fields__[part].to
select_from = sqlalchemy.sql.join(select_from, model_cls.__table__)
tables.append(model_cls.__table__)
expr = sqlalchemy.sql.select(tables) for part in item.split("__"):
model_cls = model_cls.__model_fields__[part].to
to_table = model_cls.__table__.name
alias = model_cls._orm_relationship_manager.resolve_relation_join(from_table, to_table)
if prev_model.__model_fields__[part].virtual:
# TODO: change the key lookup
to_key = prev_model.__name__.lower()
from_key = model_cls.__pkname__
else:
to_key = model_cls.__pkname__
from_key = part
on_clause = self.on_clause(from_table, to_table, previous_alias, alias, to_key, from_key)
target_table = self.prefixed_table_name(alias, to_table)
select_from = sqlalchemy.sql.outerjoin(select_from, target_table, on_clause)
tables.append(model_cls.__table__)
order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}'))
columns.extend(self.prefixed_columns(alias, model_cls.__table__))
previous_alias = alias
from_table = to_table
prev_model = model_cls
expr = sqlalchemy.sql.select(columns)
expr = expr.select_from(select_from) expr = expr.select_from(select_from)
if self.filter_clauses: if self.filter_clauses:
@ -71,6 +110,9 @@ class QuerySet:
if self.query_offset: if self.query_offset:
expr = expr.offset(self.query_offset) expr = expr.offset(self.query_offset)
for order in order_bys:
expr = expr.order_by(order)
print(expr.compile(compile_kwargs={"literal_binds": True})) print(expr.compile(compile_kwargs={"literal_binds": True}))
return expr return expr
@ -83,6 +125,7 @@ class QuerySet:
kwargs[pk_name] = kwargs.pop("pk") kwargs[pk_name] = kwargs.pop("pk")
for key, value in kwargs.items(): for key, value in kwargs.items():
table_prefix = ''
if "__" in key: if "__" in key:
parts = key.split("__") parts = key.split("__")
@ -106,14 +149,22 @@ class QuerySet:
# Walk the relationships to the actual model class # Walk the relationships to the actual model class
# against which the comparison is being made. # against which the comparison is being made.
previous_table = model_cls.__tablename__
for part in related_parts: for part in related_parts:
current_table = model_cls.__model_fields__[part].to.__tablename__
table_prefix = model_cls._orm_relationship_manager.resolve_relation_join(previous_table,
current_table)
model_cls = model_cls.__model_fields__[part].to model_cls = model_cls.__model_fields__[part].to
previous_table = current_table
print(table_prefix)
table = model_cls.__table__
column = model_cls.__table__.columns[field_name] column = model_cls.__table__.columns[field_name]
else: else:
op = "exact" op = "exact"
column = self.table.columns[key] column = self.table.columns[key]
table = self.table
# Map the operation code onto SQLAlchemy's ColumnElement # Map the operation code onto SQLAlchemy's ColumnElement
# https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement
@ -134,6 +185,13 @@ class QuerySet:
clause = getattr(column, op_attr)(value) clause = getattr(column, op_attr)(value)
clause.modifiers['escape'] = '\\' if has_escaped_character else None clause.modifiers['escape'] = '\\' if has_escaped_character else None
clause_text = str(clause.compile(compile_kwargs={"literal_binds": True}))
alias = f'{table_prefix}_' if table_prefix else ''
aliased_name = f'{alias}{table.name}.{column.name}'
clause_text = clause_text.replace(f'{table.name}.{column.name}', aliased_name)
clause = text(clause_text)
filter_clauses.append(clause) filter_clauses.append(clause)
return self.__class__( return self.__class__(
@ -212,11 +270,36 @@ class QuerySet:
expr = self.build_select_expression() expr = self.build_select_expression()
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
return [ result_rows = [
self.model_cls.from_row(row, select_related=self._select_related) self.model_cls.from_row(row, select_related=self._select_related)
for row in rows for row in rows
] ]
result_rows = self.merge_result_rows(result_rows)
return result_rows
@classmethod
def merge_result_rows(cls, result_rows):
merged_rows = []
for index, model in enumerate(result_rows):
if index > 0 and model.pk == result_rows[index - 1].pk:
result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
else:
merged_rows.append(model)
return merged_rows
@classmethod
def merge_two_instances(cls, one: 'Model', other: 'Model'):
for field in one.__model_fields__.keys():
print(field, one.dict(), other.dict())
if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model):
setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), orm.models.Model):
if getattr(one, field).pk == getattr(other, field).pk:
setattr(other, field, cls.merge_two_instances(getattr(one, field), getattr(other, field)))
return other
async def create(self, **kwargs): async def create(self, **kwargs):
new_kwargs = dict(**kwargs) new_kwargs = dict(**kwargs)

View File

@ -2,7 +2,7 @@ import pprint
import string import string
import uuid import uuid
from random import choices from random import choices
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, List
from weakref import proxy from weakref import proxy
from sqlalchemy import text from sqlalchemy import text
@ -40,6 +40,7 @@ class RelationshipManager:
self._relations[reverse_key] = get_relation_config('reverse', table_name, field) self._relations[reverse_key] = get_relation_config('reverse', table_name, field)
def deregister(self, model: 'Model'): def deregister(self, model: 'Model'):
print(f'deregistering {model.__class__.__name__}, {model._orm_id}')
for rel_type in self._relations.keys(): for rel_type in self._relations.keys():
if model.__class__.__name__.lower() in rel_type.lower(): if model.__class__.__name__.lower() in rel_type.lower():
if model._orm_id in self._relations[rel_type]: if model._orm_id in self._relations[rel_type]:
@ -54,8 +55,25 @@ class RelationshipManager:
child, parent = parent, proxy(child) child, parent = parent, proxy(child)
else: else:
child = proxy(child) child = proxy(child)
self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append(child) print(
self._relations[child_name.title() + '_' + parent_name].setdefault(child_id, []).append(parent) f'setting up relationship, {parent_id}, {child_id}, '
f'{parent.__class__.__name__}, {child.__class__.__name__}, '
f'{parent.pk if parent.values is not None else None}, '
f'{child.pk if child.values is not None else None}')
parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, [])
self.append_related_model(parents_list, child)
children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, [])
self.append_related_model(children_list, parent)
def append_related_model(self, relations_list: List['Model'], model: 'Model'):
for x in relations_list:
try:
if x.__same__(model):
return
except ReferenceError:
continue
relations_list.append(model)
def contains(self, relations_key: str, object: 'Model'): def contains(self, relations_key: str, object: 'Model'):
if relations_key in self._relations: if relations_key in self._relations:
@ -69,6 +87,12 @@ class RelationshipManager:
return self._relations[relations_key][object._orm_id][0] return self._relations[relations_key][object._orm_id][0]
return self._relations[relations_key][object._orm_id] return self._relations[relations_key][object._orm_id]
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
for k, v in self._relations.items():
if v['source_table'] == from_table and v['target_table'] == to_table:
return self._relations[k]['table_alias']
return ''
def __str__(self): # pragma no cover def __str__(self): # pragma no cover
return pprint.pformat(self._relations, indent=4, width=1) return pprint.pformat(self._relations, indent=4, width=1)

View File

@ -9,6 +9,15 @@ database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
class Department(orm.Model):
__tablename__ = "departments"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
name = orm.String(length=100)
class SchoolClass(orm.Model): class SchoolClass(orm.Model):
__tablename__ = "schoolclasses" __tablename__ = "schoolclasses"
__metadata__ = metadata __metadata__ = metadata
@ -16,6 +25,7 @@ class SchoolClass(orm.Model):
id = orm.Integer(primary_key=True) id = orm.Integer(primary_key=True)
name = orm.String(length=100) name = orm.String(length=100)
department = orm.ForeignKey(Department)
class Category(orm.Model): class Category(orm.Model):
@ -60,7 +70,8 @@ def create_test_database():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema(): async def test_model_multiple_instances_of_same_table_in_schema():
async with database: async with database:
class1 = await SchoolClass.objects.create(name="Math") department = await Department.objects.create(name='Math Department')
class1 = await SchoolClass.objects.create(name="Math", department=department)
category = await Category.objects.create(name="Foreign") category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic") category2 = await Category.objects.create(name="Domestic")
await Student.objects.create(name="Jane", category=category, schoolclass=class1) await Student.objects.create(name="Jane", category=category, schoolclass=class1)
@ -80,7 +91,8 @@ async def test_model_multiple_instances_of_same_table_in_schema():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_right_tables_join(): async def test_right_tables_join():
async with database: async with database:
class1 = await SchoolClass.objects.create(name="Math") department = await Department.objects.create(name='Math Department')
class1 = await SchoolClass.objects.create(name="Math", department=department)
category = await Category.objects.create(name="Foreign") category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic") category2 = await Category.objects.create(name="Domestic")
await Student.objects.create(name="Jane", category=category, schoolclass=class1) await Student.objects.create(name="Jane", category=category, schoolclass=class1)
@ -89,5 +101,25 @@ async def test_right_tables_join():
classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all()
assert classes[0].name == 'Math' assert classes[0].name == 'Math'
assert classes[0].students[0].name == 'Jane' assert classes[0].students[0].name == 'Jane'
breakpoint() assert classes[0].teachers[0].category.name == 'Domestic'
assert classes[0].students[0].category.name is None
await classes[0].students[0].category.load()
assert classes[0].students[0].category.name == 'Foreign'
@pytest.mark.asyncio
async def test_multiple_reverse_related_objects():
async with database:
department = await Department.objects.create(name='Math Department')
class1 = await SchoolClass.objects.create(name="Math", department=department)
category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic")
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
await Student.objects.create(name="Jack", category=category, schoolclass=class1)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all()
assert classes[0].name == 'Math'
assert classes[0].students[0].name == 'Jane'
assert classes[0].teachers[0].category.name == 'Domestic' assert classes[0].teachers[0].category.name == 'Domestic'