add fields method limiting the models columns fetched from db

This commit is contained in:
collerek
2020-10-05 09:40:41 +02:00
parent 16b8e95975
commit 287b970a09
16 changed files with 209 additions and 38 deletions

View File

@ -34,6 +34,7 @@ class Model(NewBaseModel):
select_related: List = None,
related_models: Any = None,
previous_table: str = None,
fields: List = None,
) -> Optional["Model"]:
item: Dict[str, Any] = {}
@ -61,9 +62,11 @@ class Model(NewBaseModel):
previous_table = cls.Meta.table.name
item = cls.populate_nested_models_from_row(
item, row, related_models, previous_table
item, row, related_models, previous_table, fields
)
item = cls.extract_prefixed_table_columns(
item, row, table_prefix, fields, nested=table_prefix != ""
)
item = cls.extract_prefixed_table_columns(item, row, table_prefix)
instance = cls(**item) if item.get(cls.Meta.pkname, None) is not None else None
return instance
@ -75,33 +78,47 @@ class Model(NewBaseModel):
row: sqlalchemy.engine.ResultProxy,
related_models: Any,
previous_table: sqlalchemy.Table,
fields: List = None,
) -> dict:
for related in related_models:
if isinstance(related_models, dict) and related_models[related]:
first_part, remainder = related, related_models[related]
model_cls = cls.Meta.model_fields[first_part].to
child = model_cls.from_row(
row, related_models=remainder, previous_table=previous_table
row,
related_models=remainder,
previous_table=previous_table,
fields=fields,
)
item[first_part] = child
else:
model_cls = cls.Meta.model_fields[related].to
child = model_cls.from_row(row, previous_table=previous_table)
child = model_cls.from_row(
row, previous_table=previous_table, fields=fields
)
item[related] = child
return item
@classmethod
def extract_prefixed_table_columns( # noqa CCR001
cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str
cls,
item: dict,
row: sqlalchemy.engine.result.ResultProxy,
table_prefix: str,
fields: List = None,
nested: bool = False,
) -> dict:
# databases does not keep aliases in Record for postgres, change to raw row
source = row._row if isinstance(row, Record) else row
selected_columns = cls.own_table_columns(cls, fields or [], nested=nested)
for column in cls.Meta.table.columns:
if column.name not in item:
if column.name not in item and column.name in selected_columns:
prefixed_name = (
f'{table_prefix + "_" if table_prefix else ""}{column.name}'
)
# databases does not keep aliases in Record for postgres
source = row._row if isinstance(row, Record) else row
item[column.name] = source[prefixed_name]
return item

View File

@ -149,3 +149,32 @@ class ModelTableProxy:
cls.merge_two_instances(current_field, getattr(other, field)),
)
return other
@staticmethod
def own_table_columns(
model: Type["Model"], fields: List, nested: bool = False
) -> List[str]:
column_names = [col.name for col in model.Meta.table.columns]
if not fields:
return column_names
if not nested:
columns = [
name for name in fields if "__" not in name and name in column_names
]
else:
model_name = f"{model.get_name()}__"
columns = [
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
for name in fields
if f"{model.get_name()}__" in name
]
# if the model is in select and no columns in fields, all implied
if not columns:
columns = column_names
# always has to return pk column
if model.Meta.pkname not in columns:
columns.append(model.Meta.pkname)
return columns

View File

@ -93,7 +93,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
object.__setattr__(self, "__dict__", values)
object.__setattr__(self, "__fields_set__", fields_set)
# register the related models after initialization
# register the columns models after initialization
for related in self.extract_related_names():
self.Meta.model_fields[related].expand_relationship(
kwargs.get(related), self, to_register=True

View File

@ -24,11 +24,13 @@ class SqlJoin:
select_from: sqlalchemy.sql.select,
order_bys: List[sqlalchemy.sql.elements.TextClause],
columns: List[sqlalchemy.Column],
fields: List,
) -> None:
self.used_aliases = used_aliases
self.select_from = select_from
self.order_bys = order_bys
self.columns = columns
self.fields = fields
@staticmethod
def relation_manager(model_cls: Type["Model"]) -> AliasManager:
@ -105,9 +107,12 @@ class SqlJoin:
self.select_from, target_table, on_clause
)
self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}"))
self_related_fields = model_cls.own_table_columns(
model_cls, self.fields, nested=True
)
self.columns.extend(
self.relation_manager(model_cls).prefixed_columns(
alias, model_cls.Meta.table
alias, model_cls.Meta.table, self_related_fields
)
)
self.used_aliases.append(alias)

