add preloading of not nullable relations (and all chain inbetween)

This commit is contained in:
collerek
2020-08-07 19:34:17 +02:00
parent 3929dd6d73
commit 8f179f763f
5 changed files with 128 additions and 53 deletions

BIN
.coverage

Binary file not shown.

View File

@ -185,8 +185,8 @@ def create_dummy_instance(fk: Type['Model'], pk: int = None):
class ForeignKey(BaseField): class ForeignKey(BaseField):
def __init__(self, to, related_name: str = None, nullable: bool = True, virtual: bool = False): def __init__(self, to, name: str = None, related_name: str = None, nullable: bool = True, virtual: bool = False):
super().__init__(nullable=nullable) super().__init__(nullable=nullable, name=name)
self.virtual = virtual self.virtual = virtual
self.related_name = related_name self.related_name = related_name
self.to = to self.to = to
@ -230,6 +230,8 @@ class ForeignKey(BaseField):
type_=Optional[child.__pydantic_model__], type_=Optional[child.__pydantic_model__],
model_config=child.__pydantic_model__.__config__, model_config=child.__pydantic_model__.__config__,
class_validators=child.__pydantic_model__.__validators__) class_validators=child.__pydantic_model__.__validators__)
model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True) model.__model_fields__[child_model_name] = ForeignKey(child.__class__,
name=child_model_name,
virtual=True)
return model return model

View File

