remove auto related parsing, switch to relations on instance instead of relationship manager

This commit is contained in:
collerek
2020-08-24 11:15:59 +02:00
parent 9bbf6f93ed
commit 63a24e7d36
12 changed files with 295 additions and 209 deletions

BIN
.coverage

Binary file not shown.

View File

@ -107,9 +107,8 @@ class ForeignKeyField(BaseField):
@classmethod @classmethod
def register_relation(cls, model: "Model", child: "Model") -> None: def register_relation(cls, model: "Model", child: "Model") -> None:
child_model_name = cls.related_name or child.get_name() model._orm.add(
model.Meta._orm_relationship_manager.add_relation( parent=model, child=child, child_name=cls.related_name, virtual=cls.virtual
model, child, child_model_name, virtual=cls.virtual
) )
@classmethod @classmethod

View File

@ -1,4 +1,5 @@
from typing import Any, List import itertools
from typing import Any, List, Tuple, Union
import sqlalchemy import sqlalchemy
@ -6,6 +7,21 @@ import ormar.queryset # noqa I100
from ormar.models import NewBaseModel # noqa I100 from ormar.models import NewBaseModel # noqa I100
def group_related_list(list_):
test_dict = dict()
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
for key, group in grouped:
group_list = list(group)
new = [
"__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1
]
if any("__" in x for x in new):
test_dict[key] = group_related_list(new)
else:
test_dict[key] = new
return test_dict
class Model(NewBaseModel): class Model(NewBaseModel):
__abstract__ = False __abstract__ = False
@ -14,22 +30,27 @@ class Model(NewBaseModel):
cls, cls,
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
select_related: List = None, select_related: List = None,
related_models: Any = None,
previous_table: str = None, previous_table: str = None,
) -> "Model": ) -> Union["Model", Tuple["Model", dict]]:
item = {} item = {}
select_related = select_related or [] select_related = select_related or []
related_models = related_models or []
if select_related:
related_models = group_related_list(select_related)
table_prefix = cls.Meta._orm_relationship_manager.resolve_relation_join( table_prefix = cls.Meta._orm_relationship_manager.resolve_relation_join(
previous_table, cls.Meta.table.name previous_table, cls.Meta.table.name
) )
previous_table = cls.Meta.table.name previous_table = cls.Meta.table.name
for related in select_related: for related in related_models:
if "__" in related: if isinstance(related_models, dict) and related_models[related]:
first_part, remainder = related.split("__", 1) first_part, remainder = related, related_models[related]
model_cls = cls.Meta.model_fields[first_part].to model_cls = cls.Meta.model_fields[first_part].to
child = model_cls.from_row( child = model_cls.from_row(
row, select_related=[remainder], previous_table=previous_table row, related_models=remainder, previous_table=previous_table
) )
item[first_part] = child item[first_part] = child
else: else:
@ -43,7 +64,8 @@ class Model(NewBaseModel):
f'{table_prefix + "_" if table_prefix else ""}{column.name}' f'{table_prefix + "_" if table_prefix else ""}{column.name}'
] ]
return cls(**item) instance = cls(**item) if item.get(cls.Meta.pkname, None) is not None else None
return instance
async def save(self) -> "Model": async def save(self) -> "Model":
self_fields = self._extract_model_db_fields() self_fields = self._extract_model_db_fields()

View File

