WIP changes up to join redefinition pending - use fields instead of join_params
This commit is contained in:
@ -267,7 +267,7 @@ class BaseField(FieldInfo):
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def set_self_reference_flag(cls):
|
||||
def set_self_reference_flag(cls) -> None:
|
||||
"""
|
||||
Sets `self_reference` to True if field to and owner are same model.
|
||||
:return: None
|
||||
@ -301,3 +301,13 @@ class BaseField(FieldInfo):
|
||||
:return: None
|
||||
:rtype: None
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_related_name(cls) -> str:
|
||||
"""
|
||||
Returns name to use for reverse relation.
|
||||
It's either set as `related_name` or by default it's owner model. get_name + 's'
|
||||
:return: name of the related_name or default related name.
|
||||
:rtype: str
|
||||
"""
|
||||
return "" # pragma: no cover
|
||||
|
||||
@ -167,6 +167,8 @@ def ForeignKey( # noqa CFQ002
|
||||
"""
|
||||
|
||||
owner = kwargs.pop("owner", None)
|
||||
self_reference = kwargs.pop("self_reference", False)
|
||||
|
||||
if isinstance(to, ForwardRef):
|
||||
__type__ = to if not nullable else Optional[to]
|
||||
constraints: List = []
|
||||
@ -196,6 +198,7 @@ def ForeignKey( # noqa CFQ002
|
||||
onupdate=onupdate,
|
||||
ondelete=ondelete,
|
||||
owner=owner,
|
||||
self_reference=self_reference,
|
||||
)
|
||||
|
||||
return type("ForeignKey", (ForeignKeyField, BaseField), namespace)
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from typing import Any, ForwardRef, List, Optional, TYPE_CHECKING, Tuple, Type, Union
|
||||
|
||||
from pydantic.typing import evaluate_forwardref
|
||||
|
||||
import ormar
|
||||
import ormar # noqa: I100
|
||||
from ormar.fields import BaseField
|
||||
from ormar.fields.foreign_key import ForeignKeyField
|
||||
|
||||
@ -71,6 +70,7 @@ def ManyToMany(
|
||||
related_name = kwargs.pop("related_name", None)
|
||||
nullable = kwargs.pop("nullable", True)
|
||||
owner = kwargs.pop("owner", None)
|
||||
self_reference = kwargs.pop("self_reference", False)
|
||||
|
||||
if isinstance(to, ForwardRef):
|
||||
__type__ = to if not nullable else Optional[to]
|
||||
@ -96,6 +96,7 @@ def ManyToMany(
|
||||
default=None,
|
||||
server_default=None,
|
||||
owner=owner,
|
||||
self_reference=self_reference,
|
||||
)
|
||||
|
||||
return type("ManyToMany", (ManyToManyField, BaseField), namespace)
|
||||
|
||||
@ -109,6 +109,7 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None:
|
||||
virtual=True,
|
||||
related_name=model_field.name,
|
||||
owner=model_field.to,
|
||||
self_reference=model_field.self_reference,
|
||||
)
|
||||
# register foreign keys on through model
|
||||
adjust_through_many_to_many_model(model_field=model_field)
|
||||
@ -119,12 +120,11 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None:
|
||||
virtual=True,
|
||||
related_name=model_field.name,
|
||||
owner=model_field.to,
|
||||
self_reference=model_field.self_reference,
|
||||
)
|
||||
|
||||
|
||||
def register_relation_in_alias_manager(
|
||||
field: Type[ForeignKeyField], field_name: str
|
||||
) -> None:
|
||||
def register_relation_in_alias_manager(field: Type[ForeignKeyField]) -> None:
|
||||
"""
|
||||
Registers the relation (and reverse relation) in alias manager.
|
||||
The m2m relations require registration of through model between
|
||||
@ -136,8 +136,6 @@ def register_relation_in_alias_manager(
|
||||
|
||||
:param field: relation field
|
||||
:type field: ForeignKey or ManyToManyField class
|
||||
:param field_name: name of the relation key
|
||||
:type field_name: str
|
||||
"""
|
||||
if issubclass(field, ManyToManyField):
|
||||
if field.has_unresolved_forward_refs():
|
||||
|
||||
@ -587,9 +587,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
populate_meta_sqlalchemy_table_if_required(new_model.Meta)
|
||||
expand_reverse_relationships(new_model)
|
||||
for field_name, field in new_model.Meta.model_fields.items():
|
||||
register_relation_in_alias_manager(
|
||||
field=field, field_name=field_name
|
||||
)
|
||||
register_relation_in_alias_manager(field=field)
|
||||
|
||||
if new_model.Meta.pkname not in attrs["__annotations__"]:
|
||||
field_name = new_model.Meta.pkname
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Callable, Dict, List, TYPE_CHECKING, Tuple, Type
|
||||
|
||||
import ormar
|
||||
from ormar.fields import BaseField
|
||||
from ormar.fields.foreign_key import ForeignKeyField
|
||||
from ormar.models.mixins.relation_mixin import RelationMixin
|
||||
|
||||
|
||||
@ -37,10 +37,7 @@ class PrefetchQueryMixin(RelationMixin):
|
||||
:rtype: Tuple[Type[Model], str]
|
||||
"""
|
||||
if reverse:
|
||||
field_name = (
|
||||
parent_model.Meta.model_fields[related].related_name
|
||||
or parent_model.get_name() + "s"
|
||||
)
|
||||
field_name = parent_model.Meta.model_fields[related].get_related_name()
|
||||
field = target_model.Meta.model_fields[field_name]
|
||||
if issubclass(field, ormar.fields.ManyToManyField):
|
||||
field_name = field.default_target_field_name()
|
||||
@ -79,7 +76,7 @@ class PrefetchQueryMixin(RelationMixin):
|
||||
return column.get_alias() if use_raw else column.name
|
||||
|
||||
@classmethod
|
||||
def get_related_field_name(cls, target_field: Type["BaseField"]) -> str:
|
||||
def get_related_field_name(cls, target_field: Type["ForeignKeyField"]) -> str:
|
||||
"""
|
||||
Returns name of the relation field that should be used in prefetch query.
|
||||
This field is later used to register relation in prefetch query,
|
||||
@ -93,7 +90,7 @@ class PrefetchQueryMixin(RelationMixin):
|
||||
if issubclass(target_field, ormar.fields.ManyToManyField):
|
||||
return cls.get_name()
|
||||
if target_field.virtual:
|
||||
return target_field.related_name or cls.get_name() + "s"
|
||||
return target_field.get_related_name()
|
||||
return target_field.to.Meta.pkname
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -30,7 +30,7 @@ import sqlalchemy
|
||||
from pydantic import BaseModel
|
||||
|
||||
import ormar # noqa I100
|
||||
from ormar.exceptions import ModelError
|
||||
from ormar.exceptions import ModelError, ModelPersistenceError
|
||||
from ormar.fields import BaseField
|
||||
from ormar.fields.foreign_key import ForeignKeyField
|
||||
from ormar.models.helpers import register_relation_in_alias_manager
|
||||
@ -452,9 +452,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
field.evaluate_forward_ref(globalns=globalns, localns=localns)
|
||||
field.set_self_reference_flag()
|
||||
expand_reverse_relationship(model_field=field)
|
||||
register_relation_in_alias_manager(
|
||||
field=field, field_name=field_name,
|
||||
)
|
||||
register_relation_in_alias_manager(field=field)
|
||||
update_column_definition(model=cls, field=field)
|
||||
populate_meta_sqlalchemy_table_if_required(meta=cls.Meta)
|
||||
super().update_forward_refs(**localns)
|
||||
@ -731,9 +729,15 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
if self.get_column_alias(k) in self.Meta.table.columns
|
||||
}
|
||||
for field in self._extract_db_related_names():
|
||||
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
|
||||
relation_field = self.Meta.model_fields[field]
|
||||
target_pk_name = relation_field.to.Meta.pkname
|
||||
target_field = getattr(self, field)
|
||||
self_fields[field] = getattr(target_field, target_pk_name, None)
|
||||
if not relation_field.nullable and not self_fields[field]:
|
||||
raise ModelPersistenceError(
|
||||
f"You cannot save {relation_field.to.get_name()} "
|
||||
f"model without pk set!"
|
||||
)
|
||||
return self_fields
|
||||
|
||||
def get_relation_model_id(self, target_field: Type["BaseField"]) -> Optional[int]:
|
||||
|
||||
@ -135,7 +135,7 @@ class SqlJoin:
|
||||
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
|
||||
):
|
||||
_fields = join_parameters.model_cls.Meta.model_fields
|
||||
new_part = _fields[part].to.get_name()
|
||||
new_part = _fields[part].default_target_field_name()
|
||||
self._switch_many_to_many_order_columns(part, new_part)
|
||||
if index > 0: # nested joins
|
||||
fields, exclude_fields = SqlJoin.update_inclusions(
|
||||
@ -435,11 +435,9 @@ class SqlJoin:
|
||||
from_key = join_params.prev_model.get_column_alias(
|
||||
join_params.prev_model.Meta.pkname
|
||||
)
|
||||
breakpoint()
|
||||
elif join_params.prev_model.Meta.model_fields[part].virtual:
|
||||
to_field = (
|
||||
join_params.prev_model.Meta.model_fields[part].related_name
|
||||
or join_params.prev_model.get_name() + "s"
|
||||
)
|
||||
to_field = join_params.prev_model.Meta.model_fields[part].get_related_name()
|
||||
to_key = model_cls.get_column_alias(to_field)
|
||||
from_key = join_params.prev_model.get_column_alias(
|
||||
join_params.prev_model.Meta.pkname
|
||||
|
||||
@ -9,10 +9,12 @@ from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import ormar
|
||||
from ormar.fields import BaseField, ManyToManyField
|
||||
from ormar.fields.foreign_key import ForeignKeyField
|
||||
from ormar.queryset.clause import QueryClause
|
||||
from ormar.queryset.query import Query
|
||||
from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict
|
||||
@ -314,6 +316,7 @@ class PrefetchQuery:
|
||||
|
||||
for related in related_to_extract:
|
||||
target_field = model.Meta.model_fields[related]
|
||||
target_field = cast(Type[ForeignKeyField], target_field)
|
||||
target_model = target_field.to.get_name()
|
||||
model_id = model.get_relation_model_id(target_field=target_field)
|
||||
|
||||
@ -421,6 +424,7 @@ class PrefetchQuery:
|
||||
fields = target_model.get_included(fields, related)
|
||||
exclude_fields = target_model.get_excluded(exclude_fields, related)
|
||||
target_field = target_model.Meta.model_fields[related]
|
||||
target_field = cast(Type[ForeignKeyField], target_field)
|
||||
reverse = False
|
||||
if target_field.virtual or issubclass(target_field, ManyToManyField):
|
||||
reverse = True
|
||||
@ -585,7 +589,7 @@ class PrefetchQuery:
|
||||
def _populate_rows( # noqa: CFQ002
|
||||
self,
|
||||
rows: List,
|
||||
target_field: Type["BaseField"],
|
||||
target_field: Type["ForeignKeyField"],
|
||||
parent_model: Type["Model"],
|
||||
table_prefix: str,
|
||||
fields: Union[Set[Any], Dict[Any, Any], None],
|
||||
|
||||
@ -197,7 +197,7 @@ class QuerySet:
|
||||
limit_raw_sql=self.limit_sql_raw,
|
||||
)
|
||||
exp = qry.build_select_expression()
|
||||
# print("\n", exp.compile(compile_kwargs={"literal_binds": True}))
|
||||
print("\n", exp.compile(compile_kwargs={"literal_binds": True}))
|
||||
return exp
|
||||
|
||||
def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003
|
||||
|
||||
@ -39,10 +39,9 @@ class QuerysetProxy(ormar.QuerySetProtocol):
|
||||
self._queryset: Optional["QuerySet"] = qryset
|
||||
self.type_: "RelationType" = type_
|
||||
self._owner: "Model" = self.relation.manager.owner
|
||||
self.related_field_name = (
|
||||
self._owner.Meta.model_fields[self.relation.field_name].related_name
|
||||
or self._owner.get_name() + "s"
|
||||
)
|
||||
self.related_field_name = self._owner.Meta.model_fields[
|
||||
self.relation.field_name
|
||||
].get_related_name()
|
||||
self.related_field = self.relation.to.Meta.model_fields[self.related_field_name]
|
||||
self.owner_pk_value = self._owner.pk
|
||||
|
||||
@ -108,8 +107,8 @@ class QuerysetProxy(ormar.QuerySetProtocol):
|
||||
:type child: Model
|
||||
"""
|
||||
model_cls = self.relation.through
|
||||
owner_column = self._owner.get_name()
|
||||
child_column = child.get_name()
|
||||
owner_column = self.related_field.default_target_field_name()
|
||||
child_column = self.related_field.default_source_field_name()
|
||||
kwargs = {owner_column: self._owner.pk, child_column: child.pk}
|
||||
if child.pk is None:
|
||||
raise ModelPersistenceError(
|
||||
@ -129,8 +128,8 @@ class QuerysetProxy(ormar.QuerySetProtocol):
|
||||
:type child: Model
|
||||
"""
|
||||
queryset = ormar.QuerySet(model_cls=self.relation.through)
|
||||
owner_column = self._owner.get_name()
|
||||
child_column = child.get_name()
|
||||
owner_column = self.related_field.default_target_field_name()
|
||||
child_column = self.related_field.default_source_field_name()
|
||||
kwargs = {owner_column: self._owner, child_column: child}
|
||||
link_instance = await queryset.filter(**kwargs).get() # type: ignore
|
||||
await link_instance.delete()
|
||||
|
||||
@ -164,8 +164,6 @@ class RelationsManager:
|
||||
:param name: name of the relation
|
||||
:type name: str
|
||||
"""
|
||||
relation_name = (
|
||||
item.Meta.model_fields[name].related_name or item.get_name() + "s"
|
||||
)
|
||||
relation_name = item.Meta.model_fields[name].get_related_name()
|
||||
item._orm.remove(name, parent)
|
||||
parent._orm.remove(relation_name, item)
|
||||
|
||||
@ -42,9 +42,8 @@ class RelationProxy(list):
|
||||
if self._related_field_name:
|
||||
return self._related_field_name
|
||||
owner_field = self._owner.Meta.model_fields[self.field_name]
|
||||
self._related_field_name = (
|
||||
owner_field.related_name or self._owner.get_name() + "s"
|
||||
)
|
||||
self._related_field_name = owner_field.get_related_name()
|
||||
|
||||
return self._related_field_name
|
||||
|
||||
def __getattribute__(self, item: str) -> Any:
|
||||
|
||||
@ -35,34 +35,34 @@ Game = ForwardRef("Game")
|
||||
Child = ForwardRef("Child")
|
||||
|
||||
|
||||
class ChildFriends(ormar.Model):
|
||||
class ChildFriend(ormar.Model):
|
||||
class Meta(ModelMeta):
|
||||
metadata = metadata
|
||||
database = db
|
||||
|
||||
|
||||
# class Child(ormar.Model):
|
||||
# class Meta(ModelMeta):
|
||||
# metadata = metadata
|
||||
# database = db
|
||||
#
|
||||
# id: int = ormar.Integer(primary_key=True)
|
||||
# name: str = ormar.String(max_length=100)
|
||||
# favourite_game: Game = ormar.ForeignKey(Game, related_name="liked_by")
|
||||
# least_favourite_game: Game = ormar.ForeignKey(Game, related_name="not_liked_by")
|
||||
# friends: List[Child] = ormar.ManyToMany(Child, through=ChildFriends)
|
||||
#
|
||||
#
|
||||
# class Game(ormar.Model):
|
||||
# class Meta(ModelMeta):
|
||||
# metadata = metadata
|
||||
# database = db
|
||||
#
|
||||
# id: int = ormar.Integer(primary_key=True)
|
||||
# name: str = ormar.String(max_length=100)
|
||||
class Child(ormar.Model):
|
||||
class Meta(ModelMeta):
|
||||
metadata = metadata
|
||||
database = db
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100)
|
||||
favourite_game: Game = ormar.ForeignKey(Game, related_name="liked_by")
|
||||
least_favourite_game: Game = ormar.ForeignKey(Game, related_name="not_liked_by")
|
||||
friends = ormar.ManyToMany(Child, through=ChildFriend, related_name="also_friends")
|
||||
|
||||
|
||||
# Child.update_forward_refs()
|
||||
class Game(ormar.Model):
|
||||
class Meta(ModelMeta):
|
||||
metadata = metadata
|
||||
database = db
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100)
|
||||
|
||||
|
||||
Child.update_forward_refs()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
@ -125,22 +125,56 @@ async def test_self_relation():
|
||||
assert sam_check.employees[0].name == "Joe"
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_other_forwardref_relation():
|
||||
# checkers = await Game.objects.create(name="checkers")
|
||||
# uno = await Game(name="Uno").save()
|
||||
#
|
||||
# await Child(name="Billy", favourite_game=uno, least_favourite_game=checkers).save()
|
||||
# await Child(name="Kate", favourite_game=checkers, least_favourite_game=uno).save()
|
||||
#
|
||||
# billy_check = await Child.objects.select_related(
|
||||
# ["favourite_game", "least_favourite_game"]
|
||||
# ).get(name="Billy")
|
||||
# assert billy_check.favourite_game == uno
|
||||
# assert billy_check.least_favourite_game == checkers
|
||||
#
|
||||
# uno_check = await Game.objects.select_related(["liked_by", "not_liked_by"]).get(
|
||||
# name="Uno"
|
||||
# )
|
||||
# assert uno_check.liked_by[0].name == "Billy"
|
||||
# assert uno_check.not_liked_by[0].name == "Kate"
|
||||
@pytest.mark.asyncio
|
||||
async def test_other_forwardref_relation():
|
||||
checkers = await Game.objects.create(name="checkers")
|
||||
uno = await Game(name="Uno").save()
|
||||
|
||||
await Child(name="Billy", favourite_game=uno, least_favourite_game=checkers).save()
|
||||
await Child(name="Kate", favourite_game=checkers, least_favourite_game=uno).save()
|
||||
|
||||
billy_check = await Child.objects.select_related(
|
||||
["favourite_game", "least_favourite_game"]
|
||||
).get(name="Billy")
|
||||
assert billy_check.favourite_game == uno
|
||||
assert billy_check.least_favourite_game == checkers
|
||||
|
||||
uno_check = await Game.objects.select_related(["liked_by", "not_liked_by"]).get(
|
||||
name="Uno"
|
||||
)
|
||||
assert uno_check.liked_by[0].name == "Billy"
|
||||
assert uno_check.not_liked_by[0].name == "Kate"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_m2m_self_forwardref_relation():
|
||||
checkers = await Game.objects.create(name="checkers")
|
||||
uno = await Game(name="Uno").save()
|
||||
jenga = await Game(name="Jenga").save()
|
||||
|
||||
billy = await Child(
|
||||
name="Billy", favourite_game=uno, least_favourite_game=checkers
|
||||
).save()
|
||||
kate = await Child(
|
||||
name="Kate", favourite_game=checkers, least_favourite_game=uno
|
||||
).save()
|
||||
steve = await Child(
|
||||
name="Steve", favourite_game=jenga, least_favourite_game=uno
|
||||
).save()
|
||||
|
||||
await billy.friends.add(kate)
|
||||
await billy.friends.add(steve)
|
||||
|
||||
await steve.friends.add(kate)
|
||||
await steve.friends.add(billy)
|
||||
|
||||
billy_check = await Child.objects.select_related(
|
||||
[
|
||||
"friends",
|
||||
"favourite_game",
|
||||
"least_favourite_game",
|
||||
"friends__favourite_game",
|
||||
"friends__least_favourite_game",
|
||||
]
|
||||
).get(name="Billy")
|
||||
assert len(billy_check.friends) == 2
|
||||
|
||||
@ -80,6 +80,17 @@ async def cleanup():
|
||||
await Author.objects.delete(each=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_saved_raises_error(cleanup):
|
||||
async with database:
|
||||
guido = await Author(first_name="Guido", last_name="Van Rossum").save()
|
||||
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||
news = Category(name="News")
|
||||
|
||||
with pytest.raises(ModelPersistenceError):
|
||||
await post.categories.add(news)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigning_related_objects(cleanup):
|
||||
async with database:
|
||||
|
||||
@ -6,6 +6,7 @@ import sqlalchemy as sa
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import ModelPersistenceError
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
metadata = sa.MetaData()
|
||||
@ -61,3 +62,15 @@ async def test_model_relationship():
|
||||
assert ws.id == 1
|
||||
assert ws.topic == "Topic 2"
|
||||
assert ws.category.name == "Foo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_relationship_with_not_saved():
|
||||
async with db:
|
||||
async with db.transaction(force_rollback=True):
|
||||
cat = Category(name="Foo", code=123)
|
||||
with pytest.raises(ModelPersistenceError):
|
||||
await Workshop(topic="Topic 1", category=cat).save()
|
||||
|
||||
with pytest.raises(ModelPersistenceError):
|
||||
await Workshop.objects.create(topic="Topic 1", category=cat)
|
||||
|
||||
Reference in New Issue
Block a user