diff --git a/ormar/fields/base.py b/ormar/fields/base.py index b11a7d7..966992e 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -267,7 +267,7 @@ class BaseField(FieldInfo): return value @classmethod - def set_self_reference_flag(cls): + def set_self_reference_flag(cls) -> None: """ Sets `self_reference` to True if field to and owner are same model. :return: None @@ -301,3 +301,13 @@ class BaseField(FieldInfo): :return: None :rtype: None """ + + @classmethod + def get_related_name(cls) -> str: + """ + Returns name to use for reverse relation. + It's either set as `related_name` or by default it's owner model. get_name + 's' + :return: name of the related_name or default related name. + :rtype: str + """ + return "" # pragma: no cover diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index e141e5c..e5d187d 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -167,6 +167,8 @@ def ForeignKey( # noqa CFQ002 """ owner = kwargs.pop("owner", None) + self_reference = kwargs.pop("self_reference", False) + if isinstance(to, ForwardRef): __type__ = to if not nullable else Optional[to] constraints: List = [] @@ -196,6 +198,7 @@ def ForeignKey( # noqa CFQ002 onupdate=onupdate, ondelete=ondelete, owner=owner, + self_reference=self_reference, ) return type("ForeignKey", (ForeignKeyField, BaseField), namespace) diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py index b9b8049..29923d0 100644 --- a/ormar/fields/many_to_many.py +++ b/ormar/fields/many_to_many.py @@ -1,8 +1,7 @@ from typing import Any, ForwardRef, List, Optional, TYPE_CHECKING, Tuple, Type, Union from pydantic.typing import evaluate_forwardref - -import ormar +import ormar # noqa: I100 from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField @@ -71,6 +70,7 @@ def ManyToMany( related_name = kwargs.pop("related_name", None) nullable = kwargs.pop("nullable", True) owner = kwargs.pop("owner", None) + self_reference = kwargs.pop("self_reference", False) if isinstance(to, ForwardRef): __type__ = to if not nullable else Optional[to] @@ -96,6 +96,7 @@ def ManyToMany( default=None, server_default=None, owner=owner, + self_reference=self_reference, ) return type("ManyToMany", (ManyToManyField, BaseField), namespace) diff --git a/ormar/models/helpers/relations.py b/ormar/models/helpers/relations.py index 5acbb52..70a558c 100644 --- a/ormar/models/helpers/relations.py +++ b/ormar/models/helpers/relations.py @@ -109,6 +109,7 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None: virtual=True, related_name=model_field.name, owner=model_field.to, + self_reference=model_field.self_reference, ) # register foreign keys on through model adjust_through_many_to_many_model(model_field=model_field) @@ -119,12 +120,11 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None: virtual=True, related_name=model_field.name, owner=model_field.to, + self_reference=model_field.self_reference, ) -def register_relation_in_alias_manager( - field: Type[ForeignKeyField], field_name: str -) -> None: +def register_relation_in_alias_manager(field: Type[ForeignKeyField]) -> None: """ Registers the relation (and reverse relation) in alias manager. The m2m relations require registration of through model between @@ -136,8 +136,6 @@ def register_relation_in_alias_manager( :param field: relation field :type field: ForeignKey or ManyToManyField class - :param field_name: name of the relation key - :type field_name: str """ if issubclass(field, ManyToManyField): if field.has_unresolved_forward_refs(): diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 42be22c..d6dc1c5 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -587,9 +587,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): populate_meta_sqlalchemy_table_if_required(new_model.Meta) expand_reverse_relationships(new_model) for field_name, field in new_model.Meta.model_fields.items(): - register_relation_in_alias_manager( - field=field, field_name=field_name - ) + register_relation_in_alias_manager(field=field) if new_model.Meta.pkname not in attrs["__annotations__"]: field_name = new_model.Meta.pkname diff --git a/ormar/models/mixins/prefetch_mixin.py b/ormar/models/mixins/prefetch_mixin.py index 04a11c8..273dd01 100644 --- a/ormar/models/mixins/prefetch_mixin.py +++ b/ormar/models/mixins/prefetch_mixin.py @@ -1,7 +1,7 @@ from typing import Callable, Dict, List, TYPE_CHECKING, Tuple, Type import ormar -from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField from ormar.models.mixins.relation_mixin import RelationMixin @@ -37,10 +37,7 @@ class PrefetchQueryMixin(RelationMixin): :rtype: Tuple[Type[Model], str] """ if reverse: - field_name = ( - parent_model.Meta.model_fields[related].related_name - or parent_model.get_name() + "s" - ) + field_name = parent_model.Meta.model_fields[related].get_related_name() field = target_model.Meta.model_fields[field_name] if issubclass(field, ormar.fields.ManyToManyField): field_name = field.default_target_field_name() @@ -79,7 +76,7 @@ class PrefetchQueryMixin(RelationMixin): return column.get_alias() if use_raw else column.name @classmethod - def get_related_field_name(cls, target_field: Type["BaseField"]) -> str: + def get_related_field_name(cls, target_field: Type["ForeignKeyField"]) -> str: """ Returns name of the relation field that should be used in prefetch query. This field is later used to register relation in prefetch query, @@ -93,7 +90,7 @@ class PrefetchQueryMixin(RelationMixin): if issubclass(target_field, ormar.fields.ManyToManyField): return cls.get_name() if target_field.virtual: - return target_field.related_name or cls.get_name() + "s" + return target_field.get_related_name() return target_field.to.Meta.pkname @classmethod diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index ed558e0..c062d7a 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -30,7 +30,7 @@ import sqlalchemy from pydantic import BaseModel import ormar # noqa I100 -from ormar.exceptions import ModelError +from ormar.exceptions import ModelError, ModelPersistenceError from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField from ormar.models.helpers import register_relation_in_alias_manager @@ -452,9 +452,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass field.evaluate_forward_ref(globalns=globalns, localns=localns) field.set_self_reference_flag() expand_reverse_relationship(model_field=field) - register_relation_in_alias_manager( - field=field, field_name=field_name, - ) + register_relation_in_alias_manager(field=field) update_column_definition(model=cls, field=field) populate_meta_sqlalchemy_table_if_required(meta=cls.Meta) super().update_forward_refs(**localns) @@ -731,9 +729,15 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass if self.get_column_alias(k) in self.Meta.table.columns } for field in self._extract_db_related_names(): - target_pk_name = self.Meta.model_fields[field].to.Meta.pkname + relation_field = self.Meta.model_fields[field] + target_pk_name = relation_field.to.Meta.pkname target_field = getattr(self, field) self_fields[field] = getattr(target_field, target_pk_name, None) + if not relation_field.nullable and not self_fields[field]: + raise ModelPersistenceError( + f"You cannot save {relation_field.to.get_name()} " + f"model without pk set!" + ) return self_fields def get_relation_model_id(self, target_field: Type["BaseField"]) -> Optional[int]: diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index d7b2686..0991562 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -135,7 +135,7 @@ class SqlJoin: join_parameters.model_cls.Meta.model_fields[part], ManyToManyField ): _fields = join_parameters.model_cls.Meta.model_fields - new_part = _fields[part].to.get_name() + new_part = _fields[part].default_target_field_name() self._switch_many_to_many_order_columns(part, new_part) if index > 0: # nested joins fields, exclude_fields = SqlJoin.update_inclusions( @@ -435,11 +435,9 @@ class SqlJoin: from_key = join_params.prev_model.get_column_alias( join_params.prev_model.Meta.pkname ) + breakpoint() elif join_params.prev_model.Meta.model_fields[part].virtual: - to_field = ( - join_params.prev_model.Meta.model_fields[part].related_name - or join_params.prev_model.get_name() + "s" - ) + to_field = join_params.prev_model.Meta.model_fields[part].get_related_name() to_key = model_cls.get_column_alias(to_field) from_key = join_params.prev_model.get_column_alias( join_params.prev_model.Meta.pkname diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index e0574a4..dda3baa 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -9,10 +9,12 @@ from typing import ( Tuple, Type, Union, + cast, ) import ormar from ormar.fields import BaseField, ManyToManyField +from ormar.fields.foreign_key import ForeignKeyField from ormar.queryset.clause import QueryClause from ormar.queryset.query import Query from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict @@ -314,6 +316,7 @@ class PrefetchQuery: for related in related_to_extract: target_field = model.Meta.model_fields[related] + target_field = cast(Type[ForeignKeyField], target_field) target_model = target_field.to.get_name() model_id = model.get_relation_model_id(target_field=target_field) @@ -421,6 +424,7 @@ class PrefetchQuery: fields = target_model.get_included(fields, related) exclude_fields = target_model.get_excluded(exclude_fields, related) target_field = target_model.Meta.model_fields[related] + target_field = cast(Type[ForeignKeyField], target_field) reverse = False if target_field.virtual or issubclass(target_field, ManyToManyField): reverse = True @@ -585,7 +589,7 @@ class PrefetchQuery: def _populate_rows( # noqa: CFQ002 self, rows: List, - target_field: Type["BaseField"], + target_field: Type["ForeignKeyField"], parent_model: Type["Model"], table_prefix: str, fields: Union[Set[Any], Dict[Any, Any], None], diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 4940265..11387b5 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -197,7 +197,7 @@ class QuerySet: limit_raw_sql=self.limit_sql_raw, ) exp = qry.build_select_expression() - # print("\n", exp.compile(compile_kwargs={"literal_binds": True})) + print("\n", exp.compile(compile_kwargs={"literal_binds": True})) return exp def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003 diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index 841071a..7003c7a 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -39,10 +39,9 @@ class QuerysetProxy(ormar.QuerySetProtocol): self._queryset: Optional["QuerySet"] = qryset self.type_: "RelationType" = type_ self._owner: "Model" = self.relation.manager.owner - self.related_field_name = ( - self._owner.Meta.model_fields[self.relation.field_name].related_name - or self._owner.get_name() + "s" - ) + self.related_field_name = self._owner.Meta.model_fields[ + self.relation.field_name + ].get_related_name() self.related_field = self.relation.to.Meta.model_fields[self.related_field_name] self.owner_pk_value = self._owner.pk @@ -108,8 +107,8 @@ class QuerysetProxy(ormar.QuerySetProtocol): :type child: Model """ model_cls = self.relation.through - owner_column = self._owner.get_name() - child_column = child.get_name() + owner_column = self.related_field.default_target_field_name() + child_column = self.related_field.default_source_field_name() kwargs = {owner_column: self._owner.pk, child_column: child.pk} if child.pk is None: raise ModelPersistenceError( @@ -129,8 +128,8 @@ class QuerysetProxy(ormar.QuerySetProtocol): :type child: Model """ queryset = ormar.QuerySet(model_cls=self.relation.through) - owner_column = self._owner.get_name() - child_column = child.get_name() + owner_column = self.related_field.default_target_field_name() + child_column = self.related_field.default_source_field_name() kwargs = {owner_column: self._owner, child_column: child} link_instance = await queryset.filter(**kwargs).get() # type: ignore await link_instance.delete() diff --git a/ormar/relations/relation_manager.py b/ormar/relations/relation_manager.py index 2e2b733..511dd7b 100644 --- a/ormar/relations/relation_manager.py +++ b/ormar/relations/relation_manager.py @@ -164,8 +164,6 @@ class RelationsManager: :param name: name of the relation :type name: str """ - relation_name = ( - item.Meta.model_fields[name].related_name or item.get_name() + "s" - ) + relation_name = item.Meta.model_fields[name].get_related_name() item._orm.remove(name, parent) parent._orm.remove(relation_name, item) diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index 206db7e..1012155 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -42,9 +42,8 @@ class RelationProxy(list): if self._related_field_name: return self._related_field_name owner_field = self._owner.Meta.model_fields[self.field_name] - self._related_field_name = ( - owner_field.related_name or self._owner.get_name() + "s" - ) + self._related_field_name = owner_field.get_related_name() + return self._related_field_name def __getattribute__(self, item: str) -> Any: diff --git a/tests/test_forward_refs.py b/tests/test_forward_refs.py index 7a45fca..e0e5303 100644 --- a/tests/test_forward_refs.py +++ b/tests/test_forward_refs.py @@ -35,34 +35,34 @@ Game = ForwardRef("Game") Child = ForwardRef("Child") -class ChildFriends(ormar.Model): +class ChildFriend(ormar.Model): class Meta(ModelMeta): metadata = metadata database = db -# class Child(ormar.Model): -# class Meta(ModelMeta): -# metadata = metadata -# database = db -# -# id: int = ormar.Integer(primary_key=True) -# name: str = ormar.String(max_length=100) -# favourite_game: Game = ormar.ForeignKey(Game, related_name="liked_by") -# least_favourite_game: Game = ormar.ForeignKey(Game, related_name="not_liked_by") -# friends: List[Child] = ormar.ManyToMany(Child, through=ChildFriends) -# -# -# class Game(ormar.Model): -# class Meta(ModelMeta): -# metadata = metadata -# database = db -# -# id: int = ormar.Integer(primary_key=True) -# name: str = ormar.String(max_length=100) +class Child(ormar.Model): + class Meta(ModelMeta): + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + favourite_game: Game = ormar.ForeignKey(Game, related_name="liked_by") + least_favourite_game: Game = ormar.ForeignKey(Game, related_name="not_liked_by") + friends = ormar.ManyToMany(Child, through=ChildFriend, related_name="also_friends") -# Child.update_forward_refs() +class Game(ormar.Model): + class Meta(ModelMeta): + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +Child.update_forward_refs() @pytest.fixture(autouse=True, scope="module") @@ -125,22 +125,56 @@ async def test_self_relation(): assert sam_check.employees[0].name == "Joe" -# @pytest.mark.asyncio -# async def test_other_forwardref_relation(): -# checkers = await Game.objects.create(name="checkers") -# uno = await Game(name="Uno").save() -# -# await Child(name="Billy", favourite_game=uno, least_favourite_game=checkers).save() -# await Child(name="Kate", favourite_game=checkers, least_favourite_game=uno).save() -# -# billy_check = await Child.objects.select_related( -# ["favourite_game", "least_favourite_game"] -# ).get(name="Billy") -# assert billy_check.favourite_game == uno -# assert billy_check.least_favourite_game == checkers -# -# uno_check = await Game.objects.select_related(["liked_by", "not_liked_by"]).get( -# name="Uno" -# ) -# assert uno_check.liked_by[0].name == "Billy" -# assert uno_check.not_liked_by[0].name == "Kate" +@pytest.mark.asyncio +async def test_other_forwardref_relation(): + checkers = await Game.objects.create(name="checkers") + uno = await Game(name="Uno").save() + + await Child(name="Billy", favourite_game=uno, least_favourite_game=checkers).save() + await Child(name="Kate", favourite_game=checkers, least_favourite_game=uno).save() + + billy_check = await Child.objects.select_related( + ["favourite_game", "least_favourite_game"] + ).get(name="Billy") + assert billy_check.favourite_game == uno + assert billy_check.least_favourite_game == checkers + + uno_check = await Game.objects.select_related(["liked_by", "not_liked_by"]).get( + name="Uno" + ) + assert uno_check.liked_by[0].name == "Billy" + assert uno_check.not_liked_by[0].name == "Kate" + + +@pytest.mark.asyncio +async def test_m2m_self_forwardref_relation(): + checkers = await Game.objects.create(name="checkers") + uno = await Game(name="Uno").save() + jenga = await Game(name="Jenga").save() + + billy = await Child( + name="Billy", favourite_game=uno, least_favourite_game=checkers + ).save() + kate = await Child( + name="Kate", favourite_game=checkers, least_favourite_game=uno + ).save() + steve = await Child( + name="Steve", favourite_game=jenga, least_favourite_game=uno + ).save() + + await billy.friends.add(kate) + await billy.friends.add(steve) + + await steve.friends.add(kate) + await steve.friends.add(billy) + + billy_check = await Child.objects.select_related( + [ + "friends", + "favourite_game", + "least_favourite_game", + "friends__favourite_game", + "friends__least_favourite_game", + ] + ).get(name="Billy") + assert len(billy_check.friends) == 2 diff --git a/tests/test_many_to_many.py b/tests/test_many_to_many.py index 8d8b258..8b10eae 100644 --- a/tests/test_many_to_many.py +++ b/tests/test_many_to_many.py @@ -80,6 +80,17 @@ async def cleanup(): await Author.objects.delete(each=True) +@pytest.mark.asyncio +async def test_not_saved_raises_error(cleanup): + async with database: + guido = await Author(first_name="Guido", last_name="Van Rossum").save() + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = Category(name="News") + + with pytest.raises(ModelPersistenceError): + await post.categories.add(news) + + @pytest.mark.asyncio async def test_assigning_related_objects(cleanup): async with database: diff --git a/tests/test_saving_related.py b/tests/test_saving_related.py index 6dd4fd2..388e290 100644 --- a/tests/test_saving_related.py +++ b/tests/test_saving_related.py @@ -6,6 +6,7 @@ import sqlalchemy as sa from sqlalchemy import create_engine import ormar +from ormar.exceptions import ModelPersistenceError from tests.settings import DATABASE_URL metadata = sa.MetaData() @@ -61,3 +62,15 @@ async def test_model_relationship(): assert ws.id == 1 assert ws.topic == "Topic 2" assert ws.category.name == "Foo" + + +@pytest.mark.asyncio +async def test_model_relationship_with_not_saved(): + async with db: + async with db.transaction(force_rollback=True): + cat = Category(name="Foo", code=123) + with pytest.raises(ModelPersistenceError): + await Workshop(topic="Topic 1", category=cat).save() + + with pytest.raises(ModelPersistenceError): + await Workshop.objects.create(topic="Topic 1", category=cat)