allow passing a dict and set to fields and exclude_fields, store it as dict

This commit is contained in:
collerek
2020-11-11 19:00:03 +01:00
parent 5552a8297f
commit 1242e5d600
12 changed files with 510 additions and 131 deletions

View File

@ -0,0 +1,39 @@
from typing import Dict, Set, Union
class Excludable:
@staticmethod
def get_excluded(
exclude: Union[Set, Dict, None], key: str = None
) -> Union[Set, Dict, None]:
if isinstance(exclude, dict):
return exclude.get(key, {})
return exclude
@staticmethod
def get_included(
include: Union[Set, Dict, None], key: str = None
) -> Union[Set, Dict, None]:
return Excludable.get_excluded(exclude=include, key=key)
@staticmethod
def is_excluded(exclude: Union[Set, Dict, None], key: str = None) -> bool:
if exclude is None:
return False
to_exclude = Excludable.get_excluded(exclude=exclude, key=key)
if isinstance(to_exclude, Set):
return key in to_exclude
elif to_exclude is ...:
return True
return False
@staticmethod
def is_included(include: Union[Set, Dict, None], key: str = None) -> bool:
if include is None:
return True
to_include = Excludable.get_included(include=include, key=key)
if isinstance(to_include, Set):
return key in to_include
elif to_include is ...:
return True
return False

View File

