allow passing a dict and set to fields and exclude_fields, store it as dict

This commit is contained in:
collerek
2020-11-11 19:00:03 +01:00
parent 5552a8297f
commit 1242e5d600
12 changed files with 510 additions and 131 deletions

View File

@ -0,0 +1,39 @@
from typing import Dict, Set, Union
class Excludable:
@staticmethod
def get_excluded(
exclude: Union[Set, Dict, None], key: str = None
) -> Union[Set, Dict, None]:
if isinstance(exclude, dict):
return exclude.get(key, {})
return exclude
@staticmethod
def get_included(
include: Union[Set, Dict, None], key: str = None
) -> Union[Set, Dict, None]:
return Excludable.get_excluded(exclude=include, key=key)
@staticmethod
def is_excluded(exclude: Union[Set, Dict, None], key: str = None) -> bool:
if exclude is None:
return False
to_exclude = Excludable.get_excluded(exclude=exclude, key=key)
if isinstance(to_exclude, Set):
return key in to_exclude
elif to_exclude is ...:
return True
return False
@staticmethod
def is_included(include: Union[Set, Dict, None], key: str = None) -> bool:
if include is None:
return True
to_include = Excludable.get_included(include=include, key=key)
if isinstance(to_include, Set):
return key in to_include
elif to_include is ...:
return True
return False

View File

