Merge pull request #44 from collerek/allow_dict_in_fields

Allow dict in fields, optimizations and cleanup
This commit is contained in:
collerek
2020-11-16 19:10:53 +07:00
committed by GitHub
27 changed files with 714 additions and 608 deletions

1
.gitignore vendored
View File

@ -9,3 +9,4 @@ test.db
dist dist
/ormar.egg-info/ /ormar.egg-info/
site site
profile.py

View File

@ -163,8 +163,8 @@ assert len(tracks) == 1
* `offset(offset: int) -> QuerySet` * `offset(offset: int) -> QuerySet`
* `count() -> int` * `count() -> int`
* `exists() -> bool` * `exists() -> bool`
* `fields(columns: Union[List, str]) -> QuerySet` * `fields(columns: Union[List, str, set, dict]) -> QuerySet`
* `exclude_fields(columns: Union[List, str]) -> QuerySet` * `exclude_fields(columns: Union[List, str, set, dict]) -> QuerySet`
* `order_by(columns:Union[List, str]) -> QuerySet` * `order_by(columns:Union[List, str]) -> QuerySet`
#### Relation types #### Relation types

View File

@ -163,8 +163,8 @@ assert len(tracks) == 1
* `offset(offset: int) -> QuerySet` * `offset(offset: int) -> QuerySet`
* `count() -> int` * `count() -> int`
* `exists() -> bool` * `exists() -> bool`
* `fields(columns: Union[List, str]) -> QuerySet` * `fields(columns: Union[List, str, set, dict]) -> QuerySet`
* `exclude_fields(columns: Union[List, str]) -> QuerySet` * `exclude_fields(columns: Union[List, str, set, dict]) -> QuerySet`
* `order_by(columns:Union[List, str]) -> QuerySet` * `order_by(columns:Union[List, str]) -> QuerySet`

View File

@ -348,20 +348,73 @@ has_sample = await Book.objects.filter(title='Sample').exists()
### fields ### fields
`fields(columns: Union[List, str]) -> QuerySet` `fields(columns: Union[List, str, set, dict]) -> QuerySet`
With `fields()` you can select subset of model columns to limit the data load. With `fields()` you can select subset of model columns to limit the data load.
```python hl_lines="47 59 60 66" Given a sample data like following:
```python
--8<-- "../docs_src/queries/docs006.py" --8<-- "../docs_src/queries/docs006.py"
``` ```
You can select specified fields by passing a `str, List[str], Set[str] or dict` with nested definition.
To include related models use notation `{related_name}__{column}[__{optional_next} etc.]`.
```python hl_lines="1"
all_cars = await Car.objects.select_related('manufacturer').fields(['id', 'name', 'manufacturer__name']).all()
for car in all_cars:
# excluded columns will yield None
assert all(getattr(car, x) is None for x in ['year', 'gearbox_type', 'gears', 'aircon_type'])
# included column on related models will be available, pk column is always included
# even if you do not include it in fields list
assert car.manufacturer.name == 'Toyota'
# also in the nested related models - you cannot exclude pk - it's always auto added
assert car.manufacturer.founded is None
```
`fields()` can be called several times, building up the columns to select.
If you include related models into `select_related()` call but you won't specify columns for those models in fields
- implies a list of all fields for those nested models.
```python hl_lines="1"
all_cars = await Car.objects.select_related('manufacturer').fields('id').fields(
['name']).all()
# all fiels from company model are selected
assert all_cars[0].manufacturer.name == 'Toyota'
assert all_cars[0].manufacturer.founded == 1937
```
!!!warning !!!warning
Mandatory fields cannot be excluded as it will raise `ValidationError`, to exclude a field it has to be nullable. Mandatory fields cannot be excluded as it will raise `ValidationError`, to exclude a field it has to be nullable.
You cannot exclude mandatory model columns - `manufacturer__name` in this example.
```python
await Car.objects.select_related('manufacturer').fields(['id', 'name', 'manufacturer__founded']).all()
# will raise pydantic ValidationError as company.name is required
```
!!!tip !!!tip
Pk column cannot be excluded - it's always auto added even if not explicitly included. Pk column cannot be excluded - it's always auto added even if not explicitly included.
You can also pass fields to include as dictionary or set.
To mark a field as included in a dictionary use it's name as key and ellipsis as value.
To traverse nested models use nested dictionaries.
To include fields at last level instead of nested dictionary a set can be used.
To include whole nested model specify model related field name and ellipsis.
Below you can see examples that are equivalent:
```python
--8<-- "../docs_src/queries/docs009.py"
```
!!!note !!!note
All methods that do not return the rows explicitly returns a QueySet instance so you can chain them together All methods that do not return the rows explicitly returns a QueySet instance so you can chain them together
@ -372,11 +425,15 @@ With `fields()` you can select subset of model columns to limit the data load.
### exclude_fields ### exclude_fields
`fields(columns: Union[List, str]) -> QuerySet` `exclude_fields(columns: Union[List, str, set, dict]) -> QuerySet`
With `exclude_fields()` you can select subset of model columns that will be excluded to limit the data load. With `exclude_fields()` you can select subset of model columns that will be excluded to limit the data load.
It's the oposite of `fields()` method. It's the opposite of `fields()` method so check documentation above to see what options are available.
Especially check above how you can pass also nested dictionaries and sets as a mask to exclude fields from whole hierarchy.
Below you can find few simple examples:
```python hl_lines="47 48 60 61 67" ```python hl_lines="47 48 60 61 67"
--8<-- "../docs_src/queries/docs008.py" --8<-- "../docs_src/queries/docs008.py"

View File