@ -1,9 +1,10 @@
from typing import List, TYPE_CHECKING, Type from typing import List, TYPE_CHECKING, Type, NamedTuple
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
import orm import orm
from orm import ForeignKey
from orm.exceptions import NoMatch, MultipleMatches from orm.exceptions import NoMatch, MultipleMatches
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -22,6 +23,13 @@ FILTER_OPERATORS = {
} }
class JoinParameters(NamedTuple):
prev_model: Type['Model']
previous_alias: str
from_table: str
model_cls: Type['Model']
class QuerySet: class QuerySet:
ESCAPE_CHARACTERS = ['%', '_'] ESCAPE_CHARACTERS = ['%', '_']
@ -32,7 +40,13 @@ class QuerySet:
self._select_related = [] if select_related is None else select_related self._select_related = [] if select_related is None else select_related
self.limit_count = limit_count self.limit_count = limit_count
self.query_offset = offset self.query_offset = offset
self.aliases_dict = dict()
self.auto_related = []
self.used_aliases = []
self.select_from = None
self.columns = None
self.order_bys = None
def __get__(self, instance, owner): def __get__(self, instance, owner):
return self.__class__(model_cls=owner) return self.__class__(model_cls=owner)
@ -56,52 +70,101 @@ class QuerySet:
return text(f'{alias}_{to_table}.{to_key}=' return text(f'{alias}_{to_table}.{to_key}='
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}') f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}')
def build_select_expression(self): def build_join_parameters(self, part, join_params: JoinParameters):
# tables = [self.table] model_cls = join_params.model_cls.__model_fields__[part].to
columns = list(self.table.columns)
order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')]
select_from = self.table
for key in self.model_cls.__model_fields__:
if not self.model_cls.__model_fields__[key].nullable \
and isinstance(self.model_cls.__model_fields__[key], orm.fields.ForeignKey) \
and key not in self._select_related:
self._select_related.append(key)
for item in self._select_related:
previous_alias = ''
from_table = self.table.name
prev_model = self.model_cls
model_cls = self.model_cls
for part in item.split("__"):
model_cls = model_cls.__model_fields__[part].to
to_table = model_cls.__table__.name to_table = model_cls.__table__.name
alias = model_cls._orm_relationship_manager.resolve_relation_join(from_table, to_table) alias = model_cls._orm_relationship_manager.resolve_relation_join(join_params.from_table, to_table)
# print(f'resolving tables alias from {join_params.from_table}, to: {to_table} -> {alias}')
if prev_model.__model_fields__[part].virtual: if alias not in self.used_aliases:
# TODO: change the key lookup if join_params.prev_model.__model_fields__[part].virtual:
to_key = prev_model.__name__.lower() to_key = next((v for k, v in model_cls.__model_fields__.items()
if isinstance(v, ForeignKey) and v.to == join_params.prev_model), None).name
from_key = model_cls.__pkname__ from_key = model_cls.__pkname__
else: else:
to_key = model_cls.__pkname__ to_key = model_cls.__pkname__
from_key = part from_key = part
on_clause = self.on_clause(from_table, to_table, previous_alias, alias, to_key, from_key) on_clause = self.on_clause(join_params.from_table, to_table, join_params.previous_alias, alias, to_key,
from_key)
target_table = self.prefixed_table_name(alias, to_table) target_table = self.prefixed_table_name(alias, to_table)
select_from = sqlalchemy.sql.outerjoin(select_from, target_table, on_clause) self.select_from = sqlalchemy.sql.outerjoin(self.select_from, target_table, on_clause)
# tables.append(model_cls.__table__) self.order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}'))
order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}')) self.columns.extend(self.prefixed_columns(alias, model_cls.__table__))
columns.extend(self.prefixed_columns(alias, model_cls.__table__)) self.used_aliases.append(alias)
previous_alias = alias previous_alias = alias
from_table = to_table from_table = to_table
prev_model = model_cls prev_model = model_cls
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
expr = sqlalchemy.sql.select(columns) @staticmethod
expr = expr.select_from(select_from) def field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part) -> bool:
return isinstance(field, ForeignKey) and field_name not in rel_part
def field_qualifies_to_deeper_search(self, field, parent_virtual, nested, rel_part) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
partial_match = any([x.startswith(prev_part_of_related) for x in self._select_related])
already_checked = any([x.startswith(rel_part) for x in self.auto_related])
return ((field.virtual and parent_virtual) or (partial_match and not already_checked)) or not nested
def extract_auto_required_relations(self, join_params: JoinParameters,
rel_part: str = '', nested: bool = False, parent_virtual: bool = False):
# print(f'checking model {join_params.prev_model}, {rel_part}')
for field_name, field in join_params.prev_model.__model_fields__.items():
# print(f'checking_field {field_name}')
if self.field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part):
rel_part = field_name if not rel_part else rel_part + '__' + field_name
if not field.nullable:
# print(f'field {field_name} is not nullable, appending to auto, curr rel: {rel_part}')
if rel_part not in self._select_related:
self.auto_related.append("__".join(rel_part.split("__")[:-1]))
rel_part = ''
elif self.field_qualifies_to_deeper_search(field, parent_virtual, nested, rel_part):
# print(
# f'field {field_name} is nullable, going down, curr rel: '
# f'{rel_part}, nested:{nested}, virtual:{field.virtual}, virtual_par:{parent_virtual}, '
# f'injoin: {"__".join(rel_part.split("__")[:-1]) in self._select_related}')
join_params = JoinParameters(field.to, join_params.previous_alias,
join_params.from_table, join_params.prev_model)
self.extract_auto_required_relations(join_params=join_params,
rel_part=rel_part, nested=True, parent_virtual=field.virtual)
else:
# print(
# f'field {field_name} is out, going down, curr rel: '
# f'{rel_part}, nested:{nested}, virtual:{field.virtual}, virtual_par:{parent_virtual}, '
# f'injoin: {"__".join(rel_part.split("__")[:-1]) in self._select_related}')
rel_part = ''
def build_select_expression(self):
self.columns = list(self.table.columns)
self.order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')]
self.select_from = self.table
for key in self.model_cls.__model_fields__:
if not self.model_cls.__model_fields__[key].nullable \
and isinstance(self.model_cls.__model_fields__[key], orm.fields.ForeignKey) \
and key not in self._select_related:
self._select_related = [key] + self._select_related
start_params = JoinParameters(self.model_cls, '', self.table.name, self.model_cls)
self.extract_auto_required_relations(start_params)
if self.auto_related:
new_joins = []
for join in self._select_related:
if not any([x.startswith(join) for x in self.auto_related]):
new_joins.append(join)
self._select_related = new_joins + self.auto_related
self._select_related.sort(key=lambda item: (-len(item), item))
for item in self._select_related:
join_parameters = JoinParameters(self.model_cls, '', self.table.name, self.model_cls)
for part in item.split("__"):
join_parameters = self.build_join_parameters(part, join_parameters)
expr = sqlalchemy.sql.select(self.columns)
expr = expr.select_from(self.select_from)
if self.filter_clauses: if self.filter_clauses:
if len(self.filter_clauses) == 1: if len(self.filter_clauses) == 1:
@ -116,10 +179,17 @@ class QuerySet:
if self.query_offset: if self.query_offset:
expr = expr.offset(self.query_offset) expr = expr.offset(self.query_offset)
for order in order_bys: for order in self.order_bys:
expr = expr.order_by(order) expr = expr.order_by(order)
print(expr.compile(compile_kwargs={"literal_binds": True})) # print(expr.compile(compile_kwargs={"literal_binds": True}))
self.select_from = None
self.columns = None
self.order_bys = None
self.auto_related = []
self.used_aliases = []
return expr return expr
def filter(self, **kwargs): def filter(self, **kwargs):
@ -163,7 +233,7 @@ class QuerySet:
model_cls = model_cls.__model_fields__[part].to model_cls = model_cls.__model_fields__[part].to
previous_table = current_table previous_table = current_table
print(table_prefix) # print(table_prefix)
table = model_cls.__table__ table = model_cls.__table__
column = model_cls.__table__.columns[field_name] column = model_cls.__table__.columns[field_name]
@ -298,7 +368,7 @@ class QuerySet:
@classmethod @classmethod
def merge_two_instances(cls, one: 'Model', other: 'Model'): def merge_two_instances(cls, one: 'Model', other: 'Model'):
for field in one.__model_fields__.keys(): for field in one.__model_fields__.keys():
print(field, one.dict(), other.dict()) # print(field, one.dict(), other.dict())
if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model): if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model):
setattr(other, field, getattr(one, field) + getattr(other, field)) setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), orm.models.Model): elif isinstance(getattr(one, field), orm.models.Model):

