diff --git a/README.md b/README.md index 2a072ed..3aeb6bb 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,7 @@ assert len(tracks) == 1 * `count() -> int` * `exists() -> bool` * `fields(columns: Union[List, str]) -> QuerySet` +* `exclude_fields(columns: Union[List, str]) -> QuerySet` * `order_by(columns:Union[List, str]) -> QuerySet` #### Relation types diff --git a/docs/index.md b/docs/index.md index b4e766c..e4d9788 100644 --- a/docs/index.md +++ b/docs/index.md @@ -164,6 +164,7 @@ assert len(tracks) == 1 * `count() -> int` * `exists() -> bool` * `fields(columns: Union[List, str]) -> QuerySet` +* `exclude_fields(columns: Union[List, str]) -> QuerySet` * `order_by(columns:Union[List, str]) -> QuerySet` diff --git a/docs/queries.md b/docs/queries.md index ad82d40..957b02b 100644 --- a/docs/queries.md +++ b/docs/queries.md @@ -370,6 +370,33 @@ With `fields()` you can select subset of model columns to limit the data load. Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` +### exclude_fields + +`fields(columns: Union[List, str]) -> QuerySet` + +With `exclude_fields()` you can select subset of model columns that will be excluded to limit the data load. + +It's the oposite of `fields()` method. + +```python hl_lines="47 48 60 61 67" +--8<-- "../docs_src/queries/docs008.py" +``` + +!!!warning + Mandatory fields cannot be excluded as it will raise `ValidationError`, to exclude a field it has to be nullable. + +!!!tip + Pk column cannot be excluded - it's always auto added even if explicitly excluded. + + +!!!note + All methods that do not return the rows explicitly returns a QueySet instance so you can chain them together + + So operations like `filter()`, `select_related()`, `limit()` and `offset()` etc. can be chained. + + Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` + + ### order_by `order_by(columns: Union[List, str]) -> QuerySet` diff --git a/docs_src/queries/docs008.py b/docs_src/queries/docs008.py new file mode 100644 index 0000000..52bbee7 --- /dev/null +++ b/docs_src/queries/docs008.py @@ -0,0 +1,68 @@ +import databases +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Company(ormar.Model): + class Meta: + tablename = "companies" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + founded: int = ormar.Integer(nullable=True) + + +class Car(ormar.Model): + class Meta: + tablename = "cars" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + manufacturer = ormar.ForeignKey(Company) + name: str = ormar.String(max_length=100) + year: int = ormar.Integer(nullable=True) + gearbox_type: str = ormar.String(max_length=20, nullable=True) + gears: int = ormar.Integer(nullable=True) + aircon_type: str = ormar.String(max_length=20, nullable=True) + + +# build some sample data +toyota = await Company.objects.create(name="Toyota", founded=1937) +await Car.objects.create(manufacturer=toyota, name="Corolla", year=2020, gearbox_type='Manual', gears=5, + aircon_type='Manual') +await Car.objects.create(manufacturer=toyota, name="Yaris", year=2019, gearbox_type='Manual', gears=5, + aircon_type='Manual') +await Car.objects.create(manufacturer=toyota, name="Supreme", year=2020, gearbox_type='Auto', gears=6, + aircon_type='Auto') + +# select manufacturer but only name - to include related models use notation {model_name}__{column} +all_cars = await Car.objects.select_related('manufacturer').exclude_fields( + ['year', 'gearbox_type', 'gears', 'aircon_type', 'company__founded']).all() +for car in all_cars: + # excluded columns will yield None + assert all(getattr(car, x) is None for x in ['year', 'gearbox_type', 'gears', 'aircon_type']) + # included column on related models will be available, pk column is always included + # even if you do not include it in fields list + assert car.manufacturer.name == 'Toyota' + # also in the nested related models - you cannot exclude pk - it's always auto added + assert car.manufacturer.founded is None + +# fields() can be called several times, building up the columns to select +# models selected in select_related but with no columns in fields list implies all fields +all_cars = await Car.objects.select_related('manufacturer').exclude_fields('year').exclude_fields( + ['gear', 'gearbox_type']).all() +# all fiels from company model are selected +assert all_cars[0].manufacturer.name == 'Toyota' +assert all_cars[0].manufacturer.founded == 1937 + +# cannot exclude mandatory model columns - company__name in this example +await Car.objects.select_related('manufacturer').exclude_fields(['company__name']).all() +# will raise pydantic ValidationError as company.name is required diff --git a/ormar/__init__.py b/ormar/__init__.py index 9bab42f..03476f4 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.4.3" +__version__ = "0.4.4" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/models/model.py b/ormar/models/model.py index d2bf904..bbd78cc 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -48,6 +48,7 @@ class Model(NewBaseModel): related_models: Any = None, previous_table: str = None, fields: List = None, + exclude_fields: List = None, ) -> Optional[T]: item: Dict[str, Any] = {} @@ -74,10 +75,20 @@ class Model(NewBaseModel): previous_table = cls.Meta.table.name item = cls.populate_nested_models_from_row( - item, row, related_models, previous_table, fields + item=item, + row=row, + related_models=related_models, + previous_table=previous_table, + fields=fields, + exclude_fields=exclude_fields, ) item = cls.extract_prefixed_table_columns( - item, row, table_prefix, fields, nested=table_prefix != "" + item=item, + row=row, + table_prefix=table_prefix, + fields=fields, + exclude_fields=exclude_fields, + nested=table_prefix != "", ) instance: Optional[T] = cls(**item) if item.get( @@ -86,13 +97,14 @@ class Model(NewBaseModel): return instance @classmethod - def populate_nested_models_from_row( + def populate_nested_models_from_row( # noqa: CFQ002 cls, item: dict, row: sqlalchemy.engine.ResultProxy, related_models: Any, previous_table: sqlalchemy.Table, fields: List = None, + exclude_fields: List = None, ) -> dict: for related in related_models: if isinstance(related_models, dict) and related_models[related]: @@ -103,12 +115,16 @@ class Model(NewBaseModel): related_models=remainder, previous_table=previous_table, fields=fields, + exclude_fields=exclude_fields, ) item[model_cls.get_column_name_from_alias(first_part)] = child else: model_cls = cls.Meta.model_fields[related].to child = model_cls.from_row( - row, previous_table=previous_table, fields=fields + row, + previous_table=previous_table, + fields=fields, + exclude_fields=exclude_fields, ) item[model_cls.get_column_name_from_alias(related)] = child @@ -121,6 +137,7 @@ class Model(NewBaseModel): row: sqlalchemy.engine.result.ResultProxy, table_prefix: str, fields: List = None, + exclude_fields: List = None, nested: bool = False, ) -> dict: @@ -128,8 +145,9 @@ class Model(NewBaseModel): source = row._row if cls.db_backend_name() == "postgresql" else row selected_columns = cls.own_table_columns( - cls, fields or [], nested=nested, use_alias=True + cls, fields or [], exclude_fields or [], nested=nested, use_alias=True ) + for column in cls.Meta.table.columns: alias = cls.get_column_name_from_alias(column.name) if alias not in item and alias in selected_columns: diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 9033eda..98e0ad6 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -219,16 +219,29 @@ class ModelTableProxy: def _get_not_nested_columns_from_fields( model: Type["Model"], fields: List, + exclude_fields: List, column_names: List[str], use_alias: bool = False, ) -> List[str]: fields = [model.get_column_alias(k) if not use_alias else k for k in fields] - columns = [name for name in fields if "__" not in name and name in column_names] + fields = fields or column_names + exclude_fields = [ + model.get_column_alias(k) if not use_alias else k for k in exclude_fields + ] + columns = [ + name + for name in fields + if "__" not in name and name in column_names and name not in exclude_fields + ] return columns @staticmethod def _get_nested_columns_from_fields( - model: Type["Model"], fields: List, use_alias: bool = False, + model: Type["Model"], + fields: List, + exclude_fields: List, + column_names: List[str], + use_alias: bool = False, ) -> List[str]: model_name = f"{model.get_name()}__" columns = [ @@ -236,37 +249,22 @@ class ModelTableProxy: for name in fields if f"{model.get_name()}__" in name ] + columns = columns or column_names + exclude_columns = [ + name[(name.find(model_name) + len(model_name)) :] # noqa: E203 + for name in exclude_fields + if f"{model.get_name()}__" in name + ] columns = [model.get_column_alias(k) if not use_alias else k for k in columns] - return columns + exclude_columns = [ + model.get_column_alias(k) if not use_alias else k for k in exclude_columns + ] + return [column for column in columns if column not in exclude_columns] @staticmethod - def own_table_columns( - model: Type["Model"], - fields: List, - nested: bool = False, - use_alias: bool = False, + def _populate_pk_column( + model: Type["Model"], columns: List[str], use_alias: bool = False, ) -> List[str]: - column_names = [ - model.get_column_name_from_alias(col.name) if use_alias else col.name - for col in model.Meta.table.columns - ] - if not fields: - return column_names - - if not nested: - columns = ModelTableProxy._get_not_nested_columns_from_fields( - model, fields, column_names, use_alias - ) - else: - columns = ModelTableProxy._get_nested_columns_from_fields( - model, fields, use_alias - ) - - # if the model is in select and no columns in fields, all implied - if not columns: - columns = column_names - - # always has to return pk column pk_alias = ( model.get_column_alias(model.Meta.pkname) if not use_alias @@ -275,3 +273,42 @@ class ModelTableProxy: if pk_alias not in columns: columns.append(pk_alias) return columns + + @staticmethod + def own_table_columns( + model: Type["Model"], + fields: List, + exclude_fields: List, + nested: bool = False, + use_alias: bool = False, + ) -> List[str]: + column_names = [ + model.get_column_name_from_alias(col.name) if use_alias else col.name + for col in model.Meta.table.columns + ] + if not fields and not exclude_fields: + return column_names + + if not nested: + columns = ModelTableProxy._get_not_nested_columns_from_fields( + model=model, + fields=fields, + exclude_fields=exclude_fields, + column_names=column_names, + use_alias=use_alias, + ) + else: + columns = ModelTableProxy._get_nested_columns_from_fields( + model=model, + fields=fields, + exclude_fields=exclude_fields, + column_names=column_names, + use_alias=use_alias, + ) + + # always has to return pk column + columns = ModelTableProxy._populate_pk_column( + model=model, columns=columns, use_alias=use_alias + ) + + return columns diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index b3bbb06..0cc6953 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -25,6 +25,7 @@ class SqlJoin: select_from: sqlalchemy.sql.select, columns: List[sqlalchemy.Column], fields: List, + exclude_fields: List, order_columns: Optional[List], sorted_orders: OrderedDict, ) -> None: @@ -32,6 +33,7 @@ class SqlJoin: self.select_from = select_from self.columns = columns self.fields = fields + self.exclude_fields = exclude_fields self.order_columns = order_columns self.sorted_orders = sorted_orders @@ -121,7 +123,7 @@ class SqlJoin: self.get_order_bys(alias, to_table, pkname_alias, part) self_related_fields = model_cls.own_table_columns( - model_cls, self.fields, nested=True, + model_cls, self.fields, self.exclude_fields, nested=True, ) self.columns.extend( self.relation_manager(model_cls).prefixed_columns( diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 8665829..ebe8925 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -22,6 +22,7 @@ class Query: limit_count: Optional[int], offset: Optional[int], fields: Optional[List], + exclude_fields: Optional[List], order_bys: Optional[List], ) -> None: self.query_offset = offset @@ -30,6 +31,7 @@ class Query: self.filter_clauses = filter_clauses[:] self.exclude_clauses = exclude_clauses[:] self.fields = fields[:] if fields else [] + self.exclude_fields = exclude_fields[:] if exclude_fields else [] self.model_cls = model_cls self.table = self.model_cls.Meta.table @@ -68,7 +70,7 @@ class Query: def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: self_related_fields = self.model_cls.own_table_columns( - self.model_cls, self.fields + self.model_cls, self.fields, self.exclude_fields ) self.columns = self.model_cls.Meta.alias_manager.prefixed_columns( "", self.table, self_related_fields @@ -88,6 +90,7 @@ class Query: select_from=self.select_from, columns=self.columns, fields=self.fields, + exclude_fields=self.exclude_fields, order_columns=self.order_columns, sorted_orders=self.sorted_orders, ) @@ -126,3 +129,4 @@ class Query: self.columns = [] self.used_aliases = [] self.fields = [] + self.exclude_fields = [] diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 401f7a0..d45908e 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -27,6 +27,7 @@ class QuerySet: limit_count: int = None, offset: int = None, columns: List = None, + exclude_columns: List = None, order_bys: List = None, ) -> None: self.model_cls = model_cls @@ -36,6 +37,7 @@ class QuerySet: self.limit_count = limit_count self.query_offset = offset self._columns = columns or [] + self._exclude_columns = exclude_columns or [] self.order_bys = order_bys or [] def __get__( @@ -63,7 +65,10 @@ class QuerySet: def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: result_rows = [ self.model.from_row( - row, select_related=self._select_related, fields=self._columns + row=row, + select_related=self._select_related, + fields=self._columns, + exclude_fields=self._exclude_columns, ) for row in rows ] @@ -111,6 +116,7 @@ class QuerySet: offset=self.query_offset, limit_count=self.limit_count, fields=self._columns, + exclude_fields=self._exclude_columns, order_bys=self.order_bys, ) exp = qry.build_select_expression() @@ -139,6 +145,7 @@ class QuerySet: limit_count=self.limit_count, offset=self.query_offset, columns=self._columns, + exclude_columns=self._exclude_columns, order_bys=self.order_bys, ) @@ -158,6 +165,24 @@ class QuerySet: limit_count=self.limit_count, offset=self.query_offset, columns=self._columns, + exclude_columns=self._exclude_columns, + order_bys=self.order_bys, + ) + + def exclude_fields(self, columns: Union[List, str]) -> "QuerySet": + if not isinstance(columns, list): + columns = [columns] + + columns = list(set(list(self._exclude_columns) + columns)) + return self.__class__( + model_cls=self.model, + filter_clauses=self.filter_clauses, + exclude_clauses=self.exclude_clauses, + select_related=self._select_related, + limit_count=self.limit_count, + offset=self.query_offset, + columns=self._columns, + exclude_columns=columns, order_bys=self.order_bys, ) @@ -174,6 +199,7 @@ class QuerySet: limit_count=self.limit_count, offset=self.query_offset, columns=columns, + exclude_columns=self._exclude_columns, order_bys=self.order_bys, ) @@ -190,6 +216,7 @@ class QuerySet: limit_count=self.limit_count, offset=self.query_offset, columns=self._columns, + exclude_columns=self._exclude_columns, order_bys=order_bys, ) @@ -239,6 +266,7 @@ class QuerySet: limit_count=limit_count, offset=self.query_offset, columns=self._columns, + exclude_columns=self._exclude_columns, order_bys=self.order_bys, ) @@ -251,6 +279,7 @@ class QuerySet: limit_count=self.limit_count, offset=offset, columns=self._columns, + exclude_columns=self._exclude_columns, order_bys=self.order_bys, ) diff --git a/tests/test_excluding_subset_of_columns.py b/tests/test_excluding_subset_of_columns.py new file mode 100644 index 0000000..6679c12 --- /dev/null +++ b/tests/test_excluding_subset_of_columns.py @@ -0,0 +1,121 @@ +from typing import Optional + +import databases +import pydantic +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Company(ormar.Model): + class Meta: + tablename = "companies" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False) + founded: int = ormar.Integer(nullable=True) + + +class Car(ormar.Model): + class Meta: + tablename = "cars" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + manufacturer: Optional[Company] = ormar.ForeignKey(Company) + name: str = ormar.String(max_length=100) + year: int = ormar.Integer(nullable=True) + gearbox_type: str = ormar.String(max_length=20, nullable=True) + gears: int = ormar.Integer(nullable=True) + aircon_type: str = ormar.String(max_length=20, nullable=True) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_selecting_subset(): + async with database: + async with database.transaction(force_rollback=True): + toyota = await Company.objects.create(name="Toyota", founded=1937) + await Car.objects.create( + manufacturer=toyota, + name="Corolla", + year=2020, + gearbox_type="Manual", + gears=5, + aircon_type="Manual", + ) + await Car.objects.create( + manufacturer=toyota, + name="Yaris", + year=2019, + gearbox_type="Manual", + gears=5, + aircon_type="Manual", + ) + await Car.objects.create( + manufacturer=toyota, + name="Supreme", + year=2020, + gearbox_type="Auto", + gears=6, + aircon_type="Auto", + ) + + all_cars = ( + await Car.objects.select_related("manufacturer") + .exclude_fields(["gearbox_type", "gears", "aircon_type", "year", "company__founded"]) + .all() + ) + for car in all_cars: + assert all( + getattr(car, x) is None + for x in ["year", "gearbox_type", "gears", "aircon_type"] + ) + assert car.manufacturer.name == "Toyota" + assert car.manufacturer.founded is None + + all_cars = ( + await Car.objects.select_related("manufacturer") + .exclude_fields("year") + .exclude_fields(["gearbox_type", "gears"]) + .exclude_fields("aircon_type") + .all() + ) + for car in all_cars: + assert all( + getattr(car, x) is None + for x in ["year", "gearbox_type", "gears", "aircon_type"] + ) + assert car.manufacturer.name == "Toyota" + assert car.manufacturer.founded == 1937 + + all_cars_check = await Car.objects.select_related("manufacturer").all() + for car in all_cars_check: + assert all( + getattr(car, x) is not None + for x in ["year", "gearbox_type", "gears", "aircon_type"] + ) + assert car.manufacturer.name == "Toyota" + assert car.manufacturer.founded == 1937 + + with pytest.raises(pydantic.error_wrappers.ValidationError): + # cannot exclude mandatory model columns - company__name in this example + await Car.objects.select_related("manufacturer").exclude_fields( + ["company__name"] + ).all()