@ -43,25 +43,3 @@ await Car.objects.create(manufacturer=toyota, name="Yaris", year=2019, gearbox_t
await Car.objects.create(manufacturer=toyota, name="Supreme", year=2020, gearbox_type='Auto', gears=6, await Car.objects.create(manufacturer=toyota, name="Supreme", year=2020, gearbox_type='Auto', gears=6,
aircon_type='Auto') aircon_type='Auto')
# select manufacturer but only name - to include related models use notation {model_name}__{column}
all_cars = await Car.objects.select_related('manufacturer').fields(['id', 'name', 'company__name']).all()
for car in all_cars:
# excluded columns will yield None
assert all(getattr(car, x) is None for x in ['year', 'gearbox_type', 'gears', 'aircon_type'])
# included column on related models will be available, pk column is always included
# even if you do not include it in fields list
assert car.manufacturer.name == 'Toyota'
# also in the nested related models - you cannot exclude pk - it's always auto added
assert car.manufacturer.founded is None
# fields() can be called several times, building up the columns to select
# models selected in select_related but with no columns in fields list implies all fields
all_cars = await Car.objects.select_related('manufacturer').fields('id').fields(
['name']).all()
# all fiels from company model are selected
assert all_cars[0].manufacturer.name == 'Toyota'
assert all_cars[0].manufacturer.founded == 1937
# cannot exclude mandatory model columns - company__name in this example
await Car.objects.select_related('manufacturer').fields(['id', 'name', 'company__founded']).all()
# will raise pydantic ValidationError as company.name is required

View File

@ -63,6 +63,6 @@ all_cars = await Car.objects.select_related('manufacturer').exclude_fields('year
assert all_cars[0].manufacturer.name == 'Toyota' assert all_cars[0].manufacturer.name == 'Toyota'
assert all_cars[0].manufacturer.founded == 1937 assert all_cars[0].manufacturer.founded == 1937
# cannot exclude mandatory model columns - company__name in this example # cannot exclude mandatory model columns - company__name in this example - note usage of dict/set this time
await Car.objects.select_related('manufacturer').exclude_fields(['company__name']).all() await Car.objects.select_related('manufacturer').exclude_fields([{'company': {'name'}}]).all()
# will raise pydantic ValidationError as company.name is required # will raise pydantic ValidationError as company.name is required

View File

@ -0,0 +1,33 @@
# 1. like in example above
await Car.objects.select_related('manufacturer').fields(['id', 'name', 'manufacturer__name']).all()
# 2. to mark a field as required use ellipsis
await Car.objects.select_related('manufacturer').fields({'id': ...,
'name': ...,
'manufacturer': {
'name': ...}
}).all()
# 3. to include whole nested model use ellipsis
await Car.objects.select_related('manufacturer').fields({'id': ...,
'name': ...,
'manufacturer': ...
}).all()
# 4. to specify fields at last nesting level you can also use set - equivalent to 2. above
await Car.objects.select_related('manufacturer').fields({'id': ...,
'name': ...,
'manufacturer': {'name'}
}).all()
# 5. of course set can have multiple fields
await Car.objects.select_related('manufacturer').fields({'id': ...,
'name': ...,
'manufacturer': {'name', 'founded'}
}).all()
# 6. you can include all nested fields but it will be equivalent of 3. above which is shorter
await Car.objects.select_related('manufacturer').fields({'id': ...,
'name': ...,
'manufacturer': {'id', 'name', 'founded'}
}).all()

View File

@ -5,6 +5,7 @@ import sqlalchemy
from pydantic import Field, typing from pydantic import Field, typing
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
import ormar # noqa I101
from ormar import ModelDefinitionError # noqa I101 from ormar import ModelDefinitionError # noqa I101
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -34,6 +35,10 @@ class BaseField(FieldInfo):
default: Any default: Any
server_default: Any server_default: Any
@classmethod
def is_valid_uni_relation(cls) -> bool:
return not issubclass(cls, ormar.fields.ManyToManyField) and not cls.virtual
@classmethod @classmethod
def get_alias(cls) -> str: def get_alias(cls) -> str:
return cls.alias if cls.alias else cls.name return cls.alias if cls.alias else cls.name

View File

@ -0,0 +1,43 @@
from typing import Dict, Set, Union
class Excludable:
@staticmethod
def get_excluded(
exclude: Union[Set, Dict, None], key: str = None
) -> Union[Set, Dict, None]:
if isinstance(exclude, dict):
return exclude.get(key, {})
return exclude
@staticmethod
def get_included(
include: Union[Set, Dict, None], key: str = None
) -> Union[Set, Dict, None]:
return Excludable.get_excluded(exclude=include, key=key)
@staticmethod
def is_excluded(exclude: Union[Set, Dict, None], key: str = None) -> bool:
if exclude is None:
return False
if exclude is Ellipsis: # pragma: nocover
return True
to_exclude = Excludable.get_excluded(exclude=exclude, key=key)
if isinstance(to_exclude, Set):
return key in to_exclude
elif to_exclude is ...:
return True
return False
@staticmethod
def is_included(include: Union[Set, Dict, None], key: str = None) -> bool:
if include is None:
return True
if include is Ellipsis:
return True
to_include = Excludable.get_included(include=include, key=key)
if isinstance(to_include, Set):
return key in to_include
elif to_include is ...:
return True
return False

View File

@ -1,5 +1,5 @@
import itertools import itertools
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union
import sqlalchemy import sqlalchemy
@ -47,8 +47,8 @@ class Model(NewBaseModel):
select_related: List = None, select_related: List = None,
related_models: Any = None, related_models: Any = None,
previous_table: str = None, previous_table: str = None,
fields: List = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: List = None, exclude_fields: Optional[Union[Dict, Set]] = None,
) -> Optional[T]: ) -> Optional[T]:
item: Dict[str, Any] = {} item: Dict[str, Any] = {}
@ -88,7 +88,6 @@ class Model(NewBaseModel):
table_prefix=table_prefix, table_prefix=table_prefix,
fields=fields, fields=fields,
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
nested=table_prefix != "",
) )
instance: Optional[T] = cls(**item) if item.get( instance: Optional[T] = cls(**item) if item.get(
@ -103,13 +102,17 @@ class Model(NewBaseModel):
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
related_models: Any, related_models: Any,
previous_table: sqlalchemy.Table, previous_table: sqlalchemy.Table,
fields: List = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: List = None, exclude_fields: Optional[Union[Dict, Set]] = None,
) -> dict: ) -> dict:
for related in related_models: for related in related_models:
if isinstance(related_models, dict) and related_models[related]: if isinstance(related_models, dict) and related_models[related]:
first_part, remainder = related, related_models[related] first_part, remainder = related, related_models[related]
model_cls = cls.Meta.model_fields[first_part].to model_cls = cls.Meta.model_fields[first_part].to
fields = cls.get_included(fields, first_part)
exclude_fields = cls.get_excluded(exclude_fields, first_part)
child = model_cls.from_row( child = model_cls.from_row(
row, row,
related_models=remainder, related_models=remainder,
@ -120,6 +123,8 @@ class Model(NewBaseModel):
item[model_cls.get_column_name_from_alias(first_part)] = child item[model_cls.get_column_name_from_alias(first_part)] = child
else: else:
model_cls = cls.Meta.model_fields[related].to model_cls = cls.Meta.model_fields[related].to
fields = cls.get_included(fields, related)
exclude_fields = cls.get_excluded(exclude_fields, related)
child = model_cls.from_row( child = model_cls.from_row(
row, row,
previous_table=previous_table, previous_table=previous_table,
@ -136,16 +141,18 @@ class Model(NewBaseModel):
item: dict, item: dict,
row: sqlalchemy.engine.result.ResultProxy, row: sqlalchemy.engine.result.ResultProxy,
table_prefix: str, table_prefix: str,
fields: List = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: List = None, exclude_fields: Optional[Union[Dict, Set]] = None,
nested: bool = False,
) -> dict: ) -> dict:
# databases does not keep aliases in Record for postgres, change to raw row # databases does not keep aliases in Record for postgres, change to raw row
source = row._row if cls.db_backend_name() == "postgresql" else row source = row._row if cls.db_backend_name() == "postgresql" else row
selected_columns = cls.own_table_columns( selected_columns = cls.own_table_columns(
cls, fields or [], exclude_fields or [], nested=nested, use_alias=True model=cls,
fields=fields or {},
exclude_fields=exclude_fields or {},
use_alias=False,
) )
for column in cls.Meta.table.columns: for column in cls.Meta.table.columns:

View File

@ -1,12 +1,28 @@
import inspect import inspect
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Sequence, Set, TYPE_CHECKING, Type, TypeVar, Union from typing import (
Dict,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
Type,
TypeVar,
Union,
)
import ormar
from ormar.exceptions import RelationshipInstanceError from ormar.exceptions import RelationshipInstanceError
from ormar.fields import BaseField, ManyToManyField
try:
import orjson as json
except ImportError: # pragma: nocover
import json # type: ignore
import ormar # noqa: I100
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta, expand_reverse_relationships from ormar.models.metaclass import ModelMeta
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -20,6 +36,8 @@ Field = TypeVar("Field", bound=BaseField)
class ModelTableProxy: class ModelTableProxy:
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
Meta: ModelMeta Meta: ModelMeta
_related_names: Set
_related_names_hash: Union[str, bytes]
def dict(self): # noqa A003 def dict(self): # noqa A003
raise NotImplementedError # pragma no cover raise NotImplementedError # pragma no cover
@ -65,50 +83,50 @@ class ModelTableProxy:
@classmethod @classmethod
def get_column_alias(cls, field_name: str) -> str: def get_column_alias(cls, field_name: str) -> str:
field = cls.Meta.model_fields.get(field_name) field = cls.Meta.model_fields.get(field_name)
if field and field.alias is not None: if field is not None and field.alias is not None:
return field.alias return field.alias
return field_name return field_name
@classmethod @classmethod
def get_column_name_from_alias(cls, alias: str) -> str: def get_column_name_from_alias(cls, alias: str) -> str:
for field_name, field in cls.Meta.model_fields.items(): for field_name, field in cls.Meta.model_fields.items():
if field and field.alias == alias: if field is not None and field.alias == alias:
return field_name return field_name
return alias # if not found it's not an alias but actual name return alias # if not found it's not an alias but actual name
@classmethod @classmethod
def extract_related_names(cls) -> Set: def extract_related_names(cls) -> Set:
if isinstance(cls._related_names_hash, (str, bytes)):
return cls._related_names
related_names = set() related_names = set()
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field) and issubclass(field, ForeignKeyField): if inspect.isclass(field) and issubclass(field, ForeignKeyField):
related_names.add(name) related_names.add(name)
cls._related_names_hash = json.dumps(list(cls.Meta.model_fields.keys()))
cls._related_names = related_names
return related_names return related_names
@classmethod @classmethod
def _extract_db_related_names(cls) -> Set: def _extract_db_related_names(cls) -> Set:
related_names = set() related_names = cls.extract_related_names()
for name, field in cls.Meta.model_fields.items(): related_names = {
if ( name
inspect.isclass(field) for name in related_names
and issubclass(field, ForeignKeyField) if cls.Meta.model_fields[name].is_valid_uni_relation()
and not issubclass(field, ManyToManyField) }
and not field.virtual
):
related_names.add(name)
return related_names return related_names
@classmethod @classmethod
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set: def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
if nested: if nested:
return cls.extract_related_names() return cls.extract_related_names()
related_names = set() related_names = cls.extract_related_names()
for name, field in cls.Meta.model_fields.items(): related_names = {
if ( name for name in related_names if cls.Meta.model_fields[name].nullable
inspect.isclass(field) }
and issubclass(field, ForeignKeyField)
and field.nullable
):
related_names.add(name)
return related_names return related_names
def _extract_model_db_fields(self) -> Dict: def _extract_model_db_fields(self) -> Dict:
@ -128,7 +146,6 @@ class ModelTableProxy:
def resolve_relation_name( # noqa CCR001 def resolve_relation_name( # noqa CCR001
item: Union["NewBaseModel", Type["NewBaseModel"]], item: Union["NewBaseModel", Type["NewBaseModel"]],
related: Union["NewBaseModel", Type["NewBaseModel"]], related: Union["NewBaseModel", Type["NewBaseModel"]],
register_missing: bool = True,
) -> str: ) -> str:
for name, field in item.Meta.model_fields.items(): for name, field in item.Meta.model_fields.items():
if issubclass(field, ForeignKeyField): if issubclass(field, ForeignKeyField):
@ -137,12 +154,6 @@ class ModelTableProxy:
# so we need to compare Meta too as this one is copied as is # so we need to compare Meta too as this one is copied as is
if field.to == related.__class__ or field.to.Meta == related.Meta: if field.to == related.__class__ or field.to.Meta == related.Meta:
return name return name
# fallback for not registered relation
if register_missing: # pragma nocover
expand_reverse_relationships(related.__class__) # type: ignore
return ModelTableProxy.resolve_relation_name(
item, related, register_missing=False
)
raise ValueError( raise ValueError(
f"No relation between {item.get_name()} and {related.get_name()}" f"No relation between {item.get_name()} and {related.get_name()}"
@ -151,7 +162,7 @@ class ModelTableProxy:
@staticmethod @staticmethod
def resolve_relation_field( def resolve_relation_field(
item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]] item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]]
) -> Union[Type[BaseField], Type[ForeignKeyField]]: ) -> Type[BaseField]:
name = ModelTableProxy.resolve_relation_name(item, related) name = ModelTableProxy.resolve_relation_name(item, related)
to_field = item.Meta.model_fields.get(name) to_field = item.Meta.model_fields.get(name)
if not to_field: # pragma no cover if not to_field: # pragma no cover
@ -211,59 +222,13 @@ class ModelTableProxy:
) )
return other return other
@staticmethod
def _get_not_nested_columns_from_fields(
model: Type["Model"],
fields: List,
exclude_fields: List,
column_names: List[str],
use_alias: bool = False,
) -> List[str]:
fields = [model.get_column_alias(k) if not use_alias else k for k in fields]
fields = fields or column_names
exclude_fields = [
model.get_column_alias(k) if not use_alias else k for k in exclude_fields
]
columns = [
name
for name in fields
if "__" not in name and name in column_names and name not in exclude_fields
]
return columns
@staticmethod
def _get_nested_columns_from_fields(
model: Type["Model"],
fields: List,
exclude_fields: List,
column_names: List[str],
use_alias: bool = False,
) -> List[str]:
model_name = f"{model.get_name()}__"
columns = [
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
for name in fields
if f"{model.get_name()}__" in name
]
columns = columns or column_names
exclude_columns = [
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
for name in exclude_fields
if f"{model.get_name()}__" in name
]
columns = [model.get_column_alias(k) if not use_alias else k for k in columns]
exclude_columns = [
model.get_column_alias(k) if not use_alias else k for k in exclude_columns
]
return [column for column in columns if column not in exclude_columns]
@staticmethod @staticmethod
def _populate_pk_column( def _populate_pk_column(
model: Type["Model"], columns: List[str], use_alias: bool = False, model: Type["Model"], columns: List[str], use_alias: bool = False,
) -> List[str]: ) -> List[str]:
pk_alias = ( pk_alias = (
model.get_column_alias(model.Meta.pkname) model.get_column_alias(model.Meta.pkname)
if not use_alias if use_alias
else model.Meta.pkname else model.Meta.pkname
) )
if pk_alias not in columns: if pk_alias not in columns:
@ -273,34 +238,30 @@ class ModelTableProxy:
@staticmethod @staticmethod
def own_table_columns( def own_table_columns(
model: Type["Model"], model: Type["Model"],
fields: List, fields: Optional[Union[Set, Dict]],
exclude_fields: List, exclude_fields: Optional[Union[Set, Dict]],
nested: bool = False,
use_alias: bool = False, use_alias: bool = False,
) -> List[str]: ) -> List[str]:
column_names = [ columns = [
model.get_column_name_from_alias(col.name) if use_alias else col.name model.get_column_name_from_alias(col.name) if not use_alias else col.name
for col in model.Meta.table.columns for col in model.Meta.table.columns
] ]
if not fields and not exclude_fields: field_names = [
return column_names model.get_column_name_from_alias(col.name)
for col in model.Meta.table.columns
if not nested: ]
columns = ModelTableProxy._get_not_nested_columns_from_fields( if fields:
model=model, columns = [
fields=fields, col
exclude_fields=exclude_fields, for col, name in zip(columns, field_names)
column_names=column_names, if model.is_included(fields, name)
use_alias=use_alias, ]
) if exclude_fields:
else: columns = [
columns = ModelTableProxy._get_nested_columns_from_fields( col
model=model, for col, name in zip(columns, field_names)
fields=fields, if not model.is_excluded(exclude_fields, name)
exclude_fields=exclude_fields, ]
column_names=column_names,
use_alias=use_alias,
)
# always has to return pk column # always has to return pk column
columns = ModelTableProxy._populate_pk_column( columns = ModelTableProxy._populate_pk_column(

View File

@ -9,6 +9,7 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
Set,
TYPE_CHECKING, TYPE_CHECKING,
Type, Type,
TypeVar, TypeVar,
@ -23,6 +24,7 @@ from pydantic import BaseModel
import ormar # noqa I100 import ormar # noqa I100
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.excludable import Excludable
from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.metaclass import ModelMeta, ModelMetaclass
from ormar.models.modelproxy import ModelTableProxy from ormar.models.modelproxy import ModelTableProxy
from ormar.relations.alias_manager import AliasManager from ormar.relations.alias_manager import AliasManager
@ -39,8 +41,17 @@ if TYPE_CHECKING: # pragma no cover
MappingIntStrAny = Mapping[IntStr, Any] MappingIntStrAny = Mapping[IntStr, Any]
class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass): class NewBaseModel(
__slots__ = ("_orm_id", "_orm_saved", "_orm") pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass
):
__slots__ = (
"_orm_id",
"_orm_saved",
"_orm",
"_related_names",
"_related_names_hash",
"_props",
)
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, Type[BaseField]] __model_fields__: Dict[str, Type[BaseField]]
@ -53,6 +64,10 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
__database__: databases.Database __database__: databases.Database
_orm_relationship_manager: AliasManager _orm_relationship_manager: AliasManager
_orm: RelationsManager _orm: RelationsManager
_orm_saved: bool
_related_names: Set
_related_names_hash: str
_props: List[str]
Meta: ModelMeta Meta: ModelMeta
# noinspection PyMissingConstructor # noinspection PyMissingConstructor
@ -104,7 +119,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
) )
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
if name in ("_orm_id", "_orm_saved", "_orm"): if name in ("_orm_id", "_orm_saved", "_orm", "_related_names", "_props"):
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
elif name == "pk": elif name == "pk":
object.__setattr__(self, self.Meta.pkname, value) object.__setattr__(self, self.Meta.pkname, value)
@ -123,12 +138,19 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
super().__setattr__(name, value) super().__setattr__(name, value)
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"): if item in (
"_orm_id",
"_orm_saved",
"_orm",
"__fields__",
"_related_names",
"_props",
):
return object.__getattribute__(self, item) return object.__getattribute__(self, item)
if item != "extract_related_names" and item in self.extract_related_names():
return self._extract_related_model_instead_of_field(item)
if item == "pk": if item == "pk":
return self.__dict__.get(self.Meta.pkname, None) return self.__dict__.get(self.Meta.pkname, None)
if item != "extract_related_names" and item in self.extract_related_names():
return self._extract_related_model_instead_of_field(item)
if item != "__fields__" and item in self.__fields__: if item != "__fields__" and item in self.__fields__:
value = self.__dict__.get(item, None) value = self.__dict__.get(item, None)
value = self._convert_json(item, value, "loads") value = self._convert_json(item, value, "loads")
@ -183,12 +205,16 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
) -> List[str]: ) -> List[str]:
if isinstance(cls._props, list):
props = cls._props
else:
props = [ props = [
prop prop
for prop in dir(cls) for prop in dir(cls)
if isinstance(getattr(cls, prop), property) if isinstance(getattr(cls, prop), property)
and prop not in ("__values__", "__fields__", "fields", "pk_column") and prop not in ("__values__", "__fields__", "fields", "pk_column")
] ]
cls._props = props
if include: if include:
props = [prop for prop in props if prop in include] props = [prop for prop in props if prop in include]
if exclude: if exclude:

