allow passing a dict and set to fields and exclude_fields, store it as dict
This commit is contained in:
39
ormar/models/excludable.py
Normal file
39
ormar/models/excludable.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Dict, Set, Union
|
||||
|
||||
|
||||
class Excludable:
|
||||
@staticmethod
|
||||
def get_excluded(
|
||||
exclude: Union[Set, Dict, None], key: str = None
|
||||
) -> Union[Set, Dict, None]:
|
||||
if isinstance(exclude, dict):
|
||||
return exclude.get(key, {})
|
||||
return exclude
|
||||
|
||||
@staticmethod
|
||||
def get_included(
|
||||
include: Union[Set, Dict, None], key: str = None
|
||||
) -> Union[Set, Dict, None]:
|
||||
return Excludable.get_excluded(exclude=include, key=key)
|
||||
|
||||
@staticmethod
|
||||
def is_excluded(exclude: Union[Set, Dict, None], key: str = None) -> bool:
|
||||
if exclude is None:
|
||||
return False
|
||||
to_exclude = Excludable.get_excluded(exclude=exclude, key=key)
|
||||
if isinstance(to_exclude, Set):
|
||||
return key in to_exclude
|
||||
elif to_exclude is ...:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_included(include: Union[Set, Dict, None], key: str = None) -> bool:
|
||||
if include is None:
|
||||
return True
|
||||
to_include = Excludable.get_included(include=include, key=key)
|
||||
if isinstance(to_include, Set):
|
||||
return key in to_include
|
||||
elif to_include is ...:
|
||||
return True
|
||||
return False
|
||||
@ -1,5 +1,5 @@
|
||||
import itertools
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar
|
||||
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
@ -47,8 +47,8 @@ class Model(NewBaseModel):
|
||||
select_related: List = None,
|
||||
related_models: Any = None,
|
||||
previous_table: str = None,
|
||||
fields: List = None,
|
||||
exclude_fields: List = None,
|
||||
fields: Optional[Union[Dict, Set]] = None,
|
||||
exclude_fields: Optional[Union[Dict, Set]] = None,
|
||||
) -> Optional[T]:
|
||||
|
||||
item: Dict[str, Any] = {}
|
||||
@ -88,7 +88,6 @@ class Model(NewBaseModel):
|
||||
table_prefix=table_prefix,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
nested=table_prefix != "",
|
||||
)
|
||||
|
||||
instance: Optional[T] = cls(**item) if item.get(
|
||||
@ -103,13 +102,17 @@ class Model(NewBaseModel):
|
||||
row: sqlalchemy.engine.ResultProxy,
|
||||
related_models: Any,
|
||||
previous_table: sqlalchemy.Table,
|
||||
fields: List = None,
|
||||
exclude_fields: List = None,
|
||||
fields: Optional[Union[Dict, Set]] = None,
|
||||
exclude_fields: Optional[Union[Dict, Set]] = None,
|
||||
) -> dict:
|
||||
for related in related_models:
|
||||
if isinstance(related_models, dict) and related_models[related]:
|
||||
first_part, remainder = related, related_models[related]
|
||||
model_cls = cls.Meta.model_fields[first_part].to
|
||||
|
||||
fields = cls.get_included(fields, first_part)
|
||||
exclude_fields = cls.get_excluded(exclude_fields, first_part)
|
||||
|
||||
child = model_cls.from_row(
|
||||
row,
|
||||
related_models=remainder,
|
||||
@ -120,6 +123,8 @@ class Model(NewBaseModel):
|
||||
item[model_cls.get_column_name_from_alias(first_part)] = child
|
||||
else:
|
||||
model_cls = cls.Meta.model_fields[related].to
|
||||
fields = cls.get_included(fields, related)
|
||||
exclude_fields = cls.get_excluded(exclude_fields, related)
|
||||
child = model_cls.from_row(
|
||||
row,
|
||||
previous_table=previous_table,
|
||||
@ -136,16 +141,18 @@ class Model(NewBaseModel):
|
||||
item: dict,
|
||||
row: sqlalchemy.engine.result.ResultProxy,
|
||||
table_prefix: str,
|
||||
fields: List = None,
|
||||
exclude_fields: List = None,
|
||||
nested: bool = False,
|
||||
fields: Optional[Union[Dict, Set]] = None,
|
||||
exclude_fields: Optional[Union[Dict, Set]] = None,
|
||||
) -> dict:
|
||||
|
||||
# databases does not keep aliases in Record for postgres, change to raw row
|
||||
source = row._row if cls.db_backend_name() == "postgresql" else row
|
||||
|
||||
selected_columns = cls.own_table_columns(
|
||||
cls, fields or [], exclude_fields or [], nested=nested, use_alias=True
|
||||
model=cls,
|
||||
fields=fields or {},
|
||||
exclude_fields=exclude_fields or {},
|
||||
use_alias=False,
|
||||
)
|
||||
|
||||
for column in cls.Meta.table.columns:
|
||||
|
||||
@ -1,6 +1,16 @@
|
||||
import inspect
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Sequence, Set, TYPE_CHECKING, Type, TypeVar, Union
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import RelationshipInstanceError
|
||||
@ -65,14 +75,14 @@ class ModelTableProxy:
|
||||
@classmethod
|
||||
def get_column_alias(cls, field_name: str) -> str:
|
||||
field = cls.Meta.model_fields.get(field_name)
|
||||
if field and field.alias is not None:
|
||||
if field is not None and field.alias is not None:
|
||||
return field.alias
|
||||
return field_name
|
||||
|
||||
@classmethod
|
||||
def get_column_name_from_alias(cls, alias: str) -> str:
|
||||
for field_name, field in cls.Meta.model_fields.items():
|
||||
if field and field.alias == alias:
|
||||
if field is not None and field.alias == alias:
|
||||
return field_name
|
||||
return alias # if not found it's not an alias but actual name
|
||||
|
||||
@ -211,59 +221,13 @@ class ModelTableProxy:
|
||||
)
|
||||
return other
|
||||
|
||||
@staticmethod
|
||||
def _get_not_nested_columns_from_fields(
|
||||
model: Type["Model"],
|
||||
fields: List,
|
||||
exclude_fields: List,
|
||||
column_names: List[str],
|
||||
use_alias: bool = False,
|
||||
) -> List[str]:
|
||||
fields = [model.get_column_alias(k) if not use_alias else k for k in fields]
|
||||
fields = fields or column_names
|
||||
exclude_fields = [
|
||||
model.get_column_alias(k) if not use_alias else k for k in exclude_fields
|
||||
]
|
||||
columns = [
|
||||
name
|
||||
for name in fields
|
||||
if "__" not in name and name in column_names and name not in exclude_fields
|
||||
]
|
||||
return columns
|
||||
|
||||
@staticmethod
|
||||
def _get_nested_columns_from_fields(
|
||||
model: Type["Model"],
|
||||
fields: List,
|
||||
exclude_fields: List,
|
||||
column_names: List[str],
|
||||
use_alias: bool = False,
|
||||
) -> List[str]:
|
||||
model_name = f"{model.get_name()}__"
|
||||
columns = [
|
||||
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
|
||||
for name in fields
|
||||
if f"{model.get_name()}__" in name
|
||||
]
|
||||
columns = columns or column_names
|
||||
exclude_columns = [
|
||||
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
|
||||
for name in exclude_fields
|
||||
if f"{model.get_name()}__" in name
|
||||
]
|
||||
columns = [model.get_column_alias(k) if not use_alias else k for k in columns]
|
||||
exclude_columns = [
|
||||
model.get_column_alias(k) if not use_alias else k for k in exclude_columns
|
||||
]
|
||||
return [column for column in columns if column not in exclude_columns]
|
||||
|
||||
@staticmethod
|
||||
def _populate_pk_column(
|
||||
model: Type["Model"], columns: List[str], use_alias: bool = False,
|
||||
) -> List[str]:
|
||||
pk_alias = (
|
||||
model.get_column_alias(model.Meta.pkname)
|
||||
if not use_alias
|
||||
if use_alias
|
||||
else model.Meta.pkname
|
||||
)
|
||||
if pk_alias not in columns:
|
||||
@ -273,34 +237,30 @@ class ModelTableProxy:
|
||||
@staticmethod
|
||||
def own_table_columns(
|
||||
model: Type["Model"],
|
||||
fields: List,
|
||||
exclude_fields: List,
|
||||
nested: bool = False,
|
||||
fields: Optional[Union[Set, Dict]],
|
||||
exclude_fields: Optional[Union[Set, Dict]],
|
||||
use_alias: bool = False,
|
||||
) -> List[str]:
|
||||
column_names = [
|
||||
model.get_column_name_from_alias(col.name) if use_alias else col.name
|
||||
columns = [
|
||||
model.get_column_name_from_alias(col.name) if not use_alias else col.name
|
||||
for col in model.Meta.table.columns
|
||||
]
|
||||
if not fields and not exclude_fields:
|
||||
return column_names
|
||||
|
||||
if not nested:
|
||||
columns = ModelTableProxy._get_not_nested_columns_from_fields(
|
||||
model=model,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
column_names=column_names,
|
||||
use_alias=use_alias,
|
||||
)
|
||||
else:
|
||||
columns = ModelTableProxy._get_nested_columns_from_fields(
|
||||
model=model,
|
||||
fields=fields,
|
||||
exclude_fields=exclude_fields,
|
||||
column_names=column_names,
|
||||
use_alias=use_alias,
|
||||
)
|
||||
field_names = [
|
||||
model.get_column_name_from_alias(col.name)
|
||||
for col in model.Meta.table.columns
|
||||
]
|
||||
if fields:
|
||||
columns = [
|
||||
col
|
||||
for col, name in zip(columns, field_names)
|
||||
if model.is_included(fields, name)
|
||||
]
|
||||
if exclude_fields:
|
||||
columns = [
|
||||
col
|
||||
for col, name in zip(columns, field_names)
|
||||
if not model.is_excluded(exclude_fields, name)
|
||||
]
|
||||
|
||||
# always has to return pk column
|
||||
columns = ModelTableProxy._populate_pk_column(
|
||||
|
||||
@ -23,6 +23,7 @@ from pydantic import BaseModel
|
||||
import ormar # noqa I100
|
||||
from ormar.fields import BaseField
|
||||
from ormar.fields.foreign_key import ForeignKeyField
|
||||
from ormar.models.excludable import Excludable
|
||||
from ormar.models.metaclass import ModelMeta, ModelMetaclass
|
||||
from ormar.models.modelproxy import ModelTableProxy
|
||||
from ormar.relations.alias_manager import AliasManager
|
||||
@ -39,7 +40,9 @@ if TYPE_CHECKING: # pragma no cover
|
||||
MappingIntStrAny = Mapping[IntStr, Any]
|
||||
|
||||
|
||||
class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass):
|
||||
class NewBaseModel(
|
||||
pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass
|
||||
):
|
||||
__slots__ = ("_orm_id", "_orm_saved", "_orm")
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
|
||||
Reference in New Issue
Block a user