@ -43,6 +43,18 @@ class ModelTableProxy:
related_names.add(name) related_names.add(name)
return related_names return related_names
@classmethod
def _extract_db_related_names(cls) -> Set:
related_names = set()
for name, field in cls.Meta.model_fields.items():
if (
inspect.isclass(field)
and issubclass(field, ForeignKeyField)
and not field.virtual
):
related_names.add(name)
return related_names
@classmethod @classmethod
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set: def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
if nested: if nested:
@ -62,7 +74,7 @@ class ModelTableProxy:
self_fields = { self_fields = {
k: v for k, v in self_fields.items() if k in self.Meta.table.columns k: v for k, v in self_fields.items() if k in self.Meta.table.columns
} }
for field in self._extract_related_names(): for field in self._extract_db_related_names():
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
if getattr(self, field) is not None: if getattr(self, field) is not None:
self_fields[field] = getattr(getattr(self, field), target_pk_name) self_fields[field] = getattr(getattr(self, field), target_pk_name)
@ -72,8 +84,8 @@ class ModelTableProxy:
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows = [] merged_rows = []
for index, model in enumerate(result_rows): for index, model in enumerate(result_rows):
if index > 0 and model.pk == result_rows[index - 1].pk: if index > 0 and model.pk == merged_rows[-1].pk:
result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
else: else:
merged_rows.append(model) merged_rows.append(model)
return merged_rows return merged_rows

View File

@ -20,9 +20,10 @@ from pydantic import BaseModel
import ormar # noqa I100 import ormar # noqa I100
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.metaclass import ModelMeta, ModelMetaclass
from ormar.models.modelproxy import ModelTableProxy from ormar.models.modelproxy import ModelTableProxy
from ormar.relations import AliasManager from ormar.relations import AliasManager, RelationsManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models.model import Model from ormar.models.model import Model
@ -34,7 +35,7 @@ if TYPE_CHECKING: # pragma no cover
class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass): class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass):
__slots__ = ("_orm_id", "_orm_saved") __slots__ = ("_orm_id", "_orm_saved", "_orm")
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, TypeVar[BaseField]] __model_fields__: Dict[str, TypeVar[BaseField]]
@ -46,6 +47,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
__metadata__: sqlalchemy.MetaData __metadata__: sqlalchemy.MetaData
__database__: databases.Database __database__: databases.Database
_orm_relationship_manager: AliasManager _orm_relationship_manager: AliasManager
_orm: RelationsManager
Meta: ModelMeta Meta: ModelMeta
# noinspection PyMissingConstructor # noinspection PyMissingConstructor
@ -53,6 +55,18 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
object.__setattr__(self, "_orm_id", uuid.uuid4().hex) object.__setattr__(self, "_orm_id", uuid.uuid4().hex)
object.__setattr__(self, "_orm_saved", False) object.__setattr__(self, "_orm_saved", False)
object.__setattr__(
self,
"_orm",
RelationsManager(
related_fields=[
field
for name, field in self.Meta.model_fields.items()
if issubclass(field, ForeignKeyField)
],
owner=self,
),
)
pk_only = kwargs.pop("__pk_only__", False) pk_only = kwargs.pop("__pk_only__", False)
if "pk" in kwargs: if "pk" in kwargs:
@ -71,16 +85,12 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
object.__setattr__(self, "__dict__", values) object.__setattr__(self, "__dict__", values)
object.__setattr__(self, "__fields_set__", fields_set) object.__setattr__(self, "__fields_set__", fields_set)
def __del__(self) -> None:
self.Meta._orm_relationship_manager.deregister(self)
def __setattr__(self, name: str, value: Any) -> None: def __setattr__(self, name: str, value: Any) -> None:
relation_key = self.get_name(title=True) + "_" + name
if name in self.__slots__: if name in self.__slots__:
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
elif name == "pk": elif name == "pk":
object.__setattr__(self, self.Meta.pkname, value) object.__setattr__(self, self.Meta.pkname, value)
elif self.Meta._orm_relationship_manager.contains(relation_key, self): elif name in self._orm:
self.Meta.model_fields[name].expand_relationship(value, self) self.Meta.model_fields[name].expand_relationship(value, self)
else: else:
value = ( value = (
@ -91,24 +101,27 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
super().__setattr__(name, value) super().__setattr__(name, value)
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:
if item != "__fields__" and item in self.__fields__: if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"):
related = self._extract_related_model_instead_of_field(item) return object.__getattribute__(self, item)
if related: elif item != "_extract_related_names" and item in self._extract_related_names():
return related return self._extract_related_model_instead_of_field(item)
value = object.__getattribute__(self, item) elif item == "pk":
return self.__dict__.get(self.Meta.pkname, None)
elif item != "__fields__" and item in self.__fields__:
value = self.__dict__.get(item, None)
value = self._convert_json(item, value, "loads") value = self._convert_json(item, value, "loads")
return value return value
return super().__getattribute__(item) return super().__getattribute__(item)
def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]: # def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]:
return self._extract_related_model_instead_of_field(item) # return self._extract_related_model_instead_of_field(item)
def _extract_related_model_instead_of_field( def _extract_related_model_instead_of_field(
self, item: str self, item: str
) -> Optional[Union["Model", List["Model"]]]: ) -> Optional[Union["Model", List["Model"]]]:
relation_key = self.get_name(title=True) + "_" + item # relation_key = self.get_name(title=True) + "_" + item
if self.Meta._orm_relationship_manager.contains(relation_key, self): if item in self._orm:
return self.Meta._orm_relationship_manager.get(relation_key, self) return self._orm.get(item)
def __same__(self, other: "Model") -> bool: def __same__(self, other: "Model") -> bool:
if self.__class__ != other.__class__: # pragma no cover if self.__class__ != other.__class__: # pragma no cover
@ -128,10 +141,6 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
name = name.title() name = name.title()
return name return name
@property
def pk(self) -> Any:
return getattr(self, self.Meta.pkname)
@property @property
def pk_column(self) -> sqlalchemy.Column: def pk_column(self) -> sqlalchemy.Column:
return self.Meta.table.primary_key.columns.values()[0] return self.Meta.table.primary_key.columns.values()[0]
@ -177,7 +186,6 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
setattr(self, key, value) setattr(self, key, value)
def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]: def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]:
if not self._is_conversion_to_json_needed(column_name): if not self._is_conversion_to_json_needed(column_name):
return value return value

