add exclude_fields queryset method

This commit is contained in:
collerek
2020-11-10 16:54:24 +01:00
parent 9413e51e6e
commit eafc6862f4
11 changed files with 346 additions and 38 deletions

View File

@ -164,6 +164,7 @@ assert len(tracks) == 1
* `count() -> int` * `count() -> int`
* `exists() -> bool` * `exists() -> bool`
* `fields(columns: Union[List, str]) -> QuerySet` * `fields(columns: Union[List, str]) -> QuerySet`
* `exclude_fields(columns: Union[List, str]) -> QuerySet`
* `order_by(columns:Union[List, str]) -> QuerySet` * `order_by(columns:Union[List, str]) -> QuerySet`
#### Relation types #### Relation types

View File

@ -164,6 +164,7 @@ assert len(tracks) == 1
* `count() -> int` * `count() -> int`
* `exists() -> bool` * `exists() -> bool`
* `fields(columns: Union[List, str]) -> QuerySet` * `fields(columns: Union[List, str]) -> QuerySet`
* `exclude_fields(columns: Union[List, str]) -> QuerySet`
* `order_by(columns:Union[List, str]) -> QuerySet` * `order_by(columns:Union[List, str]) -> QuerySet`

View File

@ -370,6 +370,33 @@ With `fields()` you can select subset of model columns to limit the data load.
Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()`
### exclude_fields
`fields(columns: Union[List, str]) -> QuerySet`
With `exclude_fields()` you can select subset of model columns that will be excluded to limit the data load.
It's the oposite of `fields()` method.
```python hl_lines="47 48 60 61 67"
--8<-- "../docs_src/queries/docs008.py"
```
!!!warning
Mandatory fields cannot be excluded as it will raise `ValidationError`, to exclude a field it has to be nullable.
!!!tip
Pk column cannot be excluded - it's always auto added even if explicitly excluded.
!!!note
All methods that do not return the rows explicitly returns a QueySet instance so you can chain them together
So operations like `filter()`, `select_related()`, `limit()` and `offset()` etc. can be chained.
Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()`
### order_by ### order_by
`order_by(columns: Union[List, str]) -> QuerySet` `order_by(columns: Union[List, str]) -> QuerySet`

View File

@ -0,0 +1,68 @@
import databases
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Company(ormar.Model):
class Meta:
tablename = "companies"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
founded: int = ormar.Integer(nullable=True)
class Car(ormar.Model):
class Meta:
tablename = "cars"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
manufacturer = ormar.ForeignKey(Company)
name: str = ormar.String(max_length=100)
year: int = ormar.Integer(nullable=True)
gearbox_type: str = ormar.String(max_length=20, nullable=True)
gears: int = ormar.Integer(nullable=True)
aircon_type: str = ormar.String(max_length=20, nullable=True)
# build some sample data
toyota = await Company.objects.create(name="Toyota", founded=1937)
await Car.objects.create(manufacturer=toyota, name="Corolla", year=2020, gearbox_type='Manual', gears=5,
aircon_type='Manual')
await Car.objects.create(manufacturer=toyota, name="Yaris", year=2019, gearbox_type='Manual', gears=5,
aircon_type='Manual')
await Car.objects.create(manufacturer=toyota, name="Supreme", year=2020, gearbox_type='Auto', gears=6,
aircon_type='Auto')
# select manufacturer but only name - to include related models use notation {model_name}__{column}
all_cars = await Car.objects.select_related('manufacturer').exclude_fields(
['year', 'gearbox_type', 'gears', 'aircon_type', 'company__founded']).all()
for car in all_cars:
# excluded columns will yield None
assert all(getattr(car, x) is None for x in ['year', 'gearbox_type', 'gears', 'aircon_type'])
# included column on related models will be available, pk column is always included
# even if you do not include it in fields list
assert car.manufacturer.name == 'Toyota'
# also in the nested related models - you cannot exclude pk - it's always auto added
assert car.manufacturer.founded is None
# fields() can be called several times, building up the columns to select
# models selected in select_related but with no columns in fields list implies all fields
all_cars = await Car.objects.select_related('manufacturer').exclude_fields('year').exclude_fields(
['gear', 'gearbox_type']).all()
# all fiels from company model are selected
assert all_cars[0].manufacturer.name == 'Toyota'
assert all_cars[0].manufacturer.founded == 1937
# cannot exclude mandatory model columns - company__name in this example
await Car.objects.select_related('manufacturer').exclude_fields(['company__name']).all()
# will raise pydantic ValidationError as company.name is required

