From 63a24e7d3640cb2143061c824e625d50336060f5 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 24 Aug 2020 11:15:59 +0200 Subject: [PATCH] remove auto related parsing, switch to relations on instance instead of relationship manager --- .coverage | Bin 53248 -> 53248 bytes ormar/fields/foreign_key.py | 5 +- ormar/models/model.py | 36 ++++-- ormar/models/modelproxy.py | 18 ++- ormar/models/newbasemodel.py | 52 ++++---- ormar/queryset/query.py | 10 +- ormar/queryset/queryset.py | 1 - ormar/queryset/relationship_crawler.py | 87 ------------- ormar/relations.py | 166 ++++++++++++++----------- tests/test_foreign_keys.py | 2 +- tests/test_more_same_table_joins.py | 110 ++++++++++++++++ tests/test_same_table_joins.py | 17 ++- 12 files changed, 295 insertions(+), 209 deletions(-) delete mode 100644 ormar/queryset/relationship_crawler.py create mode 100644 tests/test_more_same_table_joins.py diff --git a/.coverage b/.coverage index aa902116fbfd267068fc27ff6f75732d6109d2f5..0cd63857c69e6da9ef7ec5d430768f9a355b2980 100644 GIT binary patch delta 781 zcmZvYOH30%9LBrswm^4xrtd8ewc-O}p`{2)h{bv^K^6?$41yJCB{h^KJj4bH=o#C? zXks)`JZM6MVA6yd(e{8OAR13%gg^sP8V}Hf1_XiDDRS`OX}<6G&HQJ+=^rHf2gz3w zom4D>7HEV@DE7%^8vp^LE~`*yGL@JLOY%)HiOk?`+1uJ3LQ9dMz+k##Ca6Pq1yp4m zP07_|=jFB98hkNhij0^-$rhb@i5X^km?|cn{y`7ZZu&e;NT=}R0V#T4Zd13{JMGqv z`i5q^^`_uxX*aaG&S_D++)QxD%o`{TDe6QP{v>+9LrP7GoY)*i)4YsQks@K=#CSQO zM6wD!r65I(sQ+>%idCjCg3aEHr=6B34jn*{fURx9Z5K}TGW4uM?*o-vNoS@-&;@y@ z1wVlg^ngq3BI9O8fPx)jFR+m9V3y$*rW!tHGUye0gzlv+bdq!$*SIG|lQoU1|I|oB zu^Ka_HIVg)xO$w%7bnVb#bjr~tFgD*>^B@Ojt=~lIMEc6PMl=na=2*`DeRjVFC$b4 z*6Jyxk!(aubtmY9Vv59m?Ad=i$_7uCnaERJ@|HY{1=LOa>Sz_Mn95B?=@(+ k>x3m_^ap3-6VA`U@!hyTIKGjA1n)(vD;Oqxuf1iKKhESBfB*mh delta 776 zcmaKnUr19?9LIOPSBJazo^!X|{n2LGqy$$R=9WV%CKPF#NN=W%VLHtD=h`;^jC9jp zDgu2tL|I}&^$?L71FgjQpi)t04@!STM=wg)A4*VkZtKhkAA35#&*yi};rs0xB)bO5 zk(5jc5Wx~?h8hThQZ5EivLREOmzirW$jdD-<-sI!pQ_dA_S6YYCjBveZnu%3GOpRt zS;KHJZV+?CG!k1f74Mk`jEgB?0R5SMNW17lNmMe8KfED9Z_d>qmz5{fC}7nSDl}v@ zQ%X|Qh@M*yBHG4L3Q{EOg2g7Kcv4h}LBo2SUTbPH1E(P38Aim+P7x)$Xdc2B$knmqwC~&xoK8TA>cW(4xdb=6 zg?1m7vLmD&%M9sWd@Q9!6fTDCe~Q^7T4ElFyY{OVz!Z1}`oMJ%09U|ikPVX9AM8BN zf0s2W(z0q+h+_Adq8(7_1Tv9vDf!KvBiGRdcVWJ{?0JDL+F%vS*}FoJ=p;T9@7=e_k@tgJ(a>^ z0=$Hi<5MeljgCcI%Hk8@5R}QVIV}5i>2GW?9N!3q{lAlWDRwzhziy>-C?44!8y*i# z1gspFc*|9E1j52_AUrS^OM}=L;n!Ppeqm|iLaj9#iazaZn+T6aH0V!{DY@*nqfd-< qy?P}*BKqim`O~pB8Qs_!AK!_OP1djFBcb;UwYhZ?24D0RTmAu@Hy3#T diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 9d052a7..deea69b 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -107,9 +107,8 @@ class ForeignKeyField(BaseField): @classmethod def register_relation(cls, model: "Model", child: "Model") -> None: - child_model_name = cls.related_name or child.get_name() - model.Meta._orm_relationship_manager.add_relation( - model, child, child_model_name, virtual=cls.virtual + model._orm.add( + parent=model, child=child, child_name=cls.related_name, virtual=cls.virtual ) @classmethod diff --git a/ormar/models/model.py b/ormar/models/model.py index 84c3404..8be4ca9 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -1,4 +1,5 @@ -from typing import Any, List +import itertools +from typing import Any, List, Tuple, Union import sqlalchemy @@ -6,6 +7,21 @@ import ormar.queryset # noqa I100 from ormar.models import NewBaseModel # noqa I100 +def group_related_list(list_): + test_dict = dict() + grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0]) + for key, group in grouped: + group_list = list(group) + new = [ + "__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1 + ] + if any("__" in x for x in new): + test_dict[key] = group_related_list(new) + else: + test_dict[key] = new + return test_dict + + class Model(NewBaseModel): __abstract__ = False @@ -14,22 +30,27 @@ class Model(NewBaseModel): cls, row: sqlalchemy.engine.ResultProxy, select_related: List = None, + related_models: Any = None, previous_table: str = None, - ) -> "Model": + ) -> Union["Model", Tuple["Model", dict]]: item = {} select_related = select_related or [] + related_models = related_models or [] + if select_related: + related_models = group_related_list(select_related) table_prefix = cls.Meta._orm_relationship_manager.resolve_relation_join( previous_table, cls.Meta.table.name ) + previous_table = cls.Meta.table.name - for related in select_related: - if "__" in related: - first_part, remainder = related.split("__", 1) + for related in related_models: + if isinstance(related_models, dict) and related_models[related]: + first_part, remainder = related, related_models[related] model_cls = cls.Meta.model_fields[first_part].to child = model_cls.from_row( - row, select_related=[remainder], previous_table=previous_table + row, related_models=remainder, previous_table=previous_table ) item[first_part] = child else: @@ -43,7 +64,8 @@ class Model(NewBaseModel): f'{table_prefix + "_" if table_prefix else ""}{column.name}' ] - return cls(**item) + instance = cls(**item) if item.get(cls.Meta.pkname, None) is not None else None + return instance async def save(self) -> "Model": self_fields = self._extract_model_db_fields() diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index d9b99f3..c954a5f 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -43,6 +43,18 @@ class ModelTableProxy: related_names.add(name) return related_names + @classmethod + def _extract_db_related_names(cls) -> Set: + related_names = set() + for name, field in cls.Meta.model_fields.items(): + if ( + inspect.isclass(field) + and issubclass(field, ForeignKeyField) + and not field.virtual + ): + related_names.add(name) + return related_names + @classmethod def _exclude_related_names_not_required(cls, nested: bool = False) -> Set: if nested: @@ -62,7 +74,7 @@ class ModelTableProxy: self_fields = { k: v for k, v in self_fields.items() if k in self.Meta.table.columns } - for field in self._extract_related_names(): + for field in self._extract_db_related_names(): target_pk_name = self.Meta.model_fields[field].to.Meta.pkname if getattr(self, field) is not None: self_fields[field] = getattr(getattr(self, field), target_pk_name) @@ -72,8 +84,8 @@ class ModelTableProxy: def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: merged_rows = [] for index, model in enumerate(result_rows): - if index > 0 and model.pk == result_rows[index - 1].pk: - result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) + if index > 0 and model.pk == merged_rows[-1].pk: + merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) else: merged_rows.append(model) return merged_rows diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 0f09893..582529b 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -20,9 +20,10 @@ from pydantic import BaseModel import ormar # noqa I100 from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.modelproxy import ModelTableProxy -from ormar.relations import AliasManager +from ormar.relations import AliasManager, RelationsManager if TYPE_CHECKING: # pragma no cover from ormar.models.model import Model @@ -34,7 +35,7 @@ if TYPE_CHECKING: # pragma no cover class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass): - __slots__ = ("_orm_id", "_orm_saved") + __slots__ = ("_orm_id", "_orm_saved", "_orm") if TYPE_CHECKING: # pragma no cover __model_fields__: Dict[str, TypeVar[BaseField]] @@ -46,6 +47,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass __metadata__: sqlalchemy.MetaData __database__: databases.Database _orm_relationship_manager: AliasManager + _orm: RelationsManager Meta: ModelMeta # noinspection PyMissingConstructor @@ -53,6 +55,18 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass object.__setattr__(self, "_orm_id", uuid.uuid4().hex) object.__setattr__(self, "_orm_saved", False) + object.__setattr__( + self, + "_orm", + RelationsManager( + related_fields=[ + field + for name, field in self.Meta.model_fields.items() + if issubclass(field, ForeignKeyField) + ], + owner=self, + ), + ) pk_only = kwargs.pop("__pk_only__", False) if "pk" in kwargs: @@ -71,16 +85,12 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass object.__setattr__(self, "__dict__", values) object.__setattr__(self, "__fields_set__", fields_set) - def __del__(self) -> None: - self.Meta._orm_relationship_manager.deregister(self) - def __setattr__(self, name: str, value: Any) -> None: - relation_key = self.get_name(title=True) + "_" + name if name in self.__slots__: object.__setattr__(self, name, value) elif name == "pk": object.__setattr__(self, self.Meta.pkname, value) - elif self.Meta._orm_relationship_manager.contains(relation_key, self): + elif name in self._orm: self.Meta.model_fields[name].expand_relationship(value, self) else: value = ( @@ -91,24 +101,27 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass super().__setattr__(name, value) def __getattribute__(self, item: str) -> Any: - if item != "__fields__" and item in self.__fields__: - related = self._extract_related_model_instead_of_field(item) - if related: - return related - value = object.__getattribute__(self, item) + if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"): + return object.__getattribute__(self, item) + elif item != "_extract_related_names" and item in self._extract_related_names(): + return self._extract_related_model_instead_of_field(item) + elif item == "pk": + return self.__dict__.get(self.Meta.pkname, None) + elif item != "__fields__" and item in self.__fields__: + value = self.__dict__.get(item, None) value = self._convert_json(item, value, "loads") return value return super().__getattribute__(item) - def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]: - return self._extract_related_model_instead_of_field(item) + # def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]: + # return self._extract_related_model_instead_of_field(item) def _extract_related_model_instead_of_field( self, item: str ) -> Optional[Union["Model", List["Model"]]]: - relation_key = self.get_name(title=True) + "_" + item - if self.Meta._orm_relationship_manager.contains(relation_key, self): - return self.Meta._orm_relationship_manager.get(relation_key, self) + # relation_key = self.get_name(title=True) + "_" + item + if item in self._orm: + return self._orm.get(item) def __same__(self, other: "Model") -> bool: if self.__class__ != other.__class__: # pragma no cover @@ -128,10 +141,6 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass name = name.title() return name - @property - def pk(self) -> Any: - return getattr(self, self.Meta.pkname) - @property def pk_column(self) -> sqlalchemy.Column: return self.Meta.table.primary_key.columns.values()[0] @@ -177,7 +186,6 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass setattr(self, key, value) def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]: - if not self._is_conversion_to_json_needed(column_name): return value diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 6b04b6d..9ab8638 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -5,7 +5,6 @@ from sqlalchemy import text import ormar # noqa I100 from ormar.fields.foreign_key import ForeignKeyField -from ormar.queryset.relationship_crawler import RelationshipCrawler from ormar.relations import AliasManager if TYPE_CHECKING: # pragma no cover @@ -52,14 +51,7 @@ class Query: self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")] self.select_from = self.table - start_params = JoinParameters( - self.model_cls, "", self.table.name, self.model_cls - ) - - self._select_related = RelationshipCrawler().discover_relations( - self._select_related, prev_model=start_params.prev_model - ) - self._select_related.sort(key=lambda item: (-len(item), item)) + self._select_related.sort(key=lambda item: (item, -len(item))) for item in self._select_related: join_parameters = JoinParameters( diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 7ee3599..c22b6ad 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -138,7 +138,6 @@ class QuerySet: self.model_cls.from_row(row, select_related=self._select_related) for row in rows ] - result_rows = self.model_cls.merge_instances_list(result_rows) return result_rows diff --git a/ormar/queryset/relationship_crawler.py b/ormar/queryset/relationship_crawler.py deleted file mode 100644 index 7f8d055..0000000 --- a/ormar/queryset/relationship_crawler.py +++ /dev/null @@ -1,87 +0,0 @@ -from typing import List, TYPE_CHECKING, Type - -from ormar.fields import BaseField -from ormar.fields.foreign_key import ForeignKeyField - -if TYPE_CHECKING: # pragma no cover - from ormar import Model - - -class RelationshipCrawler: - def __init__(self) -> None: - self._select_related = [] - self.auto_related = [] - self.already_checked = [] - - def discover_relations( - self, select_related: List, prev_model: Type["Model"] - ) -> List[str]: - self._select_related = select_related - self._extract_auto_required_relations(prev_model=prev_model) - self._include_auto_related_models() - return self._select_related - - @staticmethod - def _field_is_a_foreign_key_and_no_circular_reference( - field: Type[BaseField], field_name: str, rel_part: str - ) -> bool: - return issubclass(field, ForeignKeyField) and field_name not in rel_part - - def _field_qualifies_to_deeper_search( - self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str - ) -> bool: - prev_part_of_related = "__".join(rel_part.split("__")[:-1]) - partial_match = any( - [x.startswith(prev_part_of_related) for x in self._select_related] - ) - already_checked = any( - [x.startswith(rel_part) for x in (self.auto_related + self.already_checked)] - ) - return ( - (field.virtual and parent_virtual) - or (partial_match and not already_checked) - ) or not nested - - def _extract_auto_required_relations( - self, - prev_model: Type["Model"], - rel_part: str = "", - nested: bool = False, - parent_virtual: bool = False, - ) -> None: - for field_name, field in prev_model.Meta.model_fields.items(): - if self._field_is_a_foreign_key_and_no_circular_reference( - field, field_name, rel_part - ): - rel_part = field_name if not rel_part else rel_part + "__" + field_name - if not field.nullable: - if rel_part not in self._select_related: - split_tables = rel_part.split("__") - new_related = ( - "__".join(split_tables[:-1]) - if len(split_tables) > 1 - else rel_part - ) - self.auto_related.append(new_related) - rel_part = "" - elif self._field_qualifies_to_deeper_search( - field, parent_virtual, nested, rel_part - ): - - self._extract_auto_required_relations( - prev_model=field.to, - rel_part=rel_part, - nested=True, - parent_virtual=field.virtual, - ) - else: - self.already_checked.append(rel_part) - rel_part = "" - - def _include_auto_related_models(self) -> None: - if self.auto_related: - new_joins = [] - for join in self._select_related: - if not any([x.startswith(join) for x in self.auto_related]): - new_joins.append(join) - self._select_related = new_joins + self.auto_related diff --git a/ormar/relations.py b/ormar/relations.py index 6177f26..df3dd26 100644 --- a/ormar/relations.py +++ b/ormar/relations.py @@ -1,23 +1,30 @@ -import pprint +import string import string import uuid +from enum import Enum from random import choices -from typing import List, TYPE_CHECKING, Union +from typing import List, TYPE_CHECKING, Type from weakref import proxy import sqlalchemy from sqlalchemy import text +from ormar.exceptions import RelationshipInstanceError from ormar.fields.foreign_key import ForeignKeyField # noqa I100 if TYPE_CHECKING: # pragma no cover - from ormar.models import NewBaseModel, Model + from ormar.models import Model def get_table_alias() -> str: return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] +class RelationType(Enum): + PRIMARY = 1 + REVERSE = 2 + + class AliasManager: def __init__(self) -> None: self._relations = dict() @@ -42,78 +49,97 @@ class AliasManager: table_name: str, ) -> None: if relations_key not in self._relations: - self._relations[relations_key] = {"type": "primary"} self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias() if reverse_key not in self._relations: - self._relations[reverse_key] = {"type": "reverse"} self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias() - def deregister(self, model: "NewBaseModel") -> None: - for rel_type in self._relations.keys(): - if model.get_name() in rel_type.lower(): - if model._orm_id in self._relations[rel_type]: - del self._relations[rel_type][model._orm_id] - - def add_relation( - self, - parent: "NewBaseModel", - child: "NewBaseModel", - child_model_name: str, - virtual: bool = False, - ) -> None: - parent_id, child_id = parent._orm_id, child._orm_id - parent_name = parent.get_name(title=True) - child_name = ( - child_model_name - if child.get_name() != child_model_name - else child.get_name() + "s" - ) - if virtual: - child_name, parent_name = parent_name, child.get_name() - child_id, parent_id = parent_id, child_id - child, parent = parent, proxy(child) - child_name = child_name.lower() + "s" - else: - child = proxy(child) - - parent_relation_name = parent_name.title() + "_" + child_name - parents_list = self._relations[parent_relation_name].setdefault(parent_id, []) - self.append_related_model(parents_list, child) - - child_relation_name = child.get_name(title=True) + "_" + parent_name.lower() - children_list = self._relations[child_relation_name].setdefault(child_id, []) - self.append_related_model(children_list, parent) - - @staticmethod - def append_related_model(relations_list: List["Model"], model: "Model") -> None: - for relation_child in relations_list: - try: - if relation_child.__same__(model): - return - except ReferenceError: - continue - - relations_list.append(model) - - def contains(self, relations_key: str, instance: "NewBaseModel") -> bool: - if relations_key in self._relations: - return instance._orm_id in self._relations[relations_key] - return False - - def get( - self, relations_key: str, instance: "NewBaseModel" - ) -> Union["Model", List["Model"]]: - if relations_key in self._relations: - if instance._orm_id in self._relations[relations_key]: - if self._relations[relations_key]["type"] == "primary": - return self._relations[relations_key][instance._orm_id][0] - return self._relations[relations_key][instance._orm_id] - def resolve_relation_join(self, from_table: str, to_table: str) -> str: return self._aliases.get(f"{from_table}_{to_table}", "") - def __str__(self) -> str: # pragma no cover - return pprint.pformat(self._relations, indent=4, width=1) - def __repr__(self) -> str: # pragma no cover - return self.__str__() +class Relation: + def __init__(self, type_: RelationType) -> None: + self._type = type_ + self.related_models = [] if type_ == RelationType.REVERSE else None + + def _find_existing(self, child): + for ind, relation_child in enumerate(self.related_models): + try: + if relation_child.__same__(child): + return ind + except ReferenceError: # pragma no cover + continue + return None + + def add(self, child: "Model") -> None: + if self._type == RelationType.PRIMARY: + self.related_models = child + else: + if self._find_existing(child) is None: + self.related_models.append(child) + + # def remove(self, child: "Model") -> None: + # if self._type == RelationType.PRIMARY: + # self.related_models = None + # else: + # position = self._find_existing(child) + # if position is not None: + # self.related_models.pop(position) + + def get(self): + return self.related_models + + +class RelationsManager: + def __init__( + self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None + ): + self.owner = owner + self._related_fields = related_fields or [] + self._related_names = [field.name for field in self._related_fields] + self._relations = dict() + for field in self._related_fields: + self._relations[field.name] = Relation( + type_=RelationType.PRIMARY + if not field.virtual + else RelationType.REVERSE + ) + + def __contains__(self, item): + return item in self._related_names + + def get(self, name): + relation = self._relations.get(name, None) + if relation: + return relation.get() + + def _get(self, name): + relation = self._relations.get(name, None) + if relation: + return relation + + def add(self, parent: "Model", child: "Model", child_name: str, virtual: bool): + to_field = next( + ( + field + for field in child._orm._related_fields + if field.to == parent.__class__ + ), + None, + ) + + if not to_field: # pragma no cover + raise RelationshipInstanceError( + f"Model {child.__class__} does not have reference to model {parent.__class__}" + ) + + to_name = to_field.name + if virtual: + child_name, to_name = to_name, child_name or child.get_name() + child, parent = parent, proxy(child) + else: + child_name = child_name or child.get_name() + "s" + child = proxy(child) + + parent._orm._get(child_name).add(child) + child._orm._get(to_name).add(parent) diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 6463eb6..5ae364c 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -131,7 +131,7 @@ async def test_model_crud(): album1 = await Album.objects.get(name="Malibu") assert album1.pk == 1 - assert album1.tracks is None + assert album1.tracks == [] await Track.objects.create(album={"id": track.album.pk}, title="The Bird2", position=4) diff --git a/tests/test_more_same_table_joins.py b/tests/test_more_same_table_joins.py new file mode 100644 index 0000000..21bf385 --- /dev/null +++ b/tests/test_more_same_table_joins.py @@ -0,0 +1,110 @@ +import asyncio + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Department(ormar.Model): + class Meta: + tablename = "departments" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True, autoincrement=False) + name: ormar.String(max_length=100) + + +class SchoolClass(ormar.Model): + class Meta: + tablename = "schoolclasses" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + + +class Category(ormar.Model): + class Meta: + tablename = "categories" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + department: ormar.ForeignKey(Department, nullable=False) + + +class Student(ormar.Model): + class Meta: + tablename = "students" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + schoolclass: ormar.ForeignKey(SchoolClass) + category: ormar.ForeignKey(Category, nullable=True) + + +class Teacher(ormar.Model): + class Meta: + tablename = "teachers" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + schoolclass: ormar.ForeignKey(SchoolClass) + category: ormar.ForeignKey(Category, nullable=True) + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@pytest.fixture(autouse=True, scope="module") +async def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + department = await Department.objects.create(id=1, name="Math Department") + department2 = await Department.objects.create(id=2, name="Law Department") + class1 = await SchoolClass.objects.create(name="Math") + class2 = await SchoolClass.objects.create(name="Logic") + category = await Category.objects.create(name="Foreign", department=department) + category2 = await Category.objects.create(name="Domestic", department=department2) + await Student.objects.create(name="Jane", category=category, schoolclass=class1) + await Student.objects.create(name="Judy", category=category2, schoolclass=class1) + await Student.objects.create(name="Jack", category=category2, schoolclass=class2) + await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_model_multiple_instances_of_same_table_in_schema(): + async with database: + classes = await SchoolClass.objects.select_related( + ["teachers__category__department", "students"] + ).all() + assert classes[0].name == "Math" + assert classes[0].students[0].name == "Jane" + assert len(classes[0].dict().get("students")) == 2 + assert classes[0].teachers[0].category.department.name == 'Law Department' + + assert classes[0].students[0].category.pk is not None + assert classes[0].students[0].category.name is None + await classes[0].students[0].category.load() + await classes[0].students[0].category.department.load() + assert classes[0].students[0].category.department.name == 'Math Department' diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 13e2185..5b8ffd0 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -79,11 +79,14 @@ async def create_test_database(): metadata.drop_all(engine) metadata.create_all(engine) department = await Department.objects.create(id=1, name="Math Department") + department2 = await Department.objects.create(id=2, name="Law Department") class1 = await SchoolClass.objects.create(name="Math", department=department) + class2 = await SchoolClass.objects.create(name="Logic", department=department2) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") await Student.objects.create(name="Jane", category=category, schoolclass=class1) - await Student.objects.create(name="Jack", category=category2, schoolclass=class1) + await Student.objects.create(name="Judy", category=category2, schoolclass=class1) + await Student.objects.create(name="Jack", category=category2, schoolclass=class2) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) yield metadata.drop_all(engine) @@ -100,15 +103,15 @@ async def test_model_multiple_instances_of_same_table_in_schema(): assert len(classes[0].dict().get("students")) == 2 - # related fields of main model are only populated by pk - # unless there is a required foreign key somewhere along the way - # since department is required for schoolclass it was pre loaded (again) - # but you can load them anytime + # since it's going from schoolclass => teacher => schoolclass (same class) department is already populated assert classes[0].students[0].schoolclass.name == "Math" assert classes[0].students[0].schoolclass.department.name is None await classes[0].students[0].schoolclass.department.load() assert classes[0].students[0].schoolclass.department.name == "Math Department" + await classes[1].students[0].schoolclass.department.load() + assert classes[1].students[0].schoolclass.department.name == "Law Department" + @pytest.mark.asyncio async def test_right_tables_join(): @@ -130,5 +133,7 @@ async def test_multiple_reverse_related_objects(): ["teachers__category", "students__category"] ).all() assert classes[0].name == "Math" - assert classes[0].students[1].name == "Jack" + assert classes[0].students[1].name == "Judy" + assert classes[0].students[0].category.name == "Foreign" + assert classes[0].students[1].category.name == "Domestic" assert classes[0].teachers[0].category.name == "Domestic"