View File

@ -5,7 +5,6 @@ from sqlalchemy import text
import ormar # noqa I100 import ormar # noqa I100
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.queryset.relationship_crawler import RelationshipCrawler
from ormar.relations import AliasManager from ormar.relations import AliasManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -52,14 +51,7 @@ class Query:
self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")] self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")]
self.select_from = self.table self.select_from = self.table
start_params = JoinParameters( self._select_related.sort(key=lambda item: (item, -len(item)))
self.model_cls, "", self.table.name, self.model_cls
)
self._select_related = RelationshipCrawler().discover_relations(
self._select_related, prev_model=start_params.prev_model
)
self._select_related.sort(key=lambda item: (-len(item), item))
for item in self._select_related: for item in self._select_related:
join_parameters = JoinParameters( join_parameters = JoinParameters(

View File

@ -138,7 +138,6 @@ class QuerySet:
self.model_cls.from_row(row, select_related=self._select_related) self.model_cls.from_row(row, select_related=self._select_related)
for row in rows for row in rows
] ]
result_rows = self.model_cls.merge_instances_list(result_rows) result_rows = self.model_cls.merge_instances_list(result_rows)
return result_rows return result_rows

View File

@ -1,87 +0,0 @@
from typing import List, TYPE_CHECKING, Type
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
if TYPE_CHECKING: # pragma no cover
from ormar import Model
class RelationshipCrawler:
def __init__(self) -> None:
self._select_related = []
self.auto_related = []
self.already_checked = []
def discover_relations(
self, select_related: List, prev_model: Type["Model"]
) -> List[str]:
self._select_related = select_related
self._extract_auto_required_relations(prev_model=prev_model)
self._include_auto_related_models()
return self._select_related
@staticmethod
def _field_is_a_foreign_key_and_no_circular_reference(
field: Type[BaseField], field_name: str, rel_part: str
) -> bool:
return issubclass(field, ForeignKeyField) and field_name not in rel_part
def _field_qualifies_to_deeper_search(
self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str
) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
partial_match = any(
[x.startswith(prev_part_of_related) for x in self._select_related]
)
already_checked = any(
[x.startswith(rel_part) for x in (self.auto_related + self.already_checked)]
)
return (
(field.virtual and parent_virtual)
or (partial_match and not already_checked)
) or not nested
def _extract_auto_required_relations(
self,
prev_model: Type["Model"],
rel_part: str = "",
nested: bool = False,
parent_virtual: bool = False,
) -> None:
for field_name, field in prev_model.Meta.model_fields.items():
if self._field_is_a_foreign_key_and_no_circular_reference(
field, field_name, rel_part
):
rel_part = field_name if not rel_part else rel_part + "__" + field_name
if not field.nullable:
if rel_part not in self._select_related:
split_tables = rel_part.split("__")
new_related = (
"__".join(split_tables[:-1])
if len(split_tables) > 1
else rel_part
)
self.auto_related.append(new_related)
rel_part = ""
elif self._field_qualifies_to_deeper_search(
field, parent_virtual, nested, rel_part
):
self._extract_auto_required_relations(
prev_model=field.to,
rel_part=rel_part,
nested=True,
parent_virtual=field.virtual,
)
else:
self.already_checked.append(rel_part)
rel_part = ""
def _include_auto_related_models(self) -> None:
if self.auto_related:
new_joins = []
for join in self._select_related:
if not any([x.startswith(join) for x in self.auto_related]):
new_joins.append(join)
self._select_related = new_joins + self.auto_related

