diff --git a/.coverage b/.coverage index 73c7553..86117f6 100644 Binary files a/.coverage and b/.coverage differ diff --git a/ormar/__init__.py b/ormar/__init__.py index ed4ab31..bfa2e87 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -16,6 +16,7 @@ from ormar.fields import ( ) from ormar.models import Model from ormar.queryset import QuerySet +from ormar.relations import RelationType __version__ = "0.3.0" __all__ = [ diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index d30901c..ee94013 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -24,7 +24,7 @@ from ormar.fields.foreign_key import ForeignKeyField from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.modelproxy import ModelTableProxy from ormar.relations.alias_manager import AliasManager -from ormar.relations.relation import RelationsManager +from ormar.relations.relation_manager import RelationsManager if TYPE_CHECKING: # pragma no cover from ormar.models.model import Model diff --git a/ormar/relations/__init__.py b/ormar/relations/__init__.py index 788d2f5..54860f5 100644 --- a/ormar/relations/__init__.py +++ b/ormar/relations/__init__.py @@ -1,4 +1,5 @@ -from ormar.relations.relation import Relation +from ormar.relations.relation import Relation, RelationType from ormar.relations.alias_manager import AliasManager +from ormar.relations.relation_manager import RelationsManager -__all__ = ["AliasManager"] +__all__ = ["AliasManager", "Relation", "RelationsManager", "RelationType"] diff --git a/ormar/relations/relation.py b/ormar/relations/relation.py index bd97cb1..5ec709c 100644 --- a/ormar/relations/relation.py +++ b/ormar/relations/relation.py @@ -1,15 +1,14 @@ from enum import Enum -from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, Union -from weakref import proxy +from typing import List, Optional, TYPE_CHECKING, Type, Union import ormar # noqa I100 from ormar.exceptions import RelationshipInstanceError # noqa I100 from ormar.fields.foreign_key import ForeignKeyField # noqa I100 -from ormar.fields.many_to_many import ManyToManyField -from ormar.relations.querysetproxy import QuerysetProxy +from ormar.relations.relation_proxy import RelationProxy if TYPE_CHECKING: # pragma no cover from ormar import Model + from ormar.relations import RelationsManager class RelationType(Enum): @@ -18,58 +17,6 @@ class RelationType(Enum): MULTIPLE = 3 -class RelationProxy(list): - def __init__(self, relation: "Relation") -> None: - super(RelationProxy, self).__init__() - self.relation = relation - self._owner = self.relation.manager.owner - self.queryset_proxy = QuerysetProxy(relation=self.relation) - - def __getattribute__(self, item: str) -> Any: - if item in ["count", "clear"]: - if not self.queryset_proxy.queryset: - self.queryset_proxy.queryset = self._set_queryset() - return getattr(self.queryset_proxy, item) - return super().__getattribute__(item) - - def __getattr__(self, item: str) -> Any: - if not self.queryset_proxy.queryset: - self.queryset_proxy.queryset = self._set_queryset() - return getattr(self.queryset_proxy, item) - - def _set_queryset(self) -> "QuerySet": - owner_table = self.relation._owner.Meta.tablename - pkname = self.relation._owner.Meta.pkname - pk_value = self.relation._owner.pk - if not pk_value: - raise RelationshipInstanceError( - "You cannot query many to many relationship on unsaved model." - ) - kwargs = {f"{owner_table}__{pkname}": pk_value} - queryset = ( - ormar.QuerySet(model_cls=self.relation.to) - .select_related(owner_table) - .filter(**kwargs) - ) - return queryset - - async def remove(self, item: "Model") -> None: - super().remove(item) - rel_name = item.resolve_relation_name(item, self._owner) - item._orm._get(rel_name).remove(self._owner) - if self.relation._type == RelationType.MULTIPLE: - await self.queryset_proxy.delete_through_instance(item) - - def append(self, item: "Model") -> None: - super().append(item) - - async def add(self, item: "Model") -> None: - if self.relation._type == RelationType.MULTIPLE: - await self.queryset_proxy.create_through_instance(item) - rel_name = item.resolve_relation_name(item, self._owner) - setattr(item, rel_name, self._owner) - - class Relation: def __init__( self, @@ -132,120 +79,3 @@ class Relation: return str(self.related_models) -class RelationsManager: - def __init__( - self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None - ) -> None: - self.owner = proxy(owner) - self._related_fields = related_fields or [] - self._related_names = [field.name for field in self._related_fields] - self._relations = dict() - for field in self._related_fields: - self._add_relation(field) - - def _get_relation_type(self, field: Type[ForeignKeyField]) -> RelationType: - if issubclass(field, ManyToManyField): - return RelationType.MULTIPLE - return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE - - def _add_relation(self, field: Type[ForeignKeyField]) -> None: - self._relations[field.name] = Relation( - manager=self, - type_=self._get_relation_type(field), - to=field.to, - through=getattr(field, "through", None), - ) - - def __contains__(self, item: str) -> bool: - return item in self._related_names - - def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]: - relation = self._relations.get(name, None) - if relation is not None: - return relation.get() - - def _get(self, name: str) -> Optional[Relation]: - relation = self._relations.get(name, None) - if relation is not None: - return relation - - @staticmethod - def register_missing_relation( - parent: "Model", child: "Model", child_name: str - ) -> Relation: - ormar.models.expand_reverse_relationships(child.__class__) - name = parent.resolve_relation_name(parent, child) - field = parent.Meta.model_fields[name] - parent._orm._add_relation(field) - parent_relation = parent._orm._get(child_name) - return parent_relation - - @staticmethod - def get_relations_sides_and_names( - to_field: Type[ForeignKeyField], - parent: "Model", - child: "Model", - child_name: str, - virtual: bool, - ) -> Tuple["Model", "Model", str, str]: - to_name = to_field.name - if issubclass(to_field, ManyToManyField): - child_name, to_name = ( - child.resolve_relation_name(parent, child), - child.resolve_relation_name(child, parent), - ) - child = proxy(child) - elif virtual: - child_name, to_name = to_name, child_name or child.get_name() - child, parent = parent, proxy(child) - else: - child_name = child_name or child.get_name() + "s" - child = proxy(child) - return parent, child, child_name, to_name - - @staticmethod - def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None: - to_field = next( - ( - field - for field in child._orm._related_fields - if field.to == parent.__class__ or field.to.Meta == parent.Meta - ), - None, - ) - - if not to_field: # pragma no cover - raise RelationshipInstanceError( - f"Model {child.__class__} does not have " - f"reference to model {parent.__class__}" - ) - - ( - parent, - child, - child_name, - to_name, - ) = RelationsManager.get_relations_sides_and_names( - to_field, parent, child, child_name, virtual - ) - - parent_relation = parent._orm._get(child_name) - if not parent_relation: - parent_relation = RelationsManager.register_missing_relation( - parent, child, child_name - ) - parent_relation.add(child) - child._orm._get(to_name).add(parent) - - def remove(self, name: str, child: "Model") -> None: - relation = self._get(name) - relation.remove(child) - - @staticmethod - def remove_parent(item: "Model", name: Union[str, "Model"]) -> None: - related_model = name - name = item.resolve_relation_name(item, related_model) - if name in item._orm: - relation_name = item.resolve_relation_name(related_model, item) - item._orm.remove(name, related_model) - related_model._orm.remove(relation_name, item) diff --git a/ormar/relations/relation_manager.py b/ormar/relations/relation_manager.py new file mode 100644 index 0000000..8dc1e21 --- /dev/null +++ b/ormar/relations/relation_manager.py @@ -0,0 +1,128 @@ +from _weakref import proxy +from typing import List, Type, Optional, Union, Tuple + +import ormar +from ormar.exceptions import RelationshipInstanceError +from ormar.fields.foreign_key import ForeignKeyField +from ormar.fields.many_to_many import ManyToManyField +from ormar.relations import Relation +from ormar.relations.relation import RelationType + + +class RelationsManager: + def __init__( + self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None + ) -> None: + self.owner = proxy(owner) + self._related_fields = related_fields or [] + self._related_names = [field.name for field in self._related_fields] + self._relations = dict() + for field in self._related_fields: + self._add_relation(field) + + def _get_relation_type(self, field: Type[ForeignKeyField]) -> RelationType: + if issubclass(field, ManyToManyField): + return RelationType.MULTIPLE + return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE + + def _add_relation(self, field: Type[ForeignKeyField]) -> None: + self._relations[field.name] = Relation( + manager=self, + type_=self._get_relation_type(field), + to=field.to, + through=getattr(field, "through", None), + ) + + def __contains__(self, item: str) -> bool: + return item in self._related_names + + def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]: + relation = self._relations.get(name, None) + if relation is not None: + return relation.get() + + def _get(self, name: str) -> Optional[Relation]: + relation = self._relations.get(name, None) + if relation is not None: + return relation + + @staticmethod + def register_missing_relation( + parent: "Model", child: "Model", child_name: str + ) -> Relation: + ormar.models.expand_reverse_relationships(child.__class__) + name = parent.resolve_relation_name(parent, child) + field = parent.Meta.model_fields[name] + parent._orm._add_relation(field) + parent_relation = parent._orm._get(child_name) + return parent_relation + + @staticmethod + def get_relations_sides_and_names( + to_field: Type[ForeignKeyField], + parent: "Model", + child: "Model", + child_name: str, + virtual: bool, + ) -> Tuple["Model", "Model", str, str]: + to_name = to_field.name + if issubclass(to_field, ManyToManyField): + child_name, to_name = ( + child.resolve_relation_name(parent, child), + child.resolve_relation_name(child, parent), + ) + child = proxy(child) + elif virtual: + child_name, to_name = to_name, child_name or child.get_name() + child, parent = parent, proxy(child) + else: + child_name = child_name or child.get_name() + "s" + child = proxy(child) + return parent, child, child_name, to_name + + @staticmethod + def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None: + to_field = next( + ( + field + for field in child._orm._related_fields + if field.to == parent.__class__ or field.to.Meta == parent.Meta + ), + None, + ) + + if not to_field: # pragma no cover + raise RelationshipInstanceError( + f"Model {child.__class__} does not have " + f"reference to model {parent.__class__}" + ) + + ( + parent, + child, + child_name, + to_name, + ) = RelationsManager.get_relations_sides_and_names( + to_field, parent, child, child_name, virtual + ) + + parent_relation = parent._orm._get(child_name) + if not parent_relation: + parent_relation = RelationsManager.register_missing_relation( + parent, child, child_name + ) + parent_relation.add(child) + child._orm._get(to_name).add(parent) + + def remove(self, name: str, child: "Model") -> None: + relation = self._get(name) + relation.remove(child) + + @staticmethod + def remove_parent(item: "Model", name: Union[str, "Model"]) -> None: + related_model = name + name = item.resolve_relation_name(item, related_model) + if name in item._orm: + relation_name = item.resolve_relation_name(related_model, item) + item._orm.remove(name, related_model) + related_model._orm.remove(relation_name, item) \ No newline at end of file diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py new file mode 100644 index 0000000..a3bdebd --- /dev/null +++ b/ormar/relations/relation_proxy.py @@ -0,0 +1,62 @@ +from typing import Any, TYPE_CHECKING + +import ormar +from ormar.exceptions import RelationshipInstanceError +from ormar.relations.querysetproxy import QuerysetProxy + +if TYPE_CHECKING: # pragma no cover + from ormar import Model + from ormar.relations import Relation + from ormar.queryset import QuerySet + + +class RelationProxy(list): + def __init__(self, relation: "Relation") -> None: + super(RelationProxy, self).__init__() + self.relation = relation + self._owner = self.relation.manager.owner + self.queryset_proxy = QuerysetProxy(relation=self.relation) + + def __getattribute__(self, item: str) -> Any: + if item in ["count", "clear"]: + if not self.queryset_proxy.queryset: + self.queryset_proxy.queryset = self._set_queryset() + return getattr(self.queryset_proxy, item) + return super().__getattribute__(item) + + def __getattr__(self, item: str) -> Any: + if not self.queryset_proxy.queryset: + self.queryset_proxy.queryset = self._set_queryset() + return getattr(self.queryset_proxy, item) + + def _set_queryset(self) -> "QuerySet": + owner_table = self.relation._owner.Meta.tablename + pkname = self.relation._owner.Meta.pkname + pk_value = self.relation._owner.pk + if not pk_value: + raise RelationshipInstanceError( + "You cannot query many to many relationship on unsaved model." + ) + kwargs = {f"{owner_table}__{pkname}": pk_value} + queryset = ( + ormar.QuerySet(model_cls=self.relation.to) + .select_related(owner_table) + .filter(**kwargs) + ) + return queryset + + async def remove(self, item: "Model") -> None: + super().remove(item) + rel_name = item.resolve_relation_name(item, self._owner) + item._orm._get(rel_name).remove(self._owner) + if self.relation._type == ormar.RelationType.MULTIPLE: + await self.queryset_proxy.delete_through_instance(item) + + def append(self, item: "Model") -> None: + super().append(item) + + async def add(self, item: "Model") -> None: + if self.relation._type == ormar.RelationType.MULTIPLE: + await self.queryset_proxy.create_through_instance(item) + rel_name = item.resolve_relation_name(item, self._owner) + setattr(item, rel_name, self._owner)