@ -1,5 +1,5 @@
import itertools
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union
import sqlalchemy
@ -47,8 +47,8 @@ class Model(NewBaseModel):
select_related: List = None,
related_models: Any = None,
previous_table: str = None,
fields: List = None,
exclude_fields: List = None,
fields: Optional[Union[Dict, Set]] = None,
exclude_fields: Optional[Union[Dict, Set]] = None,
) -> Optional[T]:
item: Dict[str, Any] = {}
@ -88,7 +88,6 @@ class Model(NewBaseModel):
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields,
nested=table_prefix != "",
)
instance: Optional[T] = cls(**item) if item.get(
@ -103,13 +102,17 @@ class Model(NewBaseModel):
row: sqlalchemy.engine.ResultProxy,
related_models: Any,
previous_table: sqlalchemy.Table,
fields: List = None,
exclude_fields: List = None,
fields: Optional[Union[Dict, Set]] = None,
exclude_fields: Optional[Union[Dict, Set]] = None,
) -> dict:
for related in related_models:
if isinstance(related_models, dict) and related_models[related]:
first_part, remainder = related, related_models[related]
model_cls = cls.Meta.model_fields[first_part].to
fields = cls.get_included(fields, first_part)
exclude_fields = cls.get_excluded(exclude_fields, first_part)
child = model_cls.from_row(
row,
related_models=remainder,
@ -120,6 +123,8 @@ class Model(NewBaseModel):
item[model_cls.get_column_name_from_alias(first_part)] = child
else:
model_cls = cls.Meta.model_fields[related].to
fields = cls.get_included(fields, related)
exclude_fields = cls.get_excluded(exclude_fields, related)
child = model_cls.from_row(
row,
previous_table=previous_table,
@ -136,16 +141,18 @@ class Model(NewBaseModel):
item: dict,
row: sqlalchemy.engine.result.ResultProxy,
table_prefix: str,
fields: List = None,
exclude_fields: List = None,
nested: bool = False,
fields: Optional[Union[Dict, Set]] = None,
exclude_fields: Optional[Union[Dict, Set]] = None,
) -> dict:
# databases does not keep aliases in Record for postgres, change to raw row
source = row._row if cls.db_backend_name() == "postgresql" else row
selected_columns = cls.own_table_columns(
cls, fields or [], exclude_fields or [], nested=nested, use_alias=True
model=cls,
fields=fields or {},
exclude_fields=exclude_fields or {},
use_alias=False,
)
for column in cls.Meta.table.columns:

View File

@ -1,6 +1,16 @@
import inspect
from collections import OrderedDict
from typing import Dict, List, Sequence, Set, TYPE_CHECKING, Type, TypeVar, Union
from typing import (
Dict,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
Type,
TypeVar,
Union,
)
import ormar
from ormar.exceptions import RelationshipInstanceError
@ -65,14 +75,14 @@ class ModelTableProxy:
@classmethod
def get_column_alias(cls, field_name: str) -> str:
field = cls.Meta.model_fields.get(field_name)
if field and field.alias is not None:
if field is not None and field.alias is not None:
return field.alias
return field_name
@classmethod
def get_column_name_from_alias(cls, alias: str) -> str:
for field_name, field in cls.Meta.model_fields.items():
if field and field.alias == alias:
if field is not None and field.alias == alias:
return field_name
return alias # if not found it's not an alias but actual name
@ -211,59 +221,13 @@ class ModelTableProxy:
)
return other
@staticmethod
def _get_not_nested_columns_from_fields(
model: Type["Model"],
fields: List,
exclude_fields: List,
column_names: List[str],
use_alias: bool = False,
) -> List[str]:
fields = [model.get_column_alias(k) if not use_alias else k for k in fields]
fields = fields or column_names
exclude_fields = [
model.get_column_alias(k) if not use_alias else k for k in exclude_fields
]
columns = [
name
for name in fields
if "__" not in name and name in column_names and name not in exclude_fields
]
return columns
@staticmethod
def _get_nested_columns_from_fields(
model: Type["Model"],
fields: List,
exclude_fields: List,
column_names: List[str],
use_alias: bool = False,
) -> List[str]:
model_name = f"{model.get_name()}__"
columns = [
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
for name in fields
if f"{model.get_name()}__" in name
]
columns = columns or column_names
exclude_columns = [
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
for name in exclude_fields
if f"{model.get_name()}__" in name
]
columns = [model.get_column_alias(k) if not use_alias else k for k in columns]
exclude_columns = [
model.get_column_alias(k) if not use_alias else k for k in exclude_columns
]
return [column for column in columns if column not in exclude_columns]
@staticmethod
def _populate_pk_column(
model: Type["Model"], columns: List[str], use_alias: bool = False,
) -> List[str]:
pk_alias = (
model.get_column_alias(model.Meta.pkname)
if not use_alias
if use_alias
else model.Meta.pkname
)
if pk_alias not in columns:
@ -273,34 +237,30 @@ class ModelTableProxy:
@staticmethod
def own_table_columns(
model: Type["Model"],
fields: List,
exclude_fields: List,
nested: bool = False,
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
use_alias: bool = False,
) -> List[str]:
column_names = [
model.get_column_name_from_alias(col.name) if use_alias else col.name
columns = [
model.get_column_name_from_alias(col.name) if not use_alias else col.name
for col in model.Meta.table.columns
]
if not fields and not exclude_fields:
return column_names
if not nested:
columns = ModelTableProxy._get_not_nested_columns_from_fields(
model=model,
fields=fields,
exclude_fields=exclude_fields,
column_names=column_names,
use_alias=use_alias,
)
else:
columns = ModelTableProxy._get_nested_columns_from_fields(
model=model,
fields=fields,
exclude_fields=exclude_fields,
column_names=column_names,
use_alias=use_alias,
)
field_names = [
model.get_column_name_from_alias(col.name)
for col in model.Meta.table.columns
]
if fields:
columns = [
col
for col, name in zip(columns, field_names)
if model.is_included(fields, name)
]
if exclude_fields:
columns = [
col
for col, name in zip(columns, field_names)
if not model.is_excluded(exclude_fields, name)
]
# always has to return pk column
columns = ModelTableProxy._populate_pk_column(

View File

@ -23,6 +23,7 @@ from pydantic import BaseModel
import ormar # noqa I100
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.excludable import Excludable
from ormar.models.metaclass import ModelMeta, ModelMetaclass
from ormar.models.modelproxy import ModelTableProxy
from ormar.relations.alias_manager import AliasManager
@ -39,7 +40,9 @@ if TYPE_CHECKING: # pragma no cover
MappingIntStrAny = Mapping[IntStr, Any]
class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass):
class NewBaseModel(
pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass
):
__slots__ = ("_orm_id", "_orm_saved", "_orm")
if TYPE_CHECKING: # pragma no cover

View File

@ -1,5 +1,15 @@
from collections import OrderedDict
from typing import List, NamedTuple, Optional, TYPE_CHECKING, Tuple, Type
from typing import (
Dict,
List,
NamedTuple,
Optional,
Set,
TYPE_CHECKING,
Tuple,
Type,
Union,
)
import sqlalchemy
from sqlalchemy import text
@ -24,8 +34,8 @@ class SqlJoin:
used_aliases: List,
select_from: sqlalchemy.sql.select,
columns: List[sqlalchemy.Column],
fields: List,
exclude_fields: List,
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
order_columns: Optional[List],
sorted_orders: OrderedDict,
) -> None:
@ -49,21 +59,60 @@ class SqlJoin:
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
return text(f"{left_part}={right_part}")
def build_join(
@staticmethod
def update_inclusions(
model_cls: Type["Model"],
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
nested_name: str,
) -> Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]:
fields = model_cls.get_included(fields, nested_name)
exclude_fields = model_cls.get_included(exclude_fields, nested_name)
return fields, exclude_fields
def build_join( # noqa: CCR001
self, item: str, join_parameters: JoinParameters
) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]:
for part in item.split("__"):
fields = self.fields
exclude_fields = self.exclude_fields
for index, part in enumerate(item.split("__")):
if issubclass(
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
):
_fields = join_parameters.model_cls.Meta.model_fields
new_part = _fields[part].to.get_name()
self._switch_many_to_many_order_columns(part, new_part)
if index > 0: # nested joins
fields, exclude_fields = SqlJoin.update_inclusions(
model_cls=join_parameters.model_cls,
fields=fields,
exclude_fields=exclude_fields,
nested_name=part,
)
join_parameters = self._build_join_parameters(
part, join_parameters, is_multi=True
part=part,
join_params=join_parameters,
is_multi=True,
fields=fields,
exclude_fields=exclude_fields,
)
part = new_part
join_parameters = self._build_join_parameters(part, join_parameters)
if index > 0: # nested joins
fields, exclude_fields = SqlJoin.update_inclusions(
model_cls=join_parameters.model_cls,
fields=fields,
exclude_fields=exclude_fields,
nested_name=part,
)
join_parameters = self._build_join_parameters(
part=part,
join_params=join_parameters,
fields=fields,
exclude_fields=exclude_fields,
)
return (
self.used_aliases,
@ -73,7 +122,12 @@ class SqlJoin:
)
def _build_join_parameters(
self, part: str, join_params: JoinParameters, is_multi: bool = False
self,
part: str,
join_params: JoinParameters,
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
is_multi: bool = False,
) -> JoinParameters:
if is_multi:
model_cls = join_params.model_cls.Meta.model_fields[part].through
@ -85,20 +139,30 @@ class SqlJoin:
join_params.from_table, to_table
)
if alias not in self.used_aliases:
self._process_join(join_params, is_multi, model_cls, part, alias)
self._process_join(
join_params=join_params,
is_multi=is_multi,
model_cls=model_cls,
part=part,
alias=alias,
fields=fields,
exclude_fields=exclude_fields,
)
previous_alias = alias
from_table = to_table
prev_model = model_cls
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
def _process_join(
def _process_join( # noqa: CFQ002
self,
join_params: JoinParameters,
is_multi: bool,
model_cls: Type["Model"],
part: str,
alias: str,
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
) -> None:
to_table = model_cls.Meta.table.name
to_key, from_key = self.get_to_and_from_keys(
@ -129,7 +193,10 @@ class SqlJoin:
)
self_related_fields = model_cls.own_table_columns(
model_cls, self.fields, self.exclude_fields, nested=True,
model=model_cls,
fields=fields,
exclude_fields=exclude_fields,
use_alias=True,
)
self.columns.extend(
self.relation_manager(model_cls).prefixed_columns(

View File

@ -1,5 +1,6 @@
import copy
from collections import OrderedDict
from typing import List, Optional, TYPE_CHECKING, Tuple, Type
from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Type, Union
import sqlalchemy
from sqlalchemy import text
@ -21,8 +22,8 @@ class Query:
select_related: List,
limit_count: Optional[int],
offset: Optional[int],
fields: Optional[List],
exclude_fields: Optional[List],
fields: Optional[Union[Dict, Set]],
exclude_fields: Optional[Union[Dict, Set]],
order_bys: Optional[List],
) -> None:
self.query_offset = offset
@ -30,8 +31,8 @@ class Query:
self._select_related = select_related[:]
self.filter_clauses = filter_clauses[:]
self.exclude_clauses = exclude_clauses[:]
self.fields = fields[:] if fields else []
self.exclude_fields = exclude_fields[:] if exclude_fields else []
self.fields = copy.deepcopy(fields) if fields else {}
self.exclude_fields = copy.deepcopy(exclude_fields) if exclude_fields else {}
self.model_cls = model_cls
self.table = self.model_cls.Meta.table
@ -73,7 +74,10 @@ class Query:
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
self_related_fields = self.model_cls.own_table_columns(
self.model_cls, self.fields, self.exclude_fields
model=self.model_cls,
fields=self.fields,
exclude_fields=self.exclude_fields,
use_alias=True,
)
self.columns = self.model_cls.Meta.alias_manager.prefixed_columns(
"", self.table, self_related_fields
@ -87,13 +91,14 @@ class Query:
join_parameters = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls
)
fields = self.model_cls.get_included(self.fields, item)
exclude_fields = self.model_cls.get_excluded(self.exclude_fields, item)
sql_join = SqlJoin(
used_aliases=self.used_aliases,
select_from=self.select_from,
columns=self.columns,
fields=self.fields,
exclude_fields=self.exclude_fields,
fields=fields,
exclude_fields=exclude_fields,
order_columns=self.order_columns,
sorted_orders=self.sorted_orders,
)
@ -131,5 +136,5 @@ class Query:
self.select_from = []
self.columns = []
self.used_aliases = []
self.fields = []
self.exclude_fields = []
self.fields = {}
self.exclude_fields = {}

View File

@ -1,4 +1,5 @@
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Type, Union
import copy
from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union
import databases
import sqlalchemy
@ -10,6 +11,7 @@ from ormar.exceptions import QueryDefinitionError
from ormar.queryset import FilterQuery
from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query
from ormar.queryset.utils import update, update_dict_from_list
if TYPE_CHECKING: # pragma no cover
from ormar import Model
@ -26,8 +28,8 @@ class QuerySet:
select_related: List = None,
limit_count: int = None,
offset: int = None,
columns: List = None,
exclude_columns: List = None,
columns: Dict = None,
exclude_columns: Dict = None,
order_bys: List = None,
) -> None:
self.model_cls = model_cls
@ -36,8 +38,8 @@ class QuerySet:
self._select_related = [] if select_related is None else select_related
self.limit_count = limit_count
self.query_offset = offset
self._columns = columns or []
self._exclude_columns = exclude_columns or []
self._columns = columns or {}
self._exclude_columns = exclude_columns or {}
self.order_bys = order_bys or []
def __get__(
@ -169,11 +171,16 @@ class QuerySet:
order_bys=self.order_bys,
)
def exclude_fields(self, columns: Union[List, str]) -> "QuerySet":
if not isinstance(columns, list):
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
if isinstance(columns, str):
columns = [columns]
columns = list(set(list(self._exclude_columns) + columns))
current_excluded = copy.deepcopy(self._exclude_columns)
if not isinstance(columns, dict):
current_excluded = update_dict_from_list(current_excluded, columns)
else:
current_excluded = update(current_excluded, columns)
return self.__class__(
model_cls=self.model,
filter_clauses=self.filter_clauses,
@ -182,15 +189,20 @@ class QuerySet:
limit_count=self.limit_count,
offset=self.query_offset,
columns=self._columns,
exclude_columns=columns,
exclude_columns=current_excluded,
order_bys=self.order_bys,
)
def fields(self, columns: Union[List, str]) -> "QuerySet":
if not isinstance(columns, list):
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
if isinstance(columns, str):
columns = [columns]
columns = list(set(list(self._columns) + columns))
current_included = copy.deepcopy(self._exclude_columns)
if not isinstance(columns, dict):
current_included = update_dict_from_list(current_included, columns)
else:
current_included = update(current_included, columns)
return self.__class__(
model_cls=self.model,
filter_clauses=self.filter_clauses,
@ -198,7 +210,7 @@ class QuerySet:
select_related=self._select_related,
limit_count=self.limit_count,
offset=self.query_offset,
columns=columns,
columns=current_included,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys,
)

57
ormar/queryset/utils.py Normal file
View File

@ -0,0 +1,57 @@
import collections.abc
import copy
from typing import Any, Dict, List, Set, Union
def check_node_not_dict_or_not_last_node(
part: str, parts: List, current_level: Any
) -> bool:
return (part not in current_level and part != parts[-1]) or (
part in current_level and not isinstance(current_level[part], dict)
)
def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001
new_dict: Dict = dict()
for path in list_to_trans:
current_level = new_dict
parts = path.split("__")
for part in parts:
if check_node_not_dict_or_not_last_node(
part=part, parts=parts, current_level=current_level
):
current_level[part] = dict()
elif part not in current_level:
current_level[part] = ...
current_level = current_level[part]
return new_dict
def convert_set_to_required_dict(set_to_convert: set) -> Dict:
new_dict = dict()
for key in set_to_convert:
new_dict[key] = Ellipsis
return new_dict
def update(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001
if current_dict is Ellipsis:
current_dict = dict()
for key, value in updating_dict.items():
if isinstance(value, collections.abc.Mapping):
old_key = current_dict.get(key, {})
if isinstance(old_key, set):
old_key = convert_set_to_required_dict(old_key)
current_dict[key] = update(old_key, value)
elif isinstance(value, set) and isinstance(current_dict.get(key), set):
current_dict[key] = current_dict.get(key).union(value)
else:
current_dict[key] = value
return current_dict
def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> Dict:
updated_dict = copy.copy(curr_dict)
dict_to_update = translate_list_to_dict(list_to_update)
update(updated_dict, dict_to_update)
return updated_dict