diff --git a/ormar/protocols/queryset_protocol.py b/ormar/protocols/queryset_protocol.py index 397f58b..e2ba329 100644 --- a/ormar/protocols/queryset_protocol.py +++ b/ormar/protocols/queryset_protocol.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Tuple, Union try: from typing import Protocol @@ -55,7 +55,11 @@ class QuerySetProtocol(Protocol): # pragma: nocover async def update(self, each: bool = False, **kwargs: Any) -> int: ... - async def get_or_create(self, **kwargs: Any) -> "Model": + async def get_or_create( + self, + _defaults: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Tuple["Model", bool]: ... async def update_or_create(self, **kwargs: Any) -> "Model": diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index beae8d7..95fc97c 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -972,26 +972,36 @@ class QuerySet(Generic[T]): self.check_single_result_rows_count(processed_rows) return processed_rows[0] # type: ignore - async def get_or_create(self, *args: Any, **kwargs: Any) -> "T": + async def get_or_create( + self, + _defaults: Optional[Dict[str, Any]] = None, + *args: Any, + **kwargs: Any, + ) -> Tuple["T", bool]: """ Combination of create and get methods. Tries to get a row meeting the criteria for kwargs and if `NoMatch` exception is raised - it creates a new one with given kwargs. + it creates a new one with given kwargs and _defaults. Passing a criteria is actually calling filter(*args, **kwargs) method described below. :param kwargs: fields names and proper value types :type kwargs: Any - :return: returned or created Model - :rtype: Model + :param _defaults: default values for creating object + :type _defaults: Optional[Dict[str, Any]] + :return: model instance and a boolean + :rtype: Tuple("T", bool) """ try: - return await self.get(*args, **kwargs) + return await self.get(*args, **kwargs), False except NoMatch: - return await self.create(**kwargs) + if _defaults is None: + return await self.create(**kwargs), True + else: + return await self.create(**kwargs, **_defaults), True async def update_or_create(self, **kwargs: Any) -> "T": """ diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index 8170cf3..ec3f9fd 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -11,6 +11,7 @@ from typing import ( # noqa: I100, I201 TYPE_CHECKING, Type, TypeVar, + Tuple, Union, cast, ) @@ -483,23 +484,33 @@ class QuerysetProxy(Generic[T]): ) return len(children) - async def get_or_create(self, *args: Any, **kwargs: Any) -> "T": + async def get_or_create( + self, + _defaults: Optional[Dict[str, Any]] = None, + *args: Any, + **kwargs: Any, + ) -> Tuple["T", bool]: """ Combination of create and get methods. Tries to get a row meeting the criteria fro kwargs and if `NoMatch` exception is raised - it creates a new one with given kwargs. + it creates a new one with given kwargs and _defaults. :param kwargs: fields names and proper value types :type kwargs: Any - :return: returned or created Model - :rtype: Model + :param _defaults: default values for creating object + :type _defaults: Optional[Dict[str, Any]] + :return: model instance and a boolean + :rtype: Tuple("T", bool) """ try: - return await self.get(*args, **kwargs) - except ormar.NoMatch: - return await self.create(**kwargs) + return await self.get(*args, **kwargs), False + except NoMatch: + if _defaults is None: + return await self.create(**kwargs), True + else: + return await self.create(**kwargs, **_defaults), True async def update_or_create(self, **kwargs: Any) -> "T": """ diff --git a/tests/test_fastapi/test_m2m_forwardref.py b/tests/test_fastapi/test_m2m_forwardref.py index bfccaef..c75791a 100644 --- a/tests/test_fastapi/test_m2m_forwardref.py +++ b/tests/test_fastapi/test_m2m_forwardref.py @@ -102,7 +102,7 @@ def test_payload(): "native_name": "Thailand", } resp = client.post("/", json=payload, headers={"application-type": "json"}) - print(resp.content) + # print(resp.content) assert resp.status_code == 201 resp_country = Country(**resp.json()) diff --git a/tests/test_model_definition/test_aliases.py b/tests/test_model_definition/test_aliases.py index bb6b40f..d403304 100644 --- a/tests/test_model_definition/test_aliases.py +++ b/tests/test_model_definition/test_aliases.py @@ -157,15 +157,17 @@ async def test_bulk_operations_and_fields(): async def test_working_with_aliases_get_or_create(): async with database: async with database.transaction(force_rollback=True): - artist = await Artist.objects.get_or_create( + artist, created = await Artist.objects.get_or_create( first_name="Teddy", last_name="Bear", born_year=2020 ) assert artist.pk is not None + assert created is True - artist2 = await Artist.objects.get_or_create( + artist2, created = await Artist.objects.get_or_create( first_name="Teddy", last_name="Bear", born_year=2020 ) assert artist == artist2 + assert created is False art3 = artist2.dict() art3["born_year"] = 2019 diff --git a/tests/test_model_definition/test_save_status.py b/tests/test_model_definition/test_save_status.py index 9762810..0bbda39 100644 --- a/tests/test_model_definition/test_save_status.py +++ b/tests/test_model_definition/test_save_status.py @@ -195,12 +195,14 @@ async def test_queryset_methods(): comps = await Company.objects.all() assert [comp.saved for comp in comps] - comp2 = await Company.objects.get_or_create(name="Banzai_new", founded=2001) + comp2, created = await Company.objects.get_or_create(name="Banzai_new", founded=2001) assert comp2.saved + assert created is True - comp3 = await Company.objects.get_or_create(name="Banzai", founded=1988) + comp3, created = await Company.objects.get_or_create(name="Banzai", founded=1988) assert comp3.saved assert comp3.pk == comp.pk + assert created is False update_dict = comp.dict() update_dict["founded"] = 2010 diff --git a/tests/test_queries/test_queryproxy_on_m2m_models.py b/tests/test_queries/test_queryproxy_on_m2m_models.py index a91c4f8..1fe8fc5 100644 --- a/tests/test_queries/test_queryproxy_on_m2m_models.py +++ b/tests/test_queries/test_queryproxy_on_m2m_models.py @@ -102,20 +102,23 @@ async def test_queryset_methods(): await post.categories.add(news) await post.categories.add(breaking) - category = await post.categories.get_or_create(name="News") + category, created = await post.categories.get_or_create(name="News") assert category == news assert len(post.categories) == 1 + assert created is False - category = await post.categories.get_or_create(name="Breaking News") + category, created = await post.categories.get_or_create(name="Breaking News") assert category != breaking assert category.pk is not None assert len(post.categories) == 2 + assert created is True await post.categories.update_or_create(pk=category.pk, name="Urgent News") assert len(post.categories) == 2 - cat = await post.categories.get_or_create(name="Urgent News") + cat, created = await post.categories.get_or_create(name="Urgent News") assert cat.pk == category.pk assert len(post.categories) == 1 + assert created is False await post.categories.remove(cat) await cat.delete() diff --git a/tests/test_queries/test_queryset_level_methods.py b/tests/test_queries/test_queryset_level_methods.py index 7176010..ee308dc 100644 --- a/tests/test_queries/test_queryset_level_methods.py +++ b/tests/test_queries/test_queryset_level_methods.py @@ -166,17 +166,18 @@ async def test_delete_and_update(): @pytest.mark.asyncio async def test_get_or_create(): async with database: - tom = await Book.objects.get_or_create( + tom, created = await Book.objects.get_or_create( title="Volume I", author="Anonymous", genre="Fiction" ) assert await Book.objects.count() == 1 + assert created is True - assert ( - await Book.objects.get_or_create( - title="Volume I", author="Anonymous", genre="Fiction" - ) - == tom + second_tom, created = await Book.objects.get_or_create( + title="Volume I", author="Anonymous", genre="Fiction" ) + + assert second_tom.pk == tom.pk + assert created is False assert await Book.objects.count() == 1 assert await Book.objects.create( @@ -188,6 +189,28 @@ async def test_get_or_create(): ) +@pytest.mark.asyncio +async def test_get_or_create_with_defaults(): + async with database: + book, created = await Book.objects.get_or_create( + title="Nice book", _defaults={"author": "Mojix", "genre": "Historic"} + ) + assert created is True + assert book.author == "Mojix" + assert book.title == "Nice book" + assert book.genre == "Historic" + + book2, created = await Book.objects.get_or_create( + author="Mojix", _defaults={"title": "Book2"} + ) + assert created is False + assert book2 == book + assert book2.title == "Nice book" + assert book2.author == "Mojix" + assert book2.genre == "Historic" + assert await Book.objects.count() == 1 + + @pytest.mark.asyncio async def test_update_or_create(): async with database: diff --git a/tests/test_queries/test_reverse_fk_queryset.py b/tests/test_queries/test_reverse_fk_queryset.py index 80193bc..25f54a0 100644 --- a/tests/test_queries/test_reverse_fk_queryset.py +++ b/tests/test_queries/test_reverse_fk_queryset.py @@ -87,22 +87,26 @@ async def test_quering_by_reverse_fk(): assert await album.tracks.exists() assert await album.tracks.count() == 3 - track = await album.tracks.get_or_create( + track, created = await album.tracks.get_or_create( title="The Bird", position=1, play_count=30 ) assert track == track1 + assert created is False assert len(album.tracks) == 1 - track = await album.tracks.get_or_create( - title="The Bird2", position=4, play_count=5 + track, created = await album.tracks.get_or_create( + title="The Bird2", _defaults={"position": 4, "play_count": 5} ) assert track != track1 + assert created is True assert track.pk is not None + assert track.position == 4 and track.play_count == 5 assert len(album.tracks) == 2 await album.tracks.update_or_create(pk=track.pk, play_count=50) assert len(album.tracks) == 2 - track = await album.tracks.get_or_create(title="The Bird2") + track, created = await album.tracks.get_or_create(title="The Bird2") + assert created is False assert track.play_count == 50 assert len(album.tracks) == 1