change queryset to work with column and table aliases

This commit is contained in:
collerek
2020-08-07 13:20:16 +02:00
parent 6efb56a2a0
commit 62475a1949
6 changed files with 194 additions and 25 deletions

BIN
.coverage

Binary file not shown.

View File

@ -2,6 +2,7 @@ import datetime
import decimal
from typing import Optional, List
import orm
import sqlalchemy
from pydantic import Json
from pydantic.fields import ModelField
@ -192,9 +193,14 @@ class ForeignKey(BaseField):
return to_column.get_column_type()
def expand_relationship(self, value, child):
if not isinstance(value, (self.to, dict, int, str)):
if not isinstance(value, (self.to, dict, int, str, list)) or (
isinstance(value, orm.models.Model) and not isinstance(value, self.to)):
raise RelationshipInstanceError(
'Relationship model can be build only from orm.Model, dict and integer or string (pk).')
if isinstance(value, list) and not isinstance(value, self.to):
model = [self.expand_relationship(val, child) for val in value]
return model
if isinstance(value, self.to):
model = value
elif isinstance(value, dict):

View File

@ -41,8 +41,8 @@ def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename
if field.primary_key:
pkname = field_name
if isinstance(field, ForeignKey):
reverse_name = field.related_name or field.to.__name__.title() + '_' + name.lower() + 's'
relation_name = name + '_' + field.to.__name__.lower()
reverse_name = field.related_name or field.to.__name__.lower().title() + '_' + name.lower() + 's'
relation_name = name.lower().title() + '_' + field.to.__name__.lower()
relationship_manager.add_relation_type(relation_name, reverse_name, field, tablename)
columns.append(field.get_column(field_name))
return pkname, columns, model_fields
@ -88,6 +88,8 @@ class ModelMetaclass(type):
attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__)
attrs['__model_fields__'] = model_fields
attrs['_orm_relationship_manager'] = relationship_manager
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs
)
@ -105,13 +107,13 @@ class Model(list, metaclass=ModelMetaclass):
__fields__: Dict[str, pydantic.fields.ModelField]
__pydantic_model__: Type[BaseModel]
__pkname__: str
_orm_relationship_manager: RelationshipManager
objects = qry.QuerySet()
def __init__(self, *args, **kwargs) -> None:
self._orm_id: str = uuid.uuid4().hex
self._orm_saved: bool = False
self._orm_relationship_manager: RelationshipManager = relationship_manager
self.values: Optional[BaseModel] = None
if "pk" in kwargs:
@ -129,7 +131,11 @@ class Model(list, metaclass=ModelMetaclass):
value = json.dumps(value)
except TypeError: # pragma no cover
pass
value = self.__model_fields__[key].expand_relationship(value, self)
relation_key = self.__class__.__name__.title() + '_' + key
if not self._orm_relationship_manager.contains(relation_key, self):
setattr(self.values, key, value)
else:
super().__setattr__(key, value)
@ -152,25 +158,36 @@ class Model(list, metaclass=ModelMetaclass):
def __eq__(self, other):
return self.values.dict() == other.values.dict()
def __same__(self, other):
assert self.__class__ == other.__class__
return self._orm_id == other._orm_id or (
self.values is not None and other.values is not None and self.pk == other.pk)
def __repr__(self): # pragma no cover
return self.values.__repr__()
@classmethod
def from_row(cls, row, select_related: List = None) -> 'Model':
def from_row(cls, row, select_related: List = None, previous_table: str = None) -> 'Model':
item = {}
select_related = select_related or []
table_prefix = cls._orm_relationship_manager.resolve_relation_join(previous_table, cls.__table__.name)
previous_table = cls.__table__.name
for related in select_related:
if "__" in related:
first_part, remainder = related.split("__", 1)
model_cls = cls.__model_fields__[first_part].to
item[first_part] = model_cls.from_row(row, select_related=[remainder])
child = model_cls.from_row(row, select_related=[remainder], previous_table=previous_table)
item[first_part] = child
else:
model_cls = cls.__model_fields__[related].to
item[related] = model_cls.from_row(row)
child = model_cls.from_row(row, previous_table=previous_table)
item[related] = child
for column in cls.__table__.columns:
if column.name not in item:
item[column.name] = row[column]
item[column.name] = row[f'{table_prefix + "_" if table_prefix else ""}{column.name}']
return cls(**item)
@ -202,7 +219,14 @@ class Model(list, metaclass=ModelMetaclass):
return self.__table__.primary_key.columns.values()[0]
def dict(self) -> Dict:
return self.values.dict()
dict_instance = self.values.dict()
for field in self.extract_related_names():
nested_model = getattr(self, field)
if isinstance(nested_model, list):
dict_instance[field] = [x.dict() for x in nested_model]
else:
dict_instance[field] = nested_model.dict() if nested_model is not None else {}
return dict_instance
def from_dict(self, value_dict: Dict) -> None:
for key, value in value_dict.items():