View File

@ -1,5 +1,15 @@
from collections import OrderedDict from collections import OrderedDict
from typing import List, NamedTuple, Optional, TYPE_CHECKING, Tuple, Type from typing import (
Dict,
List,
NamedTuple,
Optional,
Set,
TYPE_CHECKING,
Tuple,
Type,
Union,
)
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
@ -24,8 +34,8 @@ class SqlJoin:
used_aliases: List, used_aliases: List,
select_from: sqlalchemy.sql.select, select_from: sqlalchemy.sql.select,
columns: List[sqlalchemy.Column], columns: List[sqlalchemy.Column],
fields: List, fields: Optional[Union[Set, Dict]],
exclude_fields: List, exclude_fields: Optional[Union[Set, Dict]],
order_columns: Optional[List], order_columns: Optional[List],
sorted_orders: OrderedDict, sorted_orders: OrderedDict,
) -> None: ) -> None:
@ -49,21 +59,60 @@ class SqlJoin:
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
return text(f"{left_part}={right_part}") return text(f"{left_part}={right_part}")
def build_join( @staticmethod
def update_inclusions(
model_cls: Type["Model"],
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
nested_name: str,
) -> Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]:
fields = model_cls.get_included(fields, nested_name)
exclude_fields = model_cls.get_excluded(exclude_fields, nested_name)
return fields, exclude_fields
def build_join( # noqa: CCR001
self, item: str, join_parameters: JoinParameters self, item: str, join_parameters: JoinParameters
) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]: ) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]:
for part in item.split("__"):
fields = self.fields
exclude_fields = self.exclude_fields
for index, part in enumerate(item.split("__")):
if issubclass( if issubclass(
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
): ):
_fields = join_parameters.model_cls.Meta.model_fields _fields = join_parameters.model_cls.Meta.model_fields
new_part = _fields[part].to.get_name() new_part = _fields[part].to.get_name()
self._switch_many_to_many_order_columns(part, new_part) self._switch_many_to_many_order_columns(part, new_part)
if index > 0: # nested joins
fields, exclude_fields = SqlJoin.update_inclusions(
model_cls=join_parameters.model_cls,
fields=fields,
exclude_fields=exclude_fields,
nested_name=part,
)
join_parameters = self._build_join_parameters( join_parameters = self._build_join_parameters(
part, join_parameters, is_multi=True part=part,
join_params=join_parameters,
is_multi=True,
fields=fields,
exclude_fields=exclude_fields,
) )
part = new_part part = new_part
join_parameters = self._build_join_parameters(part, join_parameters) if index > 0: # nested joins
fields, exclude_fields = SqlJoin.update_inclusions(
model_cls=join_parameters.model_cls,
fields=fields,
exclude_fields=exclude_fields,
nested_name=part,
)
join_parameters = self._build_join_parameters(
part=part,
join_params=join_parameters,
fields=fields,
exclude_fields=exclude_fields,
)
return ( return (
self.used_aliases, self.used_aliases,
@ -73,7 +122,12 @@ class SqlJoin:
) )
def _build_join_parameters( def _build_join_parameters(
self, part: str, join_params: JoinParameters, is_multi: bool = False self,
part: str,
join_params: JoinParameters,
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
is_multi: bool = False,
) -> JoinParameters: ) -> JoinParameters:
if is_multi: if is_multi:
model_cls = join_params.model_cls.Meta.model_fields[part].through model_cls = join_params.model_cls.Meta.model_fields[part].through
@ -85,20 +139,30 @@ class SqlJoin:
join_params.from_table, to_table join_params.from_table, to_table
) )
if alias not in self.used_aliases: if alias not in self.used_aliases:
self._process_join(join_params, is_multi, model_cls, part, alias) self._process_join(
join_params=join_params,
is_multi=is_multi,
model_cls=model_cls,
part=part,
alias=alias,
fields=fields,
exclude_fields=exclude_fields,
)
previous_alias = alias previous_alias = alias
from_table = to_table from_table = to_table
prev_model = model_cls prev_model = model_cls
return JoinParameters(prev_model, previous_alias, from_table, model_cls) return JoinParameters(prev_model, previous_alias, from_table, model_cls)
def _process_join( def _process_join( # noqa: CFQ002
self, self,
join_params: JoinParameters, join_params: JoinParameters,
is_multi: bool, is_multi: bool,
model_cls: Type["Model"], model_cls: Type["Model"],
part: str, part: str,
alias: str, alias: str,
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
) -> None: ) -> None:
to_table = model_cls.Meta.table.name to_table = model_cls.Meta.table.name
to_key, from_key = self.get_to_and_from_keys( to_key, from_key = self.get_to_and_from_keys(
@ -129,7 +193,10 @@ class SqlJoin:
) )
self_related_fields = model_cls.own_table_columns( self_related_fields = model_cls.own_table_columns(
model_cls, self.fields, self.exclude_fields, nested=True, model=model_cls,
fields=fields,
exclude_fields=exclude_fields,
use_alias=True,
) )
self.columns.extend( self.columns.extend(
self.relation_manager(model_cls).prefixed_columns( self.relation_manager(model_cls).prefixed_columns(

View File

@ -1,5 +1,6 @@
import copy
from collections import OrderedDict from collections import OrderedDict
from typing import List, Optional, TYPE_CHECKING, Tuple, Type from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Type, Union
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
@ -21,8 +22,8 @@ class Query:
select_related: List, select_related: List,
limit_count: Optional[int], limit_count: Optional[int],
offset: Optional[int], offset: Optional[int],
fields: Optional[List], fields: Optional[Union[Dict, Set]],
exclude_fields: Optional[List], exclude_fields: Optional[Union[Dict, Set]],
order_bys: Optional[List], order_bys: Optional[List],
) -> None: ) -> None:
self.query_offset = offset self.query_offset = offset
@ -30,8 +31,8 @@ class Query:
self._select_related = select_related[:] self._select_related = select_related[:]
self.filter_clauses = filter_clauses[:] self.filter_clauses = filter_clauses[:]
self.exclude_clauses = exclude_clauses[:] self.exclude_clauses = exclude_clauses[:]
self.fields = fields[:] if fields else [] self.fields = copy.deepcopy(fields) if fields else {}
self.exclude_fields = exclude_fields[:] if exclude_fields else [] self.exclude_fields = copy.deepcopy(exclude_fields) if exclude_fields else {}
self.model_cls = model_cls self.model_cls = model_cls
self.table = self.model_cls.Meta.table self.table = self.model_cls.Meta.table
@ -73,7 +74,10 @@ class Query:
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
self_related_fields = self.model_cls.own_table_columns( self_related_fields = self.model_cls.own_table_columns(
self.model_cls, self.fields, self.exclude_fields model=self.model_cls,
fields=self.fields,
exclude_fields=self.exclude_fields,
use_alias=True,
) )
self.columns = self.model_cls.Meta.alias_manager.prefixed_columns( self.columns = self.model_cls.Meta.alias_manager.prefixed_columns(
"", self.table, self_related_fields "", self.table, self_related_fields
@ -87,13 +91,14 @@ class Query:
join_parameters = JoinParameters( join_parameters = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls self.model_cls, "", self.table.name, self.model_cls
) )
fields = self.model_cls.get_included(self.fields, item)
exclude_fields = self.model_cls.get_excluded(self.exclude_fields, item)
sql_join = SqlJoin( sql_join = SqlJoin(
used_aliases=self.used_aliases, used_aliases=self.used_aliases,
select_from=self.select_from, select_from=self.select_from,
columns=self.columns, columns=self.columns,
fields=self.fields, fields=fields,
exclude_fields=self.exclude_fields, exclude_fields=exclude_fields,
order_columns=self.order_columns, order_columns=self.order_columns,
sorted_orders=self.sorted_orders, sorted_orders=self.sorted_orders,
) )
@ -131,5 +136,5 @@ class Query:
self.select_from = [] self.select_from = []
self.columns = [] self.columns = []
self.used_aliases = [] self.used_aliases = []
self.fields = [] self.fields = {}
self.exclude_fields = [] self.exclude_fields = {}

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Type, Union from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union
import databases import databases
import sqlalchemy import sqlalchemy
@ -10,6 +10,7 @@ from ormar.exceptions import QueryDefinitionError
from ormar.queryset import FilterQuery from ormar.queryset import FilterQuery
from ormar.queryset.clause import QueryClause from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query from ormar.queryset.query import Query
from ormar.queryset.utils import update, update_dict_from_list
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -26,8 +27,8 @@ class QuerySet:
select_related: List = None, select_related: List = None,
limit_count: int = None, limit_count: int = None,
offset: int = None, offset: int = None,
columns: List = None, columns: Dict = None,
exclude_columns: List = None, exclude_columns: Dict = None,
order_bys: List = None, order_bys: List = None,
) -> None: ) -> None:
self.model_cls = model_cls self.model_cls = model_cls
@ -36,8 +37,8 @@ class QuerySet:
self._select_related = [] if select_related is None else select_related self._select_related = [] if select_related is None else select_related
self.limit_count = limit_count self.limit_count = limit_count
self.query_offset = offset self.query_offset = offset
self._columns = columns or [] self._columns = columns or {}
self._exclude_columns = exclude_columns or [] self._exclude_columns = exclude_columns or {}
self.order_bys = order_bys or [] self.order_bys = order_bys or []
def __get__( def __get__(
@ -169,11 +170,16 @@ class QuerySet:
order_bys=self.order_bys, order_bys=self.order_bys,
) )
def exclude_fields(self, columns: Union[List, str]) -> "QuerySet": def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
if not isinstance(columns, list): if isinstance(columns, str):
columns = [columns] columns = [columns]
columns = list(set(list(self._exclude_columns) + columns)) current_excluded = self._exclude_columns
if not isinstance(columns, dict):
current_excluded = update_dict_from_list(current_excluded, columns)
else:
current_excluded = update(current_excluded, columns)
return self.__class__( return self.__class__(
model_cls=self.model, model_cls=self.model,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
@ -182,15 +188,20 @@ class QuerySet:
limit_count=self.limit_count, limit_count=self.limit_count,
offset=self.query_offset, offset=self.query_offset,
columns=self._columns, columns=self._columns,
exclude_columns=columns, exclude_columns=current_excluded,
order_bys=self.order_bys, order_bys=self.order_bys,
) )
def fields(self, columns: Union[List, str]) -> "QuerySet": def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
if not isinstance(columns, list): if isinstance(columns, str):
columns = [columns] columns = [columns]
columns = list(set(list(self._columns) + columns)) current_included = self._columns
if not isinstance(columns, dict):
current_included = update_dict_from_list(current_included, columns)
else:
current_included = update(current_included, columns)
return self.__class__( return self.__class__(
model_cls=self.model, model_cls=self.model,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
@ -198,7 +209,7 @@ class QuerySet:
select_related=self._select_related, select_related=self._select_related,
limit_count=self.limit_count, limit_count=self.limit_count,
offset=self.query_offset, offset=self.query_offset,
columns=columns, columns=current_included,
exclude_columns=self._exclude_columns, exclude_columns=self._exclude_columns,
order_bys=self.order_bys, order_bys=self.order_bys,
) )

57
ormar/queryset/utils.py Normal file
View File

@ -0,0 +1,57 @@
import collections.abc
import copy
from typing import Any, Dict, List, Set, Union
def check_node_not_dict_or_not_last_node(
part: str, parts: List, current_level: Any
) -> bool:
return (part not in current_level and part != parts[-1]) or (
part in current_level and not isinstance(current_level[part], dict)
)
def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001
new_dict: Dict = dict()
for path in list_to_trans:
current_level = new_dict
parts = path.split("__")
for part in parts:
if check_node_not_dict_or_not_last_node(
part=part, parts=parts, current_level=current_level
):
current_level[part] = dict()
elif part not in current_level:
current_level[part] = ...
current_level = current_level[part]
return new_dict
def convert_set_to_required_dict(set_to_convert: set) -> Dict:
new_dict = dict()
for key in set_to_convert:
new_dict[key] = Ellipsis
return new_dict
def update(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001
if current_dict is Ellipsis:
current_dict = dict()
for key, value in updating_dict.items():
if isinstance(value, collections.abc.Mapping):
old_key = current_dict.get(key, {})
if isinstance(old_key, set):
old_key = convert_set_to_required_dict(old_key)
current_dict[key] = update(old_key, value)
elif isinstance(value, set) and isinstance(current_dict.get(key), set):
current_dict[key] = current_dict.get(key).union(value)
else:
current_dict[key] = value
return current_dict
def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> Dict:
updated_dict = copy.copy(curr_dict)
dict_to_update = translate_list_to_dict(list_to_update)
update(updated_dict, dict_to_update)
return updated_dict

View File

@ -1,16 +1,12 @@
from ormar.relations.alias_manager import AliasManager from ormar.relations.alias_manager import AliasManager
from ormar.relations.relation import Relation, RelationType from ormar.relations.relation import Relation, RelationType
from ormar.relations.relation_manager import RelationsManager from ormar.relations.relation_manager import RelationsManager
from ormar.relations.utils import ( from ormar.relations.utils import get_relations_sides_and_names
get_relations_sides_and_names,
register_missing_relation,
)
__all__ = [ __all__ = [
"AliasManager", "AliasManager",
"Relation", "Relation",
"RelationsManager", "RelationsManager",
"RelationType", "RelationType",
"register_missing_relation",
"get_relations_sides_and_names", "get_relations_sides_and_names",
] ]

View File

@ -5,10 +5,7 @@ from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.fields.many_to_many import ManyToManyField from ormar.fields.many_to_many import ManyToManyField
from ormar.relations.relation import Relation, RelationType from ormar.relations.relation import Relation, RelationType
from ormar.relations.utils import ( from ormar.relations.utils import get_relations_sides_and_names
get_relations_sides_and_names,
register_missing_relation,
)
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -42,8 +39,6 @@ class RelationsManager:
to=field.to, to=field.to,
through=getattr(field, "through", None), through=getattr(field, "through", None),
) )
if field.name not in self._related_names:
self._related_names.append(field.name)
def __contains__(self, item: str) -> bool: def __contains__(self, item: str) -> bool:
return item in self._related_names return item in self._related_names
@ -69,8 +64,9 @@ class RelationsManager:
) )
parent_relation = parent._orm._get(child_name) parent_relation = parent._orm._get(child_name)
if not parent_relation: if parent_relation:
parent_relation = register_missing_relation(parent, child, child_name) # print('missing', child_name)
# parent_relation = register_missing_relation(parent, child, child_name)
parent_relation.add(child) # type: ignore parent_relation.add(child) # type: ignore
child_relation = child._orm._get(to_name) child_relation = child._orm._get(to_name)

