Update get_or_create method

This commit is contained in:
Mojix Coder
2022-02-01 09:44:07 +03:30
parent 4ed267e5c3
commit fc32001fe7
9 changed files with 92 additions and 33 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Tuple, Union
try: try:
from typing import Protocol from typing import Protocol
@ -55,7 +55,11 @@ class QuerySetProtocol(Protocol): # pragma: nocover
async def update(self, each: bool = False, **kwargs: Any) -> int: async def update(self, each: bool = False, **kwargs: Any) -> int:
... ...
async def get_or_create(self, **kwargs: Any) -> "Model": async def get_or_create(
self,
_defaults: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Tuple["Model", bool]:
... ...
async def update_or_create(self, **kwargs: Any) -> "Model": async def update_or_create(self, **kwargs: Any) -> "Model":

View File

@ -972,26 +972,36 @@ class QuerySet(Generic[T]):
self.check_single_result_rows_count(processed_rows) self.check_single_result_rows_count(processed_rows)
return processed_rows[0] # type: ignore return processed_rows[0] # type: ignore
async def get_or_create(self, *args: Any, **kwargs: Any) -> "T": async def get_or_create(
self,
_defaults: Optional[Dict[str, Any]] = None,
*args: Any,
**kwargs: Any,
) -> Tuple["T", bool]:
""" """
Combination of create and get methods. Combination of create and get methods.
Tries to get a row meeting the criteria for kwargs Tries to get a row meeting the criteria for kwargs
and if `NoMatch` exception is raised and if `NoMatch` exception is raised
it creates a new one with given kwargs. it creates a new one with given kwargs and _defaults.
Passing a criteria is actually calling filter(*args, **kwargs) method described Passing a criteria is actually calling filter(*args, **kwargs) method described
below. below.
:param kwargs: fields names and proper value types :param kwargs: fields names and proper value types
:type kwargs: Any :type kwargs: Any
:return: returned or created Model :param _defaults: default values for creating object
:rtype: Model :type _defaults: Optional[Dict[str, Any]]
:return: model instance and a boolean
:rtype: Tuple("T", bool)
""" """
try: try:
return await self.get(*args, **kwargs) return await self.get(*args, **kwargs), False
except NoMatch: except NoMatch:
return await self.create(**kwargs) if _defaults is None:
return await self.create(**kwargs), True
else:
return await self.create(**kwargs, **_defaults), True
async def update_or_create(self, **kwargs: Any) -> "T": async def update_or_create(self, **kwargs: Any) -> "T":
""" """

View File

@ -11,6 +11,7 @@ from typing import ( # noqa: I100, I201
TYPE_CHECKING, TYPE_CHECKING,
Type, Type,
TypeVar, TypeVar,
Tuple,
Union, Union,
cast, cast,
) )
@ -483,23 +484,33 @@ class QuerysetProxy(Generic[T]):
) )
return len(children) return len(children)
async def get_or_create(self, *args: Any, **kwargs: Any) -> "T": async def get_or_create(
self,
_defaults: Optional[Dict[str, Any]] = None,
*args: Any,
**kwargs: Any,
) -> Tuple["T", bool]:
""" """
Combination of create and get methods. Combination of create and get methods.
Tries to get a row meeting the criteria fro kwargs Tries to get a row meeting the criteria fro kwargs
and if `NoMatch` exception is raised and if `NoMatch` exception is raised
it creates a new one with given kwargs. it creates a new one with given kwargs and _defaults.
:param kwargs: fields names and proper value types :param kwargs: fields names and proper value types
:type kwargs: Any :type kwargs: Any
:return: returned or created Model :param _defaults: default values for creating object
:rtype: Model :type _defaults: Optional[Dict[str, Any]]
:return: model instance and a boolean
:rtype: Tuple("T", bool)
""" """
try: try:
return await self.get(*args, **kwargs) return await self.get(*args, **kwargs), False
except ormar.NoMatch: except NoMatch:
return await self.create(**kwargs) if _defaults is None:
return await self.create(**kwargs), True
else:
return await self.create(**kwargs, **_defaults), True
async def update_or_create(self, **kwargs: Any) -> "T": async def update_or_create(self, **kwargs: Any) -> "T":
""" """

View File

@ -102,7 +102,7 @@ def test_payload():
"native_name": "Thailand", "native_name": "Thailand",
} }
resp = client.post("/", json=payload, headers={"application-type": "json"}) resp = client.post("/", json=payload, headers={"application-type": "json"})
print(resp.content) # print(resp.content)
assert resp.status_code == 201 assert resp.status_code == 201
resp_country = Country(**resp.json()) resp_country = Country(**resp.json())

View File

@ -157,15 +157,17 @@ async def test_bulk_operations_and_fields():
async def test_working_with_aliases_get_or_create(): async def test_working_with_aliases_get_or_create():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
artist = await Artist.objects.get_or_create( artist, created = await Artist.objects.get_or_create(
first_name="Teddy", last_name="Bear", born_year=2020 first_name="Teddy", last_name="Bear", born_year=2020
) )
assert artist.pk is not None assert artist.pk is not None
assert created is True
artist2 = await Artist.objects.get_or_create( artist2, created = await Artist.objects.get_or_create(
first_name="Teddy", last_name="Bear", born_year=2020 first_name="Teddy", last_name="Bear", born_year=2020
) )
assert artist == artist2 assert artist == artist2
assert created is False
art3 = artist2.dict() art3 = artist2.dict()
art3["born_year"] = 2019 art3["born_year"] = 2019

