Merge pull request #4 from collerek/many_to_many

Many to many
This commit is contained in:
collerek
2020-09-14 22:37:22 +07:00
committed by GitHub
19 changed files with 880 additions and 254 deletions

BIN
.coverage

Binary file not shown.

View File

@ -175,6 +175,86 @@ tracks = await Track.objects.limit(1).all()
assert len(tracks) == 1 assert len(tracks) == 1
``` ```
Since version >=0.3 Ormar supports also many to many relationships
```python
import databases
import ormar
import sqlalchemy
database = databases.Database("sqlite:///db.sqlite")
metadata = sqlalchemy.MetaData()
class Author(ormar.Model):
class Meta:
tablename = "authors"
database = database
metadata = metadata
id: ormar.Integer(primary_key=True)
first_name: ormar.String(max_length=80)
last_name: ormar.String(max_length=80)
class Category(ormar.Model):
class Meta:
tablename = "categories"
database = database
metadata = metadata
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=40)
class PostCategory(ormar.Model):
class Meta:
tablename = "posts_categories"
database = database
metadata = metadata
class Post(ormar.Model):
class Meta:
tablename = "posts"
database = database
metadata = metadata
id: ormar.Integer(primary_key=True)
title: ormar.String(max_length=200)
categories: ormar.ManyToMany(Category, through=PostCategory)
author: ormar.ForeignKey(Author)
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
# Add a category to a post.
await post.categories.add(news)
# or from the other end:
await news.posts.add(post)
# Creating related object from instance:
await post.categories.create(name="Tips")
assert len(await post.categories.all()) == 2
# Many to many relation exposes a list of related models
# and an API of the Queryset:
assert news == await post.categories.get(name="News")
# with all Queryset methods - filtering, selecting related, counting etc.
await news.posts.filter(title__contains="M2M").all()
await Category.objects.filter(posts__author=guido).get()
# related models of many to many relation can be prefetched
news_posts = await news.posts.select_related("author").all()
assert news_posts[0].author == guido
# Removal of the relationship by one
await news.posts.remove(post)
# or all at once
await news.posts.clear()
```
## Data types ## Data types
The following keyword arguments are supported on all field types. The following keyword arguments are supported on all field types.

View File