View File

@ -72,6 +72,4 @@ class RelationProxy(list):
if self.relation._type == ormar.RelationType.MULTIPLE: if self.relation._type == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item) await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner) rel_name = item.resolve_relation_name(item, self._owner)
if rel_name not in item._orm: # pragma nocover
item._orm._add_relation(item.Meta.model_fields[rel_name])
setattr(item, rel_name, self._owner) setattr(item, rel_name, self._owner)

View File

@ -1,26 +1,13 @@
from typing import Optional, TYPE_CHECKING, Tuple, Type from typing import TYPE_CHECKING, Tuple, Type
from weakref import proxy from weakref import proxy
import ormar
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.many_to_many import ManyToManyField from ormar.fields.many_to_many import ManyToManyField
from ormar.relations import Relation
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
def register_missing_relation(
parent: "Model", child: "Model", child_name: str
) -> Optional[Relation]:
ormar.models.expand_reverse_relationships(child.__class__)
name = parent.resolve_relation_name(parent, child)
field = parent.Meta.model_fields[name]
parent._orm._add_relation(field)
parent_relation = parent._orm._get(child_name)
return parent_relation
def get_relations_sides_and_names( def get_relations_sides_and_names(
to_field: Type[BaseField], to_field: Type[BaseField],
parent: "Model", parent: "Model",

View File

@ -4,6 +4,7 @@ databases[mysql]
pydantic pydantic
sqlalchemy sqlalchemy
typing_extensions typing_extensions
orjson
# Async database drivers # Async database drivers
aiomysql aiomysql
@ -34,3 +35,6 @@ flake8-variables-names
flake8-cognitive-complexity flake8-cognitive-complexity
flake8-functions flake8-functions
flake8-expression-complexity flake8-expression-complexity
# Performance testing
yappi

View File

@ -42,7 +42,7 @@ setup(
version=get_version(PACKAGE), version=get_version(PACKAGE),
url=URL, url=URL,
license="MIT", license="MIT",
description="An simple async ORM with fastapi in mind and pydantic validation.", description="A simple async ORM with fastapi in mind and pydantic validation.",
long_description=get_long_description(), long_description=get_long_description(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
keywords=['orm', 'sqlalchemy', 'fastapi', 'pydantic', 'databases', 'async', 'alembic'], keywords=['orm', 'sqlalchemy', 'fastapi', 'pydantic', 'databases', 'async', 'alembic'],
@ -56,6 +56,7 @@ setup(
"postgresql": ["asyncpg", "psycopg2"], "postgresql": ["asyncpg", "psycopg2"],
"mysql": ["aiomysql", "pymysql"], "mysql": ["aiomysql", "pymysql"],
"sqlite": ["aiosqlite"], "sqlite": ["aiosqlite"],
"orjson": ["orjson"]
}, },
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",

View File

@ -117,8 +117,8 @@ async def test_working_with_aliases():
"first_name", "first_name",
"last_name", "last_name",
"born_year", "born_year",
"child__first_name", "children__first_name",
"child__last_name", "children__last_name",
] ]
) )
.get() .get()