View File

@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType() Undefined = UndefinedType()
__version__ = "0.4.3" __version__ = "0.4.4"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",

View File

@ -48,6 +48,7 @@ class Model(NewBaseModel):
related_models: Any = None, related_models: Any = None,
previous_table: str = None, previous_table: str = None,
fields: List = None, fields: List = None,
exclude_fields: List = None,
) -> Optional[T]: ) -> Optional[T]:
item: Dict[str, Any] = {} item: Dict[str, Any] = {}
@ -74,10 +75,20 @@ class Model(NewBaseModel):
previous_table = cls.Meta.table.name previous_table = cls.Meta.table.name
item = cls.populate_nested_models_from_row( item = cls.populate_nested_models_from_row(
item, row, related_models, previous_table, fields item=item,
row=row,
related_models=related_models,
previous_table=previous_table,
fields=fields,
exclude_fields=exclude_fields,
) )
item = cls.extract_prefixed_table_columns( item = cls.extract_prefixed_table_columns(
item, row, table_prefix, fields, nested=table_prefix != "" item=item,
row=row,
table_prefix=table_prefix,
fields=fields,
exclude_fields=exclude_fields,
nested=table_prefix != "",
) )
instance: Optional[T] = cls(**item) if item.get( instance: Optional[T] = cls(**item) if item.get(
@ -86,13 +97,14 @@ class Model(NewBaseModel):
return instance return instance
@classmethod @classmethod
def populate_nested_models_from_row( def populate_nested_models_from_row( # noqa: CFQ002
cls, cls,
item: dict, item: dict,
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
related_models: Any, related_models: Any,
previous_table: sqlalchemy.Table, previous_table: sqlalchemy.Table,
fields: List = None, fields: List = None,
exclude_fields: List = None,
) -> dict: ) -> dict:
for related in related_models: for related in related_models:
if isinstance(related_models, dict) and related_models[related]: if isinstance(related_models, dict) and related_models[related]:
@ -103,12 +115,16 @@ class Model(NewBaseModel):
related_models=remainder, related_models=remainder,
previous_table=previous_table, previous_table=previous_table,
fields=fields, fields=fields,
exclude_fields=exclude_fields,
) )
item[model_cls.get_column_name_from_alias(first_part)] = child item[model_cls.get_column_name_from_alias(first_part)] = child
else: else:
model_cls = cls.Meta.model_fields[related].to model_cls = cls.Meta.model_fields[related].to
child = model_cls.from_row( child = model_cls.from_row(
row, previous_table=previous_table, fields=fields row,
previous_table=previous_table,
fields=fields,
exclude_fields=exclude_fields,
) )
item[model_cls.get_column_name_from_alias(related)] = child item[model_cls.get_column_name_from_alias(related)] = child
@ -121,6 +137,7 @@ class Model(NewBaseModel):
row: sqlalchemy.engine.result.ResultProxy, row: sqlalchemy.engine.result.ResultProxy,
table_prefix: str, table_prefix: str,
fields: List = None, fields: List = None,
exclude_fields: List = None,
nested: bool = False, nested: bool = False,
) -> dict: ) -> dict:
@ -128,8 +145,9 @@ class Model(NewBaseModel):
source = row._row if cls.db_backend_name() == "postgresql" else row source = row._row if cls.db_backend_name() == "postgresql" else row
selected_columns = cls.own_table_columns( selected_columns = cls.own_table_columns(
cls, fields or [], nested=nested, use_alias=True cls, fields or [], exclude_fields or [], nested=nested, use_alias=True
) )
for column in cls.Meta.table.columns: for column in cls.Meta.table.columns:
alias = cls.get_column_name_from_alias(column.name) alias = cls.get_column_name_from_alias(column.name)
if alias not in item and alias in selected_columns: if alias not in item and alias in selected_columns:

View File