View File

@ -1,6 +1,7 @@
from typing import List, TYPE_CHECKING
from typing import List, TYPE_CHECKING, Type
import sqlalchemy
from sqlalchemy import text
import orm
from orm.exceptions import NoMatch, MultipleMatches
@ -24,13 +25,14 @@ FILTER_OPERATORS = {
class QuerySet:
ESCAPE_CHARACTERS = ['%', '_']
def __init__(self, model_cls: 'Model' = None, filter_clauses: List = None, select_related: List = None,
def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None,
limit_count: int = None, offset: int = None):
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
self._select_related = [] if select_related is None else select_related
self.limit_count = limit_count
self.query_offset = offset
self.aliases_dict = dict()
def __get__(self, instance, owner):
return self.__class__(model_cls=owner)
@ -43,19 +45,56 @@ class QuerySet:
def table(self):
return self.model_cls.__table__
def prefixed_columns(self, alias, table):
return [text(f'{alias}_{table.name}.{column.name} as {alias}_{column.name}')
for column in table.columns]
def prefixed_table_name(self, alias, name):
return text(f'{name} {alias}_{name}')
def on_clause(self, from_table, to_table, previous_alias, alias, to_key, from_key):
return text(f'{alias}_{to_table}.{to_key}='
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}')
def build_select_expression(self):
tables = [self.table]
columns = list(self.table.columns)
order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')]
select_from = self.table
for item in self._select_related:
previous_alias = ''
from_table = self.table.name
prev_model = self.model_cls
model_cls = self.model_cls
select_from = self.table
for part in item.split("__"):
model_cls = model_cls.__model_fields__[part].to
select_from = sqlalchemy.sql.join(select_from, model_cls.__table__)
tables.append(model_cls.__table__)
expr = sqlalchemy.sql.select(tables)
for part in item.split("__"):
model_cls = model_cls.__model_fields__[part].to
to_table = model_cls.__table__.name
alias = model_cls._orm_relationship_manager.resolve_relation_join(from_table, to_table)
if prev_model.__model_fields__[part].virtual:
# TODO: change the key lookup
to_key = prev_model.__name__.lower()
from_key = model_cls.__pkname__
else:
to_key = model_cls.__pkname__
from_key = part
on_clause = self.on_clause(from_table, to_table, previous_alias, alias, to_key, from_key)
target_table = self.prefixed_table_name(alias, to_table)
select_from = sqlalchemy.sql.outerjoin(select_from, target_table, on_clause)
tables.append(model_cls.__table__)
order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}'))
columns.extend(self.prefixed_columns(alias, model_cls.__table__))
previous_alias = alias
from_table = to_table
prev_model = model_cls
expr = sqlalchemy.sql.select(columns)
expr = expr.select_from(select_from)
if self.filter_clauses:
@ -71,6 +110,9 @@ class QuerySet:
if self.query_offset:
expr = expr.offset(self.query_offset)
for order in order_bys:
expr = expr.order_by(order)
print(expr.compile(compile_kwargs={"literal_binds": True}))
return expr
@ -83,6 +125,7 @@ class QuerySet:
kwargs[pk_name] = kwargs.pop("pk")
for key, value in kwargs.items():
table_prefix = ''
if "__" in key:
parts = key.split("__")
@ -106,14 +149,22 @@ class QuerySet:
# Walk the relationships to the actual model class
# against which the comparison is being made.
previous_table = model_cls.__tablename__
for part in related_parts:
current_table = model_cls.__model_fields__[part].to.__tablename__
table_prefix = model_cls._orm_relationship_manager.resolve_relation_join(previous_table,
current_table)
model_cls = model_cls.__model_fields__[part].to
previous_table = current_table
print(table_prefix)
table = model_cls.__table__
column = model_cls.__table__.columns[field_name]
else:
op = "exact"
column = self.table.columns[key]
table = self.table
# Map the operation code onto SQLAlchemy's ColumnElement
# https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement
@ -134,6 +185,13 @@ class QuerySet:
clause = getattr(column, op_attr)(value)
clause.modifiers['escape'] = '\\' if has_escaped_character else None
clause_text = str(clause.compile(compile_kwargs={"literal_binds": True}))
alias = f'{table_prefix}_' if table_prefix else ''
aliased_name = f'{alias}{table.name}.{column.name}'
clause_text = clause_text.replace(f'{table.name}.{column.name}', aliased_name)
clause = text(clause_text)
filter_clauses.append(clause)
return self.__class__(
@ -212,11 +270,36 @@ class QuerySet:
expr = self.build_select_expression()
rows = await self.database.fetch_all(expr)
return [
result_rows = [
self.model_cls.from_row(row, select_related=self._select_related)
for row in rows
]
result_rows = self.merge_result_rows(result_rows)
return result_rows
@classmethod
def merge_result_rows(cls, result_rows):
merged_rows = []
for index, model in enumerate(result_rows):
if index > 0 and model.pk == result_rows[index - 1].pk:
result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
else:
merged_rows.append(model)
return merged_rows
@classmethod
def merge_two_instances(cls, one: 'Model', other: 'Model'):
for field in one.__model_fields__.keys():
print(field, one.dict(), other.dict())
if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model):
setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), orm.models.Model):
if getattr(one, field).pk == getattr(other, field).pk:
setattr(other, field, cls.merge_two_instances(getattr(one, field), getattr(other, field)))
return other
async def create(self, **kwargs):
new_kwargs = dict(**kwargs)

