remove auto related parsing, switch to relations on instance instead of relationship manager
This commit is contained in:
@ -107,9 +107,8 @@ class ForeignKeyField(BaseField):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_relation(cls, model: "Model", child: "Model") -> None:
|
def register_relation(cls, model: "Model", child: "Model") -> None:
|
||||||
child_model_name = cls.related_name or child.get_name()
|
model._orm.add(
|
||||||
model.Meta._orm_relationship_manager.add_relation(
|
parent=model, child=child, child_name=cls.related_name, virtual=cls.virtual
|
||||||
model, child, child_model_name, virtual=cls.virtual
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from typing import Any, List
|
import itertools
|
||||||
|
from typing import Any, List, Tuple, Union
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
@ -6,6 +7,21 @@ import ormar.queryset # noqa I100
|
|||||||
from ormar.models import NewBaseModel # noqa I100
|
from ormar.models import NewBaseModel # noqa I100
|
||||||
|
|
||||||
|
|
||||||
|
def group_related_list(list_):
|
||||||
|
test_dict = dict()
|
||||||
|
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
|
||||||
|
for key, group in grouped:
|
||||||
|
group_list = list(group)
|
||||||
|
new = [
|
||||||
|
"__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1
|
||||||
|
]
|
||||||
|
if any("__" in x for x in new):
|
||||||
|
test_dict[key] = group_related_list(new)
|
||||||
|
else:
|
||||||
|
test_dict[key] = new
|
||||||
|
return test_dict
|
||||||
|
|
||||||
|
|
||||||
class Model(NewBaseModel):
|
class Model(NewBaseModel):
|
||||||
__abstract__ = False
|
__abstract__ = False
|
||||||
|
|
||||||
@ -14,22 +30,27 @@ class Model(NewBaseModel):
|
|||||||
cls,
|
cls,
|
||||||
row: sqlalchemy.engine.ResultProxy,
|
row: sqlalchemy.engine.ResultProxy,
|
||||||
select_related: List = None,
|
select_related: List = None,
|
||||||
|
related_models: Any = None,
|
||||||
previous_table: str = None,
|
previous_table: str = None,
|
||||||
) -> "Model":
|
) -> Union["Model", Tuple["Model", dict]]:
|
||||||
|
|
||||||
item = {}
|
item = {}
|
||||||
select_related = select_related or []
|
select_related = select_related or []
|
||||||
|
related_models = related_models or []
|
||||||
|
if select_related:
|
||||||
|
related_models = group_related_list(select_related)
|
||||||
|
|
||||||
table_prefix = cls.Meta._orm_relationship_manager.resolve_relation_join(
|
table_prefix = cls.Meta._orm_relationship_manager.resolve_relation_join(
|
||||||
previous_table, cls.Meta.table.name
|
previous_table, cls.Meta.table.name
|
||||||
)
|
)
|
||||||
|
|
||||||
previous_table = cls.Meta.table.name
|
previous_table = cls.Meta.table.name
|
||||||
for related in select_related:
|
for related in related_models:
|
||||||
if "__" in related:
|
if isinstance(related_models, dict) and related_models[related]:
|
||||||
first_part, remainder = related.split("__", 1)
|
first_part, remainder = related, related_models[related]
|
||||||
model_cls = cls.Meta.model_fields[first_part].to
|
model_cls = cls.Meta.model_fields[first_part].to
|
||||||
child = model_cls.from_row(
|
child = model_cls.from_row(
|
||||||
row, select_related=[remainder], previous_table=previous_table
|
row, related_models=remainder, previous_table=previous_table
|
||||||
)
|
)
|
||||||
item[first_part] = child
|
item[first_part] = child
|
||||||
else:
|
else:
|
||||||
@ -43,7 +64,8 @@ class Model(NewBaseModel):
|
|||||||
f'{table_prefix + "_" if table_prefix else ""}{column.name}'
|
f'{table_prefix + "_" if table_prefix else ""}{column.name}'
|
||||||
]
|
]
|
||||||
|
|
||||||
return cls(**item)
|
instance = cls(**item) if item.get(cls.Meta.pkname, None) is not None else None
|
||||||
|
return instance
|
||||||
|
|
||||||
async def save(self) -> "Model":
|
async def save(self) -> "Model":
|
||||||
self_fields = self._extract_model_db_fields()
|
self_fields = self._extract_model_db_fields()
|
||||||
|
|||||||
@ -43,6 +43,18 @@ class ModelTableProxy:
|
|||||||
related_names.add(name)
|
related_names.add(name)
|
||||||
return related_names
|
return related_names
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_db_related_names(cls) -> Set:
|
||||||
|
related_names = set()
|
||||||
|
for name, field in cls.Meta.model_fields.items():
|
||||||
|
if (
|
||||||
|
inspect.isclass(field)
|
||||||
|
and issubclass(field, ForeignKeyField)
|
||||||
|
and not field.virtual
|
||||||
|
):
|
||||||
|
related_names.add(name)
|
||||||
|
return related_names
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
|
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
|
||||||
if nested:
|
if nested:
|
||||||
@ -62,7 +74,7 @@ class ModelTableProxy:
|
|||||||
self_fields = {
|
self_fields = {
|
||||||
k: v for k, v in self_fields.items() if k in self.Meta.table.columns
|
k: v for k, v in self_fields.items() if k in self.Meta.table.columns
|
||||||
}
|
}
|
||||||
for field in self._extract_related_names():
|
for field in self._extract_db_related_names():
|
||||||
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
|
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
|
||||||
if getattr(self, field) is not None:
|
if getattr(self, field) is not None:
|
||||||
self_fields[field] = getattr(getattr(self, field), target_pk_name)
|
self_fields[field] = getattr(getattr(self, field), target_pk_name)
|
||||||
@ -72,8 +84,8 @@ class ModelTableProxy:
|
|||||||
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
|
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
|
||||||
merged_rows = []
|
merged_rows = []
|
||||||
for index, model in enumerate(result_rows):
|
for index, model in enumerate(result_rows):
|
||||||
if index > 0 and model.pk == result_rows[index - 1].pk:
|
if index > 0 and model.pk == merged_rows[-1].pk:
|
||||||
result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
|
merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
|
||||||
else:
|
else:
|
||||||
merged_rows.append(model)
|
merged_rows.append(model)
|
||||||
return merged_rows
|
return merged_rows
|
||||||
|
|||||||
@ -20,9 +20,10 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
import ormar # noqa I100
|
import ormar # noqa I100
|
||||||
from ormar.fields import BaseField
|
from ormar.fields import BaseField
|
||||||
|
from ormar.fields.foreign_key import ForeignKeyField
|
||||||
from ormar.models.metaclass import ModelMeta, ModelMetaclass
|
from ormar.models.metaclass import ModelMeta, ModelMetaclass
|
||||||
from ormar.models.modelproxy import ModelTableProxy
|
from ormar.models.modelproxy import ModelTableProxy
|
||||||
from ormar.relations import AliasManager
|
from ormar.relations import AliasManager, RelationsManager
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar.models.model import Model
|
from ormar.models.model import Model
|
||||||
@ -34,7 +35,7 @@ if TYPE_CHECKING: # pragma no cover
|
|||||||
|
|
||||||
|
|
||||||
class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass):
|
class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass):
|
||||||
__slots__ = ("_orm_id", "_orm_saved")
|
__slots__ = ("_orm_id", "_orm_saved", "_orm")
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
__model_fields__: Dict[str, TypeVar[BaseField]]
|
__model_fields__: Dict[str, TypeVar[BaseField]]
|
||||||
@ -46,6 +47,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
__metadata__: sqlalchemy.MetaData
|
__metadata__: sqlalchemy.MetaData
|
||||||
__database__: databases.Database
|
__database__: databases.Database
|
||||||
_orm_relationship_manager: AliasManager
|
_orm_relationship_manager: AliasManager
|
||||||
|
_orm: RelationsManager
|
||||||
Meta: ModelMeta
|
Meta: ModelMeta
|
||||||
|
|
||||||
# noinspection PyMissingConstructor
|
# noinspection PyMissingConstructor
|
||||||
@ -53,6 +55,18 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
|
|
||||||
object.__setattr__(self, "_orm_id", uuid.uuid4().hex)
|
object.__setattr__(self, "_orm_id", uuid.uuid4().hex)
|
||||||
object.__setattr__(self, "_orm_saved", False)
|
object.__setattr__(self, "_orm_saved", False)
|
||||||
|
object.__setattr__(
|
||||||
|
self,
|
||||||
|
"_orm",
|
||||||
|
RelationsManager(
|
||||||
|
related_fields=[
|
||||||
|
field
|
||||||
|
for name, field in self.Meta.model_fields.items()
|
||||||
|
if issubclass(field, ForeignKeyField)
|
||||||
|
],
|
||||||
|
owner=self,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
pk_only = kwargs.pop("__pk_only__", False)
|
pk_only = kwargs.pop("__pk_only__", False)
|
||||||
if "pk" in kwargs:
|
if "pk" in kwargs:
|
||||||
@ -71,16 +85,12 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
object.__setattr__(self, "__dict__", values)
|
object.__setattr__(self, "__dict__", values)
|
||||||
object.__setattr__(self, "__fields_set__", fields_set)
|
object.__setattr__(self, "__fields_set__", fields_set)
|
||||||
|
|
||||||
def __del__(self) -> None:
|
|
||||||
self.Meta._orm_relationship_manager.deregister(self)
|
|
||||||
|
|
||||||
def __setattr__(self, name: str, value: Any) -> None:
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
relation_key = self.get_name(title=True) + "_" + name
|
|
||||||
if name in self.__slots__:
|
if name in self.__slots__:
|
||||||
object.__setattr__(self, name, value)
|
object.__setattr__(self, name, value)
|
||||||
elif name == "pk":
|
elif name == "pk":
|
||||||
object.__setattr__(self, self.Meta.pkname, value)
|
object.__setattr__(self, self.Meta.pkname, value)
|
||||||
elif self.Meta._orm_relationship_manager.contains(relation_key, self):
|
elif name in self._orm:
|
||||||
self.Meta.model_fields[name].expand_relationship(value, self)
|
self.Meta.model_fields[name].expand_relationship(value, self)
|
||||||
else:
|
else:
|
||||||
value = (
|
value = (
|
||||||
@ -91,24 +101,27 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
super().__setattr__(name, value)
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
def __getattribute__(self, item: str) -> Any:
|
def __getattribute__(self, item: str) -> Any:
|
||||||
if item != "__fields__" and item in self.__fields__:
|
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"):
|
||||||
related = self._extract_related_model_instead_of_field(item)
|
return object.__getattribute__(self, item)
|
||||||
if related:
|
elif item != "_extract_related_names" and item in self._extract_related_names():
|
||||||
return related
|
return self._extract_related_model_instead_of_field(item)
|
||||||
value = object.__getattribute__(self, item)
|
elif item == "pk":
|
||||||
|
return self.__dict__.get(self.Meta.pkname, None)
|
||||||
|
elif item != "__fields__" and item in self.__fields__:
|
||||||
|
value = self.__dict__.get(item, None)
|
||||||
value = self._convert_json(item, value, "loads")
|
value = self._convert_json(item, value, "loads")
|
||||||
return value
|
return value
|
||||||
return super().__getattribute__(item)
|
return super().__getattribute__(item)
|
||||||
|
|
||||||
def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]:
|
# def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]:
|
||||||
return self._extract_related_model_instead_of_field(item)
|
# return self._extract_related_model_instead_of_field(item)
|
||||||
|
|
||||||
def _extract_related_model_instead_of_field(
|
def _extract_related_model_instead_of_field(
|
||||||
self, item: str
|
self, item: str
|
||||||
) -> Optional[Union["Model", List["Model"]]]:
|
) -> Optional[Union["Model", List["Model"]]]:
|
||||||
relation_key = self.get_name(title=True) + "_" + item
|
# relation_key = self.get_name(title=True) + "_" + item
|
||||||
if self.Meta._orm_relationship_manager.contains(relation_key, self):
|
if item in self._orm:
|
||||||
return self.Meta._orm_relationship_manager.get(relation_key, self)
|
return self._orm.get(item)
|
||||||
|
|
||||||
def __same__(self, other: "Model") -> bool:
|
def __same__(self, other: "Model") -> bool:
|
||||||
if self.__class__ != other.__class__: # pragma no cover
|
if self.__class__ != other.__class__: # pragma no cover
|
||||||
@ -128,10 +141,6 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
name = name.title()
|
name = name.title()
|
||||||
return name
|
return name
|
||||||
|
|
||||||
@property
|
|
||||||
def pk(self) -> Any:
|
|
||||||
return getattr(self, self.Meta.pkname)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pk_column(self) -> sqlalchemy.Column:
|
def pk_column(self) -> sqlalchemy.Column:
|
||||||
return self.Meta.table.primary_key.columns.values()[0]
|
return self.Meta.table.primary_key.columns.values()[0]
|
||||||
@ -177,7 +186,6 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]:
|
def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]:
|
||||||
|
|
||||||
if not self._is_conversion_to_json_needed(column_name):
|
if not self._is_conversion_to_json_needed(column_name):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,6 @@ from sqlalchemy import text
|
|||||||
|
|
||||||
import ormar # noqa I100
|
import ormar # noqa I100
|
||||||
from ormar.fields.foreign_key import ForeignKeyField
|
from ormar.fields.foreign_key import ForeignKeyField
|
||||||
from ormar.queryset.relationship_crawler import RelationshipCrawler
|
|
||||||
from ormar.relations import AliasManager
|
from ormar.relations import AliasManager
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
@ -52,14 +51,7 @@ class Query:
|
|||||||
self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")]
|
self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")]
|
||||||
self.select_from = self.table
|
self.select_from = self.table
|
||||||
|
|
||||||
start_params = JoinParameters(
|
self._select_related.sort(key=lambda item: (item, -len(item)))
|
||||||
self.model_cls, "", self.table.name, self.model_cls
|
|
||||||
)
|
|
||||||
|
|
||||||
self._select_related = RelationshipCrawler().discover_relations(
|
|
||||||
self._select_related, prev_model=start_params.prev_model
|
|
||||||
)
|
|
||||||
self._select_related.sort(key=lambda item: (-len(item), item))
|
|
||||||
|
|
||||||
for item in self._select_related:
|
for item in self._select_related:
|
||||||
join_parameters = JoinParameters(
|
join_parameters = JoinParameters(
|
||||||
|
|||||||
@ -138,7 +138,6 @@ class QuerySet:
|
|||||||
self.model_cls.from_row(row, select_related=self._select_related)
|
self.model_cls.from_row(row, select_related=self._select_related)
|
||||||
for row in rows
|
for row in rows
|
||||||
]
|
]
|
||||||
|
|
||||||
result_rows = self.model_cls.merge_instances_list(result_rows)
|
result_rows = self.model_cls.merge_instances_list(result_rows)
|
||||||
|
|
||||||
return result_rows
|
return result_rows
|
||||||
|
|||||||
@ -1,87 +0,0 @@
|
|||||||
from typing import List, TYPE_CHECKING, Type
|
|
||||||
|
|
||||||
from ormar.fields import BaseField
|
|
||||||
from ormar.fields.foreign_key import ForeignKeyField
|
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
|
||||||
from ormar import Model
|
|
||||||
|
|
||||||
|
|
||||||
class RelationshipCrawler:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._select_related = []
|
|
||||||
self.auto_related = []
|
|
||||||
self.already_checked = []
|
|
||||||
|
|
||||||
def discover_relations(
|
|
||||||
self, select_related: List, prev_model: Type["Model"]
|
|
||||||
) -> List[str]:
|
|
||||||
self._select_related = select_related
|
|
||||||
self._extract_auto_required_relations(prev_model=prev_model)
|
|
||||||
self._include_auto_related_models()
|
|
||||||
return self._select_related
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _field_is_a_foreign_key_and_no_circular_reference(
|
|
||||||
field: Type[BaseField], field_name: str, rel_part: str
|
|
||||||
) -> bool:
|
|
||||||
return issubclass(field, ForeignKeyField) and field_name not in rel_part
|
|
||||||
|
|
||||||
def _field_qualifies_to_deeper_search(
|
|
||||||
self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str
|
|
||||||
) -> bool:
|
|
||||||
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
|
|
||||||
partial_match = any(
|
|
||||||
[x.startswith(prev_part_of_related) for x in self._select_related]
|
|
||||||
)
|
|
||||||
already_checked = any(
|
|
||||||
[x.startswith(rel_part) for x in (self.auto_related + self.already_checked)]
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
(field.virtual and parent_virtual)
|
|
||||||
or (partial_match and not already_checked)
|
|
||||||
) or not nested
|
|
||||||
|
|
||||||
def _extract_auto_required_relations(
|
|
||||||
self,
|
|
||||||
prev_model: Type["Model"],
|
|
||||||
rel_part: str = "",
|
|
||||||
nested: bool = False,
|
|
||||||
parent_virtual: bool = False,
|
|
||||||
) -> None:
|
|
||||||
for field_name, field in prev_model.Meta.model_fields.items():
|
|
||||||
if self._field_is_a_foreign_key_and_no_circular_reference(
|
|
||||||
field, field_name, rel_part
|
|
||||||
):
|
|
||||||
rel_part = field_name if not rel_part else rel_part + "__" + field_name
|
|
||||||
if not field.nullable:
|
|
||||||
if rel_part not in self._select_related:
|
|
||||||
split_tables = rel_part.split("__")
|
|
||||||
new_related = (
|
|
||||||
"__".join(split_tables[:-1])
|
|
||||||
if len(split_tables) > 1
|
|
||||||
else rel_part
|
|
||||||
)
|
|
||||||
self.auto_related.append(new_related)
|
|
||||||
rel_part = ""
|
|
||||||
elif self._field_qualifies_to_deeper_search(
|
|
||||||
field, parent_virtual, nested, rel_part
|
|
||||||
):
|
|
||||||
|
|
||||||
self._extract_auto_required_relations(
|
|
||||||
prev_model=field.to,
|
|
||||||
rel_part=rel_part,
|
|
||||||
nested=True,
|
|
||||||
parent_virtual=field.virtual,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.already_checked.append(rel_part)
|
|
||||||
rel_part = ""
|
|
||||||
|
|
||||||
def _include_auto_related_models(self) -> None:
|
|
||||||
if self.auto_related:
|
|
||||||
new_joins = []
|
|
||||||
for join in self._select_related:
|
|
||||||
if not any([x.startswith(join) for x in self.auto_related]):
|
|
||||||
new_joins.append(join)
|
|
||||||
self._select_related = new_joins + self.auto_related
|
|
||||||
@ -1,23 +1,30 @@
|
|||||||
import pprint
|
import string
|
||||||
import string
|
import string
|
||||||
import uuid
|
import uuid
|
||||||
|
from enum import Enum
|
||||||
from random import choices
|
from random import choices
|
||||||
from typing import List, TYPE_CHECKING, Union
|
from typing import List, TYPE_CHECKING, Type
|
||||||
from weakref import proxy
|
from weakref import proxy
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from ormar.exceptions import RelationshipInstanceError
|
||||||
from ormar.fields.foreign_key import ForeignKeyField # noqa I100
|
from ormar.fields.foreign_key import ForeignKeyField # noqa I100
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar.models import NewBaseModel, Model
|
from ormar.models import Model
|
||||||
|
|
||||||
|
|
||||||
def get_table_alias() -> str:
|
def get_table_alias() -> str:
|
||||||
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
|
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
|
||||||
|
|
||||||
|
|
||||||
|
class RelationType(Enum):
|
||||||
|
PRIMARY = 1
|
||||||
|
REVERSE = 2
|
||||||
|
|
||||||
|
|
||||||
class AliasManager:
|
class AliasManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._relations = dict()
|
self._relations = dict()
|
||||||
@ -42,78 +49,97 @@ class AliasManager:
|
|||||||
table_name: str,
|
table_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
if relations_key not in self._relations:
|
if relations_key not in self._relations:
|
||||||
self._relations[relations_key] = {"type": "primary"}
|
|
||||||
self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias()
|
self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias()
|
||||||
if reverse_key not in self._relations:
|
if reverse_key not in self._relations:
|
||||||
self._relations[reverse_key] = {"type": "reverse"}
|
|
||||||
self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias()
|
self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias()
|
||||||
|
|
||||||
def deregister(self, model: "NewBaseModel") -> None:
|
|
||||||
for rel_type in self._relations.keys():
|
|
||||||
if model.get_name() in rel_type.lower():
|
|
||||||
if model._orm_id in self._relations[rel_type]:
|
|
||||||
del self._relations[rel_type][model._orm_id]
|
|
||||||
|
|
||||||
def add_relation(
|
|
||||||
self,
|
|
||||||
parent: "NewBaseModel",
|
|
||||||
child: "NewBaseModel",
|
|
||||||
child_model_name: str,
|
|
||||||
virtual: bool = False,
|
|
||||||
) -> None:
|
|
||||||
parent_id, child_id = parent._orm_id, child._orm_id
|
|
||||||
parent_name = parent.get_name(title=True)
|
|
||||||
child_name = (
|
|
||||||
child_model_name
|
|
||||||
if child.get_name() != child_model_name
|
|
||||||
else child.get_name() + "s"
|
|
||||||
)
|
|
||||||
if virtual:
|
|
||||||
child_name, parent_name = parent_name, child.get_name()
|
|
||||||
child_id, parent_id = parent_id, child_id
|
|
||||||
child, parent = parent, proxy(child)
|
|
||||||
child_name = child_name.lower() + "s"
|
|
||||||
else:
|
|
||||||
child = proxy(child)
|
|
||||||
|
|
||||||
parent_relation_name = parent_name.title() + "_" + child_name
|
|
||||||
parents_list = self._relations[parent_relation_name].setdefault(parent_id, [])
|
|
||||||
self.append_related_model(parents_list, child)
|
|
||||||
|
|
||||||
child_relation_name = child.get_name(title=True) + "_" + parent_name.lower()
|
|
||||||
children_list = self._relations[child_relation_name].setdefault(child_id, [])
|
|
||||||
self.append_related_model(children_list, parent)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def append_related_model(relations_list: List["Model"], model: "Model") -> None:
|
|
||||||
for relation_child in relations_list:
|
|
||||||
try:
|
|
||||||
if relation_child.__same__(model):
|
|
||||||
return
|
|
||||||
except ReferenceError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
relations_list.append(model)
|
|
||||||
|
|
||||||
def contains(self, relations_key: str, instance: "NewBaseModel") -> bool:
|
|
||||||
if relations_key in self._relations:
|
|
||||||
return instance._orm_id in self._relations[relations_key]
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get(
|
|
||||||
self, relations_key: str, instance: "NewBaseModel"
|
|
||||||
) -> Union["Model", List["Model"]]:
|
|
||||||
if relations_key in self._relations:
|
|
||||||
if instance._orm_id in self._relations[relations_key]:
|
|
||||||
if self._relations[relations_key]["type"] == "primary":
|
|
||||||
return self._relations[relations_key][instance._orm_id][0]
|
|
||||||
return self._relations[relations_key][instance._orm_id]
|
|
||||||
|
|
||||||
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
|
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
|
||||||
return self._aliases.get(f"{from_table}_{to_table}", "")
|
return self._aliases.get(f"{from_table}_{to_table}", "")
|
||||||
|
|
||||||
def __str__(self) -> str: # pragma no cover
|
|
||||||
return pprint.pformat(self._relations, indent=4, width=1)
|
|
||||||
|
|
||||||
def __repr__(self) -> str: # pragma no cover
|
class Relation:
|
||||||
return self.__str__()
|
def __init__(self, type_: RelationType) -> None:
|
||||||
|
self._type = type_
|
||||||
|
self.related_models = [] if type_ == RelationType.REVERSE else None
|
||||||
|
|
||||||
|
def _find_existing(self, child):
|
||||||
|
for ind, relation_child in enumerate(self.related_models):
|
||||||
|
try:
|
||||||
|
if relation_child.__same__(child):
|
||||||
|
return ind
|
||||||
|
except ReferenceError: # pragma no cover
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add(self, child: "Model") -> None:
|
||||||
|
if self._type == RelationType.PRIMARY:
|
||||||
|
self.related_models = child
|
||||||
|
else:
|
||||||
|
if self._find_existing(child) is None:
|
||||||
|
self.related_models.append(child)
|
||||||
|
|
||||||
|
# def remove(self, child: "Model") -> None:
|
||||||
|
# if self._type == RelationType.PRIMARY:
|
||||||
|
# self.related_models = None
|
||||||
|
# else:
|
||||||
|
# position = self._find_existing(child)
|
||||||
|
# if position is not None:
|
||||||
|
# self.related_models.pop(position)
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
return self.related_models
|
||||||
|
|
||||||
|
|
||||||
|
class RelationsManager:
|
||||||
|
def __init__(
|
||||||
|
self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None
|
||||||
|
):
|
||||||
|
self.owner = owner
|
||||||
|
self._related_fields = related_fields or []
|
||||||
|
self._related_names = [field.name for field in self._related_fields]
|
||||||
|
self._relations = dict()
|
||||||
|
for field in self._related_fields:
|
||||||
|
self._relations[field.name] = Relation(
|
||||||
|
type_=RelationType.PRIMARY
|
||||||
|
if not field.virtual
|
||||||
|
else RelationType.REVERSE
|
||||||
|
)
|
||||||
|
|
||||||
|
def __contains__(self, item):
|
||||||
|
return item in self._related_names
|
||||||
|
|
||||||
|
def get(self, name):
|
||||||
|
relation = self._relations.get(name, None)
|
||||||
|
if relation:
|
||||||
|
return relation.get()
|
||||||
|
|
||||||
|
def _get(self, name):
|
||||||
|
relation = self._relations.get(name, None)
|
||||||
|
if relation:
|
||||||
|
return relation
|
||||||
|
|
||||||
|
def add(self, parent: "Model", child: "Model", child_name: str, virtual: bool):
|
||||||
|
to_field = next(
|
||||||
|
(
|
||||||
|
field
|
||||||
|
for field in child._orm._related_fields
|
||||||
|
if field.to == parent.__class__
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not to_field: # pragma no cover
|
||||||
|
raise RelationshipInstanceError(
|
||||||
|
f"Model {child.__class__} does not have reference to model {parent.__class__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
to_name = to_field.name
|
||||||
|
if virtual:
|
||||||
|
child_name, to_name = to_name, child_name or child.get_name()
|
||||||
|
child, parent = parent, proxy(child)
|
||||||
|
else:
|
||||||
|
child_name = child_name or child.get_name() + "s"
|
||||||
|
child = proxy(child)
|
||||||
|
|
||||||
|
parent._orm._get(child_name).add(child)
|
||||||
|
child._orm._get(to_name).add(parent)
|
||||||
|
|||||||
@ -131,7 +131,7 @@ async def test_model_crud():
|
|||||||
|
|
||||||
album1 = await Album.objects.get(name="Malibu")
|
album1 = await Album.objects.get(name="Malibu")
|
||||||
assert album1.pk == 1
|
assert album1.pk == 1
|
||||||
assert album1.tracks is None
|
assert album1.tracks == []
|
||||||
|
|
||||||
await Track.objects.create(album={"id": track.album.pk}, title="The Bird2", position=4)
|
await Track.objects.create(album={"id": track.album.pk}, title="The Bird2", position=4)
|
||||||
|
|
||||||
|
|||||||
110
tests/test_more_same_table_joins.py
Normal file
110
tests/test_more_same_table_joins.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
import databases
|
||||||
|
import pytest
|
||||||
|
import sqlalchemy
|
||||||
|
|
||||||
|
import ormar
|
||||||
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
|
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||||
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
|
|
||||||
|
class Department(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "departments"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True, autoincrement=False)
|
||||||
|
name: ormar.String(max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class SchoolClass(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "schoolclasses"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
name: ormar.String(max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class Category(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "categories"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
name: ormar.String(max_length=100)
|
||||||
|
department: ormar.ForeignKey(Department, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Student(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "students"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
name: ormar.String(max_length=100)
|
||||||
|
schoolclass: ormar.ForeignKey(SchoolClass)
|
||||||
|
category: ormar.ForeignKey(Category, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Teacher(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "teachers"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
name: ormar.String(max_length=100)
|
||||||
|
schoolclass: ormar.ForeignKey(SchoolClass)
|
||||||
|
category: ormar.ForeignKey(Category, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def event_loop():
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True, scope="module")
|
||||||
|
async def create_test_database():
|
||||||
|
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||||
|
metadata.drop_all(engine)
|
||||||
|
metadata.create_all(engine)
|
||||||
|
department = await Department.objects.create(id=1, name="Math Department")
|
||||||
|
department2 = await Department.objects.create(id=2, name="Law Department")
|
||||||
|
class1 = await SchoolClass.objects.create(name="Math")
|
||||||
|
class2 = await SchoolClass.objects.create(name="Logic")
|
||||||
|
category = await Category.objects.create(name="Foreign", department=department)
|
||||||
|
category2 = await Category.objects.create(name="Domestic", department=department2)
|
||||||
|
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
|
||||||
|
await Student.objects.create(name="Judy", category=category2, schoolclass=class1)
|
||||||
|
await Student.objects.create(name="Jack", category=category2, schoolclass=class2)
|
||||||
|
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
|
||||||
|
yield
|
||||||
|
metadata.drop_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_multiple_instances_of_same_table_in_schema():
|
||||||
|
async with database:
|
||||||
|
classes = await SchoolClass.objects.select_related(
|
||||||
|
["teachers__category__department", "students"]
|
||||||
|
).all()
|
||||||
|
assert classes[0].name == "Math"
|
||||||
|
assert classes[0].students[0].name == "Jane"
|
||||||
|
assert len(classes[0].dict().get("students")) == 2
|
||||||
|
assert classes[0].teachers[0].category.department.name == 'Law Department'
|
||||||
|
|
||||||
|
assert classes[0].students[0].category.pk is not None
|
||||||
|
assert classes[0].students[0].category.name is None
|
||||||
|
await classes[0].students[0].category.load()
|
||||||
|
await classes[0].students[0].category.department.load()
|
||||||
|
assert classes[0].students[0].category.department.name == 'Math Department'
|
||||||
@ -79,11 +79,14 @@ async def create_test_database():
|
|||||||
metadata.drop_all(engine)
|
metadata.drop_all(engine)
|
||||||
metadata.create_all(engine)
|
metadata.create_all(engine)
|
||||||
department = await Department.objects.create(id=1, name="Math Department")
|
department = await Department.objects.create(id=1, name="Math Department")
|
||||||
|
department2 = await Department.objects.create(id=2, name="Law Department")
|
||||||
class1 = await SchoolClass.objects.create(name="Math", department=department)
|
class1 = await SchoolClass.objects.create(name="Math", department=department)
|
||||||
|
class2 = await SchoolClass.objects.create(name="Logic", department=department2)
|
||||||
category = await Category.objects.create(name="Foreign")
|
category = await Category.objects.create(name="Foreign")
|
||||||
category2 = await Category.objects.create(name="Domestic")
|
category2 = await Category.objects.create(name="Domestic")
|
||||||
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
|
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
|
||||||
await Student.objects.create(name="Jack", category=category2, schoolclass=class1)
|
await Student.objects.create(name="Judy", category=category2, schoolclass=class1)
|
||||||
|
await Student.objects.create(name="Jack", category=category2, schoolclass=class2)
|
||||||
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
|
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
|
||||||
yield
|
yield
|
||||||
metadata.drop_all(engine)
|
metadata.drop_all(engine)
|
||||||
@ -100,15 +103,15 @@ async def test_model_multiple_instances_of_same_table_in_schema():
|
|||||||
|
|
||||||
assert len(classes[0].dict().get("students")) == 2
|
assert len(classes[0].dict().get("students")) == 2
|
||||||
|
|
||||||
# related fields of main model are only populated by pk
|
# since it's going from schoolclass => teacher => schoolclass (same class) department is already populated
|
||||||
# unless there is a required foreign key somewhere along the way
|
|
||||||
# since department is required for schoolclass it was pre loaded (again)
|
|
||||||
# but you can load them anytime
|
|
||||||
assert classes[0].students[0].schoolclass.name == "Math"
|
assert classes[0].students[0].schoolclass.name == "Math"
|
||||||
assert classes[0].students[0].schoolclass.department.name is None
|
assert classes[0].students[0].schoolclass.department.name is None
|
||||||
await classes[0].students[0].schoolclass.department.load()
|
await classes[0].students[0].schoolclass.department.load()
|
||||||
assert classes[0].students[0].schoolclass.department.name == "Math Department"
|
assert classes[0].students[0].schoolclass.department.name == "Math Department"
|
||||||
|
|
||||||
|
await classes[1].students[0].schoolclass.department.load()
|
||||||
|
assert classes[1].students[0].schoolclass.department.name == "Law Department"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_right_tables_join():
|
async def test_right_tables_join():
|
||||||
@ -130,5 +133,7 @@ async def test_multiple_reverse_related_objects():
|
|||||||
["teachers__category", "students__category"]
|
["teachers__category", "students__category"]
|
||||||
).all()
|
).all()
|
||||||
assert classes[0].name == "Math"
|
assert classes[0].name == "Math"
|
||||||
assert classes[0].students[1].name == "Jack"
|
assert classes[0].students[1].name == "Judy"
|
||||||
|
assert classes[0].students[0].category.name == "Foreign"
|
||||||
|
assert classes[0].students[1].category.name == "Domestic"
|
||||||
assert classes[0].teachers[0].category.name == "Domestic"
|
assert classes[0].teachers[0].category.name == "Domestic"
|
||||||
|
|||||||
Reference in New Issue
Block a user