allow passing a dict and set to fields and exclude_fields, store it as dict
This commit is contained in:
39
ormar/models/excludable.py
Normal file
39
ormar/models/excludable.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Dict, Set, Union
|
||||
|
||||
|
||||
class Excludable:
|
||||
@staticmethod
|
||||
def get_excluded(
|
||||
exclude: Union[Set, Dict, None], key: str = None
|
||||
) -> Union[Set, Dict, None]:
|
||||
if isinstance(exclude, dict):
|
||||
return exclude.get(key, {})
|
||||
return exclude
|
||||
|
||||
@staticmethod
|
||||
def get_included(
|
||||
include: Union[Set, Dict, None], key: str = None
|
||||
) -> Union[Set, Dict, None]:
|
||||
return Excludable.get_excluded(exclude=include, key=key)
|
||||
|
||||
@staticmethod
|
||||
def is_excluded(exclude: Union[Set, Dict, None], key: str = None) -> bool:
|
||||
if exclude is None:
|
||||
return False
|
||||
to_exclude = Excludable.get_excluded(exclude=exclude, key=key)
|
||||
if isinstance(to_exclude, Set):
|
||||
return key in to_exclude
|
||||
elif to_exclude is ...:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_included(include: Union[Set, Dict, None], key: str = None) -> bool:
|
||||
if include is None:
|
||||
return True
|
||||
to_include = Excludable.get_included(include=include, key=key)
|
||||
if isinstance(to_include, Set):
|
||||
return key in to_include
|
||||
elif to_include is ...:
|
||||
return True
|
||||
return False
|
||||
@ -1,5 +1,5 @@
|
||||
import itertools
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar
|
||||
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
@ -47,8 +47,8 @@ class Model(NewBaseModel):
|
||||
select_related: List = None,
|
||||
related_models: Any = None,
|
||||
previous_table: str = None,
|
||||
fields: List = None,
|
||||
exclude_fields: List = None,
|
||||
fields: Optional[Union[Dict, Set]] = None,
|
||||
exclude_fields: Optional[Union[Dict, Set]] = None,
|
||||
) -> Optional[T]:
|
||||
|
||||
item: Dict[str, Any] = {}
|
||||
@ -88,7 +88,6 @@ class Model(NewBaseModel):
|
||||
table_prefix=table_prefix,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
nested=table_prefix != "",
|
||||
)
|
||||
|
||||
instance: Optional[T] = cls(**item) if item.get(
|
||||
@ -103,13 +102,17 @@ class Model(NewBaseModel):
|
||||
row: sqlalchemy.engine.ResultProxy,
|
||||
related_models: Any,
|
||||
previous_table: sqlalchemy.Table,
|
||||
fields: List = None,
|
||||
exclude_fields: List = None,
|
||||
fields: Optional[Union[Dict, Set]] = None,
|
||||
exclude_fields: Optional[Union[Dict, Set]] = None,
|
||||
) -> dict:
|
||||
for related in related_models:
|
||||
if isinstance(related_models, dict) and related_models[related]:
|
||||
first_part, remainder = related, related_models[related]
|
||||
model_cls = cls.Meta.model_fields[first_part].to
|
||||
|
||||
fields = cls.get_included(fields, first_part)
|
||||
exclude_fields = cls.get_excluded(exclude_fields, first_part)
|
||||
|
||||
child = model_cls.from_row(
|
||||
row,
|
||||
related_models=remainder,
|
||||
@ -120,6 +123,8 @@ class Model(NewBaseModel):
|
||||
item[model_cls.get_column_name_from_alias(first_part)] = child
|
||||
else:
|
||||
model_cls = cls.Meta.model_fields[related].to
|
||||
fields = cls.get_included(fields, related)
|
||||
exclude_fields = cls.get_excluded(exclude_fields, related)
|
||||
child = model_cls.from_row(
|
||||
row,
|
||||
previous_table=previous_table,
|
||||
@ -136,16 +141,18 @@ class Model(NewBaseModel):
|
||||
item: dict,
|
||||
row: sqlalchemy.engine.result.ResultProxy,
|
||||
table_prefix: str,
|
||||
fields: List = None,
|
||||
exclude_fields: List = None,
|
||||
nested: bool = False,
|
||||
fields: Optional[Union[Dict, Set]] = None,
|
||||
exclude_fields: Optional[Union[Dict, Set]] = None,
|
||||
) -> dict:
|
||||
|
||||
# databases does not keep aliases in Record for postgres, change to raw row
|
||||
source = row._row if cls.db_backend_name() == "postgresql" else row
|
||||
|
||||
selected_columns = cls.own_table_columns(
|
||||
cls, fields or [], exclude_fields or [], nested=nested, use_alias=True
|
||||
model=cls,
|
||||
fields=fields or {},
|
||||
exclude_fields=exclude_fields or {},
|
||||
use_alias=False,
|
||||
)
|
||||
|
||||
for column in cls.Meta.table.columns:
|
||||
|
||||
@ -1,6 +1,16 @@
|
||||
import inspect
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Sequence, Set, TYPE_CHECKING, Type, TypeVar, Union
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import RelationshipInstanceError
|
||||
@ -65,14 +75,14 @@ class ModelTableProxy:
|
||||
@classmethod
|
||||
def get_column_alias(cls, field_name: str) -> str:
|
||||
field = cls.Meta.model_fields.get(field_name)
|
||||
if field and field.alias is not None:
|
||||
if field is not None and field.alias is not None:
|
||||
return field.alias
|
||||
return field_name
|
||||
|
||||
@classmethod
|
||||
def get_column_name_from_alias(cls, alias: str) -> str:
|
||||
for field_name, field in cls.Meta.model_fields.items():
|
||||
if field and field.alias == alias:
|
||||
if field is not None and field.alias == alias:
|
||||
return field_name
|
||||
return alias # if not found it's not an alias but actual name
|
||||
|
||||
@ -211,59 +221,13 @@ class ModelTableProxy:
|
||||
)
|
||||
return other
|
||||
|
||||
@staticmethod
|
||||
def _get_not_nested_columns_from_fields(
|
||||
model: Type["Model"],
|
||||
fields: List,
|
||||
exclude_fields: List,
|
||||
column_names: List[str],
|
||||
use_alias: bool = False,
|
||||
) -> List[str]:
|
||||
fields = [model.get_column_alias(k) if not use_alias else k for k in fields]
|
||||
fields = fields or column_names
|
||||
exclude_fields = [
|
||||
model.get_column_alias(k) if not use_alias else k for k in exclude_fields
|
||||
]
|
||||
columns = [
|
||||
name
|
||||
for name in fields
|
||||
if "__" not in name and name in column_names and name not in exclude_fields
|
||||
]
|
||||
return columns
|
||||
|
||||
@staticmethod
|
||||
def _get_nested_columns_from_fields(
|
||||
model: Type["Model"],
|
||||
fields: List,
|
||||
exclude_fields: List,
|
||||
column_names: List[str],
|
||||
use_alias: bool = False,
|
||||
) -> List[str]:
|
||||
model_name = f"{model.get_name()}__"
|
||||
columns = [
|
||||
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
|
||||
for name in fields
|
||||
if f"{model.get_name()}__" in name
|
||||
]
|
||||
columns = columns or column_names
|
||||
exclude_columns = [
|
||||
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
|
||||
for name in exclude_fields
|
||||
if f"{model.get_name()}__" in name
|
||||
]
|
||||
columns = [model.get_column_alias(k) if not use_alias else k for k in columns]
|
||||
exclude_columns = [
|
||||
model.get_column_alias(k) if not use_alias else k for k in exclude_columns
|
||||
]
|
||||
return [column for column in columns if column not in exclude_columns]
|
||||
|
||||
@staticmethod
|
||||
def _populate_pk_column(
|
||||
model: Type["Model"], columns: List[str], use_alias: bool = False,
|
||||
) -> List[str]:
|
||||
pk_alias = (
|
||||
model.get_column_alias(model.Meta.pkname)
|
||||
if not use_alias
|
||||
if use_alias
|
||||
else model.Meta.pkname
|
||||
)
|
||||
if pk_alias not in columns:
|
||||
@ -273,34 +237,30 @@ class ModelTableProxy:
|
||||
@staticmethod
|
||||
def own_table_columns(
|
||||
model: Type["Model"],
|
||||
fields: List,
|
||||
exclude_fields: List,
|
||||
nested: bool = False,
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
use_alias: bool = False,
|
||||
) -> List[str]:
|
||||
column_names = [
|
||||
model.get_column_name_from_alias(col.name) if use_alias else col.name
|
||||
columns = [
|
||||
model.get_column_name_from_alias(col.name) if not use_alias else col.name
|
||||
for col in model.Meta.table.columns
|
||||
]
|
||||
if not fields and not exclude_fields:
|
||||
return column_names
|
||||
|
||||
if not nested:
|
||||
columns = ModelTableProxy._get_not_nested_columns_from_fields(
|
||||
model=model,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
column_names=column_names,
|
||||
use_alias=use_alias,
|
||||
)
|
||||
else:
|
||||
columns = ModelTableProxy._get_nested_columns_from_fields(
|
||||
model=model,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
column_names=column_names,
|
||||
use_alias=use_alias,
|
||||
)
|
||||
field_names = [
|
||||
model.get_column_name_from_alias(col.name)
|
||||
for col in model.Meta.table.columns
|
||||
]
|
||||
if fields:
|
||||
columns = [
|
||||
col
|
||||
for col, name in zip(columns, field_names)
|
||||
if model.is_included(fields, name)
|
||||
]
|
||||
if exclude_fields:
|
||||
columns = [
|
||||
col
|
||||
for col, name in zip(columns, field_names)
|
||||
if not model.is_excluded(exclude_fields, name)
|
||||
]
|
||||
|
||||
# always has to return pk column
|
||||
columns = ModelTableProxy._populate_pk_column(
|
||||
|
||||
@ -23,6 +23,7 @@ from pydantic import BaseModel
|
||||
import ormar # noqa I100
|
||||
from ormar.fields import BaseField
|
||||
from ormar.fields.foreign_key import ForeignKeyField
|
||||
from ormar.models.excludable import Excludable
|
||||
from ormar.models.metaclass import ModelMeta, ModelMetaclass
|
||||
from ormar.models.modelproxy import ModelTableProxy
|
||||
from ormar.relations.alias_manager import AliasManager
|
||||
@ -39,7 +40,9 @@ if TYPE_CHECKING: # pragma no cover
|
||||
MappingIntStrAny = Mapping[IntStr, Any]
|
||||
|
||||
|
||||
class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass):
|
||||
class NewBaseModel(
|
||||
pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass
|
||||
):
|
||||
__slots__ = ("_orm_id", "_orm_saved", "_orm")
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
|
||||
@ -1,5 +1,15 @@
|
||||
from collections import OrderedDict
|
||||
from typing import List, NamedTuple, Optional, TYPE_CHECKING, Tuple, Type
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import text
|
||||
@ -24,8 +34,8 @@ class SqlJoin:
|
||||
used_aliases: List,
|
||||
select_from: sqlalchemy.sql.select,
|
||||
columns: List[sqlalchemy.Column],
|
||||
fields: List,
|
||||
exclude_fields: List,
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
order_columns: Optional[List],
|
||||
sorted_orders: OrderedDict,
|
||||
) -> None:
|
||||
@ -49,21 +59,60 @@ class SqlJoin:
|
||||
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
|
||||
return text(f"{left_part}={right_part}")
|
||||
|
||||
def build_join(
|
||||
@staticmethod
|
||||
def update_inclusions(
|
||||
model_cls: Type["Model"],
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
nested_name: str,
|
||||
) -> Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]:
|
||||
fields = model_cls.get_included(fields, nested_name)
|
||||
exclude_fields = model_cls.get_included(exclude_fields, nested_name)
|
||||
return fields, exclude_fields
|
||||
|
||||
def build_join( # noqa: CCR001
|
||||
self, item: str, join_parameters: JoinParameters
|
||||
) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]:
|
||||
for part in item.split("__"):
|
||||
|
||||
fields = self.fields
|
||||
exclude_fields = self.exclude_fields
|
||||
|
||||
for index, part in enumerate(item.split("__")):
|
||||
if issubclass(
|
||||
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
|
||||
):
|
||||
_fields = join_parameters.model_cls.Meta.model_fields
|
||||
new_part = _fields[part].to.get_name()
|
||||
self._switch_many_to_many_order_columns(part, new_part)
|
||||
if index > 0: # nested joins
|
||||
fields, exclude_fields = SqlJoin.update_inclusions(
|
||||
model_cls=join_parameters.model_cls,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
nested_name=part,
|
||||
)
|
||||
|
||||
join_parameters = self._build_join_parameters(
|
||||
part, join_parameters, is_multi=True
|
||||
part=part,
|
||||
join_params=join_parameters,
|
||||
is_multi=True,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
)
|
||||
part = new_part
|
||||
join_parameters = self._build_join_parameters(part, join_parameters)
|
||||
if index > 0: # nested joins
|
||||
fields, exclude_fields = SqlJoin.update_inclusions(
|
||||
model_cls=join_parameters.model_cls,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
nested_name=part,
|
||||
)
|
||||
join_parameters = self._build_join_parameters(
|
||||
part=part,
|
||||
join_params=join_parameters,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
)
|
||||
|
||||
return (
|
||||
self.used_aliases,
|
||||
@ -73,7 +122,12 @@ class SqlJoin:
|
||||
)
|
||||
|
||||
def _build_join_parameters(
|
||||
self, part: str, join_params: JoinParameters, is_multi: bool = False
|
||||
self,
|
||||
part: str,
|
||||
join_params: JoinParameters,
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
is_multi: bool = False,
|
||||
) -> JoinParameters:
|
||||
if is_multi:
|
||||
model_cls = join_params.model_cls.Meta.model_fields[part].through
|
||||
@ -85,20 +139,30 @@ class SqlJoin:
|
||||
join_params.from_table, to_table
|
||||
)
|
||||
if alias not in self.used_aliases:
|
||||
self._process_join(join_params, is_multi, model_cls, part, alias)
|
||||
self._process_join(
|
||||
join_params=join_params,
|
||||
is_multi=is_multi,
|
||||
model_cls=model_cls,
|
||||
part=part,
|
||||
alias=alias,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
)
|
||||
|
||||
previous_alias = alias
|
||||
from_table = to_table
|
||||
prev_model = model_cls
|
||||
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
||||
|
||||
def _process_join(
|
||||
def _process_join( # noqa: CFQ002
|
||||
self,
|
||||
join_params: JoinParameters,
|
||||
is_multi: bool,
|
||||
model_cls: Type["Model"],
|
||||
part: str,
|
||||
alias: str,
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
) -> None:
|
||||
to_table = model_cls.Meta.table.name
|
||||
to_key, from_key = self.get_to_and_from_keys(
|
||||
@ -129,7 +193,10 @@ class SqlJoin:
|
||||
)
|
||||
|
||||
self_related_fields = model_cls.own_table_columns(
|
||||
model_cls, self.fields, self.exclude_fields, nested=True,
|
||||
model=model_cls,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
use_alias=True,
|
||||
)
|
||||
self.columns.extend(
|
||||
self.relation_manager(model_cls).prefixed_columns(
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
from typing import List, Optional, TYPE_CHECKING, Tuple, Type
|
||||
from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Type, Union
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import text
|
||||
@ -21,8 +22,8 @@ class Query:
|
||||
select_related: List,
|
||||
limit_count: Optional[int],
|
||||
offset: Optional[int],
|
||||
fields: Optional[List],
|
||||
exclude_fields: Optional[List],
|
||||
fields: Optional[Union[Dict, Set]],
|
||||
exclude_fields: Optional[Union[Dict, Set]],
|
||||
order_bys: Optional[List],
|
||||
) -> None:
|
||||
self.query_offset = offset
|
||||
@ -30,8 +31,8 @@ class Query:
|
||||
self._select_related = select_related[:]
|
||||
self.filter_clauses = filter_clauses[:]
|
||||
self.exclude_clauses = exclude_clauses[:]
|
||||
self.fields = fields[:] if fields else []
|
||||
self.exclude_fields = exclude_fields[:] if exclude_fields else []
|
||||
self.fields = copy.deepcopy(fields) if fields else {}
|
||||
self.exclude_fields = copy.deepcopy(exclude_fields) if exclude_fields else {}
|
||||
|
||||
self.model_cls = model_cls
|
||||
self.table = self.model_cls.Meta.table
|
||||
@ -73,7 +74,10 @@ class Query:
|
||||
|
||||
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
|
||||
self_related_fields = self.model_cls.own_table_columns(
|
||||
self.model_cls, self.fields, self.exclude_fields
|
||||
model=self.model_cls,
|
||||
fields=self.fields,
|
||||
exclude_fields=self.exclude_fields,
|
||||
use_alias=True,
|
||||
)
|
||||
self.columns = self.model_cls.Meta.alias_manager.prefixed_columns(
|
||||
"", self.table, self_related_fields
|
||||
@ -87,13 +91,14 @@ class Query:
|
||||
join_parameters = JoinParameters(
|
||||
self.model_cls, "", self.table.name, self.model_cls
|
||||
)
|
||||
|
||||
fields = self.model_cls.get_included(self.fields, item)
|
||||
exclude_fields = self.model_cls.get_excluded(self.exclude_fields, item)
|
||||
sql_join = SqlJoin(
|
||||
used_aliases=self.used_aliases,
|
||||
select_from=self.select_from,
|
||||
columns=self.columns,
|
||||
fields=self.fields,
|
||||
exclude_fields=self.exclude_fields,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
order_columns=self.order_columns,
|
||||
sorted_orders=self.sorted_orders,
|
||||
)
|
||||
@ -131,5 +136,5 @@ class Query:
|
||||
self.select_from = []
|
||||
self.columns = []
|
||||
self.used_aliases = []
|
||||
self.fields = []
|
||||
self.exclude_fields = []
|
||||
self.fields = {}
|
||||
self.exclude_fields = {}
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Type, Union
|
||||
import copy
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union
|
||||
|
||||
import databases
|
||||
import sqlalchemy
|
||||
@ -10,6 +11,7 @@ from ormar.exceptions import QueryDefinitionError
|
||||
from ormar.queryset import FilterQuery
|
||||
from ormar.queryset.clause import QueryClause
|
||||
from ormar.queryset.query import Query
|
||||
from ormar.queryset.utils import update, update_dict_from_list
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
@ -26,8 +28,8 @@ class QuerySet:
|
||||
select_related: List = None,
|
||||
limit_count: int = None,
|
||||
offset: int = None,
|
||||
columns: List = None,
|
||||
exclude_columns: List = None,
|
||||
columns: Dict = None,
|
||||
exclude_columns: Dict = None,
|
||||
order_bys: List = None,
|
||||
) -> None:
|
||||
self.model_cls = model_cls
|
||||
@ -36,8 +38,8 @@ class QuerySet:
|
||||
self._select_related = [] if select_related is None else select_related
|
||||
self.limit_count = limit_count
|
||||
self.query_offset = offset
|
||||
self._columns = columns or []
|
||||
self._exclude_columns = exclude_columns or []
|
||||
self._columns = columns or {}
|
||||
self._exclude_columns = exclude_columns or {}
|
||||
self.order_bys = order_bys or []
|
||||
|
||||
def __get__(
|
||||
@ -169,11 +171,16 @@ class QuerySet:
|
||||
order_bys=self.order_bys,
|
||||
)
|
||||
|
||||
def exclude_fields(self, columns: Union[List, str]) -> "QuerySet":
|
||||
if not isinstance(columns, list):
|
||||
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
|
||||
if isinstance(columns, str):
|
||||
columns = [columns]
|
||||
|
||||
columns = list(set(list(self._exclude_columns) + columns))
|
||||
current_excluded = copy.deepcopy(self._exclude_columns)
|
||||
if not isinstance(columns, dict):
|
||||
current_excluded = update_dict_from_list(current_excluded, columns)
|
||||
else:
|
||||
current_excluded = update(current_excluded, columns)
|
||||
|
||||
return self.__class__(
|
||||
model_cls=self.model,
|
||||
filter_clauses=self.filter_clauses,
|
||||
@ -182,15 +189,20 @@ class QuerySet:
|
||||
limit_count=self.limit_count,
|
||||
offset=self.query_offset,
|
||||
columns=self._columns,
|
||||
exclude_columns=columns,
|
||||
exclude_columns=current_excluded,
|
||||
order_bys=self.order_bys,
|
||||
)
|
||||
|
||||
def fields(self, columns: Union[List, str]) -> "QuerySet":
|
||||
if not isinstance(columns, list):
|
||||
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
|
||||
if isinstance(columns, str):
|
||||
columns = [columns]
|
||||
|
||||
columns = list(set(list(self._columns) + columns))
|
||||
current_included = copy.deepcopy(self._exclude_columns)
|
||||
if not isinstance(columns, dict):
|
||||
current_included = update_dict_from_list(current_included, columns)
|
||||
else:
|
||||
current_included = update(current_included, columns)
|
||||
|
||||
return self.__class__(
|
||||
model_cls=self.model,
|
||||
filter_clauses=self.filter_clauses,
|
||||
@ -198,7 +210,7 @@ class QuerySet:
|
||||
select_related=self._select_related,
|
||||
limit_count=self.limit_count,
|
||||
offset=self.query_offset,
|
||||
columns=columns,
|
||||
columns=current_included,
|
||||
exclude_columns=self._exclude_columns,
|
||||
order_bys=self.order_bys,
|
||||
)
|
||||
|
||||
57
ormar/queryset/utils.py
Normal file
57
ormar/queryset/utils.py
Normal file
@ -0,0 +1,57 @@
|
||||
import collections.abc
|
||||
import copy
|
||||
from typing import Any, Dict, List, Set, Union
|
||||
|
||||
|
||||
def check_node_not_dict_or_not_last_node(
|
||||
part: str, parts: List, current_level: Any
|
||||
) -> bool:
|
||||
return (part not in current_level and part != parts[-1]) or (
|
||||
part in current_level and not isinstance(current_level[part], dict)
|
||||
)
|
||||
|
||||
|
||||
def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001
|
||||
new_dict: Dict = dict()
|
||||
for path in list_to_trans:
|
||||
current_level = new_dict
|
||||
parts = path.split("__")
|
||||
for part in parts:
|
||||
if check_node_not_dict_or_not_last_node(
|
||||
part=part, parts=parts, current_level=current_level
|
||||
):
|
||||
current_level[part] = dict()
|
||||
elif part not in current_level:
|
||||
current_level[part] = ...
|
||||
current_level = current_level[part]
|
||||
return new_dict
|
||||
|
||||
|
||||
def convert_set_to_required_dict(set_to_convert: set) -> Dict:
|
||||
new_dict = dict()
|
||||
for key in set_to_convert:
|
||||
new_dict[key] = Ellipsis
|
||||
return new_dict
|
||||
|
||||
|
||||
def update(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001
|
||||
if current_dict is Ellipsis:
|
||||
current_dict = dict()
|
||||
for key, value in updating_dict.items():
|
||||
if isinstance(value, collections.abc.Mapping):
|
||||
old_key = current_dict.get(key, {})
|
||||
if isinstance(old_key, set):
|
||||
old_key = convert_set_to_required_dict(old_key)
|
||||
current_dict[key] = update(old_key, value)
|
||||
elif isinstance(value, set) and isinstance(current_dict.get(key), set):
|
||||
current_dict[key] = current_dict.get(key).union(value)
|
||||
else:
|
||||
current_dict[key] = value
|
||||
return current_dict
|
||||
|
||||
|
||||
def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> Dict:
|
||||
updated_dict = copy.copy(curr_dict)
|
||||
dict_to_update = translate_list_to_dict(list_to_update)
|
||||
update(updated_dict, dict_to_update)
|
||||
return updated_dict
|
||||
@ -117,8 +117,8 @@ async def test_working_with_aliases():
|
||||
"first_name",
|
||||
"last_name",
|
||||
"born_year",
|
||||
"child__first_name",
|
||||
"child__last_name",
|
||||
"children__first_name",
|
||||
"children__last_name",
|
||||
]
|
||||
)
|
||||
.get()
|
||||
|
||||
@ -80,10 +80,53 @@ async def test_selecting_subset():
|
||||
all_cars = (
|
||||
await Car.objects.select_related("manufacturer")
|
||||
.exclude_fields(
|
||||
["gearbox_type", "gears", "aircon_type", "year", "company__founded"]
|
||||
[
|
||||
"gearbox_type",
|
||||
"gears",
|
||||
"aircon_type",
|
||||
"year",
|
||||
"manufacturer__founded",
|
||||
]
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for car in all_cars:
|
||||
assert all(
|
||||
getattr(car, x) is None
|
||||
for x in ["year", "gearbox_type", "gears", "aircon_type"]
|
||||
)
|
||||
assert car.manufacturer.name == "Toyota"
|
||||
assert car.manufacturer.founded is None
|
||||
|
||||
all_cars = (
|
||||
await Car.objects.select_related("manufacturer")
|
||||
.exclude_fields(
|
||||
{
|
||||
"gearbox_type": ...,
|
||||
"gears": ...,
|
||||
"aircon_type": ...,
|
||||
"year": ...,
|
||||
"manufacturer": {"founded": ...},
|
||||
}
|
||||
)
|
||||
.all()
|
||||
)
|
||||
all_cars2 = (
|
||||
await Car.objects.select_related("manufacturer")
|
||||
.exclude_fields(
|
||||
{
|
||||
"gearbox_type": ...,
|
||||
"gears": ...,
|
||||
"aircon_type": ...,
|
||||
"year": ...,
|
||||
"manufacturer": {"founded"},
|
||||
}
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
assert all_cars == all_cars2
|
||||
|
||||
for car in all_cars:
|
||||
assert all(
|
||||
getattr(car, x) is None
|
||||
@ -119,7 +162,7 @@ async def test_selecting_subset():
|
||||
all_cars_check2 = (
|
||||
await Car.objects.select_related("manufacturer")
|
||||
.fields(["id", "name", "manufacturer"])
|
||||
.exclude_fields("company__founded")
|
||||
.exclude_fields("manufacturer__founded")
|
||||
.all()
|
||||
)
|
||||
for car in all_cars_check2:
|
||||
@ -133,5 +176,5 @@ async def test_selecting_subset():
|
||||
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
||||
# cannot exclude mandatory model columns - company__name in this example
|
||||
await Car.objects.select_related("manufacturer").exclude_fields(
|
||||
["company__name"]
|
||||
["manufacturer__name"]
|
||||
).all()
|
||||
|
||||
98
tests/test_queryset_utils.py
Normal file
98
tests/test_queryset_utils.py
Normal file
@ -0,0 +1,98 @@
|
||||
from ormar.models.excludable import Excludable
|
||||
from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update
|
||||
|
||||
|
||||
def test_empty_excludable():
|
||||
assert Excludable.is_included(None, "key") # all fields included if empty
|
||||
assert not Excludable.is_excluded(None, "key") # none field excluded if empty
|
||||
|
||||
|
||||
def test_list_to_dict_translation():
|
||||
tet_list = ["aa", "bb", "cc__aa", "cc__bb", "cc__aa__xx", "cc__aa__yy"]
|
||||
test = translate_list_to_dict(tet_list)
|
||||
assert test == {
|
||||
"aa": Ellipsis,
|
||||
"bb": Ellipsis,
|
||||
"cc": {"aa": {"xx": Ellipsis, "yy": Ellipsis}, "bb": Ellipsis},
|
||||
}
|
||||
|
||||
|
||||
def test_updating_dict_with_list():
|
||||
curr_dict = {
|
||||
"aa": Ellipsis,
|
||||
"bb": Ellipsis,
|
||||
"cc": {"aa": {"xx": Ellipsis, "yy": Ellipsis}, "bb": Ellipsis},
|
||||
}
|
||||
list_to_update = ["ee", "bb__cc", "cc__aa__xx__oo", "cc__aa__oo"]
|
||||
test = update_dict_from_list(curr_dict, list_to_update)
|
||||
assert test == {
|
||||
"aa": Ellipsis,
|
||||
"bb": {"cc": Ellipsis},
|
||||
"cc": {
|
||||
"aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis},
|
||||
"bb": Ellipsis,
|
||||
},
|
||||
"ee": Ellipsis,
|
||||
}
|
||||
|
||||
|
||||
def test_updating_dict_inc_set_with_list():
|
||||
curr_dict = {
|
||||
"aa": Ellipsis,
|
||||
"bb": Ellipsis,
|
||||
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
|
||||
}
|
||||
list_to_update = ["uu", "bb__cc", "cc__aa__xx__oo", "cc__aa__oo"]
|
||||
test = update_dict_from_list(curr_dict, list_to_update)
|
||||
assert test == {
|
||||
"aa": Ellipsis,
|
||||
"bb": {"cc": Ellipsis},
|
||||
"cc": {
|
||||
"aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis},
|
||||
"bb": Ellipsis,
|
||||
},
|
||||
"uu": Ellipsis,
|
||||
}
|
||||
|
||||
|
||||
def test_updating_dict_inc_set_with_dict():
|
||||
curr_dict = {
|
||||
"aa": Ellipsis,
|
||||
"bb": Ellipsis,
|
||||
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
|
||||
}
|
||||
dict_to_update = {
|
||||
"uu": Ellipsis,
|
||||
"bb": {"cc", "dd"},
|
||||
"cc": {"aa": {"xx": {"oo": Ellipsis}, "oo": Ellipsis}},
|
||||
}
|
||||
test = update(curr_dict, dict_to_update)
|
||||
assert test == {
|
||||
"aa": Ellipsis,
|
||||
"bb": {"cc", "dd"},
|
||||
"cc": {
|
||||
"aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis},
|
||||
"bb": Ellipsis,
|
||||
},
|
||||
"uu": Ellipsis,
|
||||
}
|
||||
|
||||
|
||||
def test_updating_dict_inc_set_with_dict_inc_set():
|
||||
curr_dict = {
|
||||
"aa": Ellipsis,
|
||||
"bb": Ellipsis,
|
||||
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
|
||||
}
|
||||
dict_to_update = {
|
||||
"uu": Ellipsis,
|
||||
"bb": {"cc", "dd"},
|
||||
"cc": {"aa": {"xx", "oo", "zz", "ii"}},
|
||||
}
|
||||
test = update(curr_dict, dict_to_update)
|
||||
assert test == {
|
||||
"aa": Ellipsis,
|
||||
"bb": {"cc", "dd"},
|
||||
"cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis},
|
||||
"uu": Ellipsis,
|
||||
}
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
import itertools
|
||||
from typing import Optional, List
|
||||
|
||||
import databases
|
||||
import pydantic
|
||||
@ -12,6 +13,35 @@ database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
class NickNames(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "nicks"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
|
||||
is_lame: bool = ormar.Boolean(nullable=True)
|
||||
|
||||
|
||||
class NicksHq(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "nicks_x_hq"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
|
||||
class HQ(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "hqs"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
|
||||
nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq)
|
||||
|
||||
|
||||
class Company(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "companies"
|
||||
@ -21,6 +51,7 @@ class Company(ormar.Model):
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
name: str = ormar.String(max_length=100, nullable=False, name="company_name")
|
||||
founded: int = ormar.Integer(nullable=True)
|
||||
hq: HQ = ormar.ForeignKey(HQ)
|
||||
|
||||
|
||||
class Car(ormar.Model):
|
||||
@ -51,7 +82,14 @@ def create_test_database():
|
||||
async def test_selecting_subset():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
toyota = await Company.objects.create(name="Toyota", founded=1937)
|
||||
nick1 = await NickNames.objects.create(name="Nippon", is_lame=False)
|
||||
nick2 = await NickNames.objects.create(name="EroCherry", is_lame=True)
|
||||
hq = await HQ.objects.create(name="Japan")
|
||||
await hq.nicks.add(nick1)
|
||||
await hq.nicks.add(nick2)
|
||||
|
||||
toyota = await Company.objects.create(name="Toyota", founded=1937, hq=hq)
|
||||
|
||||
await Car.objects.create(
|
||||
manufacturer=toyota,
|
||||
name="Corolla",
|
||||
@ -78,17 +116,66 @@ async def test_selecting_subset():
|
||||
)
|
||||
|
||||
all_cars = (
|
||||
await Car.objects.select_related("manufacturer")
|
||||
.fields(["id", "name", "company__name"])
|
||||
await Car.objects.select_related(
|
||||
["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
|
||||
)
|
||||
.fields(
|
||||
[
|
||||
"id",
|
||||
"name",
|
||||
"manufacturer__name",
|
||||
"manufacturer__hq__name",
|
||||
"manufacturer__hq__nicks__name",
|
||||
]
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for car in all_cars:
|
||||
|
||||
all_cars2 = (
|
||||
await Car.objects.select_related(
|
||||
["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
|
||||
)
|
||||
.fields(
|
||||
{
|
||||
"id": ...,
|
||||
"name": ...,
|
||||
"manufacturer": {
|
||||
"name": ...,
|
||||
"hq": {"name": ..., "nicks": {"name": ...}},
|
||||
},
|
||||
}
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
all_cars3 = (
|
||||
await Car.objects.select_related(
|
||||
["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
|
||||
)
|
||||
.fields(
|
||||
{
|
||||
"id": ...,
|
||||
"name": ...,
|
||||
"manufacturer": {
|
||||
"name": ...,
|
||||
"hq": {"name": ..., "nicks": {"name"}},
|
||||
},
|
||||
}
|
||||
)
|
||||
.all()
|
||||
)
|
||||
assert all_cars3 == all_cars
|
||||
|
||||
for car in itertools.chain(all_cars, all_cars2):
|
||||
assert all(
|
||||
getattr(car, x) is None
|
||||
for x in ["year", "gearbox_type", "gears", "aircon_type"]
|
||||
)
|
||||
assert car.manufacturer.name == "Toyota"
|
||||
assert car.manufacturer.founded is None
|
||||
assert car.manufacturer.hq.name == "Japan"
|
||||
assert len(car.manufacturer.hq.nicks) == 2
|
||||
assert car.manufacturer.hq.nicks[0].is_lame is None
|
||||
|
||||
all_cars = (
|
||||
await Car.objects.select_related("manufacturer")
|
||||
@ -103,6 +190,7 @@ async def test_selecting_subset():
|
||||
)
|
||||
assert car.manufacturer.name == "Toyota"
|
||||
assert car.manufacturer.founded == 1937
|
||||
assert car.manufacturer.hq.name is None
|
||||
|
||||
all_cars_check = await Car.objects.select_related("manufacturer").all()
|
||||
for car in all_cars_check:
|
||||
@ -116,5 +204,5 @@ async def test_selecting_subset():
|
||||
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
||||
# cannot exclude mandatory model columns - company__name in this example
|
||||
await Car.objects.select_related("manufacturer").fields(
|
||||
["id", "name", "company__founded"]
|
||||
["id", "name", "manufacturer__founded"]
|
||||
).all()
|
||||
|
||||
Reference in New Issue
Block a user