View File

@ -1,23 +1,30 @@
import pprint import string
import string import string
import uuid import uuid
from enum import Enum
from random import choices from random import choices
from typing import List, TYPE_CHECKING, Union from typing import List, TYPE_CHECKING, Type
from weakref import proxy from weakref import proxy
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
from ormar.exceptions import RelationshipInstanceError
from ormar.fields.foreign_key import ForeignKeyField # noqa I100 from ormar.fields.foreign_key import ForeignKeyField # noqa I100
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models import NewBaseModel, Model from ormar.models import Model
def get_table_alias() -> str: def get_table_alias() -> str:
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
class RelationType(Enum):
PRIMARY = 1
REVERSE = 2
class AliasManager: class AliasManager:
def __init__(self) -> None: def __init__(self) -> None:
self._relations = dict() self._relations = dict()
@ -42,78 +49,97 @@ class AliasManager:
table_name: str, table_name: str,
) -> None: ) -> None:
if relations_key not in self._relations: if relations_key not in self._relations:
self._relations[relations_key] = {"type": "primary"}
self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias() self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias()
if reverse_key not in self._relations: if reverse_key not in self._relations:
self._relations[reverse_key] = {"type": "reverse"}
self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias() self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias()
def deregister(self, model: "NewBaseModel") -> None:
for rel_type in self._relations.keys():
if model.get_name() in rel_type.lower():
if model._orm_id in self._relations[rel_type]:
del self._relations[rel_type][model._orm_id]
def add_relation(
self,
parent: "NewBaseModel",
child: "NewBaseModel",
child_model_name: str,
virtual: bool = False,
) -> None:
parent_id, child_id = parent._orm_id, child._orm_id
parent_name = parent.get_name(title=True)
child_name = (
child_model_name
if child.get_name() != child_model_name
else child.get_name() + "s"
)
if virtual:
child_name, parent_name = parent_name, child.get_name()
child_id, parent_id = parent_id, child_id
child, parent = parent, proxy(child)
child_name = child_name.lower() + "s"
else:
child = proxy(child)
parent_relation_name = parent_name.title() + "_" + child_name
parents_list = self._relations[parent_relation_name].setdefault(parent_id, [])
self.append_related_model(parents_list, child)
child_relation_name = child.get_name(title=True) + "_" + parent_name.lower()
children_list = self._relations[child_relation_name].setdefault(child_id, [])
self.append_related_model(children_list, parent)
@staticmethod
def append_related_model(relations_list: List["Model"], model: "Model") -> None:
for relation_child in relations_list:
try:
if relation_child.__same__(model):
return
except ReferenceError:
continue
relations_list.append(model)
def contains(self, relations_key: str, instance: "NewBaseModel") -> bool:
if relations_key in self._relations:
return instance._orm_id in self._relations[relations_key]
return False
def get(
self, relations_key: str, instance: "NewBaseModel"
) -> Union["Model", List["Model"]]:
if relations_key in self._relations:
if instance._orm_id in self._relations[relations_key]:
if self._relations[relations_key]["type"] == "primary":
return self._relations[relations_key][instance._orm_id][0]
return self._relations[relations_key][instance._orm_id]
def resolve_relation_join(self, from_table: str, to_table: str) -> str: def resolve_relation_join(self, from_table: str, to_table: str) -> str:
return self._aliases.get(f"{from_table}_{to_table}", "") return self._aliases.get(f"{from_table}_{to_table}", "")
def __str__(self) -> str: # pragma no cover
return pprint.pformat(self._relations, indent=4, width=1)
def __repr__(self) -> str: # pragma no cover class Relation:
return self.__str__() def __init__(self, type_: RelationType) -> None:
self._type = type_
self.related_models = [] if type_ == RelationType.REVERSE else None
def _find_existing(self, child):
for ind, relation_child in enumerate(self.related_models):
try:
if relation_child.__same__(child):
return ind
except ReferenceError: # pragma no cover
continue
return None
def add(self, child: "Model") -> None:
if self._type == RelationType.PRIMARY:
self.related_models = child
else:
if self._find_existing(child) is None:
self.related_models.append(child)
# def remove(self, child: "Model") -> None:
# if self._type == RelationType.PRIMARY:
# self.related_models = None
# else:
# position = self._find_existing(child)
# if position is not None:
# self.related_models.pop(position)
def get(self):
return self.related_models
class RelationsManager:
def __init__(
self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None
):
self.owner = owner
self._related_fields = related_fields or []
self._related_names = [field.name for field in self._related_fields]
self._relations = dict()
for field in self._related_fields:
self._relations[field.name] = Relation(
type_=RelationType.PRIMARY
if not field.virtual
else RelationType.REVERSE
)
def __contains__(self, item):
return item in self._related_names
def get(self, name):
relation = self._relations.get(name, None)
if relation:
return relation.get()
def _get(self, name):
relation = self._relations.get(name, None)
if relation:
return relation
def add(self, parent: "Model", child: "Model", child_name: str, virtual: bool):
to_field = next(
(
field
for field in child._orm._related_fields
if field.to == parent.__class__
),
None,
)
if not to_field: # pragma no cover
raise RelationshipInstanceError(
f"Model {child.__class__} does not have reference to model {parent.__class__}"
)
to_name = to_field.name
if virtual:
child_name, to_name = to_name, child_name or child.get_name()
child, parent = parent, proxy(child)
else:
child_name = child_name or child.get_name() + "s"
child = proxy(child)
parent._orm._get(child_name).add(child)
child._orm._get(to_name).add(parent)