@ -219,16 +219,29 @@ class ModelTableProxy:
def _get_not_nested_columns_from_fields( def _get_not_nested_columns_from_fields(
model: Type["Model"], model: Type["Model"],
fields: List, fields: List,
exclude_fields: List,
column_names: List[str], column_names: List[str],
use_alias: bool = False, use_alias: bool = False,
) -> List[str]: ) -> List[str]:
fields = [model.get_column_alias(k) if not use_alias else k for k in fields] fields = [model.get_column_alias(k) if not use_alias else k for k in fields]
columns = [name for name in fields if "__" not in name and name in column_names] fields = fields or column_names
exclude_fields = [
model.get_column_alias(k) if not use_alias else k for k in exclude_fields
]
columns = [
name
for name in fields
if "__" not in name and name in column_names and name not in exclude_fields
]
return columns return columns
@staticmethod @staticmethod
def _get_nested_columns_from_fields( def _get_nested_columns_from_fields(
model: Type["Model"], fields: List, use_alias: bool = False, model: Type["Model"],
fields: List,
exclude_fields: List,
column_names: List[str],
use_alias: bool = False,
) -> List[str]: ) -> List[str]:
model_name = f"{model.get_name()}__" model_name = f"{model.get_name()}__"
columns = [ columns = [
@ -236,37 +249,22 @@ class ModelTableProxy:
for name in fields for name in fields
if f"{model.get_name()}__" in name if f"{model.get_name()}__" in name
] ]
columns = columns or column_names
exclude_columns = [
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
for name in exclude_fields
if f"{model.get_name()}__" in name
]
columns = [model.get_column_alias(k) if not use_alias else k for k in columns] columns = [model.get_column_alias(k) if not use_alias else k for k in columns]
return columns exclude_columns = [
model.get_column_alias(k) if not use_alias else k for k in exclude_columns
]
return [column for column in columns if column not in exclude_columns]
@staticmethod @staticmethod
def own_table_columns( def _populate_pk_column(
model: Type["Model"], model: Type["Model"], columns: List[str], use_alias: bool = False,
fields: List,
nested: bool = False,
use_alias: bool = False,
) -> List[str]: ) -> List[str]:
column_names = [
model.get_column_name_from_alias(col.name) if use_alias else col.name
for col in model.Meta.table.columns
]
if not fields:
return column_names
if not nested:
columns = ModelTableProxy._get_not_nested_columns_from_fields(
model, fields, column_names, use_alias
)
else:
columns = ModelTableProxy._get_nested_columns_from_fields(
model, fields, use_alias
)
# if the model is in select and no columns in fields, all implied
if not columns:
columns = column_names
# always has to return pk column
pk_alias = ( pk_alias = (
model.get_column_alias(model.Meta.pkname) model.get_column_alias(model.Meta.pkname)
if not use_alias if not use_alias
@ -275,3 +273,42 @@ class ModelTableProxy:
if pk_alias not in columns: if pk_alias not in columns:
columns.append(pk_alias) columns.append(pk_alias)
return columns return columns
@staticmethod
def own_table_columns(
model: Type["Model"],
fields: List,
exclude_fields: List,
nested: bool = False,
use_alias: bool = False,
) -> List[str]:
column_names = [
model.get_column_name_from_alias(col.name) if use_alias else col.name
for col in model.Meta.table.columns
]
if not fields and not exclude_fields:
return column_names
if not nested:
columns = ModelTableProxy._get_not_nested_columns_from_fields(
model=model,
fields=fields,
exclude_fields=exclude_fields,
column_names=column_names,
use_alias=use_alias,
)
else:
columns = ModelTableProxy._get_nested_columns_from_fields(
model=model,
fields=fields,
exclude_fields=exclude_fields,
column_names=column_names,
use_alias=use_alias,
)
# always has to return pk column
columns = ModelTableProxy._populate_pk_column(
model=model, columns=columns, use_alias=use_alias
)
return columns

View File

@ -25,6 +25,7 @@ class SqlJoin:
select_from: sqlalchemy.sql.select, select_from: sqlalchemy.sql.select,
columns: List[sqlalchemy.Column], columns: List[sqlalchemy.Column],
fields: List, fields: List,
exclude_fields: List,
order_columns: Optional[List], order_columns: Optional[List],
sorted_orders: OrderedDict, sorted_orders: OrderedDict,
) -> None: ) -> None:
@ -32,6 +33,7 @@ class SqlJoin:
self.select_from = select_from self.select_from = select_from
self.columns = columns self.columns = columns
self.fields = fields self.fields = fields
self.exclude_fields = exclude_fields
self.order_columns = order_columns self.order_columns = order_columns
self.sorted_orders = sorted_orders self.sorted_orders = sorted_orders
@ -121,7 +123,7 @@ class SqlJoin:
self.get_order_bys(alias, to_table, pkname_alias, part) self.get_order_bys(alias, to_table, pkname_alias, part)
self_related_fields = model_cls.own_table_columns( self_related_fields = model_cls.own_table_columns(
model_cls, self.fields, nested=True, model_cls, self.fields, self.exclude_fields, nested=True,
) )
self.columns.extend( self.columns.extend(
self.relation_manager(model_cls).prefixed_columns( self.relation_manager(model_cls).prefixed_columns(

View File

@ -22,6 +22,7 @@ class Query:
limit_count: Optional[int], limit_count: Optional[int],
offset: Optional[int], offset: Optional[int],
fields: Optional[List], fields: Optional[List],
exclude_fields: Optional[List],
order_bys: Optional[List], order_bys: Optional[List],
) -> None: ) -> None:
self.query_offset = offset self.query_offset = offset
@ -30,6 +31,7 @@ class Query:
self.filter_clauses = filter_clauses[:] self.filter_clauses = filter_clauses[:]
self.exclude_clauses = exclude_clauses[:] self.exclude_clauses = exclude_clauses[:]
self.fields = fields[:] if fields else [] self.fields = fields[:] if fields else []
self.exclude_fields = exclude_fields[:] if exclude_fields else []
self.model_cls = model_cls self.model_cls = model_cls
self.table = self.model_cls.Meta.table self.table = self.model_cls.Meta.table
@ -68,7 +70,7 @@ class Query:
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
self_related_fields = self.model_cls.own_table_columns( self_related_fields = self.model_cls.own_table_columns(
self.model_cls, self.fields self.model_cls, self.fields, self.exclude_fields
) )
self.columns = self.model_cls.Meta.alias_manager.prefixed_columns( self.columns = self.model_cls.Meta.alias_manager.prefixed_columns(
"", self.table, self_related_fields "", self.table, self_related_fields
@ -88,6 +90,7 @@ class Query:
select_from=self.select_from, select_from=self.select_from,
columns=self.columns, columns=self.columns,
fields=self.fields, fields=self.fields,
exclude_fields=self.exclude_fields,
order_columns=self.order_columns, order_columns=self.order_columns,
sorted_orders=self.sorted_orders, sorted_orders=self.sorted_orders,
) )
@ -126,3 +129,4 @@ class Query:
self.columns = [] self.columns = []
self.used_aliases = [] self.used_aliases = []
self.fields = [] self.fields = []
self.exclude_fields = []

View File

@ -27,6 +27,7 @@ class QuerySet:
limit_count: int = None, limit_count: int = None,
offset: int = None, offset: int = None,
columns: List = None, columns: List = None,
exclude_columns: List = None,
order_bys: List = None, order_bys: List = None,
) -> None: ) -> None:
self.model_cls = model_cls self.model_cls = model_cls
@ -36,6 +37,7 @@ class QuerySet:
self.limit_count = limit_count self.limit_count = limit_count
self.query_offset = offset self.query_offset = offset
self._columns = columns or [] self._columns = columns or []
self._exclude_columns = exclude_columns or []
self.order_bys = order_bys or [] self.order_bys = order_bys or []
def __get__( def __get__(
@ -63,7 +65,10 @@ class QuerySet:
def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]:
result_rows = [ result_rows = [
self.model.from_row( self.model.from_row(
row, select_related=self._select_related, fields=self._columns row=row,
select_related=self._select_related,
fields=self._columns,
exclude_fields=self._exclude_columns,
) )
for row in rows for row in rows
] ]
@ -111,6 +116,7 @@ class QuerySet:
offset=self.query_offset, offset=self.query_offset,
limit_count=self.limit_count, limit_count=self.limit_count,
fields=self._columns, fields=self._columns,
exclude_fields=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
) )
exp = qry.build_select_expression() exp = qry.build_select_expression()
@ -139,6 +145,7 @@ class QuerySet:
limit_count=self.limit_count, limit_count=self.limit_count,
offset=self.query_offset, offset=self.query_offset,
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
) )
@ -158,6 +165,24 @@ class QuerySet:
limit_count=self.limit_count, limit_count=self.limit_count,
offset=self.query_offset, offset=self.query_offset,
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys,
)
def exclude_fields(self, columns: Union[List, str]) -> "QuerySet":
if not isinstance(columns, list):
columns = [columns]
columns = list(set(list(self._exclude_columns) + columns))
return self.__class__(
model_cls=self.model,
filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
select_related=self._select_related,
limit_count=self.limit_count,
offset=self.query_offset,
columns=self._columns,
exclude_columns=columns,
order_bys=self.order_bys, order_bys=self.order_bys,
) )
@ -174,6 +199,7 @@ class QuerySet:
limit_count=self.limit_count, limit_count=self.limit_count,
offset=self.query_offset, offset=self.query_offset,
columns=columns, columns=columns,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
) )
@ -190,6 +216,7 @@ class QuerySet:
limit_count=self.limit_count, limit_count=self.limit_count,
offset=self.query_offset, offset=self.query_offset,
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns,
order_bys=order_bys, order_bys=order_bys,
) )
@ -239,6 +266,7 @@ class QuerySet:
limit_count=limit_count, limit_count=limit_count,
offset=self.query_offset, offset=self.query_offset,
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
) )
@ -251,6 +279,7 @@ class QuerySet:
limit_count=self.limit_count, limit_count=self.limit_count,
offset=offset, offset=offset,
columns=self._columns, columns=self._columns,
exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
) )

