add fields method limiting the models columns fetched from db
This commit is contained in:
@ -241,19 +241,19 @@ await post.categories.add(news)
|
||||
# or from the other end:
|
||||
await news.posts.add(post)
|
||||
|
||||
# Creating related object from instance:
|
||||
# Creating columns object from instance:
|
||||
await post.categories.create(name="Tips")
|
||||
assert len(await post.categories.all()) == 2
|
||||
|
||||
# Many to many relation exposes a list of related models
|
||||
# Many to many relation exposes a list of columns models
|
||||
# and an API of the Queryset:
|
||||
assert news == await post.categories.get(name="News")
|
||||
|
||||
# with all Queryset methods - filtering, selecting related, counting etc.
|
||||
# with all Queryset methods - filtering, selecting columns, counting etc.
|
||||
await news.posts.filter(title__contains="M2M").all()
|
||||
await Category.objects.filter(posts__author=guido).get()
|
||||
|
||||
# related models of many to many relation can be prefetched
|
||||
# columns models of many to many relation can be prefetched
|
||||
news_posts = await news.posts.select_related("author").all()
|
||||
assert news_posts[0].author == guido
|
||||
|
||||
|
||||
@ -76,7 +76,7 @@ track = Track.objects.filter(name="The Bird").get()
|
||||
# will return a track with name equal to 'The Bird'
|
||||
|
||||
tracks = Track.objects.filter(album__name="Fantasies").all()
|
||||
# will return all tracks where the related album name = 'Fantasies'
|
||||
# will return all tracks where the columns album name = 'Fantasies'
|
||||
```
|
||||
|
||||
You can use special filter suffix to change the filter operands:
|
||||
@ -106,7 +106,7 @@ To chain related `Models` relation use double underscore.
|
||||
|
||||
```python
|
||||
album = await Album.objects.select_related("tracks").all()
|
||||
# will return album will all related tracks
|
||||
# will return album will all columns tracks
|
||||
```
|
||||
|
||||
You can provide a string or a list of strings
|
||||
|
||||
@ -29,7 +29,7 @@ class Track(ormar.Model):
|
||||
album = await Album.objects.create(name="Brooklyn")
|
||||
await Track.objects.create(album=album, title="The Bird", position=1)
|
||||
|
||||
# explicit preload of related Album Model
|
||||
# explicit preload of columns Album Model
|
||||
track = await Track.objects.select_related("album").get(title="The Bird")
|
||||
assert track.album.name == 'Brooklyn'
|
||||
# Will produce: True
|
||||
|
||||
@ -33,7 +33,7 @@ print('name' in course.__dict__)
|
||||
print(course.name)
|
||||
# Math <- value returned from underlying pydantic model
|
||||
print('department' in course.__dict__)
|
||||
# False <- related model is not stored on Course instance
|
||||
# False <- columns model is not stored on Course instance
|
||||
print(course.department)
|
||||
# Department(id=None, name='Science') <- Department model
|
||||
# returned from AliasManager
|
||||
|
||||
@ -34,6 +34,7 @@ class Model(NewBaseModel):
|
||||
select_related: List = None,
|
||||
related_models: Any = None,
|
||||
previous_table: str = None,
|
||||
fields: List = None,
|
||||
) -> Optional["Model"]:
|
||||
|
||||
item: Dict[str, Any] = {}
|
||||
@ -61,9 +62,11 @@ class Model(NewBaseModel):
|
||||
previous_table = cls.Meta.table.name
|
||||
|
||||
item = cls.populate_nested_models_from_row(
|
||||
item, row, related_models, previous_table
|
||||
item, row, related_models, previous_table, fields
|
||||
)
|
||||
item = cls.extract_prefixed_table_columns(
|
||||
item, row, table_prefix, fields, nested=table_prefix != ""
|
||||
)
|
||||
item = cls.extract_prefixed_table_columns(item, row, table_prefix)
|
||||
|
||||
instance = cls(**item) if item.get(cls.Meta.pkname, None) is not None else None
|
||||
return instance
|
||||
@ -75,33 +78,47 @@ class Model(NewBaseModel):
|
||||
row: sqlalchemy.engine.ResultProxy,
|
||||
related_models: Any,
|
||||
previous_table: sqlalchemy.Table,
|
||||
fields: List = None,
|
||||
) -> dict:
|
||||
for related in related_models:
|
||||
if isinstance(related_models, dict) and related_models[related]:
|
||||
first_part, remainder = related, related_models[related]
|
||||
model_cls = cls.Meta.model_fields[first_part].to
|
||||
child = model_cls.from_row(
|
||||
row, related_models=remainder, previous_table=previous_table
|
||||
row,
|
||||
related_models=remainder,
|
||||
previous_table=previous_table,
|
||||
fields=fields,
|
||||
)
|
||||
item[first_part] = child
|
||||
else:
|
||||
model_cls = cls.Meta.model_fields[related].to
|
||||
child = model_cls.from_row(row, previous_table=previous_table)
|
||||
child = model_cls.from_row(
|
||||
row, previous_table=previous_table, fields=fields
|
||||
)
|
||||
item[related] = child
|
||||
|
||||
return item
|
||||
|
||||
@classmethod
|
||||
def extract_prefixed_table_columns( # noqa CCR001
|
||||
cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str
|
||||
cls,
|
||||
item: dict,
|
||||
row: sqlalchemy.engine.result.ResultProxy,
|
||||
table_prefix: str,
|
||||
fields: List = None,
|
||||
nested: bool = False,
|
||||
) -> dict:
|
||||
|
||||
# databases does not keep aliases in Record for postgres, change to raw row
|
||||
source = row._row if isinstance(row, Record) else row
|
||||
|
||||
selected_columns = cls.own_table_columns(cls, fields or [], nested=nested)
|
||||
for column in cls.Meta.table.columns:
|
||||
if column.name not in item:
|
||||
if column.name not in item and column.name in selected_columns:
|
||||
prefixed_name = (
|
||||
f'{table_prefix + "_" if table_prefix else ""}{column.name}'
|
||||
)
|
||||
# databases does not keep aliases in Record for postgres
|
||||
source = row._row if isinstance(row, Record) else row
|
||||
item[column.name] = source[prefixed_name]
|
||||
|
||||
return item
|
||||
|
||||
@ -149,3 +149,32 @@ class ModelTableProxy:
|
||||
cls.merge_two_instances(current_field, getattr(other, field)),
|
||||
)
|
||||
return other
|
||||
|
||||
@staticmethod
|
||||
def own_table_columns(
|
||||
model: Type["Model"], fields: List, nested: bool = False
|
||||
) -> List[str]:
|
||||
column_names = [col.name for col in model.Meta.table.columns]
|
||||
if not fields:
|
||||
return column_names
|
||||
|
||||
if not nested:
|
||||
columns = [
|
||||
name for name in fields if "__" not in name and name in column_names
|
||||
]
|
||||
else:
|
||||
model_name = f"{model.get_name()}__"
|
||||
columns = [
|
||||
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
|
||||
for name in fields
|
||||
if f"{model.get_name()}__" in name
|
||||
]
|
||||
|
||||
# if the model is in select and no columns in fields, all implied
|
||||
if not columns:
|
||||
columns = column_names
|
||||
|
||||
# always has to return pk column
|
||||
if model.Meta.pkname not in columns:
|
||||
columns.append(model.Meta.pkname)
|
||||
return columns
|
||||
|
||||
@ -93,7 +93,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
object.__setattr__(self, "__dict__", values)
|
||||
object.__setattr__(self, "__fields_set__", fields_set)
|
||||
|
||||
# register the related models after initialization
|
||||
# register the columns models after initialization
|
||||
for related in self.extract_related_names():
|
||||
self.Meta.model_fields[related].expand_relationship(
|
||||
kwargs.get(related), self, to_register=True
|
||||
|
||||
@ -24,11 +24,13 @@ class SqlJoin:
|
||||
select_from: sqlalchemy.sql.select,
|
||||
order_bys: List[sqlalchemy.sql.elements.TextClause],
|
||||
columns: List[sqlalchemy.Column],
|
||||
fields: List,
|
||||
) -> None:
|
||||
self.used_aliases = used_aliases
|
||||
self.select_from = select_from
|
||||
self.order_bys = order_bys
|
||||
self.columns = columns
|
||||
self.fields = fields
|
||||
|
||||
@staticmethod
|
||||
def relation_manager(model_cls: Type["Model"]) -> AliasManager:
|
||||
@ -105,9 +107,12 @@ class SqlJoin:
|
||||
self.select_from, target_table, on_clause
|
||||
)
|
||||
self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}"))
|
||||
self_related_fields = model_cls.own_table_columns(
|
||||
model_cls, self.fields, nested=True
|
||||
)
|
||||
self.columns.extend(
|
||||
self.relation_manager(model_cls).prefixed_columns(
|
||||
alias, model_cls.Meta.table
|
||||
alias, model_cls.Meta.table, self_related_fields
|
||||
)
|
||||
)
|
||||
self.used_aliases.append(alias)
|
||||
|
||||
@ -20,12 +20,14 @@ class Query:
|
||||
select_related: List,
|
||||
limit_count: Optional[int],
|
||||
offset: Optional[int],
|
||||
fields: Optional[List],
|
||||
) -> None:
|
||||
self.query_offset = offset
|
||||
self.limit_count = limit_count
|
||||
self._select_related = select_related[:]
|
||||
self.filter_clauses = filter_clauses[:]
|
||||
self.exclude_clauses = exclude_clauses[:]
|
||||
self.fields = fields[:] if fields else []
|
||||
|
||||
self.model_cls = model_cls
|
||||
self.table = self.model_cls.Meta.table
|
||||
@ -41,7 +43,12 @@ class Query:
|
||||
return f"{self.table.name}.{self.model_cls.Meta.pkname}"
|
||||
|
||||
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
|
||||
self.columns = list(self.table.columns)
|
||||
self_related_fields = self.model_cls.own_table_columns(
|
||||
self.model_cls, self.fields
|
||||
)
|
||||
self.columns = self.model_cls.Meta.alias_manager.prefixed_columns(
|
||||
"", self.table, self_related_fields
|
||||
)
|
||||
self.order_bys = [text(self.prefixed_pk_name)]
|
||||
self.select_from = self.table
|
||||
|
||||
@ -57,6 +64,7 @@ class Query:
|
||||
select_from=self.select_from,
|
||||
columns=self.columns,
|
||||
order_bys=self.order_bys,
|
||||
fields=self.fields,
|
||||
)
|
||||
|
||||
(
|
||||
@ -93,3 +101,4 @@ class Query:
|
||||
self.columns = []
|
||||
self.order_bys = []
|
||||
self.used_aliases = []
|
||||
self.fields = []
|
||||
|
||||
@ -26,6 +26,7 @@ class QuerySet:
|
||||
select_related: List = None,
|
||||
limit_count: int = None,
|
||||
offset: int = None,
|
||||
columns: List = None,
|
||||
) -> None:
|
||||
self.model_cls = model_cls
|
||||
self.filter_clauses = [] if filter_clauses is None else filter_clauses
|
||||
@ -33,6 +34,7 @@ class QuerySet:
|
||||
self._select_related = [] if select_related is None else select_related
|
||||
self.limit_count = limit_count
|
||||
self.query_offset = offset
|
||||
self._columns = columns or []
|
||||
self.order_bys = None
|
||||
|
||||
def __get__(
|
||||
@ -59,7 +61,9 @@ class QuerySet:
|
||||
|
||||
def _process_query_result_rows(self, rows: List) -> List[Optional["Model"]]:
|
||||
result_rows = [
|
||||
self.model.from_row(row, select_related=self._select_related)
|
||||
self.model.from_row(
|
||||
row, select_related=self._select_related, fields=self._columns
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
if result_rows:
|
||||
@ -104,6 +108,7 @@ class QuerySet:
|
||||
exclude_clauses=self.exclude_clauses,
|
||||
offset=self.query_offset,
|
||||
limit_count=self.limit_count,
|
||||
fields=self._columns,
|
||||
)
|
||||
exp = qry.build_select_expression()
|
||||
# print(exp.compile(compile_kwargs={"literal_binds": True}))
|
||||
@ -130,6 +135,7 @@ class QuerySet:
|
||||
select_related=select_related,
|
||||
limit_count=self.limit_count,
|
||||
offset=self.query_offset,
|
||||
columns=self._columns,
|
||||
)
|
||||
|
||||
def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003
|
||||
@ -147,6 +153,22 @@ class QuerySet:
|
||||
select_related=related,
|
||||
limit_count=self.limit_count,
|
||||
offset=self.query_offset,
|
||||
columns=self._columns,
|
||||
)
|
||||
|
||||
def fields(self, columns: Union[List, str]) -> "QuerySet":
|
||||
if not isinstance(columns, list):
|
||||
columns = [columns]
|
||||
|
||||
columns = list(set(list(self._columns) + columns))
|
||||
return self.__class__(
|
||||
model_cls=self.model,
|
||||
filter_clauses=self.filter_clauses,
|
||||
exclude_clauses=self.exclude_clauses,
|
||||
select_related=self._select_related,
|
||||
limit_count=self.limit_count,
|
||||
offset=self.query_offset,
|
||||
columns=columns,
|
||||
)
|
||||
|
||||
async def exists(self) -> bool:
|
||||
@ -193,6 +215,7 @@ class QuerySet:
|
||||
select_related=self._select_related,
|
||||
limit_count=limit_count,
|
||||
offset=self.query_offset,
|
||||
columns=self._columns,
|
||||
)
|
||||
|
||||
def offset(self, offset: int) -> "QuerySet":
|
||||
@ -203,6 +226,7 @@ class QuerySet:
|
||||
select_related=self._select_related,
|
||||
limit_count=self.limit_count,
|
||||
offset=offset,
|
||||
columns=self._columns,
|
||||
)
|
||||
|
||||
async def first(self, **kwargs: Any) -> "Model":
|
||||
|
||||
@ -17,10 +17,18 @@ class AliasManager:
|
||||
self._aliases: Dict[str, str] = dict()
|
||||
|
||||
@staticmethod
|
||||
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
|
||||
def prefixed_columns(
|
||||
alias: str, table: sqlalchemy.Table, fields: List = None
|
||||
) -> List[text]:
|
||||
alias = f"{alias}_" if alias else ""
|
||||
all_columns = (
|
||||
table.columns
|
||||
if not fields
|
||||
else [col for col in table.columns if col.name in fields]
|
||||
)
|
||||
return [
|
||||
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
|
||||
for column in table.columns
|
||||
text(f"{alias}{table.name}.{column.name} as {alias}{column.name}")
|
||||
for column in all_columns
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
|
||||
4
setup.py
4
setup.py
@ -42,10 +42,10 @@ setup(
|
||||
version=get_version(PACKAGE),
|
||||
url=URL,
|
||||
license="MIT",
|
||||
description="An simple async ORM with Fastapi in mind.",
|
||||
description="An simple async ORM with fastapi in mind and pydantic validation.",
|
||||
long_description=get_long_description(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords=['ORM', 'sqlalchemy', 'fastapi', 'pydantic', 'databases'],
|
||||
keywords=['orm', 'sqlalchemy', 'fastapi', 'pydantic', 'databases', 'async', 'alembic'],
|
||||
author="collerek",
|
||||
author_email="collerek@gmail.com",
|
||||
packages=get_packages(PACKAGE),
|
||||
|
||||
@ -1,9 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
import databases
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
from pydantic import root_validator, validator
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
|
||||
@ -236,8 +233,8 @@ async def test_fk_filter():
|
||||
|
||||
tracks = (
|
||||
await Track.objects.select_related("album")
|
||||
.filter(album__name="Fantasies")
|
||||
.all()
|
||||
.filter(album__name="Fantasies")
|
||||
.all()
|
||||
)
|
||||
assert len(tracks) == 3
|
||||
for track in tracks:
|
||||
@ -245,8 +242,8 @@ async def test_fk_filter():
|
||||
|
||||
tracks = (
|
||||
await Track.objects.select_related("album")
|
||||
.filter(album__name__icontains="fan")
|
||||
.all()
|
||||
.filter(album__name__icontains="fan")
|
||||
.all()
|
||||
)
|
||||
assert len(tracks) == 3
|
||||
for track in tracks:
|
||||
@ -291,8 +288,8 @@ async def test_multiple_fk():
|
||||
|
||||
members = (
|
||||
await Member.objects.select_related("team__org")
|
||||
.filter(team__org__ident="ACME Ltd")
|
||||
.all()
|
||||
.filter(team__org__ident="ACME Ltd")
|
||||
.all()
|
||||
)
|
||||
assert len(members) == 4
|
||||
for member in members:
|
||||
@ -324,8 +321,8 @@ async def test_pk_filter():
|
||||
|
||||
tracks = (
|
||||
await Track.objects.select_related("album")
|
||||
.filter(position=2, album__name="Test")
|
||||
.all()
|
||||
.filter(position=2, album__name="Test")
|
||||
.all()
|
||||
)
|
||||
assert len(tracks) == 1
|
||||
|
||||
|
||||
@ -89,7 +89,7 @@ async def test_assigning_related_objects(cleanup):
|
||||
# or from the other end:
|
||||
await news.posts.add(post)
|
||||
|
||||
# Creating related object from instance:
|
||||
# Creating columns object from instance:
|
||||
await post.categories.create(name="Tips")
|
||||
assert len(post.categories) == 2
|
||||
|
||||
@ -148,7 +148,7 @@ async def test_removal_of_the_relations(cleanup):
|
||||
await news.posts.remove(post)
|
||||
assert len(await news.posts.all()) == 0
|
||||
|
||||
# Remove all related objects:
|
||||
# Remove all columns objects:
|
||||
await post.categories.add(news)
|
||||
await post.categories.clear()
|
||||
assert len(await post.categories.all()) == 0
|
||||
|
||||
82
tests/test_selecting_subset_of_columns.py
Normal file
82
tests/test_selecting_subset_of_columns.py
Normal file
@ -0,0 +1,82 @@
|
||||
import databases
|
||||
import pydantic
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
class Company(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "companies"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: ormar.Integer(primary_key=True)
|
||||
name: ormar.String(max_length=100)
|
||||
founded: ormar.Integer(nullable=True)
|
||||
|
||||
|
||||
class Car(ormar.Model):
|
||||
class Meta:
|
||||
tablename = "cars"
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
id: ormar.Integer(primary_key=True)
|
||||
manufacturer: ormar.ForeignKey(Company)
|
||||
name: ormar.String(max_length=100)
|
||||
year: ormar.Integer(nullable=True)
|
||||
gearbox_type: ormar.String(max_length=20, nullable=True)
|
||||
gears: ormar.Integer(nullable=True)
|
||||
aircon_type: ormar.String(max_length=20, nullable=True)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def create_test_database():
|
||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||
metadata.drop_all(engine)
|
||||
metadata.create_all(engine)
|
||||
yield
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selecting_subset():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
toyota = await Company.objects.create(name="Toyota", founded=1937)
|
||||
await Car.objects.create(manufacturer=toyota, name="Corolla", year=2020, gearbox_type='Manual', gears=5,
|
||||
aircon_type='Manual')
|
||||
await Car.objects.create(manufacturer=toyota, name="Yaris", year=2019, gearbox_type='Manual', gears=5,
|
||||
aircon_type='Manual')
|
||||
await Car.objects.create(manufacturer=toyota, name="Supreme", year=2020, gearbox_type='Auto', gears=6,
|
||||
aircon_type='Auto')
|
||||
|
||||
all_cars = await Car.objects.select_related('manufacturer').fields(['id', 'name', 'company__name']).all()
|
||||
for car in all_cars:
|
||||
assert all(getattr(car, x) is None for x in ['year', 'gearbox_type', 'gears', 'aircon_type'])
|
||||
assert car.manufacturer.name == 'Toyota'
|
||||
assert car.manufacturer.founded is None
|
||||
|
||||
all_cars = await Car.objects.select_related('manufacturer').fields('id').fields(
|
||||
['name', 'manufacturer']).all()
|
||||
for car in all_cars:
|
||||
assert all(getattr(car, x) is None for x in ['year', 'gearbox_type', 'gears', 'aircon_type'])
|
||||
assert car.manufacturer.name == 'Toyota'
|
||||
assert car.manufacturer.founded == 1937
|
||||
|
||||
all_cars_check = await Car.objects.select_related('manufacturer').all()
|
||||
for car in all_cars_check:
|
||||
assert all(getattr(car, x) is not None for x in ['year', 'gearbox_type', 'gears', 'aircon_type'])
|
||||
assert car.manufacturer.name == 'Toyota'
|
||||
assert car.manufacturer.founded == 1937
|
||||
|
||||
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
||||
# cannot exclude mandatory model columns - company__name in this example
|
||||
await Car.objects.select_related('manufacturer').fields(
|
||||
['id', 'name', 'company__founded']).all()
|
||||
Reference in New Issue
Block a user