working basic many to many relationships
This commit is contained in:
@ -9,6 +9,7 @@ from ormar.fields import (
|
||||
ForeignKey,
|
||||
Integer,
|
||||
JSON,
|
||||
ManyToMany,
|
||||
String,
|
||||
Text,
|
||||
Time,
|
||||
@ -28,6 +29,7 @@ __all__ = [
|
||||
"Date",
|
||||
"Decimal",
|
||||
"Float",
|
||||
"ManyToMany",
|
||||
"Model",
|
||||
"ModelDefinitionError",
|
||||
"ModelNotSet",
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from ormar.fields.base import BaseField
|
||||
from ormar.fields.foreign_key import ForeignKey
|
||||
from ormar.fields.many_to_many import ManyToMany
|
||||
from ormar.fields.model_fields import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
@ -27,5 +28,6 @@ __all__ = [
|
||||
"Float",
|
||||
"Time",
|
||||
"ForeignKey",
|
||||
"ManyToMany",
|
||||
"BaseField",
|
||||
]
|
||||
|
||||
@ -22,6 +22,7 @@ class BaseField:
|
||||
index: bool
|
||||
unique: bool
|
||||
pydantic_only: bool
|
||||
virtual: bool = False
|
||||
|
||||
default: Any
|
||||
server_default: Any
|
||||
|
||||
@ -22,7 +22,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
|
||||
return fk(**init_dict)
|
||||
|
||||
|
||||
def ForeignKey(
|
||||
def ForeignKey( # noqa CFQ002
|
||||
to: Type["Model"],
|
||||
*,
|
||||
name: str = None,
|
||||
@ -30,6 +30,8 @@ def ForeignKey(
|
||||
nullable: bool = True,
|
||||
related_name: str = None,
|
||||
virtual: bool = False,
|
||||
onupdate: str = None,
|
||||
ondelete: str = None,
|
||||
) -> Type[object]:
|
||||
fk_string = to.Meta.tablename + "." + to.Meta.pkname
|
||||
to_field = to.__fields__[to.Meta.pkname]
|
||||
@ -37,7 +39,11 @@ def ForeignKey(
|
||||
to=to,
|
||||
name=name,
|
||||
nullable=nullable,
|
||||
constraints=[sqlalchemy.schema.ForeignKey(fk_string)],
|
||||
constraints=[
|
||||
sqlalchemy.schema.ForeignKey(
|
||||
fk_string, ondelete=ondelete, onupdate=onupdate
|
||||
)
|
||||
],
|
||||
unique=unique,
|
||||
column_type=to_field.type_.column_type,
|
||||
related_name=related_name,
|
||||
@ -117,7 +123,7 @@ class ForeignKeyField(BaseField):
|
||||
cls, value: Any, child: "Model", to_register: bool = True
|
||||
) -> Optional[Union["Model", List["Model"]]]:
|
||||
if value is None:
|
||||
return None
|
||||
return None if not cls.virtual else []
|
||||
|
||||
constructors = {
|
||||
f"{cls.to.__name__}": cls._register_existing_model,
|
||||
|
||||
40
ormar/fields/many_to_many.py
Normal file
40
ormar/fields/many_to_many.py
Normal file
@ -0,0 +1,40 @@
|
||||
from typing import TYPE_CHECKING, Type
|
||||
|
||||
from ormar.fields import BaseField
|
||||
from ormar.fields.foreign_key import ForeignKeyField
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar.models import Model
|
||||
|
||||
|
||||
def ManyToMany(
|
||||
to: Type["Model"],
|
||||
through: Type["Model"],
|
||||
*,
|
||||
name: str = None,
|
||||
unique: bool = False,
|
||||
related_name: str = None,
|
||||
virtual: bool = False,
|
||||
) -> Type[object]:
|
||||
to_field = to.__fields__[to.Meta.pkname]
|
||||
namespace = dict(
|
||||
to=to,
|
||||
through=through,
|
||||
name=name,
|
||||
nullable=True,
|
||||
unique=unique,
|
||||
column_type=to_field.type_.column_type,
|
||||
related_name=related_name,
|
||||
virtual=virtual,
|
||||
primary_key=False,
|
||||
index=False,
|
||||
pydantic_only=False,
|
||||
default=None,
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
return type("ManyToMany", (ManyToManyField, BaseField), namespace)
|
||||
|
||||
|
||||
class ManyToManyField(ForeignKeyField):
|
||||
through: Type["Model"]
|
||||
@ -1,16 +1,18 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
|
||||
|
||||
import databases
|
||||
import pydantic
|
||||
import sqlalchemy
|
||||
from pydantic import BaseConfig
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic.fields import FieldInfo, ModelField
|
||||
|
||||
from ormar import ForeignKey, ModelDefinitionError # noqa I100
|
||||
from ormar import ForeignKey, ModelDefinitionError, Integer # noqa I100
|
||||
from ormar.fields import BaseField
|
||||
from ormar.fields.foreign_key import ForeignKeyField
|
||||
from ormar.fields.many_to_many import ManyToMany, ManyToManyField
|
||||
from ormar.queryset import QuerySet
|
||||
from ormar.relations import AliasManager
|
||||
from ormar.relations.alias_manager import AliasManager
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
@ -30,7 +32,14 @@ class ModelMeta:
|
||||
|
||||
|
||||
def register_relation_on_build(table_name: str, field: ForeignKey) -> None:
|
||||
alias_manager.add_relation_type(field, table_name)
|
||||
alias_manager.add_relation_type(field.to.Meta.tablename, table_name)
|
||||
|
||||
|
||||
def register_many_to_many_relation_on_build(table_name: str, field: ManyToMany) -> None:
|
||||
alias_manager.add_relation_type(field.through.Meta.tablename, table_name)
|
||||
alias_manager.add_relation_type(
|
||||
field.through.Meta.tablename, field.to.Meta.tablename
|
||||
)
|
||||
|
||||
|
||||
def reverse_field_not_already_registered(
|
||||
@ -51,17 +60,74 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
|
||||
if reverse_field_not_already_registered(
|
||||
child, child_model_name, parent_model
|
||||
):
|
||||
register_reverse_model_fields(parent_model, child, child_model_name)
|
||||
register_reverse_model_fields(
|
||||
parent_model, child, child_model_name, model_field
|
||||
)
|
||||
|
||||
|
||||
def register_reverse_model_fields(
|
||||
model: Type["Model"], child: Type["Model"], child_model_name: str
|
||||
model: Type["Model"],
|
||||
child: Type["Model"],
|
||||
child_model_name: str,
|
||||
model_field: Type["ForeignKeyField"],
|
||||
) -> None:
|
||||
if issubclass(model_field, ManyToManyField):
|
||||
model.Meta.model_fields[child_model_name] = ManyToMany(
|
||||
child, through=model_field.through, name=child_model_name, virtual=True
|
||||
)
|
||||
# register foreign keys on through model
|
||||
adjust_through_many_to_many_model(model, child, model_field)
|
||||
else:
|
||||
model.Meta.model_fields[child_model_name] = ForeignKey(
|
||||
child, name=child_model_name, virtual=True
|
||||
)
|
||||
|
||||
|
||||
def adjust_through_many_to_many_model(
|
||||
model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField]
|
||||
) -> None:
|
||||
model_field.through.Meta.model_fields[model.get_name()] = ForeignKey(
|
||||
model, name=model.get_name(), ondelete="CASCADE"
|
||||
)
|
||||
model_field.through.Meta.model_fields[child.get_name()] = ForeignKey(
|
||||
child, name=child.get_name(), ondelete="CASCADE"
|
||||
)
|
||||
|
||||
create_and_append_m2m_fk(model, model_field)
|
||||
create_and_append_m2m_fk(child, model_field)
|
||||
|
||||
create_pydantic_field(model.get_name(), model, model_field)
|
||||
create_pydantic_field(child.get_name(), child, model_field)
|
||||
|
||||
|
||||
def create_pydantic_field(
|
||||
field_name: str, model: Type["Model"], model_field: Type[ManyToManyField]
|
||||
) -> None:
|
||||
model_field.through.__fields__[field_name] = ModelField(
|
||||
name=field_name,
|
||||
type_=Optional[model],
|
||||
model_config=model.__config__,
|
||||
required=False,
|
||||
class_validators=model.__validators__,
|
||||
)
|
||||
|
||||
|
||||
def create_and_append_m2m_fk(
|
||||
model: Type["Model"], model_field: Type[ManyToManyField]
|
||||
) -> None:
|
||||
column = sqlalchemy.Column(
|
||||
model.get_name(),
|
||||
model.Meta.table.columns.get(model.Meta.pkname).type,
|
||||
sqlalchemy.schema.ForeignKey(
|
||||
model.Meta.tablename + "." + model.Meta.pkname,
|
||||
ondelete="CASCADE",
|
||||
onupdate="CASCADE",
|
||||
),
|
||||
)
|
||||
model_field.through.Meta.columns.append(column)
|
||||
model_field.through.Meta.table.append_column(column)
|
||||
|
||||
|
||||
def check_pk_column_validity(
|
||||
field_name: str, field: BaseField, pkname: str
|
||||
) -> Optional[str]:
|
||||
@ -77,17 +143,34 @@ def sqlalchemy_columns_from_model_fields(
|
||||
) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
|
||||
columns = []
|
||||
pkname = None
|
||||
if len(model_fields.keys()) == 0:
|
||||
model_fields["id"] = Integer(name="id", primary_key=True)
|
||||
logging.warning(
|
||||
"Table {table_name} had no fields so auto "
|
||||
"Integer primary key named `id` created."
|
||||
)
|
||||
for field_name, field in model_fields.items():
|
||||
if field.primary_key:
|
||||
pkname = check_pk_column_validity(field_name, field, pkname)
|
||||
if not field.pydantic_only:
|
||||
if (
|
||||
not field.pydantic_only
|
||||
and not field.virtual
|
||||
and not issubclass(field, ManyToManyField)
|
||||
):
|
||||
columns.append(field.get_column(field_name))
|
||||
if issubclass(field, ForeignKeyField):
|
||||
register_relation_on_build(table_name, field)
|
||||
|
||||
register_relation_in_alias_manager(table_name, field)
|
||||
return pkname, columns
|
||||
|
||||
|
||||
def register_relation_in_alias_manager(
|
||||
table_name: str, field: Type[ForeignKeyField]
|
||||
) -> None:
|
||||
if issubclass(field, ManyToManyField):
|
||||
register_many_to_many_relation_on_build(table_name, field)
|
||||
elif issubclass(field, ForeignKeyField):
|
||||
register_relation_on_build(table_name, field)
|
||||
|
||||
|
||||
def populate_default_pydantic_field_value(
|
||||
type_: Type[BaseField], field: str, attrs: dict
|
||||
) -> dict:
|
||||
@ -109,15 +192,11 @@ def populate_pydantic_default_values(attrs: Dict) -> Dict:
|
||||
return attrs
|
||||
|
||||
|
||||
def extract_annotations_and_module(
|
||||
attrs: dict, new_model: "ModelMetaclass", bases: Tuple
|
||||
) -> dict:
|
||||
annotations = attrs.get("__annotations__") or new_model.__annotations__
|
||||
attrs["__annotations__"] = annotations
|
||||
def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict:
|
||||
attrs["__annotations__"] = attrs.get("__annotations__") or bases[0].__dict__.get(
|
||||
"__annotations__", {}
|
||||
)
|
||||
attrs = populate_pydantic_default_values(attrs)
|
||||
|
||||
attrs["__module__"] = attrs["__module__"] or bases[0].__module__
|
||||
attrs["__annotations__"] = attrs["__annotations__"] or bases[0].__annotations__
|
||||
return attrs
|
||||
|
||||
|
||||
@ -175,20 +254,26 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
|
||||
|
||||
class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
|
||||
|
||||
attrs["Config"] = get_pydantic_base_orm_config()
|
||||
attrs = extract_annotations_and_default_vals(attrs, bases)
|
||||
new_model = super().__new__( # type: ignore
|
||||
mcs, name, bases, attrs
|
||||
)
|
||||
# breakpoint()
|
||||
|
||||
if hasattr(new_model, "Meta"):
|
||||
|
||||
attrs = extract_annotations_and_module(attrs, new_model, bases)
|
||||
# attrs = extract_annotations_and_default_vals(attrs, bases)
|
||||
new_model = populate_meta_orm_model_fields(attrs, new_model)
|
||||
new_model = populate_meta_tablename_columns_and_pk(name, new_model)
|
||||
new_model = populate_meta_sqlalchemy_table_if_required(new_model)
|
||||
expand_reverse_relationships(new_model)
|
||||
|
||||
if new_model.Meta.pkname not in attrs["__annotations__"]:
|
||||
field_name = new_model.Meta.pkname
|
||||
field = Integer(name=field_name, primary_key=True)
|
||||
attrs["__annotations__"][field_name] = field
|
||||
populate_default_pydantic_field_value(field, field_name, attrs)
|
||||
|
||||
new_model = super().__new__( # type: ignore
|
||||
mcs, name, bases, attrs
|
||||
)
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, List, Tuple, Union
|
||||
import sqlalchemy
|
||||
|
||||
import ormar.queryset # noqa I100
|
||||
from ormar.fields.many_to_many import ManyToManyField
|
||||
from ormar.models import NewBaseModel # noqa I100
|
||||
|
||||
|
||||
@ -40,10 +41,19 @@ class Model(NewBaseModel):
|
||||
if select_related:
|
||||
related_models = group_related_list(select_related)
|
||||
|
||||
# breakpoint()
|
||||
if (
|
||||
previous_table
|
||||
and previous_table in cls.Meta.model_fields
|
||||
and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField)
|
||||
):
|
||||
previous_table = cls.Meta.model_fields[
|
||||
previous_table
|
||||
].through.Meta.tablename
|
||||
|
||||
table_prefix = cls.Meta.alias_manager.resolve_relation_join(
|
||||
previous_table, cls.Meta.table.name
|
||||
)
|
||||
|
||||
previous_table = cls.Meta.table.name
|
||||
|
||||
item = cls.populate_nested_models_from_row(
|
||||
|
||||
@ -23,7 +23,8 @@ from ormar.fields import BaseField
|
||||
from ormar.fields.foreign_key import ForeignKeyField
|
||||
from ormar.models.metaclass import ModelMeta, ModelMetaclass
|
||||
from ormar.models.modelproxy import ModelTableProxy
|
||||
from ormar.relations import AliasManager, RelationsManager
|
||||
from ormar.relations.alias_manager import AliasManager
|
||||
from ormar.relations.relation import RelationsManager
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar.models.model import Model
|
||||
@ -96,13 +97,16 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
kwargs.get(related), self, to_register=True
|
||||
)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
|
||||
if name in self.__slots__:
|
||||
object.__setattr__(self, name, value)
|
||||
elif name == "pk":
|
||||
object.__setattr__(self, self.Meta.pkname, value)
|
||||
elif name in self._orm:
|
||||
model = self.Meta.model_fields[name].expand_relationship(value, self)
|
||||
if isinstance(self.__dict__.get(name), list):
|
||||
self.__dict__[name].append(model)
|
||||
else:
|
||||
self.__dict__[name] = model
|
||||
else:
|
||||
value = (
|
||||
@ -131,15 +135,20 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
if item in self._orm:
|
||||
return self._orm.get(item)
|
||||
|
||||
def __eq__(self, other: "Model") -> bool:
|
||||
if isinstance(other, NewBaseModel):
|
||||
return self.__same__(other)
|
||||
return super().__eq__(other) # pragma no cover
|
||||
|
||||
def __same__(self, other: "Model") -> bool:
|
||||
return (
|
||||
self._orm_id == other._orm_id
|
||||
or self.__dict__ == other.__dict__
|
||||
or self.dict() == other.dict()
|
||||
or (self.pk == other.pk and self.pk is not None)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls, title: bool = False, lower: bool = True) -> str:
|
||||
def get_name(cls, lower: bool = True) -> str:
|
||||
name = cls.__name__
|
||||
if lower:
|
||||
name = name.lower()
|
||||
|
||||
@ -5,6 +5,7 @@ from sqlalchemy import text
|
||||
|
||||
import ormar # noqa I100
|
||||
from ormar.exceptions import QueryDefinitionError
|
||||
from ormar.fields.many_to_many import ManyToManyField
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
@ -128,6 +129,10 @@ class QueryClause:
|
||||
# against which the comparison is being made.
|
||||
previous_table = model_cls.Meta.tablename
|
||||
for part in related_parts:
|
||||
if issubclass(model_cls.Meta.model_fields[part], ManyToManyField):
|
||||
previous_table = model_cls.Meta.model_fields[
|
||||
part
|
||||
].through.Meta.tablename
|
||||
current_table = model_cls.Meta.model_fields[part].to.Meta.tablename
|
||||
manager = model_cls.Meta.alias_manager
|
||||
table_prefix = manager.resolve_relation_join(previous_table, current_table)
|
||||
|
||||
@ -4,8 +4,10 @@ import sqlalchemy
|
||||
from sqlalchemy import text
|
||||
|
||||
import ormar # noqa I100
|
||||
from ormar.fields import BaseField
|
||||
from ormar.fields.foreign_key import ForeignKeyField
|
||||
from ormar.relations import AliasManager
|
||||
from ormar.fields.many_to_many import ManyToManyField
|
||||
from ormar.relations.alias_manager import AliasManager
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
@ -63,6 +65,15 @@ class Query:
|
||||
)
|
||||
|
||||
for part in item.split("__"):
|
||||
if issubclass(
|
||||
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
|
||||
):
|
||||
_fields = join_parameters.model_cls.Meta.model_fields
|
||||
new_part = _fields[part].to.get_name()
|
||||
join_parameters = self._build_join_parameters(
|
||||
part, join_parameters, is_multi=True
|
||||
)
|
||||
part = new_part
|
||||
join_parameters = self._build_join_parameters(part, join_parameters)
|
||||
|
||||
expr = sqlalchemy.sql.select(self.columns)
|
||||
@ -83,9 +94,17 @@ class Query:
|
||||
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
|
||||
return text(f"{left_part}={right_part}")
|
||||
|
||||
def _is_target_relation_key(
|
||||
self, field: BaseField, target_model: Type["Model"]
|
||||
) -> bool:
|
||||
return issubclass(field, ForeignKeyField) and field.to.Meta == target_model.Meta
|
||||
|
||||
def _build_join_parameters(
|
||||
self, part: str, join_params: JoinParameters
|
||||
self, part: str, join_params: JoinParameters, is_multi: bool = False
|
||||
) -> JoinParameters:
|
||||
if is_multi:
|
||||
model_cls = join_params.model_cls.Meta.model_fields[part].through
|
||||
else:
|
||||
model_cls = join_params.model_cls.Meta.model_fields[part].to
|
||||
to_table = model_cls.Meta.table.name
|
||||
|
||||
@ -93,13 +112,12 @@ class Query:
|
||||
join_params.from_table, to_table
|
||||
)
|
||||
if alias not in self.used_aliases:
|
||||
if join_params.prev_model.Meta.model_fields[part].virtual:
|
||||
if join_params.prev_model.Meta.model_fields[part].virtual or is_multi:
|
||||
to_key = next(
|
||||
(
|
||||
v
|
||||
for k, v in model_cls.Meta.model_fields.items()
|
||||
if issubclass(v, ForeignKeyField)
|
||||
and v.to == join_params.prev_model
|
||||
if self._is_target_relation_key(v, join_params.prev_model)
|
||||
),
|
||||
None,
|
||||
).name
|
||||
@ -129,16 +147,19 @@ class Query:
|
||||
prev_model = model_cls
|
||||
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
||||
|
||||
def _apply_expression_modifiers(
|
||||
self, expr: sqlalchemy.sql.select
|
||||
) -> sqlalchemy.sql.select:
|
||||
def filter(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: # noqa A003
|
||||
if self.filter_clauses:
|
||||
if len(self.filter_clauses) == 1:
|
||||
clause = self.filter_clauses[0]
|
||||
else:
|
||||
clause = sqlalchemy.sql.and_(*self.filter_clauses)
|
||||
expr = expr.where(clause)
|
||||
return expr
|
||||
|
||||
def _apply_expression_modifiers(
|
||||
self, expr: sqlalchemy.sql.select
|
||||
) -> sqlalchemy.sql.select:
|
||||
expr = self.filter(expr)
|
||||
if self.limit_count:
|
||||
expr = expr.limit(self.limit_count)
|
||||
|
||||
|
||||
@ -48,6 +48,7 @@ class QuerySet:
|
||||
limit_count=self.limit_count,
|
||||
)
|
||||
exp = qry.build_select_expression()
|
||||
# print(exp.compile(compile_kwargs={"literal_binds": True}))
|
||||
return exp
|
||||
|
||||
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
|
||||
@ -70,7 +71,7 @@ class QuerySet:
|
||||
if not isinstance(related, (list, tuple)):
|
||||
related = [related]
|
||||
|
||||
related = list(self._select_related) + related
|
||||
related = list(set(list(self._select_related) + related))
|
||||
return self.__class__(
|
||||
model_cls=self.model_cls,
|
||||
filter_clauses=self.filter_clauses,
|
||||
@ -82,13 +83,28 @@ class QuerySet:
|
||||
async def exists(self) -> bool:
|
||||
expr = self.build_select_expression()
|
||||
expr = sqlalchemy.exists(expr).select()
|
||||
# print(expr.compile(compile_kwargs={"literal_binds": True}))
|
||||
return await self.database.fetch_val(expr)
|
||||
|
||||
async def count(self) -> int:
|
||||
expr = self.build_select_expression().alias("subquery_for_count")
|
||||
expr = sqlalchemy.func.count().select().select_from(expr)
|
||||
# print(expr.compile(compile_kwargs={"literal_binds": True}))
|
||||
return await self.database.fetch_val(expr)
|
||||
|
||||
async def delete(self, **kwargs: Any) -> int:
|
||||
if kwargs:
|
||||
return await self.filter(**kwargs).delete()
|
||||
qry = Query(
|
||||
model_cls=self.model_cls,
|
||||
select_related=self._select_related,
|
||||
filter_clauses=self.filter_clauses,
|
||||
offset=self.query_offset,
|
||||
limit_count=self.limit_count,
|
||||
)
|
||||
expr = qry.filter(self.table.delete())
|
||||
return await self.database.execute(expr)
|
||||
|
||||
def limit(self, limit_count: int) -> "QuerySet":
|
||||
return self.__class__(
|
||||
model_cls=self.model_cls,
|
||||
@ -143,6 +159,7 @@ class QuerySet:
|
||||
return await self.filter(**kwargs).all()
|
||||
|
||||
expr = self.build_select_expression()
|
||||
# breakpoint()
|
||||
rows = await self.database.fetch_all(expr)
|
||||
result_rows = [
|
||||
self.model_cls.from_row(row, select_related=self._select_related)
|
||||
|
||||
@ -1,198 +0,0 @@
|
||||
import string
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from random import choices
|
||||
from typing import List, Optional, TYPE_CHECKING, Type, Union
|
||||
from weakref import proxy
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import text
|
||||
|
||||
import ormar # noqa I100
|
||||
from ormar.exceptions import RelationshipInstanceError # noqa I100
|
||||
from ormar.fields.foreign_key import ForeignKeyField # noqa I100
|
||||
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar.models import Model
|
||||
|
||||
|
||||
def get_table_alias() -> str:
|
||||
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
|
||||
|
||||
|
||||
class RelationType(Enum):
|
||||
PRIMARY = 1
|
||||
REVERSE = 2
|
||||
|
||||
|
||||
class AliasManager:
|
||||
def __init__(self) -> None:
|
||||
self._aliases = dict()
|
||||
|
||||
@staticmethod
|
||||
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
|
||||
return [
|
||||
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
|
||||
for column in table.columns
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def prefixed_table_name(alias: str, name: str) -> text:
|
||||
return text(f"{name} {alias}_{name}")
|
||||
|
||||
def add_relation_type(self, field: ForeignKeyField, table_name: str,) -> None:
|
||||
if f"{table_name}_{field.to.Meta.tablename}" not in self._aliases:
|
||||
self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias()
|
||||
if f"{field.to.Meta.tablename}_{table_name}" not in self._aliases:
|
||||
self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias()
|
||||
|
||||
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
|
||||
return self._aliases.get(f"{from_table}_{to_table}", "")
|
||||
|
||||
|
||||
class RelationProxy(list):
|
||||
def __init__(self, relation: "Relation") -> None:
|
||||
super(RelationProxy, self).__init__()
|
||||
self.relation = relation
|
||||
self._owner = self.relation.manager.owner
|
||||
|
||||
def remove(self, item: "Model") -> None:
|
||||
super().remove(item)
|
||||
rel_name = item.resolve_relation_name(item, self._owner)
|
||||
item._orm._get(rel_name).remove(self._owner)
|
||||
|
||||
def append(self, item: "Model") -> None:
|
||||
super().append(item)
|
||||
|
||||
def add(self, item: "Model") -> None:
|
||||
rel_name = item.resolve_relation_name(item, self._owner)
|
||||
setattr(item, rel_name, self._owner)
|
||||
|
||||
|
||||
class Relation:
|
||||
def __init__(self, manager: "RelationsManager", type_: RelationType) -> None:
|
||||
self.manager = manager
|
||||
self._owner = manager.owner
|
||||
self._type = type_
|
||||
self.related_models = (
|
||||
RelationProxy(relation=self) if type_ == RelationType.REVERSE else None
|
||||
)
|
||||
|
||||
def _find_existing(self, child: "Model") -> Optional[int]:
|
||||
for ind, relation_child in enumerate(self.related_models[:]):
|
||||
try:
|
||||
if relation_child.__same__(child):
|
||||
return ind
|
||||
except ReferenceError: # pragma no cover
|
||||
self.related_models.pop(ind)
|
||||
return None
|
||||
|
||||
def add(self, child: "Model") -> None:
|
||||
relation_name = self._owner.resolve_relation_name(self._owner, child)
|
||||
if self._type == RelationType.PRIMARY:
|
||||
self.related_models = child
|
||||
self._owner.__dict__[relation_name] = child
|
||||
else:
|
||||
if self._find_existing(child) is None:
|
||||
self.related_models.append(child)
|
||||
rel = self._owner.__dict__.get(relation_name, [])
|
||||
rel.append(child)
|
||||
self._owner.__dict__[relation_name] = rel
|
||||
|
||||
def remove(self, child: "Model") -> None:
|
||||
relation_name = self._owner.resolve_relation_name(self._owner, child)
|
||||
if self._type == RelationType.PRIMARY:
|
||||
if self.related_models.__same__(child):
|
||||
self.related_models = None
|
||||
del self._owner.__dict__[relation_name]
|
||||
else:
|
||||
position = self._find_existing(child)
|
||||
if position is not None:
|
||||
self.related_models.pop(position)
|
||||
del self._owner.__dict__[relation_name][position]
|
||||
|
||||
def get(self) -> Union[List["Model"], "Model"]:
|
||||
return self.related_models
|
||||
|
||||
def __repr__(self) -> str: # pragma no cover
|
||||
return str(self.related_models)
|
||||
|
||||
|
||||
class RelationsManager:
|
||||
def __init__(
|
||||
self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None
|
||||
) -> None:
|
||||
self.owner = owner
|
||||
self._related_fields = related_fields or []
|
||||
self._related_names = [field.name for field in self._related_fields]
|
||||
self._relations = dict()
|
||||
for field in self._related_fields:
|
||||
self._add_relation(field)
|
||||
|
||||
def _add_relation(self, field: Type[ForeignKeyField]) -> None:
|
||||
self._relations[field.name] = Relation(
|
||||
manager=self,
|
||||
type_=RelationType.PRIMARY if not field.virtual else RelationType.REVERSE,
|
||||
)
|
||||
|
||||
def __contains__(self, item: str) -> bool:
|
||||
return item in self._related_names
|
||||
|
||||
def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]:
|
||||
relation = self._relations.get(name, None)
|
||||
if relation:
|
||||
return relation.get()
|
||||
|
||||
def _get(self, name: str) -> Optional[Relation]:
|
||||
relation = self._relations.get(name, None)
|
||||
if relation:
|
||||
return relation
|
||||
|
||||
@staticmethod
|
||||
def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None:
|
||||
to_field = next(
|
||||
(
|
||||
field
|
||||
for field in child._orm._related_fields
|
||||
if field.to == parent.__class__ or field.to.Meta == parent.Meta
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not to_field: # pragma no cover
|
||||
raise RelationshipInstanceError(
|
||||
f"Model {child.__class__} does not have "
|
||||
f"reference to model {parent.__class__}"
|
||||
)
|
||||
|
||||
to_name = to_field.name
|
||||
if virtual:
|
||||
child_name, to_name = to_name, child_name or child.get_name()
|
||||
child, parent = parent, proxy(child)
|
||||
else:
|
||||
child_name = child_name or child.get_name() + "s"
|
||||
child = proxy(child)
|
||||
|
||||
parent_relation = parent._orm._get(child_name)
|
||||
if not parent_relation:
|
||||
ormar.models.expand_reverse_relationships(child.__class__)
|
||||
name = parent.resolve_relation_name(parent, child)
|
||||
field = parent.Meta.model_fields[name]
|
||||
parent._orm._add_relation(field)
|
||||
parent_relation = parent._orm._get(child_name)
|
||||
parent_relation.add(child)
|
||||
child._orm._get(to_name).add(parent)
|
||||
|
||||
def remove(self, name: str, child: "Model") -> None:
|
||||
relation = self._get(name)
|
||||
relation.remove(child)
|
||||
|
||||
@staticmethod
|
||||
def remove_parent(item: "Model", name: Union[str, "Model"]) -> None:
|
||||
related_model = name
|
||||
name = item.resolve_relation_name(item, related_model)
|
||||
if name in item._orm:
|
||||
relation_name = item.resolve_relation_name(related_model, item)
|
||||
item._orm.remove(name, related_model)
|
||||
related_model._orm.remove(relation_name, item)
|
||||
3
ormar/relations/__init__.py
Normal file
3
ormar/relations/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from ormar.relations.alias_manager import AliasManager
|
||||
|
||||
__all__ = ["AliasManager"]
|
||||
36
ormar/relations/alias_manager.py
Normal file
36
ormar/relations/alias_manager.py
Normal file
@ -0,0 +1,36 @@
|
||||
import string
|
||||
import uuid
|
||||
from random import choices
|
||||
from typing import List
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
def get_table_alias() -> str:
|
||||
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
|
||||
|
||||
|
||||
class AliasManager:
|
||||
def __init__(self) -> None:
|
||||
self._aliases = dict()
|
||||
|
||||
@staticmethod
|
||||
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
|
||||
return [
|
||||
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
|
||||
for column in table.columns
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def prefixed_table_name(alias: str, name: str) -> text:
|
||||
return text(f"{name} {alias}_{name}")
|
||||
|
||||
def add_relation_type(self, to_table_name: str, table_name: str,) -> None:
|
||||
if f"{table_name}_{to_table_name}" not in self._aliases:
|
||||
self._aliases[f"{table_name}_{to_table_name}"] = get_table_alias()
|
||||
if f"{to_table_name}_{table_name}" not in self._aliases:
|
||||
self._aliases[f"{to_table_name}_{table_name}"] = get_table_alias()
|
||||
|
||||
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
|
||||
return self._aliases.get(f"{from_table}_{to_table}", "")
|
||||
332
ormar/relations/relation.py
Normal file
332
ormar/relations/relation.py
Normal file
@ -0,0 +1,332 @@
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, Union
|
||||
from weakref import proxy
|
||||
|
||||
import ormar # noqa I100
|
||||
from ormar.exceptions import RelationshipInstanceError # noqa I100
|
||||
from ormar.fields.foreign_key import ForeignKeyField # noqa I100
|
||||
from ormar.fields.many_to_many import ManyToManyField
|
||||
from ormar.queryset import QuerySet
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
|
||||
|
||||
class RelationType(Enum):
|
||||
PRIMARY = 1
|
||||
REVERSE = 2
|
||||
MULTIPLE = 3
|
||||
|
||||
|
||||
class QuerysetProxy:
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
relation: "Relation"
|
||||
|
||||
def __init__(self, relation: "Relation") -> None:
|
||||
self.relation = relation
|
||||
self.queryset = None
|
||||
|
||||
def _assign_child_to_parent(self, child: "Model") -> None:
|
||||
owner = self.relation._owner
|
||||
rel_name = owner.resolve_relation_name(owner, child)
|
||||
setattr(owner, rel_name, child)
|
||||
|
||||
def _register_related(self, child: Union["Model", List["Model"]]) -> None:
|
||||
if isinstance(child, list):
|
||||
for subchild in child:
|
||||
self._assign_child_to_parent(subchild)
|
||||
else:
|
||||
self._assign_child_to_parent(child)
|
||||
|
||||
async def create_through_instance(self, child: "Model") -> None:
|
||||
queryset = QuerySet(model_cls=self.relation.through)
|
||||
owner_column = self.relation._owner.get_name()
|
||||
child_column = child.get_name()
|
||||
kwargs = {owner_column: self.relation._owner, child_column: child}
|
||||
await queryset.create(**kwargs)
|
||||
|
||||
async def delete_through_instance(self, child: "Model") -> None:
|
||||
queryset = QuerySet(model_cls=self.relation.through)
|
||||
owner_column = self.relation._owner.get_name()
|
||||
child_column = child.get_name()
|
||||
kwargs = {owner_column: self.relation._owner, child_column: child}
|
||||
link_instance = await queryset.filter(**kwargs).get()
|
||||
await link_instance.delete()
|
||||
|
||||
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
|
||||
return self.queryset.filter(**kwargs)
|
||||
|
||||
def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet":
|
||||
return self.queryset.select_related(related)
|
||||
|
||||
async def exists(self) -> bool:
|
||||
return await self.queryset.exists()
|
||||
|
||||
async def count(self) -> int:
|
||||
return await self.queryset.count()
|
||||
|
||||
async def clear(self) -> int:
|
||||
queryset = QuerySet(model_cls=self.relation.through)
|
||||
owner_column = self.relation._owner.get_name()
|
||||
kwargs = {owner_column: self.relation._owner}
|
||||
return await queryset.delete(**kwargs)
|
||||
|
||||
def limit(self, limit_count: int) -> "QuerySet":
|
||||
return self.queryset.limit(limit_count)
|
||||
|
||||
def offset(self, offset: int) -> "QuerySet":
|
||||
return self.queryset.offset(offset)
|
||||
|
||||
async def first(self, **kwargs: Any) -> "Model":
|
||||
first = await self.queryset.first(**kwargs)
|
||||
self._register_related(first)
|
||||
return first
|
||||
|
||||
async def get(self, **kwargs: Any) -> "Model":
|
||||
get = await self.queryset.get(**kwargs)
|
||||
self._register_related(get)
|
||||
return get
|
||||
|
||||
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
|
||||
all_items = await self.queryset.all(**kwargs)
|
||||
self._register_related(all_items)
|
||||
return all_items
|
||||
|
||||
async def create(self, **kwargs: Any) -> "Model":
|
||||
create = await self.queryset.create(**kwargs)
|
||||
self._register_related(create)
|
||||
await self.create_through_instance(create)
|
||||
return create
|
||||
|
||||
|
||||
class RelationProxy(list):
|
||||
def __init__(self, relation: "Relation") -> None:
|
||||
super(RelationProxy, self).__init__()
|
||||
self.relation = relation
|
||||
self._owner = self.relation.manager.owner
|
||||
self.queryset_proxy = QuerysetProxy(relation=self.relation)
|
||||
|
||||
def __getattribute__(self, item: str) -> Any:
|
||||
if item in ["count", "clear"]:
|
||||
if not self.queryset_proxy.queryset:
|
||||
self.queryset_proxy.queryset = self._set_queryset()
|
||||
return getattr(self.queryset_proxy, item)
|
||||
return super().__getattribute__(item)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
if not self.queryset_proxy.queryset:
|
||||
self.queryset_proxy.queryset = self._set_queryset()
|
||||
return getattr(self.queryset_proxy, item)
|
||||
|
||||
def _set_queryset(self) -> QuerySet:
|
||||
owner_table = self.relation._owner.Meta.tablename
|
||||
pkname = self.relation._owner.Meta.pkname
|
||||
pk_value = self.relation._owner.pk
|
||||
if not pk_value:
|
||||
raise RelationshipInstanceError(
|
||||
"You cannot query many to many relationship on unsaved model."
|
||||
)
|
||||
kwargs = {f"{owner_table}__{pkname}": pk_value}
|
||||
queryset = (
|
||||
QuerySet(model_cls=self.relation.to)
|
||||
.select_related(owner_table)
|
||||
.filter(**kwargs)
|
||||
)
|
||||
return queryset
|
||||
|
||||
async def remove(self, item: "Model") -> None:
|
||||
super().remove(item)
|
||||
rel_name = item.resolve_relation_name(item, self._owner)
|
||||
item._orm._get(rel_name).remove(self._owner)
|
||||
if self.relation._type == RelationType.MULTIPLE:
|
||||
await self.queryset_proxy.delete_through_instance(item)
|
||||
|
||||
def append(self, item: "Model") -> None:
|
||||
super().append(item)
|
||||
|
||||
async def add(self, item: "Model") -> None:
|
||||
if self.relation._type == RelationType.MULTIPLE:
|
||||
await self.queryset_proxy.create_through_instance(item)
|
||||
rel_name = item.resolve_relation_name(item, self._owner)
|
||||
setattr(item, rel_name, self._owner)
|
||||
|
||||
|
||||
class Relation:
|
||||
def __init__(
|
||||
self,
|
||||
manager: "RelationsManager",
|
||||
type_: RelationType,
|
||||
to: Type["Model"],
|
||||
through: Type["Model"] = None,
|
||||
) -> None:
|
||||
self.manager = manager
|
||||
self._owner = manager.owner
|
||||
self._type = type_
|
||||
self.to = to
|
||||
self.through = through
|
||||
self.related_models = (
|
||||
RelationProxy(relation=self)
|
||||
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
|
||||
else None
|
||||
)
|
||||
|
||||
def _find_existing(self, child: "Model") -> Optional[int]:
|
||||
for ind, relation_child in enumerate(self.related_models[:]):
|
||||
try:
|
||||
if relation_child.__same__(child):
|
||||
return ind
|
||||
except ReferenceError: # pragma no cover
|
||||
self.related_models.pop(ind)
|
||||
return None
|
||||
|
||||
def add(self, child: "Model") -> None:
|
||||
relation_name = self._owner.resolve_relation_name(self._owner, child)
|
||||
if self._type == RelationType.PRIMARY:
|
||||
self.related_models = child
|
||||
self._owner.__dict__[relation_name] = child
|
||||
else:
|
||||
if self._find_existing(child) is None:
|
||||
self.related_models.append(child)
|
||||
rel = self._owner.__dict__.get(relation_name, [])
|
||||
rel = rel or []
|
||||
if not isinstance(rel, list):
|
||||
rel = [rel]
|
||||
rel.append(child)
|
||||
self._owner.__dict__[relation_name] = rel
|
||||
|
||||
def remove(self, child: "Model") -> None:
|
||||
relation_name = self._owner.resolve_relation_name(self._owner, child)
|
||||
if self._type == RelationType.PRIMARY:
|
||||
if self.related_models.__same__(child):
|
||||
self.related_models = None
|
||||
del self._owner.__dict__[relation_name]
|
||||
else:
|
||||
position = self._find_existing(child)
|
||||
if position is not None:
|
||||
self.related_models.pop(position)
|
||||
del self._owner.__dict__[relation_name][position]
|
||||
|
||||
def get(self) -> Union[List["Model"], "Model"]:
|
||||
return self.related_models
|
||||
|
||||
def __repr__(self) -> str: # pragma no cover
|
||||
return str(self.related_models)
|
||||
|
||||
|
||||
class RelationsManager:
|
||||
def __init__(
|
||||
self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None
|
||||
) -> None:
|
||||
self.owner = proxy(owner)
|
||||
self._related_fields = related_fields or []
|
||||
self._related_names = [field.name for field in self._related_fields]
|
||||
self._relations = dict()
|
||||
for field in self._related_fields:
|
||||
self._add_relation(field)
|
||||
|
||||
def _get_relation_type(self, field: Type[ForeignKeyField]) -> RelationType:
|
||||
if issubclass(field, ManyToManyField):
|
||||
return RelationType.MULTIPLE
|
||||
return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE
|
||||
|
||||
def _add_relation(self, field: Type[ForeignKeyField]) -> None:
|
||||
self._relations[field.name] = Relation(
|
||||
manager=self,
|
||||
type_=self._get_relation_type(field),
|
||||
to=field.to,
|
||||
through=getattr(field, "through", None),
|
||||
)
|
||||
|
||||
def __contains__(self, item: str) -> bool:
|
||||
return item in self._related_names
|
||||
|
||||
def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]:
|
||||
relation = self._relations.get(name, None)
|
||||
if relation is not None:
|
||||
return relation.get()
|
||||
|
||||
def _get(self, name: str) -> Optional[Relation]:
|
||||
relation = self._relations.get(name, None)
|
||||
if relation is not None:
|
||||
return relation
|
||||
|
||||
@staticmethod
|
||||
def register_missing_relation(
|
||||
parent: "Model", child: "Model", child_name: str
|
||||
) -> Relation:
|
||||
ormar.models.expand_reverse_relationships(child.__class__)
|
||||
name = parent.resolve_relation_name(parent, child)
|
||||
field = parent.Meta.model_fields[name]
|
||||
parent._orm._add_relation(field)
|
||||
parent_relation = parent._orm._get(child_name)
|
||||
return parent_relation
|
||||
|
||||
@staticmethod
|
||||
def get_relations_sides_and_names(
|
||||
to_field: Type[ForeignKeyField],
|
||||
parent: "Model",
|
||||
child: "Model",
|
||||
child_name: str,
|
||||
virtual: bool,
|
||||
) -> Tuple["Model", "Model", str, str]:
|
||||
to_name = to_field.name
|
||||
if issubclass(to_field, ManyToManyField):
|
||||
child_name, to_name = (
|
||||
child.resolve_relation_name(parent, child),
|
||||
child.resolve_relation_name(child, parent),
|
||||
)
|
||||
child = proxy(child)
|
||||
elif virtual:
|
||||
child_name, to_name = to_name, child_name or child.get_name()
|
||||
child, parent = parent, proxy(child)
|
||||
else:
|
||||
child_name = child_name or child.get_name() + "s"
|
||||
child = proxy(child)
|
||||
return parent, child, child_name, to_name
|
||||
|
||||
@staticmethod
|
||||
def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None:
|
||||
to_field = next(
|
||||
(
|
||||
field
|
||||
for field in child._orm._related_fields
|
||||
if field.to == parent.__class__ or field.to.Meta == parent.Meta
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not to_field: # pragma no cover
|
||||
raise RelationshipInstanceError(
|
||||
f"Model {child.__class__} does not have "
|
||||
f"reference to model {parent.__class__}"
|
||||
)
|
||||
|
||||
(
|
||||
parent,
|
||||
child,
|
||||
child_name,
|
||||
to_name,
|
||||
) = RelationsManager.get_relations_sides_and_names(
|
||||
to_field, parent, child, child_name, virtual
|
||||
)
|
||||
|
||||
parent_relation = parent._orm._get(child_name)
|
||||
if not parent_relation:
|
||||
parent_relation = RelationsManager.register_missing_relation(
|
||||
parent, child, child_name
|
||||
)
|
||||
parent_relation.add(child)
|
||||
child._orm._get(to_name).add(parent)
|
||||
|
||||
def remove(self, name: str, child: "Model") -> None:
|
||||
relation = self._get(name)
|
||||
relation.remove(child)
|
||||
|
||||
@staticmethod
|
||||
def remove_parent(item: "Model", name: Union[str, "Model"]) -> None:
|
||||
related_model = name
|
||||
name = item.resolve_relation_name(item, related_model)
|
||||
if name in item._orm:
|
||||
relation_name = item.resolve_relation_name(related_model, item)
|
||||
item._orm.remove(name, related_model)
|
||||
related_model._orm.remove(relation_name, item)
|
||||
@ -1,5 +1,3 @@
|
||||
import gc
|
||||
|
||||
import databases
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
@ -179,7 +177,7 @@ async def test_model_removal_from_relations():
|
||||
await track3.save()
|
||||
|
||||
assert len(album.tracks) == 3
|
||||
album.tracks.remove(track1)
|
||||
await album.tracks.remove(track1)
|
||||
assert len(album.tracks) == 2
|
||||
assert track1.album is None
|
||||
|
||||
@ -187,7 +185,7 @@ async def test_model_removal_from_relations():
|
||||
track1 = await Track.objects.get(title="The Birdman")
|
||||
assert track1.album is None
|
||||
|
||||
album.tracks.add(track1)
|
||||
await album.tracks.add(track1)
|
||||
assert len(album.tracks) == 3
|
||||
assert track1.album == album
|
||||
|
||||
|
||||
178
tests/test_many_to_many.py
Normal file
178
tests/test_many_to_many.py
Normal file
@ -0,0 +1,178 @@
|
||||
import databases
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import RelationshipInstanceError
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
class Author(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "authors"
|
||||
database = database
|
||||
metadata = metadata
|
||||
|
||||
id: ormar.Integer(primary_key=True)
|
||||
first_name: ormar.String(max_length=80)
|
||||
last_name: ormar.String(max_length=80)
|
||||
|
||||
|
||||
class Category(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "categories"
|
||||
database = database
|
||||
metadata = metadata
|
||||
|
||||
id: ormar.Integer(primary_key=True)
|
||||
name: ormar.String(max_length=40)
|
||||
|
||||
|
||||
class PostCategory(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "posts_categories"
|
||||
database = database
|
||||
metadata = metadata
|
||||
|
||||
|
||||
class Post(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "posts"
|
||||
database = database
|
||||
metadata = metadata
|
||||
|
||||
id: ormar.Integer(primary_key=True)
|
||||
title: ormar.String(max_length=200)
|
||||
categories: ormar.ManyToMany(Category, through=PostCategory)
|
||||
author: ormar.ForeignKey(Author)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def create_test_database():
|
||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||
metadata.create_all(engine)
|
||||
yield
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def cleanup():
|
||||
yield
|
||||
await PostCategory.objects.delete()
|
||||
await Post.objects.delete()
|
||||
await Category.objects.delete()
|
||||
await Author.objects.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigning_related_objects(cleanup):
|
||||
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
|
||||
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||
news = await Category.objects.create(name="News")
|
||||
|
||||
# Add a category to a post.
|
||||
await post.categories.add(news)
|
||||
# or from the other end:
|
||||
await news.posts.add(post)
|
||||
|
||||
# Creating related object from instance:
|
||||
await post.categories.create(name="Tips")
|
||||
assert len(post.categories) == 2
|
||||
|
||||
post_categories = await post.categories.all()
|
||||
assert len(post_categories) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quering_of_the_m2m_models(cleanup):
|
||||
# orm can do this already.
|
||||
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
|
||||
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||
news = await Category.objects.create(name="News")
|
||||
# tl;dr: `post.categories` exposes the QuerySet API.
|
||||
|
||||
await post.categories.add(news)
|
||||
|
||||
post_categories = await post.categories.all()
|
||||
assert len(post_categories) == 1
|
||||
|
||||
assert news == await post.categories.get(name="News")
|
||||
|
||||
num_posts = await news.posts.count()
|
||||
assert num_posts == 1
|
||||
|
||||
posts_about_m2m = await news.posts.filter(title__contains="M2M").all()
|
||||
assert len(posts_about_m2m) == 1
|
||||
assert posts_about_m2m[0] == post
|
||||
posts_about_python = await Post.objects.filter(categories__name="python").all()
|
||||
assert len(posts_about_python) == 0
|
||||
|
||||
# Traversal of relationships: which categories has Guido contributed to?
|
||||
category = await Category.objects.filter(posts__author=guido).get()
|
||||
assert category == news
|
||||
# or:
|
||||
category2 = await Category.objects.filter(posts__author__first_name="Guido").get()
|
||||
assert category2 == news
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_removal_of_the_relations(cleanup):
|
||||
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
|
||||
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||
news = await Category.objects.create(name="News")
|
||||
await post.categories.add(news)
|
||||
assert len(await post.categories.all()) == 1
|
||||
await post.categories.remove(news)
|
||||
assert len(await post.categories.all()) == 0
|
||||
# or:
|
||||
await news.posts.add(post)
|
||||
assert len(await news.posts.all()) == 1
|
||||
await news.posts.remove(post)
|
||||
assert len(await news.posts.all()) == 0
|
||||
|
||||
# Remove all related objects:
|
||||
await post.categories.add(news)
|
||||
await post.categories.clear()
|
||||
assert len(await post.categories.all()) == 0
|
||||
|
||||
# post would also lose 'news' category when running:
|
||||
await post.categories.add(news)
|
||||
await news.delete()
|
||||
assert len(await post.categories.all()) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selecting_related(cleanup):
|
||||
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
|
||||
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||
news = await Category.objects.create(name="News")
|
||||
recent = await Category.objects.create(name="Recent")
|
||||
await post.categories.add(news)
|
||||
await post.categories.add(recent)
|
||||
assert len(await post.categories.all()) == 2
|
||||
# Loads categories and posts (2 queries) and perform the join in Python.
|
||||
categories = await Category.objects.select_related("posts").all()
|
||||
# No extra queries needed => no more `await`s required.
|
||||
for category in categories:
|
||||
assert category.posts[0] == post
|
||||
|
||||
news_posts = await news.posts.select_related("author").all()
|
||||
assert news_posts[0].author == guido
|
||||
|
||||
assert (await post.categories.limit(1).all())[0] == news
|
||||
assert (await post.categories.offset(1).limit(1).all())[0] == recent
|
||||
|
||||
assert await post.categories.first() == news
|
||||
|
||||
assert await post.categories.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selecting_related_fail_without_saving(cleanup):
|
||||
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
|
||||
post = Post(title="Hello, M2M", author=guido)
|
||||
with pytest.raises(RelationshipInstanceError):
|
||||
await post.categories.all()
|
||||
Reference in New Issue
Block a user