wip - refactor of alias resolver - through models columns with fields are not properly handled yet

This commit is contained in:
collerek
2021-06-05 18:53:15 +02:00
parent b1b3d5cd92
commit 955ac48cdd
5 changed files with 358 additions and 196 deletions

View File

@ -26,6 +26,7 @@ from ormar.queryset.actions.order_action import OrderAction
from ormar.queryset.clause import FilterGroup, Prefix, QueryClause from ormar.queryset.clause import FilterGroup, Prefix, QueryClause
from ormar.queryset.prefetch_query import PrefetchQuery from ormar.queryset.prefetch_query import PrefetchQuery
from ormar.queryset.query import Query from ormar.queryset.query import Query
from ormar.queryset.reverse_alias_resolver import ReverseAliasResolver
from ormar.queryset.utils import get_relationship_alias_model_and_str from ormar.queryset.utils import get_relationship_alias_model_and_str
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -586,10 +587,12 @@ class QuerySet(Generic[T]):
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
if not rows: if not rows:
return [] return []
column_names = list(rows[0].keys()) alias_resolver = ReverseAliasResolver(
column_map = self._resolve_data_prefix_to_relation_str( select_related=self._select_related,
column_names=column_names excludable=self._excludable,
model_cls=self.model_cls,
) )
column_map = alias_resolver.resolve_columns(columns_names=list(rows[0].keys()))
result = [ result = [
{column_map.get(k): v for k, v in dict(x).items() if k in column_map} {column_map.get(k): v for k, v in dict(x).items() if k in column_map}
for x in rows for x in rows
@ -598,7 +601,7 @@ class QuerySet(Generic[T]):
return result return result
if _flatten and not self._excludable.include_entry_count() == 1: if _flatten and not self._excludable.include_entry_count() == 1:
raise QueryDefinitionError( raise QueryDefinitionError(
"You cannot flatten values_list if more than " "one field is selected!" "You cannot flatten values_list if more than one field is selected!"
) )
tuple_result = [tuple(x.values()) for x in result] tuple_result = [tuple(x.values()) for x in result]
return tuple_result if not _flatten else [x[0] for x in tuple_result] return tuple_result if not _flatten else [x[0] for x in tuple_result]
@ -625,50 +628,6 @@ class QuerySet(Generic[T]):
""" """
return await self.values(fields=fields, _as_dict=False, _flatten=flatten) return await self.values(fields=fields, _as_dict=False, _flatten=flatten)
def _resolve_data_prefix_to_relation_str(self, column_names: List[str]) -> Dict:
resolved_names = dict()
for column_name in column_names:
prefixes_map = self._create_prefixes_map()
column_parts = column_name.split("_")
potential_prefix = column_parts[0]
if potential_prefix in prefixes_map:
prefix = prefixes_map[potential_prefix]
allowed_columns = prefix.model_cls.own_table_columns(
model=prefix.model_cls,
excludable=self._excludable,
alias=prefix.table_prefix,
add_pk_columns=False,
)
new_column_name = "_".join(column_parts[1:])
if new_column_name in allowed_columns:
resolved_names[column_name] = f"{prefix.relation_str}__" + "_".join(
column_name.split("_")[1:]
)
else:
assert self.model_cls
allowed_columns = self.model_cls.own_table_columns(
model=self.model_cls,
excludable=self._excludable,
add_pk_columns=False,
)
if column_name in allowed_columns:
resolved_names[column_name] = column_name
return resolved_names
def _create_prefixes_map(self) -> Dict[str, Prefix]:
prefixes: List[Prefix] = []
for related in self._select_related:
related_split = related.split("__")
for index in range(len(related_split)):
prefix = Prefix(
self.model_cls, # type: ignore
*get_relationship_alias_model_and_str(
self.model_cls, related_split[0 : (index + 1)] # type: ignore
),
)
prefixes.append(prefix)
return {x.table_prefix: x for x in prefixes}
async def exists(self) -> bool: async def exists(self) -> bool:
""" """
Returns a bool value to confirm if there are rows matching the given criteria Returns a bool value to confirm if there are rows matching the given criteria

View File

@ -0,0 +1,82 @@
from typing import Dict, List, TYPE_CHECKING, Tuple, Type
from ormar.queryset.clause import Prefix
from ormar.queryset.utils import get_relationship_alias_model_and_str
if TYPE_CHECKING:
from ormar import Model
from ormar.models.excludable import ExcludableItems
class ReverseAliasResolver:
def __init__(
self,
model_cls: Type["Model"],
excludable: "ExcludableItems",
select_related: List[str],
) -> None:
self.select_related = select_related
self.model_cls = model_cls
self.reversed_aliases = self.model_cls.Meta.alias_manager.reversed_aliases
self.excludable = excludable
def resolve_columns(self, columns_names: List[str]) -> Dict:
resolved_names = dict()
prefixes, target_models = self._create_prefixes_map()
for column_name in columns_names:
column_parts = column_name.split("_")
potential_prefix = column_parts[0]
if potential_prefix in self.reversed_aliases:
relation = self.reversed_aliases[potential_prefix]
relation_str = prefixes[relation]
target_model = target_models[relation]
allowed_columns = target_model.own_table_columns(
model=target_model,
excludable=self.excludable,
alias=potential_prefix,
add_pk_columns=False,
)
new_column_name = column_name.replace(f"{potential_prefix}_", "")
if new_column_name in allowed_columns:
resolved_names[column_name] = column_name.replace(
f"{potential_prefix}_", f"{relation_str}__"
)
else:
allowed_columns = self.model_cls.own_table_columns(
model=self.model_cls,
excludable=self.excludable,
add_pk_columns=False,
)
if column_name in allowed_columns:
resolved_names[column_name] = column_name
return resolved_names
def _create_prefixes_map(self) -> Tuple[Dict, Dict]:
prefixes: Dict = dict()
target_models: Dict = dict()
for related in self.select_related:
model_cls = self.model_cls
related_split = related.split("__")
related_str = ""
for related in related_split:
prefix_name = f"{model_cls.get_name()}_{related}"
new_related_str = (f"{related_str}__" if related_str else "") + related
prefixes[prefix_name] = new_related_str
field = model_cls.Meta.model_fields[related]
target_models[prefix_name] = field.to
if field.is_multi:
target_models[prefix_name] = field.through
new_through_str = (
f"{related_str}__" if related_str else ""
) + field.through.get_name()
prefixes[prefix_name] = new_through_str
prefix_name = (
f"{field.through.get_name()}_"
f"{field.default_target_field_name()}"
)
prefixes[prefix_name] = new_related_str
target_models[prefix_name] = field.to
model_cls = field.to
related_str = new_related_str
return prefixes, target_models

View File

@ -34,6 +34,7 @@ class AliasManager:
def __init__(self) -> None: def __init__(self) -> None:
self._aliases_new: Dict[str, str] = dict() self._aliases_new: Dict[str, str] = dict()
self._reversed_aliases: Dict[str, str] = dict()
def __contains__(self, item: str) -> bool: def __contains__(self, item: str) -> bool:
return self._aliases_new.__contains__(item) return self._aliases_new.__contains__(item)
@ -41,6 +42,14 @@ class AliasManager:
def __getitem__(self, key: str) -> Any: def __getitem__(self, key: str) -> Any:
return self._aliases_new.__getitem__(key) return self._aliases_new.__getitem__(key)
@property
def reversed_aliases(self):
if self._reversed_aliases:
return self._reversed_aliases
reversed_aliases = {v: k for k, v in self._aliases_new.items()}
self._reversed_aliases = reversed_aliases
return self._reversed_aliases
@staticmethod @staticmethod
def prefixed_columns( def prefixed_columns(
alias: str, table: sqlalchemy.Table, fields: List = None alias: str, table: sqlalchemy.Table, fields: List = None

View File

@ -1,3 +1,4 @@
import asyncio
import itertools import itertools
from typing import Optional, List from typing import Optional, List
@ -9,7 +10,7 @@ import sqlalchemy
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
@ -78,43 +79,54 @@ def create_test_database():
metadata.drop_all(engine) metadata.drop_all(engine)
@pytest.yield_fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module")
async def sample_data(event_loop, create_test_database):
async with database:
nick1 = await NickNames.objects.create(name="Nippon", is_lame=False)
nick2 = await NickNames.objects.create(name="EroCherry", is_lame=True)
hq = await HQ.objects.create(name="Japan")
await hq.nicks.add(nick1)
await hq.nicks.add(nick2)
toyota = await Company.objects.create(name="Toyota", founded=1937, hq=hq)
await Car.objects.create(
manufacturer=toyota,
name="Corolla",
year=2020,
gearbox_type="Manual",
gears=5,
aircon_type="Manual",
)
await Car.objects.create(
manufacturer=toyota,
name="Yaris",
year=2019,
gearbox_type="Manual",
gears=5,
aircon_type="Manual",
)
await Car.objects.create(
manufacturer=toyota,
name="Supreme",
year=2020,
gearbox_type="Auto",
gears=6,
aircon_type="Auto",
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_selecting_subset(): async def test_selecting_subset():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
nick1 = await NickNames.objects.create(name="Nippon", is_lame=False)
nick2 = await NickNames.objects.create(name="EroCherry", is_lame=True)
hq = await HQ.objects.create(name="Japan")
await hq.nicks.add(nick1)
await hq.nicks.add(nick2)
toyota = await Company.objects.create(name="Toyota", founded=1937, hq=hq)
await Car.objects.create(
manufacturer=toyota,
name="Corolla",
year=2020,
gearbox_type="Manual",
gears=5,
aircon_type="Manual",
)
await Car.objects.create(
manufacturer=toyota,
name="Yaris",
year=2019,
gearbox_type="Manual",
gears=5,
aircon_type="Manual",
)
await Car.objects.create(
manufacturer=toyota,
name="Supreme",
year=2020,
gearbox_type="Auto",
gears=6,
aircon_type="Auto",
)
all_cars = ( all_cars = (
await Car.objects.select_related(["manufacturer__hq__nicks"]) await Car.objects.select_related(["manufacturer__hq__nicks"])
.fields( .fields(

View File

@ -1,3 +1,4 @@
import asyncio
from typing import List, Optional from typing import List, Optional
import databases import databases
@ -7,7 +8,7 @@ import sqlalchemy
import ormar import ormar
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
@ -24,6 +25,15 @@ class User(ormar.Model):
name: str = ormar.String(max_length=100) name: str = ormar.String(max_length=100)
class Role(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
users: List[User] = ormar.ManyToMany(User)
class Category(ormar.Model): class Category(ormar.Model):
class Meta(BaseMeta): class Meta(BaseMeta):
tablename = "categories" tablename = "categories"
@ -31,7 +41,7 @@ class Category(ormar.Model):
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=40) name: str = ormar.String(max_length=40)
sort_order: int = ormar.Integer(nullable=True) sort_order: int = ormar.Integer(nullable=True)
created_by: Optional[User] = ormar.ForeignKey(User) created_by: Optional[User] = ormar.ForeignKey(User, related_name="categories")
class Post(ormar.Model): class Post(ormar.Model):
@ -51,123 +61,213 @@ def create_test_database():
metadata.drop_all(engine) metadata.drop_all(engine)
@pytest.mark.asyncio @pytest.yield_fixture(scope="module")
async def test_queryset_values(): def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module")
async def sample_data(event_loop, create_test_database):
async with database: async with database:
async with database.transaction(force_rollback=True): creator = await User(name="Anonymous").save()
creator = await User(name="Anonymous").save() admin = await Role(name="admin").save()
news = await Category(name="News", sort_order=0, created_by=creator).save() editor = await Role(name="editor").save()
await Post(name="Ormar strikes again!", category=news).save() await creator.roles.add(admin)
await Post(name="Why don't you use ormar yet?", category=news).save() await creator.roles.add(editor)
await Post(name="Check this out, ormar now for free", category=news).save() news = await Category(name="News", sort_order=0, created_by=creator).save()
await Post(name="Ormar strikes again!", category=news).save()
posts = await Post.objects.values() await Post(name="Why don't you use ormar yet?", category=news).save()
assert posts == [ await Post(name="Check this out, ormar now for free", category=news).save()
{"id": 1, "name": "Ormar strikes again!", "category": 1},
{"id": 2, "name": "Why don't you use ormar yet?", "category": 1},
{"id": 3, "name": "Check this out, ormar now for free", "category": 1},
]
posts = await Post.objects.select_related("category__created_by").values()
assert posts == [
{
"id": 1,
"name": "Ormar strikes again!",
"category": 1,
"category__id": 1,
"category__name": "News",
"category__sort_order": 0,
"category__created_by": 1,
"category__created_by__id": 1,
"category__created_by__name": "Anonymous",
},
{
"category": 1,
"id": 2,
"name": "Why don't you use ormar yet?",
"category__id": 1,
"category__name": "News",
"category__sort_order": 0,
"category__created_by": 1,
"category__created_by__id": 1,
"category__created_by__name": "Anonymous",
},
{
"id": 3,
"name": "Check this out, ormar now for free",
"category": 1,
"category__id": 1,
"category__name": "News",
"category__sort_order": 0,
"category__created_by": 1,
"category__created_by__id": 1,
"category__created_by__name": "Anonymous",
},
]
posts = await Post.objects.select_related("category__created_by").values(
["name", "category__name", "category__created_by__name"]
)
assert posts == [
{
"name": "Ormar strikes again!",
"category__name": "News",
"category__created_by__name": "Anonymous",
},
{
"name": "Why don't you use ormar yet?",
"category__name": "News",
"category__created_by__name": "Anonymous",
},
{
"name": "Check this out, ormar now for free",
"category__name": "News",
"category__created_by__name": "Anonymous",
},
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_queryset_values_list(): async def test_simple_queryset_values():
async with database: async with database:
async with database.transaction(force_rollback=True): posts = await Post.objects.values()
creator = await User(name="Anonymous").save() assert posts == [
news = await Category(name="News", sort_order=0, created_by=creator).save() {"id": 1, "name": "Ormar strikes again!", "category": 1},
await Post(name="Ormar strikes again!", category=news).save() {"id": 2, "name": "Why don't you use ormar yet?", "category": 1},
await Post(name="Why don't you use ormar yet?", category=news).save() {"id": 3, "name": "Check this out, ormar now for free", "category": 1},
await Post(name="Check this out, ormar now for free", category=news).save() ]
posts = await Post.objects.values_list()
assert posts == [
(1, "Ormar strikes again!", 1),
(2, "Why don't you use ormar yet?", 1),
(3, "Check this out, ormar now for free", 1),
]
posts = await Post.objects.select_related( @pytest.mark.asyncio
"category__created_by" async def test_queryset_values_nested_relation():
).values_list() async with database:
assert posts == [ posts = await Post.objects.select_related("category__created_by").values()
(1, "Ormar strikes again!", 1, 1, "News", 0, 1, 1, "Anonymous"), assert posts == [
(2, "Why don't you use ormar yet?", 1, 1, "News", 0, 1, 1, "Anonymous"), {
( "id": 1,
3, "name": "Ormar strikes again!",
"Check this out, ormar now for free", "category": 1,
1, "category__id": 1,
1, "category__name": "News",
"News", "category__sort_order": 0,
0, "category__created_by": 1,
1, "category__created_by__id": 1,
1, "category__created_by__name": "Anonymous",
"Anonymous", },
), {
] "category": 1,
"id": 2,
"name": "Why don't you use ormar yet?",
"category__id": 1,
"category__name": "News",
"category__sort_order": 0,
"category__created_by": 1,
"category__created_by__id": 1,
"category__created_by__name": "Anonymous",
},
{
"id": 3,
"name": "Check this out, ormar now for free",
"category": 1,
"category__id": 1,
"category__name": "News",
"category__sort_order": 0,
"category__created_by": 1,
"category__created_by__id": 1,
"category__created_by__name": "Anonymous",
},
]
posts = await Post.objects.select_related(
"category__created_by" @pytest.mark.asyncio
).values_list(["name", "category__name", "category__created_by__name"]) async def test_queryset_values_nested_relation_subset_of_fields():
assert posts == [ async with database:
("Ormar strikes again!", "News", "Anonymous"), posts = await Post.objects.select_related("category__created_by").values(
("Why don't you use ormar yet?", "News", "Anonymous"), ["name", "category__name", "category__created_by__name"]
("Check this out, ormar now for free", "News", "Anonymous"), )
] assert posts == [
{
"name": "Ormar strikes again!",
"category__name": "News",
"category__created_by__name": "Anonymous",
},
{
"name": "Why don't you use ormar yet?",
"category__name": "News",
"category__created_by__name": "Anonymous",
},
{
"name": "Check this out, ormar now for free",
"category__name": "News",
"category__created_by__name": "Anonymous",
},
]
@pytest.mark.asyncio
async def test_queryset_simple_values_list():
async with database:
posts = await Post.objects.values_list()
assert posts == [
(1, "Ormar strikes again!", 1),
(2, "Why don't you use ormar yet?", 1),
(3, "Check this out, ormar now for free", 1),
]
@pytest.mark.asyncio
async def test_queryset_nested_relation_values_list():
async with database:
posts = await Post.objects.select_related("category__created_by").values_list()
assert posts == [
(1, "Ormar strikes again!", 1, 1, "News", 0, 1, 1, "Anonymous"),
(2, "Why don't you use ormar yet?", 1, 1, "News", 0, 1, 1, "Anonymous"),
(
3,
"Check this out, ormar now for free",
1,
1,
"News",
0,
1,
1,
"Anonymous",
),
]
@pytest.mark.asyncio
async def test_queryset_nested_relation_subset_of_fields_values_list():
async with database:
posts = await Post.objects.select_related("category__created_by").values_list(
["name", "category__name", "category__created_by__name"]
)
assert posts == [
("Ormar strikes again!", "News", "Anonymous"),
("Why don't you use ormar yet?", "News", "Anonymous"),
("Check this out, ormar now for free", "News", "Anonymous"),
]
@pytest.mark.asyncio
async def test_m2m_values():
async with database:
user = await User.objects.select_related("roles").values()
assert user == [
{
"id": 1,
"name": "Anonymous",
"roleuser__id": 1,
"roleuser__role": 1,
"roleuser__user": 1,
"roles__id": 1,
"roles__name": "admin",
},
{
"id": 1,
"name": "Anonymous",
"roleuser__id": 2,
"roleuser__role": 2,
"roleuser__user": 1,
"roles__id": 2,
"roles__name": "editor",
},
]
@pytest.mark.asyncio
async def test_nested_m2m_values():
async with database:
user = (
await Role.objects.select_related("users__categories")
.filter(name="admin")
.values()
)
assert user == [
{
"id": 1,
"name": "admin",
"roleuser__id": 1,
"roleuser__role": 1,
"roleuser__user": 1,
"users__id": 1,
"users__name": "Anonymous",
"users__categories__id": 1,
"users__categories__name": "News",
"users__categories__sort_order": 0,
"users__categories__created_by": 1,
}
]
@pytest.mark.asyncio
async def test_nested_m2m_values_subset_of_fields():
async with database:
user = (
await Role.objects.select_related("users__categories")
.filter(name="admin")
.fields({"name": ..., "users": {"name": ...}})
.values()
)
assert user == [
{
"name": "admin",
"users__name": "Anonymous",
"users__categories__name": "News",
}
]