View File

@ -195,12 +195,14 @@ async def test_queryset_methods():
comps = await Company.objects.all() comps = await Company.objects.all()
assert [comp.saved for comp in comps] assert [comp.saved for comp in comps]
comp2 = await Company.objects.get_or_create(name="Banzai_new", founded=2001) comp2, created = await Company.objects.get_or_create(name="Banzai_new", founded=2001)
assert comp2.saved assert comp2.saved
assert created is True
comp3 = await Company.objects.get_or_create(name="Banzai", founded=1988) comp3, created = await Company.objects.get_or_create(name="Banzai", founded=1988)
assert comp3.saved assert comp3.saved
assert comp3.pk == comp.pk assert comp3.pk == comp.pk
assert created is False
update_dict = comp.dict() update_dict = comp.dict()
update_dict["founded"] = 2010 update_dict["founded"] = 2010

View File

@ -102,20 +102,23 @@ async def test_queryset_methods():
await post.categories.add(news) await post.categories.add(news)
await post.categories.add(breaking) await post.categories.add(breaking)
category = await post.categories.get_or_create(name="News") category, created = await post.categories.get_or_create(name="News")
assert category == news assert category == news
assert len(post.categories) == 1 assert len(post.categories) == 1
assert created is False
category = await post.categories.get_or_create(name="Breaking News") category, created = await post.categories.get_or_create(name="Breaking News")
assert category != breaking assert category != breaking
assert category.pk is not None assert category.pk is not None
assert len(post.categories) == 2 assert len(post.categories) == 2
assert created is True
await post.categories.update_or_create(pk=category.pk, name="Urgent News") await post.categories.update_or_create(pk=category.pk, name="Urgent News")
assert len(post.categories) == 2 assert len(post.categories) == 2
cat = await post.categories.get_or_create(name="Urgent News") cat, created = await post.categories.get_or_create(name="Urgent News")
assert cat.pk == category.pk assert cat.pk == category.pk
assert len(post.categories) == 1 assert len(post.categories) == 1
assert created is False
await post.categories.remove(cat) await post.categories.remove(cat)
await cat.delete() await cat.delete()

View File

@ -166,17 +166,18 @@ async def test_delete_and_update():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_or_create(): async def test_get_or_create():
async with database: async with database:
tom = await Book.objects.get_or_create( tom, created = await Book.objects.get_or_create(
title="Volume I", author="Anonymous", genre="Fiction" title="Volume I", author="Anonymous", genre="Fiction"
) )
assert await Book.objects.count() == 1 assert await Book.objects.count() == 1
assert created is True
assert ( second_tom, created = await Book.objects.get_or_create(
await Book.objects.get_or_create( title="Volume I", author="Anonymous", genre="Fiction"
title="Volume I", author="Anonymous", genre="Fiction"
)
== tom
) )
assert second_tom.pk == tom.pk
assert created is False
assert await Book.objects.count() == 1 assert await Book.objects.count() == 1
assert await Book.objects.create( assert await Book.objects.create(
@ -188,6 +189,28 @@ async def test_get_or_create():
) )
@pytest.mark.asyncio
async def test_get_or_create_with_defaults():
async with database:
book, created = await Book.objects.get_or_create(
title="Nice book", _defaults={"author": "Mojix", "genre": "Historic"}
)
assert created is True
assert book.author == "Mojix"
assert book.title == "Nice book"
assert book.genre == "Historic"
book2, created = await Book.objects.get_or_create(
author="Mojix", _defaults={"title": "Book2"}
)
assert created is False
assert book2 == book
assert book2.title == "Nice book"
assert book2.author == "Mojix"
assert book2.genre == "Historic"
assert await Book.objects.count() == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_or_create(): async def test_update_or_create():
async with database: async with database:

View File

@ -87,22 +87,26 @@ async def test_quering_by_reverse_fk():
assert await album.tracks.exists() assert await album.tracks.exists()
assert await album.tracks.count() == 3 assert await album.tracks.count() == 3
track = await album.tracks.get_or_create( track, created = await album.tracks.get_or_create(
title="The Bird", position=1, play_count=30 title="The Bird", position=1, play_count=30
) )
assert track == track1 assert track == track1
assert created is False
assert len(album.tracks) == 1 assert len(album.tracks) == 1
track = await album.tracks.get_or_create( track, created = await album.tracks.get_or_create(
title="The Bird2", position=4, play_count=5 title="The Bird2", _defaults={"position": 4, "play_count": 5}
) )
assert track != track1 assert track != track1
assert created is True
assert track.pk is not None assert track.pk is not None
assert track.position == 4 and track.play_count == 5
assert len(album.tracks) == 2 assert len(album.tracks) == 2
await album.tracks.update_or_create(pk=track.pk, play_count=50) await album.tracks.update_or_create(pk=track.pk, play_count=50)
assert len(album.tracks) == 2 assert len(album.tracks) == 2
track = await album.tracks.get_or_create(title="The Bird2") track, created = await album.tracks.get_or_create(title="The Bird2")
assert created is False
assert track.play_count == 50 assert track.play_count == 50
assert len(album.tracks) == 1 assert len(album.tracks) == 1