@ -1,5 +1,5 @@
import itertools import itertools
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union
import sqlalchemy import sqlalchemy
@ -47,8 +47,8 @@ class Model(NewBaseModel):
select_related: List = None, select_related: List = None,
related_models: Any = None, related_models: Any = None,
previous_table: str = None, previous_table: str = None,
fields: List = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: List = None, exclude_fields: Optional[Union[Dict, Set]] = None,
) -> Optional[T]: ) -> Optional[T]:
item: Dict[str, Any] = {} item: Dict[str, Any] = {}
@ -88,7 +88,6 @@ class Model(NewBaseModel):
table_prefix=table_prefix, table_prefix=table_prefix,
fields=fields, fields=fields,
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
nested=table_prefix != "",
) )
instance: Optional[T] = cls(**item) if item.get( instance: Optional[T] = cls(**item) if item.get(
@ -103,13 +102,17 @@ class Model(NewBaseModel):
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
related_models: Any, related_models: Any,
previous_table: sqlalchemy.Table, previous_table: sqlalchemy.Table,
fields: List = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: List = None, exclude_fields: Optional[Union[Dict, Set]] = None,
) -> dict: ) -> dict:
for related in related_models: for related in related_models:
if isinstance(related_models, dict) and related_models[related]: if isinstance(related_models, dict) and related_models[related]:
first_part, remainder = related, related_models[related] first_part, remainder = related, related_models[related]
model_cls = cls.Meta.model_fields[first_part].to model_cls = cls.Meta.model_fields[first_part].to
fields = cls.get_included(fields, first_part)
exclude_fields = cls.get_excluded(exclude_fields, first_part)
child = model_cls.from_row( child = model_cls.from_row(
row, row,
related_models=remainder, related_models=remainder,
@ -120,6 +123,8 @@ class Model(NewBaseModel):
item[model_cls.get_column_name_from_alias(first_part)] = child item[model_cls.get_column_name_from_alias(first_part)] = child
else: else:
model_cls = cls.Meta.model_fields[related].to model_cls = cls.Meta.model_fields[related].to
fields = cls.get_included(fields, related)
exclude_fields = cls.get_excluded(exclude_fields, related)
child = model_cls.from_row( child = model_cls.from_row(
row, row,
previous_table=previous_table, previous_table=previous_table,
@ -136,16 +141,18 @@ class Model(NewBaseModel):
item: dict, item: dict,
row: sqlalchemy.engine.result.ResultProxy, row: sqlalchemy.engine.result.ResultProxy,
table_prefix: str, table_prefix: str,
fields: List = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: List = None, exclude_fields: Optional[Union[Dict, Set]] = None,
nested: bool = False,
) -> dict: ) -> dict:
# databases does not keep aliases in Record for postgres, change to raw row # databases does not keep aliases in Record for postgres, change to raw row
source = row._row if cls.db_backend_name() == "postgresql" else row source = row._row if cls.db_backend_name() == "postgresql" else row
selected_columns = cls.own_table_columns( selected_columns = cls.own_table_columns(
cls, fields or [], exclude_fields or [], nested=nested, use_alias=True model=cls,
fields=fields or {},
exclude_fields=exclude_fields or {},
use_alias=False,
) )
for column in cls.Meta.table.columns: for column in cls.Meta.table.columns:

View File

@ -1,6 +1,16 @@
import inspect import inspect
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Sequence, Set, TYPE_CHECKING, Type, TypeVar, Union from typing import (
Dict,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
Type,
TypeVar,
Union,
)
import ormar import ormar
from ormar.exceptions import RelationshipInstanceError from ormar.exceptions import RelationshipInstanceError
@ -65,14 +75,14 @@ class ModelTableProxy:
@classmethod @classmethod
def get_column_alias(cls, field_name: str) -> str: def get_column_alias(cls, field_name: str) -> str:
field = cls.Meta.model_fields.get(field_name) field = cls.Meta.model_fields.get(field_name)
if field and field.alias is not None: if field is not None and field.alias is not None:
return field.alias return field.alias
return field_name return field_name
@classmethod @classmethod
def get_column_name_from_alias(cls, alias: str) -> str: def get_column_name_from_alias(cls, alias: str) -> str:
for field_name, field in cls.Meta.model_fields.items(): for field_name, field in cls.Meta.model_fields.items():
if field and field.alias == alias: if field is not None and field.alias == alias:
return field_name return field_name
return alias # if not found it's not an alias but actual name return alias # if not found it's not an alias but actual name
@ -211,59 +221,13 @@ class ModelTableProxy:
) )
return other return other
@staticmethod
def _get_not_nested_columns_from_fields(
model: Type["Model"],
fields: List,
exclude_fields: List,
column_names: List[str],
use_alias: bool = False,
) -> List[str]:
fields = [model.get_column_alias(k) if not use_alias else k for k in fields]
fields = fields or column_names
exclude_fields = [
model.get_column_alias(k) if not use_alias else k for k in exclude_fields
]
columns = [
name
for name in fields
if "__" not in name and name in column_names and name not in exclude_fields
]
return columns
@staticmethod
def _get_nested_columns_from_fields(
model: Type["Model"],
fields: List,
exclude_fields: List,
column_names: List[str],
use_alias: bool = False,
) -> List[str]:
model_name = f"{model.get_name()}__"
columns = [
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
for name in fields
if f"{model.get_name()}__" in name
]
columns = columns or column_names
exclude_columns = [
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
for name in exclude_fields
if f"{model.get_name()}__" in name
]
columns = [model.get_column_alias(k) if not use_alias else k for k in columns]
exclude_columns = [
model.get_column_alias(k) if not use_alias else k for k in exclude_columns
]
return [column for column in columns if column not in exclude_columns]
@staticmethod @staticmethod
def _populate_pk_column( def _populate_pk_column(
model: Type["Model"], columns: List[str], use_alias: bool = False, model: Type["Model"], columns: List[str], use_alias: bool = False,
) -> List[str]: ) -> List[str]:
pk_alias = ( pk_alias = (
model.get_column_alias(model.Meta.pkname) model.get_column_alias(model.Meta.pkname)
if not use_alias if use_alias
else model.Meta.pkname else model.Meta.pkname
) )
if pk_alias not in columns: if pk_alias not in columns:
@ -273,34 +237,30 @@ class ModelTableProxy:
@staticmethod @staticmethod
def own_table_columns( def own_table_columns(
model: Type["Model"], model: Type["Model"],
fields: List, fields: Optional[Union[Set, Dict]],
exclude_fields: List, exclude_fields: Optional[Union[Set, Dict]],
nested: bool = False,
use_alias: bool = False, use_alias: bool = False,
) -> List[str]: ) -> List[str]:
column_names = [ columns = [
model.get_column_name_from_alias(col.name) if use_alias else col.name model.get_column_name_from_alias(col.name) if not use_alias else col.name
for col in model.Meta.table.columns for col in model.Meta.table.columns
] ]
if not fields and not exclude_fields: field_names = [
return column_names model.get_column_name_from_alias(col.name)
for col in model.Meta.table.columns
if not nested: ]
columns = ModelTableProxy._get_not_nested_columns_from_fields( if fields:
model=model, columns = [
fields=fields, col
exclude_fields=exclude_fields, for col, name in zip(columns, field_names)
column_names=column_names, if model.is_included(fields, name)
use_alias=use_alias, ]
) if exclude_fields:
else: columns = [
columns = ModelTableProxy._get_nested_columns_from_fields( col
model=model, for col, name in zip(columns, field_names)
fields=fields, if not model.is_excluded(exclude_fields, name)
exclude_fields=exclude_fields, ]
column_names=column_names,
use_alias=use_alias,
)
# always has to return pk column # always has to return pk column
columns = ModelTableProxy._populate_pk_column( columns = ModelTableProxy._populate_pk_column(

View File

@ -23,6 +23,7 @@ from pydantic import BaseModel
import ormar # noqa I100 import ormar # noqa I100
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.excludable import Excludable
from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.metaclass import ModelMeta, ModelMetaclass
from ormar.models.modelproxy import ModelTableProxy from ormar.models.modelproxy import ModelTableProxy
from ormar.relations.alias_manager import AliasManager from ormar.relations.alias_manager import AliasManager
@ -39,7 +40,9 @@ if TYPE_CHECKING: # pragma no cover
MappingIntStrAny = Mapping[IntStr, Any] MappingIntStrAny = Mapping[IntStr, Any]
class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass): class NewBaseModel(
pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass
):
__slots__ = ("_orm_id", "_orm_saved", "_orm") __slots__ = ("_orm_id", "_orm_saved", "_orm")
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover

View File

@ -1,5 +1,15 @@
from collections import OrderedDict from collections import OrderedDict
from typing import List, NamedTuple, Optional, TYPE_CHECKING, Tuple, Type from typing import (
Dict,
List,
NamedTuple,
Optional,
Set,
TYPE_CHECKING,
Tuple,
Type,
Union,
)
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
@ -24,8 +34,8 @@ class SqlJoin:
used_aliases: List, used_aliases: List,
select_from: sqlalchemy.sql.select, select_from: sqlalchemy.sql.select,
columns: List[sqlalchemy.Column], columns: List[sqlalchemy.Column],
fields: List, fields: Optional[Union[Set, Dict]],
exclude_fields: List, exclude_fields: Optional[Union[Set, Dict]],
order_columns: Optional[List], order_columns: Optional[List],
sorted_orders: OrderedDict, sorted_orders: OrderedDict,
) -> None: ) -> None:
@ -49,21 +59,60 @@ class SqlJoin:
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
return text(f"{left_part}={right_part}") return text(f"{left_part}={right_part}")
def build_join( @staticmethod
def update_inclusions(
model_cls: Type["Model"],
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
nested_name: str,
) -> Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]:
fields = model_cls.get_included(fields, nested_name)
exclude_fields = model_cls.get_included(exclude_fields, nested_name)
return fields, exclude_fields
def build_join( # noqa: CCR001
self, item: str, join_parameters: JoinParameters self, item: str, join_parameters: JoinParameters
) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]: ) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]:
for part in item.split("__"):
fields = self.fields
exclude_fields = self.exclude_fields
for index, part in enumerate(item.split("__")):
if issubclass( if issubclass(
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
): ):
_fields = join_parameters.model_cls.Meta.model_fields _fields = join_parameters.model_cls.Meta.model_fields
new_part = _fields[part].to.get_name() new_part = _fields[part].to.get_name()
self._switch_many_to_many_order_columns(part, new_part) self._switch_many_to_many_order_columns(part, new_part)
if index > 0: # nested joins
fields, exclude_fields = SqlJoin.update_inclusions(
model_cls=join_parameters.model_cls,
fields=fields,
exclude_fields=exclude_fields,
nested_name=part,
)
join_parameters = self._build_join_parameters( join_parameters = self._build_join_parameters(
part, join_parameters, is_multi=True part=part,
join_params=join_parameters,
is_multi=True,
fields=fields,
exclude_fields=exclude_fields,
) )
part = new_part part = new_part
join_parameters = self._build_join_parameters(part, join_parameters) if index > 0: # nested joins
fields, exclude_fields = SqlJoin.update_inclusions(
model_cls=join_parameters.model_cls,
fields=fields,
exclude_fields=exclude_fields,
nested_name=part,
)
join_parameters = self._build_join_parameters(
part=part,
join_params=join_parameters,
fields=fields,
exclude_fields=exclude_fields,
)
return ( return (
self.used_aliases, self.used_aliases,
@ -73,7 +122,12 @@ class SqlJoin:
) )
def _build_join_parameters( def _build_join_parameters(
self, part: str, join_params: JoinParameters, is_multi: bool = False self,
part: str,
join_params: JoinParameters,
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
is_multi: bool = False,
) -> JoinParameters: ) -> JoinParameters:
if is_multi: if is_multi:
model_cls = join_params.model_cls.Meta.model_fields[part].through model_cls = join_params.model_cls.Meta.model_fields[part].through
@ -85,20 +139,30 @@ class SqlJoin:
join_params.from_table, to_table join_params.from_table, to_table
) )
if alias not in self.used_aliases: if alias not in self.used_aliases:
self._process_join(join_params, is_multi, model_cls, part, alias) self._process_join(
join_params=join_params,
is_multi=is_multi,
model_cls=model_cls,
part=part,
alias=alias,
fields=fields,
exclude_fields=exclude_fields,
)
previous_alias = alias previous_alias = alias
from_table = to_table from_table = to_table
prev_model = model_cls prev_model = model_cls
return JoinParameters(prev_model, previous_alias, from_table, model_cls) return JoinParameters(prev_model, previous_alias, from_table, model_cls)
def _process_join( def _process_join( # noqa: CFQ002
self, self,
join_params: JoinParameters, join_params: JoinParameters,
is_multi: bool, is_multi: bool,
model_cls: Type["Model"], model_cls: Type["Model"],
part: str, part: str,
alias: str, alias: str,
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
) -> None: ) -> None:
to_table = model_cls.Meta.table.name to_table = model_cls.Meta.table.name
to_key, from_key = self.get_to_and_from_keys( to_key, from_key = self.get_to_and_from_keys(
@ -129,7 +193,10 @@ class SqlJoin:
) )
self_related_fields = model_cls.own_table_columns( self_related_fields = model_cls.own_table_columns(
model_cls, self.fields, self.exclude_fields, nested=True, model=model_cls,
fields=fields,
exclude_fields=exclude_fields,
use_alias=True,
) )
self.columns.extend( self.columns.extend(
self.relation_manager(model_cls).prefixed_columns( self.relation_manager(model_cls).prefixed_columns(

View File

@ -1,5 +1,6 @@
import copy
from collections import OrderedDict from collections import OrderedDict
from typing import List, Optional, TYPE_CHECKING, Tuple, Type from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Type, Union
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
@ -21,8 +22,8 @@ class Query:
select_related: List, select_related: List,
limit_count: Optional[int], limit_count: Optional[int],
offset: Optional[int], offset: Optional[int],
fields: Optional[List], fields: Optional[Union[Dict, Set]],
exclude_fields: Optional[List], exclude_fields: Optional[Union[Dict, Set]],
order_bys: Optional[List], order_bys: Optional[List],
) -> None: ) -> None:
self.query_offset = offset self.query_offset = offset
@ -30,8 +31,8 @@ class Query:
self._select_related = select_related[:] self._select_related = select_related[:]
self.filter_clauses = filter_clauses[:] self.filter_clauses = filter_clauses[:]
self.exclude_clauses = exclude_clauses[:] self.exclude_clauses = exclude_clauses[:]
self.fields = fields[:] if fields else [] self.fields = copy.deepcopy(fields) if fields else {}
self.exclude_fields = exclude_fields[:] if exclude_fields else [] self.exclude_fields = copy.deepcopy(exclude_fields) if exclude_fields else {}
self.model_cls = model_cls self.model_cls = model_cls
self.table = self.model_cls.Meta.table self.table = self.model_cls.Meta.table
@ -73,7 +74,10 @@ class Query:
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
self_related_fields = self.model_cls.own_table_columns( self_related_fields = self.model_cls.own_table_columns(
self.model_cls, self.fields, self.exclude_fields model=self.model_cls,
fields=self.fields,
exclude_fields=self.exclude_fields,
use_alias=True,
) )
self.columns = self.model_cls.Meta.alias_manager.prefixed_columns( self.columns = self.model_cls.Meta.alias_manager.prefixed_columns(
"", self.table, self_related_fields "", self.table, self_related_fields
@ -87,13 +91,14 @@ class Query:
join_parameters = JoinParameters( join_parameters = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls self.model_cls, "", self.table.name, self.model_cls
) )
fields = self.model_cls.get_included(self.fields, item)
exclude_fields = self.model_cls.get_excluded(self.exclude_fields, item)
sql_join = SqlJoin( sql_join = SqlJoin(
used_aliases=self.used_aliases, used_aliases=self.used_aliases,
select_from=self.select_from, select_from=self.select_from,
columns=self.columns, columns=self.columns,
fields=self.fields, fields=fields,
exclude_fields=self.exclude_fields, exclude_fields=exclude_fields,
order_columns=self.order_columns, order_columns=self.order_columns,
sorted_orders=self.sorted_orders, sorted_orders=self.sorted_orders,
) )
@ -131,5 +136,5 @@ class Query:
self.select_from = [] self.select_from = []
self.columns = [] self.columns = []
self.used_aliases = [] self.used_aliases = []
self.fields = [] self.fields = {}
self.exclude_fields = [] self.exclude_fields = {}

View File

@ -1,4 +1,5 @@
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Type, Union import copy
from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union
import databases import databases
import sqlalchemy import sqlalchemy
@ -10,6 +11,7 @@ from ormar.exceptions import QueryDefinitionError
from ormar.queryset import FilterQuery from ormar.queryset import FilterQuery
from ormar.queryset.clause import QueryClause from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query from ormar.queryset.query import Query
from ormar.queryset.utils import update, update_dict_from_list
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -26,8 +28,8 @@ class QuerySet:
select_related: List = None, select_related: List = None,
limit_count: int = None, limit_count: int = None,
offset: int = None, offset: int = None,
columns: List = None, columns: Dict = None,
exclude_columns: List = None, exclude_columns: Dict = None,
order_bys: List = None, order_bys: List = None,
) -> None: ) -> None:
self.model_cls = model_cls self.model_cls = model_cls
@ -36,8 +38,8 @@ class QuerySet:
self._select_related = [] if select_related is None else select_related self._select_related = [] if select_related is None else select_related
self.limit_count = limit_count self.limit_count = limit_count
self.query_offset = offset self.query_offset = offset
self._columns = columns or [] self._columns = columns or {}
self._exclude_columns = exclude_columns or [] self._exclude_columns = exclude_columns or {}
self.order_bys = order_bys or [] self.order_bys = order_bys or []
def __get__( def __get__(
@ -169,11 +171,16 @@ class QuerySet:
order_bys=self.order_bys, order_bys=self.order_bys,
) )
def exclude_fields(self, columns: Union[List, str]) -> "QuerySet": def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
if not isinstance(columns, list): if isinstance(columns, str):
columns = [columns] columns = [columns]
columns = list(set(list(self._exclude_columns) + columns)) current_excluded = copy.deepcopy(self._exclude_columns)
if not isinstance(columns, dict):
current_excluded = update_dict_from_list(current_excluded, columns)
else:
current_excluded = update(current_excluded, columns)
return self.__class__( return self.__class__(
model_cls=self.model, model_cls=self.model,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
@ -182,15 +189,20 @@ class QuerySet:
limit_count=self.limit_count, limit_count=self.limit_count,
offset=self.query_offset, offset=self.query_offset,
columns=self._columns, columns=self._columns,
exclude_columns=columns, exclude_columns=current_excluded,
order_bys=self.order_bys, order_bys=self.order_bys,
) )
def fields(self, columns: Union[List, str]) -> "QuerySet": def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
if not isinstance(columns, list): if isinstance(columns, str):
columns = [columns] columns = [columns]
columns = list(set(list(self._columns) + columns)) current_included = copy.deepcopy(self._exclude_columns)
if not isinstance(columns, dict):
current_included = update_dict_from_list(current_included, columns)
else:
current_included = update(current_included, columns)
return self.__class__( return self.__class__(
model_cls=self.model, model_cls=self.model,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
@ -198,7 +210,7 @@ class QuerySet:
select_related=self._select_related, select_related=self._select_related,
limit_count=self.limit_count, limit_count=self.limit_count,
offset=self.query_offset, offset=self.query_offset,
columns=columns, columns=current_included,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
) )

57
ormar/queryset/utils.py Normal file
View File

@ -0,0 +1,57 @@
import collections.abc
import copy
from typing import Any, Dict, List, Set, Union
def check_node_not_dict_or_not_last_node(
part: str, parts: List, current_level: Any
) -> bool:
return (part not in current_level and part != parts[-1]) or (
part in current_level and not isinstance(current_level[part], dict)
)
def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001
new_dict: Dict = dict()
for path in list_to_trans:
current_level = new_dict
parts = path.split("__")
for part in parts:
if check_node_not_dict_or_not_last_node(
part=part, parts=parts, current_level=current_level
):
current_level[part] = dict()
elif part not in current_level:
current_level[part] = ...
current_level = current_level[part]
return new_dict
def convert_set_to_required_dict(set_to_convert: set) -> Dict:
new_dict = dict()
for key in set_to_convert:
new_dict[key] = Ellipsis
return new_dict
def update(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001
if current_dict is Ellipsis:
current_dict = dict()
for key, value in updating_dict.items():
if isinstance(value, collections.abc.Mapping):
old_key = current_dict.get(key, {})
if isinstance(old_key, set):
old_key = convert_set_to_required_dict(old_key)
current_dict[key] = update(old_key, value)
elif isinstance(value, set) and isinstance(current_dict.get(key), set):
current_dict[key] = current_dict.get(key).union(value)
else:
current_dict[key] = value
return current_dict
def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> Dict:
updated_dict = copy.copy(curr_dict)
dict_to_update = translate_list_to_dict(list_to_update)
update(updated_dict, dict_to_update)
return updated_dict

View File

@ -117,8 +117,8 @@ async def test_working_with_aliases():
"first_name", "first_name",
"last_name", "last_name",
"born_year", "born_year",
"child__first_name", "children__first_name",
"child__last_name", "children__last_name",
] ]
) )
.get() .get()

