Merge pull request #44 from collerek/allow_dict_in_fields
Allow dict in fields, optimizations and cleanup
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -9,3 +9,4 @@ test.db
|
|||||||
dist
|
dist
|
||||||
/ormar.egg-info/
|
/ormar.egg-info/
|
||||||
site
|
site
|
||||||
|
profile.py
|
||||||
|
|||||||
@ -163,8 +163,8 @@ assert len(tracks) == 1
|
|||||||
* `offset(offset: int) -> QuerySet`
|
* `offset(offset: int) -> QuerySet`
|
||||||
* `count() -> int`
|
* `count() -> int`
|
||||||
* `exists() -> bool`
|
* `exists() -> bool`
|
||||||
* `fields(columns: Union[List, str]) -> QuerySet`
|
* `fields(columns: Union[List, str, set, dict]) -> QuerySet`
|
||||||
* `exclude_fields(columns: Union[List, str]) -> QuerySet`
|
* `exclude_fields(columns: Union[List, str, set, dict]) -> QuerySet`
|
||||||
* `order_by(columns:Union[List, str]) -> QuerySet`
|
* `order_by(columns:Union[List, str]) -> QuerySet`
|
||||||
|
|
||||||
#### Relation types
|
#### Relation types
|
||||||
|
|||||||
@ -163,8 +163,8 @@ assert len(tracks) == 1
|
|||||||
* `offset(offset: int) -> QuerySet`
|
* `offset(offset: int) -> QuerySet`
|
||||||
* `count() -> int`
|
* `count() -> int`
|
||||||
* `exists() -> bool`
|
* `exists() -> bool`
|
||||||
* `fields(columns: Union[List, str]) -> QuerySet`
|
* `fields(columns: Union[List, str, set, dict]) -> QuerySet`
|
||||||
* `exclude_fields(columns: Union[List, str]) -> QuerySet`
|
* `exclude_fields(columns: Union[List, str, set, dict]) -> QuerySet`
|
||||||
* `order_by(columns:Union[List, str]) -> QuerySet`
|
* `order_by(columns:Union[List, str]) -> QuerySet`
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -348,20 +348,73 @@ has_sample = await Book.objects.filter(title='Sample').exists()
|
|||||||
|
|
||||||
### fields
|
### fields
|
||||||
|
|
||||||
`fields(columns: Union[List, str]) -> QuerySet`
|
`fields(columns: Union[List, str, set, dict]) -> QuerySet`
|
||||||
|
|
||||||
With `fields()` you can select subset of model columns to limit the data load.
|
With `fields()` you can select subset of model columns to limit the data load.
|
||||||
|
|
||||||
```python hl_lines="47 59 60 66"
|
Given a sample data like following:
|
||||||
|
|
||||||
|
```python
|
||||||
--8<-- "../docs_src/queries/docs006.py"
|
--8<-- "../docs_src/queries/docs006.py"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can select specified fields by passing a `str, List[str], Set[str] or dict` with nested definition.
|
||||||
|
|
||||||
|
To include related models use notation `{related_name}__{column}[__{optional_next} etc.]`.
|
||||||
|
|
||||||
|
```python hl_lines="1"
|
||||||
|
all_cars = await Car.objects.select_related('manufacturer').fields(['id', 'name', 'manufacturer__name']).all()
|
||||||
|
for car in all_cars:
|
||||||
|
# excluded columns will yield None
|
||||||
|
assert all(getattr(car, x) is None for x in ['year', 'gearbox_type', 'gears', 'aircon_type'])
|
||||||
|
# included column on related models will be available, pk column is always included
|
||||||
|
# even if you do not include it in fields list
|
||||||
|
assert car.manufacturer.name == 'Toyota'
|
||||||
|
# also in the nested related models - you cannot exclude pk - it's always auto added
|
||||||
|
assert car.manufacturer.founded is None
|
||||||
|
```
|
||||||
|
|
||||||
|
`fields()` can be called several times, building up the columns to select.
|
||||||
|
|
||||||
|
If you include related models into `select_related()` call but you won't specify columns for those models in fields
|
||||||
|
- implies a list of all fields for those nested models.
|
||||||
|
|
||||||
|
```python hl_lines="1"
|
||||||
|
all_cars = await Car.objects.select_related('manufacturer').fields('id').fields(
|
||||||
|
['name']).all()
|
||||||
|
# all fiels from company model are selected
|
||||||
|
assert all_cars[0].manufacturer.name == 'Toyota'
|
||||||
|
assert all_cars[0].manufacturer.founded == 1937
|
||||||
|
```
|
||||||
|
|
||||||
!!!warning
|
!!!warning
|
||||||
Mandatory fields cannot be excluded as it will raise `ValidationError`, to exclude a field it has to be nullable.
|
Mandatory fields cannot be excluded as it will raise `ValidationError`, to exclude a field it has to be nullable.
|
||||||
|
|
||||||
|
You cannot exclude mandatory model columns - `manufacturer__name` in this example.
|
||||||
|
|
||||||
|
```python
|
||||||
|
await Car.objects.select_related('manufacturer').fields(['id', 'name', 'manufacturer__founded']).all()
|
||||||
|
# will raise pydantic ValidationError as company.name is required
|
||||||
|
```
|
||||||
|
|
||||||
!!!tip
|
!!!tip
|
||||||
Pk column cannot be excluded - it's always auto added even if not explicitly included.
|
Pk column cannot be excluded - it's always auto added even if not explicitly included.
|
||||||
|
|
||||||
|
You can also pass fields to include as dictionary or set.
|
||||||
|
|
||||||
|
To mark a field as included in a dictionary use it's name as key and ellipsis as value.
|
||||||
|
|
||||||
|
To traverse nested models use nested dictionaries.
|
||||||
|
|
||||||
|
To include fields at last level instead of nested dictionary a set can be used.
|
||||||
|
|
||||||
|
To include whole nested model specify model related field name and ellipsis.
|
||||||
|
|
||||||
|
Below you can see examples that are equivalent:
|
||||||
|
|
||||||
|
```python
|
||||||
|
--8<-- "../docs_src/queries/docs009.py"
|
||||||
|
```
|
||||||
|
|
||||||
!!!note
|
!!!note
|
||||||
All methods that do not return the rows explicitly returns a QueySet instance so you can chain them together
|
All methods that do not return the rows explicitly returns a QueySet instance so you can chain them together
|
||||||
@ -372,11 +425,15 @@ With `fields()` you can select subset of model columns to limit the data load.
|
|||||||
|
|
||||||
### exclude_fields
|
### exclude_fields
|
||||||
|
|
||||||
`fields(columns: Union[List, str]) -> QuerySet`
|
`exclude_fields(columns: Union[List, str, set, dict]) -> QuerySet`
|
||||||
|
|
||||||
With `exclude_fields()` you can select subset of model columns that will be excluded to limit the data load.
|
With `exclude_fields()` you can select subset of model columns that will be excluded to limit the data load.
|
||||||
|
|
||||||
It's the oposite of `fields()` method.
|
It's the opposite of `fields()` method so check documentation above to see what options are available.
|
||||||
|
|
||||||
|
Especially check above how you can pass also nested dictionaries and sets as a mask to exclude fields from whole hierarchy.
|
||||||
|
|
||||||
|
Below you can find few simple examples:
|
||||||
|
|
||||||
```python hl_lines="47 48 60 61 67"
|
```python hl_lines="47 48 60 61 67"
|
||||||
--8<-- "../docs_src/queries/docs008.py"
|
--8<-- "../docs_src/queries/docs008.py"
|
||||||
|
|||||||
@ -43,25 +43,3 @@ await Car.objects.create(manufacturer=toyota, name="Yaris", year=2019, gearbox_t
|
|||||||
await Car.objects.create(manufacturer=toyota, name="Supreme", year=2020, gearbox_type='Auto', gears=6,
|
await Car.objects.create(manufacturer=toyota, name="Supreme", year=2020, gearbox_type='Auto', gears=6,
|
||||||
aircon_type='Auto')
|
aircon_type='Auto')
|
||||||
|
|
||||||
# select manufacturer but only name - to include related models use notation {model_name}__{column}
|
|
||||||
all_cars = await Car.objects.select_related('manufacturer').fields(['id', 'name', 'company__name']).all()
|
|
||||||
for car in all_cars:
|
|
||||||
# excluded columns will yield None
|
|
||||||
assert all(getattr(car, x) is None for x in ['year', 'gearbox_type', 'gears', 'aircon_type'])
|
|
||||||
# included column on related models will be available, pk column is always included
|
|
||||||
# even if you do not include it in fields list
|
|
||||||
assert car.manufacturer.name == 'Toyota'
|
|
||||||
# also in the nested related models - you cannot exclude pk - it's always auto added
|
|
||||||
assert car.manufacturer.founded is None
|
|
||||||
|
|
||||||
# fields() can be called several times, building up the columns to select
|
|
||||||
# models selected in select_related but with no columns in fields list implies all fields
|
|
||||||
all_cars = await Car.objects.select_related('manufacturer').fields('id').fields(
|
|
||||||
['name']).all()
|
|
||||||
# all fiels from company model are selected
|
|
||||||
assert all_cars[0].manufacturer.name == 'Toyota'
|
|
||||||
assert all_cars[0].manufacturer.founded == 1937
|
|
||||||
|
|
||||||
# cannot exclude mandatory model columns - company__name in this example
|
|
||||||
await Car.objects.select_related('manufacturer').fields(['id', 'name', 'company__founded']).all()
|
|
||||||
# will raise pydantic ValidationError as company.name is required
|
|
||||||
|
|||||||
@ -63,6 +63,6 @@ all_cars = await Car.objects.select_related('manufacturer').exclude_fields('year
|
|||||||
assert all_cars[0].manufacturer.name == 'Toyota'
|
assert all_cars[0].manufacturer.name == 'Toyota'
|
||||||
assert all_cars[0].manufacturer.founded == 1937
|
assert all_cars[0].manufacturer.founded == 1937
|
||||||
|
|
||||||
# cannot exclude mandatory model columns - company__name in this example
|
# cannot exclude mandatory model columns - company__name in this example - note usage of dict/set this time
|
||||||
await Car.objects.select_related('manufacturer').exclude_fields(['company__name']).all()
|
await Car.objects.select_related('manufacturer').exclude_fields([{'company': {'name'}}]).all()
|
||||||
# will raise pydantic ValidationError as company.name is required
|
# will raise pydantic ValidationError as company.name is required
|
||||||
|
|||||||
33
docs_src/queries/docs009.py
Normal file
33
docs_src/queries/docs009.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# 1. like in example above
|
||||||
|
await Car.objects.select_related('manufacturer').fields(['id', 'name', 'manufacturer__name']).all()
|
||||||
|
|
||||||
|
# 2. to mark a field as required use ellipsis
|
||||||
|
await Car.objects.select_related('manufacturer').fields({'id': ...,
|
||||||
|
'name': ...,
|
||||||
|
'manufacturer': {
|
||||||
|
'name': ...}
|
||||||
|
}).all()
|
||||||
|
|
||||||
|
# 3. to include whole nested model use ellipsis
|
||||||
|
await Car.objects.select_related('manufacturer').fields({'id': ...,
|
||||||
|
'name': ...,
|
||||||
|
'manufacturer': ...
|
||||||
|
}).all()
|
||||||
|
|
||||||
|
# 4. to specify fields at last nesting level you can also use set - equivalent to 2. above
|
||||||
|
await Car.objects.select_related('manufacturer').fields({'id': ...,
|
||||||
|
'name': ...,
|
||||||
|
'manufacturer': {'name'}
|
||||||
|
}).all()
|
||||||
|
|
||||||
|
# 5. of course set can have multiple fields
|
||||||
|
await Car.objects.select_related('manufacturer').fields({'id': ...,
|
||||||
|
'name': ...,
|
||||||
|
'manufacturer': {'name', 'founded'}
|
||||||
|
}).all()
|
||||||
|
|
||||||
|
# 6. you can include all nested fields but it will be equivalent of 3. above which is shorter
|
||||||
|
await Car.objects.select_related('manufacturer').fields({'id': ...,
|
||||||
|
'name': ...,
|
||||||
|
'manufacturer': {'id', 'name', 'founded'}
|
||||||
|
}).all()
|
||||||
@ -5,6 +5,7 @@ import sqlalchemy
|
|||||||
from pydantic import Field, typing
|
from pydantic import Field, typing
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
|
import ormar # noqa I101
|
||||||
from ormar import ModelDefinitionError # noqa I101
|
from ormar import ModelDefinitionError # noqa I101
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
@ -34,6 +35,10 @@ class BaseField(FieldInfo):
|
|||||||
default: Any
|
default: Any
|
||||||
server_default: Any
|
server_default: Any
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_valid_uni_relation(cls) -> bool:
|
||||||
|
return not issubclass(cls, ormar.fields.ManyToManyField) and not cls.virtual
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_alias(cls) -> str:
|
def get_alias(cls) -> str:
|
||||||
return cls.alias if cls.alias else cls.name
|
return cls.alias if cls.alias else cls.name
|
||||||
|
|||||||
43
ormar/models/excludable.py
Normal file
43
ormar/models/excludable.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from typing import Dict, Set, Union
|
||||||
|
|
||||||
|
|
||||||
|
class Excludable:
|
||||||
|
@staticmethod
|
||||||
|
def get_excluded(
|
||||||
|
exclude: Union[Set, Dict, None], key: str = None
|
||||||
|
) -> Union[Set, Dict, None]:
|
||||||
|
if isinstance(exclude, dict):
|
||||||
|
return exclude.get(key, {})
|
||||||
|
return exclude
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_included(
|
||||||
|
include: Union[Set, Dict, None], key: str = None
|
||||||
|
) -> Union[Set, Dict, None]:
|
||||||
|
return Excludable.get_excluded(exclude=include, key=key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_excluded(exclude: Union[Set, Dict, None], key: str = None) -> bool:
|
||||||
|
if exclude is None:
|
||||||
|
return False
|
||||||
|
if exclude is Ellipsis: # pragma: nocover
|
||||||
|
return True
|
||||||
|
to_exclude = Excludable.get_excluded(exclude=exclude, key=key)
|
||||||
|
if isinstance(to_exclude, Set):
|
||||||
|
return key in to_exclude
|
||||||
|
elif to_exclude is ...:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_included(include: Union[Set, Dict, None], key: str = None) -> bool:
|
||||||
|
if include is None:
|
||||||
|
return True
|
||||||
|
if include is Ellipsis:
|
||||||
|
return True
|
||||||
|
to_include = Excludable.get_included(include=include, key=key)
|
||||||
|
if isinstance(to_include, Set):
|
||||||
|
return key in to_include
|
||||||
|
elif to_include is ...:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
@ -1,5 +1,5 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar
|
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
@ -47,8 +47,8 @@ class Model(NewBaseModel):
|
|||||||
select_related: List = None,
|
select_related: List = None,
|
||||||
related_models: Any = None,
|
related_models: Any = None,
|
||||||
previous_table: str = None,
|
previous_table: str = None,
|
||||||
fields: List = None,
|
fields: Optional[Union[Dict, Set]] = None,
|
||||||
exclude_fields: List = None,
|
exclude_fields: Optional[Union[Dict, Set]] = None,
|
||||||
) -> Optional[T]:
|
) -> Optional[T]:
|
||||||
|
|
||||||
item: Dict[str, Any] = {}
|
item: Dict[str, Any] = {}
|
||||||
@ -88,7 +88,6 @@ class Model(NewBaseModel):
|
|||||||
table_prefix=table_prefix,
|
table_prefix=table_prefix,
|
||||||
fields=fields,
|
fields=fields,
|
||||||
exclude_fields=exclude_fields,
|
exclude_fields=exclude_fields,
|
||||||
nested=table_prefix != "",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
instance: Optional[T] = cls(**item) if item.get(
|
instance: Optional[T] = cls(**item) if item.get(
|
||||||
@ -103,13 +102,17 @@ class Model(NewBaseModel):
|
|||||||
row: sqlalchemy.engine.ResultProxy,
|
row: sqlalchemy.engine.ResultProxy,
|
||||||
related_models: Any,
|
related_models: Any,
|
||||||
previous_table: sqlalchemy.Table,
|
previous_table: sqlalchemy.Table,
|
||||||
fields: List = None,
|
fields: Optional[Union[Dict, Set]] = None,
|
||||||
exclude_fields: List = None,
|
exclude_fields: Optional[Union[Dict, Set]] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
for related in related_models:
|
for related in related_models:
|
||||||
if isinstance(related_models, dict) and related_models[related]:
|
if isinstance(related_models, dict) and related_models[related]:
|
||||||
first_part, remainder = related, related_models[related]
|
first_part, remainder = related, related_models[related]
|
||||||
model_cls = cls.Meta.model_fields[first_part].to
|
model_cls = cls.Meta.model_fields[first_part].to
|
||||||
|
|
||||||
|
fields = cls.get_included(fields, first_part)
|
||||||
|
exclude_fields = cls.get_excluded(exclude_fields, first_part)
|
||||||
|
|
||||||
child = model_cls.from_row(
|
child = model_cls.from_row(
|
||||||
row,
|
row,
|
||||||
related_models=remainder,
|
related_models=remainder,
|
||||||
@ -120,6 +123,8 @@ class Model(NewBaseModel):
|
|||||||
item[model_cls.get_column_name_from_alias(first_part)] = child
|
item[model_cls.get_column_name_from_alias(first_part)] = child
|
||||||
else:
|
else:
|
||||||
model_cls = cls.Meta.model_fields[related].to
|
model_cls = cls.Meta.model_fields[related].to
|
||||||
|
fields = cls.get_included(fields, related)
|
||||||
|
exclude_fields = cls.get_excluded(exclude_fields, related)
|
||||||
child = model_cls.from_row(
|
child = model_cls.from_row(
|
||||||
row,
|
row,
|
||||||
previous_table=previous_table,
|
previous_table=previous_table,
|
||||||
@ -136,16 +141,18 @@ class Model(NewBaseModel):
|
|||||||
item: dict,
|
item: dict,
|
||||||
row: sqlalchemy.engine.result.ResultProxy,
|
row: sqlalchemy.engine.result.ResultProxy,
|
||||||
table_prefix: str,
|
table_prefix: str,
|
||||||
fields: List = None,
|
fields: Optional[Union[Dict, Set]] = None,
|
||||||
exclude_fields: List = None,
|
exclude_fields: Optional[Union[Dict, Set]] = None,
|
||||||
nested: bool = False,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
||||||
# databases does not keep aliases in Record for postgres, change to raw row
|
# databases does not keep aliases in Record for postgres, change to raw row
|
||||||
source = row._row if cls.db_backend_name() == "postgresql" else row
|
source = row._row if cls.db_backend_name() == "postgresql" else row
|
||||||
|
|
||||||
selected_columns = cls.own_table_columns(
|
selected_columns = cls.own_table_columns(
|
||||||
cls, fields or [], exclude_fields or [], nested=nested, use_alias=True
|
model=cls,
|
||||||
|
fields=fields or {},
|
||||||
|
exclude_fields=exclude_fields or {},
|
||||||
|
use_alias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
for column in cls.Meta.table.columns:
|
for column in cls.Meta.table.columns:
|
||||||
|
|||||||
@ -1,12 +1,28 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Dict, List, Sequence, Set, TYPE_CHECKING, Type, TypeVar, Union
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Set,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import ormar
|
|
||||||
from ormar.exceptions import RelationshipInstanceError
|
from ormar.exceptions import RelationshipInstanceError
|
||||||
from ormar.fields import BaseField, ManyToManyField
|
|
||||||
|
try:
|
||||||
|
import orjson as json
|
||||||
|
except ImportError: # pragma: nocover
|
||||||
|
import json # type: ignore
|
||||||
|
|
||||||
|
import ormar # noqa: I100
|
||||||
|
from ormar.fields import BaseField
|
||||||
from ormar.fields.foreign_key import ForeignKeyField
|
from ormar.fields.foreign_key import ForeignKeyField
|
||||||
from ormar.models.metaclass import ModelMeta, expand_reverse_relationships
|
from ormar.models.metaclass import ModelMeta
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
@ -20,6 +36,8 @@ Field = TypeVar("Field", bound=BaseField)
|
|||||||
class ModelTableProxy:
|
class ModelTableProxy:
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
Meta: ModelMeta
|
Meta: ModelMeta
|
||||||
|
_related_names: Set
|
||||||
|
_related_names_hash: Union[str, bytes]
|
||||||
|
|
||||||
def dict(self): # noqa A003
|
def dict(self): # noqa A003
|
||||||
raise NotImplementedError # pragma no cover
|
raise NotImplementedError # pragma no cover
|
||||||
@ -65,50 +83,50 @@ class ModelTableProxy:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_column_alias(cls, field_name: str) -> str:
|
def get_column_alias(cls, field_name: str) -> str:
|
||||||
field = cls.Meta.model_fields.get(field_name)
|
field = cls.Meta.model_fields.get(field_name)
|
||||||
if field and field.alias is not None:
|
if field is not None and field.alias is not None:
|
||||||
return field.alias
|
return field.alias
|
||||||
return field_name
|
return field_name
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_column_name_from_alias(cls, alias: str) -> str:
|
def get_column_name_from_alias(cls, alias: str) -> str:
|
||||||
for field_name, field in cls.Meta.model_fields.items():
|
for field_name, field in cls.Meta.model_fields.items():
|
||||||
if field and field.alias == alias:
|
if field is not None and field.alias == alias:
|
||||||
return field_name
|
return field_name
|
||||||
return alias # if not found it's not an alias but actual name
|
return alias # if not found it's not an alias but actual name
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_related_names(cls) -> Set:
|
def extract_related_names(cls) -> Set:
|
||||||
|
|
||||||
|
if isinstance(cls._related_names_hash, (str, bytes)):
|
||||||
|
return cls._related_names
|
||||||
|
|
||||||
related_names = set()
|
related_names = set()
|
||||||
for name, field in cls.Meta.model_fields.items():
|
for name, field in cls.Meta.model_fields.items():
|
||||||
if inspect.isclass(field) and issubclass(field, ForeignKeyField):
|
if inspect.isclass(field) and issubclass(field, ForeignKeyField):
|
||||||
related_names.add(name)
|
related_names.add(name)
|
||||||
|
cls._related_names_hash = json.dumps(list(cls.Meta.model_fields.keys()))
|
||||||
|
cls._related_names = related_names
|
||||||
|
|
||||||
return related_names
|
return related_names
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_db_related_names(cls) -> Set:
|
def _extract_db_related_names(cls) -> Set:
|
||||||
related_names = set()
|
related_names = cls.extract_related_names()
|
||||||
for name, field in cls.Meta.model_fields.items():
|
related_names = {
|
||||||
if (
|
name
|
||||||
inspect.isclass(field)
|
for name in related_names
|
||||||
and issubclass(field, ForeignKeyField)
|
if cls.Meta.model_fields[name].is_valid_uni_relation()
|
||||||
and not issubclass(field, ManyToManyField)
|
}
|
||||||
and not field.virtual
|
|
||||||
):
|
|
||||||
related_names.add(name)
|
|
||||||
return related_names
|
return related_names
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
|
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
|
||||||
if nested:
|
if nested:
|
||||||
return cls.extract_related_names()
|
return cls.extract_related_names()
|
||||||
related_names = set()
|
related_names = cls.extract_related_names()
|
||||||
for name, field in cls.Meta.model_fields.items():
|
related_names = {
|
||||||
if (
|
name for name in related_names if cls.Meta.model_fields[name].nullable
|
||||||
inspect.isclass(field)
|
}
|
||||||
and issubclass(field, ForeignKeyField)
|
|
||||||
and field.nullable
|
|
||||||
):
|
|
||||||
related_names.add(name)
|
|
||||||
return related_names
|
return related_names
|
||||||
|
|
||||||
def _extract_model_db_fields(self) -> Dict:
|
def _extract_model_db_fields(self) -> Dict:
|
||||||
@ -128,7 +146,6 @@ class ModelTableProxy:
|
|||||||
def resolve_relation_name( # noqa CCR001
|
def resolve_relation_name( # noqa CCR001
|
||||||
item: Union["NewBaseModel", Type["NewBaseModel"]],
|
item: Union["NewBaseModel", Type["NewBaseModel"]],
|
||||||
related: Union["NewBaseModel", Type["NewBaseModel"]],
|
related: Union["NewBaseModel", Type["NewBaseModel"]],
|
||||||
register_missing: bool = True,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
for name, field in item.Meta.model_fields.items():
|
for name, field in item.Meta.model_fields.items():
|
||||||
if issubclass(field, ForeignKeyField):
|
if issubclass(field, ForeignKeyField):
|
||||||
@ -137,12 +154,6 @@ class ModelTableProxy:
|
|||||||
# so we need to compare Meta too as this one is copied as is
|
# so we need to compare Meta too as this one is copied as is
|
||||||
if field.to == related.__class__ or field.to.Meta == related.Meta:
|
if field.to == related.__class__ or field.to.Meta == related.Meta:
|
||||||
return name
|
return name
|
||||||
# fallback for not registered relation
|
|
||||||
if register_missing: # pragma nocover
|
|
||||||
expand_reverse_relationships(related.__class__) # type: ignore
|
|
||||||
return ModelTableProxy.resolve_relation_name(
|
|
||||||
item, related, register_missing=False
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No relation between {item.get_name()} and {related.get_name()}"
|
f"No relation between {item.get_name()} and {related.get_name()}"
|
||||||
@ -151,7 +162,7 @@ class ModelTableProxy:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def resolve_relation_field(
|
def resolve_relation_field(
|
||||||
item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]]
|
item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]]
|
||||||
) -> Union[Type[BaseField], Type[ForeignKeyField]]:
|
) -> Type[BaseField]:
|
||||||
name = ModelTableProxy.resolve_relation_name(item, related)
|
name = ModelTableProxy.resolve_relation_name(item, related)
|
||||||
to_field = item.Meta.model_fields.get(name)
|
to_field = item.Meta.model_fields.get(name)
|
||||||
if not to_field: # pragma no cover
|
if not to_field: # pragma no cover
|
||||||
@ -211,59 +222,13 @@ class ModelTableProxy:
|
|||||||
)
|
)
|
||||||
return other
|
return other
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_not_nested_columns_from_fields(
|
|
||||||
model: Type["Model"],
|
|
||||||
fields: List,
|
|
||||||
exclude_fields: List,
|
|
||||||
column_names: List[str],
|
|
||||||
use_alias: bool = False,
|
|
||||||
) -> List[str]:
|
|
||||||
fields = [model.get_column_alias(k) if not use_alias else k for k in fields]
|
|
||||||
fields = fields or column_names
|
|
||||||
exclude_fields = [
|
|
||||||
model.get_column_alias(k) if not use_alias else k for k in exclude_fields
|
|
||||||
]
|
|
||||||
columns = [
|
|
||||||
name
|
|
||||||
for name in fields
|
|
||||||
if "__" not in name and name in column_names and name not in exclude_fields
|
|
||||||
]
|
|
||||||
return columns
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_nested_columns_from_fields(
|
|
||||||
model: Type["Model"],
|
|
||||||
fields: List,
|
|
||||||
exclude_fields: List,
|
|
||||||
column_names: List[str],
|
|
||||||
use_alias: bool = False,
|
|
||||||
) -> List[str]:
|
|
||||||
model_name = f"{model.get_name()}__"
|
|
||||||
columns = [
|
|
||||||
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
|
|
||||||
for name in fields
|
|
||||||
if f"{model.get_name()}__" in name
|
|
||||||
]
|
|
||||||
columns = columns or column_names
|
|
||||||
exclude_columns = [
|
|
||||||
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
|
|
||||||
for name in exclude_fields
|
|
||||||
if f"{model.get_name()}__" in name
|
|
||||||
]
|
|
||||||
columns = [model.get_column_alias(k) if not use_alias else k for k in columns]
|
|
||||||
exclude_columns = [
|
|
||||||
model.get_column_alias(k) if not use_alias else k for k in exclude_columns
|
|
||||||
]
|
|
||||||
return [column for column in columns if column not in exclude_columns]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _populate_pk_column(
|
def _populate_pk_column(
|
||||||
model: Type["Model"], columns: List[str], use_alias: bool = False,
|
model: Type["Model"], columns: List[str], use_alias: bool = False,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
pk_alias = (
|
pk_alias = (
|
||||||
model.get_column_alias(model.Meta.pkname)
|
model.get_column_alias(model.Meta.pkname)
|
||||||
if not use_alias
|
if use_alias
|
||||||
else model.Meta.pkname
|
else model.Meta.pkname
|
||||||
)
|
)
|
||||||
if pk_alias not in columns:
|
if pk_alias not in columns:
|
||||||
@ -273,34 +238,30 @@ class ModelTableProxy:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def own_table_columns(
|
def own_table_columns(
|
||||||
model: Type["Model"],
|
model: Type["Model"],
|
||||||
fields: List,
|
fields: Optional[Union[Set, Dict]],
|
||||||
exclude_fields: List,
|
exclude_fields: Optional[Union[Set, Dict]],
|
||||||
nested: bool = False,
|
|
||||||
use_alias: bool = False,
|
use_alias: bool = False,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
column_names = [
|
columns = [
|
||||||
model.get_column_name_from_alias(col.name) if use_alias else col.name
|
model.get_column_name_from_alias(col.name) if not use_alias else col.name
|
||||||
for col in model.Meta.table.columns
|
for col in model.Meta.table.columns
|
||||||
]
|
]
|
||||||
if not fields and not exclude_fields:
|
field_names = [
|
||||||
return column_names
|
model.get_column_name_from_alias(col.name)
|
||||||
|
for col in model.Meta.table.columns
|
||||||
if not nested:
|
]
|
||||||
columns = ModelTableProxy._get_not_nested_columns_from_fields(
|
if fields:
|
||||||
model=model,
|
columns = [
|
||||||
fields=fields,
|
col
|
||||||
exclude_fields=exclude_fields,
|
for col, name in zip(columns, field_names)
|
||||||
column_names=column_names,
|
if model.is_included(fields, name)
|
||||||
use_alias=use_alias,
|
]
|
||||||
)
|
if exclude_fields:
|
||||||
else:
|
columns = [
|
||||||
columns = ModelTableProxy._get_nested_columns_from_fields(
|
col
|
||||||
model=model,
|
for col, name in zip(columns, field_names)
|
||||||
fields=fields,
|
if not model.is_excluded(exclude_fields, name)
|
||||||
exclude_fields=exclude_fields,
|
]
|
||||||
column_names=column_names,
|
|
||||||
use_alias=use_alias,
|
|
||||||
)
|
|
||||||
|
|
||||||
# always has to return pk column
|
# always has to return pk column
|
||||||
columns = ModelTableProxy._populate_pk_column(
|
columns = ModelTableProxy._populate_pk_column(
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from typing import (
|
|||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Set,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
@ -23,6 +24,7 @@ from pydantic import BaseModel
|
|||||||
import ormar # noqa I100
|
import ormar # noqa I100
|
||||||
from ormar.fields import BaseField
|
from ormar.fields import BaseField
|
||||||
from ormar.fields.foreign_key import ForeignKeyField
|
from ormar.fields.foreign_key import ForeignKeyField
|
||||||
|
from ormar.models.excludable import Excludable
|
||||||
from ormar.models.metaclass import ModelMeta, ModelMetaclass
|
from ormar.models.metaclass import ModelMeta, ModelMetaclass
|
||||||
from ormar.models.modelproxy import ModelTableProxy
|
from ormar.models.modelproxy import ModelTableProxy
|
||||||
from ormar.relations.alias_manager import AliasManager
|
from ormar.relations.alias_manager import AliasManager
|
||||||
@ -39,8 +41,17 @@ if TYPE_CHECKING: # pragma no cover
|
|||||||
MappingIntStrAny = Mapping[IntStr, Any]
|
MappingIntStrAny = Mapping[IntStr, Any]
|
||||||
|
|
||||||
|
|
||||||
class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass):
|
class NewBaseModel(
|
||||||
__slots__ = ("_orm_id", "_orm_saved", "_orm")
|
pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass
|
||||||
|
):
|
||||||
|
__slots__ = (
|
||||||
|
"_orm_id",
|
||||||
|
"_orm_saved",
|
||||||
|
"_orm",
|
||||||
|
"_related_names",
|
||||||
|
"_related_names_hash",
|
||||||
|
"_props",
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
__model_fields__: Dict[str, Type[BaseField]]
|
__model_fields__: Dict[str, Type[BaseField]]
|
||||||
@ -53,6 +64,10 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
__database__: databases.Database
|
__database__: databases.Database
|
||||||
_orm_relationship_manager: AliasManager
|
_orm_relationship_manager: AliasManager
|
||||||
_orm: RelationsManager
|
_orm: RelationsManager
|
||||||
|
_orm_saved: bool
|
||||||
|
_related_names: Set
|
||||||
|
_related_names_hash: str
|
||||||
|
_props: List[str]
|
||||||
Meta: ModelMeta
|
Meta: ModelMeta
|
||||||
|
|
||||||
# noinspection PyMissingConstructor
|
# noinspection PyMissingConstructor
|
||||||
@ -104,7 +119,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
|
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
|
||||||
if name in ("_orm_id", "_orm_saved", "_orm"):
|
if name in ("_orm_id", "_orm_saved", "_orm", "_related_names", "_props"):
|
||||||
object.__setattr__(self, name, value)
|
object.__setattr__(self, name, value)
|
||||||
elif name == "pk":
|
elif name == "pk":
|
||||||
object.__setattr__(self, self.Meta.pkname, value)
|
object.__setattr__(self, self.Meta.pkname, value)
|
||||||
@ -123,12 +138,19 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
super().__setattr__(name, value)
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
def __getattribute__(self, item: str) -> Any:
|
def __getattribute__(self, item: str) -> Any:
|
||||||
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"):
|
if item in (
|
||||||
|
"_orm_id",
|
||||||
|
"_orm_saved",
|
||||||
|
"_orm",
|
||||||
|
"__fields__",
|
||||||
|
"_related_names",
|
||||||
|
"_props",
|
||||||
|
):
|
||||||
return object.__getattribute__(self, item)
|
return object.__getattribute__(self, item)
|
||||||
if item != "extract_related_names" and item in self.extract_related_names():
|
|
||||||
return self._extract_related_model_instead_of_field(item)
|
|
||||||
if item == "pk":
|
if item == "pk":
|
||||||
return self.__dict__.get(self.Meta.pkname, None)
|
return self.__dict__.get(self.Meta.pkname, None)
|
||||||
|
if item != "extract_related_names" and item in self.extract_related_names():
|
||||||
|
return self._extract_related_model_instead_of_field(item)
|
||||||
if item != "__fields__" and item in self.__fields__:
|
if item != "__fields__" and item in self.__fields__:
|
||||||
value = self.__dict__.get(item, None)
|
value = self.__dict__.get(item, None)
|
||||||
value = self._convert_json(item, value, "loads")
|
value = self._convert_json(item, value, "loads")
|
||||||
@ -183,12 +205,16 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
|
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
|
||||||
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
|
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
props = [
|
if isinstance(cls._props, list):
|
||||||
prop
|
props = cls._props
|
||||||
for prop in dir(cls)
|
else:
|
||||||
if isinstance(getattr(cls, prop), property)
|
props = [
|
||||||
and prop not in ("__values__", "__fields__", "fields", "pk_column")
|
prop
|
||||||
]
|
for prop in dir(cls)
|
||||||
|
if isinstance(getattr(cls, prop), property)
|
||||||
|
and prop not in ("__values__", "__fields__", "fields", "pk_column")
|
||||||
|
]
|
||||||
|
cls._props = props
|
||||||
if include:
|
if include:
|
||||||
props = [prop for prop in props if prop in include]
|
props = [prop for prop in props if prop in include]
|
||||||
if exclude:
|
if exclude:
|
||||||
|
|||||||
@ -1,5 +1,15 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import List, NamedTuple, Optional, TYPE_CHECKING, Tuple, Type
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
NamedTuple,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
@ -24,8 +34,8 @@ class SqlJoin:
|
|||||||
used_aliases: List,
|
used_aliases: List,
|
||||||
select_from: sqlalchemy.sql.select,
|
select_from: sqlalchemy.sql.select,
|
||||||
columns: List[sqlalchemy.Column],
|
columns: List[sqlalchemy.Column],
|
||||||
fields: List,
|
fields: Optional[Union[Set, Dict]],
|
||||||
exclude_fields: List,
|
exclude_fields: Optional[Union[Set, Dict]],
|
||||||
order_columns: Optional[List],
|
order_columns: Optional[List],
|
||||||
sorted_orders: OrderedDict,
|
sorted_orders: OrderedDict,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -49,21 +59,60 @@ class SqlJoin:
|
|||||||
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
|
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
|
||||||
return text(f"{left_part}={right_part}")
|
return text(f"{left_part}={right_part}")
|
||||||
|
|
||||||
def build_join(
|
@staticmethod
|
||||||
|
def update_inclusions(
|
||||||
|
model_cls: Type["Model"],
|
||||||
|
fields: Optional[Union[Set, Dict]],
|
||||||
|
exclude_fields: Optional[Union[Set, Dict]],
|
||||||
|
nested_name: str,
|
||||||
|
) -> Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]:
|
||||||
|
fields = model_cls.get_included(fields, nested_name)
|
||||||
|
exclude_fields = model_cls.get_excluded(exclude_fields, nested_name)
|
||||||
|
return fields, exclude_fields
|
||||||
|
|
||||||
|
def build_join( # noqa: CCR001
|
||||||
self, item: str, join_parameters: JoinParameters
|
self, item: str, join_parameters: JoinParameters
|
||||||
) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]:
|
) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]:
|
||||||
for part in item.split("__"):
|
|
||||||
|
fields = self.fields
|
||||||
|
exclude_fields = self.exclude_fields
|
||||||
|
|
||||||
|
for index, part in enumerate(item.split("__")):
|
||||||
if issubclass(
|
if issubclass(
|
||||||
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
|
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
|
||||||
):
|
):
|
||||||
_fields = join_parameters.model_cls.Meta.model_fields
|
_fields = join_parameters.model_cls.Meta.model_fields
|
||||||
new_part = _fields[part].to.get_name()
|
new_part = _fields[part].to.get_name()
|
||||||
self._switch_many_to_many_order_columns(part, new_part)
|
self._switch_many_to_many_order_columns(part, new_part)
|
||||||
|
if index > 0: # nested joins
|
||||||
|
fields, exclude_fields = SqlJoin.update_inclusions(
|
||||||
|
model_cls=join_parameters.model_cls,
|
||||||
|
fields=fields,
|
||||||
|
exclude_fields=exclude_fields,
|
||||||
|
nested_name=part,
|
||||||
|
)
|
||||||
|
|
||||||
join_parameters = self._build_join_parameters(
|
join_parameters = self._build_join_parameters(
|
||||||
part, join_parameters, is_multi=True
|
part=part,
|
||||||
|
join_params=join_parameters,
|
||||||
|
is_multi=True,
|
||||||
|
fields=fields,
|
||||||
|
exclude_fields=exclude_fields,
|
||||||
)
|
)
|
||||||
part = new_part
|
part = new_part
|
||||||
join_parameters = self._build_join_parameters(part, join_parameters)
|
if index > 0: # nested joins
|
||||||
|
fields, exclude_fields = SqlJoin.update_inclusions(
|
||||||
|
model_cls=join_parameters.model_cls,
|
||||||
|
fields=fields,
|
||||||
|
exclude_fields=exclude_fields,
|
||||||
|
nested_name=part,
|
||||||
|
)
|
||||||
|
join_parameters = self._build_join_parameters(
|
||||||
|
part=part,
|
||||||
|
join_params=join_parameters,
|
||||||
|
fields=fields,
|
||||||
|
exclude_fields=exclude_fields,
|
||||||
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
self.used_aliases,
|
self.used_aliases,
|
||||||
@ -73,7 +122,12 @@ class SqlJoin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _build_join_parameters(
|
def _build_join_parameters(
|
||||||
self, part: str, join_params: JoinParameters, is_multi: bool = False
|
self,
|
||||||
|
part: str,
|
||||||
|
join_params: JoinParameters,
|
||||||
|
fields: Optional[Union[Set, Dict]],
|
||||||
|
exclude_fields: Optional[Union[Set, Dict]],
|
||||||
|
is_multi: bool = False,
|
||||||
) -> JoinParameters:
|
) -> JoinParameters:
|
||||||
if is_multi:
|
if is_multi:
|
||||||
model_cls = join_params.model_cls.Meta.model_fields[part].through
|
model_cls = join_params.model_cls.Meta.model_fields[part].through
|
||||||
@ -85,20 +139,30 @@ class SqlJoin:
|
|||||||
join_params.from_table, to_table
|
join_params.from_table, to_table
|
||||||
)
|
)
|
||||||
if alias not in self.used_aliases:
|
if alias not in self.used_aliases:
|
||||||
self._process_join(join_params, is_multi, model_cls, part, alias)
|
self._process_join(
|
||||||
|
join_params=join_params,
|
||||||
|
is_multi=is_multi,
|
||||||
|
model_cls=model_cls,
|
||||||
|
part=part,
|
||||||
|
alias=alias,
|
||||||
|
fields=fields,
|
||||||
|
exclude_fields=exclude_fields,
|
||||||
|
)
|
||||||
|
|
||||||
previous_alias = alias
|
previous_alias = alias
|
||||||
from_table = to_table
|
from_table = to_table
|
||||||
prev_model = model_cls
|
prev_model = model_cls
|
||||||
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
||||||
|
|
||||||
def _process_join(
|
def _process_join( # noqa: CFQ002
|
||||||
self,
|
self,
|
||||||
join_params: JoinParameters,
|
join_params: JoinParameters,
|
||||||
is_multi: bool,
|
is_multi: bool,
|
||||||
model_cls: Type["Model"],
|
model_cls: Type["Model"],
|
||||||
part: str,
|
part: str,
|
||||||
alias: str,
|
alias: str,
|
||||||
|
fields: Optional[Union[Set, Dict]],
|
||||||
|
exclude_fields: Optional[Union[Set, Dict]],
|
||||||
) -> None:
|
) -> None:
|
||||||
to_table = model_cls.Meta.table.name
|
to_table = model_cls.Meta.table.name
|
||||||
to_key, from_key = self.get_to_and_from_keys(
|
to_key, from_key = self.get_to_and_from_keys(
|
||||||
@ -129,7 +193,10 @@ class SqlJoin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self_related_fields = model_cls.own_table_columns(
|
self_related_fields = model_cls.own_table_columns(
|
||||||
model_cls, self.fields, self.exclude_fields, nested=True,
|
model=model_cls,
|
||||||
|
fields=fields,
|
||||||
|
exclude_fields=exclude_fields,
|
||||||
|
use_alias=True,
|
||||||
)
|
)
|
||||||
self.columns.extend(
|
self.columns.extend(
|
||||||
self.relation_manager(model_cls).prefixed_columns(
|
self.relation_manager(model_cls).prefixed_columns(
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
|
import copy
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import List, Optional, TYPE_CHECKING, Tuple, Type
|
from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Type, Union
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
@ -21,8 +22,8 @@ class Query:
|
|||||||
select_related: List,
|
select_related: List,
|
||||||
limit_count: Optional[int],
|
limit_count: Optional[int],
|
||||||
offset: Optional[int],
|
offset: Optional[int],
|
||||||
fields: Optional[List],
|
fields: Optional[Union[Dict, Set]],
|
||||||
exclude_fields: Optional[List],
|
exclude_fields: Optional[Union[Dict, Set]],
|
||||||
order_bys: Optional[List],
|
order_bys: Optional[List],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.query_offset = offset
|
self.query_offset = offset
|
||||||
@ -30,8 +31,8 @@ class Query:
|
|||||||
self._select_related = select_related[:]
|
self._select_related = select_related[:]
|
||||||
self.filter_clauses = filter_clauses[:]
|
self.filter_clauses = filter_clauses[:]
|
||||||
self.exclude_clauses = exclude_clauses[:]
|
self.exclude_clauses = exclude_clauses[:]
|
||||||
self.fields = fields[:] if fields else []
|
self.fields = copy.deepcopy(fields) if fields else {}
|
||||||
self.exclude_fields = exclude_fields[:] if exclude_fields else []
|
self.exclude_fields = copy.deepcopy(exclude_fields) if exclude_fields else {}
|
||||||
|
|
||||||
self.model_cls = model_cls
|
self.model_cls = model_cls
|
||||||
self.table = self.model_cls.Meta.table
|
self.table = self.model_cls.Meta.table
|
||||||
@ -73,7 +74,10 @@ class Query:
|
|||||||
|
|
||||||
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
|
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
|
||||||
self_related_fields = self.model_cls.own_table_columns(
|
self_related_fields = self.model_cls.own_table_columns(
|
||||||
self.model_cls, self.fields, self.exclude_fields
|
model=self.model_cls,
|
||||||
|
fields=self.fields,
|
||||||
|
exclude_fields=self.exclude_fields,
|
||||||
|
use_alias=True,
|
||||||
)
|
)
|
||||||
self.columns = self.model_cls.Meta.alias_manager.prefixed_columns(
|
self.columns = self.model_cls.Meta.alias_manager.prefixed_columns(
|
||||||
"", self.table, self_related_fields
|
"", self.table, self_related_fields
|
||||||
@ -87,13 +91,14 @@ class Query:
|
|||||||
join_parameters = JoinParameters(
|
join_parameters = JoinParameters(
|
||||||
self.model_cls, "", self.table.name, self.model_cls
|
self.model_cls, "", self.table.name, self.model_cls
|
||||||
)
|
)
|
||||||
|
fields = self.model_cls.get_included(self.fields, item)
|
||||||
|
exclude_fields = self.model_cls.get_excluded(self.exclude_fields, item)
|
||||||
sql_join = SqlJoin(
|
sql_join = SqlJoin(
|
||||||
used_aliases=self.used_aliases,
|
used_aliases=self.used_aliases,
|
||||||
select_from=self.select_from,
|
select_from=self.select_from,
|
||||||
columns=self.columns,
|
columns=self.columns,
|
||||||
fields=self.fields,
|
fields=fields,
|
||||||
exclude_fields=self.exclude_fields,
|
exclude_fields=exclude_fields,
|
||||||
order_columns=self.order_columns,
|
order_columns=self.order_columns,
|
||||||
sorted_orders=self.sorted_orders,
|
sorted_orders=self.sorted_orders,
|
||||||
)
|
)
|
||||||
@ -131,5 +136,5 @@ class Query:
|
|||||||
self.select_from = []
|
self.select_from = []
|
||||||
self.columns = []
|
self.columns = []
|
||||||
self.used_aliases = []
|
self.used_aliases = []
|
||||||
self.fields = []
|
self.fields = {}
|
||||||
self.exclude_fields = []
|
self.exclude_fields = {}
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Type, Union
|
from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union
|
||||||
|
|
||||||
import databases
|
import databases
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
@ -10,6 +10,7 @@ from ormar.exceptions import QueryDefinitionError
|
|||||||
from ormar.queryset import FilterQuery
|
from ormar.queryset import FilterQuery
|
||||||
from ormar.queryset.clause import QueryClause
|
from ormar.queryset.clause import QueryClause
|
||||||
from ormar.queryset.query import Query
|
from ormar.queryset.query import Query
|
||||||
|
from ormar.queryset.utils import update, update_dict_from_list
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
@ -26,8 +27,8 @@ class QuerySet:
|
|||||||
select_related: List = None,
|
select_related: List = None,
|
||||||
limit_count: int = None,
|
limit_count: int = None,
|
||||||
offset: int = None,
|
offset: int = None,
|
||||||
columns: List = None,
|
columns: Dict = None,
|
||||||
exclude_columns: List = None,
|
exclude_columns: Dict = None,
|
||||||
order_bys: List = None,
|
order_bys: List = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_cls = model_cls
|
self.model_cls = model_cls
|
||||||
@ -36,8 +37,8 @@ class QuerySet:
|
|||||||
self._select_related = [] if select_related is None else select_related
|
self._select_related = [] if select_related is None else select_related
|
||||||
self.limit_count = limit_count
|
self.limit_count = limit_count
|
||||||
self.query_offset = offset
|
self.query_offset = offset
|
||||||
self._columns = columns or []
|
self._columns = columns or {}
|
||||||
self._exclude_columns = exclude_columns or []
|
self._exclude_columns = exclude_columns or {}
|
||||||
self.order_bys = order_bys or []
|
self.order_bys = order_bys or []
|
||||||
|
|
||||||
def __get__(
|
def __get__(
|
||||||
@ -169,11 +170,16 @@ class QuerySet:
|
|||||||
order_bys=self.order_bys,
|
order_bys=self.order_bys,
|
||||||
)
|
)
|
||||||
|
|
||||||
def exclude_fields(self, columns: Union[List, str]) -> "QuerySet":
|
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
|
||||||
if not isinstance(columns, list):
|
if isinstance(columns, str):
|
||||||
columns = [columns]
|
columns = [columns]
|
||||||
|
|
||||||
columns = list(set(list(self._exclude_columns) + columns))
|
current_excluded = self._exclude_columns
|
||||||
|
if not isinstance(columns, dict):
|
||||||
|
current_excluded = update_dict_from_list(current_excluded, columns)
|
||||||
|
else:
|
||||||
|
current_excluded = update(current_excluded, columns)
|
||||||
|
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
model_cls=self.model,
|
model_cls=self.model,
|
||||||
filter_clauses=self.filter_clauses,
|
filter_clauses=self.filter_clauses,
|
||||||
@ -182,15 +188,20 @@ class QuerySet:
|
|||||||
limit_count=self.limit_count,
|
limit_count=self.limit_count,
|
||||||
offset=self.query_offset,
|
offset=self.query_offset,
|
||||||
columns=self._columns,
|
columns=self._columns,
|
||||||
exclude_columns=columns,
|
exclude_columns=current_excluded,
|
||||||
order_bys=self.order_bys,
|
order_bys=self.order_bys,
|
||||||
)
|
)
|
||||||
|
|
||||||
def fields(self, columns: Union[List, str]) -> "QuerySet":
|
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
|
||||||
if not isinstance(columns, list):
|
if isinstance(columns, str):
|
||||||
columns = [columns]
|
columns = [columns]
|
||||||
|
|
||||||
columns = list(set(list(self._columns) + columns))
|
current_included = self._columns
|
||||||
|
if not isinstance(columns, dict):
|
||||||
|
current_included = update_dict_from_list(current_included, columns)
|
||||||
|
else:
|
||||||
|
current_included = update(current_included, columns)
|
||||||
|
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
model_cls=self.model,
|
model_cls=self.model,
|
||||||
filter_clauses=self.filter_clauses,
|
filter_clauses=self.filter_clauses,
|
||||||
@ -198,7 +209,7 @@ class QuerySet:
|
|||||||
select_related=self._select_related,
|
select_related=self._select_related,
|
||||||
limit_count=self.limit_count,
|
limit_count=self.limit_count,
|
||||||
offset=self.query_offset,
|
offset=self.query_offset,
|
||||||
columns=columns,
|
columns=current_included,
|
||||||
exclude_columns=self._exclude_columns,
|
exclude_columns=self._exclude_columns,
|
||||||
order_bys=self.order_bys,
|
order_bys=self.order_bys,
|
||||||
)
|
)
|
||||||
|
|||||||
57
ormar/queryset/utils.py
Normal file
57
ormar/queryset/utils.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import collections.abc
|
||||||
|
import copy
|
||||||
|
from typing import Any, Dict, List, Set, Union
|
||||||
|
|
||||||
|
|
||||||
|
def check_node_not_dict_or_not_last_node(
|
||||||
|
part: str, parts: List, current_level: Any
|
||||||
|
) -> bool:
|
||||||
|
return (part not in current_level and part != parts[-1]) or (
|
||||||
|
part in current_level and not isinstance(current_level[part], dict)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def translate_list_to_dict(list_to_trans: Union[List, Set]) -> Dict: # noqa: CCR001
|
||||||
|
new_dict: Dict = dict()
|
||||||
|
for path in list_to_trans:
|
||||||
|
current_level = new_dict
|
||||||
|
parts = path.split("__")
|
||||||
|
for part in parts:
|
||||||
|
if check_node_not_dict_or_not_last_node(
|
||||||
|
part=part, parts=parts, current_level=current_level
|
||||||
|
):
|
||||||
|
current_level[part] = dict()
|
||||||
|
elif part not in current_level:
|
||||||
|
current_level[part] = ...
|
||||||
|
current_level = current_level[part]
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_set_to_required_dict(set_to_convert: set) -> Dict:
|
||||||
|
new_dict = dict()
|
||||||
|
for key in set_to_convert:
|
||||||
|
new_dict[key] = Ellipsis
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
|
def update(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001
|
||||||
|
if current_dict is Ellipsis:
|
||||||
|
current_dict = dict()
|
||||||
|
for key, value in updating_dict.items():
|
||||||
|
if isinstance(value, collections.abc.Mapping):
|
||||||
|
old_key = current_dict.get(key, {})
|
||||||
|
if isinstance(old_key, set):
|
||||||
|
old_key = convert_set_to_required_dict(old_key)
|
||||||
|
current_dict[key] = update(old_key, value)
|
||||||
|
elif isinstance(value, set) and isinstance(current_dict.get(key), set):
|
||||||
|
current_dict[key] = current_dict.get(key).union(value)
|
||||||
|
else:
|
||||||
|
current_dict[key] = value
|
||||||
|
return current_dict
|
||||||
|
|
||||||
|
|
||||||
|
def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> Dict:
|
||||||
|
updated_dict = copy.copy(curr_dict)
|
||||||
|
dict_to_update = translate_list_to_dict(list_to_update)
|
||||||
|
update(updated_dict, dict_to_update)
|
||||||
|
return updated_dict
|
||||||
@ -1,16 +1,12 @@
|
|||||||
from ormar.relations.alias_manager import AliasManager
|
from ormar.relations.alias_manager import AliasManager
|
||||||
from ormar.relations.relation import Relation, RelationType
|
from ormar.relations.relation import Relation, RelationType
|
||||||
from ormar.relations.relation_manager import RelationsManager
|
from ormar.relations.relation_manager import RelationsManager
|
||||||
from ormar.relations.utils import (
|
from ormar.relations.utils import get_relations_sides_and_names
|
||||||
get_relations_sides_and_names,
|
|
||||||
register_missing_relation,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AliasManager",
|
"AliasManager",
|
||||||
"Relation",
|
"Relation",
|
||||||
"RelationsManager",
|
"RelationsManager",
|
||||||
"RelationType",
|
"RelationType",
|
||||||
"register_missing_relation",
|
|
||||||
"get_relations_sides_and_names",
|
"get_relations_sides_and_names",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -5,10 +5,7 @@ from ormar.fields import BaseField
|
|||||||
from ormar.fields.foreign_key import ForeignKeyField
|
from ormar.fields.foreign_key import ForeignKeyField
|
||||||
from ormar.fields.many_to_many import ManyToManyField
|
from ormar.fields.many_to_many import ManyToManyField
|
||||||
from ormar.relations.relation import Relation, RelationType
|
from ormar.relations.relation import Relation, RelationType
|
||||||
from ormar.relations.utils import (
|
from ormar.relations.utils import get_relations_sides_and_names
|
||||||
get_relations_sides_and_names,
|
|
||||||
register_missing_relation,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
@ -42,8 +39,6 @@ class RelationsManager:
|
|||||||
to=field.to,
|
to=field.to,
|
||||||
through=getattr(field, "through", None),
|
through=getattr(field, "through", None),
|
||||||
)
|
)
|
||||||
if field.name not in self._related_names:
|
|
||||||
self._related_names.append(field.name)
|
|
||||||
|
|
||||||
def __contains__(self, item: str) -> bool:
|
def __contains__(self, item: str) -> bool:
|
||||||
return item in self._related_names
|
return item in self._related_names
|
||||||
@ -69,9 +64,10 @@ class RelationsManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
parent_relation = parent._orm._get(child_name)
|
parent_relation = parent._orm._get(child_name)
|
||||||
if not parent_relation:
|
if parent_relation:
|
||||||
parent_relation = register_missing_relation(parent, child, child_name)
|
# print('missing', child_name)
|
||||||
parent_relation.add(child) # type: ignore
|
# parent_relation = register_missing_relation(parent, child, child_name)
|
||||||
|
parent_relation.add(child) # type: ignore
|
||||||
|
|
||||||
child_relation = child._orm._get(to_name)
|
child_relation = child._orm._get(to_name)
|
||||||
if child_relation:
|
if child_relation:
|
||||||
|
|||||||
@ -72,6 +72,4 @@ class RelationProxy(list):
|
|||||||
if self.relation._type == ormar.RelationType.MULTIPLE:
|
if self.relation._type == ormar.RelationType.MULTIPLE:
|
||||||
await self.queryset_proxy.create_through_instance(item)
|
await self.queryset_proxy.create_through_instance(item)
|
||||||
rel_name = item.resolve_relation_name(item, self._owner)
|
rel_name = item.resolve_relation_name(item, self._owner)
|
||||||
if rel_name not in item._orm: # pragma nocover
|
|
||||||
item._orm._add_relation(item.Meta.model_fields[rel_name])
|
|
||||||
setattr(item, rel_name, self._owner)
|
setattr(item, rel_name, self._owner)
|
||||||
|
|||||||
@ -1,26 +1,13 @@
|
|||||||
from typing import Optional, TYPE_CHECKING, Tuple, Type
|
from typing import TYPE_CHECKING, Tuple, Type
|
||||||
from weakref import proxy
|
from weakref import proxy
|
||||||
|
|
||||||
import ormar
|
|
||||||
from ormar.fields import BaseField
|
from ormar.fields import BaseField
|
||||||
from ormar.fields.many_to_many import ManyToManyField
|
from ormar.fields.many_to_many import ManyToManyField
|
||||||
from ormar.relations import Relation
|
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
|
|
||||||
|
|
||||||
def register_missing_relation(
|
|
||||||
parent: "Model", child: "Model", child_name: str
|
|
||||||
) -> Optional[Relation]:
|
|
||||||
ormar.models.expand_reverse_relationships(child.__class__)
|
|
||||||
name = parent.resolve_relation_name(parent, child)
|
|
||||||
field = parent.Meta.model_fields[name]
|
|
||||||
parent._orm._add_relation(field)
|
|
||||||
parent_relation = parent._orm._get(child_name)
|
|
||||||
return parent_relation
|
|
||||||
|
|
||||||
|
|
||||||
def get_relations_sides_and_names(
|
def get_relations_sides_and_names(
|
||||||
to_field: Type[BaseField],
|
to_field: Type[BaseField],
|
||||||
parent: "Model",
|
parent: "Model",
|
||||||
|
|||||||
@ -4,6 +4,7 @@ databases[mysql]
|
|||||||
pydantic
|
pydantic
|
||||||
sqlalchemy
|
sqlalchemy
|
||||||
typing_extensions
|
typing_extensions
|
||||||
|
orjson
|
||||||
|
|
||||||
# Async database drivers
|
# Async database drivers
|
||||||
aiomysql
|
aiomysql
|
||||||
@ -34,3 +35,6 @@ flake8-variables-names
|
|||||||
flake8-cognitive-complexity
|
flake8-cognitive-complexity
|
||||||
flake8-functions
|
flake8-functions
|
||||||
flake8-expression-complexity
|
flake8-expression-complexity
|
||||||
|
|
||||||
|
# Performance testing
|
||||||
|
yappi
|
||||||
|
|||||||
3
setup.py
3
setup.py
@ -42,7 +42,7 @@ setup(
|
|||||||
version=get_version(PACKAGE),
|
version=get_version(PACKAGE),
|
||||||
url=URL,
|
url=URL,
|
||||||
license="MIT",
|
license="MIT",
|
||||||
description="An simple async ORM with fastapi in mind and pydantic validation.",
|
description="A simple async ORM with fastapi in mind and pydantic validation.",
|
||||||
long_description=get_long_description(),
|
long_description=get_long_description(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
keywords=['orm', 'sqlalchemy', 'fastapi', 'pydantic', 'databases', 'async', 'alembic'],
|
keywords=['orm', 'sqlalchemy', 'fastapi', 'pydantic', 'databases', 'async', 'alembic'],
|
||||||
@ -56,6 +56,7 @@ setup(
|
|||||||
"postgresql": ["asyncpg", "psycopg2"],
|
"postgresql": ["asyncpg", "psycopg2"],
|
||||||
"mysql": ["aiomysql", "pymysql"],
|
"mysql": ["aiomysql", "pymysql"],
|
||||||
"sqlite": ["aiosqlite"],
|
"sqlite": ["aiosqlite"],
|
||||||
|
"orjson": ["orjson"]
|
||||||
},
|
},
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
|
|||||||
@ -117,8 +117,8 @@ async def test_working_with_aliases():
|
|||||||
"first_name",
|
"first_name",
|
||||||
"last_name",
|
"last_name",
|
||||||
"born_year",
|
"born_year",
|
||||||
"child__first_name",
|
"children__first_name",
|
||||||
"child__last_name",
|
"children__last_name",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
.get()
|
.get()
|
||||||
|
|||||||
@ -80,10 +80,53 @@ async def test_selecting_subset():
|
|||||||
all_cars = (
|
all_cars = (
|
||||||
await Car.objects.select_related("manufacturer")
|
await Car.objects.select_related("manufacturer")
|
||||||
.exclude_fields(
|
.exclude_fields(
|
||||||
["gearbox_type", "gears", "aircon_type", "year", "company__founded"]
|
[
|
||||||
|
"gearbox_type",
|
||||||
|
"gears",
|
||||||
|
"aircon_type",
|
||||||
|
"year",
|
||||||
|
"manufacturer__founded",
|
||||||
|
]
|
||||||
)
|
)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
for car in all_cars:
|
||||||
|
assert all(
|
||||||
|
getattr(car, x) is None
|
||||||
|
for x in ["year", "gearbox_type", "gears", "aircon_type"]
|
||||||
|
)
|
||||||
|
assert car.manufacturer.name == "Toyota"
|
||||||
|
assert car.manufacturer.founded is None
|
||||||
|
|
||||||
|
all_cars = (
|
||||||
|
await Car.objects.select_related("manufacturer")
|
||||||
|
.exclude_fields(
|
||||||
|
{
|
||||||
|
"gearbox_type": ...,
|
||||||
|
"gears": ...,
|
||||||
|
"aircon_type": ...,
|
||||||
|
"year": ...,
|
||||||
|
"manufacturer": {"founded": ...},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
all_cars2 = (
|
||||||
|
await Car.objects.select_related("manufacturer")
|
||||||
|
.exclude_fields(
|
||||||
|
{
|
||||||
|
"gearbox_type": ...,
|
||||||
|
"gears": ...,
|
||||||
|
"aircon_type": ...,
|
||||||
|
"year": ...,
|
||||||
|
"manufacturer": {"founded"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all_cars == all_cars2
|
||||||
|
|
||||||
for car in all_cars:
|
for car in all_cars:
|
||||||
assert all(
|
assert all(
|
||||||
getattr(car, x) is None
|
getattr(car, x) is None
|
||||||
@ -119,7 +162,7 @@ async def test_selecting_subset():
|
|||||||
all_cars_check2 = (
|
all_cars_check2 = (
|
||||||
await Car.objects.select_related("manufacturer")
|
await Car.objects.select_related("manufacturer")
|
||||||
.fields(["id", "name", "manufacturer"])
|
.fields(["id", "name", "manufacturer"])
|
||||||
.exclude_fields("company__founded")
|
.exclude_fields("manufacturer__founded")
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
for car in all_cars_check2:
|
for car in all_cars_check2:
|
||||||
@ -133,5 +176,5 @@ async def test_selecting_subset():
|
|||||||
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
||||||
# cannot exclude mandatory model columns - company__name in this example
|
# cannot exclude mandatory model columns - company__name in this example
|
||||||
await Car.objects.select_related("manufacturer").exclude_fields(
|
await Car.objects.select_related("manufacturer").exclude_fields(
|
||||||
["company__name"]
|
["manufacturer__name"]
|
||||||
).all()
|
).all()
|
||||||
|
|||||||
@ -1,374 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import databases
|
|
||||||
import pytest
|
|
||||||
import sqlalchemy
|
|
||||||
|
|
||||||
import ormar
|
|
||||||
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
|
|
||||||
from tests.settings import DATABASE_URL
|
|
||||||
|
|
||||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
|
||||||
metadata = sqlalchemy.MetaData()
|
|
||||||
|
|
||||||
|
|
||||||
class Album(ormar.Model):
|
|
||||||
class Meta:
|
|
||||||
tablename = "albums"
|
|
||||||
metadata = metadata
|
|
||||||
database = database
|
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
|
||||||
name: str = ormar.String(max_length=100)
|
|
||||||
|
|
||||||
|
|
||||||
class Track(ormar.Model):
|
|
||||||
class Meta:
|
|
||||||
tablename = "tracks"
|
|
||||||
metadata = metadata
|
|
||||||
database = database
|
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
|
||||||
album: Optional[Album] = ormar.ForeignKey(Album)
|
|
||||||
title: str = ormar.String(max_length=100)
|
|
||||||
position: int = ormar.Integer()
|
|
||||||
|
|
||||||
|
|
||||||
class Cover(ormar.Model):
|
|
||||||
class Meta:
|
|
||||||
tablename = "covers"
|
|
||||||
metadata = metadata
|
|
||||||
database = database
|
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
|
||||||
album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures")
|
|
||||||
title: str = ormar.String(max_length=100)
|
|
||||||
|
|
||||||
|
|
||||||
class Organisation(ormar.Model):
|
|
||||||
class Meta:
|
|
||||||
tablename = "org"
|
|
||||||
metadata = metadata
|
|
||||||
database = database
|
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
|
||||||
ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"])
|
|
||||||
|
|
||||||
|
|
||||||
class Team(ormar.Model):
|
|
||||||
class Meta:
|
|
||||||
tablename = "teams"
|
|
||||||
metadata = metadata
|
|
||||||
database = database
|
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
|
||||||
org: Optional[Organisation] = ormar.ForeignKey(Organisation)
|
|
||||||
name: str = ormar.String(max_length=100)
|
|
||||||
|
|
||||||
|
|
||||||
class Member(ormar.Model):
|
|
||||||
class Meta:
|
|
||||||
tablename = "members"
|
|
||||||
metadata = metadata
|
|
||||||
database = database
|
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
|
||||||
team: Optional[Team] = ormar.ForeignKey(Team)
|
|
||||||
email: str = ormar.String(max_length=100)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True, scope="module")
|
|
||||||
def create_test_database():
|
|
||||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
|
||||||
metadata.drop_all(engine)
|
|
||||||
metadata.create_all(engine)
|
|
||||||
yield
|
|
||||||
metadata.drop_all(engine)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_wrong_query_foreign_key_type():
|
|
||||||
async with database:
|
|
||||||
with pytest.raises(RelationshipInstanceError):
|
|
||||||
Track(title="The Error", album="wrong_pk_type")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_setting_explicitly_empty_relation():
|
|
||||||
async with database:
|
|
||||||
track = Track(album=None, title="The Bird", position=1)
|
|
||||||
assert track.album is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_related_name():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
album = await Album.objects.create(name="Vanilla")
|
|
||||||
await Cover.objects.create(album=album, title="The cover file")
|
|
||||||
assert len(album.cover_pictures) == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_model_crud():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
album = Album(name="Jamaica")
|
|
||||||
await album.save()
|
|
||||||
track1 = Track(album=album, title="The Bird", position=1)
|
|
||||||
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
|
|
||||||
track3 = Track(album=album, title="The Waters", position=3)
|
|
||||||
await track1.save()
|
|
||||||
await track2.save()
|
|
||||||
await track3.save()
|
|
||||||
|
|
||||||
track = await Track.objects.get(title="The Bird")
|
|
||||||
assert track.album.pk == album.pk
|
|
||||||
assert isinstance(track.album, ormar.Model)
|
|
||||||
assert track.album.name is None
|
|
||||||
await track.album.load()
|
|
||||||
assert track.album.name == "Jamaica"
|
|
||||||
|
|
||||||
assert len(album.tracks) == 3
|
|
||||||
assert album.tracks[1].title == "Heart don't stand a chance"
|
|
||||||
|
|
||||||
album1 = await Album.objects.get(name="Jamaica")
|
|
||||||
assert album1.pk == album.pk
|
|
||||||
assert album1.tracks == []
|
|
||||||
|
|
||||||
await Track.objects.create(
|
|
||||||
album={"id": track.album.pk}, title="The Bird2", position=4
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_select_related():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
album = Album(name="Malibu")
|
|
||||||
await album.save()
|
|
||||||
track1 = Track(album=album, title="The Bird", position=1)
|
|
||||||
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
|
|
||||||
track3 = Track(album=album, title="The Waters", position=3)
|
|
||||||
await track1.save()
|
|
||||||
await track2.save()
|
|
||||||
await track3.save()
|
|
||||||
|
|
||||||
fantasies = Album(name="Fantasies")
|
|
||||||
await fantasies.save()
|
|
||||||
track4 = Track(album=fantasies, title="Help I'm Alive", position=1)
|
|
||||||
track5 = Track(album=fantasies, title="Sick Muse", position=2)
|
|
||||||
track6 = Track(album=fantasies, title="Satellite Mind", position=3)
|
|
||||||
await track4.save()
|
|
||||||
await track5.save()
|
|
||||||
await track6.save()
|
|
||||||
|
|
||||||
track = await Track.objects.select_related("album").get(title="The Bird")
|
|
||||||
assert track.album.name == "Malibu"
|
|
||||||
|
|
||||||
tracks = await Track.objects.select_related("album").all()
|
|
||||||
assert len(tracks) == 6
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_model_removal_from_relations():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
album = Album(name="Chichi")
|
|
||||||
await album.save()
|
|
||||||
track1 = Track(album=album, title="The Birdman", position=1)
|
|
||||||
track2 = Track(album=album, title="Superman", position=2)
|
|
||||||
track3 = Track(album=album, title="Wonder Woman", position=3)
|
|
||||||
await track1.save()
|
|
||||||
await track2.save()
|
|
||||||
await track3.save()
|
|
||||||
|
|
||||||
assert len(album.tracks) == 3
|
|
||||||
await album.tracks.remove(track1)
|
|
||||||
assert len(album.tracks) == 2
|
|
||||||
assert track1.album is None
|
|
||||||
|
|
||||||
await track1.update()
|
|
||||||
track1 = await Track.objects.get(title="The Birdman")
|
|
||||||
assert track1.album is None
|
|
||||||
|
|
||||||
await album.tracks.add(track1)
|
|
||||||
assert len(album.tracks) == 3
|
|
||||||
assert track1.album == album
|
|
||||||
|
|
||||||
await track1.update()
|
|
||||||
track1 = await Track.objects.select_related("album__tracks").get(
|
|
||||||
title="The Birdman"
|
|
||||||
)
|
|
||||||
album = await Album.objects.select_related("tracks").get(name="Chichi")
|
|
||||||
assert track1.album == album
|
|
||||||
|
|
||||||
track1.remove(album)
|
|
||||||
assert track1.album is None
|
|
||||||
assert len(album.tracks) == 2
|
|
||||||
|
|
||||||
track2.remove(album)
|
|
||||||
assert track2.album is None
|
|
||||||
assert len(album.tracks) == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_fk_filter():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
malibu = Album(name="Malibu%")
|
|
||||||
await malibu.save()
|
|
||||||
await Track.objects.create(album=malibu, title="The Bird", position=1)
|
|
||||||
await Track.objects.create(
|
|
||||||
album=malibu, title="Heart don't stand a chance", position=2
|
|
||||||
)
|
|
||||||
await Track.objects.create(album=malibu, title="The Waters", position=3)
|
|
||||||
|
|
||||||
fantasies = await Album.objects.create(name="Fantasies")
|
|
||||||
await Track.objects.create(
|
|
||||||
album=fantasies, title="Help I'm Alive", position=1
|
|
||||||
)
|
|
||||||
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
|
|
||||||
await Track.objects.create(
|
|
||||||
album=fantasies, title="Satellite Mind", position=3
|
|
||||||
)
|
|
||||||
|
|
||||||
tracks = (
|
|
||||||
await Track.objects.select_related("album")
|
|
||||||
.filter(album__name="Fantasies")
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
assert len(tracks) == 3
|
|
||||||
for track in tracks:
|
|
||||||
assert track.album.name == "Fantasies"
|
|
||||||
|
|
||||||
tracks = (
|
|
||||||
await Track.objects.select_related("album")
|
|
||||||
.filter(album__name__icontains="fan")
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
assert len(tracks) == 3
|
|
||||||
for track in tracks:
|
|
||||||
assert track.album.name == "Fantasies"
|
|
||||||
|
|
||||||
tracks = await Track.objects.filter(album__name__contains="Fan").all()
|
|
||||||
assert len(tracks) == 3
|
|
||||||
for track in tracks:
|
|
||||||
assert track.album.name == "Fantasies"
|
|
||||||
|
|
||||||
tracks = await Track.objects.filter(album__name__contains="Malibu%").all()
|
|
||||||
assert len(tracks) == 3
|
|
||||||
|
|
||||||
tracks = (
|
|
||||||
await Track.objects.filter(album=malibu).select_related("album").all()
|
|
||||||
)
|
|
||||||
assert len(tracks) == 3
|
|
||||||
for track in tracks:
|
|
||||||
assert track.album.name == "Malibu%"
|
|
||||||
|
|
||||||
tracks = await Track.objects.select_related("album").all(album=malibu)
|
|
||||||
assert len(tracks) == 3
|
|
||||||
for track in tracks:
|
|
||||||
assert track.album.name == "Malibu%"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_multiple_fk():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
acme = await Organisation.objects.create(ident="ACME Ltd")
|
|
||||||
red_team = await Team.objects.create(org=acme, name="Red Team")
|
|
||||||
blue_team = await Team.objects.create(org=acme, name="Blue Team")
|
|
||||||
await Member.objects.create(team=red_team, email="a@example.org")
|
|
||||||
await Member.objects.create(team=red_team, email="b@example.org")
|
|
||||||
await Member.objects.create(team=blue_team, email="c@example.org")
|
|
||||||
await Member.objects.create(team=blue_team, email="d@example.org")
|
|
||||||
|
|
||||||
other = await Organisation.objects.create(ident="Other ltd")
|
|
||||||
team = await Team.objects.create(org=other, name="Green Team")
|
|
||||||
await Member.objects.create(team=team, email="e@example.org")
|
|
||||||
|
|
||||||
members = (
|
|
||||||
await Member.objects.select_related("team__org")
|
|
||||||
.filter(team__org__ident="ACME Ltd")
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
assert len(members) == 4
|
|
||||||
for member in members:
|
|
||||||
assert member.team.org.ident == "ACME Ltd"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_wrong_choices():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
await Organisation.objects.create(ident="Test 1")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_pk_filter():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
fantasies = await Album.objects.create(name="Test")
|
|
||||||
track = await Track.objects.create(
|
|
||||||
album=fantasies, title="Test1", position=1
|
|
||||||
)
|
|
||||||
await Track.objects.create(album=fantasies, title="Test2", position=2)
|
|
||||||
await Track.objects.create(album=fantasies, title="Test3", position=3)
|
|
||||||
tracks = (
|
|
||||||
await Track.objects.select_related("album").filter(pk=track.pk).all()
|
|
||||||
)
|
|
||||||
assert len(tracks) == 1
|
|
||||||
|
|
||||||
tracks = (
|
|
||||||
await Track.objects.select_related("album")
|
|
||||||
.filter(position=2, album__name="Test")
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
assert len(tracks) == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_limit_and_offset():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
fantasies = await Album.objects.create(name="Limitless")
|
|
||||||
await Track.objects.create(
|
|
||||||
id=None, album=fantasies, title="Sample", position=1
|
|
||||||
)
|
|
||||||
await Track.objects.create(album=fantasies, title="Sample2", position=2)
|
|
||||||
await Track.objects.create(album=fantasies, title="Sample3", position=3)
|
|
||||||
|
|
||||||
tracks = await Track.objects.limit(1).all()
|
|
||||||
assert len(tracks) == 1
|
|
||||||
assert tracks[0].title == "Sample"
|
|
||||||
|
|
||||||
tracks = await Track.objects.limit(1).offset(1).all()
|
|
||||||
assert len(tracks) == 1
|
|
||||||
assert tracks[0].title == "Sample2"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_exceptions():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
fantasies = await Album.objects.create(name="Test")
|
|
||||||
|
|
||||||
with pytest.raises(NoMatch):
|
|
||||||
await Album.objects.get(name="Test2")
|
|
||||||
|
|
||||||
await Track.objects.create(album=fantasies, title="Test1", position=1)
|
|
||||||
await Track.objects.create(album=fantasies, title="Test2", position=2)
|
|
||||||
await Track.objects.create(album=fantasies, title="Test3", position=3)
|
|
||||||
with pytest.raises(MultipleMatches):
|
|
||||||
await Track.objects.select_related("album").get(album=fantasies)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_wrong_model_passed_as_fk():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
with pytest.raises(RelationshipInstanceError):
|
|
||||||
org = await Organisation.objects.create(ident="ACME Ltd")
|
|
||||||
await Track.objects.create(album=org, title="Test1", position=1)
|
|
||||||
98
tests/test_queryset_utils.py
Normal file
98
tests/test_queryset_utils.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from ormar.models.excludable import Excludable
|
||||||
|
from ormar.queryset.utils import translate_list_to_dict, update_dict_from_list, update
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_excludable():
|
||||||
|
assert Excludable.is_included(None, "key") # all fields included if empty
|
||||||
|
assert not Excludable.is_excluded(None, "key") # none field excluded if empty
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_to_dict_translation():
|
||||||
|
tet_list = ["aa", "bb", "cc__aa", "cc__bb", "cc__aa__xx", "cc__aa__yy"]
|
||||||
|
test = translate_list_to_dict(tet_list)
|
||||||
|
assert test == {
|
||||||
|
"aa": Ellipsis,
|
||||||
|
"bb": Ellipsis,
|
||||||
|
"cc": {"aa": {"xx": Ellipsis, "yy": Ellipsis}, "bb": Ellipsis},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_updating_dict_with_list():
|
||||||
|
curr_dict = {
|
||||||
|
"aa": Ellipsis,
|
||||||
|
"bb": Ellipsis,
|
||||||
|
"cc": {"aa": {"xx": Ellipsis, "yy": Ellipsis}, "bb": Ellipsis},
|
||||||
|
}
|
||||||
|
list_to_update = ["ee", "bb__cc", "cc__aa__xx__oo", "cc__aa__oo"]
|
||||||
|
test = update_dict_from_list(curr_dict, list_to_update)
|
||||||
|
assert test == {
|
||||||
|
"aa": Ellipsis,
|
||||||
|
"bb": {"cc": Ellipsis},
|
||||||
|
"cc": {
|
||||||
|
"aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis},
|
||||||
|
"bb": Ellipsis,
|
||||||
|
},
|
||||||
|
"ee": Ellipsis,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_updating_dict_inc_set_with_list():
|
||||||
|
curr_dict = {
|
||||||
|
"aa": Ellipsis,
|
||||||
|
"bb": Ellipsis,
|
||||||
|
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
|
||||||
|
}
|
||||||
|
list_to_update = ["uu", "bb__cc", "cc__aa__xx__oo", "cc__aa__oo"]
|
||||||
|
test = update_dict_from_list(curr_dict, list_to_update)
|
||||||
|
assert test == {
|
||||||
|
"aa": Ellipsis,
|
||||||
|
"bb": {"cc": Ellipsis},
|
||||||
|
"cc": {
|
||||||
|
"aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis},
|
||||||
|
"bb": Ellipsis,
|
||||||
|
},
|
||||||
|
"uu": Ellipsis,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_updating_dict_inc_set_with_dict():
|
||||||
|
curr_dict = {
|
||||||
|
"aa": Ellipsis,
|
||||||
|
"bb": Ellipsis,
|
||||||
|
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
|
||||||
|
}
|
||||||
|
dict_to_update = {
|
||||||
|
"uu": Ellipsis,
|
||||||
|
"bb": {"cc", "dd"},
|
||||||
|
"cc": {"aa": {"xx": {"oo": Ellipsis}, "oo": Ellipsis}},
|
||||||
|
}
|
||||||
|
test = update(curr_dict, dict_to_update)
|
||||||
|
assert test == {
|
||||||
|
"aa": Ellipsis,
|
||||||
|
"bb": {"cc", "dd"},
|
||||||
|
"cc": {
|
||||||
|
"aa": {"xx": {"oo": Ellipsis}, "yy": Ellipsis, "oo": Ellipsis},
|
||||||
|
"bb": Ellipsis,
|
||||||
|
},
|
||||||
|
"uu": Ellipsis,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_updating_dict_inc_set_with_dict_inc_set():
|
||||||
|
curr_dict = {
|
||||||
|
"aa": Ellipsis,
|
||||||
|
"bb": Ellipsis,
|
||||||
|
"cc": {"aa": {"xx", "yy"}, "bb": Ellipsis},
|
||||||
|
}
|
||||||
|
dict_to_update = {
|
||||||
|
"uu": Ellipsis,
|
||||||
|
"bb": {"cc", "dd"},
|
||||||
|
"cc": {"aa": {"xx", "oo", "zz", "ii"}},
|
||||||
|
}
|
||||||
|
test = update(curr_dict, dict_to_update)
|
||||||
|
assert test == {
|
||||||
|
"aa": Ellipsis,
|
||||||
|
"bb": {"cc", "dd"},
|
||||||
|
"cc": {"aa": {"xx", "yy", "oo", "zz", "ii"}, "bb": Ellipsis},
|
||||||
|
"uu": Ellipsis,
|
||||||
|
}
|
||||||
@ -1,4 +1,5 @@
|
|||||||
from typing import Optional
|
import itertools
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
import databases
|
import databases
|
||||||
import pydantic
|
import pydantic
|
||||||
@ -12,6 +13,35 @@ database = databases.Database(DATABASE_URL, force_rollback=True)
|
|||||||
metadata = sqlalchemy.MetaData()
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
|
|
||||||
|
class NickNames(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "nicks"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: int = ormar.Integer(primary_key=True)
|
||||||
|
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
|
||||||
|
is_lame: bool = ormar.Boolean(nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
class NicksHq(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "nicks_x_hq"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
|
||||||
|
class HQ(ormar.Model):
|
||||||
|
class Meta:
|
||||||
|
tablename = "hqs"
|
||||||
|
metadata = metadata
|
||||||
|
database = database
|
||||||
|
|
||||||
|
id: int = ormar.Integer(primary_key=True)
|
||||||
|
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
|
||||||
|
nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq)
|
||||||
|
|
||||||
|
|
||||||
class Company(ormar.Model):
|
class Company(ormar.Model):
|
||||||
class Meta:
|
class Meta:
|
||||||
tablename = "companies"
|
tablename = "companies"
|
||||||
@ -21,6 +51,7 @@ class Company(ormar.Model):
|
|||||||
id: int = ormar.Integer(primary_key=True)
|
id: int = ormar.Integer(primary_key=True)
|
||||||
name: str = ormar.String(max_length=100, nullable=False, name="company_name")
|
name: str = ormar.String(max_length=100, nullable=False, name="company_name")
|
||||||
founded: int = ormar.Integer(nullable=True)
|
founded: int = ormar.Integer(nullable=True)
|
||||||
|
hq: HQ = ormar.ForeignKey(HQ)
|
||||||
|
|
||||||
|
|
||||||
class Car(ormar.Model):
|
class Car(ormar.Model):
|
||||||
@ -51,7 +82,14 @@ def create_test_database():
|
|||||||
async def test_selecting_subset():
|
async def test_selecting_subset():
|
||||||
async with database:
|
async with database:
|
||||||
async with database.transaction(force_rollback=True):
|
async with database.transaction(force_rollback=True):
|
||||||
toyota = await Company.objects.create(name="Toyota", founded=1937)
|
nick1 = await NickNames.objects.create(name="Nippon", is_lame=False)
|
||||||
|
nick2 = await NickNames.objects.create(name="EroCherry", is_lame=True)
|
||||||
|
hq = await HQ.objects.create(name="Japan")
|
||||||
|
await hq.nicks.add(nick1)
|
||||||
|
await hq.nicks.add(nick2)
|
||||||
|
|
||||||
|
toyota = await Company.objects.create(name="Toyota", founded=1937, hq=hq)
|
||||||
|
|
||||||
await Car.objects.create(
|
await Car.objects.create(
|
||||||
manufacturer=toyota,
|
manufacturer=toyota,
|
||||||
name="Corolla",
|
name="Corolla",
|
||||||
@ -78,17 +116,66 @@ async def test_selecting_subset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
all_cars = (
|
all_cars = (
|
||||||
await Car.objects.select_related("manufacturer")
|
await Car.objects.select_related(
|
||||||
.fields(["id", "name", "company__name"])
|
["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
|
||||||
|
)
|
||||||
|
.fields(
|
||||||
|
[
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"manufacturer__name",
|
||||||
|
"manufacturer__hq__name",
|
||||||
|
"manufacturer__hq__nicks__name",
|
||||||
|
]
|
||||||
|
)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
for car in all_cars:
|
|
||||||
|
all_cars2 = (
|
||||||
|
await Car.objects.select_related(
|
||||||
|
["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
|
||||||
|
)
|
||||||
|
.fields(
|
||||||
|
{
|
||||||
|
"id": ...,
|
||||||
|
"name": ...,
|
||||||
|
"manufacturer": {
|
||||||
|
"name": ...,
|
||||||
|
"hq": {"name": ..., "nicks": {"name": ...}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
all_cars3 = (
|
||||||
|
await Car.objects.select_related(
|
||||||
|
["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"]
|
||||||
|
)
|
||||||
|
.fields(
|
||||||
|
{
|
||||||
|
"id": ...,
|
||||||
|
"name": ...,
|
||||||
|
"manufacturer": {
|
||||||
|
"name": ...,
|
||||||
|
"hq": {"name": ..., "nicks": {"name"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
assert all_cars3 == all_cars
|
||||||
|
|
||||||
|
for car in itertools.chain(all_cars, all_cars2):
|
||||||
assert all(
|
assert all(
|
||||||
getattr(car, x) is None
|
getattr(car, x) is None
|
||||||
for x in ["year", "gearbox_type", "gears", "aircon_type"]
|
for x in ["year", "gearbox_type", "gears", "aircon_type"]
|
||||||
)
|
)
|
||||||
assert car.manufacturer.name == "Toyota"
|
assert car.manufacturer.name == "Toyota"
|
||||||
assert car.manufacturer.founded is None
|
assert car.manufacturer.founded is None
|
||||||
|
assert car.manufacturer.hq.name == "Japan"
|
||||||
|
assert len(car.manufacturer.hq.nicks) == 2
|
||||||
|
assert car.manufacturer.hq.nicks[0].is_lame is None
|
||||||
|
|
||||||
all_cars = (
|
all_cars = (
|
||||||
await Car.objects.select_related("manufacturer")
|
await Car.objects.select_related("manufacturer")
|
||||||
@ -103,9 +190,16 @@ async def test_selecting_subset():
|
|||||||
)
|
)
|
||||||
assert car.manufacturer.name == "Toyota"
|
assert car.manufacturer.name == "Toyota"
|
||||||
assert car.manufacturer.founded == 1937
|
assert car.manufacturer.founded == 1937
|
||||||
|
assert car.manufacturer.hq.name is None
|
||||||
|
|
||||||
all_cars_check = await Car.objects.select_related("manufacturer").all()
|
all_cars_check = await Car.objects.select_related("manufacturer").all()
|
||||||
for car in all_cars_check:
|
all_cars_with_whole_nested = (
|
||||||
|
await Car.objects.select_related("manufacturer")
|
||||||
|
.fields(["id", "name", "year", "gearbox_type", "gears", "aircon_type"])
|
||||||
|
.fields({"manufacturer": ...})
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
for car in itertools.chain(all_cars_check, all_cars_with_whole_nested):
|
||||||
assert all(
|
assert all(
|
||||||
getattr(car, x) is not None
|
getattr(car, x) is not None
|
||||||
for x in ["year", "gearbox_type", "gears", "aircon_type"]
|
for x in ["year", "gearbox_type", "gears", "aircon_type"]
|
||||||
@ -113,8 +207,20 @@ async def test_selecting_subset():
|
|||||||
assert car.manufacturer.name == "Toyota"
|
assert car.manufacturer.name == "Toyota"
|
||||||
assert car.manufacturer.founded == 1937
|
assert car.manufacturer.founded == 1937
|
||||||
|
|
||||||
|
all_cars_dummy = (
|
||||||
|
await Car.objects.select_related("manufacturer")
|
||||||
|
.fields(["id", "name", "year", "gearbox_type", "gears", "aircon_type"])
|
||||||
|
.fields({"manufacturer": ...})
|
||||||
|
.exclude_fields({"manufacturer": ...})
|
||||||
|
.fields({"manufacturer": {"name"}})
|
||||||
|
.exclude_fields({"manufacturer__founded"})
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all_cars_dummy[0].manufacturer.founded is None
|
||||||
|
|
||||||
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
||||||
# cannot exclude mandatory model columns - company__name in this example
|
# cannot exclude mandatory model columns - company__name in this example
|
||||||
await Car.objects.select_related("manufacturer").fields(
|
await Car.objects.select_related("manufacturer").fields(
|
||||||
["id", "name", "company__founded"]
|
["id", "name", "manufacturer__founded"]
|
||||||
).all()
|
).all()
|
||||||
|
|||||||
Reference in New Issue
Block a user