diff --git a/ormar/__init__.py b/ormar/__init__.py index 328e894..4868e8f 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -54,7 +54,7 @@ from ormar.fields import ( UUID, UniqueColumns, ) # noqa: I100 -from ormar.models import Model +from ormar.models import ExcludableItems, Model from ormar.models.metaclass import ModelMeta from ormar.queryset import OrderAction, QuerySet from ormar.relations import RelationType @@ -107,4 +107,5 @@ __all__ = [ "ManyToManyField", "ForeignKeyField", "OrderAction", + "ExcludableItems", ] diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 155b1c9..1fada90 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -93,11 +93,10 @@ class BaseField(FieldInfo): :rtype: bool """ return ( - field_name not in ["default", "default_factory", "alias", - "allow_mutation"] - and not field_name.startswith("__") - and hasattr(cls, field_name) - and not callable(getattr(cls, field_name)) + field_name not in ["default", "default_factory", "alias", "allow_mutation"] + and not field_name.startswith("__") + and hasattr(cls, field_name) + and not callable(getattr(cls, field_name)) ) @classmethod @@ -206,7 +205,7 @@ class BaseField(FieldInfo): :rtype: bool """ return cls.default is not None or ( - cls.server_default is not None and use_server + cls.server_default is not None and use_server ) @classmethod @@ -239,7 +238,7 @@ class BaseField(FieldInfo): ondelete=con.ondelete, onupdate=con.onupdate, name=f"fk_{cls.owner.Meta.tablename}_{cls.to.Meta.tablename}" - f"_{cls.to.get_column_alias(cls.to.Meta.pkname)}_{cls.name}", + f"_{cls.to.get_column_alias(cls.to.Meta.pkname)}_{cls.name}", ) for con in cls.constraints ] @@ -272,10 +271,10 @@ class BaseField(FieldInfo): @classmethod def expand_relationship( - cls, - value: Any, - child: Union["Model", "NewBaseModel"], - to_register: bool = True, + cls, + value: Any, + child: Union["Model", "NewBaseModel"], + to_register: bool = True, ) -> Any: """ Function overwritten for relations, in basic field the value is returned as is. @@ -303,7 +302,7 @@ class BaseField(FieldInfo): :rtype: None """ if cls.owner is not None and ( - cls.owner == cls.to or cls.owner.Meta == cls.to.Meta + cls.owner == cls.to or cls.owner.Meta == cls.to.Meta ): cls.self_reference = True cls.self_reference_primary = cls.name diff --git a/ormar/fields/through_field.py b/ormar/fields/through_field.py index e5e4a24..b25e94b 100644 --- a/ormar/fields/through_field.py +++ b/ormar/fields/through_field.py @@ -17,11 +17,9 @@ if TYPE_CHECKING: # pragma no cover def Through( # noqa CFQ002 to: "ToType", *, name: str = None, related_name: str = None, **kwargs: Any, ) -> Any: - # TODO: clean docstring """ - Despite a name it's a function that returns constructed ForeignKeyField. - This function is actually used in model declaration (as ormar.ForeignKey(ToModel)). - + Despite a name it's a function that returns constructed ThroughField. + It's a special field populated only for m2m relations. Accepts number of relation setting parameters as well as all BaseField ones. :param to: target related ormar Model @@ -30,15 +28,13 @@ def Through( # noqa CFQ002 :type name: str :param related_name: name of reversed FK relation populated for you on to model :type related_name: str - :param virtual: marks if relation is virtual. It is for reversed FK and auto generated FK on through model in Many2Many relations. - :type virtual: bool :param kwargs: all other args to be populated by BaseField :type kwargs: Any :return: ormar ForeignKeyField with relation to selected model :rtype: ForeignKeyField """ - + nullable = kwargs.pop("nullable", False) owner = kwargs.pop("owner", None) namespace = dict( __type__=to, @@ -49,7 +45,7 @@ def Through( # noqa CFQ002 related_name=related_name, virtual=True, owner=owner, - nullable=False, + nullable=nullable, unique=False, column_type=None, primary_key=False, diff --git a/ormar/models/__init__.py b/ormar/models/__init__.py index 58372b7..0ecd9cc 100644 --- a/ormar/models/__init__.py +++ b/ormar/models/__init__.py @@ -7,5 +7,6 @@ ass well as vast number of helper functions for pydantic, sqlalchemy and relatio from ormar.models.newbasemodel import NewBaseModel # noqa I100 from ormar.models.model_row import ModelRow # noqa I100 from ormar.models.model import Model, T # noqa I100 +from ormar.models.excludable import ExcludableItems # noqa I100 -__all__ = ["T", "NewBaseModel", "Model", "ModelRow"] +__all__ = ["T", "NewBaseModel", "Model", "ModelRow", "ExcludableItems"] diff --git a/ormar/models/excludable.py b/ormar/models/excludable.py index b8832d2..203d8a0 100644 --- a/ormar/models/excludable.py +++ b/ormar/models/excludable.py @@ -7,19 +7,12 @@ if TYPE_CHECKING: # pragma: no cover from ormar import Model +# TODO: Add docstrings @dataclass class Excludable: include: Set = field(default_factory=set) exclude: Set = field(default_factory=set) - @property - def include_all(self): - return ... in self.include - - @property - def exclude_all(self): - return ... in self.exclude - def get_copy(self) -> "Excludable": _copy = self.__class__() _copy.include = {x for x in self.include} @@ -28,12 +21,9 @@ class Excludable: def set_values(self, value: Set, is_exclude: bool) -> None: prop = "exclude" if is_exclude else "include" - if ... in getattr(self, prop) or ... in value: - setattr(self, prop, {...}) - else: - current_value = getattr(self, prop) - current_value.update(value) - setattr(self, prop, current_value) + current_value = getattr(self, prop) + current_value.update(value) + setattr(self, prop, current_value) def is_included(self, key: str) -> bool: return (... in self.include or key in self.include) if self.include else True @@ -61,13 +51,17 @@ class ExcludableItems: def get(self, model_cls: Type["Model"], alias: str = "") -> Excludable: key = f"{alias + '_' if alias else ''}{model_cls.get_name(lower=True)}" - return self.items.get(key, Excludable()) + excludable = self.items.get(key) + if not excludable: + excludable = Excludable() + self.items[key] = excludable + return excludable def build( - self, - items: Union[List[str], str, Tuple[str], Set[str], Dict], - model_cls: Type["Model"], - is_exclude: bool = False, + self, + items: Union[List[str], str, Tuple[str], Set[str], Dict], + model_cls: Type["Model"], + is_exclude: bool = False, ) -> None: if isinstance(items, str): @@ -96,7 +90,7 @@ class ExcludableItems: ) def _set_excludes( - self, items: Set, model_name: str, is_exclude: bool, alias: str = "" + self, items: Set, model_name: str, is_exclude: bool, alias: str = "" ) -> None: key = f"{alias + '_' if alias else ''}{model_name}" @@ -107,13 +101,13 @@ class ExcludableItems: self.items[key] = excludable def _traverse_dict( # noqa: CFQ002 - self, - values: Dict, - source_model: Type["Model"], - model_cls: Type["Model"], - is_exclude: bool, - related_items: List = None, - alias: str = "", + self, + values: Dict, + source_model: Type["Model"], + model_cls: Type["Model"], + is_exclude: bool, + related_items: List = None, + alias: str = "", ) -> None: self_fields = set() @@ -122,14 +116,13 @@ class ExcludableItems: if value is ...: self_fields.add(key) elif isinstance(value, set): - related_items.append(key) ( table_prefix, target_model, _, _, ) = get_relationship_alias_model_and_str( - source_model=source_model, related_parts=related_items + source_model=source_model, related_parts=related_items + [key] ) self._set_excludes( items=value, @@ -165,7 +158,7 @@ class ExcludableItems: ) def _traverse_list( - self, values: Set[str], model_cls: Type["Model"], is_exclude: bool + self, values: Set[str], model_cls: Type["Model"], is_exclude: bool ) -> None: # here we have only nested related keys diff --git a/ormar/models/mixins/excludable_mixin.py b/ormar/models/mixins/excludable_mixin.py index 4b25035..a7850d5 100644 --- a/ormar/models/mixins/excludable_mixin.py +++ b/ormar/models/mixins/excludable_mixin.py @@ -4,12 +4,12 @@ from typing import ( Dict, List, Mapping, - Optional, Set, TYPE_CHECKING, Type, TypeVar, - Union, cast, + Union, + cast, ) from ormar.models.excludable import ExcludableItems @@ -36,7 +36,7 @@ class ExcludableMixin(RelationMixin): @staticmethod def get_child( - items: Union[Set, Dict, None], key: str = None + items: Union[Set, Dict, None], key: str = None ) -> Union[Set, Dict, None]: """ Used to get nested dictionaries keys if they exists otherwise returns @@ -52,89 +52,11 @@ class ExcludableMixin(RelationMixin): return items.get(key, {}) return items - @staticmethod - def get_excluded( - exclude: Union[Set, Dict, None], key: str = None - ) -> Union[Set, Dict, None]: - """ - Proxy to ExcludableMixin.get_child for exclusions. - - :param exclude: bag of items to exclude - :type exclude: Union[Set, Dict, None] - :param key: name of the child to extract - :type key: str - :return: child extracted from items if exists - :rtype: Union[Set, Dict, None] - """ - return ExcludableMixin.get_child(items=exclude, key=key) - - @staticmethod - def get_included( - include: Union[Set, Dict, None], key: str = None - ) -> Union[Set, Dict, None]: - """ - Proxy to ExcludableMixin.get_child for inclusions. - - :param include: bag of items to include - :type include: Union[Set, Dict, None] - :param key: name of the child to extract - :type key: str - :return: child extracted from items if exists - :rtype: Union[Set, Dict, None] - """ - return ExcludableMixin.get_child(items=include, key=key) - - @staticmethod - def is_excluded(exclude: Union[Set, Dict, None], key: str = None) -> bool: - """ - Checks if given key should be excluded on model/ dict. - - :param exclude: bag of items to exclude - :type exclude: Union[Set, Dict, None] - :param key: name of the child to extract - :type key: str - :return: child extracted from items if exists - :rtype: Union[Set, Dict, None] - """ - if exclude is None: - return False - if exclude is Ellipsis: # pragma: nocover - return True - to_exclude = ExcludableMixin.get_excluded(exclude=exclude, key=key) - if isinstance(to_exclude, Set): - return key in to_exclude - if to_exclude is ...: - return True - return False - - @staticmethod - def is_included(include: Union[Set, Dict, None], key: str = None) -> bool: - """ - Checks if given key should be included on model/ dict. - - :param include: bag of items to include - :type include: Union[Set, Dict, None] - :param key: name of the child to extract - :type key: str - :return: child extracted from items if exists - :rtype: Union[Set, Dict, None] - """ - if include is None: - return True - if include is Ellipsis: - return True - to_include = ExcludableMixin.get_included(include=include, key=key) - if isinstance(to_include, Set): - return key in to_include - if to_include is ...: - return True - return False - @staticmethod def _populate_pk_column( - model: Union[Type["Model"], Type["ModelRow"]], - columns: List[str], - use_alias: bool = False, + model: Union[Type["Model"], Type["ModelRow"]], + columns: List[str], + use_alias: bool = False, ) -> List[str]: """ Adds primary key column/alias (depends on use_alias flag) to list of @@ -160,13 +82,12 @@ class ExcludableMixin(RelationMixin): @classmethod def own_table_columns( - cls, - model: Union[Type["Model"], Type["ModelRow"]], - excludable: ExcludableItems, - alias: str = '', - use_alias: bool = False, + cls, + model: Union[Type["Model"], Type["ModelRow"]], + excludable: ExcludableItems, + alias: str = "", + use_alias: bool = False, ) -> List[str]: - # TODO update docstring """ Returns list of aliases or field names for given model. Aliases/names switch is use_alias flag. @@ -176,6 +97,10 @@ class ExcludableMixin(RelationMixin): Primary key field is always added and cannot be excluded (will be added anyway). + :param alias: relation prefix + :type alias: str + :param excludable: structure of fields to include and exclude + :type excludable: ExcludableItems :param model: model on columns are selected :type model: Type["Model"] :param use_alias: flag if aliases or field names should be used @@ -183,7 +108,7 @@ class ExcludableMixin(RelationMixin): :return: list of column field names or aliases :rtype: List[str] """ - model_excludable = excludable.get(model_cls=model, alias=alias) + model_excludable = excludable.get(model_cls=model, alias=alias) # type: ignore columns = [ model.get_column_name_from_alias(col.name) if not use_alias else col.name for col in model.Meta.table.columns @@ -214,9 +139,9 @@ class ExcludableMixin(RelationMixin): @classmethod def _update_excluded_with_related_not_required( - cls, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny", None], - nested: bool = False, + cls, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny", None], + nested: bool = False, ) -> Union[Set, Dict]: """ Used during generation of the dict(). @@ -243,11 +168,7 @@ class ExcludableMixin(RelationMixin): return exclude @classmethod - def get_names_to_exclude( - cls, - excludable: ExcludableItems, - alias: str - ) -> Set: + def get_names_to_exclude(cls, excludable: ExcludableItems, alias: str) -> Set: """ Returns a set of models field names that should be explicitly excluded during model initialization. @@ -268,7 +189,7 @@ class ExcludableMixin(RelationMixin): model = cast(Type["Model"], cls) model_excludable = excludable.get(model_cls=model, alias=alias) fields_names = cls.extract_db_own_fields() - if model_excludable.include and model_excludable.include_all: + if model_excludable.include: fields_to_keep = model_excludable.include.intersection(fields_names) else: fields_to_keep = fields_names diff --git a/ormar/models/model_row.py b/ormar/models/model_row.py index 6a6cb0e..1b15e6d 100644 --- a/ormar/models/model_row.py +++ b/ormar/models/model_row.py @@ -3,11 +3,9 @@ from typing import ( Dict, List, Optional, - Set, TYPE_CHECKING, Type, TypeVar, - Union, cast, ) @@ -17,7 +15,6 @@ from ormar.models import NewBaseModel # noqa: I202 from ormar.models.excludable import ExcludableItems from ormar.models.helpers.models import group_related_list - if TYPE_CHECKING: # pragma: no cover from ormar.fields import ForeignKeyField from ormar.models import T @@ -36,6 +33,7 @@ class ModelRow(NewBaseModel): related_field: Type["ForeignKeyField"] = None, excludable: ExcludableItems = None, current_relation_str: str = "", + proxy_source_model: Optional[Type["ModelRow"]] = None, ) -> Optional[T]: """ Model method to convert raw sql row from database into ormar.Model instance. @@ -91,12 +89,10 @@ class ModelRow(NewBaseModel): excludable=excludable, current_relation_str=current_relation_str, source_model=source_model, + proxy_source_model=proxy_source_model, # type: ignore ) item = cls.extract_prefixed_table_columns( - item=item, - row=row, - table_prefix=table_prefix, - excludable=excludable + item=item, row=row, table_prefix=table_prefix, excludable=excludable ) instance: Optional[T] = None @@ -117,6 +113,7 @@ class ModelRow(NewBaseModel): related_models: Any, excludable: ExcludableItems, current_relation_str: str = None, + proxy_source_model: Type[T] = None, ) -> dict: """ Traverses structure of related models and populates the nested models @@ -165,20 +162,22 @@ class ModelRow(NewBaseModel): excludable=excludable, current_relation_str=relation_str, source_model=source_model, + proxy_source_model=proxy_source_model, ) item[model_cls.get_column_name_from_alias(related)] = child if field.is_multi and child: - # TODO: way to figure out which side should be populated? through_name = cls.Meta.model_fields[related].through.get_name() - # for now it's nested dict, should be instance? through_child = cls.populate_through_instance( row=row, related=related, through_name=through_name, - excludable=excludable + excludable=excludable, ) - item[through_name] = through_child - setattr(child, through_name, through_child) + + if child.__class__ != proxy_source_model: + setattr(child, through_name, through_child) + else: + item[through_name] = through_child child.set_save_status(True) return item @@ -189,19 +188,24 @@ class ModelRow(NewBaseModel): row: sqlalchemy.engine.ResultProxy, through_name: str, related: str, - excludable: ExcludableItems - ) -> Dict: - # TODO: fix excludes and includes and docstring + excludable: ExcludableItems, + ) -> "ModelRow": model_cls = cls.Meta.model_fields[through_name].to table_prefix = cls.Meta.alias_manager.resolve_relation_alias( from_model=cls, relation_name=related ) - child = model_cls.extract_prefixed_table_columns( - item={}, - row=row, - excludable=excludable, - table_prefix=table_prefix + # remove relations on through field + model_excludable = excludable.get(model_cls=model_cls, alias=table_prefix) + model_excludable.set_values( + value=model_cls.extract_related_names(), is_exclude=True ) + child_dict = model_cls.extract_prefixed_table_columns( + item={}, row=row, excludable=excludable, table_prefix=table_prefix + ) + child_dict["__excluded__"] = model_cls.get_names_to_exclude( + excludable=excludable, alias=table_prefix + ) + child = model_cls(**child_dict) # type: ignore return child @classmethod @@ -210,7 +214,7 @@ class ModelRow(NewBaseModel): item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str, - excludable: ExcludableItems + excludable: ExcludableItems, ) -> Dict: """ Extracts own fields from raw sql result, using a given prefix. @@ -242,10 +246,7 @@ class ModelRow(NewBaseModel): source = row._row if cls.db_backend_name() == "postgresql" else row selected_columns = cls.own_table_columns( - model=cls, - excludable=excludable, - alias=table_prefix, - use_alias=False, + model=cls, excludable=excludable, alias=table_prefix, use_alias=False, ) for column in cls.Meta.table.columns: diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index a6f1e93..b9e71df 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -1,14 +1,11 @@ from collections import OrderedDict from typing import ( Any, - Dict, List, Optional, - Set, TYPE_CHECKING, Tuple, Type, - Union, ) import sqlalchemy @@ -16,12 +13,12 @@ from sqlalchemy import text import ormar # noqa I100 from ormar.exceptions import RelationshipInstanceError -from ormar.models.excludable import ExcludableItems from ormar.relations import AliasManager if TYPE_CHECKING: # pragma no cover from ormar import Model from ormar.queryset import OrderAction + from ormar.models.excludable import ExcludableItems class SqlJoin: @@ -30,7 +27,7 @@ class SqlJoin: used_aliases: List, select_from: sqlalchemy.sql.select, columns: List[sqlalchemy.Column], - excludable: ExcludableItems, + excludable: "ExcludableItems", order_columns: Optional[List["OrderAction"]], sorted_orders: OrderedDict, main_model: Type["Model"], @@ -44,7 +41,7 @@ class SqlJoin: self.related_models = related_models or [] self.select_from = select_from self.columns = columns - self.excludable=excludable + self.excludable = excludable self.order_columns = order_columns self.sorted_orders = sorted_orders self.main_model = main_model @@ -296,7 +293,6 @@ class SqlJoin: self._get_order_bys() - # TODO: fix fields and exclusions for through model? self_related_fields = self.next_model.own_table_columns( model=self.next_model, excludable=self.excludable, diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 88fc8e3..a661b73 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -1,19 +1,15 @@ from typing import ( - Any, Dict, List, - Optional, Sequence, Set, TYPE_CHECKING, Tuple, Type, - Union, cast, ) import ormar -from ormar.models.excludable import ExcludableItems from ormar.queryset.clause import QueryClause from ormar.queryset.query import Query from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict @@ -22,29 +18,7 @@ if TYPE_CHECKING: # pragma: no cover from ormar import Model from ormar.fields import ForeignKeyField, BaseField from ormar.queryset import OrderAction - - -def add_relation_field_to_fields( - fields: Union[Set[Any], Dict[Any, Any], None], related_field_name: str -) -> Union[Set[Any], Dict[Any, Any], None]: - """ - Adds related field into fields to include as otherwise it would be skipped. - Related field is added only if fields are already populated. - Empty fields implies all fields. - - :param fields: Union[Set[Any], Dict[Any, Any], None] - :type fields: Dict - :param related_field_name: name of the field with relation - :type related_field_name: str - :return: updated fields dict - :rtype: Union[Set[Any], Dict[Any, Any], None] - """ - if fields and related_field_name not in fields: - if isinstance(fields, dict): - fields[related_field_name] = ... - elif isinstance(fields, set): - fields.add(related_field_name) - return fields + from ormar.models.excludable import ExcludableItems def sort_models(models: List["Model"], orders_by: Dict) -> List["Model"]: @@ -74,12 +48,12 @@ def sort_models(models: List["Model"], orders_by: Dict) -> List["Model"]: def set_children_on_model( # noqa: CCR001 - model: "Model", - related: str, - children: Dict, - model_id: int, - models: Dict, - orders_by: Dict, + model: "Model", + related: str, + children: Dict, + model_id: int, + models: Dict, + orders_by: Dict, ) -> None: """ Extract ids of child models by given relation id key value. @@ -124,12 +98,12 @@ class PrefetchQuery: """ def __init__( # noqa: CFQ002 - self, - model_cls: Type["Model"], - excludable: ExcludableItems, - prefetch_related: List, - select_related: List, - orders_by: List["OrderAction"], + self, + model_cls: Type["Model"], + excludable: "ExcludableItems", + prefetch_related: List, + select_related: List, + orders_by: List["OrderAction"], ) -> None: self.model = model_cls @@ -147,7 +121,7 @@ class PrefetchQuery: ) async def prefetch_related( - self, models: Sequence["Model"], rows: List + self, models: Sequence["Model"], rows: List ) -> Sequence["Model"]: """ Main entry point for prefetch_query. @@ -172,7 +146,7 @@ class PrefetchQuery: return await self._prefetch_related_models(models=models, rows=rows) def _extract_ids_from_raw_data( - self, parent_model: Type["Model"], column_name: str + self, parent_model: Type["Model"], column_name: str ) -> Set: """ Iterates over raw rows and extract id values of relation columns by using @@ -195,7 +169,7 @@ class PrefetchQuery: return list_of_ids def _extract_ids_from_preloaded_models( - self, parent_model: Type["Model"], column_name: str + self, parent_model: Type["Model"], column_name: str ) -> Set: """ Extracts relation ids from already populated models if they were included @@ -218,7 +192,7 @@ class PrefetchQuery: return list_of_ids def _extract_required_ids( - self, parent_model: Type["Model"], reverse: bool, related: str, + self, parent_model: Type["Model"], reverse: bool, related: str, ) -> Set: """ Delegates extraction of the fields to either get ids from raw sql response @@ -252,11 +226,11 @@ class PrefetchQuery: ) def _get_filter_for_prefetch( - self, - parent_model: Type["Model"], - target_model: Type["Model"], - reverse: bool, - related: str, + self, + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool, + related: str, ) -> List: """ Populates where clause with condition to return only models within the @@ -297,7 +271,7 @@ class PrefetchQuery: return [] def _populate_nested_related( - self, model: "Model", prefetch_dict: Dict, orders_by: Dict, + self, model: "Model", prefetch_dict: Dict, orders_by: Dict, ) -> "Model": """ Populates all related models children of parent model that are @@ -341,7 +315,7 @@ class PrefetchQuery: return model async def _prefetch_related_models( - self, models: Sequence["Model"], rows: List + self, models: Sequence["Model"], rows: List ) -> Sequence["Model"]: """ Main method of the query. @@ -385,13 +359,13 @@ class PrefetchQuery: return models async def _extract_related_models( # noqa: CFQ002, CCR001 - self, - related: str, - target_model: Type["Model"], - prefetch_dict: Dict, - select_dict: Dict, - excludable: ExcludableItems, - orders_by: Dict, + self, + related: str, + target_model: Type["Model"], + prefetch_dict: Dict, + select_dict: Dict, + excludable: "ExcludableItems", + orders_by: Dict, ) -> None: """ Constructs queries with required ids and extracts data with fields that should @@ -443,15 +417,16 @@ class PrefetchQuery: related_field_name = parent_model.get_related_field_name( target_field=target_field ) - table_prefix, rows = await self._run_prefetch_query( + table_prefix, exclude_prefix, rows = await self._run_prefetch_query( target_field=target_field, excludable=excludable, filter_clauses=filter_clauses, - related_field_name=related_field_name + related_field_name=related_field_name, ) else: rows = [] table_prefix = "" + exclude_prefix = "" if prefetch_dict and prefetch_dict is not Ellipsis: for subrelated in prefetch_dict.keys(): @@ -472,6 +447,7 @@ class PrefetchQuery: parent_model=parent_model, target_field=target_field, table_prefix=table_prefix, + exclude_prefix=exclude_prefix, excludable=excludable, prefetch_dict=prefetch_dict, orders_by=orders_by, @@ -484,12 +460,12 @@ class PrefetchQuery: ) async def _run_prefetch_query( - self, - target_field: Type["BaseField"], - excludable: ExcludableItems, - filter_clauses: List, - related_field_name: str - ) -> Tuple[str, List]: + self, + target_field: Type["BaseField"], + excludable: "ExcludableItems", + filter_clauses: List, + related_field_name: str, + ) -> Tuple[str, str, List]: """ Actually runs the queries against the database and populates the raw response for given related model. @@ -509,17 +485,22 @@ class PrefetchQuery: select_related = [] query_target = target_model table_prefix = "" + exclude_prefix = target_field.to.Meta.alias_manager.resolve_relation_alias( + from_model=target_field.owner, relation_name=target_field.name + ) if target_field.is_multi: query_target = target_field.through select_related = [target_name] table_prefix = target_field.to.Meta.alias_manager.resolve_relation_alias( from_model=query_target, relation_name=target_name ) + exclude_prefix = table_prefix self.already_extracted.setdefault(target_name, {})["prefix"] = table_prefix - model_excludable = excludable.get(model_cls=target_model, alias=table_prefix) + model_excludable = excludable.get(model_cls=target_model, alias=exclude_prefix) if model_excludable.include and not model_excludable.is_included( - related_field_name): + related_field_name + ): model_excludable.set_values({related_field_name}, is_exclude=False) qry = Query( @@ -537,7 +518,7 @@ class PrefetchQuery: # print(expr.compile(compile_kwargs={"literal_binds": True})) rows = await self.database.fetch_all(expr) self.already_extracted.setdefault(target_name, {}).update({"raw": rows}) - return table_prefix, rows + return table_prefix, exclude_prefix, rows @staticmethod def _get_select_related_if_apply(related: str, select_dict: Dict) -> Dict: @@ -559,7 +540,7 @@ class PrefetchQuery: ) def _update_already_loaded_rows( # noqa: CFQ002 - self, target_field: Type["BaseField"], prefetch_dict: Dict, orders_by: Dict, + self, target_field: Type["BaseField"], prefetch_dict: Dict, orders_by: Dict, ) -> None: """ Updates models that are already loaded, usually children of children. @@ -578,14 +559,15 @@ class PrefetchQuery: ) def _populate_rows( # noqa: CFQ002 - self, - rows: List, - target_field: Type["ForeignKeyField"], - parent_model: Type["Model"], - table_prefix: str, - excludable: ExcludableItems, - prefetch_dict: Dict, - orders_by: Dict, + self, + rows: List, + target_field: Type["ForeignKeyField"], + parent_model: Type["Model"], + table_prefix: str, + exclude_prefix: str, + excludable: "ExcludableItems", + prefetch_dict: Dict, + orders_by: Dict, ) -> None: """ Instantiates children models extracted from given relation. @@ -617,13 +599,10 @@ class PrefetchQuery: # TODO Fix fields field_name = parent_model.get_related_field_name(target_field=target_field) item = target_model.extract_prefixed_table_columns( - item={}, - row=row, - table_prefix=table_prefix, - excludable=excludable, + item={}, row=row, table_prefix=table_prefix, excludable=excludable, ) item["__excluded__"] = target_model.get_names_to_exclude( - excludable=excludable, alias=table_prefix + excludable=excludable, alias=exclude_prefix ) instance = target_model(**item) instance = self._populate_nested_related( diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 2e88212..0987bac 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -1,12 +1,10 @@ -import copy from collections import OrderedDict -from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Type, Union +from typing import List, Optional, TYPE_CHECKING, Tuple, Type import sqlalchemy from sqlalchemy import text import ormar # noqa I100 -from ormar.models.excludable import ExcludableItems from ormar.models.helpers.models import group_related_list from ormar.queryset import FilterQuery, LimitQuery, OffsetQuery, OrderQuery from ormar.queryset.actions.filter_action import FilterAction @@ -15,20 +13,21 @@ from ormar.queryset.join import SqlJoin if TYPE_CHECKING: # pragma no cover from ormar import Model from ormar.queryset import OrderAction + from ormar.models.excludable import ExcludableItems class Query: def __init__( # noqa CFQ002 - self, - model_cls: Type["Model"], - filter_clauses: List[FilterAction], - exclude_clauses: List[FilterAction], - select_related: List, - limit_count: Optional[int], - offset: Optional[int], - excludable: ExcludableItems, - order_bys: Optional[List["OrderAction"]], - limit_raw_sql: bool, + self, + model_cls: Type["Model"], + filter_clauses: List[FilterAction], + exclude_clauses: List[FilterAction], + select_related: List, + limit_count: Optional[int], + offset: Optional[int], + excludable: "ExcludableItems", + order_bys: Optional[List["OrderAction"]], + limit_raw_sql: bool, ) -> None: self.query_offset = offset self.limit_count = limit_count @@ -103,9 +102,7 @@ class Query: :rtype: sqlalchemy.sql.selectable.Select """ self_related_fields = self.model_cls.own_table_columns( - model=self.model_cls, - excludable=self.excludable, - use_alias=True, + model=self.model_cls, excludable=self.excludable, use_alias=True, ) self.columns = self.model_cls.Meta.alias_manager.prefixed_columns( "", self.table, self_related_fields @@ -191,7 +188,7 @@ class Query: return expr def _apply_expression_modifiers( - self, expr: sqlalchemy.sql.select + self, expr: sqlalchemy.sql.select ) -> sqlalchemy.sql.select: """ Receives the select query (might be join) and applies: diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 1c93590..f4f9fd0 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -20,18 +20,17 @@ from sqlalchemy import bindparam import ormar # noqa I100 from ormar import MultipleMatches, NoMatch from ormar.exceptions import ModelError, ModelPersistenceError, QueryDefinitionError -from ormar.models.excludable import ExcludableItems from ormar.queryset import FilterQuery from ormar.queryset.actions.order_action import OrderAction from ormar.queryset.clause import QueryClause from ormar.queryset.prefetch_query import PrefetchQuery from ormar.queryset.query import Query -from ormar.queryset.utils import update, update_dict_from_list if TYPE_CHECKING: # pragma no cover from ormar.models import T from ormar.models.metaclass import ModelMeta from ormar.relations.querysetproxy import QuerysetProxy + from ormar.models.excludable import ExcludableItems else: T = TypeVar("T") @@ -42,18 +41,20 @@ class QuerySet(Generic[T]): """ def __init__( # noqa CFQ002 - self, - model_cls: Optional[Type[T]] = None, - filter_clauses: List = None, - exclude_clauses: List = None, - select_related: List = None, - limit_count: int = None, - offset: int = None, - excludable: ExcludableItems = None, - order_bys: List = None, - prefetch_related: List = None, - limit_raw_sql: bool = False, + self, + model_cls: Optional[Type[T]] = None, + filter_clauses: List = None, + exclude_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, + excludable: "ExcludableItems" = None, + order_bys: List = None, + prefetch_related: List = None, + limit_raw_sql: bool = False, + proxy_source_model: Optional[Type[T]] = None, ) -> None: + self.proxy_source_model = proxy_source_model self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses self.exclude_clauses = [] if exclude_clauses is None else exclude_clauses @@ -61,14 +62,14 @@ class QuerySet(Generic[T]): self._prefetch_related = [] if prefetch_related is None else prefetch_related self.limit_count = limit_count self.query_offset = offset - self._excludable = excludable or ExcludableItems() + self._excludable = excludable or ormar.ExcludableItems() self.order_bys = order_bys or [] self.limit_sql_raw = limit_raw_sql def __get__( - self, - instance: Optional[Union["QuerySet", "QuerysetProxy"]], - owner: Union[Type[T], Type["QuerysetProxy"]], + self, + instance: Optional[Union["QuerySet", "QuerysetProxy"]], + owner: Union[Type[T], Type["QuerysetProxy"]], ) -> "QuerySet": if issubclass(owner, ormar.Model): if owner.Meta.requires_ref_update: @@ -105,8 +106,53 @@ class QuerySet(Generic[T]): raise ValueError("Model class of QuerySet is not initialized") return self.model_cls + def rebuild_self( # noqa: CFQ002 + self, + filter_clauses: List = None, + exclude_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, + excludable: "ExcludableItems" = None, + order_bys: List = None, + prefetch_related: List = None, + limit_raw_sql: bool = None, + proxy_source_model: Optional[Type[T]] = None, + ) -> "QuerySet": + """ + Method that returns new instance of queryset based on passed params, + all not passed params are taken from current values. + """ + overwrites = { + "select_related": "_select_related", + "offset": "query_offset", + "excludable": "_excludable", + "prefetch_related": "_prefetch_related", + "limit_raw_sql": "limit_sql_raw", + } + passed_args = locals() + + def replace_if_none(arg_name: str) -> Any: + if passed_args.get(arg_name) is None: + return getattr(self, overwrites.get(arg_name, arg_name)) + return passed_args.get(arg_name) + + return self.__class__( + model_cls=self.model_cls, + filter_clauses=replace_if_none("filter_clauses"), + exclude_clauses=replace_if_none("exclude_clauses"), + select_related=replace_if_none("select_related"), + limit_count=replace_if_none("limit_count"), + offset=replace_if_none("offset"), + excludable=replace_if_none("excludable"), + order_bys=replace_if_none("order_bys"), + prefetch_related=replace_if_none("prefetch_related"), + limit_raw_sql=replace_if_none("limit_raw_sql"), + proxy_source_model=replace_if_none("proxy_source_model"), + ) + async def _prefetch_related_models( - self, models: Sequence[Optional["T"]], rows: List + self, models: Sequence[Optional["T"]], rows: List ) -> Sequence[Optional["T"]]: """ Performs prefetch query for selected models names. @@ -142,6 +188,7 @@ class QuerySet(Generic[T]): select_related=self._select_related, excludable=self._excludable, source_model=self.model, + proxy_source_model=self.proxy_source_model, ) for row in rows ] @@ -183,7 +230,7 @@ class QuerySet(Generic[T]): return self.model_meta.table def build_select_expression( - self, limit: int = None, offset: int = None, order_bys: List = None, + self, limit: int = None, offset: int = None, order_bys: List = None, ) -> sqlalchemy.sql.select: """ Constructs the actual database query used in the QuerySet. @@ -254,17 +301,10 @@ class QuerySet(Generic[T]): exclude_clauses = self.exclude_clauses filter_clauses = filter_clauses - return self.__class__( - model_cls=self.model, + return self.rebuild_self( filter_clauses=filter_clauses, exclude_clauses=exclude_clauses, select_related=select_related, - limit_count=self.limit_count, - offset=self.query_offset, - excludable=self._excludable, - order_bys=self.order_bys, - prefetch_related=self._prefetch_related, - limit_raw_sql=self.limit_sql_raw, ) def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003 @@ -309,18 +349,7 @@ class QuerySet(Generic[T]): related = [related] related = list(set(list(self._select_related) + related)) - return self.__class__( - model_cls=self.model, - filter_clauses=self.filter_clauses, - exclude_clauses=self.exclude_clauses, - select_related=related, - limit_count=self.limit_count, - offset=self.query_offset, - excludable=self._excludable, - order_bys=self.order_bys, - prefetch_related=self._prefetch_related, - limit_raw_sql=self.limit_sql_raw, - ) + return self.rebuild_self(select_related=related,) def prefetch_related(self, related: Union[List, str]) -> "QuerySet": """ @@ -344,21 +373,11 @@ class QuerySet(Generic[T]): related = [related] related = list(set(list(self._prefetch_related) + related)) - return self.__class__( - model_cls=self.model, - filter_clauses=self.filter_clauses, - exclude_clauses=self.exclude_clauses, - select_related=self._select_related, - limit_count=self.limit_count, - offset=self.query_offset, - excludable=self._excludable, - order_bys=self.order_bys, - prefetch_related=related, - limit_raw_sql=self.limit_sql_raw, - ) + return self.rebuild_self(prefetch_related=related,) - def fields(self, columns: Union[List, str, Set, Dict], - _is_exclude: bool = False) -> "QuerySet": + def fields( + self, columns: Union[List, str, Set, Dict], _is_exclude: bool = False + ) -> "QuerySet": """ With `fields()` you can select subset of model columns to limit the data load. @@ -396,29 +415,22 @@ class QuerySet(Generic[T]): To include whole nested model specify model related field name and ellipsis. + :param _is_exclude: flag if it's exclude or include operation + :type _is_exclude: bool :param columns: columns to include :type columns: Union[List, str, Set, Dict] :return: QuerySet :rtype: QuerySet """ - excludable = ExcludableItems.from_excludable(self._excludable) - excludable.build(items=columns, - model_cls=self.model_cls, - is_exclude=_is_exclude) - - return self.__class__( - model_cls=self.model, - filter_clauses=self.filter_clauses, - exclude_clauses=self.exclude_clauses, - select_related=self._select_related, - limit_count=self.limit_count, - offset=self.query_offset, - excludable=excludable, - order_bys=self.order_bys, - prefetch_related=self._prefetch_related, - limit_raw_sql=self.limit_sql_raw, + excludable = ormar.ExcludableItems.from_excludable(self._excludable) + excludable.build( + items=columns, + model_cls=self.model_cls, # type: ignore + is_exclude=_is_exclude, ) + return self.rebuild_self(excludable=excludable,) + def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": """ With `exclude_fields()` you can select subset of model columns that will @@ -489,18 +501,7 @@ class QuerySet(Generic[T]): ] order_bys = self.order_bys + [x for x in orders_by if x not in self.order_bys] - return self.__class__( - model_cls=self.model, - filter_clauses=self.filter_clauses, - exclude_clauses=self.exclude_clauses, - select_related=self._select_related, - limit_count=self.limit_count, - offset=self.query_offset, - excludable=self._excludable, - order_bys=order_bys, - prefetch_related=self._prefetch_related, - limit_raw_sql=self.limit_sql_raw, - ) + return self.rebuild_self(order_bys=order_bys,) async def exists(self) -> bool: """ @@ -601,18 +602,7 @@ class QuerySet(Generic[T]): limit_count = page_size query_offset = (page - 1) * page_size - return self.__class__( - model_cls=self.model, - filter_clauses=self.filter_clauses, - exclude_clauses=self.exclude_clauses, - select_related=self._select_related, - limit_count=limit_count, - offset=query_offset, - excludable=self._excludable, - order_bys=self.order_bys, - prefetch_related=self._prefetch_related, - limit_raw_sql=self.limit_sql_raw, - ) + return self.rebuild_self(limit_count=limit_count, offset=query_offset,) def limit(self, limit_count: int, limit_raw_sql: bool = None) -> "QuerySet": """ @@ -629,18 +619,7 @@ class QuerySet(Generic[T]): :rtype: QuerySet """ limit_raw_sql = self.limit_sql_raw if limit_raw_sql is None else limit_raw_sql - return self.__class__( - model_cls=self.model, - filter_clauses=self.filter_clauses, - exclude_clauses=self.exclude_clauses, - select_related=self._select_related, - limit_count=limit_count, - offset=self.query_offset, - excludable=self._excludable, - order_bys=self.order_bys, - prefetch_related=self._prefetch_related, - limit_raw_sql=limit_raw_sql, - ) + return self.rebuild_self(limit_count=limit_count, limit_raw_sql=limit_raw_sql,) def offset(self, offset: int, limit_raw_sql: bool = None) -> "QuerySet": """ @@ -657,18 +636,7 @@ class QuerySet(Generic[T]): :rtype: QuerySet """ limit_raw_sql = self.limit_sql_raw if limit_raw_sql is None else limit_raw_sql - return self.__class__( - model_cls=self.model, - filter_clauses=self.filter_clauses, - exclude_clauses=self.exclude_clauses, - select_related=self._select_related, - limit_count=self.limit_count, - offset=offset, - excludable=self._excludable, - order_bys=self.order_bys, - prefetch_related=self._prefetch_related, - limit_raw_sql=limit_raw_sql, - ) + return self.rebuild_self(offset=offset, limit_raw_sql=limit_raw_sql,) async def first(self, **kwargs: Any) -> T: """ @@ -687,12 +655,12 @@ class QuerySet(Generic[T]): expr = self.build_select_expression( limit=1, order_bys=[ - OrderAction( - order_str=f"{self.model.Meta.pkname}", - model_cls=self.model_cls, # type: ignore - ) - ] - + self.order_bys, + OrderAction( + order_str=f"{self.model.Meta.pkname}", + model_cls=self.model_cls, # type: ignore + ) + ] + + self.order_bys, ) rows = await self.database.fetch_all(expr) processed_rows = self._process_query_result_rows(rows) @@ -723,12 +691,12 @@ class QuerySet(Generic[T]): expr = self.build_select_expression( limit=1, order_bys=[ - OrderAction( - order_str=f"-{self.model.Meta.pkname}", - model_cls=self.model_cls, # type: ignore - ) - ] - + self.order_bys, + OrderAction( + order_str=f"-{self.model.Meta.pkname}", + model_cls=self.model_cls, # type: ignore + ) + ] + + self.order_bys, ) else: expr = self.build_select_expression() @@ -831,9 +799,9 @@ class QuerySet(Generic[T]): # refresh server side defaults if any( - field.server_default is not None - for name, field in self.model.Meta.model_fields.items() - if name not in kwargs + field.server_default is not None + for name, field in self.model.Meta.model_fields.items() + if name not in kwargs ): instance = await instance.load() instance.set_save_status(True) @@ -868,7 +836,7 @@ class QuerySet(Generic[T]): objt.set_save_status(True) async def bulk_update( # noqa: CCR001 - self, objects: List[T], columns: List[str] = None + self, objects: List[T], columns: List[str] = None ) -> None: """ Performs bulk update in one database session to speed up the process. diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index 58d6e9e..596f594 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -119,7 +119,9 @@ class RelationProxy(list): self._check_if_model_saved() kwargs = {f"{related_field.get_alias()}__{pkname}": self._owner.pk} queryset = ( - ormar.QuerySet(model_cls=self.relation.to) + ormar.QuerySet( + model_cls=self.relation.to, proxy_source_model=self._owner.__class__ + ) .select_related(related_field.name) .filter(**kwargs) ) diff --git a/tests/test_choices_schema.py b/tests/test_choices_schema.py index 0cf1852..978c3f6 100644 --- a/tests/test_choices_schema.py +++ b/tests/test_choices_schema.py @@ -22,7 +22,7 @@ uuid1 = uuid.uuid4() uuid2 = uuid.uuid4() -class TestEnum(Enum): +class EnumTest(Enum): val1 = "Val1" val2 = "Val2" @@ -56,7 +56,7 @@ class Organisation(ormar.Model): ) random_json: pydantic.Json = ormar.JSON(choices=["aa", '{"aa":"bb"}']) random_uuid: uuid.UUID = ormar.UUID(choices=[uuid1, uuid2]) - enum_string: str = ormar.String(max_length=100, choices=list(TestEnum)) + enum_string: str = ormar.String(max_length=100, choices=list(EnumTest)) @app.on_event("startup") @@ -110,7 +110,7 @@ def test_all_endpoints(): "random_decimal": 12.4, "random_json": '{"aa":"bb"}', "random_uuid": str(uuid1), - "enum_string": TestEnum.val1.value, + "enum_string": EnumTest.val1.value, }, ) diff --git a/tests/test_m2m_through_fields.py b/tests/test_m2m_through_fields.py index 8dd8bba..6c4211a 100644 --- a/tests/test_m2m_through_fields.py +++ b/tests/test_m2m_through_fields.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List, Sequence, cast import databases import pytest @@ -131,6 +131,7 @@ async def test_getting_additional_fields_from_queryset() -> Any: ) await post.categories.all() + assert post.postcategory is None assert post.categories[0].postcategory.sort_order == 1 assert post.categories[1].postcategory.sort_order == 2 @@ -138,8 +139,31 @@ async def test_getting_additional_fields_from_queryset() -> Any: categories__name="Test category2" ) assert post2.categories[0].postcategory.sort_order == 2 - # if TYPE_CHECKING: - # reveal_type(post2) + + +@pytest.mark.asyncio +async def test_only_one_side_has_through() -> Any: + async with database: + post = await Post(title="Test post").save() + await post.categories.create( + name="Test category1", postcategory={"sort_order": 1} + ) + await post.categories.create( + name="Test category2", postcategory={"sort_order": 2} + ) + + post2 = await Post.objects.select_related("categories").get() + assert post2.postcategory is None + assert post2.categories[0].postcategory is not None + + await post2.categories.all() + assert post2.postcategory is None + assert post2.categories[0].postcategory is not None + + categories = await Category.objects.select_related("posts").all() + categories = cast(Sequence[Category], categories) + assert categories[0].postcategory is None + assert categories[0].posts[0].postcategory is not None @pytest.mark.asyncio @@ -294,7 +318,6 @@ async def test_update_through_from_related() -> Any: @pytest.mark.asyncio -@pytest.mark.skip # TODO: Restore after finished exclude refactor async def test_excluding_fields_on_through_model() -> Any: async with database: post = await Post(title="Test post").save() @@ -323,6 +346,17 @@ async def test_excluding_fields_on_through_model() -> Any: assert post2.categories[2].postcategory.param_name is None assert post2.categories[2].postcategory.sort_order == 3 + post3 = ( + await Post.objects.select_related("categories") + .fields({"postcategory": ..., "title": ...}) + .exclude_fields({"postcategory": {"param_name", "sort_order"}}) + .get() + ) + assert len(post3.categories) == 3 + for category in post3.categories: + assert category.postcategory.param_name is None + assert category.postcategory.sort_order is None + # TODO: check/ modify following @@ -337,9 +371,9 @@ async def test_excluding_fields_on_through_model() -> Any: # ordering by in order_by (V) # updating in query (V) # updating from querysetproxy (V) +# including/excluding in fields? # modifying from instance (both sides?) (X) <= no, the loaded one doesn't have relations -# including/excluding in fields? # allowing to change fk fields names in through model? # make through optional? auto-generated for cases other fields are missing? diff --git a/tests/test_queryset_utils.py b/tests/test_queryset_utils.py index daae2b4..cd96dc8 100644 --- a/tests/test_queryset_utils.py +++ b/tests/test_queryset_utils.py @@ -8,11 +8,6 @@ from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, from tests.settings import DATABASE_URL -def test_empty_excludable(): - assert ExcludableMixin.is_included(None, "key") # all fields included if empty - assert not ExcludableMixin.is_excluded(None, "key") # none field excluded if empty - - def test_list_to_dict_translation(): tet_list = ["aa", "bb", "cc__aa", "cc__bb", "cc__aa__xx", "cc__aa__yy"] test = translate_list_to_dict(tet_list)