start to refactor fields and eclude_fields into ExcludableItems to simplify access

This commit is contained in:
collerek
2021-02-26 17:47:52 +01:00
parent 7bf781098f
commit ad9d065c6d
9 changed files with 434 additions and 20 deletions

162
ormar/models/excludable.py Normal file
View File

@ -0,0 +1,162 @@
from dataclasses import dataclass, field
from typing import Dict, List, Set, TYPE_CHECKING, Tuple, Type, Union
from ormar.queryset.utils import get_relationship_alias_model_and_str
if TYPE_CHECKING: # pragma: no cover
from ormar import Model
@dataclass
class Excludable:
include: Set = field(default_factory=set)
exclude: Set = field(default_factory=set)
def set_values(self, value: Set, is_exclude: bool) -> None:
prop = "exclude" if is_exclude else "include"
if ... in getattr(self, prop) or ... in value:
setattr(self, prop, {...})
else:
current_value = getattr(self, prop)
current_value.update(value)
setattr(self, prop, current_value)
def is_included(self, key: str) -> bool:
return (... in self.include or key in self.include) if self.include else True
def is_excluded(self, key: str) -> bool:
return (... in self.exclude or key in self.exclude) if self.exclude else False
class ExcludableItems:
"""
Keeps a dictionary of Excludables by alias + model_name keys
to allow quick lookup by nested models without need to travers
deeply nested dictionaries and passing include/exclude around
"""
def __init__(self) -> None:
self.items: Dict[str, Excludable] = dict()
def get(self, model_cls: Type["Model"], alias: str = "") -> Excludable:
key = f"{alias + '_' if alias else ''}{model_cls.get_name(lower=True)}"
return self.items.get(key, Excludable())
def build(
self,
items: Union[List[str], str, Tuple[str], Set[str], Dict],
model_cls: Type["Model"],
is_exclude: bool = False,
) -> None:
if isinstance(items, str):
items = {items}
if isinstance(items, Dict):
self._traverse_dict(
values=items,
source_model=model_cls,
model_cls=model_cls,
is_exclude=is_exclude,
)
else:
items = set(items)
nested_items = set(x for x in items if "__" in x)
items.difference_update(nested_items)
self._set_excludes(
items=items,
model_name=model_cls.get_name(lower=True),
is_exclude=is_exclude,
)
if nested_items:
self._traverse_list(
values=nested_items, model_cls=model_cls, is_exclude=is_exclude
)
def _set_excludes(
self, items: Set, model_name: str, is_exclude: bool, alias: str = ""
) -> None:
key = f"{alias + '_' if alias else ''}{model_name}"
excludable = self.items.get(key)
if not excludable:
excludable = Excludable()
excludable.set_values(value=items, is_exclude=is_exclude)
self.items[key] = excludable
def _traverse_dict( # noqa: CFQ002
self,
values: Dict,
source_model: Type["Model"],
model_cls: Type["Model"],
is_exclude: bool,
related_items: List = None,
alias: str = "",
) -> None:
self_fields = set()
related_items = related_items[:] if related_items else []
for key, value in values.items():
if value is ...:
self_fields.add(key)
elif isinstance(value, set):
related_items.append(key)
(
table_prefix,
target_model,
_,
_,
) = get_relationship_alias_model_and_str(
source_model=source_model, related_parts=related_items
)
self._set_excludes(
items=value,
model_name=target_model.get_name(),
is_exclude=is_exclude,
alias=table_prefix,
)
else:
# dict
related_items.append(key)
(
table_prefix,
target_model,
_,
_,
) = get_relationship_alias_model_and_str(
source_model=source_model, related_parts=related_items
)
self._traverse_dict(
values=value,
source_model=source_model,
model_cls=target_model,
is_exclude=is_exclude,
related_items=related_items,
alias=table_prefix,
)
if self_fields:
self._set_excludes(
items=self_fields,
model_name=model_cls.get_name(),
is_exclude=is_exclude,
alias=alias,
)
def _traverse_list(
self, values: Set[str], model_cls: Type["Model"], is_exclude: bool
) -> None:
# here we have only nested related keys
for key in values:
key_split = key.split("__")
related_items, field_name = key_split[:-1], key_split[-1]
(table_prefix, target_model, _, _) = get_relationship_alias_model_and_str(
source_model=model_cls, related_parts=related_items
)
self._set_excludes(
items={field_name},
model_name=target_model.get_name(),
is_exclude=is_exclude,
alias=table_prefix,
)

