working basic many to many relationships

This commit is contained in:
collerek
2020-09-14 17:13:27 +02:00
parent 58c3627be7
commit 4674f625df
18 changed files with 791 additions and 244 deletions

BIN
.coverage

Binary file not shown.

View File

@ -9,6 +9,7 @@ from ormar.fields import (
ForeignKey,
Integer,
JSON,
ManyToMany,
String,
Text,
Time,
@ -28,6 +29,7 @@ __all__ = [
"Date",
"Decimal",
"Float",
"ManyToMany",
"Model",
"ModelDefinitionError",
"ModelNotSet",

View File

@ -1,5 +1,6 @@
from ormar.fields.base import BaseField
from ormar.fields.foreign_key import ForeignKey
from ormar.fields.many_to_many import ManyToMany
from ormar.fields.model_fields import (
BigInteger,
Boolean,
@ -27,5 +28,6 @@ __all__ = [
"Float",
"Time",
"ForeignKey",
"ManyToMany",
"BaseField",
]

View File

@ -22,6 +22,7 @@ class BaseField:
index: bool
unique: bool
pydantic_only: bool
virtual: bool = False
default: Any
server_default: Any

View File

@ -22,7 +22,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
return fk(**init_dict)
def ForeignKey(
def ForeignKey( # noqa CFQ002
to: Type["Model"],
*,
name: str = None,
@ -30,6 +30,8 @@ def ForeignKey(
nullable: bool = True,
related_name: str = None,
virtual: bool = False,
onupdate: str = None,
ondelete: str = None,
) -> Type[object]:
fk_string = to.Meta.tablename + "." + to.Meta.pkname
to_field = to.__fields__[to.Meta.pkname]
@ -37,7 +39,11 @@ def ForeignKey(
to=to,
name=name,
nullable=nullable,
constraints=[sqlalchemy.schema.ForeignKey(fk_string)],
constraints=[
sqlalchemy.schema.ForeignKey(
fk_string, ondelete=ondelete, onupdate=onupdate
)
],
unique=unique,
column_type=to_field.type_.column_type,
related_name=related_name,
@ -117,7 +123,7 @@ class ForeignKeyField(BaseField):
cls, value: Any, child: "Model", to_register: bool = True
) -> Optional[Union["Model", List["Model"]]]:
if value is None:
return None
return None if not cls.virtual else []
constructors = {
f"{cls.to.__name__}": cls._register_existing_model,

View File

@ -0,0 +1,40 @@
from typing import TYPE_CHECKING, Type
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
if TYPE_CHECKING: # pragma no cover
from ormar.models import Model
def ManyToMany(
to: Type["Model"],
through: Type["Model"],
*,
name: str = None,
unique: bool = False,
related_name: str = None,
virtual: bool = False,
) -> Type[object]:
to_field = to.__fields__[to.Meta.pkname]
namespace = dict(
to=to,
through=through,
name=name,
nullable=True,
unique=unique,
column_type=to_field.type_.column_type,
related_name=related_name,
virtual=virtual,
primary_key=False,
index=False,
pydantic_only=False,
default=None,
server_default=None,
)
return type("ManyToMany", (ManyToManyField, BaseField), namespace)
class ManyToManyField(ForeignKeyField):
through: Type["Model"]

View File

@ -1,16 +1,18 @@
import logging
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
import databases
import pydantic
import sqlalchemy
from pydantic import BaseConfig
from pydantic.fields import FieldInfo
from pydantic.fields import FieldInfo, ModelField
from ormar import ForeignKey, ModelDefinitionError # noqa I100
from ormar import ForeignKey, ModelDefinitionError, Integer # noqa I100
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.fields.many_to_many import ManyToMany, ManyToManyField
from ormar.queryset import QuerySet
from ormar.relations import AliasManager
from ormar.relations.alias_manager import AliasManager
if TYPE_CHECKING: # pragma no cover
from ormar import Model
@ -30,7 +32,14 @@ class ModelMeta:
def register_relation_on_build(table_name: str, field: ForeignKey) -> None:
alias_manager.add_relation_type(field, table_name)
alias_manager.add_relation_type(field.to.Meta.tablename, table_name)
def register_many_to_many_relation_on_build(table_name: str, field: ManyToMany) -> None:
alias_manager.add_relation_type(field.through.Meta.tablename, table_name)
alias_manager.add_relation_type(
field.through.Meta.tablename, field.to.Meta.tablename
)
def reverse_field_not_already_registered(
@ -51,15 +60,72 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
if reverse_field_not_already_registered(
child, child_model_name, parent_model
):
register_reverse_model_fields(parent_model, child, child_model_name)
register_reverse_model_fields(
parent_model, child, child_model_name, model_field
)
def register_reverse_model_fields(
model: Type["Model"], child: Type["Model"], child_model_name: str
model: Type["Model"],
child: Type["Model"],
child_model_name: str,
model_field: Type["ForeignKeyField"],
) -> None:
model.Meta.model_fields[child_model_name] = ForeignKey(
child, name=child_model_name, virtual=True
if issubclass(model_field, ManyToManyField):
model.Meta.model_fields[child_model_name] = ManyToMany(
child, through=model_field.through, name=child_model_name, virtual=True
)
# register foreign keys on through model
adjust_through_many_to_many_model(model, child, model_field)
else:
model.Meta.model_fields[child_model_name] = ForeignKey(
child, name=child_model_name, virtual=True
)
def adjust_through_many_to_many_model(
model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField]
) -> None:
model_field.through.Meta.model_fields[model.get_name()] = ForeignKey(
model, name=model.get_name(), ondelete="CASCADE"
)
model_field.through.Meta.model_fields[child.get_name()] = ForeignKey(
child, name=child.get_name(), ondelete="CASCADE"
)
create_and_append_m2m_fk(model, model_field)
create_and_append_m2m_fk(child, model_field)
create_pydantic_field(model.get_name(), model, model_field)
create_pydantic_field(child.get_name(), child, model_field)
def create_pydantic_field(
field_name: str, model: Type["Model"], model_field: Type[ManyToManyField]
) -> None:
model_field.through.__fields__[field_name] = ModelField(
name=field_name,
type_=Optional[model],
model_config=model.__config__,
required=False,
class_validators=model.__validators__,
)
def create_and_append_m2m_fk(
model: Type["Model"], model_field: Type[ManyToManyField]
) -> None:
column = sqlalchemy.Column(
model.get_name(),
model.Meta.table.columns.get(model.Meta.pkname).type,
sqlalchemy.schema.ForeignKey(
model.Meta.tablename + "." + model.Meta.pkname,
ondelete="CASCADE",
onupdate="CASCADE",
),
)
model_field.through.Meta.columns.append(column)
model_field.through.Meta.table.append_column(column)
def check_pk_column_validity(
@ -77,17 +143,34 @@ def sqlalchemy_columns_from_model_fields(
) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
columns = []
pkname = None
if len(model_fields.keys()) == 0:
model_fields["id"] = Integer(name="id", primary_key=True)
logging.warning(
"Table {table_name} had no fields so auto "
"Integer primary key named `id` created."
)
for field_name, field in model_fields.items():
if field.primary_key:
pkname = check_pk_column_validity(field_name, field, pkname)
if not field.pydantic_only:
if (
not field.pydantic_only
and not field.virtual
and not issubclass(field, ManyToManyField)
):
columns.append(field.get_column(field_name))
if issubclass(field, ForeignKeyField):
register_relation_on_build(table_name, field)
register_relation_in_alias_manager(table_name, field)
return pkname, columns
def register_relation_in_alias_manager(
table_name: str, field: Type[ForeignKeyField]
) -> None:
if issubclass(field, ManyToManyField):
register_many_to_many_relation_on_build(table_name, field)
elif issubclass(field, ForeignKeyField):
register_relation_on_build(table_name, field)
def populate_default_pydantic_field_value(
type_: Type[BaseField], field: str, attrs: dict
) -> dict:
@ -109,15 +192,11 @@ def populate_pydantic_default_values(attrs: Dict) -> Dict:
return attrs
def extract_annotations_and_module(
attrs: dict, new_model: "ModelMetaclass", bases: Tuple
) -> dict:
annotations = attrs.get("__annotations__") or new_model.__annotations__
attrs["__annotations__"] = annotations
def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict:
attrs["__annotations__"] = attrs.get("__annotations__") or bases[0].__dict__.get(
"__annotations__", {}
)
attrs = populate_pydantic_default_values(attrs)
attrs["__module__"] = attrs["__module__"] or bases[0].__module__
attrs["__annotations__"] = attrs["__annotations__"] or bases[0].__annotations__
return attrs
@ -175,20 +254,26 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
attrs["Config"] = get_pydantic_base_orm_config()
attrs = extract_annotations_and_default_vals(attrs, bases)
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs
)
# breakpoint()
if hasattr(new_model, "Meta"):
attrs = extract_annotations_and_module(attrs, new_model, bases)
# attrs = extract_annotations_and_default_vals(attrs, bases)
new_model = populate_meta_orm_model_fields(attrs, new_model)
new_model = populate_meta_tablename_columns_and_pk(name, new_model)
new_model = populate_meta_sqlalchemy_table_if_required(new_model)
expand_reverse_relationships(new_model)
if new_model.Meta.pkname not in attrs["__annotations__"]:
field_name = new_model.Meta.pkname
field = Integer(name=field_name, primary_key=True)
attrs["__annotations__"][field_name] = field
populate_default_pydantic_field_value(field, field_name, attrs)
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs
)

View File

@ -4,6 +4,7 @@ from typing import Any, List, Tuple, Union
import sqlalchemy
import ormar.queryset # noqa I100
from ormar.fields.many_to_many import ManyToManyField
from ormar.models import NewBaseModel # noqa I100
@ -40,10 +41,19 @@ class Model(NewBaseModel):
if select_related:
related_models = group_related_list(select_related)
# breakpoint()
if (
previous_table
and previous_table in cls.Meta.model_fields
and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField)
):
previous_table = cls.Meta.model_fields[
previous_table
].through.Meta.tablename
table_prefix = cls.Meta.alias_manager.resolve_relation_join(
previous_table, cls.Meta.table.name
)
previous_table = cls.Meta.table.name
item = cls.populate_nested_models_from_row(

View File

@ -23,7 +23,8 @@ from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta, ModelMetaclass
from ormar.models.modelproxy import ModelTableProxy
from ormar.relations import AliasManager, RelationsManager
from ormar.relations.alias_manager import AliasManager
from ormar.relations.relation import RelationsManager
if TYPE_CHECKING: # pragma no cover
from ormar.models.model import Model
@ -96,14 +97,17 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
kwargs.get(related), self, to_register=True
)
def __setattr__(self, name: str, value: Any) -> None:
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
if name in self.__slots__:
object.__setattr__(self, name, value)
elif name == "pk":
object.__setattr__(self, self.Meta.pkname, value)
elif name in self._orm:
model = self.Meta.model_fields[name].expand_relationship(value, self)
self.__dict__[name] = model
if isinstance(self.__dict__.get(name), list):
self.__dict__[name].append(model)
else:
self.__dict__[name] = model
else:
value = (
self._convert_json(name, value, "dumps")
@ -131,15 +135,20 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if item in self._orm:
return self._orm.get(item)
def __eq__(self, other: "Model") -> bool:
if isinstance(other, NewBaseModel):
return self.__same__(other)
return super().__eq__(other) # pragma no cover
def __same__(self, other: "Model") -> bool:
return (
self._orm_id == other._orm_id
or self.__dict__ == other.__dict__
or self.dict() == other.dict()
or (self.pk == other.pk and self.pk is not None)
)
@classmethod
def get_name(cls, title: bool = False, lower: bool = True) -> str:
def get_name(cls, lower: bool = True) -> str:
name = cls.__name__
if lower:
name = name.lower()

View File

@ -5,6 +5,7 @@ from sqlalchemy import text
import ormar # noqa I100
from ormar.exceptions import QueryDefinitionError
from ormar.fields.many_to_many import ManyToManyField
if TYPE_CHECKING: # pragma no cover
from ormar import Model
@ -128,6 +129,10 @@ class QueryClause:
# against which the comparison is being made.
previous_table = model_cls.Meta.tablename
for part in related_parts:
if issubclass(model_cls.Meta.model_fields[part], ManyToManyField):
previous_table = model_cls.Meta.model_fields[
part
].through.Meta.tablename
current_table = model_cls.Meta.model_fields[part].to.Meta.tablename
manager = model_cls.Meta.alias_manager
table_prefix = manager.resolve_relation_join(previous_table, current_table)

View File

@ -4,8 +4,10 @@ import sqlalchemy
from sqlalchemy import text
import ormar # noqa I100
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.relations import AliasManager
from ormar.fields.many_to_many import ManyToManyField
from ormar.relations.alias_manager import AliasManager
if TYPE_CHECKING: # pragma no cover
from ormar import Model
@ -63,6 +65,15 @@ class Query:
)
for part in item.split("__"):
if issubclass(
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
):
_fields = join_parameters.model_cls.Meta.model_fields
new_part = _fields[part].to.get_name()
join_parameters = self._build_join_parameters(
part, join_parameters, is_multi=True
)
part = new_part
join_parameters = self._build_join_parameters(part, join_parameters)
expr = sqlalchemy.sql.select(self.columns)
@ -83,23 +94,30 @@ class Query:
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
return text(f"{left_part}={right_part}")
def _is_target_relation_key(
self, field: BaseField, target_model: Type["Model"]
) -> bool:
return issubclass(field, ForeignKeyField) and field.to.Meta == target_model.Meta
def _build_join_parameters(
self, part: str, join_params: JoinParameters
self, part: str, join_params: JoinParameters, is_multi: bool = False
) -> JoinParameters:
model_cls = join_params.model_cls.Meta.model_fields[part].to
if is_multi:
model_cls = join_params.model_cls.Meta.model_fields[part].through
else:
model_cls = join_params.model_cls.Meta.model_fields[part].to
to_table = model_cls.Meta.table.name
alias = model_cls.Meta.alias_manager.resolve_relation_join(
join_params.from_table, to_table
)
if alias not in self.used_aliases:
if join_params.prev_model.Meta.model_fields[part].virtual:
if join_params.prev_model.Meta.model_fields[part].virtual or is_multi:
to_key = next(
(
v
for k, v in model_cls.Meta.model_fields.items()
if issubclass(v, ForeignKeyField)
and v.to == join_params.prev_model
if self._is_target_relation_key(v, join_params.prev_model)
),
None,
).name
@ -129,16 +147,19 @@ class Query:
prev_model = model_cls
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
def _apply_expression_modifiers(
self, expr: sqlalchemy.sql.select
) -> sqlalchemy.sql.select:
def filter(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: # noqa A003
if self.filter_clauses:
if len(self.filter_clauses) == 1:
clause = self.filter_clauses[0]
else:
clause = sqlalchemy.sql.and_(*self.filter_clauses)
expr = expr.where(clause)
return expr
def _apply_expression_modifiers(
self, expr: sqlalchemy.sql.select
) -> sqlalchemy.sql.select:
expr = self.filter(expr)
if self.limit_count:
expr = expr.limit(self.limit_count)

View File

@ -48,6 +48,7 @@ class QuerySet:
limit_count=self.limit_count,
)
exp = qry.build_select_expression()
# print(exp.compile(compile_kwargs={"literal_binds": True}))
return exp
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
@ -70,7 +71,7 @@ class QuerySet:
if not isinstance(related, (list, tuple)):
related = [related]
related = list(self._select_related) + related
related = list(set(list(self._select_related) + related))
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
@ -82,13 +83,28 @@ class QuerySet:
async def exists(self) -> bool:
expr = self.build_select_expression()
expr = sqlalchemy.exists(expr).select()
# print(expr.compile(compile_kwargs={"literal_binds": True}))
return await self.database.fetch_val(expr)
async def count(self) -> int:
expr = self.build_select_expression().alias("subquery_for_count")
expr = sqlalchemy.func.count().select().select_from(expr)
# print(expr.compile(compile_kwargs={"literal_binds": True}))
return await self.database.fetch_val(expr)
async def delete(self, **kwargs: Any) -> int:
if kwargs:
return await self.filter(**kwargs).delete()
qry = Query(
model_cls=self.model_cls,
select_related=self._select_related,
filter_clauses=self.filter_clauses,
offset=self.query_offset,
limit_count=self.limit_count,
)
expr = qry.filter(self.table.delete())
return await self.database.execute(expr)
def limit(self, limit_count: int) -> "QuerySet":
return self.__class__(
model_cls=self.model_cls,
@ -143,6 +159,7 @@ class QuerySet:
return await self.filter(**kwargs).all()
expr = self.build_select_expression()
# breakpoint()
rows = await self.database.fetch_all(expr)
result_rows = [
self.model_cls.from_row(row, select_related=self._select_related)

View File

@ -1,198 +0,0 @@
import string
import uuid
from enum import Enum
from random import choices
from typing import List, Optional, TYPE_CHECKING, Type, Union
from weakref import proxy
import sqlalchemy
from sqlalchemy import text
import ormar # noqa I100
from ormar.exceptions import RelationshipInstanceError # noqa I100
from ormar.fields.foreign_key import ForeignKeyField # noqa I100
if TYPE_CHECKING: # pragma no cover
from ormar.models import Model
def get_table_alias() -> str:
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
class RelationType(Enum):
PRIMARY = 1
REVERSE = 2
class AliasManager:
def __init__(self) -> None:
self._aliases = dict()
@staticmethod
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
return [
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
for column in table.columns
]
@staticmethod
def prefixed_table_name(alias: str, name: str) -> text:
return text(f"{name} {alias}_{name}")
def add_relation_type(self, field: ForeignKeyField, table_name: str,) -> None:
if f"{table_name}_{field.to.Meta.tablename}" not in self._aliases:
self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias()
if f"{field.to.Meta.tablename}_{table_name}" not in self._aliases:
self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias()
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
return self._aliases.get(f"{from_table}_{to_table}", "")
class RelationProxy(list):
def __init__(self, relation: "Relation") -> None:
super(RelationProxy, self).__init__()
self.relation = relation
self._owner = self.relation.manager.owner
def remove(self, item: "Model") -> None:
super().remove(item)
rel_name = item.resolve_relation_name(item, self._owner)
item._orm._get(rel_name).remove(self._owner)
def append(self, item: "Model") -> None:
super().append(item)
def add(self, item: "Model") -> None:
rel_name = item.resolve_relation_name(item, self._owner)
setattr(item, rel_name, self._owner)
class Relation:
def __init__(self, manager: "RelationsManager", type_: RelationType) -> None:
self.manager = manager
self._owner = manager.owner
self._type = type_
self.related_models = (
RelationProxy(relation=self) if type_ == RelationType.REVERSE else None
)
def _find_existing(self, child: "Model") -> Optional[int]:
for ind, relation_child in enumerate(self.related_models[:]):
try:
if relation_child.__same__(child):
return ind
except ReferenceError: # pragma no cover
self.related_models.pop(ind)
return None
def add(self, child: "Model") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY:
self.related_models = child
self._owner.__dict__[relation_name] = child
else:
if self._find_existing(child) is None:
self.related_models.append(child)
rel = self._owner.__dict__.get(relation_name, [])
rel.append(child)
self._owner.__dict__[relation_name] = rel
def remove(self, child: "Model") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY:
if self.related_models.__same__(child):
self.related_models = None
del self._owner.__dict__[relation_name]
else:
position = self._find_existing(child)
if position is not None:
self.related_models.pop(position)
del self._owner.__dict__[relation_name][position]
def get(self) -> Union[List["Model"], "Model"]:
return self.related_models
def __repr__(self) -> str: # pragma no cover
return str(self.related_models)
class RelationsManager:
def __init__(
self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None
) -> None:
self.owner = owner
self._related_fields = related_fields or []
self._related_names = [field.name for field in self._related_fields]
self._relations = dict()
for field in self._related_fields:
self._add_relation(field)
def _add_relation(self, field: Type[ForeignKeyField]) -> None:
self._relations[field.name] = Relation(
manager=self,
type_=RelationType.PRIMARY if not field.virtual else RelationType.REVERSE,
)
def __contains__(self, item: str) -> bool:
return item in self._related_names
def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]:
relation = self._relations.get(name, None)
if relation:
return relation.get()
def _get(self, name: str) -> Optional[Relation]:
relation = self._relations.get(name, None)
if relation:
return relation
@staticmethod
def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None:
to_field = next(
(
field
for field in child._orm._related_fields
if field.to == parent.__class__ or field.to.Meta == parent.Meta
),
None,
)
if not to_field: # pragma no cover
raise RelationshipInstanceError(
f"Model {child.__class__} does not have "
f"reference to model {parent.__class__}"
)
to_name = to_field.name
if virtual:
child_name, to_name = to_name, child_name or child.get_name()
child, parent = parent, proxy(child)
else:
child_name = child_name or child.get_name() + "s"
child = proxy(child)
parent_relation = parent._orm._get(child_name)
if not parent_relation:
ormar.models.expand_reverse_relationships(child.__class__)
name = parent.resolve_relation_name(parent, child)
field = parent.Meta.model_fields[name]
parent._orm._add_relation(field)
parent_relation = parent._orm._get(child_name)
parent_relation.add(child)
child._orm._get(to_name).add(parent)
def remove(self, name: str, child: "Model") -> None:
relation = self._get(name)
relation.remove(child)
@staticmethod
def remove_parent(item: "Model", name: Union[str, "Model"]) -> None:
related_model = name
name = item.resolve_relation_name(item, related_model)
if name in item._orm:
relation_name = item.resolve_relation_name(related_model, item)
item._orm.remove(name, related_model)
related_model._orm.remove(relation_name, item)

View File

@ -0,0 +1,3 @@
from ormar.relations.alias_manager import AliasManager
__all__ = ["AliasManager"]

View File

@ -0,0 +1,36 @@
import string
import uuid
from random import choices
from typing import List
import sqlalchemy
from sqlalchemy import text
def get_table_alias() -> str:
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
class AliasManager:
def __init__(self) -> None:
self._aliases = dict()
@staticmethod
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
return [
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
for column in table.columns
]
@staticmethod
def prefixed_table_name(alias: str, name: str) -> text:
return text(f"{name} {alias}_{name}")
def add_relation_type(self, to_table_name: str, table_name: str,) -> None:
if f"{table_name}_{to_table_name}" not in self._aliases:
self._aliases[f"{table_name}_{to_table_name}"] = get_table_alias()
if f"{to_table_name}_{table_name}" not in self._aliases:
self._aliases[f"{to_table_name}_{table_name}"] = get_table_alias()
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
return self._aliases.get(f"{from_table}_{to_table}", "")

332
ormar/relations/relation.py Normal file
View File

@ -0,0 +1,332 @@
from enum import Enum
from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, Union
from weakref import proxy
import ormar # noqa I100
from ormar.exceptions import RelationshipInstanceError # noqa I100
from ormar.fields.foreign_key import ForeignKeyField # noqa I100
from ormar.fields.many_to_many import ManyToManyField
from ormar.queryset import QuerySet
if TYPE_CHECKING: # pragma no cover
from ormar import Model
class RelationType(Enum):
PRIMARY = 1
REVERSE = 2
MULTIPLE = 3
class QuerysetProxy:
if TYPE_CHECKING: # pragma no cover
relation: "Relation"
def __init__(self, relation: "Relation") -> None:
self.relation = relation
self.queryset = None
def _assign_child_to_parent(self, child: "Model") -> None:
owner = self.relation._owner
rel_name = owner.resolve_relation_name(owner, child)
setattr(owner, rel_name, child)
def _register_related(self, child: Union["Model", List["Model"]]) -> None:
if isinstance(child, list):
for subchild in child:
self._assign_child_to_parent(subchild)
else:
self._assign_child_to_parent(child)
async def create_through_instance(self, child: "Model") -> None:
queryset = QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name()
child_column = child.get_name()
kwargs = {owner_column: self.relation._owner, child_column: child}
await queryset.create(**kwargs)
async def delete_through_instance(self, child: "Model") -> None:
queryset = QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name()
child_column = child.get_name()
kwargs = {owner_column: self.relation._owner, child_column: child}
link_instance = await queryset.filter(**kwargs).get()
await link_instance.delete()
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
return self.queryset.filter(**kwargs)
def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet":
return self.queryset.select_related(related)
async def exists(self) -> bool:
return await self.queryset.exists()
async def count(self) -> int:
return await self.queryset.count()
async def clear(self) -> int:
queryset = QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name()
kwargs = {owner_column: self.relation._owner}
return await queryset.delete(**kwargs)
def limit(self, limit_count: int) -> "QuerySet":
return self.queryset.limit(limit_count)
def offset(self, offset: int) -> "QuerySet":
return self.queryset.offset(offset)
async def first(self, **kwargs: Any) -> "Model":
first = await self.queryset.first(**kwargs)
self._register_related(first)
return first
async def get(self, **kwargs: Any) -> "Model":
get = await self.queryset.get(**kwargs)
self._register_related(get)
return get
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
all_items = await self.queryset.all(**kwargs)
self._register_related(all_items)
return all_items
async def create(self, **kwargs: Any) -> "Model":
create = await self.queryset.create(**kwargs)
self._register_related(create)
await self.create_through_instance(create)
return create
class RelationProxy(list):
def __init__(self, relation: "Relation") -> None:
super(RelationProxy, self).__init__()
self.relation = relation
self._owner = self.relation.manager.owner
self.queryset_proxy = QuerysetProxy(relation=self.relation)
def __getattribute__(self, item: str) -> Any:
if item in ["count", "clear"]:
if not self.queryset_proxy.queryset:
self.queryset_proxy.queryset = self._set_queryset()
return getattr(self.queryset_proxy, item)
return super().__getattribute__(item)
def __getattr__(self, item: str) -> Any:
if not self.queryset_proxy.queryset:
self.queryset_proxy.queryset = self._set_queryset()
return getattr(self.queryset_proxy, item)
def _set_queryset(self) -> QuerySet:
owner_table = self.relation._owner.Meta.tablename
pkname = self.relation._owner.Meta.pkname
pk_value = self.relation._owner.pk
if not pk_value:
raise RelationshipInstanceError(
"You cannot query many to many relationship on unsaved model."
)
kwargs = {f"{owner_table}__{pkname}": pk_value}
queryset = (
QuerySet(model_cls=self.relation.to)
.select_related(owner_table)
.filter(**kwargs)
)
return queryset
async def remove(self, item: "Model") -> None:
super().remove(item)
rel_name = item.resolve_relation_name(item, self._owner)
item._orm._get(rel_name).remove(self._owner)
if self.relation._type == RelationType.MULTIPLE:
await self.queryset_proxy.delete_through_instance(item)
def append(self, item: "Model") -> None:
super().append(item)
async def add(self, item: "Model") -> None:
if self.relation._type == RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner)
setattr(item, rel_name, self._owner)
class Relation:
def __init__(
self,
manager: "RelationsManager",
type_: RelationType,
to: Type["Model"],
through: Type["Model"] = None,
) -> None:
self.manager = manager
self._owner = manager.owner
self._type = type_
self.to = to
self.through = through
self.related_models = (
RelationProxy(relation=self)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None
)
def _find_existing(self, child: "Model") -> Optional[int]:
for ind, relation_child in enumerate(self.related_models[:]):
try:
if relation_child.__same__(child):
return ind
except ReferenceError: # pragma no cover
self.related_models.pop(ind)
return None
def add(self, child: "Model") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY:
self.related_models = child
self._owner.__dict__[relation_name] = child
else:
if self._find_existing(child) is None:
self.related_models.append(child)
rel = self._owner.__dict__.get(relation_name, [])
rel = rel or []
if not isinstance(rel, list):
rel = [rel]
rel.append(child)
self._owner.__dict__[relation_name] = rel
def remove(self, child: "Model") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY:
if self.related_models.__same__(child):
self.related_models = None
del self._owner.__dict__[relation_name]
else:
position = self._find_existing(child)
if position is not None:
self.related_models.pop(position)
del self._owner.__dict__[relation_name][position]
def get(self) -> Union[List["Model"], "Model"]:
return self.related_models
def __repr__(self) -> str: # pragma no cover
return str(self.related_models)
class RelationsManager:
def __init__(
self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None
) -> None:
self.owner = proxy(owner)
self._related_fields = related_fields or []
self._related_names = [field.name for field in self._related_fields]
self._relations = dict()
for field in self._related_fields:
self._add_relation(field)
def _get_relation_type(self, field: Type[ForeignKeyField]) -> RelationType:
if issubclass(field, ManyToManyField):
return RelationType.MULTIPLE
return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE
def _add_relation(self, field: Type[ForeignKeyField]) -> None:
self._relations[field.name] = Relation(
manager=self,
type_=self._get_relation_type(field),
to=field.to,
through=getattr(field, "through", None),
)
def __contains__(self, item: str) -> bool:
return item in self._related_names
def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]:
relation = self._relations.get(name, None)
if relation is not None:
return relation.get()
def _get(self, name: str) -> Optional[Relation]:
relation = self._relations.get(name, None)
if relation is not None:
return relation
@staticmethod
def register_missing_relation(
parent: "Model", child: "Model", child_name: str
) -> Relation:
ormar.models.expand_reverse_relationships(child.__class__)
name = parent.resolve_relation_name(parent, child)
field = parent.Meta.model_fields[name]
parent._orm._add_relation(field)
parent_relation = parent._orm._get(child_name)
return parent_relation
@staticmethod
def get_relations_sides_and_names(
to_field: Type[ForeignKeyField],
parent: "Model",
child: "Model",
child_name: str,
virtual: bool,
) -> Tuple["Model", "Model", str, str]:
to_name = to_field.name
if issubclass(to_field, ManyToManyField):
child_name, to_name = (
child.resolve_relation_name(parent, child),
child.resolve_relation_name(child, parent),
)
child = proxy(child)
elif virtual:
child_name, to_name = to_name, child_name or child.get_name()
child, parent = parent, proxy(child)
else:
child_name = child_name or child.get_name() + "s"
child = proxy(child)
return parent, child, child_name, to_name
@staticmethod
def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None:
to_field = next(
(
field
for field in child._orm._related_fields
if field.to == parent.__class__ or field.to.Meta == parent.Meta
),
None,
)
if not to_field: # pragma no cover
raise RelationshipInstanceError(
f"Model {child.__class__} does not have "
f"reference to model {parent.__class__}"
)
(
parent,
child,
child_name,
to_name,
) = RelationsManager.get_relations_sides_and_names(
to_field, parent, child, child_name, virtual
)
parent_relation = parent._orm._get(child_name)
if not parent_relation:
parent_relation = RelationsManager.register_missing_relation(
parent, child, child_name
)
parent_relation.add(child)
child._orm._get(to_name).add(parent)
def remove(self, name: str, child: "Model") -> None:
relation = self._get(name)
relation.remove(child)
@staticmethod
def remove_parent(item: "Model", name: Union[str, "Model"]) -> None:
related_model = name
name = item.resolve_relation_name(item, related_model)
if name in item._orm:
relation_name = item.resolve_relation_name(related_model, item)
item._orm.remove(name, related_model)
related_model._orm.remove(relation_name, item)

View File

@ -1,5 +1,3 @@
import gc
import databases
import pytest
import sqlalchemy
@ -179,7 +177,7 @@ async def test_model_removal_from_relations():
await track3.save()
assert len(album.tracks) == 3
album.tracks.remove(track1)
await album.tracks.remove(track1)
assert len(album.tracks) == 2
assert track1.album is None
@ -187,7 +185,7 @@ async def test_model_removal_from_relations():
track1 = await Track.objects.get(title="The Birdman")
assert track1.album is None
album.tracks.add(track1)
await album.tracks.add(track1)
assert len(album.tracks) == 3
assert track1.album == album

178
tests/test_many_to_many.py Normal file
View File

@ -0,0 +1,178 @@
import databases
import pytest
import sqlalchemy
import ormar
from ormar.exceptions import RelationshipInstanceError
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Author(ormar.Model):
class Meta:
tablename = "authors"
database = database
metadata = metadata
id: ormar.Integer(primary_key=True)
first_name: ormar.String(max_length=80)
last_name: ormar.String(max_length=80)
class Category(ormar.Model):
class Meta:
tablename = "categories"
database = database
metadata = metadata
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=40)
class PostCategory(ormar.Model):
class Meta:
tablename = "posts_categories"
database = database
metadata = metadata
class Post(ormar.Model):
class Meta:
tablename = "posts"
database = database
metadata = metadata
id: ormar.Integer(primary_key=True)
title: ormar.String(max_length=200)
categories: ormar.ManyToMany(Category, through=PostCategory)
author: ormar.ForeignKey(Author)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.fixture(scope="function")
async def cleanup():
yield
await PostCategory.objects.delete()
await Post.objects.delete()
await Category.objects.delete()
await Author.objects.delete()
@pytest.mark.asyncio
async def test_assigning_related_objects(cleanup):
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
# Add a category to a post.
await post.categories.add(news)
# or from the other end:
await news.posts.add(post)
# Creating related object from instance:
await post.categories.create(name="Tips")
assert len(post.categories) == 2
post_categories = await post.categories.all()
assert len(post_categories) == 2
@pytest.mark.asyncio
async def test_quering_of_the_m2m_models(cleanup):
# orm can do this already.
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
# tl;dr: `post.categories` exposes the QuerySet API.
await post.categories.add(news)
post_categories = await post.categories.all()
assert len(post_categories) == 1
assert news == await post.categories.get(name="News")
num_posts = await news.posts.count()
assert num_posts == 1
posts_about_m2m = await news.posts.filter(title__contains="M2M").all()
assert len(posts_about_m2m) == 1
assert posts_about_m2m[0] == post
posts_about_python = await Post.objects.filter(categories__name="python").all()
assert len(posts_about_python) == 0
# Traversal of relationships: which categories has Guido contributed to?
category = await Category.objects.filter(posts__author=guido).get()
assert category == news
# or:
category2 = await Category.objects.filter(posts__author__first_name="Guido").get()
assert category2 == news
@pytest.mark.asyncio
async def test_removal_of_the_relations(cleanup):
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
await post.categories.add(news)
assert len(await post.categories.all()) == 1
await post.categories.remove(news)
assert len(await post.categories.all()) == 0
# or:
await news.posts.add(post)
assert len(await news.posts.all()) == 1
await news.posts.remove(post)
assert len(await news.posts.all()) == 0
# Remove all related objects:
await post.categories.add(news)
await post.categories.clear()
assert len(await post.categories.all()) == 0
# post would also lose 'news' category when running:
await post.categories.add(news)
await news.delete()
assert len(await post.categories.all()) == 0
@pytest.mark.asyncio
async def test_selecting_related(cleanup):
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
recent = await Category.objects.create(name="Recent")
await post.categories.add(news)
await post.categories.add(recent)
assert len(await post.categories.all()) == 2
# Loads categories and posts (2 queries) and perform the join in Python.
categories = await Category.objects.select_related("posts").all()
# No extra queries needed => no more `await`s required.
for category in categories:
assert category.posts[0] == post
news_posts = await news.posts.select_related("author").all()
assert news_posts[0].author == guido
assert (await post.categories.limit(1).all())[0] == news
assert (await post.categories.offset(1).limit(1).all())[0] == recent
assert await post.categories.first() == news
assert await post.categories.exists()
@pytest.mark.asyncio
async def test_selecting_related_fail_without_saving(cleanup):
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = Post(title="Hello, M2M", author=guido)
with pytest.raises(RelationshipInstanceError):
await post.categories.all()