diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 933920b..2c22326 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: 22.1.0 + rev: 22.3.0 hooks: - id: black exclude: ^(docs_src/|examples/) @@ -11,7 +11,7 @@ repos: exclude: ^(docs_src/|examples/|tests/) args: [ '--max-line-length=88' ] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.931 + rev: v0.961 hooks: - id: mypy exclude: ^(docs_src/|examples/) diff --git a/README.md b/README.md index 525c87a..05632be 100644 --- a/README.md +++ b/README.md @@ -595,6 +595,7 @@ metadata.drop_all(engine) * `bulk_update(objects: List[Model], columns: List[str] = None) -> None` * `delete(*args, each: bool = False, **kwargs) -> int` * `all(*args, **kwargs) -> List[Optional[Model]]` +* `iterate(*args, **kwargs) -> AsyncGenerator[Model]` * `filter(*args, **kwargs) -> QuerySet` * `exclude(*args, **kwargs) -> QuerySet` * `select_related(related: Union[List, str]) -> QuerySet` diff --git a/docs/index.md b/docs/index.md index 6ea70e8..5eaaacf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -604,6 +604,7 @@ metadata.drop_all(engine) * `bulk_update(objects: List[Model], columns: List[str] = None) -> None` * `delete(*args, each: bool = False, **kwargs) -> int` * `all(*args, **kwargs) -> List[Optional[Model]]` +* `iterate(*args, **kwargs) -> AsyncGenerator[Model]` * `filter(*args, **kwargs) -> QuerySet` * `exclude(*args, **kwargs) -> QuerySet` * `select_related(related: Union[List, str]) -> QuerySet` diff --git a/docs/queries/read.md b/docs/queries/read.md index 9fa6794..68fa50a 100644 --- a/docs/queries/read.md +++ b/docs/queries/read.md @@ -173,6 +173,45 @@ tracks = await Track.objects.all() ``` +## iterate + +`iterate(*args, **kwargs) -> AsyncGenerator["Model"]` + +Return async iterable generator for all rows from a database for given model. + +Passing args and/or kwargs is a shortcut and equals to calling `filter(*args, **kwargs).iterate()`. + +If there are no rows meeting the criteria an empty async generator is returned. + +```python +class Album(ormar.Model): + class Meta: + tablename = "album" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) +``` + +```python +await Album.objects.create(name='The Cat') +await Album.objects.create(name='The Dog') +# will asynchronously iterate all Album models yielding one main model at a time from the generator +async for album in Album.objects.iterate(): + print(album.name) + +# The Cat +# The Dog + +``` + +!!!warning + Use of `iterate()` causes previous `prefetch_related()` calls to be ignored; + since these two optimizations do not make sense together. + + If `iterate()` & `prefetch_related()` are used together the `QueryDefinitionError` exception is raised. + ## Model methods Each model instance have a set of methods to `save`, `update` or `load` itself. @@ -235,4 +274,4 @@ objects from other side of the relation. To read more about `QuerysetProxy` visit [querysetproxy][querysetproxy] section -[querysetproxy]: ../relations/queryset-proxy.md \ No newline at end of file +[querysetproxy]: ../relations/queryset-proxy.md diff --git a/docs/relations/queryset-proxy.md b/docs/relations/queryset-proxy.md index 4d167b6..fcfc0a0 100644 --- a/docs/relations/queryset-proxy.md +++ b/docs/relations/queryset-proxy.md @@ -84,6 +84,23 @@ assert news_posts[0].author == guido !!!tip Read more in queries documentation [all][all] +### iterate + +`iterate(**kwargs) -> AsyncGenerator["Model"]` + +To iterate on related models use `iterate()` method. + +Note that you can filter the queryset, select related, exclude fields etc. like in normal query. + +```python +# iterate on categories of this post with an async generator +async for category in post.categories.iterate(): + print(category.name) +``` + +!!!tip + Read more in queries documentation [iterate][iterate] + ## Insert/ update data into database ### create @@ -294,6 +311,7 @@ Returns a bool value to confirm if there are rows matching the given criteria (a [queries]: ../queries/index.md [get]: ../queries/read.md#get [all]: ../queries/read.md#all +[iterate]: ../queries/read.md#iterate [create]: ../queries/create.md#create [get_or_create]: ../queries/read.md#get_or_create [update_or_create]: ../queries/update.md#update_or_create diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 5ea70fa..5ae40b6 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -12,6 +12,7 @@ from typing import ( TypeVar, Union, cast, + AsyncGenerator, ) import databases @@ -1055,6 +1056,55 @@ class QuerySet(Generic[T]): return result_rows + async def iterate( # noqa: A003 + self, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator["T", None]: + """ + Return async iterable generator for all rows from a database for given model. + + Passing args and/or kwargs is a shortcut and equals to calling + `filter(*args, **kwargs).iterate()`. + + If there are no rows meeting the criteria an empty async generator is returned. + + :param kwargs: fields names and proper value types + :type kwargs: Any + :return: asynchronous iterable generator of returned models + :rtype: AsyncGenerator[Model] + """ + + if self._prefetch_related: + raise QueryDefinitionError( + "Prefetch related queries are not supported in iterators" + ) + + if kwargs or args: + async for result in self.filter(*args, **kwargs).iterate(): + yield result + return + + expr = self.build_select_expression() + + rows: list = [] + last_primary_key = None + pk_alias = self.model.get_column_alias(self.model_meta.pkname) + + async for row in self.database.iterate(query=expr): + current_primary_key = row[pk_alias] + if last_primary_key == current_primary_key or last_primary_key is None: + last_primary_key = current_primary_key + rows.append(row) + continue + + yield self._process_query_result_rows(rows)[0] + last_primary_key = current_primary_key + rows = [row] + + if rows: + yield self._process_query_result_rows(rows)[0] + async def create(self, **kwargs: Any) -> "T": """ Creates the model instance, saves it in a database and returns the updates model diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index b49ca4f..e27934d 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -14,6 +14,7 @@ from typing import ( # noqa: I100, I201 TypeVar, Union, cast, + AsyncGenerator, ) import ormar # noqa: I100, I202 @@ -431,6 +432,28 @@ class QuerysetProxy(Generic[T]): self._register_related(all_items) return all_items + async def iterate( # noqa: A003 + self, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator["T", None]: + """ + Return async iterable generator for all rows from a database for given model. + + Passing args and/or kwargs is a shortcut and equals to calling + `filter(*args, **kwargs).iterate()`. + + If there are no rows meeting the criteria an empty async generator is returned. + + :param kwargs: fields names and proper value types + :type kwargs: Any + :return: asynchronous iterable generator of returned models + :rtype: AsyncGenerator[Model] + """ + + async for item in self.queryset.iterate(*args, **kwargs): + yield item + async def create(self, **kwargs: Any) -> "T": """ Creates the model instance, saves it in a database and returns the updates model diff --git a/tests/test_model_definition/test_iterate.py b/tests/test_model_definition/test_iterate.py new file mode 100644 index 0000000..54fd324 --- /dev/null +++ b/tests/test_model_definition/test_iterate.py @@ -0,0 +1,268 @@ +import uuid +import databases +import pytest +import sqlalchemy + +import ormar +from ormar.exceptions import QueryDefinitionError +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class User(ormar.Model): + class Meta: + tablename = "users3" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, default="") + + +class User2(ormar.Model): + class Meta: + tablename = "users4" + metadata = metadata + database = database + + id: uuid.UUID = ormar.UUID( + uuid_format="string", primary_key=True, default=uuid.uuid4 + ) + name: str = ormar.String(max_length=100, default="") + + +class Task(ormar.Model): + class Meta: + tablename = "tasks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, default="") + user: User = ormar.ForeignKey(to=User) + + +class Task2(ormar.Model): + class Meta: + tablename = "tasks2" + metadata = metadata + database = database + + id: uuid.UUID = ormar.UUID( + uuid_format="string", primary_key=True, default=uuid.uuid4 + ) + name: str = ormar.String(max_length=100, default="") + user: User2 = ormar.ForeignKey(to=User2) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_empty_result(): + async with database: + async with database.transaction(force_rollback=True): + async for user in User.objects.iterate(): + pass # pragma: no cover + + +@pytest.mark.asyncio +async def test_model_iterator(): + async with database: + async with database.transaction(force_rollback=True): + tom = await User.objects.create(name="Tom") + jane = await User.objects.create(name="Jane") + lucy = await User.objects.create(name="Lucy") + + async for user in User.objects.iterate(): + assert user in (tom, jane, lucy) + + +@pytest.mark.asyncio +async def test_model_iterator_filter(): + async with database: + async with database.transaction(force_rollback=True): + tom = await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") + + async for user in User.objects.iterate(name="Tom"): + assert user.name == tom.name + + +@pytest.mark.asyncio +async def test_model_iterator_relations(): + async with database: + async with database.transaction(force_rollback=True): + tom = await User.objects.create(name="Tom") + jane = await User.objects.create(name="Jane") + lucy = await User.objects.create(name="Lucy") + + for user in tom, jane, lucy: + await Task.objects.create(name="task1", user=user) + await Task.objects.create(name="task2", user=user) + + results = [] + async for user in User.objects.select_related(User.tasks).iterate(): + assert len(user.tasks) == 2 + results.append(user) + + assert len(results) == 3 + + +@pytest.mark.asyncio +async def test_model_iterator_relations_queryset_proxy(): + async with database: + async with database.transaction(force_rollback=True): + tom = await User.objects.create(name="Tom") + jane = await User.objects.create(name="Jane") + + for user in tom, jane: + await Task.objects.create(name="task1", user=user) + await Task.objects.create(name="task2", user=user) + + tom_tasks = [] + async for task in tom.tasks.iterate(): + assert task.name in ("task1", "task2") + tom_tasks.append(task) + + assert len(tom_tasks) == 2 + + jane_tasks = [] + async for task in jane.tasks.iterate(): + assert task.name in ("task1", "task2") + jane_tasks.append(task) + + assert len(jane_tasks) == 2 + + +@pytest.mark.asyncio +async def test_model_iterator_uneven_number_of_relations(): + async with database: + async with database.transaction(force_rollback=True): + tom = await User.objects.create(name="Tom") + jane = await User.objects.create(name="Jane") + lucy = await User.objects.create(name="Lucy") + + for user in tom, jane: + await Task.objects.create(name="task1", user=user) + await Task.objects.create(name="task2", user=user) + + await Task.objects.create(name="task3", user=lucy) + expected_counts = {"Tom": 2, "Jane": 2, "Lucy": 1} + results = [] + async for user in User.objects.select_related(User.tasks).iterate(): + assert len(user.tasks) == expected_counts[user.name] + results.append(user) + + assert len(results) == 3 + + +@pytest.mark.asyncio +async def test_model_iterator_uuid_pk(): + async with database: + async with database.transaction(force_rollback=True): + tom = await User2.objects.create(name="Tom") + jane = await User2.objects.create(name="Jane") + lucy = await User2.objects.create(name="Lucy") + + async for user in User2.objects.iterate(): + assert user in (tom, jane, lucy) + + +@pytest.mark.asyncio +async def test_model_iterator_filter_uuid_pk(): + async with database: + async with database.transaction(force_rollback=True): + tom = await User2.objects.create(name="Tom") + await User2.objects.create(name="Jane") + await User2.objects.create(name="Lucy") + + async for user in User2.objects.iterate(name="Tom"): + assert user.name == tom.name + + +@pytest.mark.asyncio +async def test_model_iterator_relations_uuid_pk(): + async with database: + async with database.transaction(force_rollback=True): + tom = await User2.objects.create(name="Tom") + jane = await User2.objects.create(name="Jane") + lucy = await User2.objects.create(name="Lucy") + + for user in tom, jane, lucy: + await Task2.objects.create(name="task1", user=user) + await Task2.objects.create(name="task2", user=user) + + results = [] + async for user in User2.objects.select_related(User2.task2s).iterate(): + assert len(user.task2s) == 2 + results.append(user) + + assert len(results) == 3 + + +@pytest.mark.asyncio +async def test_model_iterator_relations_queryset_proxy_uuid_pk(): + async with database: + async with database.transaction(force_rollback=True): + tom = await User2.objects.create(name="Tom") + jane = await User2.objects.create(name="Jane") + + for user in tom, jane: + await Task2.objects.create(name="task1", user=user) + await Task2.objects.create(name="task2", user=user) + + tom_tasks = [] + async for task in tom.task2s.iterate(): + assert task.name in ("task1", "task2") + tom_tasks.append(task) + + assert len(tom_tasks) == 2 + + jane_tasks = [] + async for task in jane.task2s.iterate(): + assert task.name in ("task1", "task2") + jane_tasks.append(task) + + assert len(jane_tasks) == 2 + + +@pytest.mark.asyncio +async def test_model_iterator_uneven_number_of_relations_uuid_pk(): + async with database: + async with database.transaction(force_rollback=True): + tom = await User2.objects.create(name="Tom") + jane = await User2.objects.create(name="Jane") + lucy = await User2.objects.create(name="Lucy") + + for user in tom, jane: + await Task2.objects.create(name="task1", user=user) + await Task2.objects.create(name="task2", user=user) + + await Task2.objects.create(name="task3", user=lucy) + + expected_counts = {"Tom": 2, "Jane": 2, "Lucy": 1} + + results = [] + async for user in User2.objects.select_related(User2.task2s).iterate(): + assert len(user.task2s) == expected_counts[user.name] + results.append(user) + + assert len(results) == 3 + + +@pytest.mark.asyncio +async def test_model_iterator_with_prefetch_raises_error(): + async with database: + with pytest.raises(QueryDefinitionError): + async for user in User.objects.prefetch_related(User.tasks).iterate(): + pass # pragma: no cover