View File

@ -131,7 +131,7 @@ async def test_model_crud():
album1 = await Album.objects.get(name="Malibu") album1 = await Album.objects.get(name="Malibu")
assert album1.pk == 1 assert album1.pk == 1
assert album1.tracks is None assert album1.tracks == []
await Track.objects.create(album={"id": track.album.pk}, title="The Bird2", position=4) await Track.objects.create(album={"id": track.album.pk}, title="The Bird2", position=4)

View File

@ -0,0 +1,110 @@
import asyncio
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Department(ormar.Model):
class Meta:
tablename = "departments"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True, autoincrement=False)
name: ormar.String(max_length=100)
class SchoolClass(ormar.Model):
class Meta:
tablename = "schoolclasses"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
class Category(ormar.Model):
class Meta:
tablename = "categories"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
department: ormar.ForeignKey(Department, nullable=False)
class Student(ormar.Model):
class Meta:
tablename = "students"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
schoolclass: ormar.ForeignKey(SchoolClass)
category: ormar.ForeignKey(Category, nullable=True)
class Teacher(ormar.Model):
class Meta:
tablename = "teachers"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
schoolclass: ormar.ForeignKey(SchoolClass)
category: ormar.ForeignKey(Category, nullable=True)
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module")
async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
department = await Department.objects.create(id=1, name="Math Department")
department2 = await Department.objects.create(id=2, name="Law Department")
class1 = await SchoolClass.objects.create(name="Math")
class2 = await SchoolClass.objects.create(name="Logic")
category = await Category.objects.create(name="Foreign", department=department)
category2 = await Category.objects.create(name="Domestic", department=department2)
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
await Student.objects.create(name="Judy", category=category2, schoolclass=class1)
await Student.objects.create(name="Jack", category=category2, schoolclass=class2)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema():
async with database:
classes = await SchoolClass.objects.select_related(
["teachers__category__department", "students"]
).all()
assert classes[0].name == "Math"
assert classes[0].students[0].name == "Jane"
assert len(classes[0].dict().get("students")) == 2
assert classes[0].teachers[0].category.department.name == 'Law Department'
assert classes[0].students[0].category.pk is not None
assert classes[0].students[0].category.name is None
await classes[0].students[0].category.load()
await classes[0].students[0].category.department.load()
assert classes[0].students[0].category.department.name == 'Math Department'

