80
README.md
80
README.md
@ -175,6 +175,86 @@ tracks = await Track.objects.limit(1).all()
|
|||||||
assert len(tracks) == 1
|
assert len(tracks) == 1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Since version >=0.3 Ormar supports also many to many relationships
|
||||||
|
```python
|
||||||
|
import databases
|
||||||
|
import ormar
|
||||||
|
import sqlalchemy
|
||||||
|
|
||||||
|
database = databases.Database("sqlite:///db.sqlite")
|
||||||
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
|
class Author(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "authors"
|
||||||
|
database = database
|
||||||
|
metadata = metadata
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
first_name: ormar.String(max_length=80)
|
||||||
|
last_name: ormar.String(max_length=80)
|
||||||
|
|
||||||
|
|
||||||
|
class Category(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "categories"
|
||||||
|
database = database
|
||||||
|
metadata = metadata
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
name: ormar.String(max_length=40)
|
||||||
|
|
||||||
|
|
||||||
|
class PostCategory(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "posts_categories"
|
||||||
|
database = database
|
||||||
|
metadata = metadata
|
||||||
|
|
||||||
|
|
||||||
|
class Post(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "posts"
|
||||||
|
database = database
|
||||||
|
metadata = metadata
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
title: ormar.String(max_length=200)
|
||||||
|
categories: ormar.ManyToMany(Category, through=PostCategory)
|
||||||
|
author: ormar.ForeignKey(Author)
|
||||||
|
|
||||||
|
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
|
||||||
|
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||||
|
news = await Category.objects.create(name="News")
|
||||||
|
|
||||||
|
# Add a category to a post.
|
||||||
|
await post.categories.add(news)
|
||||||
|
# or from the other end:
|
||||||
|
await news.posts.add(post)
|
||||||
|
|
||||||
|
# Creating related object from instance:
|
||||||
|
await post.categories.create(name="Tips")
|
||||||
|
assert len(await post.categories.all()) == 2
|
||||||
|
|
||||||
|
# Many to many relation exposes a list of related models
|
||||||
|
# and an API of the Queryset:
|
||||||
|
assert news == await post.categories.get(name="News")
|
||||||
|
|
||||||
|
# with all Queryset methods - filtering, selecting related, counting etc.
|
||||||
|
await news.posts.filter(title__contains="M2M").all()
|
||||||
|
await Category.objects.filter(posts__author=guido).get()
|
||||||
|
|
||||||
|
# related models of many to many relation can be prefetched
|
||||||
|
news_posts = await news.posts.select_related("author").all()
|
||||||
|
assert news_posts[0].author == guido
|
||||||
|
|
||||||
|
# Removal of the relationship by one
|
||||||
|
await news.posts.remove(post)
|
||||||
|
# or all at once
|
||||||
|
await news.posts.clear()
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
## Data types
|
## Data types
|
||||||
|
|
||||||
The following keyword arguments are supported on all field types.
|
The following keyword arguments are supported on all field types.
|
||||||
|
|||||||
@ -9,13 +9,14 @@ from ormar.fields import (
|
|||||||
ForeignKey,
|
ForeignKey,
|
||||||
Integer,
|
Integer,
|
||||||
JSON,
|
JSON,
|
||||||
|
ManyToMany,
|
||||||
String,
|
String,
|
||||||
Text,
|
Text,
|
||||||
Time,
|
Time,
|
||||||
)
|
)
|
||||||
from ormar.models import Model
|
from ormar.models import Model
|
||||||
|
|
||||||
__version__ = "0.2.2"
|
__version__ = "0.3.0"
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Integer",
|
"Integer",
|
||||||
"BigInteger",
|
"BigInteger",
|
||||||
@ -28,6 +29,7 @@ __all__ = [
|
|||||||
"Date",
|
"Date",
|
||||||
"Decimal",
|
"Decimal",
|
||||||
"Float",
|
"Float",
|
||||||
|
"ManyToMany",
|
||||||
"Model",
|
"Model",
|
||||||
"ModelDefinitionError",
|
"ModelDefinitionError",
|
||||||
"ModelNotSet",
|
"ModelNotSet",
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from ormar.fields.base import BaseField
|
from ormar.fields.base import BaseField
|
||||||
from ormar.fields.foreign_key import ForeignKey
|
from ormar.fields.foreign_key import ForeignKey
|
||||||
|
from ormar.fields.many_to_many import ManyToMany
|
||||||
from ormar.fields.model_fields import (
|
from ormar.fields.model_fields import (
|
||||||
BigInteger,
|
BigInteger,
|
||||||
Boolean,
|
Boolean,
|
||||||
@ -27,5 +28,6 @@ __all__ = [
|
|||||||
"Float",
|
"Float",
|
||||||
"Time",
|
"Time",
|
||||||
"ForeignKey",
|
"ForeignKey",
|
||||||
|
"ManyToMany",
|
||||||
"BaseField",
|
"BaseField",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -22,6 +22,7 @@ class BaseField:
|
|||||||
index: bool
|
index: bool
|
||||||
unique: bool
|
unique: bool
|
||||||
pydantic_only: bool
|
pydantic_only: bool
|
||||||
|
virtual: bool = False
|
||||||
|
|
||||||
default: Any
|
default: Any
|
||||||
server_default: Any
|
server_default: Any
|
||||||
@ -34,7 +35,6 @@ class BaseField:
|
|||||||
default = cls.default if cls.default is not None else cls.server_default
|
default = cls.default if cls.default is not None else cls.server_default
|
||||||
if callable(default):
|
if callable(default):
|
||||||
return Field(default_factory=default)
|
return Field(default_factory=default)
|
||||||
else:
|
|
||||||
return Field(default=default)
|
return Field(default=default)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -22,7 +22,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
|
|||||||
return fk(**init_dict)
|
return fk(**init_dict)
|
||||||
|
|
||||||
|
|
||||||
def ForeignKey(
|
def ForeignKey( # noqa CFQ002
|
||||||
to: Type["Model"],
|
to: Type["Model"],
|
||||||
*,
|
*,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
@ -30,6 +30,8 @@ def ForeignKey(
|
|||||||
nullable: bool = True,
|
nullable: bool = True,
|
||||||
related_name: str = None,
|
related_name: str = None,
|
||||||
virtual: bool = False,
|
virtual: bool = False,
|
||||||
|
onupdate: str = None,
|
||||||
|
ondelete: str = None,
|
||||||
) -> Type[object]:
|
) -> Type[object]:
|
||||||
fk_string = to.Meta.tablename + "." + to.Meta.pkname
|
fk_string = to.Meta.tablename + "." + to.Meta.pkname
|
||||||
to_field = to.__fields__[to.Meta.pkname]
|
to_field = to.__fields__[to.Meta.pkname]
|
||||||
@ -37,7 +39,11 @@ def ForeignKey(
|
|||||||
to=to,
|
to=to,
|
||||||
name=name,
|
name=name,
|
||||||
nullable=nullable,
|
nullable=nullable,
|
||||||
constraints=[sqlalchemy.schema.ForeignKey(fk_string)],
|
constraints=[
|
||||||
|
sqlalchemy.schema.ForeignKey(
|
||||||
|
fk_string, ondelete=ondelete, onupdate=onupdate
|
||||||
|
)
|
||||||
|
],
|
||||||
unique=unique,
|
unique=unique,
|
||||||
column_type=to_field.type_.column_type,
|
column_type=to_field.type_.column_type,
|
||||||
related_name=related_name,
|
related_name=related_name,
|
||||||
@ -117,7 +123,7 @@ class ForeignKeyField(BaseField):
|
|||||||
cls, value: Any, child: "Model", to_register: bool = True
|
cls, value: Any, child: "Model", to_register: bool = True
|
||||||
) -> Optional[Union["Model", List["Model"]]]:
|
) -> Optional[Union["Model", List["Model"]]]:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None if not cls.virtual else []
|
||||||
|
|
||||||
constructors = {
|
constructors = {
|
||||||
f"{cls.to.__name__}": cls._register_existing_model,
|
f"{cls.to.__name__}": cls._register_existing_model,
|
||||||
|
|||||||
40
ormar/fields/many_to_many.py
Normal file
40
ormar/fields/many_to_many.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from typing import TYPE_CHECKING, Type
|
||||||
|
|
||||||
|
from ormar.fields import BaseField
|
||||||
|
from ormar.fields.foreign_key import ForeignKeyField
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma no cover
|
||||||
|
from ormar.models import Model
|
||||||
|
|
||||||
|
|
||||||
|
def ManyToMany(
|
||||||
|
to: Type["Model"],
|
||||||
|
through: Type["Model"],
|
||||||
|
*,
|
||||||
|
name: str = None,
|
||||||
|
unique: bool = False,
|
||||||
|
related_name: str = None,
|
||||||
|
virtual: bool = False,
|
||||||
|
) -> Type[object]:
|
||||||
|
to_field = to.__fields__[to.Meta.pkname]
|
||||||
|
namespace = dict(
|
||||||
|
to=to,
|
||||||
|
through=through,
|
||||||
|
name=name,
|
||||||
|
nullable=True,
|
||||||
|
unique=unique,
|
||||||
|
column_type=to_field.type_.column_type,
|
||||||
|
related_name=related_name,
|
||||||
|
virtual=virtual,
|
||||||
|
primary_key=False,
|
||||||
|
index=False,
|
||||||
|
pydantic_only=False,
|
||||||
|
default=None,
|
||||||
|
server_default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return type("ManyToMany", (ManyToManyField, BaseField), namespace)
|
||||||
|
|
||||||
|
|
||||||
|
class ManyToManyField(ForeignKeyField):
|
||||||
|
through: Type["Model"]
|
||||||
@ -1,16 +1,18 @@
|
|||||||
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
|
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
|
||||||
|
|
||||||
import databases
|
import databases
|
||||||
import pydantic
|
import pydantic
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from pydantic import BaseConfig
|
from pydantic import BaseConfig
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo, ModelField
|
||||||
|
|
||||||
from ormar import ForeignKey, ModelDefinitionError # noqa I100
|
from ormar import ForeignKey, ModelDefinitionError, Integer # noqa I100
|
||||||
from ormar.fields import BaseField
|
from ormar.fields import BaseField
|
||||||
from ormar.fields.foreign_key import ForeignKeyField
|
from ormar.fields.foreign_key import ForeignKeyField
|
||||||
|
from ormar.fields.many_to_many import ManyToMany, ManyToManyField
|
||||||
from ormar.queryset import QuerySet
|
from ormar.queryset import QuerySet
|
||||||
from ormar.relations import AliasManager
|
from ormar.relations.alias_manager import AliasManager
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
@ -30,7 +32,14 @@ class ModelMeta:
|
|||||||
|
|
||||||
|
|
||||||
def register_relation_on_build(table_name: str, field: ForeignKey) -> None:
|
def register_relation_on_build(table_name: str, field: ForeignKey) -> None:
|
||||||
alias_manager.add_relation_type(field, table_name)
|
alias_manager.add_relation_type(field.to.Meta.tablename, table_name)
|
||||||
|
|
||||||
|
|
||||||
|
def register_many_to_many_relation_on_build(table_name: str, field: ManyToMany) -> None:
|
||||||
|
alias_manager.add_relation_type(field.through.Meta.tablename, table_name)
|
||||||
|
alias_manager.add_relation_type(
|
||||||
|
field.through.Meta.tablename, field.to.Meta.tablename
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def reverse_field_not_already_registered(
|
def reverse_field_not_already_registered(
|
||||||
@ -51,17 +60,74 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
|
|||||||
if reverse_field_not_already_registered(
|
if reverse_field_not_already_registered(
|
||||||
child, child_model_name, parent_model
|
child, child_model_name, parent_model
|
||||||
):
|
):
|
||||||
register_reverse_model_fields(parent_model, child, child_model_name)
|
register_reverse_model_fields(
|
||||||
|
parent_model, child, child_model_name, model_field
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_reverse_model_fields(
|
def register_reverse_model_fields(
|
||||||
model: Type["Model"], child: Type["Model"], child_model_name: str
|
model: Type["Model"],
|
||||||
|
child: Type["Model"],
|
||||||
|
child_model_name: str,
|
||||||
|
model_field: Type["ForeignKeyField"],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if issubclass(model_field, ManyToManyField):
|
||||||
|
model.Meta.model_fields[child_model_name] = ManyToMany(
|
||||||
|
child, through=model_field.through, name=child_model_name, virtual=True
|
||||||
|
)
|
||||||
|
# register foreign keys on through model
|
||||||
|
adjust_through_many_to_many_model(model, child, model_field)
|
||||||
|
else:
|
||||||
model.Meta.model_fields[child_model_name] = ForeignKey(
|
model.Meta.model_fields[child_model_name] = ForeignKey(
|
||||||
child, name=child_model_name, virtual=True
|
child, name=child_model_name, virtual=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_through_many_to_many_model(
|
||||||
|
model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField]
|
||||||
|
) -> None:
|
||||||
|
model_field.through.Meta.model_fields[model.get_name()] = ForeignKey(
|
||||||
|
model, name=model.get_name(), ondelete="CASCADE"
|
||||||
|
)
|
||||||
|
model_field.through.Meta.model_fields[child.get_name()] = ForeignKey(
|
||||||
|
child, name=child.get_name(), ondelete="CASCADE"
|
||||||
|
)
|
||||||
|
|
||||||
|
create_and_append_m2m_fk(model, model_field)
|
||||||
|
create_and_append_m2m_fk(child, model_field)
|
||||||
|
|
||||||
|
create_pydantic_field(model.get_name(), model, model_field)
|
||||||
|
create_pydantic_field(child.get_name(), child, model_field)
|
||||||
|
|
||||||
|
|
||||||
|
def create_pydantic_field(
|
||||||
|
field_name: str, model: Type["Model"], model_field: Type[ManyToManyField]
|
||||||
|
) -> None:
|
||||||
|
model_field.through.__fields__[field_name] = ModelField(
|
||||||
|
name=field_name,
|
||||||
|
type_=Optional[model],
|
||||||
|
model_config=model.__config__,
|
||||||
|
required=False,
|
||||||
|
class_validators=model.__validators__,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_and_append_m2m_fk(
|
||||||
|
model: Type["Model"], model_field: Type[ManyToManyField]
|
||||||
|
) -> None:
|
||||||
|
column = sqlalchemy.Column(
|
||||||
|
model.get_name(),
|
||||||
|
model.Meta.table.columns.get(model.Meta.pkname).type,
|
||||||
|
sqlalchemy.schema.ForeignKey(
|
||||||
|
model.Meta.tablename + "." + model.Meta.pkname,
|
||||||
|
ondelete="CASCADE",
|
||||||
|
onupdate="CASCADE",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model_field.through.Meta.columns.append(column)
|
||||||
|
model_field.through.Meta.table.append_column(column)
|
||||||
|
|
||||||
|
|
||||||
def check_pk_column_validity(
|
def check_pk_column_validity(
|
||||||
field_name: str, field: BaseField, pkname: str
|
field_name: str, field: BaseField, pkname: str
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
@ -77,17 +143,34 @@ def sqlalchemy_columns_from_model_fields(
|
|||||||
) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
|
) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
|
||||||
columns = []
|
columns = []
|
||||||
pkname = None
|
pkname = None
|
||||||
|
if len(model_fields.keys()) == 0:
|
||||||
|
model_fields["id"] = Integer(name="id", primary_key=True)
|
||||||
|
logging.warning(
|
||||||
|
"Table {table_name} had no fields so auto "
|
||||||
|
"Integer primary key named `id` created."
|
||||||
|
)
|
||||||
for field_name, field in model_fields.items():
|
for field_name, field in model_fields.items():
|
||||||
if field.primary_key:
|
if field.primary_key:
|
||||||
pkname = check_pk_column_validity(field_name, field, pkname)
|
pkname = check_pk_column_validity(field_name, field, pkname)
|
||||||
if not field.pydantic_only:
|
if (
|
||||||
|
not field.pydantic_only
|
||||||
|
and not field.virtual
|
||||||
|
and not issubclass(field, ManyToManyField)
|
||||||
|
):
|
||||||
columns.append(field.get_column(field_name))
|
columns.append(field.get_column(field_name))
|
||||||
if issubclass(field, ForeignKeyField):
|
register_relation_in_alias_manager(table_name, field)
|
||||||
register_relation_on_build(table_name, field)
|
|
||||||
|
|
||||||
return pkname, columns
|
return pkname, columns
|
||||||
|
|
||||||
|
|
||||||
|
def register_relation_in_alias_manager(
|
||||||
|
table_name: str, field: Type[ForeignKeyField]
|
||||||
|
) -> None:
|
||||||
|
if issubclass(field, ManyToManyField):
|
||||||
|
register_many_to_many_relation_on_build(table_name, field)
|
||||||
|
elif issubclass(field, ForeignKeyField):
|
||||||
|
register_relation_on_build(table_name, field)
|
||||||
|
|
||||||
|
|
||||||
def populate_default_pydantic_field_value(
|
def populate_default_pydantic_field_value(
|
||||||
type_: Type[BaseField], field: str, attrs: dict
|
type_: Type[BaseField], field: str, attrs: dict
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@ -109,15 +192,11 @@ def populate_pydantic_default_values(attrs: Dict) -> Dict:
|
|||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
|
|
||||||
def extract_annotations_and_module(
|
def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict:
|
||||||
attrs: dict, new_model: "ModelMetaclass", bases: Tuple
|
attrs["__annotations__"] = attrs.get("__annotations__") or bases[0].__dict__.get(
|
||||||
) -> dict:
|
"__annotations__", {}
|
||||||
annotations = attrs.get("__annotations__") or new_model.__annotations__
|
)
|
||||||
attrs["__annotations__"] = annotations
|
|
||||||
attrs = populate_pydantic_default_values(attrs)
|
attrs = populate_pydantic_default_values(attrs)
|
||||||
|
|
||||||
attrs["__module__"] = attrs["__module__"] or bases[0].__module__
|
|
||||||
attrs["__annotations__"] = attrs["__annotations__"] or bases[0].__annotations__
|
|
||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
|
|
||||||
@ -175,20 +254,26 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
|
|||||||
|
|
||||||
class ModelMetaclass(pydantic.main.ModelMetaclass):
|
class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||||
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
|
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
|
||||||
|
|
||||||
attrs["Config"] = get_pydantic_base_orm_config()
|
attrs["Config"] = get_pydantic_base_orm_config()
|
||||||
|
attrs = extract_annotations_and_default_vals(attrs, bases)
|
||||||
new_model = super().__new__( # type: ignore
|
new_model = super().__new__( # type: ignore
|
||||||
mcs, name, bases, attrs
|
mcs, name, bases, attrs
|
||||||
)
|
)
|
||||||
|
# breakpoint()
|
||||||
|
|
||||||
if hasattr(new_model, "Meta"):
|
if hasattr(new_model, "Meta"):
|
||||||
|
# attrs = extract_annotations_and_default_vals(attrs, bases)
|
||||||
attrs = extract_annotations_and_module(attrs, new_model, bases)
|
|
||||||
new_model = populate_meta_orm_model_fields(attrs, new_model)
|
new_model = populate_meta_orm_model_fields(attrs, new_model)
|
||||||
new_model = populate_meta_tablename_columns_and_pk(name, new_model)
|
new_model = populate_meta_tablename_columns_and_pk(name, new_model)
|
||||||
new_model = populate_meta_sqlalchemy_table_if_required(new_model)
|
new_model = populate_meta_sqlalchemy_table_if_required(new_model)
|
||||||
expand_reverse_relationships(new_model)
|
expand_reverse_relationships(new_model)
|
||||||
|
|
||||||
|
if new_model.Meta.pkname not in attrs["__annotations__"]:
|
||||||
|
field_name = new_model.Meta.pkname
|
||||||
|
field = Integer(name=field_name, primary_key=True)
|
||||||
|
attrs["__annotations__"][field_name] = field
|
||||||
|
populate_default_pydantic_field_value(field, field_name, attrs)
|
||||||
|
|
||||||
new_model = super().__new__( # type: ignore
|
new_model = super().__new__( # type: ignore
|
||||||
mcs, name, bases, attrs
|
mcs, name, bases, attrs
|
||||||
)
|
)
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from typing import Any, List, Tuple, Union
|
|||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
import ormar.queryset # noqa I100
|
import ormar.queryset # noqa I100
|
||||||
|
from ormar.fields.many_to_many import ManyToManyField
|
||||||
from ormar.models import NewBaseModel # noqa I100
|
from ormar.models import NewBaseModel # noqa I100
|
||||||
|
|
||||||
|
|
||||||
@ -40,10 +41,19 @@ class Model(NewBaseModel):
|
|||||||
if select_related:
|
if select_related:
|
||||||
related_models = group_related_list(select_related)
|
related_models = group_related_list(select_related)
|
||||||
|
|
||||||
|
# breakpoint()
|
||||||
|
if (
|
||||||
|
previous_table
|
||||||
|
and previous_table in cls.Meta.model_fields
|
||||||
|
and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField)
|
||||||
|
):
|
||||||
|
previous_table = cls.Meta.model_fields[
|
||||||
|
previous_table
|
||||||
|
].through.Meta.tablename
|
||||||
|
|
||||||
table_prefix = cls.Meta.alias_manager.resolve_relation_join(
|
table_prefix = cls.Meta.alias_manager.resolve_relation_join(
|
||||||
previous_table, cls.Meta.table.name
|
previous_table, cls.Meta.table.name
|
||||||
)
|
)
|
||||||
|
|
||||||
previous_table = cls.Meta.table.name
|
previous_table = cls.Meta.table.name
|
||||||
|
|
||||||
item = cls.populate_nested_models_from_row(
|
item = cls.populate_nested_models_from_row(
|
||||||
|
|||||||
@ -23,7 +23,8 @@ from ormar.fields import BaseField
|
|||||||
from ormar.fields.foreign_key import ForeignKeyField
|
from ormar.fields.foreign_key import ForeignKeyField
|
||||||
from ormar.models.metaclass import ModelMeta, ModelMetaclass
|
from ormar.models.metaclass import ModelMeta, ModelMetaclass
|
||||||
from ormar.models.modelproxy import ModelTableProxy
|
from ormar.models.modelproxy import ModelTableProxy
|
||||||
from ormar.relations import AliasManager, RelationsManager
|
from ormar.relations.alias_manager import AliasManager
|
||||||
|
from ormar.relations.relation import RelationsManager
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar.models.model import Model
|
from ormar.models.model import Model
|
||||||
@ -96,13 +97,16 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
kwargs.get(related), self, to_register=True
|
kwargs.get(related), self, to_register=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def __setattr__(self, name: str, value: Any) -> None:
|
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
|
||||||
if name in self.__slots__:
|
if name in self.__slots__:
|
||||||
object.__setattr__(self, name, value)
|
object.__setattr__(self, name, value)
|
||||||
elif name == "pk":
|
elif name == "pk":
|
||||||
object.__setattr__(self, self.Meta.pkname, value)
|
object.__setattr__(self, self.Meta.pkname, value)
|
||||||
elif name in self._orm:
|
elif name in self._orm:
|
||||||
model = self.Meta.model_fields[name].expand_relationship(value, self)
|
model = self.Meta.model_fields[name].expand_relationship(value, self)
|
||||||
|
if isinstance(self.__dict__.get(name), list):
|
||||||
|
self.__dict__[name].append(model)
|
||||||
|
else:
|
||||||
self.__dict__[name] = model
|
self.__dict__[name] = model
|
||||||
else:
|
else:
|
||||||
value = (
|
value = (
|
||||||
@ -115,11 +119,11 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
def __getattribute__(self, item: str) -> Any:
|
def __getattribute__(self, item: str) -> Any:
|
||||||
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"):
|
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"):
|
||||||
return object.__getattribute__(self, item)
|
return object.__getattribute__(self, item)
|
||||||
elif item != "_extract_related_names" and item in self._extract_related_names():
|
if item != "_extract_related_names" and item in self._extract_related_names():
|
||||||
return self._extract_related_model_instead_of_field(item)
|
return self._extract_related_model_instead_of_field(item)
|
||||||
elif item == "pk":
|
if item == "pk":
|
||||||
return self.__dict__.get(self.Meta.pkname, None)
|
return self.__dict__.get(self.Meta.pkname, None)
|
||||||
elif item != "__fields__" and item in self.__fields__:
|
if item != "__fields__" and item in self.__fields__:
|
||||||
value = self.__dict__.get(item, None)
|
value = self.__dict__.get(item, None)
|
||||||
value = self._convert_json(item, value, "loads")
|
value = self._convert_json(item, value, "loads")
|
||||||
return value
|
return value
|
||||||
@ -131,15 +135,20 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
if item in self._orm:
|
if item in self._orm:
|
||||||
return self._orm.get(item)
|
return self._orm.get(item)
|
||||||
|
|
||||||
|
def __eq__(self, other: "Model") -> bool:
|
||||||
|
if isinstance(other, NewBaseModel):
|
||||||
|
return self.__same__(other)
|
||||||
|
return super().__eq__(other) # pragma no cover
|
||||||
|
|
||||||
def __same__(self, other: "Model") -> bool:
|
def __same__(self, other: "Model") -> bool:
|
||||||
return (
|
return (
|
||||||
self._orm_id == other._orm_id
|
self._orm_id == other._orm_id
|
||||||
or self.__dict__ == other.__dict__
|
or self.dict() == other.dict()
|
||||||
or (self.pk == other.pk and self.pk is not None)
|
or (self.pk == other.pk and self.pk is not None)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls, title: bool = False, lower: bool = True) -> str:
|
def get_name(cls, lower: bool = True) -> str:
|
||||||
name = cls.__name__
|
name = cls.__name__
|
||||||
if lower:
|
if lower:
|
||||||
name = name.lower()
|
name = name.lower()
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from sqlalchemy import text
|
|||||||
|
|
||||||
import ormar # noqa I100
|
import ormar # noqa I100
|
||||||
from ormar.exceptions import QueryDefinitionError
|
from ormar.exceptions import QueryDefinitionError
|
||||||
|
from ormar.fields.many_to_many import ManyToManyField
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
@ -128,6 +129,10 @@ class QueryClause:
|
|||||||
# against which the comparison is being made.
|
# against which the comparison is being made.
|
||||||
previous_table = model_cls.Meta.tablename
|
previous_table = model_cls.Meta.tablename
|
||||||
for part in related_parts:
|
for part in related_parts:
|
||||||
|
if issubclass(model_cls.Meta.model_fields[part], ManyToManyField):
|
||||||
|
previous_table = model_cls.Meta.model_fields[
|
||||||
|
part
|
||||||
|
].through.Meta.tablename
|
||||||
current_table = model_cls.Meta.model_fields[part].to.Meta.tablename
|
current_table = model_cls.Meta.model_fields[part].to.Meta.tablename
|
||||||
manager = model_cls.Meta.alias_manager
|
manager = model_cls.Meta.alias_manager
|
||||||
table_prefix = manager.resolve_relation_join(previous_table, current_table)
|
table_prefix = manager.resolve_relation_join(previous_table, current_table)
|
||||||
|
|||||||
@ -4,8 +4,10 @@ import sqlalchemy
|
|||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
import ormar # noqa I100
|
import ormar # noqa I100
|
||||||
|
from ormar.fields import BaseField
|
||||||
from ormar.fields.foreign_key import ForeignKeyField
|
from ormar.fields.foreign_key import ForeignKeyField
|
||||||
from ormar.relations import AliasManager
|
from ormar.fields.many_to_many import ManyToManyField
|
||||||
|
from ormar.relations.alias_manager import AliasManager
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
@ -63,6 +65,15 @@ class Query:
|
|||||||
)
|
)
|
||||||
|
|
||||||
for part in item.split("__"):
|
for part in item.split("__"):
|
||||||
|
if issubclass(
|
||||||
|
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
|
||||||
|
):
|
||||||
|
_fields = join_parameters.model_cls.Meta.model_fields
|
||||||
|
new_part = _fields[part].to.get_name()
|
||||||
|
join_parameters = self._build_join_parameters(
|
||||||
|
part, join_parameters, is_multi=True
|
||||||
|
)
|
||||||
|
part = new_part
|
||||||
join_parameters = self._build_join_parameters(part, join_parameters)
|
join_parameters = self._build_join_parameters(part, join_parameters)
|
||||||
|
|
||||||
expr = sqlalchemy.sql.select(self.columns)
|
expr = sqlalchemy.sql.select(self.columns)
|
||||||
@ -83,9 +94,17 @@ class Query:
|
|||||||
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
|
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
|
||||||
return text(f"{left_part}={right_part}")
|
return text(f"{left_part}={right_part}")
|
||||||
|
|
||||||
|
def _is_target_relation_key(
|
||||||
|
self, field: BaseField, target_model: Type["Model"]
|
||||||
|
) -> bool:
|
||||||
|
return issubclass(field, ForeignKeyField) and field.to.Meta == target_model.Meta
|
||||||
|
|
||||||
def _build_join_parameters(
|
def _build_join_parameters(
|
||||||
self, part: str, join_params: JoinParameters
|
self, part: str, join_params: JoinParameters, is_multi: bool = False
|
||||||
) -> JoinParameters:
|
) -> JoinParameters:
|
||||||
|
if is_multi:
|
||||||
|
model_cls = join_params.model_cls.Meta.model_fields[part].through
|
||||||
|
else:
|
||||||
model_cls = join_params.model_cls.Meta.model_fields[part].to
|
model_cls = join_params.model_cls.Meta.model_fields[part].to
|
||||||
to_table = model_cls.Meta.table.name
|
to_table = model_cls.Meta.table.name
|
||||||
|
|
||||||
@ -93,13 +112,12 @@ class Query:
|
|||||||
join_params.from_table, to_table
|
join_params.from_table, to_table
|
||||||
)
|
)
|
||||||
if alias not in self.used_aliases:
|
if alias not in self.used_aliases:
|
||||||
if join_params.prev_model.Meta.model_fields[part].virtual:
|
if join_params.prev_model.Meta.model_fields[part].virtual or is_multi:
|
||||||
to_key = next(
|
to_key = next(
|
||||||
(
|
(
|
||||||
v
|
v
|
||||||
for k, v in model_cls.Meta.model_fields.items()
|
for k, v in model_cls.Meta.model_fields.items()
|
||||||
if issubclass(v, ForeignKeyField)
|
if self._is_target_relation_key(v, join_params.prev_model)
|
||||||
and v.to == join_params.prev_model
|
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
).name
|
).name
|
||||||
@ -129,16 +147,19 @@ class Query:
|
|||||||
prev_model = model_cls
|
prev_model = model_cls
|
||||||
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
||||||
|
|
||||||
def _apply_expression_modifiers(
|
def filter(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: # noqa A003
|
||||||
self, expr: sqlalchemy.sql.select
|
|
||||||
) -> sqlalchemy.sql.select:
|
|
||||||
if self.filter_clauses:
|
if self.filter_clauses:
|
||||||
if len(self.filter_clauses) == 1:
|
if len(self.filter_clauses) == 1:
|
||||||
clause = self.filter_clauses[0]
|
clause = self.filter_clauses[0]
|
||||||
else:
|
else:
|
||||||
clause = sqlalchemy.sql.and_(*self.filter_clauses)
|
clause = sqlalchemy.sql.and_(*self.filter_clauses)
|
||||||
expr = expr.where(clause)
|
expr = expr.where(clause)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def _apply_expression_modifiers(
|
||||||
|
self, expr: sqlalchemy.sql.select
|
||||||
|
) -> sqlalchemy.sql.select:
|
||||||
|
expr = self.filter(expr)
|
||||||
if self.limit_count:
|
if self.limit_count:
|
||||||
expr = expr.limit(self.limit_count)
|
expr = expr.limit(self.limit_count)
|
||||||
|
|
||||||
|
|||||||
@ -48,6 +48,7 @@ class QuerySet:
|
|||||||
limit_count=self.limit_count,
|
limit_count=self.limit_count,
|
||||||
)
|
)
|
||||||
exp = qry.build_select_expression()
|
exp = qry.build_select_expression()
|
||||||
|
# print(exp.compile(compile_kwargs={"literal_binds": True}))
|
||||||
return exp
|
return exp
|
||||||
|
|
||||||
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
|
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
|
||||||
@ -70,7 +71,7 @@ class QuerySet:
|
|||||||
if not isinstance(related, (list, tuple)):
|
if not isinstance(related, (list, tuple)):
|
||||||
related = [related]
|
related = [related]
|
||||||
|
|
||||||
related = list(self._select_related) + related
|
related = list(set(list(self._select_related) + related))
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
model_cls=self.model_cls,
|
model_cls=self.model_cls,
|
||||||
filter_clauses=self.filter_clauses,
|
filter_clauses=self.filter_clauses,
|
||||||
@ -82,13 +83,28 @@ class QuerySet:
|
|||||||
async def exists(self) -> bool:
|
async def exists(self) -> bool:
|
||||||
expr = self.build_select_expression()
|
expr = self.build_select_expression()
|
||||||
expr = sqlalchemy.exists(expr).select()
|
expr = sqlalchemy.exists(expr).select()
|
||||||
|
# print(expr.compile(compile_kwargs={"literal_binds": True}))
|
||||||
return await self.database.fetch_val(expr)
|
return await self.database.fetch_val(expr)
|
||||||
|
|
||||||
async def count(self) -> int:
|
async def count(self) -> int:
|
||||||
expr = self.build_select_expression().alias("subquery_for_count")
|
expr = self.build_select_expression().alias("subquery_for_count")
|
||||||
expr = sqlalchemy.func.count().select().select_from(expr)
|
expr = sqlalchemy.func.count().select().select_from(expr)
|
||||||
|
# print(expr.compile(compile_kwargs={"literal_binds": True}))
|
||||||
return await self.database.fetch_val(expr)
|
return await self.database.fetch_val(expr)
|
||||||
|
|
||||||
|
async def delete(self, **kwargs: Any) -> int:
|
||||||
|
if kwargs:
|
||||||
|
return await self.filter(**kwargs).delete()
|
||||||
|
qry = Query(
|
||||||
|
model_cls=self.model_cls,
|
||||||
|
select_related=self._select_related,
|
||||||
|
filter_clauses=self.filter_clauses,
|
||||||
|
offset=self.query_offset,
|
||||||
|
limit_count=self.limit_count,
|
||||||
|
)
|
||||||
|
expr = qry.filter(self.table.delete())
|
||||||
|
return await self.database.execute(expr)
|
||||||
|
|
||||||
def limit(self, limit_count: int) -> "QuerySet":
|
def limit(self, limit_count: int) -> "QuerySet":
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
model_cls=self.model_cls,
|
model_cls=self.model_cls,
|
||||||
@ -118,7 +134,7 @@ class QuerySet:
|
|||||||
async def get(self, **kwargs: Any) -> "Model":
|
async def get(self, **kwargs: Any) -> "Model":
|
||||||
if kwargs:
|
if kwargs:
|
||||||
return await self.filter(**kwargs).get()
|
return await self.filter(**kwargs).get()
|
||||||
else:
|
|
||||||
if not self.filter_clauses:
|
if not self.filter_clauses:
|
||||||
expr = self.build_select_expression().limit(2)
|
expr = self.build_select_expression().limit(2)
|
||||||
else:
|
else:
|
||||||
@ -143,6 +159,7 @@ class QuerySet:
|
|||||||
return await self.filter(**kwargs).all()
|
return await self.filter(**kwargs).all()
|
||||||
|
|
||||||
expr = self.build_select_expression()
|
expr = self.build_select_expression()
|
||||||
|
# breakpoint()
|
||||||
rows = await self.database.fetch_all(expr)
|
rows = await self.database.fetch_all(expr)
|
||||||
result_rows = [
|
result_rows = [
|
||||||
self.model_cls.from_row(row, select_related=self._select_related)
|
self.model_cls.from_row(row, select_related=self._select_related)
|
||||||
|
|||||||
@ -1,198 +0,0 @@
|
|||||||
import string
|
|
||||||
import uuid
|
|
||||||
from enum import Enum
|
|
||||||
from random import choices
|
|
||||||
from typing import List, Optional, TYPE_CHECKING, Type, Union
|
|
||||||
from weakref import proxy
|
|
||||||
|
|
||||||
import sqlalchemy
|
|
||||||
from sqlalchemy import text
|
|
||||||
|
|
||||||
import ormar # noqa I100
|
|
||||||
from ormar.exceptions import RelationshipInstanceError # noqa I100
|
|
||||||
from ormar.fields.foreign_key import ForeignKeyField # noqa I100
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
|
||||||
from ormar.models import Model
|
|
||||||
|
|
||||||
|
|
||||||
def get_table_alias() -> str:
|
|
||||||
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
|
|
||||||
|
|
||||||
|
|
||||||
class RelationType(Enum):
|
|
||||||
PRIMARY = 1
|
|
||||||
REVERSE = 2
|
|
||||||
|
|
||||||
|
|
||||||
class AliasManager:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._aliases = dict()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
|
|
||||||
return [
|
|
||||||
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
|
|
||||||
for column in table.columns
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def prefixed_table_name(alias: str, name: str) -> text:
|
|
||||||
return text(f"{name} {alias}_{name}")
|
|
||||||
|
|
||||||
def add_relation_type(self, field: ForeignKeyField, table_name: str,) -> None:
|
|
||||||
if f"{table_name}_{field.to.Meta.tablename}" not in self._aliases:
|
|
||||||
self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias()
|
|
||||||
if f"{field.to.Meta.tablename}_{table_name}" not in self._aliases:
|
|
||||||
self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias()
|
|
||||||
|
|
||||||
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
|
|
||||||
return self._aliases.get(f"{from_table}_{to_table}", "")
|
|
||||||
|
|
||||||
|
|
||||||
class RelationProxy(list):
|
|
||||||
def __init__(self, relation: "Relation") -> None:
|
|
||||||
super(RelationProxy, self).__init__()
|
|
||||||
self.relation = relation
|
|
||||||
self._owner = self.relation.manager.owner
|
|
||||||
|
|
||||||
def remove(self, item: "Model") -> None:
|
|
||||||
super().remove(item)
|
|
||||||
rel_name = item.resolve_relation_name(item, self._owner)
|
|
||||||
item._orm._get(rel_name).remove(self._owner)
|
|
||||||
|
|
||||||
def append(self, item: "Model") -> None:
|
|
||||||
super().append(item)
|
|
||||||
|
|
||||||
def add(self, item: "Model") -> None:
|
|
||||||
rel_name = item.resolve_relation_name(item, self._owner)
|
|
||||||
setattr(item, rel_name, self._owner)
|
|
||||||
|
|
||||||
|
|
||||||
class Relation:
|
|
||||||
def __init__(self, manager: "RelationsManager", type_: RelationType) -> None:
|
|
||||||
self.manager = manager
|
|
||||||
self._owner = manager.owner
|
|
||||||
self._type = type_
|
|
||||||
self.related_models = (
|
|
||||||
RelationProxy(relation=self) if type_ == RelationType.REVERSE else None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _find_existing(self, child: "Model") -> Optional[int]:
|
|
||||||
for ind, relation_child in enumerate(self.related_models[:]):
|
|
||||||
try:
|
|
||||||
if relation_child.__same__(child):
|
|
||||||
return ind
|
|
||||||
except ReferenceError: # pragma no cover
|
|
||||||
self.related_models.pop(ind)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def add(self, child: "Model") -> None:
|
|
||||||
relation_name = self._owner.resolve_relation_name(self._owner, child)
|
|
||||||
if self._type == RelationType.PRIMARY:
|
|
||||||
self.related_models = child
|
|
||||||
self._owner.__dict__[relation_name] = child
|
|
||||||
else:
|
|
||||||
if self._find_existing(child) is None:
|
|
||||||
self.related_models.append(child)
|
|
||||||
rel = self._owner.__dict__.get(relation_name, [])
|
|
||||||
rel.append(child)
|
|
||||||
self._owner.__dict__[relation_name] = rel
|
|
||||||
|
|
||||||
def remove(self, child: "Model") -> None:
|
|
||||||
relation_name = self._owner.resolve_relation_name(self._owner, child)
|
|
||||||
if self._type == RelationType.PRIMARY:
|
|
||||||
if self.related_models.__same__(child):
|
|
||||||
self.related_models = None
|
|
||||||
del self._owner.__dict__[relation_name]
|
|
||||||
else:
|
|
||||||
position = self._find_existing(child)
|
|
||||||
if position is not None:
|
|
||||||
self.related_models.pop(position)
|
|
||||||
del self._owner.__dict__[relation_name][position]
|
|
||||||
|
|
||||||
def get(self) -> Union[List["Model"], "Model"]:
|
|
||||||
return self.related_models
|
|
||||||
|
|
||||||
def __repr__(self) -> str: # pragma no cover
|
|
||||||
return str(self.related_models)
|
|
||||||
|
|
||||||
|
|
||||||
class RelationsManager:
|
|
||||||
def __init__(
|
|
||||||
self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None
|
|
||||||
) -> None:
|
|
||||||
self.owner = owner
|
|
||||||
self._related_fields = related_fields or []
|
|
||||||
self._related_names = [field.name for field in self._related_fields]
|
|
||||||
self._relations = dict()
|
|
||||||
for field in self._related_fields:
|
|
||||||
self._add_relation(field)
|
|
||||||
|
|
||||||
def _add_relation(self, field: Type[ForeignKeyField]) -> None:
|
|
||||||
self._relations[field.name] = Relation(
|
|
||||||
manager=self,
|
|
||||||
type_=RelationType.PRIMARY if not field.virtual else RelationType.REVERSE,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __contains__(self, item: str) -> bool:
|
|
||||||
return item in self._related_names
|
|
||||||
|
|
||||||
def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]:
|
|
||||||
relation = self._relations.get(name, None)
|
|
||||||
if relation:
|
|
||||||
return relation.get()
|
|
||||||
|
|
||||||
def _get(self, name: str) -> Optional[Relation]:
|
|
||||||
relation = self._relations.get(name, None)
|
|
||||||
if relation:
|
|
||||||
return relation
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None:
|
|
||||||
to_field = next(
|
|
||||||
(
|
|
||||||
field
|
|
||||||
for field in child._orm._related_fields
|
|
||||||
if field.to == parent.__class__ or field.to.Meta == parent.Meta
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not to_field: # pragma no cover
|
|
||||||
raise RelationshipInstanceError(
|
|
||||||
f"Model {child.__class__} does not have "
|
|
||||||
f"reference to model {parent.__class__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
to_name = to_field.name
|
|
||||||
if virtual:
|
|
||||||
child_name, to_name = to_name, child_name or child.get_name()
|
|
||||||
child, parent = parent, proxy(child)
|
|
||||||
else:
|
|
||||||
child_name = child_name or child.get_name() + "s"
|
|
||||||
child = proxy(child)
|
|
||||||
|
|
||||||
parent_relation = parent._orm._get(child_name)
|
|
||||||
if not parent_relation:
|
|
||||||
ormar.models.expand_reverse_relationships(child.__class__)
|
|
||||||
name = parent.resolve_relation_name(parent, child)
|
|
||||||
field = parent.Meta.model_fields[name]
|
|
||||||
parent._orm._add_relation(field)
|
|
||||||
parent_relation = parent._orm._get(child_name)
|
|
||||||
parent_relation.add(child)
|
|
||||||
child._orm._get(to_name).add(parent)
|
|
||||||
|
|
||||||
def remove(self, name: str, child: "Model") -> None:
|
|
||||||
relation = self._get(name)
|
|
||||||
relation.remove(child)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def remove_parent(item: "Model", name: Union[str, "Model"]) -> None:
|
|
||||||
related_model = name
|
|
||||||
name = item.resolve_relation_name(item, related_model)
|
|
||||||
if name in item._orm:
|
|
||||||
relation_name = item.resolve_relation_name(related_model, item)
|
|
||||||
item._orm.remove(name, related_model)
|
|
||||||
related_model._orm.remove(relation_name, item)
|
|
||||||
3
ormar/relations/__init__.py
Normal file
3
ormar/relations/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from ormar.relations.alias_manager import AliasManager
|
||||||
|
|
||||||
|
__all__ = ["AliasManager"]
|
||||||
36
ormar/relations/alias_manager.py
Normal file
36
ormar/relations/alias_manager.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import string
|
||||||
|
import uuid
|
||||||
|
from random import choices
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
|
||||||
|
def get_table_alias() -> str:
|
||||||
|
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
|
||||||
|
|
||||||
|
|
||||||
|
class AliasManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._aliases = dict()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
|
||||||
|
return [
|
||||||
|
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
|
||||||
|
for column in table.columns
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prefixed_table_name(alias: str, name: str) -> text:
|
||||||
|
return text(f"{name} {alias}_{name}")
|
||||||
|
|
||||||
|
def add_relation_type(self, to_table_name: str, table_name: str,) -> None:
|
||||||
|
if f"{table_name}_{to_table_name}" not in self._aliases:
|
||||||
|
self._aliases[f"{table_name}_{to_table_name}"] = get_table_alias()
|
||||||
|
if f"{to_table_name}_{table_name}" not in self._aliases:
|
||||||
|
self._aliases[f"{to_table_name}_{table_name}"] = get_table_alias()
|
||||||
|
|
||||||
|
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
|
||||||
|
return self._aliases.get(f"{from_table}_{to_table}", "")
|
||||||
332
ormar/relations/relation.py
Normal file
332
ormar/relations/relation.py
Normal file
@ -0,0 +1,332 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, Union
|
||||||
|
from weakref import proxy
|
||||||
|
|
||||||
|
import ormar # noqa I100
|
||||||
|
from ormar.exceptions import RelationshipInstanceError # noqa I100
|
||||||
|
from ormar.fields.foreign_key import ForeignKeyField # noqa I100
|
||||||
|
from ormar.fields.many_to_many import ManyToManyField
|
||||||
|
from ormar.queryset import QuerySet
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma no cover
|
||||||
|
from ormar import Model
|
||||||
|
|
||||||
|
|
||||||
|
class RelationType(Enum):
|
||||||
|
PRIMARY = 1
|
||||||
|
REVERSE = 2
|
||||||
|
MULTIPLE = 3
|
||||||
|
|
||||||
|
|
||||||
|
class QuerysetProxy:
|
||||||
|
if TYPE_CHECKING: # pragma no cover
|
||||||
|
relation: "Relation"
|
||||||
|
|
||||||
|
def __init__(self, relation: "Relation") -> None:
|
||||||
|
self.relation = relation
|
||||||
|
self.queryset = None
|
||||||
|
|
||||||
|
def _assign_child_to_parent(self, child: "Model") -> None:
|
||||||
|
owner = self.relation._owner
|
||||||
|
rel_name = owner.resolve_relation_name(owner, child)
|
||||||
|
setattr(owner, rel_name, child)
|
||||||
|
|
||||||
|
def _register_related(self, child: Union["Model", List["Model"]]) -> None:
|
||||||
|
if isinstance(child, list):
|
||||||
|
for subchild in child:
|
||||||
|
self._assign_child_to_parent(subchild)
|
||||||
|
else:
|
||||||
|
self._assign_child_to_parent(child)
|
||||||
|
|
||||||
|
async def create_through_instance(self, child: "Model") -> None:
|
||||||
|
queryset = QuerySet(model_cls=self.relation.through)
|
||||||
|
owner_column = self.relation._owner.get_name()
|
||||||
|
child_column = child.get_name()
|
||||||
|
kwargs = {owner_column: self.relation._owner, child_column: child}
|
||||||
|
await queryset.create(**kwargs)
|
||||||
|
|
||||||
|
async def delete_through_instance(self, child: "Model") -> None:
|
||||||
|
queryset = QuerySet(model_cls=self.relation.through)
|
||||||
|
owner_column = self.relation._owner.get_name()
|
||||||
|
child_column = child.get_name()
|
||||||
|
kwargs = {owner_column: self.relation._owner, child_column: child}
|
||||||
|
link_instance = await queryset.filter(**kwargs).get()
|
||||||
|
await link_instance.delete()
|
||||||
|
|
||||||
|
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
|
||||||
|
return self.queryset.filter(**kwargs)
|
||||||
|
|
||||||
|
def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet":
|
||||||
|
return self.queryset.select_related(related)
|
||||||
|
|
||||||
|
async def exists(self) -> bool:
|
||||||
|
return await self.queryset.exists()
|
||||||
|
|
||||||
|
async def count(self) -> int:
|
||||||
|
return await self.queryset.count()
|
||||||
|
|
||||||
|
async def clear(self) -> int:
|
||||||
|
queryset = QuerySet(model_cls=self.relation.through)
|
||||||
|
owner_column = self.relation._owner.get_name()
|
||||||
|
kwargs = {owner_column: self.relation._owner}
|
||||||
|
return await queryset.delete(**kwargs)
|
||||||
|
|
||||||
|
def limit(self, limit_count: int) -> "QuerySet":
|
||||||
|
return self.queryset.limit(limit_count)
|
||||||
|
|
||||||
|
def offset(self, offset: int) -> "QuerySet":
|
||||||
|
return self.queryset.offset(offset)
|
||||||
|
|
||||||
|
async def first(self, **kwargs: Any) -> "Model":
|
||||||
|
first = await self.queryset.first(**kwargs)
|
||||||
|
self._register_related(first)
|
||||||
|
return first
|
||||||
|
|
||||||
|
async def get(self, **kwargs: Any) -> "Model":
|
||||||
|
get = await self.queryset.get(**kwargs)
|
||||||
|
self._register_related(get)
|
||||||
|
return get
|
||||||
|
|
||||||
|
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
|
||||||
|
all_items = await self.queryset.all(**kwargs)
|
||||||
|
self._register_related(all_items)
|
||||||
|
return all_items
|
||||||
|
|
||||||
|
async def create(self, **kwargs: Any) -> "Model":
|
||||||
|
create = await self.queryset.create(**kwargs)
|
||||||
|
self._register_related(create)
|
||||||
|
await self.create_through_instance(create)
|
||||||
|
return create
|
||||||
|
|
||||||
|
|
||||||
|
class RelationProxy(list):
|
||||||
|
def __init__(self, relation: "Relation") -> None:
|
||||||
|
super(RelationProxy, self).__init__()
|
||||||
|
self.relation = relation
|
||||||
|
self._owner = self.relation.manager.owner
|
||||||
|
self.queryset_proxy = QuerysetProxy(relation=self.relation)
|
||||||
|
|
||||||
|
def __getattribute__(self, item: str) -> Any:
|
||||||
|
if item in ["count", "clear"]:
|
||||||
|
if not self.queryset_proxy.queryset:
|
||||||
|
self.queryset_proxy.queryset = self._set_queryset()
|
||||||
|
return getattr(self.queryset_proxy, item)
|
||||||
|
return super().__getattribute__(item)
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
if not self.queryset_proxy.queryset:
|
||||||
|
self.queryset_proxy.queryset = self._set_queryset()
|
||||||
|
return getattr(self.queryset_proxy, item)
|
||||||
|
|
||||||
|
def _set_queryset(self) -> QuerySet:
|
||||||
|
owner_table = self.relation._owner.Meta.tablename
|
||||||
|
pkname = self.relation._owner.Meta.pkname
|
||||||
|
pk_value = self.relation._owner.pk
|
||||||
|
if not pk_value:
|
||||||
|
raise RelationshipInstanceError(
|
||||||
|
"You cannot query many to many relationship on unsaved model."
|
||||||
|
)
|
||||||
|
kwargs = {f"{owner_table}__{pkname}": pk_value}
|
||||||
|
queryset = (
|
||||||
|
QuerySet(model_cls=self.relation.to)
|
||||||
|
.select_related(owner_table)
|
||||||
|
.filter(**kwargs)
|
||||||
|
)
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
async def remove(self, item: "Model") -> None:
|
||||||
|
super().remove(item)
|
||||||
|
rel_name = item.resolve_relation_name(item, self._owner)
|
||||||
|
item._orm._get(rel_name).remove(self._owner)
|
||||||
|
if self.relation._type == RelationType.MULTIPLE:
|
||||||
|
await self.queryset_proxy.delete_through_instance(item)
|
||||||
|
|
||||||
|
def append(self, item: "Model") -> None:
|
||||||
|
super().append(item)
|
||||||
|
|
||||||
|
async def add(self, item: "Model") -> None:
|
||||||
|
if self.relation._type == RelationType.MULTIPLE:
|
||||||
|
await self.queryset_proxy.create_through_instance(item)
|
||||||
|
rel_name = item.resolve_relation_name(item, self._owner)
|
||||||
|
setattr(item, rel_name, self._owner)
|
||||||
|
|
||||||
|
|
||||||
|
class Relation:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
manager: "RelationsManager",
|
||||||
|
type_: RelationType,
|
||||||
|
to: Type["Model"],
|
||||||
|
through: Type["Model"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.manager = manager
|
||||||
|
self._owner = manager.owner
|
||||||
|
self._type = type_
|
||||||
|
self.to = to
|
||||||
|
self.through = through
|
||||||
|
self.related_models = (
|
||||||
|
RelationProxy(relation=self)
|
||||||
|
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_existing(self, child: "Model") -> Optional[int]:
|
||||||
|
for ind, relation_child in enumerate(self.related_models[:]):
|
||||||
|
try:
|
||||||
|
if relation_child.__same__(child):
|
||||||
|
return ind
|
||||||
|
except ReferenceError: # pragma no cover
|
||||||
|
self.related_models.pop(ind)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add(self, child: "Model") -> None:
|
||||||
|
relation_name = self._owner.resolve_relation_name(self._owner, child)
|
||||||
|
if self._type == RelationType.PRIMARY:
|
||||||
|
self.related_models = child
|
||||||
|
self._owner.__dict__[relation_name] = child
|
||||||
|
else:
|
||||||
|
if self._find_existing(child) is None:
|
||||||
|
self.related_models.append(child)
|
||||||
|
rel = self._owner.__dict__.get(relation_name, [])
|
||||||
|
rel = rel or []
|
||||||
|
if not isinstance(rel, list):
|
||||||
|
rel = [rel]
|
||||||
|
rel.append(child)
|
||||||
|
self._owner.__dict__[relation_name] = rel
|
||||||
|
|
||||||
|
def remove(self, child: "Model") -> None:
|
||||||
|
relation_name = self._owner.resolve_relation_name(self._owner, child)
|
||||||
|
if self._type == RelationType.PRIMARY:
|
||||||
|
if self.related_models.__same__(child):
|
||||||
|
self.related_models = None
|
||||||
|
del self._owner.__dict__[relation_name]
|
||||||
|
else:
|
||||||
|
position = self._find_existing(child)
|
||||||
|
if position is not None:
|
||||||
|
self.related_models.pop(position)
|
||||||
|
del self._owner.__dict__[relation_name][position]
|
||||||
|
|
||||||
|
def get(self) -> Union[List["Model"], "Model"]:
|
||||||
|
return self.related_models
|
||||||
|
|
||||||
|
def __repr__(self) -> str: # pragma no cover
|
||||||
|
return str(self.related_models)
|
||||||
|
|
||||||
|
|
||||||
|
class RelationsManager:
|
||||||
|
def __init__(
|
||||||
|
self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None
|
||||||
|
) -> None:
|
||||||
|
self.owner = proxy(owner)
|
||||||
|
self._related_fields = related_fields or []
|
||||||
|
self._related_names = [field.name for field in self._related_fields]
|
||||||
|
self._relations = dict()
|
||||||
|
for field in self._related_fields:
|
||||||
|
self._add_relation(field)
|
||||||
|
|
||||||
|
def _get_relation_type(self, field: Type[ForeignKeyField]) -> RelationType:
|
||||||
|
if issubclass(field, ManyToManyField):
|
||||||
|
return RelationType.MULTIPLE
|
||||||
|
return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE
|
||||||
|
|
||||||
|
def _add_relation(self, field: Type[ForeignKeyField]) -> None:
|
||||||
|
self._relations[field.name] = Relation(
|
||||||
|
manager=self,
|
||||||
|
type_=self._get_relation_type(field),
|
||||||
|
to=field.to,
|
||||||
|
through=getattr(field, "through", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __contains__(self, item: str) -> bool:
|
||||||
|
return item in self._related_names
|
||||||
|
|
||||||
|
def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]:
|
||||||
|
relation = self._relations.get(name, None)
|
||||||
|
if relation is not None:
|
||||||
|
return relation.get()
|
||||||
|
|
||||||
|
def _get(self, name: str) -> Optional[Relation]:
|
||||||
|
relation = self._relations.get(name, None)
|
||||||
|
if relation is not None:
|
||||||
|
return relation
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register_missing_relation(
|
||||||
|
parent: "Model", child: "Model", child_name: str
|
||||||
|
) -> Relation:
|
||||||
|
ormar.models.expand_reverse_relationships(child.__class__)
|
||||||
|
name = parent.resolve_relation_name(parent, child)
|
||||||
|
field = parent.Meta.model_fields[name]
|
||||||
|
parent._orm._add_relation(field)
|
||||||
|
parent_relation = parent._orm._get(child_name)
|
||||||
|
return parent_relation
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_relations_sides_and_names(
|
||||||
|
to_field: Type[ForeignKeyField],
|
||||||
|
parent: "Model",
|
||||||
|
child: "Model",
|
||||||
|
child_name: str,
|
||||||
|
virtual: bool,
|
||||||
|
) -> Tuple["Model", "Model", str, str]:
|
||||||
|
to_name = to_field.name
|
||||||
|
if issubclass(to_field, ManyToManyField):
|
||||||
|
child_name, to_name = (
|
||||||
|
child.resolve_relation_name(parent, child),
|
||||||
|
child.resolve_relation_name(child, parent),
|
||||||
|
)
|
||||||
|
child = proxy(child)
|
||||||
|
elif virtual:
|
||||||
|
child_name, to_name = to_name, child_name or child.get_name()
|
||||||
|
child, parent = parent, proxy(child)
|
||||||
|
else:
|
||||||
|
child_name = child_name or child.get_name() + "s"
|
||||||
|
child = proxy(child)
|
||||||
|
return parent, child, child_name, to_name
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None:
|
||||||
|
to_field = next(
|
||||||
|
(
|
||||||
|
field
|
||||||
|
for field in child._orm._related_fields
|
||||||
|
if field.to == parent.__class__ or field.to.Meta == parent.Meta
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not to_field: # pragma no cover
|
||||||
|
raise RelationshipInstanceError(
|
||||||
|
f"Model {child.__class__} does not have "
|
||||||
|
f"reference to model {parent.__class__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
parent,
|
||||||
|
child,
|
||||||
|
child_name,
|
||||||
|
to_name,
|
||||||
|
) = RelationsManager.get_relations_sides_and_names(
|
||||||
|
to_field, parent, child, child_name, virtual
|
||||||
|
)
|
||||||
|
|
||||||
|
parent_relation = parent._orm._get(child_name)
|
||||||
|
if not parent_relation:
|
||||||
|
parent_relation = RelationsManager.register_missing_relation(
|
||||||
|
parent, child, child_name
|
||||||
|
)
|
||||||
|
parent_relation.add(child)
|
||||||
|
child._orm._get(to_name).add(parent)
|
||||||
|
|
||||||
|
def remove(self, name: str, child: "Model") -> None:
|
||||||
|
relation = self._get(name)
|
||||||
|
relation.remove(child)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def remove_parent(item: "Model", name: Union[str, "Model"]) -> None:
|
||||||
|
related_model = name
|
||||||
|
name = item.resolve_relation_name(item, related_model)
|
||||||
|
if name in item._orm:
|
||||||
|
relation_name = item.resolve_relation_name(related_model, item)
|
||||||
|
item._orm.remove(name, related_model)
|
||||||
|
related_model._orm.remove(relation_name, item)
|
||||||
@ -1,5 +1,3 @@
|
|||||||
import gc
|
|
||||||
|
|
||||||
import databases
|
import databases
|
||||||
import pytest
|
import pytest
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
@ -179,7 +177,7 @@ async def test_model_removal_from_relations():
|
|||||||
await track3.save()
|
await track3.save()
|
||||||
|
|
||||||
assert len(album.tracks) == 3
|
assert len(album.tracks) == 3
|
||||||
album.tracks.remove(track1)
|
await album.tracks.remove(track1)
|
||||||
assert len(album.tracks) == 2
|
assert len(album.tracks) == 2
|
||||||
assert track1.album is None
|
assert track1.album is None
|
||||||
|
|
||||||
@ -187,7 +185,7 @@ async def test_model_removal_from_relations():
|
|||||||
track1 = await Track.objects.get(title="The Birdman")
|
track1 = await Track.objects.get(title="The Birdman")
|
||||||
assert track1.album is None
|
assert track1.album is None
|
||||||
|
|
||||||
album.tracks.add(track1)
|
await album.tracks.add(track1)
|
||||||
assert len(album.tracks) == 3
|
assert len(album.tracks) == 3
|
||||||
assert track1.album == album
|
assert track1.album == album
|
||||||
|
|
||||||
|
|||||||
178
tests/test_many_to_many.py
Normal file
178
tests/test_many_to_many.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
import databases
|
||||||
|
import pytest
|
||||||
|
import sqlalchemy
|
||||||
|
|
||||||
|
import ormar
|
||||||
|
from ormar.exceptions import RelationshipInstanceError
|
||||||
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
|
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||||
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
|
|
||||||
|
class Author(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "authors"
|
||||||
|
database = database
|
||||||
|
metadata = metadata
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
first_name: ormar.String(max_length=80)
|
||||||
|
last_name: ormar.String(max_length=80)
|
||||||
|
|
||||||
|
|
||||||
|
class Category(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "categories"
|
||||||
|
database = database
|
||||||
|
metadata = metadata
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
name: ormar.String(max_length=40)
|
||||||
|
|
||||||
|
|
||||||
|
class PostCategory(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "posts_categories"
|
||||||
|
database = database
|
||||||
|
metadata = metadata
|
||||||
|
|
||||||
|
|
||||||
|
class Post(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "posts"
|
||||||
|
database = database
|
||||||
|
metadata = metadata
|
||||||
|
|
||||||
|
id: ormar.Integer(primary_key=True)
|
||||||
|
title: ormar.String(max_length=200)
|
||||||
|
categories: ormar.ManyToMany(Category, through=PostCategory)
|
||||||
|
author: ormar.ForeignKey(Author)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True, scope="module")
|
||||||
|
def create_test_database():
|
||||||
|
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||||
|
metadata.create_all(engine)
|
||||||
|
yield
|
||||||
|
metadata.drop_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
async def cleanup():
|
||||||
|
yield
|
||||||
|
await PostCategory.objects.delete()
|
||||||
|
await Post.objects.delete()
|
||||||
|
await Category.objects.delete()
|
||||||
|
await Author.objects.delete()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assigning_related_objects(cleanup):
|
||||||
|
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
|
||||||
|
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||||
|
news = await Category.objects.create(name="News")
|
||||||
|
|
||||||
|
# Add a category to a post.
|
||||||
|
await post.categories.add(news)
|
||||||
|
# or from the other end:
|
||||||
|
await news.posts.add(post)
|
||||||
|
|
||||||
|
# Creating related object from instance:
|
||||||
|
await post.categories.create(name="Tips")
|
||||||
|
assert len(post.categories) == 2
|
||||||
|
|
||||||
|
post_categories = await post.categories.all()
|
||||||
|
assert len(post_categories) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_quering_of_the_m2m_models(cleanup):
|
||||||
|
# orm can do this already.
|
||||||
|
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
|
||||||
|
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||||
|
news = await Category.objects.create(name="News")
|
||||||
|
# tl;dr: `post.categories` exposes the QuerySet API.
|
||||||
|
|
||||||
|
await post.categories.add(news)
|
||||||
|
|
||||||
|
post_categories = await post.categories.all()
|
||||||
|
assert len(post_categories) == 1
|
||||||
|
|
||||||
|
assert news == await post.categories.get(name="News")
|
||||||
|
|
||||||
|
num_posts = await news.posts.count()
|
||||||
|
assert num_posts == 1
|
||||||
|
|
||||||
|
posts_about_m2m = await news.posts.filter(title__contains="M2M").all()
|
||||||
|
assert len(posts_about_m2m) == 1
|
||||||
|
assert posts_about_m2m[0] == post
|
||||||
|
posts_about_python = await Post.objects.filter(categories__name="python").all()
|
||||||
|
assert len(posts_about_python) == 0
|
||||||
|
|
||||||
|
# Traversal of relationships: which categories has Guido contributed to?
|
||||||
|
category = await Category.objects.filter(posts__author=guido).get()
|
||||||
|
assert category == news
|
||||||
|
# or:
|
||||||
|
category2 = await Category.objects.filter(posts__author__first_name="Guido").get()
|
||||||
|
assert category2 == news
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_removal_of_the_relations(cleanup):
|
||||||
|
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
|
||||||
|
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||||
|
news = await Category.objects.create(name="News")
|
||||||
|
await post.categories.add(news)
|
||||||
|
assert len(await post.categories.all()) == 1
|
||||||
|
await post.categories.remove(news)
|
||||||
|
assert len(await post.categories.all()) == 0
|
||||||
|
# or:
|
||||||
|
await news.posts.add(post)
|
||||||
|
assert len(await news.posts.all()) == 1
|
||||||
|
await news.posts.remove(post)
|
||||||
|
assert len(await news.posts.all()) == 0
|
||||||
|
|
||||||
|
# Remove all related objects:
|
||||||
|
await post.categories.add(news)
|
||||||
|
await post.categories.clear()
|
||||||
|
assert len(await post.categories.all()) == 0
|
||||||
|
|
||||||
|
# post would also lose 'news' category when running:
|
||||||
|
await post.categories.add(news)
|
||||||
|
await news.delete()
|
||||||
|
assert len(await post.categories.all()) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_selecting_related(cleanup):
|
||||||
|
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
|
||||||
|
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||||
|
news = await Category.objects.create(name="News")
|
||||||
|
recent = await Category.objects.create(name="Recent")
|
||||||
|
await post.categories.add(news)
|
||||||
|
await post.categories.add(recent)
|
||||||
|
assert len(await post.categories.all()) == 2
|
||||||
|
# Loads categories and posts (2 queries) and perform the join in Python.
|
||||||
|
categories = await Category.objects.select_related("posts").all()
|
||||||
|
# No extra queries needed => no more `await`s required.
|
||||||
|
for category in categories:
|
||||||
|
assert category.posts[0] == post
|
||||||
|
|
||||||
|
news_posts = await news.posts.select_related("author").all()
|
||||||
|
assert news_posts[0].author == guido
|
||||||
|
|
||||||
|
assert (await post.categories.limit(1).all())[0] == news
|
||||||
|
assert (await post.categories.offset(1).limit(1).all())[0] == recent
|
||||||
|
|
||||||
|
assert await post.categories.first() == news
|
||||||
|
|
||||||
|
assert await post.categories.exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_selecting_related_fail_without_saving(cleanup):
|
||||||
|
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
|
||||||
|
post = Post(title="Hello, M2M", author=guido)
|
||||||
|
with pytest.raises(RelationshipInstanceError):
|
||||||
|
await post.categories.all()
|
||||||
Reference in New Issue
Block a user