View File

@ -80,10 +80,53 @@ async def test_selecting_subset():
all_cars = ( all_cars = (
await Car.objects.select_related("manufacturer") await Car.objects.select_related("manufacturer")
.exclude_fields( .exclude_fields(
["gearbox_type", "gears", "aircon_type", "year", "company__founded"] [
"gearbox_type",
"gears",
"aircon_type",
"year",
"manufacturer__founded",
]
) )
.all() .all()
) )
for car in all_cars:
assert all(
getattr(car, x) is None
for x in ["year", "gearbox_type", "gears", "aircon_type"]
)
assert car.manufacturer.name == "Toyota"
assert car.manufacturer.founded is None
all_cars = (
await Car.objects.select_related("manufacturer")
.exclude_fields(
{
"gearbox_type": ...,
"gears": ...,
"aircon_type": ...,
"year": ...,
"manufacturer": {"founded": ...},
}
)
.all()
)
all_cars2 = (
await Car.objects.select_related("manufacturer")
.exclude_fields(
{
"gearbox_type": ...,
"gears": ...,
"aircon_type": ...,
"year": ...,
"manufacturer": {"founded"},
}
)
.all()
)
assert all_cars == all_cars2
for car in all_cars: for car in all_cars:
assert all( assert all(
getattr(car, x) is None getattr(car, x) is None
@ -119,7 +162,7 @@ async def test_selecting_subset():
all_cars_check2 = ( all_cars_check2 = (
await Car.objects.select_related("manufacturer") await Car.objects.select_related("manufacturer")
.fields(["id", "name", "manufacturer"]) .fields(["id", "name", "manufacturer"])
.exclude_fields("company__founded") .exclude_fields("manufacturer__founded")
.all() .all()
) )
for car in all_cars_check2: for car in all_cars_check2:
@ -133,5 +176,5 @@ async def test_selecting_subset():
with pytest.raises(pydantic.error_wrappers.ValidationError): with pytest.raises(pydantic.error_wrappers.ValidationError):
# cannot exclude mandatory model columns - company__name in this example # cannot exclude mandatory model columns - company__name in this example
await Car.objects.select_related("manufacturer").exclude_fields( await Car.objects.select_related("manufacturer").exclude_fields(
["company__name"] ["manufacturer__name"]
).all() ).all()

View File

@ -1,374 +0,0 @@
from typing import Optional
import databases
import pytest
import sqlalchemy
import ormar
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Track(ormar.Model):
class Meta:
tablename = "tracks"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album)
title: str = ormar.String(max_length=100)
position: int = ormar.Integer()
class Cover(ormar.Model):
class Meta:
tablename = "covers"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures")
title: str = ormar.String(max_length=100)
class Organisation(ormar.Model):
class Meta:
tablename = "org"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"])
class Team(ormar.Model):
class Meta:
tablename = "teams"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
org: Optional[Organisation] = ormar.ForeignKey(Organisation)
name: str = ormar.String(max_length=100)
class Member(ormar.Model):
class Meta:
tablename = "members"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
team: Optional[Team] = ormar.ForeignKey(Team)
email: str = ormar.String(max_length=100)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_wrong_query_foreign_key_type():
async with database:
with pytest.raises(RelationshipInstanceError):
Track(title="The Error", album="wrong_pk_type")
@pytest.mark.asyncio
async def test_setting_explicitly_empty_relation():
async with database:
track = Track(album=None, title="The Bird", position=1)
assert track.album is None
@pytest.mark.asyncio
async def test_related_name():
async with database:
async with database.transaction(force_rollback=True):
album = await Album.objects.create(name="Vanilla")
await Cover.objects.create(album=album, title="The cover file")
assert len(album.cover_pictures) == 1
@pytest.mark.asyncio
async def test_model_crud():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Jamaica")
await album.save()
track1 = Track(album=album, title="The Bird", position=1)
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
track3 = Track(album=album, title="The Waters", position=3)
await track1.save()
await track2.save()
await track3.save()
track = await Track.objects.get(title="The Bird")
assert track.album.pk == album.pk
assert isinstance(track.album, ormar.Model)
assert track.album.name is None
await track.album.load()
assert track.album.name == "Jamaica"
assert len(album.tracks) == 3
assert album.tracks[1].title == "Heart don't stand a chance"
album1 = await Album.objects.get(name="Jamaica")
assert album1.pk == album.pk
assert album1.tracks == []
await Track.objects.create(
album={"id": track.album.pk}, title="The Bird2", position=4
)
@pytest.mark.asyncio
async def test_select_related():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Malibu")
await album.save()
track1 = Track(album=album, title="The Bird", position=1)
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
track3 = Track(album=album, title="The Waters", position=3)
await track1.save()
await track2.save()
await track3.save()
fantasies = Album(name="Fantasies")
await fantasies.save()
track4 = Track(album=fantasies, title="Help I'm Alive", position=1)
track5 = Track(album=fantasies, title="Sick Muse", position=2)
track6 = Track(album=fantasies, title="Satellite Mind", position=3)
await track4.save()
await track5.save()
await track6.save()
track = await Track.objects.select_related("album").get(title="The Bird")
assert track.album.name == "Malibu"
tracks = await Track.objects.select_related("album").all()
assert len(tracks) == 6
@pytest.mark.asyncio
async def test_model_removal_from_relations():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Chichi")
await album.save()
track1 = Track(album=album, title="The Birdman", position=1)
track2 = Track(album=album, title="Superman", position=2)
track3 = Track(album=album, title="Wonder Woman", position=3)
await track1.save()
await track2.save()
await track3.save()
assert len(album.tracks) == 3
await album.tracks.remove(track1)
assert len(album.tracks) == 2
assert track1.album is None
await track1.update()
track1 = await Track.objects.get(title="The Birdman")
assert track1.album is None
await album.tracks.add(track1)
assert len(album.tracks) == 3
assert track1.album == album
await track1.update()
track1 = await Track.objects.select_related("album__tracks").get(
title="The Birdman"
)
album = await Album.objects.select_related("tracks").get(name="Chichi")
assert track1.album == album
track1.remove(album)
assert track1.album is None
assert len(album.tracks) == 2
track2.remove(album)
assert track2.album is None
assert len(album.tracks) == 1
@pytest.mark.asyncio
async def test_fk_filter():
async with database:
async with database.transaction(force_rollback=True):
malibu = Album(name="Malibu%")
await malibu.save()
await Track.objects.create(album=malibu, title="The Bird", position=1)
await Track.objects.create(
album=malibu, title="Heart don't stand a chance", position=2
)
await Track.objects.create(album=malibu, title="The Waters", position=3)
fantasies = await Album.objects.create(name="Fantasies")
await Track.objects.create(
album=fantasies, title="Help I'm Alive", position=1
)
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
await Track.objects.create(
album=fantasies, title="Satellite Mind", position=3
)
tracks = (
await Track.objects.select_related("album")
.filter(album__name="Fantasies")
.all()
)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = (
await Track.objects.select_related("album")
.filter(album__name__icontains="fan")
.all()
)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = await Track.objects.filter(album__name__contains="Fan").all()
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = await Track.objects.filter(album__name__contains="Malibu%").all()
assert len(tracks) == 3
tracks = (
await Track.objects.filter(album=malibu).select_related("album").all()
)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Malibu%"
tracks = await Track.objects.select_related("album").all(album=malibu)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Malibu%"
@pytest.mark.asyncio
async def test_multiple_fk():
async with database:
async with database.transaction(force_rollback=True):
acme = await Organisation.objects.create(ident="ACME Ltd")
red_team = await Team.objects.create(org=acme, name="Red Team")
blue_team = await Team.objects.create(org=acme, name="Blue Team")
await Member.objects.create(team=red_team, email="a@example.org")
await Member.objects.create(team=red_team, email="b@example.org")
await Member.objects.create(team=blue_team, email="c@example.org")
await Member.objects.create(team=blue_team, email="d@example.org")
other = await Organisation.objects.create(ident="Other ltd")
team = await Team.objects.create(org=other, name="Green Team")
await Member.objects.create(team=team, email="e@example.org")
members = (
await Member.objects.select_related("team__org")
.filter(team__org__ident="ACME Ltd")
.all()
)
assert len(members) == 4
for member in members:
assert member.team.org.ident == "ACME Ltd"
@pytest.mark.asyncio
async def test_wrong_choices():
async with database:
async with database.transaction(force_rollback=True):
with pytest.raises(ValueError):
await Organisation.objects.create(ident="Test 1")
@pytest.mark.asyncio
async def test_pk_filter():
async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Test")
track = await Track.objects.create(
album=fantasies, title="Test1", position=1
)
await Track.objects.create(album=fantasies, title="Test2", position=2)
await Track.objects.create(album=fantasies, title="Test3", position=3)
tracks = (
await Track.objects.select_related("album").filter(pk=track.pk).all()
)
assert len(tracks) == 1
tracks = (
await Track.objects.select_related("album")
.filter(position=2, album__name="Test")
.all()
)
assert len(tracks) == 1
@pytest.mark.asyncio
async def test_limit_and_offset():
async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Limitless")
await Track.objects.create(
id=None, album=fantasies, title="Sample", position=1
)
await Track.objects.create(album=fantasies, title="Sample2", position=2)
await Track.objects.create(album=fantasies, title="Sample3", position=3)
tracks = await Track.objects.limit(1).all()
assert len(tracks) == 1
assert tracks[0].title == "Sample"
tracks = await Track.objects.limit(1).offset(1).all()
assert len(tracks) == 1
assert tracks[0].title == "Sample2"
@pytest.mark.asyncio
async def test_get_exceptions():
async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Test")
with pytest.raises(NoMatch):
await Album.objects.get(name="Test2")
await Track.objects.create(album=fantasies, title="Test1", position=1)
await Track.objects.create(album=fantasies, title="Test2", position=2)
await Track.objects.create(album=fantasies, title="Test3", position=3)
with pytest.raises(MultipleMatches):
await Track.objects.select_related("album").get(album=fantasies)
@pytest.mark.asyncio
async def test_wrong_model_passed_as_fk():
async with database:
async with database.transaction(force_rollback=True):
with pytest.raises(RelationshipInstanceError):
org = await Organisation.objects.create(ident="ACME Ltd")
await Track.objects.create(album=org, title="Test1", position=1)

View File

@ -0,0 +1,98 @@
from ormar.models.excludable import Excludable
from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update
def test_empty_excludable():
assert Excludable.is_included(None, "key") # all fields included if empty
assert not Excludable.is_excluded(None, "key") # none field excluded if empty
def test_list_to_dict_translation():
tet_list = ["aa", "bb", "cc__aa", "cc__bb", "cc__aa__xx", "cc__aa__yy"]
test = translate_list_to_dict(tet_list)
assert test == {
"aa": Ellipsis,
"bb": Ellipsis,
"cc": {"aa": {"xx": Ellipsis, "yy": Ellipsis}, "bb": Ellipsis},
}
def test_updating_dict_with_list():
curr_dict = {
"aa": Ellipsis,
"bb": Ellipsis,
"cc": {"aa": {"xx": Ellipsis, "yy": Ellipsis}, "bb": Ellipsis},
}
list_to_update = ["ee", "bb__cc", "cc__aa__xx__oo", "cc__aa__oo"]
test = update_dict_from_list(curr_dict, list_to_update)
assert test == {
"aa": Ellipsis,
"bb": {"cc": Ellipsis},
"cc": {
"aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis},
"bb": Ellipsis,
},
"ee": Ellipsis,
}
def test_updating_dict_inc_set_with_list():
curr_dict = {
"aa": Ellipsis,
"bb": Ellipsis,
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
}
list_to_update = ["uu", "bb__cc", "cc__aa__xx__oo", "cc__aa__oo"]
test = update_dict_from_list(curr_dict, list_to_update)
assert test == {
"aa": Ellipsis,
"bb": {"cc": Ellipsis},
"cc": {
"aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis},
"bb": Ellipsis,
},
"uu": Ellipsis,
}
def test_updating_dict_inc_set_with_dict():
curr_dict = {
"aa": Ellipsis,
"bb": Ellipsis,
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
}
dict_to_update = {
"uu": Ellipsis,
"bb": {"cc", "dd"},
"cc": {"aa": {"xx": {"oo": Ellipsis}, "oo": Ellipsis}},
}
test = update(curr_dict, dict_to_update)
assert test == {
"aa": Ellipsis,
"bb": {"cc", "dd"},
"cc": {
"aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis},
"bb": Ellipsis,
},
"uu": Ellipsis,
}
def test_updating_dict_inc_set_with_dict_inc_set():
curr_dict = {
"aa": Ellipsis,
"bb": Ellipsis,
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
}
dict_to_update = {
"uu": Ellipsis,
"bb": {"cc", "dd"},
"cc": {"aa": {"xx", "oo", "zz", "ii"}},
}
test = update(curr_dict, dict_to_update)
assert test == {
"aa": Ellipsis,
"bb": {"cc", "dd"},
"cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis},
"uu": Ellipsis,
}

View File

@ -1,4 +1,5 @@
from typing import Optional import itertools
from typing import Optional, List
import databases import databases
import pydantic import pydantic
@ -12,6 +13,35 @@ database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
class NickNames(ormar.Model):
class Meta:
tablename = "nicks"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
is_lame: bool = ormar.Boolean(nullable=True)
class NicksHq(ormar.Model):
class Meta:
tablename = "nicks_x_hq"
metadata = metadata
database = database
class HQ(ormar.Model):
class Meta:
tablename = "hqs"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq)
class Company(ormar.Model): class Company(ormar.Model):
class Meta: class Meta:
tablename = "companies" tablename = "companies"
@ -21,6 +51,7 @@ class Company(ormar.Model):
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="company_name") name: str = ormar.String(max_length=100, nullable=False, name="company_name")
founded: int = ormar.Integer(nullable=True) founded: int = ormar.Integer(nullable=True)
hq: HQ = ormar.ForeignKey(HQ)
class Car(ormar.Model): class Car(ormar.Model):
@ -51,7 +82,14 @@ def create_test_database():
async def test_selecting_subset(): async def test_selecting_subset():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
toyota = await Company.objects.create(name="Toyota", founded=1937) nick1 = await NickNames.objects.create(name="Nippon", is_lame=False)
nick2 = await NickNames.objects.create(name="EroCherry", is_lame=True)
hq = await HQ.objects.create(name="Japan")
await hq.nicks.add(nick1)
await hq.nicks.add(nick2)
toyota = await Company.objects.create(name="Toyota", founded=1937, hq=hq)
await Car.objects.create( await Car.objects.create(
manufacturer=toyota, manufacturer=toyota,
name="Corolla", name="Corolla",
@ -78,17 +116,66 @@ async def test_selecting_subset():
) )
all_cars = ( all_cars = (
await Car.objects.select_related("manufacturer") await Car.objects.select_related(
.fields(["id", "name", "company__name"]) ["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
)
.fields(
[
"id",
"name",
"manufacturer__name",
"manufacturer__hq__name",
"manufacturer__hq__nicks__name",
]
)
.all() .all()
) )
for car in all_cars:
all_cars2 = (
await Car.objects.select_related(
["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
)
.fields(
{
"id": ...,
"name": ...,
"manufacturer": {
"name": ...,
"hq": {"name": ..., "nicks": {"name": ...}},
},
}
)
.all()
)
all_cars3 = (
await Car.objects.select_related(
["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
)
.fields(
{
"id": ...,
"name": ...,
"manufacturer": {
"name": ...,
"hq": {"name": ..., "nicks": {"name"}},
},
}
)
.all()
)
assert all_cars3 == all_cars
for car in itertools.chain(all_cars, all_cars2):
assert all( assert all(
getattr(car, x) is None getattr(car, x) is None
for x in ["year", "gearbox_type", "gears", "aircon_type"] for x in ["year", "gearbox_type", "gears", "aircon_type"]
) )
assert car.manufacturer.name == "Toyota" assert car.manufacturer.name == "Toyota"
assert car.manufacturer.founded is None assert car.manufacturer.founded is None
assert car.manufacturer.hq.name == "Japan"
assert len(car.manufacturer.hq.nicks) == 2
assert car.manufacturer.hq.nicks[0].is_lame is None
all_cars = ( all_cars = (
await Car.objects.select_related("manufacturer") await Car.objects.select_related("manufacturer")
@ -103,9 +190,16 @@ async def test_selecting_subset():
) )
assert car.manufacturer.name == "Toyota" assert car.manufacturer.name == "Toyota"
assert car.manufacturer.founded == 1937 assert car.manufacturer.founded == 1937
assert car.manufacturer.hq.name is None
all_cars_check = await Car.objects.select_related("manufacturer").all() all_cars_check = await Car.objects.select_related("manufacturer").all()
for car in all_cars_check: all_cars_with_whole_nested = (
await Car.objects.select_related("manufacturer")
.fields(["id", "name", "year", "gearbox_type", "gears", "aircon_type"])
.fields({"manufacturer": ...})
.all()
)
for car in itertools.chain(all_cars_check, all_cars_with_whole_nested):
assert all( assert all(
getattr(car, x) is not None getattr(car, x) is not None
for x in ["year", "gearbox_type", "gears", "aircon_type"] for x in ["year", "gearbox_type", "gears", "aircon_type"]
@ -113,8 +207,20 @@ async def test_selecting_subset():
assert car.manufacturer.name == "Toyota" assert car.manufacturer.name == "Toyota"
assert car.manufacturer.founded == 1937 assert car.manufacturer.founded == 1937
all_cars_dummy = (
await Car.objects.select_related("manufacturer")
.fields(["id", "name", "year", "gearbox_type", "gears", "aircon_type"])
.fields({"manufacturer": ...})
.exclude_fields({"manufacturer": ...})
.fields({"manufacturer": {"name"}})
.exclude_fields({"manufacturer__founded"})
.all()
)
assert all_cars_dummy[0].manufacturer.founded is None
with pytest.raises(pydantic.error_wrappers.ValidationError): with pytest.raises(pydantic.error_wrappers.ValidationError):
# cannot exclude mandatory model columns - company__name in this example # cannot exclude mandatory model columns - company__name in this example
await Car.objects.select_related("manufacturer").fields( await Car.objects.select_related("manufacturer").fields(
["id", "name", "company__founded"] ["id", "name", "manufacturer__founded"]
).all() ).all()