View File

@ -2,7 +2,7 @@ import pprint
import string
import uuid
from random import choices
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List
from weakref import proxy
from sqlalchemy import text
@ -40,6 +40,7 @@ class RelationshipManager:
self._relations[reverse_key] = get_relation_config('reverse', table_name, field)
def deregister(self, model: 'Model'):
print(f'deregistering {model.__class__.__name__}, {model._orm_id}')
for rel_type in self._relations.keys():
if model.__class__.__name__.lower() in rel_type.lower():
if model._orm_id in self._relations[rel_type]:
@ -54,8 +55,25 @@ class RelationshipManager:
child, parent = parent, proxy(child)
else:
child = proxy(child)
self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append(child)
self._relations[child_name.title() + '_' + parent_name].setdefault(child_id, []).append(parent)
print(
f'setting up relationship, {parent_id}, {child_id}, '
f'{parent.__class__.__name__}, {child.__class__.__name__}, '
f'{parent.pk if parent.values is not None else None}, '
f'{child.pk if child.values is not None else None}')
parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, [])
self.append_related_model(parents_list, child)
children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, [])
self.append_related_model(children_list, parent)
def append_related_model(self, relations_list: List['Model'], model: 'Model'):
for x in relations_list:
try:
if x.__same__(model):
return
except ReferenceError:
continue
relations_list.append(model)
def contains(self, relations_key: str, object: 'Model'):
if relations_key in self._relations:
@ -69,6 +87,12 @@ class RelationshipManager:
return self._relations[relations_key][object._orm_id][0]
return self._relations[relations_key][object._orm_id]
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
for k, v in self._relations.items():
if v['source_table'] == from_table and v['target_table'] == to_table:
return self._relations[k]['table_alias']
return ''
def __str__(self): # pragma no cover
return pprint.pformat(self._relations, indent=4, width=1)

View File

@ -9,6 +9,15 @@ database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Department(orm.Model):
__tablename__ = "departments"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
name = orm.String(length=100)
class SchoolClass(orm.Model):
__tablename__ = "schoolclasses"
__metadata__ = metadata
@ -16,6 +25,7 @@ class SchoolClass(orm.Model):
id = orm.Integer(primary_key=True)
name = orm.String(length=100)
department = orm.ForeignKey(Department)
class Category(orm.Model):
@ -60,7 +70,8 @@ def create_test_database():
@pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema():
async with database:
class1 = await SchoolClass.objects.create(name="Math")
department = await Department.objects.create(name='Math Department')
class1 = await SchoolClass.objects.create(name="Math", department=department)
category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic")
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
@ -80,7 +91,8 @@ async def test_model_multiple_instances_of_same_table_in_schema():
@pytest.mark.asyncio
async def test_right_tables_join():
async with database:
class1 = await SchoolClass.objects.create(name="Math")
department = await Department.objects.create(name='Math Department')
class1 = await SchoolClass.objects.create(name="Math", department=department)
category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic")
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
@ -89,5 +101,25 @@ async def test_right_tables_join():
classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all()
assert classes[0].name == 'Math'
assert classes[0].students[0].name == 'Jane'
breakpoint()
assert classes[0].teachers[0].category.name == 'Domestic'
assert classes[0].students[0].category.name is None
await classes[0].students[0].category.load()
assert classes[0].students[0].category.name == 'Foreign'
@pytest.mark.asyncio
async def test_multiple_reverse_related_objects():
async with database:
department = await Department.objects.create(name='Math Department')
class1 = await SchoolClass.objects.create(name="Math", department=department)
category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic")
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
await Student.objects.create(name="Jack", category=category, schoolclass=class1)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all()
assert classes[0].name == 'Math'
assert classes[0].students[0].name == 'Jane'
assert classes[0].teachers[0].category.name == 'Domestic'