@ -9,13 +9,14 @@ from ormar.fields import (
ForeignKey, ForeignKey,
Integer, Integer,
JSON, JSON,
ManyToMany,
String, String,
Text, Text,
Time, Time,
) )
from ormar.models import Model from ormar.models import Model
__version__ = "0.2.2" __version__ = "0.3.0"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",
@ -28,6 +29,7 @@ __all__ = [
"Date", "Date",
"Decimal", "Decimal",
"Float", "Float",
"ManyToMany",
"Model", "Model",
"ModelDefinitionError", "ModelDefinitionError",
"ModelNotSet", "ModelNotSet",

View File

@ -1,5 +1,6 @@
from ormar.fields.base import BaseField from ormar.fields.base import BaseField
from ormar.fields.foreign_key import ForeignKey from ormar.fields.foreign_key import ForeignKey
from ormar.fields.many_to_many import ManyToMany
from ormar.fields.model_fields import ( from ormar.fields.model_fields import (
BigInteger, BigInteger,
Boolean, Boolean,
@ -27,5 +28,6 @@ __all__ = [
"Float", "Float",
"Time", "Time",
"ForeignKey", "ForeignKey",
"ManyToMany",
"BaseField", "BaseField",
] ]

View File

@ -22,6 +22,7 @@ class BaseField:
index: bool index: bool
unique: bool unique: bool
pydantic_only: bool pydantic_only: bool
virtual: bool = False
default: Any default: Any
server_default: Any server_default: Any
@ -34,8 +35,7 @@ class BaseField:
default = cls.default if cls.default is not None else cls.server_default default = cls.default if cls.default is not None else cls.server_default
if callable(default): if callable(default):
return Field(default_factory=default) return Field(default_factory=default)
else: return Field(default=default)
return Field(default=default)
return None return None
@classmethod @classmethod

View File

@ -22,7 +22,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
return fk(**init_dict) return fk(**init_dict)
def ForeignKey( def ForeignKey( # noqa CFQ002
to: Type["Model"], to: Type["Model"],
*, *,
name: str = None, name: str = None,
@ -30,6 +30,8 @@ def ForeignKey(
nullable: bool = True, nullable: bool = True,
related_name: str = None, related_name: str = None,
virtual: bool = False, virtual: bool = False,
onupdate: str = None,
ondelete: str = None,
) -> Type[object]: ) -> Type[object]:
fk_string = to.Meta.tablename + "." + to.Meta.pkname fk_string = to.Meta.tablename + "." + to.Meta.pkname
to_field = to.__fields__[to.Meta.pkname] to_field = to.__fields__[to.Meta.pkname]
@ -37,7 +39,11 @@ def ForeignKey(
to=to, to=to,
name=name, name=name,
nullable=nullable, nullable=nullable,
constraints=[sqlalchemy.schema.ForeignKey(fk_string)], constraints=[
sqlalchemy.schema.ForeignKey(
fk_string, ondelete=ondelete, onupdate=onupdate
)
],
unique=unique, unique=unique,
column_type=to_field.type_.column_type, column_type=to_field.type_.column_type,
related_name=related_name, related_name=related_name,
@ -117,7 +123,7 @@ class ForeignKeyField(BaseField):
cls, value: Any, child: "Model", to_register: bool = True cls, value: Any, child: "Model", to_register: bool = True
) -> Optional[Union["Model", List["Model"]]]: ) -> Optional[Union["Model", List["Model"]]]:
if value is None: if value is None:
return None return None if not cls.virtual else []
constructors = { constructors = {
f"{cls.to.__name__}": cls._register_existing_model, f"{cls.to.__name__}": cls._register_existing_model,

View File

@ -0,0 +1,40 @@
from typing import TYPE_CHECKING, Type
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
if TYPE_CHECKING: # pragma no cover
from ormar.models import Model
def ManyToMany(
to: Type["Model"],
through: Type["Model"],
*,
name: str = None,
unique: bool = False,
related_name: str = None,
virtual: bool = False,
) -> Type[object]:
to_field = to.__fields__[to.Meta.pkname]
namespace = dict(
to=to,
through=through,
name=name,
nullable=True,
unique=unique,
column_type=to_field.type_.column_type,
related_name=related_name,
virtual=virtual,
primary_key=False,
index=False,
pydantic_only=False,
default=None,
server_default=None,
)
return type("ManyToMany", (ManyToManyField, BaseField), namespace)
class ManyToManyField(ForeignKeyField):
through: Type["Model"]

View File

@ -1,16 +1,18 @@
import logging
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
import databases import databases
import pydantic import pydantic
import sqlalchemy import sqlalchemy
from pydantic import BaseConfig from pydantic import BaseConfig
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo, ModelField
from ormar import ForeignKey, ModelDefinitionError # noqa I100 from ormar import ForeignKey, ModelDefinitionError, Integer # noqa I100
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.fields.many_to_many import ManyToMany, ManyToManyField
from ormar.queryset import QuerySet from ormar.queryset import QuerySet
from ormar.relations import AliasManager from ormar.relations.alias_manager import AliasManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -30,7 +32,14 @@ class ModelMeta:
def register_relation_on_build(table_name: str, field: ForeignKey) -> None: def register_relation_on_build(table_name: str, field: ForeignKey) -> None:
alias_manager.add_relation_type(field, table_name) alias_manager.add_relation_type(field.to.Meta.tablename, table_name)
def register_many_to_many_relation_on_build(table_name: str, field: ManyToMany) -> None:
alias_manager.add_relation_type(field.through.Meta.tablename, table_name)
alias_manager.add_relation_type(
field.through.Meta.tablename, field.to.Meta.tablename
)
def reverse_field_not_already_registered( def reverse_field_not_already_registered(
@ -51,15 +60,72 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
if reverse_field_not_already_registered( if reverse_field_not_already_registered(
child, child_model_name, parent_model child, child_model_name, parent_model
): ):
register_reverse_model_fields(parent_model, child, child_model_name) register_reverse_model_fields(
parent_model, child, child_model_name, model_field
)
def register_reverse_model_fields( def register_reverse_model_fields(
model: Type["Model"], child: Type["Model"], child_model_name: str model: Type["Model"],
child: Type["Model"],
child_model_name: str,
model_field: Type["ForeignKeyField"],
) -> None: ) -> None:
model.Meta.model_fields[child_model_name] = ForeignKey( if issubclass(model_field, ManyToManyField):
child, name=child_model_name, virtual=True model.Meta.model_fields[child_model_name] = ManyToMany(
child, through=model_field.through, name=child_model_name, virtual=True
)
# register foreign keys on through model
adjust_through_many_to_many_model(model, child, model_field)
else:
model.Meta.model_fields[child_model_name] = ForeignKey(
child, name=child_model_name, virtual=True
)
def adjust_through_many_to_many_model(
model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField]
) -> None:
model_field.through.Meta.model_fields[model.get_name()] = ForeignKey(
model, name=model.get_name(), ondelete="CASCADE"
) )
model_field.through.Meta.model_fields[child.get_name()] = ForeignKey(
child, name=child.get_name(), ondelete="CASCADE"
)
create_and_append_m2m_fk(model, model_field)
create_and_append_m2m_fk(child, model_field)
create_pydantic_field(model.get_name(), model, model_field)
create_pydantic_field(child.get_name(), child, model_field)
def create_pydantic_field(
field_name: str, model: Type["Model"], model_field: Type[ManyToManyField]
) -> None:
model_field.through.__fields__[field_name] = ModelField(
name=field_name,
type_=Optional[model],
model_config=model.__config__,
required=False,
class_validators=model.__validators__,
)
def create_and_append_m2m_fk(
model: Type["Model"], model_field: Type[ManyToManyField]
) -> None:
column = sqlalchemy.Column(
model.get_name(),
model.Meta.table.columns.get(model.Meta.pkname).type,
sqlalchemy.schema.ForeignKey(
model.Meta.tablename + "." + model.Meta.pkname,
ondelete="CASCADE",
onupdate="CASCADE",
),
)
model_field.through.Meta.columns.append(column)
model_field.through.Meta.table.append_column(column)
def check_pk_column_validity( def check_pk_column_validity(
@ -77,17 +143,34 @@ def sqlalchemy_columns_from_model_fields(
) -> Tuple[Optional[str], List[sqlalchemy.Column]]: ) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
columns = [] columns = []
pkname = None pkname = None
if len(model_fields.keys()) == 0:
model_fields["id"] = Integer(name="id", primary_key=True)
logging.warning(
"Table {table_name} had no fields so auto "
"Integer primary key named `id` created."
)
for field_name, field in model_fields.items(): for field_name, field in model_fields.items():
if field.primary_key: if field.primary_key:
pkname = check_pk_column_validity(field_name, field, pkname) pkname = check_pk_column_validity(field_name, field, pkname)
if not field.pydantic_only: if (
not field.pydantic_only
and not field.virtual
and not issubclass(field, ManyToManyField)
):
columns.append(field.get_column(field_name)) columns.append(field.get_column(field_name))
if issubclass(field, ForeignKeyField): register_relation_in_alias_manager(table_name, field)
register_relation_on_build(table_name, field)
return pkname, columns return pkname, columns
def register_relation_in_alias_manager(
table_name: str, field: Type[ForeignKeyField]
) -> None:
if issubclass(field, ManyToManyField):
register_many_to_many_relation_on_build(table_name, field)
elif issubclass(field, ForeignKeyField):
register_relation_on_build(table_name, field)
def populate_default_pydantic_field_value( def populate_default_pydantic_field_value(
type_: Type[BaseField], field: str, attrs: dict type_: Type[BaseField], field: str, attrs: dict
) -> dict: ) -> dict:
@ -109,15 +192,11 @@ def populate_pydantic_default_values(attrs: Dict) -> Dict:
return attrs return attrs
def extract_annotations_and_module( def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict:
attrs: dict, new_model: "ModelMetaclass", bases: Tuple attrs["__annotations__"] = attrs.get("__annotations__") or bases[0].__dict__.get(
) -> dict: "__annotations__", {}
annotations = attrs.get("__annotations__") or new_model.__annotations__ )
attrs["__annotations__"] = annotations
attrs = populate_pydantic_default_values(attrs) attrs = populate_pydantic_default_values(attrs)
attrs["__module__"] = attrs["__module__"] or bases[0].__module__
attrs["__annotations__"] = attrs["__annotations__"] or bases[0].__annotations__
return attrs return attrs
@ -175,20 +254,26 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
class ModelMetaclass(pydantic.main.ModelMetaclass): class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
attrs["Config"] = get_pydantic_base_orm_config() attrs["Config"] = get_pydantic_base_orm_config()
attrs = extract_annotations_and_default_vals(attrs, bases)
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
) )
# breakpoint()
if hasattr(new_model, "Meta"): if hasattr(new_model, "Meta"):
# attrs = extract_annotations_and_default_vals(attrs, bases)
attrs = extract_annotations_and_module(attrs, new_model, bases)
new_model = populate_meta_orm_model_fields(attrs, new_model) new_model = populate_meta_orm_model_fields(attrs, new_model)
new_model = populate_meta_tablename_columns_and_pk(name, new_model) new_model = populate_meta_tablename_columns_and_pk(name, new_model)
new_model = populate_meta_sqlalchemy_table_if_required(new_model) new_model = populate_meta_sqlalchemy_table_if_required(new_model)
expand_reverse_relationships(new_model) expand_reverse_relationships(new_model)
if new_model.Meta.pkname not in attrs["__annotations__"]:
field_name = new_model.Meta.pkname
field = Integer(name=field_name, primary_key=True)
attrs["__annotations__"][field_name] = field
populate_default_pydantic_field_value(field, field_name, attrs)
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
) )

View File

@ -4,6 +4,7 @@ from typing import Any, List, Tuple, Union
import sqlalchemy import sqlalchemy
import ormar.queryset # noqa I100 import ormar.queryset # noqa I100
from ormar.fields.many_to_many import ManyToManyField
from ormar.models import NewBaseModel # noqa I100 from ormar.models import NewBaseModel # noqa I100
@ -40,10 +41,19 @@ class Model(NewBaseModel):
if select_related: if select_related:
related_models = group_related_list(select_related) related_models = group_related_list(select_related)
# breakpoint()
if (
previous_table
and previous_table in cls.Meta.model_fields
and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField)
):
previous_table = cls.Meta.model_fields[
previous_table
].through.Meta.tablename
table_prefix = cls.Meta.alias_manager.resolve_relation_join( table_prefix = cls.Meta.alias_manager.resolve_relation_join(
previous_table, cls.Meta.table.name previous_table, cls.Meta.table.name
) )
previous_table = cls.Meta.table.name previous_table = cls.Meta.table.name
item = cls.populate_nested_models_from_row( item = cls.populate_nested_models_from_row(

View File

@ -23,7 +23,8 @@ from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.metaclass import ModelMeta, ModelMetaclass
from ormar.models.modelproxy import ModelTableProxy from ormar.models.modelproxy import ModelTableProxy
from ormar.relations import AliasManager, RelationsManager from ormar.relations.alias_manager import AliasManager
from ormar.relations.relation import RelationsManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models.model import Model from ormar.models.model import Model
@ -96,14 +97,17 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
kwargs.get(related), self, to_register=True kwargs.get(related), self, to_register=True
) )
def __setattr__(self, name: str, value: Any) -> None: def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
if name in self.__slots__: if name in self.__slots__:
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
elif name == "pk": elif name == "pk":
object.__setattr__(self, self.Meta.pkname, value) object.__setattr__(self, self.Meta.pkname, value)
elif name in self._orm: elif name in self._orm:
model = self.Meta.model_fields[name].expand_relationship(value, self) model = self.Meta.model_fields[name].expand_relationship(value, self)
self.__dict__[name] = model if isinstance(self.__dict__.get(name), list):
self.__dict__[name].append(model)
else:
self.__dict__[name] = model
else: else:
value = ( value = (
self._convert_json(name, value, "dumps") self._convert_json(name, value, "dumps")
@ -115,11 +119,11 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"): if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"):
return object.__getattribute__(self, item) return object.__getattribute__(self, item)
elif item != "_extract_related_names" and item in self._extract_related_names(): if item != "_extract_related_names" and item in self._extract_related_names():
return self._extract_related_model_instead_of_field(item) return self._extract_related_model_instead_of_field(item)
elif item == "pk": if item == "pk":
return self.__dict__.get(self.Meta.pkname, None) return self.__dict__.get(self.Meta.pkname, None)
elif item != "__fields__" and item in self.__fields__: if item != "__fields__" and item in self.__fields__:
value = self.__dict__.get(item, None) value = self.__dict__.get(item, None)
value = self._convert_json(item, value, "loads") value = self._convert_json(item, value, "loads")
return value return value
@ -131,15 +135,20 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if item in self._orm: if item in self._orm:
return self._orm.get(item) return self._orm.get(item)
def __eq__(self, other: "Model") -> bool:
if isinstance(other, NewBaseModel):
return self.__same__(other)
return super().__eq__(other) # pragma no cover
def __same__(self, other: "Model") -> bool: def __same__(self, other: "Model") -> bool:
return ( return (
self._orm_id == other._orm_id self._orm_id == other._orm_id
or self.__dict__ == other.__dict__ or self.dict() == other.dict()
or (self.pk == other.pk and self.pk is not None) or (self.pk == other.pk and self.pk is not None)
) )
@classmethod @classmethod
def get_name(cls, title: bool = False, lower: bool = True) -> str: def get_name(cls, lower: bool = True) -> str:
name = cls.__name__ name = cls.__name__
if lower: if lower:
name = name.lower() name = name.lower()

View File

@ -5,6 +5,7 @@ from sqlalchemy import text
import ormar # noqa I100 import ormar # noqa I100
from ormar.exceptions import QueryDefinitionError from ormar.exceptions import QueryDefinitionError
from ormar.fields.many_to_many import ManyToManyField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -128,6 +129,10 @@ class QueryClause:
# against which the comparison is being made. # against which the comparison is being made.
previous_table = model_cls.Meta.tablename previous_table = model_cls.Meta.tablename
for part in related_parts: for part in related_parts:
if issubclass(model_cls.Meta.model_fields[part], ManyToManyField):
previous_table = model_cls.Meta.model_fields[
part
].through.Meta.tablename
current_table = model_cls.Meta.model_fields[part].to.Meta.tablename current_table = model_cls.Meta.model_fields[part].to.Meta.tablename
manager = model_cls.Meta.alias_manager manager = model_cls.Meta.alias_manager
table_prefix = manager.resolve_relation_join(previous_table, current_table) table_prefix = manager.resolve_relation_join(previous_table, current_table)

View File

@ -4,8 +4,10 @@ import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
import ormar # noqa I100 import ormar # noqa I100
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.relations import AliasManager from ormar.fields.many_to_many import ManyToManyField
from ormar.relations.alias_manager import AliasManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -63,6 +65,15 @@ class Query:
) )
for part in item.split("__"): for part in item.split("__"):
if issubclass(
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
):
_fields = join_parameters.model_cls.Meta.model_fields
new_part = _fields[part].to.get_name()
join_parameters = self._build_join_parameters(
part, join_parameters, is_multi=True
)
part = new_part
join_parameters = self._build_join_parameters(part, join_parameters) join_parameters = self._build_join_parameters(part, join_parameters)
expr = sqlalchemy.sql.select(self.columns) expr = sqlalchemy.sql.select(self.columns)
@ -83,23 +94,30 @@ class Query:
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
return text(f"{left_part}={right_part}") return text(f"{left_part}={right_part}")
def _is_target_relation_key(
self, field: BaseField, target_model: Type["Model"]
) -> bool:
return issubclass(field, ForeignKeyField) and field.to.Meta == target_model.Meta
def _build_join_parameters( def _build_join_parameters(
self, part: str, join_params: JoinParameters self, part: str, join_params: JoinParameters, is_multi: bool = False
) -> JoinParameters: ) -> JoinParameters:
model_cls = join_params.model_cls.Meta.model_fields[part].to if is_multi:
model_cls = join_params.model_cls.Meta.model_fields[part].through
else:
model_cls = join_params.model_cls.Meta.model_fields[part].to
to_table = model_cls.Meta.table.name to_table = model_cls.Meta.table.name
alias = model_cls.Meta.alias_manager.resolve_relation_join( alias = model_cls.Meta.alias_manager.resolve_relation_join(
join_params.from_table, to_table join_params.from_table, to_table
) )
if alias not in self.used_aliases: if alias not in self.used_aliases:
if join_params.prev_model.Meta.model_fields[part].virtual: if join_params.prev_model.Meta.model_fields[part].virtual or is_multi:
to_key = next( to_key = next(
( (
v v
for k, v in model_cls.Meta.model_fields.items() for k, v in model_cls.Meta.model_fields.items()
if issubclass(v, ForeignKeyField) if self._is_target_relation_key(v, join_params.prev_model)
and v.to == join_params.prev_model
), ),
None, None,
).name ).name
@ -129,16 +147,19 @@ class Query:
prev_model = model_cls prev_model = model_cls
return JoinParameters(prev_model, previous_alias, from_table, model_cls) return JoinParameters(prev_model, previous_alias, from_table, model_cls)
def _apply_expression_modifiers( def filter(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: # noqa A003
self, expr: sqlalchemy.sql.select
) -> sqlalchemy.sql.select:
if self.filter_clauses: if self.filter_clauses:
if len(self.filter_clauses) == 1: if len(self.filter_clauses) == 1:
clause = self.filter_clauses[0] clause = self.filter_clauses[0]
else: else:
clause = sqlalchemy.sql.and_(*self.filter_clauses) clause = sqlalchemy.sql.and_(*self.filter_clauses)
expr = expr.where(clause) expr = expr.where(clause)
return expr
def _apply_expression_modifiers(
self, expr: sqlalchemy.sql.select
) -> sqlalchemy.sql.select:
expr = self.filter(expr)
if self.limit_count: if self.limit_count:
expr = expr.limit(self.limit_count) expr = expr.limit(self.limit_count)

View File

@ -48,6 +48,7 @@ class QuerySet:
limit_count=self.limit_count, limit_count=self.limit_count,
) )
exp = qry.build_select_expression() exp = qry.build_select_expression()
# print(exp.compile(compile_kwargs={"literal_binds": True}))
return exp return exp
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
@ -70,7 +71,7 @@ class QuerySet:
if not isinstance(related, (list, tuple)): if not isinstance(related, (list, tuple)):
related = [related] related = [related]
related = list(self._select_related) + related related = list(set(list(self._select_related) + related))
return self.__class__( return self.__class__(
model_cls=self.model_cls, model_cls=self.model_cls,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
@ -82,13 +83,28 @@ class QuerySet:
async def exists(self) -> bool: async def exists(self) -> bool:
expr = self.build_select_expression() expr = self.build_select_expression()
expr = sqlalchemy.exists(expr).select() expr = sqlalchemy.exists(expr).select()
# print(expr.compile(compile_kwargs={"literal_binds": True}))
return await self.database.fetch_val(expr) return await self.database.fetch_val(expr)
async def count(self) -> int: async def count(self) -> int:
expr = self.build_select_expression().alias("subquery_for_count") expr = self.build_select_expression().alias("subquery_for_count")
expr = sqlalchemy.func.count().select().select_from(expr) expr = sqlalchemy.func.count().select().select_from(expr)
# print(expr.compile(compile_kwargs={"literal_binds": True}))
return await self.database.fetch_val(expr) return await self.database.fetch_val(expr)
async def delete(self, **kwargs: Any) -> int:
if kwargs:
return await self.filter(**kwargs).delete()
qry = Query(
model_cls=self.model_cls,
select_related=self._select_related,
filter_clauses=self.filter_clauses,
offset=self.query_offset,
limit_count=self.limit_count,
)
expr = qry.filter(self.table.delete())
return await self.database.execute(expr)
def limit(self, limit_count: int) -> "QuerySet": def limit(self, limit_count: int) -> "QuerySet":
return self.__class__( return self.__class__(
model_cls=self.model_cls, model_cls=self.model_cls,
@ -118,11 +134,11 @@ class QuerySet:
async def get(self, **kwargs: Any) -> "Model": async def get(self, **kwargs: Any) -> "Model":
if kwargs: if kwargs:
return await self.filter(**kwargs).get() return await self.filter(**kwargs).get()
if not self.filter_clauses:
expr = self.build_select_expression().limit(2)
else: else:
if not self.filter_clauses: expr = self.build_select_expression()
expr = self.build_select_expression().limit(2)
else:
expr = self.build_select_expression()
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
@ -143,6 +159,7 @@ class QuerySet:
return await self.filter(**kwargs).all() return await self.filter(**kwargs).all()
expr = self.build_select_expression() expr = self.build_select_expression()
# breakpoint()
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
result_rows = [ result_rows = [
self.model_cls.from_row(row, select_related=self._select_related) self.model_cls.from_row(row, select_related=self._select_related)

View File

@ -1,198 +0,0 @@
import string
import uuid
from enum import Enum
from random import choices
from typing import List, Optional, TYPE_CHECKING, Type, Union
from weakref import proxy
import sqlalchemy
from sqlalchemy import text
import ormar # noqa I100
from ormar.exceptions import RelationshipInstanceError # noqa I100
from ormar.fields.foreign_key import ForeignKeyField # noqa I100
if TYPE_CHECKING: # pragma no cover
from ormar.models import Model
def get_table_alias() -> str:
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
class RelationType(Enum):
PRIMARY = 1
REVERSE = 2
class AliasManager:
def __init__(self) -> None:
self._aliases = dict()
@staticmethod
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
return [
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
for column in table.columns
]
@staticmethod
def prefixed_table_name(alias: str, name: str) -> text:
return text(f"{name} {alias}_{name}")
def add_relation_type(self, field: ForeignKeyField, table_name: str,) -> None:
if f"{table_name}_{field.to.Meta.tablename}" not in self._aliases:
self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias()
if f"{field.to.Meta.tablename}_{table_name}" not in self._aliases:
self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias()
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
return self._aliases.get(f"{from_table}_{to_table}", "")
class RelationProxy(list):
def __init__(self, relation: "Relation") -> None:
super(RelationProxy, self).__init__()
self.relation = relation
self._owner = self.relation.manager.owner
def remove(self, item: "Model") -> None:
super().remove(item)
rel_name = item.resolve_relation_name(item, self._owner)
item._orm._get(rel_name).remove(self._owner)
def append(self, item: "Model") -> None:
super().append(item)
def add(self, item: "Model") -> None:
rel_name = item.resolve_relation_name(item, self._owner)
setattr(item, rel_name, self._owner)
class Relation:
def __init__(self, manager: "RelationsManager", type_: RelationType) -> None:
self.manager = manager
self._owner = manager.owner
self._type = type_
self.related_models = (
RelationProxy(relation=self) if type_ == RelationType.REVERSE else None
)
def _find_existing(self, child: "Model") -> Optional[int]:
for ind, relation_child in enumerate(self.related_models[:]):
try:
if relation_child.__same__(child):
return ind
except ReferenceError: # pragma no cover
self.related_models.pop(ind)
return None
def add(self, child: "Model") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY:
self.related_models = child
self._owner.__dict__[relation_name] = child
else:
if self._find_existing(child) is None:
self.related_models.append(child)
rel = self._owner.__dict__.get(relation_name, [])
rel.append(child)
self._owner.__dict__[relation_name] = rel
def remove(self, child: "Model") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY:
if self.related_models.__same__(child):
self.related_models = None
del self._owner.__dict__[relation_name]
else:
position = self._find_existing(child)
if position is not None:
self.related_models.pop(position)
del self._owner.__dict__[relation_name][position]
def get(self) -> Union[List["Model"], "Model"]:
return self.related_models
def __repr__(self) -> str: # pragma no cover
return str(self.related_models)
class RelationsManager:
def __init__(
self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None
) -> None:
self.owner = owner
self._related_fields = related_fields or []
self._related_names = [field.name for field in self._related_fields]
self._relations = dict()
for field in self._related_fields:
self._add_relation(field)
def _add_relation(self, field: Type[ForeignKeyField]) -> None:
self._relations[field.name] = Relation(
manager=self,
type_=RelationType.PRIMARY if not field.virtual else RelationType.REVERSE,
)
def __contains__(self, item: str) -> bool:
return item in self._related_names
def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]:
relation = self._relations.get(name, None)
if relation:
return relation.get()
def _get(self, name: str) -> Optional[Relation]:
relation = self._relations.get(name, None)
if relation:
return relation
@staticmethod
def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None:
to_field = next(
(
field
for field in child._orm._related_fields
if field.to == parent.__class__ or field.to.Meta == parent.Meta
),
None,
)
if not to_field: # pragma no cover
raise RelationshipInstanceError(
f"Model {child.__class__} does not have "
f"reference to model {parent.__class__}"
)
to_name = to_field.name
if virtual:
child_name, to_name = to_name, child_name or child.get_name()
child, parent = parent, proxy(child)
else:
child_name = child_name or child.get_name() + "s"
child = proxy(child)
parent_relation = parent._orm._get(child_name)
if not parent_relation:
ormar.models.expand_reverse_relationships(child.__class__)
name = parent.resolve_relation_name(parent, child)
field = parent.Meta.model_fields[name]
parent._orm._add_relation(field)
parent_relation = parent._orm._get(child_name)
parent_relation.add(child)
child._orm._get(to_name).add(parent)
def remove(self, name: str, child: "Model") -> None:
relation = self._get(name)
relation.remove(child)
@staticmethod
def remove_parent(item: "Model", name: Union[str, "Model"]) -> None:
related_model = name
name = item.resolve_relation_name(item, related_model)
if name in item._orm:
relation_name = item.resolve_relation_name(related_model, item)
item._orm.remove(name, related_model)
related_model._orm.remove(relation_name, item)

View File

@ -0,0 +1,3 @@
from ormar.relations.alias_manager import AliasManager
__all__ = ["AliasManager"]

View File

@ -0,0 +1,36 @@
import string
import uuid
from random import choices
from typing import List
import sqlalchemy
from sqlalchemy import text
def get_table_alias() -> str:
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
class AliasManager:
def __init__(self) -> None:
self._aliases = dict()
@staticmethod
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
return [
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
for column in table.columns
]
@staticmethod
def prefixed_table_name(alias: str, name: str) -> text:
return text(f"{name} {alias}_{name}")
def add_relation_type(self, to_table_name: str, table_name: str,) -> None:
if f"{table_name}_{to_table_name}" not in self._aliases:
self._aliases[f"{table_name}_{to_table_name}"] = get_table_alias()
if f"{to_table_name}_{table_name}" not in self._aliases:
self._aliases[f"{to_table_name}_{table_name}"] = get_table_alias()
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
return self._aliases.get(f"{from_table}_{to_table}", "")

332
ormar/relations/relation.py Normal file
View File

@ -0,0 +1,332 @@
from enum import Enum
from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, Union
from weakref import proxy
import ormar # noqa I100
from ormar.exceptions import RelationshipInstanceError # noqa I100
from ormar.fields.foreign_key import ForeignKeyField # noqa I100
from ormar.fields.many_to_many import ManyToManyField
from ormar.queryset import QuerySet
if TYPE_CHECKING: # pragma no cover
from ormar import Model
class RelationType(Enum):
PRIMARY = 1
REVERSE = 2
MULTIPLE = 3
class QuerysetProxy:
if TYPE_CHECKING: # pragma no cover
relation: "Relation"
def __init__(self, relation: "Relation") -> None:
self.relation = relation
self.queryset = None
def _assign_child_to_parent(self, child: "Model") -> None:
owner = self.relation._owner
rel_name = owner.resolve_relation_name(owner, child)
setattr(owner, rel_name, child)
def _register_related(self, child: Union["Model", List["Model"]]) -> None:
if isinstance(child, list):
for subchild in child:
self._assign_child_to_parent(subchild)
else:
self._assign_child_to_parent(child)
async def create_through_instance(self, child: "Model") -> None:
queryset = QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name()
child_column = child.get_name()
kwargs = {owner_column: self.relation._owner, child_column: child}
await queryset.create(**kwargs)
async def delete_through_instance(self, child: "Model") -> None:
queryset = QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name()
child_column = child.get_name()
kwargs = {owner_column: self.relation._owner, child_column: child}
link_instance = await queryset.filter(**kwargs).get()
await link_instance.delete()
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
return self.queryset.filter(**kwargs)
def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet":
return self.queryset.select_related(related)
async def exists(self) -> bool:
return await self.queryset.exists()
async def count(self) -> int:
return await self.queryset.count()
async def clear(self) -> int:
queryset = QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name()
kwargs = {owner_column: self.relation._owner}
return await queryset.delete(**kwargs)
def limit(self, limit_count: int) -> "QuerySet":
return self.queryset.limit(limit_count)
def offset(self, offset: int) -> "QuerySet":
return self.queryset.offset(offset)
async def first(self, **kwargs: Any) -> "Model":
first = await self.queryset.first(**kwargs)
self._register_related(first)
return first
async def get(self, **kwargs: Any) -> "Model":
get = await self.queryset.get(**kwargs)
self._register_related(get)
return get
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
all_items = await self.queryset.all(**kwargs)
self._register_related(all_items)
return all_items
async def create(self, **kwargs: Any) -> "Model":
create = await self.queryset.create(**kwargs)
self._register_related(create)
await self.create_through_instance(create)
return create
class RelationProxy(list):
def __init__(self, relation: "Relation") -> None:
super(RelationProxy, self).__init__()
self.relation = relation
self._owner = self.relation.manager.owner
self.queryset_proxy = QuerysetProxy(relation=self.relation)
def __getattribute__(self, item: str) -> Any:
if item in ["count", "clear"]:
if not self.queryset_proxy.queryset:
self.queryset_proxy.queryset = self._set_queryset()
return getattr(self.queryset_proxy, item)
return super().__getattribute__(item)
def __getattr__(self, item: str) -> Any:
if not self.queryset_proxy.queryset:
self.queryset_proxy.queryset = self._set_queryset()
return getattr(self.queryset_proxy, item)
def _set_queryset(self) -> QuerySet:
owner_table = self.relation._owner.Meta.tablename
pkname = self.relation._owner.Meta.pkname
pk_value = self.relation._owner.pk
if not pk_value:
raise RelationshipInstanceError(
"You cannot query many to many relationship on unsaved model."
)
kwargs = {f"{owner_table}__{pkname}": pk_value}
queryset = (
QuerySet(model_cls=self.relation.to)
.select_related(owner_table)
.filter(**kwargs)
)
return queryset
async def remove(self, item: "Model") -> None:
super().remove(item)
rel_name = item.resolve_relation_name(item, self._owner)
item._orm._get(rel_name).remove(self._owner)
if self.relation._type == RelationType.MULTIPLE:
await self.queryset_proxy.delete_through_instance(item)
def append(self, item: "Model") -> None:
super().append(item)
async def add(self, item: "Model") -> None:
if self.relation._type == RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner)
setattr(item, rel_name, self._owner)
class Relation:
def __init__(
self,
manager: "RelationsManager",
type_: RelationType,
to: Type["Model"],
through: Type["Model"] = None,
) -> None:
self.manager = manager
self._owner = manager.owner
self._type = type_
self.to = to
self.through = through
self.related_models = (
RelationProxy(relation=self)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None
)
def _find_existing(self, child: "Model") -> Optional[int]:
for ind, relation_child in enumerate(self.related_models[:]):
try:
if relation_child.__same__(child):
return ind
except ReferenceError: # pragma no cover
self.related_models.pop(ind)
return None
def add(self, child: "Model") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY:
self.related_models = child
self._owner.__dict__[relation_name] = child
else:
if self._find_existing(child) is None:
self.related_models.append(child)
rel = self._owner.__dict__.get(relation_name, [])
rel = rel or []
if not isinstance(rel, list):
rel = [rel]
rel.append(child)
self._owner.__dict__[relation_name] = rel
def remove(self, child: "Model") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY:
if self.related_models.__same__(child):
self.related_models = None
del self._owner.__dict__[relation_name]
else:
position = self._find_existing(child)
if position is not None:
self.related_models.pop(position)
del self._owner.__dict__[relation_name][position]
def get(self) -> Union[List["Model"], "Model"]:
return self.related_models
def __repr__(self) -> str: # pragma no cover
return str(self.related_models)
class RelationsManager:
def __init__(
self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None
) -> None:
self.owner = proxy(owner)
self._related_fields = related_fields or []
self._related_names = [field.name for field in self._related_fields]
self._relations = dict()
for field in self._related_fields:
self._add_relation(field)
def _get_relation_type(self, field: Type[ForeignKeyField]) -> RelationType:
if issubclass(field, ManyToManyField):
return RelationType.MULTIPLE
return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE
def _add_relation(self, field: Type[ForeignKeyField]) -> None:
self._relations[field.name] = Relation(
manager=self,
type_=self._get_relation_type(field),
to=field.to,
through=getattr(field, "through", None),
)
def __contains__(self, item: str) -> bool:
return item in self._related_names
def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]:
relation = self._relations.get(name, None)
if relation is not None:
return relation.get()
def _get(self, name: str) -> Optional[Relation]:
relation = self._relations.get(name, None)
if relation is not None:
return relation
@staticmethod
def register_missing_relation(
parent: "Model", child: "Model", child_name: str
) -> Relation:
ormar.models.expand_reverse_relationships(child.__class__)
name = parent.resolve_relation_name(parent, child)
field = parent.Meta.model_fields[name]
parent._orm._add_relation(field)
parent_relation = parent._orm._get(child_name)
return parent_relation
@staticmethod
def get_relations_sides_and_names(
to_field: Type[ForeignKeyField],
parent: "Model",
child: "Model",
child_name: str,
virtual: bool,
) -> Tuple["Model", "Model", str, str]:
to_name = to_field.name
if issubclass(to_field, ManyToManyField):
child_name, to_name = (
child.resolve_relation_name(parent, child),
child.resolve_relation_name(child, parent),
)
child = proxy(child)
elif virtual:
child_name, to_name = to_name, child_name or child.get_name()
child, parent = parent, proxy(child)
else:
child_name = child_name or child.get_name() + "s"
child = proxy(child)
return parent, child, child_name, to_name
@staticmethod
def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None:
to_field = next(
(
field
for field in child._orm._related_fields
if field.to == parent.__class__ or field.to.Meta == parent.Meta
),
None,
)
if not to_field: # pragma no cover
raise RelationshipInstanceError(
f"Model {child.__class__} does not have "
f"reference to model {parent.__class__}"
)
(
parent,
child,
child_name,
to_name,
) = RelationsManager.get_relations_sides_and_names(
to_field, parent, child, child_name, virtual
)
parent_relation = parent._orm._get(child_name)
if not parent_relation:
parent_relation = RelationsManager.register_missing_relation(
parent, child, child_name
)
parent_relation.add(child)
child._orm._get(to_name).add(parent)
def remove(self, name: str, child: "Model") -> None:
relation = self._get(name)
relation.remove(child)
@staticmethod
def remove_parent(item: "Model", name: Union[str, "Model"]) -> None:
related_model = name
name = item.resolve_relation_name(item, related_model)
if name in item._orm:
relation_name = item.resolve_relation_name(related_model, item)
item._orm.remove(name, related_model)
related_model._orm.remove(relation_name, item)

View File

@ -1,5 +1,3 @@
import gc
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
@ -179,7 +177,7 @@ async def test_model_removal_from_relations():
await track3.save() await track3.save()
assert len(album.tracks) == 3 assert len(album.tracks) == 3
album.tracks.remove(track1) await album.tracks.remove(track1)
assert len(album.tracks) == 2 assert len(album.tracks) == 2
assert track1.album is None assert track1.album is None
@ -187,7 +185,7 @@ async def test_model_removal_from_relations():
track1 = await Track.objects.get(title="The Birdman") track1 = await Track.objects.get(title="The Birdman")
assert track1.album is None assert track1.album is None
album.tracks.add(track1) await album.tracks.add(track1)
assert len(album.tracks) == 3 assert len(album.tracks) == 3
assert track1.album == album assert track1.album == album

178
tests/test_many_to_many.py Normal file
View File

@ -0,0 +1,178 @@
import databases
import pytest
import sqlalchemy
import ormar
from ormar.exceptions import RelationshipInstanceError
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Author(ormar.Model):
class Meta:
tablename = "authors"
database = database
metadata = metadata
id: ormar.Integer(primary_key=True)
first_name: ormar.String(max_length=80)
last_name: ormar.String(max_length=80)
class Category(ormar.Model):
class Meta:
tablename = "categories"
database = database
metadata = metadata
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=40)
class PostCategory(ormar.Model):
class Meta:
tablename = "posts_categories"
database = database
metadata = metadata
class Post(ormar.Model):
class Meta:
tablename = "posts"
database = database
metadata = metadata
id: ormar.Integer(primary_key=True)
title: ormar.String(max_length=200)
categories: ormar.ManyToMany(Category, through=PostCategory)
author: ormar.ForeignKey(Author)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.fixture(scope="function")
async def cleanup():
yield
await PostCategory.objects.delete()
await Post.objects.delete()
await Category.objects.delete()
await Author.objects.delete()
@pytest.mark.asyncio
async def test_assigning_related_objects(cleanup):
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
# Add a category to a post.
await post.categories.add(news)
# or from the other end:
await news.posts.add(post)
# Creating related object from instance:
await post.categories.create(name="Tips")
assert len(post.categories) == 2
post_categories = await post.categories.all()
assert len(post_categories) == 2
@pytest.mark.asyncio
async def test_quering_of_the_m2m_models(cleanup):
# orm can do this already.
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
# tl;dr: `post.categories` exposes the QuerySet API.
await post.categories.add(news)
post_categories = await post.categories.all()
assert len(post_categories) == 1
assert news == await post.categories.get(name="News")
num_posts = await news.posts.count()
assert num_posts == 1
posts_about_m2m = await news.posts.filter(title__contains="M2M").all()
assert len(posts_about_m2m) == 1
assert posts_about_m2m[0] == post
posts_about_python = await Post.objects.filter(categories__name="python").all()
assert len(posts_about_python) == 0
# Traversal of relationships: which categories has Guido contributed to?
category = await Category.objects.filter(posts__author=guido).get()
assert category == news
# or:
category2 = await Category.objects.filter(posts__author__first_name="Guido").get()
assert category2 == news
@pytest.mark.asyncio
async def test_removal_of_the_relations(cleanup):
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
await post.categories.add(news)
assert len(await post.categories.all()) == 1
await post.categories.remove(news)
assert len(await post.categories.all()) == 0
# or:
await news.posts.add(post)
assert len(await news.posts.all()) == 1
await news.posts.remove(post)
assert len(await news.posts.all()) == 0
# Remove all related objects:
await post.categories.add(news)
await post.categories.clear()
assert len(await post.categories.all()) == 0
# post would also lose 'news' category when running:
await post.categories.add(news)
await news.delete()
assert len(await post.categories.all()) == 0
@pytest.mark.asyncio
async def test_selecting_related(cleanup):
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
recent = await Category.objects.create(name="Recent")
await post.categories.add(news)
await post.categories.add(recent)
assert len(await post.categories.all()) == 2
# Loads categories and posts (2 queries) and perform the join in Python.
categories = await Category.objects.select_related("posts").all()
# No extra queries needed => no more `await`s required.
for category in categories:
assert category.posts[0] == post
news_posts = await news.posts.select_related("author").all()
assert news_posts[0].author == guido
assert (await post.categories.limit(1).all())[0] == news
assert (await post.categories.offset(1).limit(1).all())[0] == recent
assert await post.categories.first() == news
assert await post.categories.exists()
@pytest.mark.asyncio
async def test_selecting_related_fail_without_saving(cleanup):
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = Post(title="Hello, M2M", author=guido)
with pytest.raises(RelationshipInstanceError):
await post.categories.all()