View File

@ -42,6 +42,9 @@ class FilterAction(QueryAction):
super().__init__(query_str=filter_str, model_cls=model_cls) super().__init__(query_str=filter_str, model_cls=model_cls)
self.filter_value = value self.filter_value = value
self._escape_characters_in_clause() self._escape_characters_in_clause()
self.is_source_model_filter = False
if self.source_model == self.target_model and "__" not in self.related_str:
self.is_source_model_filter = True
def has_escaped_characters(self) -> bool: def has_escaped_characters(self) -> bool:
"""Check if value is a string that contains characters to escape""" """Check if value is a string that contains characters to escape"""

View File

@ -177,12 +177,12 @@ class Query:
filters_to_use = [ filters_to_use = [
filter_clause filter_clause
for filter_clause in self.filter_clauses for filter_clause in self.filter_clauses
if filter_clause.table_prefix == "" if filter_clause.is_source_model_filter
] ]
excludes_to_use = [ excludes_to_use = [
filter_clause filter_clause
for filter_clause in self.exclude_clauses for filter_clause in self.exclude_clauses
if filter_clause.table_prefix == "" if filter_clause.is_source_model_filter
] ]
sorts_to_use = { sorts_to_use = {
k: v for k, v in self.sorted_orders.items() if k.is_source_model_order k: v for k, v in self.sorted_orders.items() if k.is_source_model_order

View File

@ -410,6 +410,7 @@ class QuerySet(Generic[T]):
if isinstance(columns, str): if isinstance(columns, str):
columns = [columns] columns = [columns]
# TODO: Flatten all excludes into one dict-like structure with alias + model key
current_included = self._columns current_included = self._columns
if not isinstance(columns, dict): if not isinstance(columns, dict):
current_included = update_dict_from_list(current_included, columns) current_included = update_dict_from_list(current_included, columns)

View File

@ -230,12 +230,12 @@ def get_relationship_alias_model_and_str(
""" """
table_prefix = "" table_prefix = ""
is_through = False is_through = False
model_cls = source_model target_model = source_model
previous_model = model_cls previous_model = target_model
previous_models = [model_cls] previous_models = [target_model]
manager = model_cls.Meta.alias_manager manager = target_model.Meta.alias_manager
for relation in related_parts[:]: for relation in related_parts[:]:
related_field = model_cls.Meta.model_fields[relation] related_field = target_model.Meta.model_fields[relation]
if related_field.is_through: if related_field.is_through:
# through is always last - cannot go further # through is always last - cannot go further
@ -256,10 +256,10 @@ def get_relationship_alias_model_and_str(
table_prefix = manager.resolve_relation_alias( table_prefix = manager.resolve_relation_alias(
from_model=previous_model, relation_name=relation from_model=previous_model, relation_name=relation
) )
model_cls = related_field.to target_model = related_field.to
previous_model = model_cls previous_model = target_model
if not is_through: if not is_through:
previous_models.append(previous_model) previous_models.append(previous_model)
relation_str = "__".join(related_parts) relation_str = "__".join(related_parts)
return table_prefix, model_cls, relation_str, is_through return table_prefix, target_model, relation_str, is_through

View File

@ -330,13 +330,12 @@ class QuerysetProxy(Generic[T]):
through_kwargs = kwargs.pop(self.through_model_name, {}) through_kwargs = kwargs.pop(self.through_model_name, {})
children = await self.queryset.all() children = await self.queryset.all()
for child in children: for child in children:
if child: await child.update(**kwargs) # type: ignore
await child.update(**kwargs) if self.type_ == ormar.RelationType.MULTIPLE and through_kwargs:
if self.type_ == ormar.RelationType.MULTIPLE and through_kwargs: await self.update_through_instance(
await self.update_through_instance( child=child, # type: ignore
child=child, # type: ignore **through_kwargs,
**through_kwargs, )
)
return len(children) return len(children)
async def get_or_create(self, **kwargs: Any) -> "T": async def get_or_create(self, **kwargs: Any) -> "T":

View File

@ -0,0 +1,218 @@
from typing import List, Optional
import databases
import sqlalchemy
import ormar
from ormar.models.excludable import ExcludableItems
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class BaseMeta(ormar.ModelMeta):
database = database
metadata = metadata
class NickNames(ormar.Model):
class Meta(BaseMeta):
tablename = "nicks"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
is_lame: bool = ormar.Boolean(nullable=True)
class NicksHq(ormar.Model):
class Meta(BaseMeta):
tablename = "nicks_x_hq"
class HQ(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq)
class Company(ormar.Model):
class Meta(BaseMeta):
tablename = "companies"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="company_name")
founded: int = ormar.Integer(nullable=True)
hq: HQ = ormar.ForeignKey(HQ)
class Car(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
manufacturer: Optional[Company] = ormar.ForeignKey(Company)
name: str = ormar.String(max_length=100)
year: int = ormar.Integer(nullable=True)
gearbox_type: str = ormar.String(max_length=20, nullable=True)
gears: int = ormar.Integer(nullable=True)
aircon_type: str = ormar.String(max_length=20, nullable=True)
def compare_results(excludable):
car_excludable = excludable.get(Car)
assert car_excludable.exclude == {"year", "gearbox_type", "gears", "aircon_type"}
assert car_excludable.include == set()
assert car_excludable.is_excluded("year")
alias = Company.Meta.alias_manager.resolve_relation_alias(Car, "manufacturer")
manu_excludable = excludable.get(Company, alias=alias)
assert manu_excludable.exclude == {"founded"}
assert manu_excludable.include == set()
assert manu_excludable.is_excluded("founded")
def compare_results_include(excludable):
manager = Company.Meta.alias_manager
car_excludable = excludable.get(Car)
assert car_excludable.include == {"id", "name"}
assert car_excludable.exclude == set()
assert car_excludable.is_included("name")
assert not car_excludable.is_included("gears")
alias = manager.resolve_relation_alias(Car, "manufacturer")
manu_excludable = excludable.get(Company, alias=alias)
assert manu_excludable.include == {"name"}
assert manu_excludable.exclude == set()
assert manu_excludable.is_included("name")
assert not manu_excludable.is_included("founded")
alias = manager.resolve_relation_alias(Company, "hq")
hq_excludable = excludable.get(HQ, alias=alias)
assert hq_excludable.include == {"name"}
assert hq_excludable.exclude == set()
alias = manager.resolve_relation_alias(NicksHq, "nicknames")
nick_excludable = excludable.get(NickNames, alias=alias)
assert nick_excludable.include == {"name"}
assert nick_excludable.exclude == set()
def test_excluding_fields_from_list():
fields = [
"gearbox_type",
"gears",
"aircon_type",
"year",
"manufacturer__founded",
]
excludable = ExcludableItems()
excludable.build(items=fields, model_cls=Car, is_exclude=True)
compare_results(excludable)
def test_excluding_fields_from_dict():
fields = {
"gearbox_type": ...,
"gears": ...,
"aircon_type": ...,
"year": ...,
"manufacturer": {"founded": ...},
}
excludable = ExcludableItems()
excludable.build(items=fields, model_cls=Car, is_exclude=True)
compare_results(excludable)
def test_excluding_fields_from_dict_with_set():
fields = {
"gearbox_type": ...,
"gears": ...,
"aircon_type": ...,
"year": ...,
"manufacturer": {"founded"},
}
excludable = ExcludableItems()
excludable.build(items=fields, model_cls=Car, is_exclude=True)
compare_results(excludable)
def test_gradual_build_from_lists():
fields_col = [
"year",
["gearbox_type", "gears"],
"aircon_type",
["manufacturer__founded"],
]
excludable = ExcludableItems()
for fields in fields_col:
excludable.build(items=fields, model_cls=Car, is_exclude=True)
compare_results(excludable)
def test_nested_includes():
fields = [
"id",
"name",
"manufacturer__name",
"manufacturer__hq__name",
"manufacturer__hq__nicks__name",
]
excludable = ExcludableItems()
excludable.build(items=fields, model_cls=Car, is_exclude=False)
compare_results_include(excludable)
def test_nested_includes_from_dict():
fields = {
"id": ...,
"name": ...,
"manufacturer": {"name": ..., "hq": {"name": ..., "nicks": {"name": ...}},},
}
excludable = ExcludableItems()
excludable.build(items=fields, model_cls=Car, is_exclude=False)
compare_results_include(excludable)
def test_nested_includes_from_dict_with_set():
fields = {
"id": ...,
"name": ...,
"manufacturer": {"name": ..., "hq": {"name": ..., "nicks": {"name"}},},
}
excludable = ExcludableItems()
excludable.build(items=fields, model_cls=Car, is_exclude=False)
compare_results_include(excludable)
def test_includes_and_excludes_combo():
fields_inc1 = ["id", "name", "year", "gearbox_type", "gears"]
fields_inc2 = {"manufacturer": {"name"}}
fields_exc1 = {"manufacturer__founded"}
fields_exc2 = "aircon_type"
excludable = ExcludableItems()
excludable.build(items=fields_inc1, model_cls=Car, is_exclude=False)
excludable.build(items=fields_inc2, model_cls=Car, is_exclude=False)
excludable.build(items=fields_exc1, model_cls=Car, is_exclude=True)
excludable.build(items=fields_exc2, model_cls=Car, is_exclude=True)
car_excludable = excludable.get(Car)
assert car_excludable.include == {"id", "name", "year", "gearbox_type", "gears"}
assert car_excludable.exclude == {"aircon_type"}
assert car_excludable.is_excluded("aircon_type")
assert car_excludable.is_included("name")
alias = Company.Meta.alias_manager.resolve_relation_alias(Car, "manufacturer")
manu_excludable = excludable.get(Company, alias=alias)
assert manu_excludable.include == {"name"}
assert manu_excludable.exclude == {"founded"}
assert manu_excludable.is_excluded("founded")

View File

@ -1,4 +1,4 @@
from typing import Any, TYPE_CHECKING from typing import Any
import databases import databases
import pytest import pytest
@ -293,6 +293,37 @@ async def test_update_through_from_related() -> Any:
assert post2.categories[2].postcategory.sort_order == 4 assert post2.categories[2].postcategory.sort_order == 4
@pytest.mark.asyncio
@pytest.mark.skip # TODO: Restore after finished exclude refactor
async def test_excluding_fields_on_through_model() -> Any:
async with database:
post = await Post(title="Test post").save()
await post.categories.create(
name="Test category1",
postcategory={"sort_order": 2, "param_name": "volume"},
)
await post.categories.create(
name="Test category2", postcategory={"sort_order": 1, "param_name": "area"}
)
await post.categories.create(
name="Test category3",
postcategory={"sort_order": 3, "param_name": "velocity"},
)
post2 = (
await Post.objects.select_related("categories")
.exclude_fields("postcategory__param_name")
.order_by("postcategory__sort_order")
.get()
)
assert len(post2.categories) == 3
assert post2.categories[0].postcategory.param_name is None
assert post2.categories[0].postcategory.sort_order == 1
assert post2.categories[2].postcategory.param_name is None
assert post2.categories[2].postcategory.sort_order == 3
# TODO: check/ modify following # TODO: check/ modify following
# add to fields with class lower name (V) # add to fields with class lower name (V)

View File

@ -204,8 +204,8 @@ async def test_selecting_subset():
all_cars_dummy = ( all_cars_dummy = (
await Car.objects.select_related("manufacturer") await Car.objects.select_related("manufacturer")
.fields(["id", "name", "year", "gearbox_type", "gears", "aircon_type"]) .fields(["id", "name", "year", "gearbox_type", "gears", "aircon_type"])
.fields({"manufacturer": ...}) # .fields({"manufacturer": ...})
.exclude_fields({"manufacturer": ...}) # .exclude_fields({"manufacturer": ...})
.fields({"manufacturer": {"name"}}) .fields({"manufacturer": {"name"}})
.exclude_fields({"manufacturer__founded"}) .exclude_fields({"manufacturer__founded"})
.all() .all()