From 955ac48cdd1cf878354cd842603fa71b76a9da71 Mon Sep 17 00:00:00 2001 From: collerek Date: Sat, 5 Jun 2021 18:53:15 +0200 Subject: [PATCH] wip - refactor of alias resolver - through models columns with fields are not properly handled yet --- ormar/queryset/queryset.py | 55 +-- ormar/queryset/reverse_alias_resolver.py | 82 +++++ ormar/relations/alias_manager.py | 9 + .../test_selecting_subset_of_columns.py | 80 +++-- .../test_values_and_values_list.py | 328 ++++++++++++------ 5 files changed, 358 insertions(+), 196 deletions(-) create mode 100644 ormar/queryset/reverse_alias_resolver.py diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index f4aeff2..2d1ca79 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -26,6 +26,7 @@ from ormar.queryset.actions.order_action import OrderAction from ormar.queryset.clause import FilterGroup, Prefix, QueryClause from ormar.queryset.prefetch_query import PrefetchQuery from ormar.queryset.query import Query +from ormar.queryset.reverse_alias_resolver import ReverseAliasResolver from ormar.queryset.utils import get_relationship_alias_model_and_str if TYPE_CHECKING: # pragma no cover @@ -586,10 +587,12 @@ class QuerySet(Generic[T]): rows = await self.database.fetch_all(expr) if not rows: return [] - column_names = list(rows[0].keys()) - column_map = self._resolve_data_prefix_to_relation_str( - column_names=column_names + alias_resolver = ReverseAliasResolver( + select_related=self._select_related, + excludable=self._excludable, + model_cls=self.model_cls, ) + column_map = alias_resolver.resolve_columns(columns_names=list(rows[0].keys())) result = [ {column_map.get(k): v for k, v in dict(x).items() if k in column_map} for x in rows @@ -598,7 +601,7 @@ class QuerySet(Generic[T]): return result if _flatten and not self._excludable.include_entry_count() == 1: raise QueryDefinitionError( - "You cannot flatten values_list if more than " "one field is selected!" + "You cannot flatten values_list if more than one field is selected!" ) tuple_result = [tuple(x.values()) for x in result] return tuple_result if not _flatten else [x[0] for x in tuple_result] @@ -625,50 +628,6 @@ class QuerySet(Generic[T]): """ return await self.values(fields=fields, _as_dict=False, _flatten=flatten) - def _resolve_data_prefix_to_relation_str(self, column_names: List[str]) -> Dict: - resolved_names = dict() - for column_name in column_names: - prefixes_map = self._create_prefixes_map() - column_parts = column_name.split("_") - potential_prefix = column_parts[0] - if potential_prefix in prefixes_map: - prefix = prefixes_map[potential_prefix] - allowed_columns = prefix.model_cls.own_table_columns( - model=prefix.model_cls, - excludable=self._excludable, - alias=prefix.table_prefix, - add_pk_columns=False, - ) - new_column_name = "_".join(column_parts[1:]) - if new_column_name in allowed_columns: - resolved_names[column_name] = f"{prefix.relation_str}__" + "_".join( - column_name.split("_")[1:] - ) - else: - assert self.model_cls - allowed_columns = self.model_cls.own_table_columns( - model=self.model_cls, - excludable=self._excludable, - add_pk_columns=False, - ) - if column_name in allowed_columns: - resolved_names[column_name] = column_name - return resolved_names - - def _create_prefixes_map(self) -> Dict[str, Prefix]: - prefixes: List[Prefix] = [] - for related in self._select_related: - related_split = related.split("__") - for index in range(len(related_split)): - prefix = Prefix( - self.model_cls, # type: ignore - *get_relationship_alias_model_and_str( - self.model_cls, related_split[0 : (index + 1)] # type: ignore - ), - ) - prefixes.append(prefix) - return {x.table_prefix: x for x in prefixes} - async def exists(self) -> bool: """ Returns a bool value to confirm if there are rows matching the given criteria diff --git a/ormar/queryset/reverse_alias_resolver.py b/ormar/queryset/reverse_alias_resolver.py new file mode 100644 index 0000000..af916bc --- /dev/null +++ b/ormar/queryset/reverse_alias_resolver.py @@ -0,0 +1,82 @@ +from typing import Dict, List, TYPE_CHECKING, Tuple, Type + +from ormar.queryset.clause import Prefix +from ormar.queryset.utils import get_relationship_alias_model_and_str + +if TYPE_CHECKING: + from ormar import Model + from ormar.models.excludable import ExcludableItems + + +class ReverseAliasResolver: + def __init__( + self, + model_cls: Type["Model"], + excludable: "ExcludableItems", + select_related: List[str], + ) -> None: + self.select_related = select_related + self.model_cls = model_cls + self.reversed_aliases = self.model_cls.Meta.alias_manager.reversed_aliases + self.excludable = excludable + + def resolve_columns(self, columns_names: List[str]) -> Dict: + resolved_names = dict() + prefixes, target_models = self._create_prefixes_map() + for column_name in columns_names: + column_parts = column_name.split("_") + potential_prefix = column_parts[0] + if potential_prefix in self.reversed_aliases: + relation = self.reversed_aliases[potential_prefix] + relation_str = prefixes[relation] + target_model = target_models[relation] + allowed_columns = target_model.own_table_columns( + model=target_model, + excludable=self.excludable, + alias=potential_prefix, + add_pk_columns=False, + ) + new_column_name = column_name.replace(f"{potential_prefix}_", "") + if new_column_name in allowed_columns: + resolved_names[column_name] = column_name.replace( + f"{potential_prefix}_", f"{relation_str}__" + ) + else: + allowed_columns = self.model_cls.own_table_columns( + model=self.model_cls, + excludable=self.excludable, + add_pk_columns=False, + ) + if column_name in allowed_columns: + resolved_names[column_name] = column_name + + return resolved_names + + def _create_prefixes_map(self) -> Tuple[Dict, Dict]: + prefixes: Dict = dict() + target_models: Dict = dict() + for related in self.select_related: + model_cls = self.model_cls + related_split = related.split("__") + related_str = "" + for related in related_split: + prefix_name = f"{model_cls.get_name()}_{related}" + new_related_str = (f"{related_str}__" if related_str else "") + related + prefixes[prefix_name] = new_related_str + field = model_cls.Meta.model_fields[related] + target_models[prefix_name] = field.to + if field.is_multi: + target_models[prefix_name] = field.through + new_through_str = ( + f"{related_str}__" if related_str else "" + ) + field.through.get_name() + prefixes[prefix_name] = new_through_str + prefix_name = ( + f"{field.through.get_name()}_" + f"{field.default_target_field_name()}" + ) + prefixes[prefix_name] = new_related_str + target_models[prefix_name] = field.to + model_cls = field.to + related_str = new_related_str + return prefixes, target_models diff --git a/ormar/relations/alias_manager.py b/ormar/relations/alias_manager.py index adb978d..2ad8733 100644 --- a/ormar/relations/alias_manager.py +++ b/ormar/relations/alias_manager.py @@ -34,6 +34,7 @@ class AliasManager: def __init__(self) -> None: self._aliases_new: Dict[str, str] = dict() + self._reversed_aliases: Dict[str, str] = dict() def __contains__(self, item: str) -> bool: return self._aliases_new.__contains__(item) @@ -41,6 +42,14 @@ class AliasManager: def __getitem__(self, key: str) -> Any: return self._aliases_new.__getitem__(key) + @property + def reversed_aliases(self): + if self._reversed_aliases: + return self._reversed_aliases + reversed_aliases = {v: k for k, v in self._aliases_new.items()} + self._reversed_aliases = reversed_aliases + return self._reversed_aliases + @staticmethod def prefixed_columns( alias: str, table: sqlalchemy.Table, fields: List = None diff --git a/tests/test_queries/test_selecting_subset_of_columns.py b/tests/test_queries/test_selecting_subset_of_columns.py index 809b508..78168db 100644 --- a/tests/test_queries/test_selecting_subset_of_columns.py +++ b/tests/test_queries/test_selecting_subset_of_columns.py @@ -1,3 +1,4 @@ +import asyncio import itertools from typing import Optional, List @@ -9,7 +10,7 @@ import sqlalchemy import ormar from tests.settings import DATABASE_URL -database = databases.Database(DATABASE_URL, force_rollback=True) +database = databases.Database(DATABASE_URL) metadata = sqlalchemy.MetaData() @@ -78,43 +79,54 @@ def create_test_database(): metadata.drop_all(engine) +@pytest.yield_fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(autouse=True, scope="module") +async def sample_data(event_loop, create_test_database): + async with database: + nick1 = await NickNames.objects.create(name="Nippon", is_lame=False) + nick2 = await NickNames.objects.create(name="EroCherry", is_lame=True) + hq = await HQ.objects.create(name="Japan") + await hq.nicks.add(nick1) + await hq.nicks.add(nick2) + + toyota = await Company.objects.create(name="Toyota", founded=1937, hq=hq) + + await Car.objects.create( + manufacturer=toyota, + name="Corolla", + year=2020, + gearbox_type="Manual", + gears=5, + aircon_type="Manual", + ) + await Car.objects.create( + manufacturer=toyota, + name="Yaris", + year=2019, + gearbox_type="Manual", + gears=5, + aircon_type="Manual", + ) + await Car.objects.create( + manufacturer=toyota, + name="Supreme", + year=2020, + gearbox_type="Auto", + gears=6, + aircon_type="Auto", + ) + + @pytest.mark.asyncio async def test_selecting_subset(): async with database: async with database.transaction(force_rollback=True): - nick1 = await NickNames.objects.create(name="Nippon", is_lame=False) - nick2 = await NickNames.objects.create(name="EroCherry", is_lame=True) - hq = await HQ.objects.create(name="Japan") - await hq.nicks.add(nick1) - await hq.nicks.add(nick2) - - toyota = await Company.objects.create(name="Toyota", founded=1937, hq=hq) - - await Car.objects.create( - manufacturer=toyota, - name="Corolla", - year=2020, - gearbox_type="Manual", - gears=5, - aircon_type="Manual", - ) - await Car.objects.create( - manufacturer=toyota, - name="Yaris", - year=2019, - gearbox_type="Manual", - gears=5, - aircon_type="Manual", - ) - await Car.objects.create( - manufacturer=toyota, - name="Supreme", - year=2020, - gearbox_type="Auto", - gears=6, - aircon_type="Auto", - ) - all_cars = ( await Car.objects.select_related(["manufacturer__hq__nicks"]) .fields( diff --git a/tests/test_queries/test_values_and_values_list.py b/tests/test_queries/test_values_and_values_list.py index 6051d80..178ade1 100644 --- a/tests/test_queries/test_values_and_values_list.py +++ b/tests/test_queries/test_values_and_values_list.py @@ -1,3 +1,4 @@ +import asyncio from typing import List, Optional import databases @@ -7,7 +8,7 @@ import sqlalchemy import ormar from tests.settings import DATABASE_URL -database = databases.Database(DATABASE_URL, force_rollback=True) +database = databases.Database(DATABASE_URL) metadata = sqlalchemy.MetaData() @@ -24,6 +25,15 @@ class User(ormar.Model): name: str = ormar.String(max_length=100) +class Role(ormar.Model): + class Meta(BaseMeta): + pass + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + users: List[User] = ormar.ManyToMany(User) + + class Category(ormar.Model): class Meta(BaseMeta): tablename = "categories" @@ -31,7 +41,7 @@ class Category(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=40) sort_order: int = ormar.Integer(nullable=True) - created_by: Optional[User] = ormar.ForeignKey(User) + created_by: Optional[User] = ormar.ForeignKey(User, related_name="categories") class Post(ormar.Model): @@ -51,123 +61,213 @@ def create_test_database(): metadata.drop_all(engine) -@pytest.mark.asyncio -async def test_queryset_values(): +@pytest.yield_fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(autouse=True, scope="module") +async def sample_data(event_loop, create_test_database): async with database: - async with database.transaction(force_rollback=True): - creator = await User(name="Anonymous").save() - news = await Category(name="News", sort_order=0, created_by=creator).save() - await Post(name="Ormar strikes again!", category=news).save() - await Post(name="Why don't you use ormar yet?", category=news).save() - await Post(name="Check this out, ormar now for free", category=news).save() - - posts = await Post.objects.values() - assert posts == [ - {"id": 1, "name": "Ormar strikes again!", "category": 1}, - {"id": 2, "name": "Why don't you use ormar yet?", "category": 1}, - {"id": 3, "name": "Check this out, ormar now for free", "category": 1}, - ] - - posts = await Post.objects.select_related("category__created_by").values() - assert posts == [ - { - "id": 1, - "name": "Ormar strikes again!", - "category": 1, - "category__id": 1, - "category__name": "News", - "category__sort_order": 0, - "category__created_by": 1, - "category__created_by__id": 1, - "category__created_by__name": "Anonymous", - }, - { - "category": 1, - "id": 2, - "name": "Why don't you use ormar yet?", - "category__id": 1, - "category__name": "News", - "category__sort_order": 0, - "category__created_by": 1, - "category__created_by__id": 1, - "category__created_by__name": "Anonymous", - }, - { - "id": 3, - "name": "Check this out, ormar now for free", - "category": 1, - "category__id": 1, - "category__name": "News", - "category__sort_order": 0, - "category__created_by": 1, - "category__created_by__id": 1, - "category__created_by__name": "Anonymous", - }, - ] - - posts = await Post.objects.select_related("category__created_by").values( - ["name", "category__name", "category__created_by__name"] - ) - assert posts == [ - { - "name": "Ormar strikes again!", - "category__name": "News", - "category__created_by__name": "Anonymous", - }, - { - "name": "Why don't you use ormar yet?", - "category__name": "News", - "category__created_by__name": "Anonymous", - }, - { - "name": "Check this out, ormar now for free", - "category__name": "News", - "category__created_by__name": "Anonymous", - }, - ] + creator = await User(name="Anonymous").save() + admin = await Role(name="admin").save() + editor = await Role(name="editor").save() + await creator.roles.add(admin) + await creator.roles.add(editor) + news = await Category(name="News", sort_order=0, created_by=creator).save() + await Post(name="Ormar strikes again!", category=news).save() + await Post(name="Why don't you use ormar yet?", category=news).save() + await Post(name="Check this out, ormar now for free", category=news).save() @pytest.mark.asyncio -async def test_queryset_values_list(): +async def test_simple_queryset_values(): async with database: - async with database.transaction(force_rollback=True): - creator = await User(name="Anonymous").save() - news = await Category(name="News", sort_order=0, created_by=creator).save() - await Post(name="Ormar strikes again!", category=news).save() - await Post(name="Why don't you use ormar yet?", category=news).save() - await Post(name="Check this out, ormar now for free", category=news).save() + posts = await Post.objects.values() + assert posts == [ + {"id": 1, "name": "Ormar strikes again!", "category": 1}, + {"id": 2, "name": "Why don't you use ormar yet?", "category": 1}, + {"id": 3, "name": "Check this out, ormar now for free", "category": 1}, + ] - posts = await Post.objects.values_list() - assert posts == [ - (1, "Ormar strikes again!", 1), - (2, "Why don't you use ormar yet?", 1), - (3, "Check this out, ormar now for free", 1), - ] - posts = await Post.objects.select_related( - "category__created_by" - ).values_list() - assert posts == [ - (1, "Ormar strikes again!", 1, 1, "News", 0, 1, 1, "Anonymous"), - (2, "Why don't you use ormar yet?", 1, 1, "News", 0, 1, 1, "Anonymous"), - ( - 3, - "Check this out, ormar now for free", - 1, - 1, - "News", - 0, - 1, - 1, - "Anonymous", - ), - ] +@pytest.mark.asyncio +async def test_queryset_values_nested_relation(): + async with database: + posts = await Post.objects.select_related("category__created_by").values() + assert posts == [ + { + "id": 1, + "name": "Ormar strikes again!", + "category": 1, + "category__id": 1, + "category__name": "News", + "category__sort_order": 0, + "category__created_by": 1, + "category__created_by__id": 1, + "category__created_by__name": "Anonymous", + }, + { + "category": 1, + "id": 2, + "name": "Why don't you use ormar yet?", + "category__id": 1, + "category__name": "News", + "category__sort_order": 0, + "category__created_by": 1, + "category__created_by__id": 1, + "category__created_by__name": "Anonymous", + }, + { + "id": 3, + "name": "Check this out, ormar now for free", + "category": 1, + "category__id": 1, + "category__name": "News", + "category__sort_order": 0, + "category__created_by": 1, + "category__created_by__id": 1, + "category__created_by__name": "Anonymous", + }, + ] - posts = await Post.objects.select_related( - "category__created_by" - ).values_list(["name", "category__name", "category__created_by__name"]) - assert posts == [ - ("Ormar strikes again!", "News", "Anonymous"), - ("Why don't you use ormar yet?", "News", "Anonymous"), - ("Check this out, ormar now for free", "News", "Anonymous"), - ] + +@pytest.mark.asyncio +async def test_queryset_values_nested_relation_subset_of_fields(): + async with database: + posts = await Post.objects.select_related("category__created_by").values( + ["name", "category__name", "category__created_by__name"] + ) + assert posts == [ + { + "name": "Ormar strikes again!", + "category__name": "News", + "category__created_by__name": "Anonymous", + }, + { + "name": "Why don't you use ormar yet?", + "category__name": "News", + "category__created_by__name": "Anonymous", + }, + { + "name": "Check this out, ormar now for free", + "category__name": "News", + "category__created_by__name": "Anonymous", + }, + ] + + +@pytest.mark.asyncio +async def test_queryset_simple_values_list(): + async with database: + posts = await Post.objects.values_list() + assert posts == [ + (1, "Ormar strikes again!", 1), + (2, "Why don't you use ormar yet?", 1), + (3, "Check this out, ormar now for free", 1), + ] + + +@pytest.mark.asyncio +async def test_queryset_nested_relation_values_list(): + async with database: + posts = await Post.objects.select_related("category__created_by").values_list() + assert posts == [ + (1, "Ormar strikes again!", 1, 1, "News", 0, 1, 1, "Anonymous"), + (2, "Why don't you use ormar yet?", 1, 1, "News", 0, 1, 1, "Anonymous"), + ( + 3, + "Check this out, ormar now for free", + 1, + 1, + "News", + 0, + 1, + 1, + "Anonymous", + ), + ] + + +@pytest.mark.asyncio +async def test_queryset_nested_relation_subset_of_fields_values_list(): + async with database: + posts = await Post.objects.select_related("category__created_by").values_list( + ["name", "category__name", "category__created_by__name"] + ) + assert posts == [ + ("Ormar strikes again!", "News", "Anonymous"), + ("Why don't you use ormar yet?", "News", "Anonymous"), + ("Check this out, ormar now for free", "News", "Anonymous"), + ] + + +@pytest.mark.asyncio +async def test_m2m_values(): + async with database: + user = await User.objects.select_related("roles").values() + assert user == [ + { + "id": 1, + "name": "Anonymous", + "roleuser__id": 1, + "roleuser__role": 1, + "roleuser__user": 1, + "roles__id": 1, + "roles__name": "admin", + }, + { + "id": 1, + "name": "Anonymous", + "roleuser__id": 2, + "roleuser__role": 2, + "roleuser__user": 1, + "roles__id": 2, + "roles__name": "editor", + }, + ] + + +@pytest.mark.asyncio +async def test_nested_m2m_values(): + async with database: + user = ( + await Role.objects.select_related("users__categories") + .filter(name="admin") + .values() + ) + assert user == [ + { + "id": 1, + "name": "admin", + "roleuser__id": 1, + "roleuser__role": 1, + "roleuser__user": 1, + "users__id": 1, + "users__name": "Anonymous", + "users__categories__id": 1, + "users__categories__name": "News", + "users__categories__sort_order": 0, + "users__categories__created_by": 1, + } + ] + + +@pytest.mark.asyncio +async def test_nested_m2m_values_subset_of_fields(): + async with database: + user = ( + await Role.objects.select_related("users__categories") + .filter(name="admin") + .fields({"name": ..., "users": {"name": ...}}) + .values() + ) + assert user == [ + { + "name": "admin", + "users__name": "Anonymous", + "users__categories__name": "News", + } + ]