View File

@ -0,0 +1,121 @@
from typing import Optional
import databases
import pydantic
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Company(ormar.Model):
class Meta:
tablename = "companies"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False)
founded: int = ormar.Integer(nullable=True)
class Car(ormar.Model):
class Meta:
tablename = "cars"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
manufacturer: Optional[Company] = ormar.ForeignKey(Company)
name: str = ormar.String(max_length=100)
year: int = ormar.Integer(nullable=True)
gearbox_type: str = ormar.String(max_length=20, nullable=True)
gears: int = ormar.Integer(nullable=True)
aircon_type: str = ormar.String(max_length=20, nullable=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_selecting_subset():
async with database:
async with database.transaction(force_rollback=True):
toyota = await Company.objects.create(name="Toyota", founded=1937)
await Car.objects.create(
manufacturer=toyota,
name="Corolla",
year=2020,
gearbox_type="Manual",
gears=5,
aircon_type="Manual",
)
await Car.objects.create(
manufacturer=toyota,
name="Yaris",
year=2019,
gearbox_type="Manual",
gears=5,
aircon_type="Manual",
)
await Car.objects.create(
manufacturer=toyota,
name="Supreme",
year=2020,
gearbox_type="Auto",
gears=6,
aircon_type="Auto",
)
all_cars = (
await Car.objects.select_related("manufacturer")
.exclude_fields(["gearbox_type", "gears", "aircon_type", "year", "company__founded"])
.all()
)
for car in all_cars:
assert all(
getattr(car, x) is None
for x in ["year", "gearbox_type", "gears", "aircon_type"]
)
assert car.manufacturer.name == "Toyota"
assert car.manufacturer.founded is None
all_cars = (
await Car.objects.select_related("manufacturer")
.exclude_fields("year")
.exclude_fields(["gearbox_type", "gears"])
.exclude_fields("aircon_type")
.all()
)
for car in all_cars:
assert all(
getattr(car, x) is None
for x in ["year", "gearbox_type", "gears", "aircon_type"]
)
assert car.manufacturer.name == "Toyota"
assert car.manufacturer.founded == 1937
all_cars_check = await Car.objects.select_related("manufacturer").all()
for car in all_cars_check:
assert all(
getattr(car, x) is not None
for x in ["year", "gearbox_type", "gears", "aircon_type"]
)
assert car.manufacturer.name == "Toyota"
assert car.manufacturer.founded == 1937
with pytest.raises(pydantic.error_wrappers.ValidationError):
# cannot exclude mandatory model columns - company__name in this example
await Car.objects.select_related("manufacturer").exclude_fields(
["company__name"]
).all()