View File

@ -80,10 +80,53 @@ async def test_selecting_subset():
all_cars = ( all_cars = (
await Car.objects.select_related("manufacturer") await Car.objects.select_related("manufacturer")
.exclude_fields( .exclude_fields(
["gearbox_type", "gears", "aircon_type", "year", "company__founded"] [
"gearbox_type",
"gears",
"aircon_type",
"year",
"manufacturer__founded",
]
) )
.all() .all()
) )
for car in all_cars:
assert all(
getattr(car, x) is None
for x in ["year", "gearbox_type", "gears", "aircon_type"]
)
assert car.manufacturer.name == "Toyota"
assert car.manufacturer.founded is None
all_cars = (
await Car.objects.select_related("manufacturer")
.exclude_fields(
{
"gearbox_type": ...,
"gears": ...,
"aircon_type": ...,
"year": ...,
"manufacturer": {"founded": ...},
}
)
.all()
)
all_cars2 = (
await Car.objects.select_related("manufacturer")
.exclude_fields(
{
"gearbox_type": ...,
"gears": ...,
"aircon_type": ...,
"year": ...,
"manufacturer": {"founded"},
}
)
.all()
)
assert all_cars == all_cars2
for car in all_cars: for car in all_cars:
assert all( assert all(
getattr(car, x) is None getattr(car, x) is None
@ -119,7 +162,7 @@ async def test_selecting_subset():
all_cars_check2 = ( all_cars_check2 = (
await Car.objects.select_related("manufacturer") await Car.objects.select_related("manufacturer")
.fields(["id", "name", "manufacturer"]) .fields(["id", "name", "manufacturer"])
.exclude_fields("company__founded") .exclude_fields("manufacturer__founded")
.all() .all()
) )
for car in all_cars_check2: for car in all_cars_check2:
@ -133,5 +176,5 @@ async def test_selecting_subset():
with pytest.raises(pydantic.error_wrappers.ValidationError): with pytest.raises(pydantic.error_wrappers.ValidationError):
# cannot exclude mandatory model columns - company__name in this example # cannot exclude mandatory model columns - company__name in this example
await Car.objects.select_related("manufacturer").exclude_fields( await Car.objects.select_related("manufacturer").exclude_fields(
["company__name"] ["manufacturer__name"]
).all() ).all()

View File

@ -0,0 +1,98 @@
from ormar.models.excludable import Excludable
from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update
def test_empty_excludable():
assert Excludable.is_included(None, "key") # all fields included if empty
assert not Excludable.is_excluded(None, "key") # none field excluded if empty
def test_list_to_dict_translation():
tet_list = ["aa", "bb", "cc__aa", "cc__bb", "cc__aa__xx", "cc__aa__yy"]
test = translate_list_to_dict(tet_list)
assert test == {
"aa": Ellipsis,
"bb": Ellipsis,
"cc": {"aa": {"xx": Ellipsis, "yy": Ellipsis}, "bb": Ellipsis},
}
def test_updating_dict_with_list():
curr_dict = {
"aa": Ellipsis,
"bb": Ellipsis,
"cc": {"aa": {"xx": Ellipsis, "yy": Ellipsis}, "bb": Ellipsis},
}
list_to_update = ["ee", "bb__cc", "cc__aa__xx__oo", "cc__aa__oo"]
test = update_dict_from_list(curr_dict, list_to_update)
assert test == {
"aa": Ellipsis,
"bb": {"cc": Ellipsis},
"cc": {
"aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis},
"bb": Ellipsis,
},
"ee": Ellipsis,
}
def test_updating_dict_inc_set_with_list():
curr_dict = {
"aa": Ellipsis,
"bb": Ellipsis,
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
}
list_to_update = ["uu", "bb__cc", "cc__aa__xx__oo", "cc__aa__oo"]
test = update_dict_from_list(curr_dict, list_to_update)
assert test == {
"aa": Ellipsis,
"bb": {"cc": Ellipsis},
"cc": {
"aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis},
"bb": Ellipsis,
},
"uu": Ellipsis,
}
def test_updating_dict_inc_set_with_dict():
curr_dict = {
"aa": Ellipsis,
"bb": Ellipsis,
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
}
dict_to_update = {
"uu": Ellipsis,
"bb": {"cc", "dd"},
"cc": {"aa": {"xx": {"oo": Ellipsis}, "oo": Ellipsis}},
}
test = update(curr_dict, dict_to_update)
assert test == {
"aa": Ellipsis,
"bb": {"cc", "dd"},
"cc": {
"aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis},
"bb": Ellipsis,
},
"uu": Ellipsis,
}
def test_updating_dict_inc_set_with_dict_inc_set():
curr_dict = {
"aa": Ellipsis,
"bb": Ellipsis,
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
}
dict_to_update = {
"uu": Ellipsis,
"bb": {"cc", "dd"},
"cc": {"aa": {"xx", "oo", "zz", "ii"}},
}
test = update(curr_dict, dict_to_update)
assert test == {
"aa": Ellipsis,
"bb": {"cc", "dd"},
"cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis},
"uu": Ellipsis,
}

View File

@ -1,4 +1,5 @@
from typing import Optional import itertools
from typing import Optional, List
import databases import databases
import pydantic import pydantic
@ -12,6 +13,35 @@ database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
class NickNames(ormar.Model):
class Meta:
tablename = "nicks"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
is_lame: bool = ormar.Boolean(nullable=True)
class NicksHq(ormar.Model):
class Meta:
tablename = "nicks_x_hq"
metadata = metadata
database = database
class HQ(ormar.Model):
class Meta:
tablename = "hqs"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq)
class Company(ormar.Model): class Company(ormar.Model):
class Meta: class Meta:
tablename = "companies" tablename = "companies"
@ -21,6 +51,7 @@ class Company(ormar.Model):
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="company_name") name: str = ormar.String(max_length=100, nullable=False, name="company_name")
founded: int = ormar.Integer(nullable=True) founded: int = ormar.Integer(nullable=True)
hq: HQ = ormar.ForeignKey(HQ)
class Car(ormar.Model): class Car(ormar.Model):
@ -51,7 +82,14 @@ def create_test_database():
async def test_selecting_subset(): async def test_selecting_subset():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
toyota = await Company.objects.create(name="Toyota", founded=1937) nick1 = await NickNames.objects.create(name="Nippon", is_lame=False)
nick2 = await NickNames.objects.create(name="EroCherry", is_lame=True)
hq = await HQ.objects.create(name="Japan")
await hq.nicks.add(nick1)
await hq.nicks.add(nick2)
toyota = await Company.objects.create(name="Toyota", founded=1937, hq=hq)
await Car.objects.create( await Car.objects.create(
manufacturer=toyota, manufacturer=toyota,
name="Corolla", name="Corolla",
@ -78,17 +116,66 @@ async def test_selecting_subset():
) )
all_cars = ( all_cars = (
await Car.objects.select_related("manufacturer") await Car.objects.select_related(
.fields(["id", "name", "company__name"]) ["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
)
.fields(
[
"id",
"name",
"manufacturer__name",
"manufacturer__hq__name",
"manufacturer__hq__nicks__name",
]
)
.all() .all()
) )
for car in all_cars:
all_cars2 = (
await Car.objects.select_related(
["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
)
.fields(
{
"id": ...,
"name": ...,
"manufacturer": {
"name": ...,
"hq": {"name": ..., "nicks": {"name": ...}},
},
}
)
.all()
)
all_cars3 = (
await Car.objects.select_related(
["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
)
.fields(
{
"id": ...,
"name": ...,
"manufacturer": {
"name": ...,
"hq": {"name": ..., "nicks": {"name"}},
},
}
)
.all()
)
assert all_cars3 == all_cars
for car in itertools.chain(all_cars, all_cars2):
assert all( assert all(
getattr(car, x) is None getattr(car, x) is None
for x in ["year", "gearbox_type", "gears", "aircon_type"] for x in ["year", "gearbox_type", "gears", "aircon_type"]
) )
assert car.manufacturer.name == "Toyota" assert car.manufacturer.name == "Toyota"
assert car.manufacturer.founded is None assert car.manufacturer.founded is None
assert car.manufacturer.hq.name == "Japan"
assert len(car.manufacturer.hq.nicks) == 2
assert car.manufacturer.hq.nicks[0].is_lame is None
all_cars = ( all_cars = (
await Car.objects.select_related("manufacturer") await Car.objects.select_related("manufacturer")
@ -103,6 +190,7 @@ async def test_selecting_subset():
) )
assert car.manufacturer.name == "Toyota" assert car.manufacturer.name == "Toyota"
assert car.manufacturer.founded == 1937 assert car.manufacturer.founded == 1937
assert car.manufacturer.hq.name is None
all_cars_check = await Car.objects.select_related("manufacturer").all() all_cars_check = await Car.objects.select_related("manufacturer").all()
for car in all_cars_check: for car in all_cars_check:
@ -116,5 +204,5 @@ async def test_selecting_subset():
with pytest.raises(pydantic.error_wrappers.ValidationError): with pytest.raises(pydantic.error_wrappers.ValidationError):
# cannot exclude mandatory model columns - company__name in this example # cannot exclude mandatory model columns - company__name in this example
await Car.objects.select_related("manufacturer").fields( await Car.objects.select_related("manufacturer").fields(
["id", "name", "company__founded"] ["id", "name", "manufacturer__founded"]
).all() ).all()