Files
ormar/ormar/relations/relation_manager.py
2020-12-15 11:55:07 +01:00

100 lines
3.5 KiB
Python

from typing import Dict, List, Optional, Sequence, TYPE_CHECKING, Type, TypeVar, Union
from weakref import proxy
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.fields.many_to_many import ManyToManyField
from ormar.relations.relation import Relation, RelationType
from ormar.relations.utils import get_relations_sides_and_names
if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.models import NewBaseModel
T = TypeVar("T", bound=Model)
class RelationsManager:
def __init__(
self,
related_fields: List[Type[ForeignKeyField]] = None,
owner: "NewBaseModel" = None,
) -> None:
self.owner = proxy(owner)
self._related_fields = related_fields or []
self._related_names = [field.name for field in self._related_fields]
self._relations: Dict[str, Relation] = dict()
for field in self._related_fields:
self._add_relation(field)
def _get_relation_type(self, field: Type[BaseField]) -> RelationType:
if issubclass(field, ManyToManyField):
return RelationType.MULTIPLE
return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE
def _add_relation(self, field: Type[BaseField]) -> None:
self._relations[field.name] = Relation(
manager=self,
type_=self._get_relation_type(field),
field_name=field.name,
to=field.to,
through=getattr(field, "through", None),
)
def __contains__(self, item: str) -> bool:
return item in self._related_names
def get(self, name: str) -> Optional[Union["T", Sequence["T"]]]:
relation = self._relations.get(name, None)
if relation is not None:
return relation.get()
return None # pragma nocover
def _get(self, name: str) -> Optional[Relation]:
relation = self._relations.get(name, None)
if relation is not None:
return relation
return None
@staticmethod
def add(
parent: "Model",
child: "Model",
child_name: str,
virtual: bool,
relation_name: str,
) -> None:
to_field: Type[BaseField] = child.Meta.model_fields[relation_name]
print('comming', child_name, relation_name)
(parent, child, child_name, to_name,) = get_relations_sides_and_names(
to_field, parent, child, child_name, virtual
)
print('adding', parent.get_name(), child.get_name(), child_name)
parent_relation = parent._orm._get(child_name)
if parent_relation:
parent_relation.add(child) # type: ignore
print('adding', child.get_name(), parent.get_name(), child_name)
child_relation = child._orm._get(to_name)
if child_relation:
child_relation.add(parent)
def remove(
self, name: str, child: Union["NewBaseModel", Type["NewBaseModel"]]
) -> None:
relation = self._get(name)
if relation:
relation.remove(child)
@staticmethod
def remove_parent(
item: Union["NewBaseModel", Type["NewBaseModel"]], name: "Model"
) -> None:
related_model = name
rel_name = item.resolve_relation_name(item, related_model)
if rel_name in item._orm:
relation_name = item.resolve_relation_name(related_model, item)
item._orm.remove(rel_name, related_model)
related_model._orm.remove(relation_name, item)