wip - refactor of alias resolver - through models columns with fields are not properly handled yet

This commit is contained in:
collerek
2021-06-05 18:53:15 +02:00
parent b1b3d5cd92
commit 955ac48cdd
5 changed files with 358 additions and 196 deletions

View File

@ -26,6 +26,7 @@ from ormar.queryset.actions.order_action import OrderAction
from ormar.queryset.clause import FilterGroup, Prefix, QueryClause from ormar.queryset.clause import FilterGroup, Prefix, QueryClause
from ormar.queryset.prefetch_query import PrefetchQuery from ormar.queryset.prefetch_query import PrefetchQuery
from ormar.queryset.query import Query from ormar.queryset.query import Query
from ormar.queryset.reverse_alias_resolver import ReverseAliasResolver
from ormar.queryset.utils import get_relationship_alias_model_and_str from ormar.queryset.utils import get_relationship_alias_model_and_str
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -586,10 +587,12 @@ class QuerySet(Generic[T]):
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
if not rows: if not rows:
return [] return []
column_names = list(rows[0].keys()) alias_resolver = ReverseAliasResolver(
column_map = self._resolve_data_prefix_to_relation_str( select_related=self._select_related,
column_names=column_names excludable=self._excludable,
model_cls=self.model_cls,
) )
column_map = alias_resolver.resolve_columns(columns_names=list(rows[0].keys()))
result = [ result = [
{column_map.get(k): v for k, v in dict(x).items() if k in column_map} {column_map.get(k): v for k, v in dict(x).items() if k in column_map}
for x in rows for x in rows
@ -598,7 +601,7 @@ class QuerySet(Generic[T]):
return result return result
if _flatten and not self._excludable.include_entry_count() == 1: if _flatten and not self._excludable.include_entry_count() == 1:
raise QueryDefinitionError( raise QueryDefinitionError(
"You cannot flatten values_list if more than " "one field is selected!" "You cannot flatten values_list if more than one field is selected!"
) )
tuple_result = [tuple(x.values()) for x in result] tuple_result = [tuple(x.values()) for x in result]
return tuple_result if not _flatten else [x[0] for x in tuple_result] return tuple_result if not _flatten else [x[0] for x in tuple_result]
@ -625,50 +628,6 @@ class QuerySet(Generic[T]):
""" """
return await self.values(fields=fields, _as_dict=False, _flatten=flatten) return await self.values(fields=fields, _as_dict=False, _flatten=flatten)
def _resolve_data_prefix_to_relation_str(self, column_names: List[str]) -> Dict:
resolved_names = dict()
for column_name in column_names:
prefixes_map = self._create_prefixes_map()
column_parts = column_name.split("_")
potential_prefix = column_parts[0]
if potential_prefix in prefixes_map:
prefix = prefixes_map[potential_prefix]
allowed_columns = prefix.model_cls.own_table_columns(
model=prefix.model_cls,
excludable=self._excludable,
alias=prefix.table_prefix,
add_pk_columns=False,
)
new_column_name = "_".join(column_parts[1:])
if new_column_name in allowed_columns:
resolved_names[column_name] = f"{prefix.relation_str}__" + "_".join(
column_name.split("_")[1:]
)
else:
assert self.model_cls
allowed_columns = self.model_cls.own_table_columns(
model=self.model_cls,
excludable=self._excludable,
add_pk_columns=False,
)
if column_name in allowed_columns:
resolved_names[column_name] = column_name
return resolved_names
def _create_prefixes_map(self) -> Dict[str, Prefix]:
prefixes: List[Prefix] = []
for related in self._select_related:
related_split = related.split("__")
for index in range(len(related_split)):
prefix = Prefix(
self.model_cls, # type: ignore
*get_relationship_alias_model_and_str(
self.model_cls, related_split[0 : (index + 1)] # type: ignore
),
)
prefixes.append(prefix)
return {x.table_prefix: x for x in prefixes}
async def exists(self) -> bool: async def exists(self) -> bool:
""" """
Returns a bool value to confirm if there are rows matching the given criteria Returns a bool value to confirm if there are rows matching the given criteria

View File

@ -0,0 +1,82 @@
from typing import Dict, List, TYPE_CHECKING, Tuple, Type
from ormar.queryset.clause import Prefix
from ormar.queryset.utils import get_relationship_alias_model_and_str
if TYPE_CHECKING:
from ormar import Model
from ormar.models.excludable import ExcludableItems
class ReverseAliasResolver:
def __init__(
self,
model_cls: Type["Model"],
excludable: "ExcludableItems",
select_related: List[str],
) -> None:
self.select_related = select_related
self.model_cls = model_cls
self.reversed_aliases = self.model_cls.Meta.alias_manager.reversed_aliases
self.excludable = excludable
def resolve_columns(self, columns_names: List[str]) -> Dict:
resolved_names = dict()
prefixes, target_models = self._create_prefixes_map()
for column_name in columns_names:
column_parts = column_name.split("_")
potential_prefix = column_parts[0]
if potential_prefix in self.reversed_aliases:
relation = self.reversed_aliases[potential_prefix]
relation_str = prefixes[relation]
target_model = target_models[relation]
allowed_columns = target_model.own_table_columns(
model=target_model,
excludable=self.excludable,
alias=potential_prefix,
add_pk_columns=False,
)
new_column_name = column_name.replace(f"{potential_prefix}_", "")
if new_column_name in allowed_columns:
resolved_names[column_name] = column_name.replace(
f"{potential_prefix}_", f"{relation_str}__"
)
else:
allowed_columns = self.model_cls.own_table_columns(
model=self.model_cls,
excludable=self.excludable,
add_pk_columns=False,
)
if column_name in allowed_columns:
resolved_names[column_name] = column_name
return resolved_names
def _create_prefixes_map(self) -> Tuple[Dict, Dict]:
prefixes: Dict = dict()
target_models: Dict = dict()
for related in self.select_related:
model_cls = self.model_cls
related_split = related.split("__")
related_str = ""
for related in related_split:
prefix_name = f"{model_cls.get_name()}_{related}"
new_related_str = (f"{related_str}__" if related_str else "") + related
prefixes[prefix_name] = new_related_str
field = model_cls.Meta.model_fields[related]
target_models[prefix_name] = field.to
if field.is_multi:
target_models[prefix_name] = field.through
new_through_str = (
f"{related_str}__" if related_str else ""
) + field.through.get_name()
prefixes[prefix_name] = new_through_str
prefix_name = (
f"{field.through.get_name()}_"
f"{field.default_target_field_name()}"
)
prefixes[prefix_name] = new_related_str
target_models[prefix_name] = field.to
model_cls = field.to
related_str = new_related_str
return prefixes, target_models

View File

@ -34,6 +34,7 @@ class AliasManager:
def __init__(self) -> None: def __init__(self) -> None:
self._aliases_new: Dict[str, str] = dict() self._aliases_new: Dict[str, str] = dict()
self._reversed_aliases: Dict[str, str] = dict()
def __contains__(self, item: str) -> bool: def __contains__(self, item: str) -> bool:
return self._aliases_new.__contains__(item) return self._aliases_new.__contains__(item)
@ -41,6 +42,14 @@ class AliasManager:
def __getitem__(self, key: str) -> Any: def __getitem__(self, key: str) -> Any:
return self._aliases_new.__getitem__(key) return self._aliases_new.__getitem__(key)
@property
def reversed_aliases(self):
if self._reversed_aliases:
return self._reversed_aliases
reversed_aliases = {v: k for k, v in self._aliases_new.items()}
self._reversed_aliases = reversed_aliases
return self._reversed_aliases
@staticmethod @staticmethod
def prefixed_columns( def prefixed_columns(
alias: str, table: sqlalchemy.Table, fields: List = None alias: str, table: sqlalchemy.Table, fields: List = None

View File

@ -1,3 +1,4 @@
import asyncio
import itertools import itertools
from typing import Optional, List from typing import Optional, List
@ -9,7 +10,7 @@ import sqlalchemy
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
@ -78,10 +79,16 @@ def create_test_database():
metadata.drop_all(engine) metadata.drop_all(engine)
@pytest.mark.asyncio @pytest.yield_fixture(scope="module")
async def test_selecting_subset(): def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module")
async def sample_data(event_loop, create_test_database):
async with database: async with database:
async with database.transaction(force_rollback=True):
nick1 = await NickNames.objects.create(name="Nippon", is_lame=False) nick1 = await NickNames.objects.create(name="Nippon", is_lame=False)
nick2 = await NickNames.objects.create(name="EroCherry", is_lame=True) nick2 = await NickNames.objects.create(name="EroCherry", is_lame=True)
hq = await HQ.objects.create(name="Japan") hq = await HQ.objects.create(name="Japan")
@ -115,6 +122,11 @@ async def test_selecting_subset():
aircon_type="Auto", aircon_type="Auto",
) )
@pytest.mark.asyncio
async def test_selecting_subset():
async with database:
async with database.transaction(force_rollback=True):
all_cars = ( all_cars = (
await Car.objects.select_related(["manufacturer__hq__nicks"]) await Car.objects.select_related(["manufacturer__hq__nicks"])
.fields( .fields(

View File

@ -1,3 +1,4 @@
import asyncio
from typing import List, Optional from typing import List, Optional
import databases import databases
@ -7,7 +8,7 @@ import sqlalchemy
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
@ -24,6 +25,15 @@ class User(ormar.Model):
name: str = ormar.String(max_length=100) name: str = ormar.String(max_length=100)
class Role(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
users: List[User] = ormar.ManyToMany(User)
class Category(ormar.Model): class Category(ormar.Model):
class Meta(BaseMeta): class Meta(BaseMeta):
tablename = "categories" tablename = "categories"
@ -31,7 +41,7 @@ class Category(ormar.Model):
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=40) name: str = ormar.String(max_length=40)
sort_order: int = ormar.Integer(nullable=True) sort_order: int = ormar.Integer(nullable=True)
created_by: Optional[User] = ormar.ForeignKey(User) created_by: Optional[User] = ormar.ForeignKey(User, related_name="categories")
class Post(ormar.Model): class Post(ormar.Model):
@ -51,16 +61,30 @@ def create_test_database():
metadata.drop_all(engine) metadata.drop_all(engine)
@pytest.mark.asyncio @pytest.yield_fixture(scope="module")
async def test_queryset_values(): def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module")
async def sample_data(event_loop, create_test_database):
async with database: async with database:
async with database.transaction(force_rollback=True):
creator = await User(name="Anonymous").save() creator = await User(name="Anonymous").save()
admin = await Role(name="admin").save()
editor = await Role(name="editor").save()
await creator.roles.add(admin)
await creator.roles.add(editor)
news = await Category(name="News", sort_order=0, created_by=creator).save() news = await Category(name="News", sort_order=0, created_by=creator).save()
await Post(name="Ormar strikes again!", category=news).save() await Post(name="Ormar strikes again!", category=news).save()
await Post(name="Why don't you use ormar yet?", category=news).save() await Post(name="Why don't you use ormar yet?", category=news).save()
await Post(name="Check this out, ormar now for free", category=news).save() await Post(name="Check this out, ormar now for free", category=news).save()
@pytest.mark.asyncio
async def test_simple_queryset_values():
async with database:
posts = await Post.objects.values() posts = await Post.objects.values()
assert posts == [ assert posts == [
{"id": 1, "name": "Ormar strikes again!", "category": 1}, {"id": 1, "name": "Ormar strikes again!", "category": 1},
@ -68,6 +92,10 @@ async def test_queryset_values():
{"id": 3, "name": "Check this out, ormar now for free", "category": 1}, {"id": 3, "name": "Check this out, ormar now for free", "category": 1},
] ]
@pytest.mark.asyncio
async def test_queryset_values_nested_relation():
async with database:
posts = await Post.objects.select_related("category__created_by").values() posts = await Post.objects.select_related("category__created_by").values()
assert posts == [ assert posts == [
{ {
@ -105,6 +133,10 @@ async def test_queryset_values():
}, },
] ]
@pytest.mark.asyncio
async def test_queryset_values_nested_relation_subset_of_fields():
async with database:
posts = await Post.objects.select_related("category__created_by").values( posts = await Post.objects.select_related("category__created_by").values(
["name", "category__name", "category__created_by__name"] ["name", "category__name", "category__created_by__name"]
) )
@ -128,15 +160,8 @@ async def test_queryset_values():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_queryset_values_list(): async def test_queryset_simple_values_list():
async with database: async with database:
async with database.transaction(force_rollback=True):
creator = await User(name="Anonymous").save()
news = await Category(name="News", sort_order=0, created_by=creator).save()
await Post(name="Ormar strikes again!", category=news).save()
await Post(name="Why don't you use ormar yet?", category=news).save()
await Post(name="Check this out, ormar now for free", category=news).save()
posts = await Post.objects.values_list() posts = await Post.objects.values_list()
assert posts == [ assert posts == [
(1, "Ormar strikes again!", 1), (1, "Ormar strikes again!", 1),
@ -144,9 +169,11 @@ async def test_queryset_values_list():
(3, "Check this out, ormar now for free", 1), (3, "Check this out, ormar now for free", 1),
] ]
posts = await Post.objects.select_related(
"category__created_by" @pytest.mark.asyncio
).values_list() async def test_queryset_nested_relation_values_list():
async with database:
posts = await Post.objects.select_related("category__created_by").values_list()
assert posts == [ assert posts == [
(1, "Ormar strikes again!", 1, 1, "News", 0, 1, 1, "Anonymous"), (1, "Ormar strikes again!", 1, 1, "News", 0, 1, 1, "Anonymous"),
(2, "Why don't you use ormar yet?", 1, 1, "News", 0, 1, 1, "Anonymous"), (2, "Why don't you use ormar yet?", 1, 1, "News", 0, 1, 1, "Anonymous"),
@ -163,11 +190,84 @@ async def test_queryset_values_list():
), ),
] ]
posts = await Post.objects.select_related(
"category__created_by" @pytest.mark.asyncio
).values_list(["name", "category__name", "category__created_by__name"]) async def test_queryset_nested_relation_subset_of_fields_values_list():
async with database:
posts = await Post.objects.select_related("category__created_by").values_list(
["name", "category__name", "category__created_by__name"]
)
assert posts == [ assert posts == [
("Ormar strikes again!", "News", "Anonymous"), ("Ormar strikes again!", "News", "Anonymous"),
("Why don't you use ormar yet?", "News", "Anonymous"), ("Why don't you use ormar yet?", "News", "Anonymous"),
("Check this out, ormar now for free", "News", "Anonymous"), ("Check this out, ormar now for free", "News", "Anonymous"),
] ]
@pytest.mark.asyncio
async def test_m2m_values():
async with database:
user = await User.objects.select_related("roles").values()
assert user == [
{
"id": 1,
"name": "Anonymous",
"roleuser__id": 1,
"roleuser__role": 1,
"roleuser__user": 1,
"roles__id": 1,
"roles__name": "admin",
},
{
"id": 1,
"name": "Anonymous",
"roleuser__id": 2,
"roleuser__role": 2,
"roleuser__user": 1,
"roles__id": 2,
"roles__name": "editor",
},
]
@pytest.mark.asyncio
async def test_nested_m2m_values():
async with database:
user = (
await Role.objects.select_related("users__categories")
.filter(name="admin")
.values()
)
assert user == [
{
"id": 1,
"name": "admin",
"roleuser__id": 1,
"roleuser__role": 1,
"roleuser__user": 1,
"users__id": 1,
"users__name": "Anonymous",
"users__categories__id": 1,
"users__categories__name": "News",
"users__categories__sort_order": 0,
"users__categories__created_by": 1,
}
]
@pytest.mark.asyncio
async def test_nested_m2m_values_subset_of_fields():
async with database:
user = (
await Role.objects.select_related("users__categories")
.filter(name="admin")
.fields({"name": ..., "users": {"name": ...}})
.values()
)
assert user == [
{
"name": "admin",
"users__name": "Anonymous",
"users__categories__name": "News",
}
]