View File

@ -79,11 +79,14 @@ async def create_test_database():
metadata.drop_all(engine) metadata.drop_all(engine)
metadata.create_all(engine) metadata.create_all(engine)
department = await Department.objects.create(id=1, name="Math Department") department = await Department.objects.create(id=1, name="Math Department")
department2 = await Department.objects.create(id=2, name="Law Department")
class1 = await SchoolClass.objects.create(name="Math", department=department) class1 = await SchoolClass.objects.create(name="Math", department=department)
class2 = await SchoolClass.objects.create(name="Logic", department=department2)
category = await Category.objects.create(name="Foreign") category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic") category2 = await Category.objects.create(name="Domestic")
await Student.objects.create(name="Jane", category=category, schoolclass=class1) await Student.objects.create(name="Jane", category=category, schoolclass=class1)
await Student.objects.create(name="Jack", category=category2, schoolclass=class1) await Student.objects.create(name="Judy", category=category2, schoolclass=class1)
await Student.objects.create(name="Jack", category=category2, schoolclass=class2)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
yield yield
metadata.drop_all(engine) metadata.drop_all(engine)
@ -100,15 +103,15 @@ async def test_model_multiple_instances_of_same_table_in_schema():
assert len(classes[0].dict().get("students")) == 2 assert len(classes[0].dict().get("students")) == 2
# related fields of main model are only populated by pk # since it's going from schoolclass => teacher => schoolclass (same class) department is already populated
# unless there is a required foreign key somewhere along the way
# since department is required for schoolclass it was pre loaded (again)
# but you can load them anytime
assert classes[0].students[0].schoolclass.name == "Math" assert classes[0].students[0].schoolclass.name == "Math"
assert classes[0].students[0].schoolclass.department.name is None assert classes[0].students[0].schoolclass.department.name is None
await classes[0].students[0].schoolclass.department.load() await classes[0].students[0].schoolclass.department.load()
assert classes[0].students[0].schoolclass.department.name == "Math Department" assert classes[0].students[0].schoolclass.department.name == "Math Department"
await classes[1].students[0].schoolclass.department.load()
assert classes[1].students[0].schoolclass.department.name == "Law Department"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_right_tables_join(): async def test_right_tables_join():
@ -130,5 +133,7 @@ async def test_multiple_reverse_related_objects():
["teachers__category", "students__category"] ["teachers__category", "students__category"]
).all() ).all()
assert classes[0].name == "Math" assert classes[0].name == "Math"
assert classes[0].students[1].name == "Jack" assert classes[0].students[1].name == "Judy"
assert classes[0].students[0].category.name == "Foreign"
assert classes[0].students[1].category.name == "Domestic"
assert classes[0].teachers[0].category.name == "Domestic" assert classes[0].teachers[0].category.name == "Domestic"