diff --git a/ormar/models/excludable.py b/ormar/models/excludable.py new file mode 100644 index 0000000..003f5d4 --- /dev/null +++ b/ormar/models/excludable.py @@ -0,0 +1,39 @@ +from typing import Dict, Set, Union + + +class Excludable: + @staticmethod + def get_excluded( + exclude: Union[Set, Dict, None], key: str = None + ) -> Union[Set, Dict, None]: + if isinstance(exclude, dict): + return exclude.get(key, {}) + return exclude + + @staticmethod + def get_included( + include: Union[Set, Dict, None], key: str = None + ) -> Union[Set, Dict, None]: + return Excludable.get_excluded(exclude=include, key=key) + + @staticmethod + def is_excluded(exclude: Union[Set, Dict, None], key: str = None) -> bool: + if exclude is None: + return False + to_exclude = Excludable.get_excluded(exclude=exclude, key=key) + if isinstance(to_exclude, Set): + return key in to_exclude + elif to_exclude is ...: + return True + return False + + @staticmethod + def is_included(include: Union[Set, Dict, None], key: str = None) -> bool: + if include is None: + return True + to_include = Excludable.get_included(include=include, key=key) + if isinstance(to_include, Set): + return key in to_include + elif to_include is ...: + return True + return False diff --git a/ormar/models/model.py b/ormar/models/model.py index 7cef4ee..688eed0 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -1,5 +1,5 @@ import itertools -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union import sqlalchemy @@ -47,8 +47,8 @@ class Model(NewBaseModel): select_related: List = None, related_models: Any = None, previous_table: str = None, - fields: List = None, - exclude_fields: List = None, + fields: Optional[Union[Dict, Set]] = None, + exclude_fields: Optional[Union[Dict, Set]] = None, ) -> Optional[T]: item: Dict[str, Any] = {} @@ -88,7 +88,6 @@ class Model(NewBaseModel): table_prefix=table_prefix, fields=fields, exclude_fields=exclude_fields, - nested=table_prefix != "", ) instance: Optional[T] = cls(**item) if item.get( @@ -103,13 +102,17 @@ class Model(NewBaseModel): row: sqlalchemy.engine.ResultProxy, related_models: Any, previous_table: sqlalchemy.Table, - fields: List = None, - exclude_fields: List = None, + fields: Optional[Union[Dict, Set]] = None, + exclude_fields: Optional[Union[Dict, Set]] = None, ) -> dict: for related in related_models: if isinstance(related_models, dict) and related_models[related]: first_part, remainder = related, related_models[related] model_cls = cls.Meta.model_fields[first_part].to + + fields = cls.get_included(fields, first_part) + exclude_fields = cls.get_excluded(exclude_fields, first_part) + child = model_cls.from_row( row, related_models=remainder, @@ -120,6 +123,8 @@ class Model(NewBaseModel): item[model_cls.get_column_name_from_alias(first_part)] = child else: model_cls = cls.Meta.model_fields[related].to + fields = cls.get_included(fields, related) + exclude_fields = cls.get_excluded(exclude_fields, related) child = model_cls.from_row( row, previous_table=previous_table, @@ -136,16 +141,18 @@ class Model(NewBaseModel): item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str, - fields: List = None, - exclude_fields: List = None, - nested: bool = False, + fields: Optional[Union[Dict, Set]] = None, + exclude_fields: Optional[Union[Dict, Set]] = None, ) -> dict: # databases does not keep aliases in Record for postgres, change to raw row source = row._row if cls.db_backend_name() == "postgresql" else row selected_columns = cls.own_table_columns( - cls, fields or [], exclude_fields or [], nested=nested, use_alias=True + model=cls, + fields=fields or {}, + exclude_fields=exclude_fields or {}, + use_alias=False, ) for column in cls.Meta.table.columns: diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index c39f112..4cbe4f5 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -1,6 +1,16 @@ import inspect from collections import OrderedDict -from typing import Dict, List, Sequence, Set, TYPE_CHECKING, Type, TypeVar, Union +from typing import ( + Dict, + List, + Optional, + Sequence, + Set, + TYPE_CHECKING, + Type, + TypeVar, + Union, +) import ormar from ormar.exceptions import RelationshipInstanceError @@ -65,14 +75,14 @@ class ModelTableProxy: @classmethod def get_column_alias(cls, field_name: str) -> str: field = cls.Meta.model_fields.get(field_name) - if field and field.alias is not None: + if field is not None and field.alias is not None: return field.alias return field_name @classmethod def get_column_name_from_alias(cls, alias: str) -> str: for field_name, field in cls.Meta.model_fields.items(): - if field and field.alias == alias: + if field is not None and field.alias == alias: return field_name return alias # if not found it's not an alias but actual name @@ -211,59 +221,13 @@ class ModelTableProxy: ) return other - @staticmethod - def _get_not_nested_columns_from_fields( - model: Type["Model"], - fields: List, - exclude_fields: List, - column_names: List[str], - use_alias: bool = False, - ) -> List[str]: - fields = [model.get_column_alias(k) if not use_alias else k for k in fields] - fields = fields or column_names - exclude_fields = [ - model.get_column_alias(k) if not use_alias else k for k in exclude_fields - ] - columns = [ - name - for name in fields - if "__" not in name and name in column_names and name not in exclude_fields - ] - return columns - - @staticmethod - def _get_nested_columns_from_fields( - model: Type["Model"], - fields: List, - exclude_fields: List, - column_names: List[str], - use_alias: bool = False, - ) -> List[str]: - model_name = f"{model.get_name()}__" - columns = [ - name[(name.find(model_name) + len(model_name)) :] # noqa: E203 - for name in fields - if f"{model.get_name()}__" in name - ] - columns = columns or column_names - exclude_columns = [ - name[(name.find(model_name) + len(model_name)) :] # noqa: E203 - for name in exclude_fields - if f"{model.get_name()}__" in name - ] - columns = [model.get_column_alias(k) if not use_alias else k for k in columns] - exclude_columns = [ - model.get_column_alias(k) if not use_alias else k for k in exclude_columns - ] - return [column for column in columns if column not in exclude_columns] - @staticmethod def _populate_pk_column( model: Type["Model"], columns: List[str], use_alias: bool = False, ) -> List[str]: pk_alias = ( model.get_column_alias(model.Meta.pkname) - if not use_alias + if use_alias else model.Meta.pkname ) if pk_alias not in columns: @@ -273,34 +237,30 @@ class ModelTableProxy: @staticmethod def own_table_columns( model: Type["Model"], - fields: List, - exclude_fields: List, - nested: bool = False, + fields: Optional[Union[Set, Dict]], + exclude_fields: Optional[Union[Set, Dict]], use_alias: bool = False, ) -> List[str]: - column_names = [ - model.get_column_name_from_alias(col.name) if use_alias else col.name + columns = [ + model.get_column_name_from_alias(col.name) if not use_alias else col.name for col in model.Meta.table.columns ] - if not fields and not exclude_fields: - return column_names - - if not nested: - columns = ModelTableProxy._get_not_nested_columns_from_fields( - model=model, - fields=fields, - exclude_fields=exclude_fields, - column_names=column_names, - use_alias=use_alias, - ) - else: - columns = ModelTableProxy._get_nested_columns_from_fields( - model=model, - fields=fields, - exclude_fields=exclude_fields, - column_names=column_names, - use_alias=use_alias, - ) + field_names = [ + model.get_column_name_from_alias(col.name) + for col in model.Meta.table.columns + ] + if fields: + columns = [ + col + for col, name in zip(columns, field_names) + if model.is_included(fields, name) + ] + if exclude_fields: + columns = [ + col + for col, name in zip(columns, field_names) + if not model.is_excluded(exclude_fields, name) + ] # always has to return pk column columns = ModelTableProxy._populate_pk_column( diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 2d6ac22..75ef303 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -23,6 +23,7 @@ from pydantic import BaseModel import ormar # noqa I100 from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField +from ormar.models.excludable import Excludable from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.modelproxy import ModelTableProxy from ormar.relations.alias_manager import AliasManager @@ -39,7 +40,9 @@ if TYPE_CHECKING: # pragma no cover MappingIntStrAny = Mapping[IntStr, Any] -class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass): +class NewBaseModel( + pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass +): __slots__ = ("_orm_id", "_orm_saved", "_orm") if TYPE_CHECKING: # pragma no cover diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index 496020e..101e708 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -1,5 +1,15 @@ from collections import OrderedDict -from typing import List, NamedTuple, Optional, TYPE_CHECKING, Tuple, Type +from typing import ( + Dict, + List, + NamedTuple, + Optional, + Set, + TYPE_CHECKING, + Tuple, + Type, + Union, +) import sqlalchemy from sqlalchemy import text @@ -24,8 +34,8 @@ class SqlJoin: used_aliases: List, select_from: sqlalchemy.sql.select, columns: List[sqlalchemy.Column], - fields: List, - exclude_fields: List, + fields: Optional[Union[Set, Dict]], + exclude_fields: Optional[Union[Set, Dict]], order_columns: Optional[List], sorted_orders: OrderedDict, ) -> None: @@ -49,21 +59,60 @@ class SqlJoin: right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" return text(f"{left_part}={right_part}") - def build_join( + @staticmethod + def update_inclusions( + model_cls: Type["Model"], + fields: Optional[Union[Set, Dict]], + exclude_fields: Optional[Union[Set, Dict]], + nested_name: str, + ) -> Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]: + fields = model_cls.get_included(fields, nested_name) + exclude_fields = model_cls.get_included(exclude_fields, nested_name) + return fields, exclude_fields + + def build_join( # noqa: CCR001 self, item: str, join_parameters: JoinParameters ) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]: - for part in item.split("__"): + + fields = self.fields + exclude_fields = self.exclude_fields + + for index, part in enumerate(item.split("__")): if issubclass( join_parameters.model_cls.Meta.model_fields[part], ManyToManyField ): _fields = join_parameters.model_cls.Meta.model_fields new_part = _fields[part].to.get_name() self._switch_many_to_many_order_columns(part, new_part) + if index > 0: # nested joins + fields, exclude_fields = SqlJoin.update_inclusions( + model_cls=join_parameters.model_cls, + fields=fields, + exclude_fields=exclude_fields, + nested_name=part, + ) + join_parameters = self._build_join_parameters( - part, join_parameters, is_multi=True + part=part, + join_params=join_parameters, + is_multi=True, + fields=fields, + exclude_fields=exclude_fields, ) part = new_part - join_parameters = self._build_join_parameters(part, join_parameters) + if index > 0: # nested joins + fields, exclude_fields = SqlJoin.update_inclusions( + model_cls=join_parameters.model_cls, + fields=fields, + exclude_fields=exclude_fields, + nested_name=part, + ) + join_parameters = self._build_join_parameters( + part=part, + join_params=join_parameters, + fields=fields, + exclude_fields=exclude_fields, + ) return ( self.used_aliases, @@ -73,7 +122,12 @@ class SqlJoin: ) def _build_join_parameters( - self, part: str, join_params: JoinParameters, is_multi: bool = False + self, + part: str, + join_params: JoinParameters, + fields: Optional[Union[Set, Dict]], + exclude_fields: Optional[Union[Set, Dict]], + is_multi: bool = False, ) -> JoinParameters: if is_multi: model_cls = join_params.model_cls.Meta.model_fields[part].through @@ -85,20 +139,30 @@ class SqlJoin: join_params.from_table, to_table ) if alias not in self.used_aliases: - self._process_join(join_params, is_multi, model_cls, part, alias) + self._process_join( + join_params=join_params, + is_multi=is_multi, + model_cls=model_cls, + part=part, + alias=alias, + fields=fields, + exclude_fields=exclude_fields, + ) previous_alias = alias from_table = to_table prev_model = model_cls return JoinParameters(prev_model, previous_alias, from_table, model_cls) - def _process_join( + def _process_join( # noqa: CFQ002 self, join_params: JoinParameters, is_multi: bool, model_cls: Type["Model"], part: str, alias: str, + fields: Optional[Union[Set, Dict]], + exclude_fields: Optional[Union[Set, Dict]], ) -> None: to_table = model_cls.Meta.table.name to_key, from_key = self.get_to_and_from_keys( @@ -129,7 +193,10 @@ class SqlJoin: ) self_related_fields = model_cls.own_table_columns( - model_cls, self.fields, self.exclude_fields, nested=True, + model=model_cls, + fields=fields, + exclude_fields=exclude_fields, + use_alias=True, ) self.columns.extend( self.relation_manager(model_cls).prefixed_columns( diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 1bdf829..dc08b99 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -1,5 +1,6 @@ +import copy from collections import OrderedDict -from typing import List, Optional, TYPE_CHECKING, Tuple, Type +from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Type, Union import sqlalchemy from sqlalchemy import text @@ -21,8 +22,8 @@ class Query: select_related: List, limit_count: Optional[int], offset: Optional[int], - fields: Optional[List], - exclude_fields: Optional[List], + fields: Optional[Union[Dict, Set]], + exclude_fields: Optional[Union[Dict, Set]], order_bys: Optional[List], ) -> None: self.query_offset = offset @@ -30,8 +31,8 @@ class Query: self._select_related = select_related[:] self.filter_clauses = filter_clauses[:] self.exclude_clauses = exclude_clauses[:] - self.fields = fields[:] if fields else [] - self.exclude_fields = exclude_fields[:] if exclude_fields else [] + self.fields = copy.deepcopy(fields) if fields else {} + self.exclude_fields = copy.deepcopy(exclude_fields) if exclude_fields else {} self.model_cls = model_cls self.table = self.model_cls.Meta.table @@ -73,7 +74,10 @@ class Query: def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: self_related_fields = self.model_cls.own_table_columns( - self.model_cls, self.fields, self.exclude_fields + model=self.model_cls, + fields=self.fields, + exclude_fields=self.exclude_fields, + use_alias=True, ) self.columns = self.model_cls.Meta.alias_manager.prefixed_columns( "", self.table, self_related_fields @@ -87,13 +91,14 @@ class Query: join_parameters = JoinParameters( self.model_cls, "", self.table.name, self.model_cls ) - + fields = self.model_cls.get_included(self.fields, item) + exclude_fields = self.model_cls.get_excluded(self.exclude_fields, item) sql_join = SqlJoin( used_aliases=self.used_aliases, select_from=self.select_from, columns=self.columns, - fields=self.fields, - exclude_fields=self.exclude_fields, + fields=fields, + exclude_fields=exclude_fields, order_columns=self.order_columns, sorted_orders=self.sorted_orders, ) @@ -131,5 +136,5 @@ class Query: self.select_from = [] self.columns = [] self.used_aliases = [] - self.fields = [] - self.exclude_fields = [] + self.fields = {} + self.exclude_fields = {} diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index a6aa9e7..ce2648e 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1,4 +1,5 @@ -from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Type, Union +import copy +from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union import databases import sqlalchemy @@ -10,6 +11,7 @@ from ormar.exceptions import QueryDefinitionError from ormar.queryset import FilterQuery from ormar.queryset.clause import QueryClause from ormar.queryset.query import Query +from ormar.queryset.utils import update, update_dict_from_list if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -26,8 +28,8 @@ class QuerySet: select_related: List = None, limit_count: int = None, offset: int = None, - columns: List = None, - exclude_columns: List = None, + columns: Dict = None, + exclude_columns: Dict = None, order_bys: List = None, ) -> None: self.model_cls = model_cls @@ -36,8 +38,8 @@ class QuerySet: self._select_related = [] if select_related is None else select_related self.limit_count = limit_count self.query_offset = offset - self._columns = columns or [] - self._exclude_columns = exclude_columns or [] + self._columns = columns or {} + self._exclude_columns = exclude_columns or {} self.order_bys = order_bys or [] def __get__( @@ -169,11 +171,16 @@ class QuerySet: order_bys=self.order_bys, ) - def exclude_fields(self, columns: Union[List, str]) -> "QuerySet": - if not isinstance(columns, list): + def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": + if isinstance(columns, str): columns = [columns] - columns = list(set(list(self._exclude_columns) + columns)) + current_excluded = copy.deepcopy(self._exclude_columns) + if not isinstance(columns, dict): + current_excluded = update_dict_from_list(current_excluded, columns) + else: + current_excluded = update(current_excluded, columns) + return self.__class__( model_cls=self.model, filter_clauses=self.filter_clauses, @@ -182,15 +189,20 @@ class QuerySet: limit_count=self.limit_count, offset=self.query_offset, columns=self._columns, - exclude_columns=columns, + exclude_columns=current_excluded, order_bys=self.order_bys, ) - def fields(self, columns: Union[List, str]) -> "QuerySet": - if not isinstance(columns, list): + def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": + if isinstance(columns, str): columns = [columns] - columns = list(set(list(self._columns) + columns)) + current_included = copy.deepcopy(self._exclude_columns) + if not isinstance(columns, dict): + current_included = update_dict_from_list(current_included, columns) + else: + current_included = update(current_included, columns) + return self.__class__( model_cls=self.model, filter_clauses=self.filter_clauses, @@ -198,7 +210,7 @@ class QuerySet: select_related=self._select_related, limit_count=self.limit_count, offset=self.query_offset, - columns=columns, + columns=current_included, exclude_columns=self._exclude_columns, order_bys=self.order_bys, ) diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py new file mode 100644 index 0000000..c3c8fa9 --- /dev/null +++ b/ormar/queryset/utils.py @@ -0,0 +1,57 @@ +import collections.abc +import copy +from typing import Any, Dict, List, Set, Union + + +def check_node_not_dict_or_not_last_node( + part: str, parts: List, current_level: Any +) -> bool: + return (part not in current_level and part != parts[-1]) or ( + part in current_level and not isinstance(current_level[part], dict) + ) + + +def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001 + new_dict: Dict = dict() + for path in list_to_trans: + current_level = new_dict + parts = path.split("__") + for part in parts: + if check_node_not_dict_or_not_last_node( + part=part, parts=parts, current_level=current_level + ): + current_level[part] = dict() + elif part not in current_level: + current_level[part] = ... + current_level = current_level[part] + return new_dict + + +def convert_set_to_required_dict(set_to_convert: set) -> Dict: + new_dict = dict() + for key in set_to_convert: + new_dict[key] = Ellipsis + return new_dict + + +def update(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001 + if current_dict is Ellipsis: + current_dict = dict() + for key, value in updating_dict.items(): + if isinstance(value, collections.abc.Mapping): + old_key = current_dict.get(key, {}) + if isinstance(old_key, set): + old_key = convert_set_to_required_dict(old_key) + current_dict[key] = update(old_key, value) + elif isinstance(value, set) and isinstance(current_dict.get(key), set): + current_dict[key] = current_dict.get(key).union(value) + else: + current_dict[key] = value + return current_dict + + +def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> Dict: + updated_dict = copy.copy(curr_dict) + dict_to_update = translate_list_to_dict(list_to_update) + update(updated_dict, dict_to_update) + return updated_dict diff --git a/tests/test_aliases.py b/tests/test_aliases.py index 81c4a8c..df472cc 100644 --- a/tests/test_aliases.py +++ b/tests/test_aliases.py @@ -117,8 +117,8 @@ async def test_working_with_aliases(): "first_name", "last_name", "born_year", - "child__first_name", - "child__last_name", + "children__first_name", + "children__last_name", ] ) .get() diff --git a/tests/test_excluding_subset_of_columns.py b/tests/test_excluding_subset_of_columns.py index 6b7a6a5..dd9ac57 100644 --- a/tests/test_excluding_subset_of_columns.py +++ b/tests/test_excluding_subset_of_columns.py @@ -80,10 +80,53 @@ async def test_selecting_subset(): all_cars = ( await Car.objects.select_related("manufacturer") .exclude_fields( - ["gearbox_type", "gears", "aircon_type", "year", "company__founded"] + [ + "gearbox_type", + "gears", + "aircon_type", + "year", + "manufacturer__founded", + ] ) .all() ) + for car in all_cars: + assert all( + getattr(car, x) is None + for x in ["year", "gearbox_type", "gears", "aircon_type"] + ) + assert car.manufacturer.name == "Toyota" + assert car.manufacturer.founded is None + + all_cars = ( + await Car.objects.select_related("manufacturer") + .exclude_fields( + { + "gearbox_type": ..., + "gears": ..., + "aircon_type": ..., + "year": ..., + "manufacturer": {"founded": ...}, + } + ) + .all() + ) + all_cars2 = ( + await Car.objects.select_related("manufacturer") + .exclude_fields( + { + "gearbox_type": ..., + "gears": ..., + "aircon_type": ..., + "year": ..., + "manufacturer": {"founded"}, + } + ) + .all() + ) + + assert all_cars == all_cars2 + for car in all_cars: assert all( getattr(car, x) is None @@ -119,7 +162,7 @@ async def test_selecting_subset(): all_cars_check2 = ( await Car.objects.select_related("manufacturer") .fields(["id", "name", "manufacturer"]) - .exclude_fields("company__founded") + .exclude_fields("manufacturer__founded") .all() ) for car in all_cars_check2: @@ -133,5 +176,5 @@ async def test_selecting_subset(): with pytest.raises(pydantic.error_wrappers.ValidationError): # cannot exclude mandatory model columns - company__name in this example await Car.objects.select_related("manufacturer").exclude_fields( - ["company__name"] + ["manufacturer__name"] ).all() diff --git a/tests/test_queryset_utils.py b/tests/test_queryset_utils.py new file mode 100644 index 0000000..bac0a26 --- /dev/null +++ b/tests/test_queryset_utils.py @@ -0,0 +1,98 @@ +from ormar.models.excludable import Excludable +from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update + + +def test_empty_excludable(): + assert Excludable.is_included(None, "key") # all fields included if empty + assert not Excludable.is_excluded(None, "key") # none field excluded if empty + + +def test_list_to_dict_translation(): + tet_list = ["aa", "bb", "cc__aa", "cc__bb", "cc__aa__xx", "cc__aa__yy"] + test = translate_list_to_dict(tet_list) + assert test == { + "aa": Ellipsis, + "bb": Ellipsis, + "cc": {"aa": {"xx": Ellipsis, "yy": Ellipsis}, "bb": Ellipsis}, + } + + +def test_updating_dict_with_list(): + curr_dict = { + "aa": Ellipsis, + "bb": Ellipsis, + "cc": {"aa": {"xx": Ellipsis, "yy": Ellipsis}, "bb": Ellipsis}, + } + list_to_update = ["ee", "bb__cc", "cc__aa__xx__oo", "cc__aa__oo"] + test = update_dict_from_list(curr_dict, list_to_update) + assert test == { + "aa": Ellipsis, + "bb": {"cc": Ellipsis}, + "cc": { + "aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis}, + "bb": Ellipsis, + }, + "ee": Ellipsis, + } + + +def test_updating_dict_inc_set_with_list(): + curr_dict = { + "aa": Ellipsis, + "bb": Ellipsis, + "cc": {"aa": {"xx", "yy"}, "bb": Ellipsis}, + } + list_to_update = ["uu", "bb__cc", "cc__aa__xx__oo", "cc__aa__oo"] + test = update_dict_from_list(curr_dict, list_to_update) + assert test == { + "aa": Ellipsis, + "bb": {"cc": Ellipsis}, + "cc": { + "aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis}, + "bb": Ellipsis, + }, + "uu": Ellipsis, + } + + +def test_updating_dict_inc_set_with_dict(): + curr_dict = { + "aa": Ellipsis, + "bb": Ellipsis, + "cc": {"aa": {"xx", "yy"}, "bb": Ellipsis}, + } + dict_to_update = { + "uu": Ellipsis, + "bb": {"cc", "dd"}, + "cc": {"aa": {"xx": {"oo": Ellipsis}, "oo": Ellipsis}}, + } + test = update(curr_dict, dict_to_update) + assert test == { + "aa": Ellipsis, + "bb": {"cc", "dd"}, + "cc": { + "aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis}, + "bb": Ellipsis, + }, + "uu": Ellipsis, + } + + +def test_updating_dict_inc_set_with_dict_inc_set(): + curr_dict = { + "aa": Ellipsis, + "bb": Ellipsis, + "cc": {"aa": {"xx", "yy"}, "bb": Ellipsis}, + } + dict_to_update = { + "uu": Ellipsis, + "bb": {"cc", "dd"}, + "cc": {"aa": {"xx", "oo", "zz", "ii"}}, + } + test = update(curr_dict, dict_to_update) + assert test == { + "aa": Ellipsis, + "bb": {"cc", "dd"}, + "cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis}, + "uu": Ellipsis, + } diff --git a/tests/test_selecting_subset_of_columns.py b/tests/test_selecting_subset_of_columns.py index 0f76f1a..e9133b6 100644 --- a/tests/test_selecting_subset_of_columns.py +++ b/tests/test_selecting_subset_of_columns.py @@ -1,4 +1,5 @@ -from typing import Optional +import itertools +from typing import Optional, List import databases import pydantic @@ -12,6 +13,35 @@ database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() +class NickNames(ormar.Model): + class Meta: + tablename = "nicks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + is_lame: bool = ormar.Boolean(nullable=True) + + +class NicksHq(ormar.Model): + class Meta: + tablename = "nicks_x_hq" + metadata = metadata + database = database + + +class HQ(ormar.Model): + class Meta: + tablename = "hqs" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq) + + class Company(ormar.Model): class Meta: tablename = "companies" @@ -21,6 +51,7 @@ class Company(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="company_name") founded: int = ormar.Integer(nullable=True) + hq: HQ = ormar.ForeignKey(HQ) class Car(ormar.Model): @@ -51,7 +82,14 @@ def create_test_database(): async def test_selecting_subset(): async with database: async with database.transaction(force_rollback=True): - toyota = await Company.objects.create(name="Toyota", founded=1937) + nick1 = await NickNames.objects.create(name="Nippon", is_lame=False) + nick2 = await NickNames.objects.create(name="EroCherry", is_lame=True) + hq = await HQ.objects.create(name="Japan") + await hq.nicks.add(nick1) + await hq.nicks.add(nick2) + + toyota = await Company.objects.create(name="Toyota", founded=1937, hq=hq) + await Car.objects.create( manufacturer=toyota, name="Corolla", @@ -78,17 +116,66 @@ async def test_selecting_subset(): ) all_cars = ( - await Car.objects.select_related("manufacturer") - .fields(["id", "name", "company__name"]) + await Car.objects.select_related( + ["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"] + ) + .fields( + [ + "id", + "name", + "manufacturer__name", + "manufacturer__hq__name", + "manufacturer__hq__nicks__name", + ] + ) .all() ) - for car in all_cars: + + all_cars2 = ( + await Car.objects.select_related( + ["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"] + ) + .fields( + { + "id": ..., + "name": ..., + "manufacturer": { + "name": ..., + "hq": {"name": ..., "nicks": {"name": ...}}, + }, + } + ) + .all() + ) + + all_cars3 = ( + await Car.objects.select_related( + ["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"] + ) + .fields( + { + "id": ..., + "name": ..., + "manufacturer": { + "name": ..., + "hq": {"name": ..., "nicks": {"name"}}, + }, + } + ) + .all() + ) + assert all_cars3 == all_cars + + for car in itertools.chain(all_cars, all_cars2): assert all( getattr(car, x) is None for x in ["year", "gearbox_type", "gears", "aircon_type"] ) assert car.manufacturer.name == "Toyota" assert car.manufacturer.founded is None + assert car.manufacturer.hq.name == "Japan" + assert len(car.manufacturer.hq.nicks) == 2 + assert car.manufacturer.hq.nicks[0].is_lame is None all_cars = ( await Car.objects.select_related("manufacturer") @@ -103,6 +190,7 @@ async def test_selecting_subset(): ) assert car.manufacturer.name == "Toyota" assert car.manufacturer.founded == 1937 + assert car.manufacturer.hq.name is None all_cars_check = await Car.objects.select_related("manufacturer").all() for car in all_cars_check: @@ -116,5 +204,5 @@ async def test_selecting_subset(): with pytest.raises(pydantic.error_wrappers.ValidationError): # cannot exclude mandatory model columns - company__name in this example await Car.objects.select_related("manufacturer").fields( - ["id", "name", "company__founded"] + ["id", "name", "manufacturer__founded"] ).all()