View File

@ -20,12 +20,14 @@ class Query:
select_related: List,
limit_count: Optional[int],
offset: Optional[int],
fields: Optional[List],
) -> None:
self.query_offset = offset
self.limit_count = limit_count
self._select_related = select_related[:]
self.filter_clauses = filter_clauses[:]
self.exclude_clauses = exclude_clauses[:]
self.fields = fields[:] if fields else []
self.model_cls = model_cls
self.table = self.model_cls.Meta.table
@ -41,7 +43,12 @@ class Query:
return f"{self.table.name}.{self.model_cls.Meta.pkname}"
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
self.columns = list(self.table.columns)
self_related_fields = self.model_cls.own_table_columns(
self.model_cls, self.fields
)
self.columns = self.model_cls.Meta.alias_manager.prefixed_columns(
"", self.table, self_related_fields
)
self.order_bys = [text(self.prefixed_pk_name)]
self.select_from = self.table
@ -57,6 +64,7 @@ class Query:
select_from=self.select_from,
columns=self.columns,
order_bys=self.order_bys,
fields=self.fields,
)
(
@ -93,3 +101,4 @@ class Query:
self.columns = []
self.order_bys = []
self.used_aliases = []
self.fields = []

View File

@ -26,6 +26,7 @@ class QuerySet:
select_related: List = None,
limit_count: int = None,
offset: int = None,
columns: List = None,
) -> None:
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
@ -33,6 +34,7 @@ class QuerySet:
self._select_related = [] if select_related is None else select_related
self.limit_count = limit_count
self.query_offset = offset
self._columns = columns or []
self.order_bys = None
def __get__(
@ -59,7 +61,9 @@ class QuerySet:
def _process_query_result_rows(self, rows: List) -> List[Optional["Model"]]:
result_rows = [
self.model.from_row(row, select_related=self._select_related)
self.model.from_row(
row, select_related=self._select_related, fields=self._columns
)
for row in rows
]
if result_rows:
@ -104,6 +108,7 @@ class QuerySet:
exclude_clauses=self.exclude_clauses,
offset=self.query_offset,
limit_count=self.limit_count,
fields=self._columns,
)
exp = qry.build_select_expression()
# print(exp.compile(compile_kwargs={"literal_binds": True}))
@ -130,6 +135,7 @@ class QuerySet:
select_related=select_related,
limit_count=self.limit_count,
offset=self.query_offset,
columns=self._columns,
)
def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003
@ -147,6 +153,22 @@ class QuerySet:
select_related=related,
limit_count=self.limit_count,
offset=self.query_offset,
columns=self._columns,
)
def fields(self, columns: Union[List, str]) -> "QuerySet":
if not isinstance(columns, list):
columns = [columns]
columns = list(set(list(self._columns) + columns))
return self.__class__(
model_cls=self.model,
filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
select_related=self._select_related,
limit_count=self.limit_count,
offset=self.query_offset,
columns=columns,
)
async def exists(self) -> bool:
@ -193,6 +215,7 @@ class QuerySet:
select_related=self._select_related,
limit_count=limit_count,
offset=self.query_offset,
columns=self._columns,
)
def offset(self, offset: int) -> "QuerySet":
@ -203,6 +226,7 @@ class QuerySet:
select_related=self._select_related,
limit_count=self.limit_count,
offset=offset,
columns=self._columns,
)
async def first(self, **kwargs: Any) -> "Model":

View File

@ -17,10 +17,18 @@ class AliasManager:
self._aliases: Dict[str, str] = dict()
@staticmethod
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
def prefixed_columns(
alias: str, table: sqlalchemy.Table, fields: List = None
) -> List[text]:
alias = f"{alias}_" if alias else ""
all_columns = (
table.columns
if not fields
else [col for col in table.columns if col.name in fields]
)
return [
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
for column in table.columns
text(f"{alias}{table.name}.{column.name} as {alias}{column.name}")
for column in all_columns
]
@staticmethod