View File

@ -40,7 +40,7 @@ class RelationshipManager:
self._relations[reverse_key] = get_relation_config('reverse', table_name, field) self._relations[reverse_key] = get_relation_config('reverse', table_name, field)
def deregister(self, model: 'Model'): def deregister(self, model: 'Model'):
print(f'deregistering {model.__class__.__name__}, {model._orm_id}') # print(f'deregistering {model.__class__.__name__}, {model._orm_id}')
for rel_type in self._relations.keys(): for rel_type in self._relations.keys():
if model.__class__.__name__.lower() in rel_type.lower(): if model.__class__.__name__.lower() in rel_type.lower():
if model._orm_id in self._relations[rel_type]: if model._orm_id in self._relations[rel_type]:
@ -55,11 +55,11 @@ class RelationshipManager:
child, parent = parent, proxy(child) child, parent = parent, proxy(child)
else: else:
child = proxy(child) child = proxy(child)
print( # print(
f'setting up relationship, {parent_id}, {child_id}, ' # f'setting up relationship, {parent_id}, {child_id}, '
f'{parent.__class__.__name__}, {child.__class__.__name__}, ' # f'{parent.__class__.__name__}, {child.__class__.__name__}, '
f'{parent.pk if parent.values is not None else None}, ' # f'{parent.pk if parent.values is not None else None}, '
f'{child.pk if child.values is not None else None}') # f'{child.pk if child.values is not None else None}')
parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, []) parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, [])
self.append_related_model(parents_list, child) self.append_related_model(parents_list, child)
children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, []) children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, [])

View File

@ -77,15 +77,18 @@ async def test_model_multiple_instances_of_same_table_in_schema():
await Student.objects.create(name="Jane", category=category, schoolclass=class1) await Student.objects.create(name="Jane", category=category, schoolclass=class1)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
classes = await SchoolClass.objects.select_related(['teachers', 'students']).all() classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all()
assert classes[0].name == 'Math' assert classes[0].name == 'Math'
assert classes[0].students[0].name == 'Jane' assert classes[0].students[0].name == 'Jane'
# related fields of main model are only populated by pk # related fields of main model are only populated by pk
# unless there is a required foreign key somewhere along the way
# since department is required for schoolclass it was pre loaded (again)
# but you can load them anytime # but you can load them anytime
assert classes[0].students[0].schoolclass.name is None
await classes[0].students[0].schoolclass.load()
assert classes[0].students[0].schoolclass.name == 'Math' assert classes[0].students[0].schoolclass.name == 'Math'
assert classes[0].students[0].schoolclass.department.name is None
await classes[0].students[0].schoolclass.department.load()
assert classes[0].students[0].schoolclass.department.name == 'Math Department'
@pytest.mark.asyncio @pytest.mark.asyncio