fix dumping to dict with include and exclude

This commit is contained in:
collerek
2020-11-27 18:51:40 +01:00
parent 9631f6d1d5
commit 0ed4ef4833
9 changed files with 634 additions and 140 deletions

View File

@ -2,19 +2,25 @@ from typing import Dict, Set, Union
class Excludable: class Excludable:
@staticmethod
def get_child(
items: Union[Set, Dict, None], key: str = None
) -> Union[Set, Dict, None]:
if isinstance(items, dict):
return items.get(key, {})
return items
@staticmethod @staticmethod
def get_excluded( def get_excluded(
exclude: Union[Set, Dict, None], key: str = None exclude: Union[Set, Dict, None], key: str = None
) -> Union[Set, Dict, None]: ) -> Union[Set, Dict, None]:
if isinstance(exclude, dict): return Excludable.get_child(items=exclude, key=key)
return exclude.get(key, {})
return exclude
@staticmethod @staticmethod
def get_included( def get_included(
include: Union[Set, Dict, None], key: str = None include: Union[Set, Dict, None], key: str = None
) -> Union[Set, Dict, None]: ) -> Union[Set, Dict, None]:
return Excludable.get_excluded(exclude=include, key=key) return Excludable.get_child(items=include, key=key)
@staticmethod @staticmethod
def is_excluded(exclude: Union[Set, Dict, None], key: str = None) -> bool: def is_excluded(exclude: Union[Set, Dict, None], key: str = None) -> bool:

View File

@ -1,10 +1,12 @@
import inspect import inspect
from collections import OrderedDict from collections import OrderedDict
from typing import ( from typing import (
AbstractSet,
Any, Any,
Callable, Callable,
Dict, Dict,
List, List,
Mapping,
Optional, Optional,
Sequence, Sequence,
Set, Set,
@ -16,6 +18,7 @@ from typing import (
) )
from ormar.exceptions import ModelPersistenceError, RelationshipInstanceError from ormar.exceptions import ModelPersistenceError, RelationshipInstanceError
from ormar.queryset.utils import translate_list_to_dict, update
try: try:
import orjson as json import orjson as json
@ -32,6 +35,9 @@ if TYPE_CHECKING: # pragma no cover
from ormar.models import NewBaseModel from ormar.models import NewBaseModel
T = TypeVar("T", bound=Model) T = TypeVar("T", bound=Model)
IntStr = Union[int, str]
AbstractSetIntStr = AbstractSet[IntStr]
MappingIntStrAny = Mapping[IntStr, Any]
Field = TypeVar("Field", bound=BaseField) Field = TypeVar("Field", bound=BaseField)
@ -203,6 +209,21 @@ class ModelTableProxy:
} }
return related_names return related_names
@classmethod
def _update_excluded_with_related_not_required(
cls,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny", None],
nested: bool = False,
) -> Union[Set, Dict]:
exclude = exclude or {}
related_set = cls._exclude_related_names_not_required(nested=nested)
if isinstance(exclude, set):
exclude.union(related_set)
else:
related_dict = translate_list_to_dict(related_set)
exclude = update(related_dict, exclude)
return exclude
def _extract_model_db_fields(self) -> Dict: def _extract_model_db_fields(self) -> Dict:
self_fields = self._extract_own_model_fields() self_fields = self._extract_own_model_fields()
self_fields = { self_fields = {

View File

@ -27,6 +27,7 @@ from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.excludable import Excludable from ormar.models.excludable import Excludable
from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.metaclass import ModelMeta, ModelMetaclass
from ormar.models.modelproxy import ModelTableProxy from ormar.models.modelproxy import ModelTableProxy
from ormar.queryset.utils import translate_list_to_dict
from ormar.relations.alias_manager import AliasManager from ormar.relations.alias_manager import AliasManager
from ormar.relations.relation_manager import RelationsManager from ormar.relations.relation_manager import RelationsManager
@ -213,9 +214,7 @@ class NewBaseModel(
@classmethod @classmethod
def get_properties( def get_properties(
cls, cls, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None]
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
) -> List[str]: ) -> List[str]:
if isinstance(cls._props, list): if isinstance(cls._props, list):
props = cls._props props = cls._props
@ -234,11 +233,76 @@ class NewBaseModel(
props = [prop for prop in props if prop not in exclude] props = [prop for prop in props if prop not in exclude]
return props return props
def dict( # noqa A003 def _get_related_not_excluded_fields(
self, include: Optional[Dict], exclude: Optional[Dict],
) -> List:
fields = [field for field in self.extract_related_names()]
if include:
fields = [field for field in fields if field in include]
if exclude:
fields = [
field
for field in fields
if field not in exclude or exclude.get(field) is not Ellipsis
]
return fields
@staticmethod
def _extract_nested_models_from_list(
models: List, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None],
) -> List:
result = []
for model in models:
try:
result.append(
model.dict(nested=True, include=include, exclude=exclude,)
)
except ReferenceError: # pragma no cover
continue
return result
@staticmethod
def _skip_ellipsis(
items: Union[Set, Dict, None], key: str
) -> Union[Set, Dict, None]:
result = Excludable.get_child(items, key)
return result if result is not Ellipsis else None
def _extract_nested_models( # noqa: CCR001
self,
nested: bool,
dict_instance: Dict,
include: Optional[Dict],
exclude: Optional[Dict],
) -> Dict:
fields = self._get_related_not_excluded_fields(include=include, exclude=exclude)
for field in fields:
if self.Meta.model_fields[field].virtual and nested:
continue
nested_model = getattr(self, field)
if isinstance(nested_model, list):
dict_instance[field] = self._extract_nested_models_from_list(
models=nested_model,
include=self._skip_ellipsis(include, field),
exclude=self._skip_ellipsis(exclude, field),
)
elif nested_model is not None:
dict_instance[field] = nested_model.dict(
nested=True,
include=self._skip_ellipsis(include, field),
exclude=self._skip_ellipsis(exclude, field),
)
else:
dict_instance[field] = None
return dict_instance
def dict( # type: ignore # noqa A003
self, self,
*, *,
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, include: Union[Set, Dict] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, exclude: Union[Set, Dict] = None,
by_alias: bool = False, by_alias: bool = False,
skip_defaults: bool = None, skip_defaults: bool = None,
exclude_unset: bool = False, exclude_unset: bool = False,
@ -248,30 +312,25 @@ class NewBaseModel(
) -> "DictStrAny": # noqa: A003' ) -> "DictStrAny": # noqa: A003'
dict_instance = super().dict( dict_instance = super().dict(
include=include, include=include,
exclude=self._exclude_related_names_not_required(nested), exclude=self._update_excluded_with_related_not_required(exclude, nested),
by_alias=by_alias, by_alias=by_alias,
skip_defaults=skip_defaults, skip_defaults=skip_defaults,
exclude_unset=exclude_unset, exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults, exclude_defaults=exclude_defaults,
exclude_none=exclude_none, exclude_none=exclude_none,
) )
for field in self.extract_related_names():
nested_model = getattr(self, field)
if self.Meta.model_fields[field].virtual and nested: if include and isinstance(include, Set):
continue include = translate_list_to_dict(include)
if isinstance(nested_model, list): if exclude and isinstance(exclude, Set):
result = [] exclude = translate_list_to_dict(exclude)
for model in nested_model:
try: dict_instance = self._extract_nested_models(
result.append(model.dict(nested=True)) nested=nested,
except ReferenceError: # pragma no cover dict_instance=dict_instance,
continue include=include, # type: ignore
dict_instance[field] = result exclude=exclude, # type: ignore
elif nested_model is not None: )
dict_instance[field] = nested_model.dict(nested=True)
else:
dict_instance[field] = None
# include model properties as fields # include model properties as fields
props = self.get_properties(include=include, exclude=exclude) props = self.get_properties(include=include, exclude=exclude)

View File

@ -1,6 +1,15 @@
import collections.abc import collections.abc
import copy import copy
from typing import Any, Dict, List, Sequence, Set, TYPE_CHECKING, Type, Union from typing import (
Any,
Dict,
List,
Sequence,
Set,
TYPE_CHECKING,
Type,
Union,
)
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model

View File

@ -0,0 +1,139 @@
from typing import Optional
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
metadata = sqlalchemy.MetaData()
database = databases.Database(DATABASE_URL, force_rollback=True)
class User(ormar.Model):
class Meta:
tablename: str = "users"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
email: str = ormar.String(max_length=255, nullable=False)
password: str = ormar.String(max_length=255, nullable=True)
first_name: str = ormar.String(max_length=255, nullable=False)
class Tier(ormar.Model):
class Meta:
tablename = "tiers"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Category(ormar.Model):
class Meta:
tablename = "categories"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
tier: Optional[Tier] = ormar.ForeignKey(Tier)
class Item(ormar.Model):
class Meta:
tablename = "items"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
category: Optional[Category] = ormar.ForeignKey(Category, nullable=True)
created_by: Optional[User] = ormar.ForeignKey(User)
@pytest.fixture(autouse=True, scope="module")
def sample_data():
user = User(email="test@test.com", password="ijacids7^*&", first_name="Anna")
tier = Tier(name="Tier I")
category1 = Category(name="Toys", tier=tier)
category2 = Category(name="Weapons", tier=tier)
item1 = Item(name="Teddy Bear", category=category1, created_by=user)
item2 = Item(name="M16", category=category2, created_by=user)
return item1, item2
def test_dumping_to_dict_no_exclusion(sample_data):
item1, item2 = sample_data
dict1 = item1.dict()
assert dict1["name"] == "Teddy Bear"
assert dict1["category"]["name"] == "Toys"
assert dict1["category"]["tier"]['name'] == "Tier I"
assert dict1["created_by"]["email"] == "test@test.com"
dict2 = item2.dict()
assert dict2["name"] == "M16"
assert dict2["category"]["name"] == "Weapons"
assert dict2["created_by"]["email"] == "test@test.com"
def test_dumping_to_dict_exclude_set(sample_data):
item1, item2 = sample_data
dict3 = item2.dict(exclude={"name"})
assert "name" not in dict3
assert dict3["category"]["name"] == "Weapons"
assert dict3["created_by"]["email"] == "test@test.com"
dict4 = item2.dict(exclude={"category"})
assert dict4["name"] == "M16"
assert "category" not in dict4
assert dict4["created_by"]["email"] == "test@test.com"
dict5 = item2.dict(exclude={"category", "name"})
assert "name" not in dict5
assert "category" not in dict5
assert dict5["created_by"]["email"] == "test@test.com"
def test_dumping_to_dict_exclude_dict(sample_data):
item1, item2 = sample_data
dict6 = item2.dict(exclude={"category": {"name"}, "name": ...})
assert "name" not in dict6
assert "category" in dict6
assert "name" not in dict6["category"]
assert dict6["created_by"]["email"] == "test@test.com"
def test_dumping_to_dict_exclude_nested_dict(sample_data):
item1, item2 = sample_data
dict1 = item2.dict(exclude={"category": {"tier": {"name"}}, "name": ...})
assert "name" not in dict1
assert "category" in dict1
assert dict1["category"]['name'] == 'Weapons'
assert dict1["created_by"]["email"] == "test@test.com"
assert dict1["category"]["tier"].get('name') is None
def test_dumping_to_dict_exclude_and_include_nested_dict(sample_data):
item1, item2 = sample_data
dict1 = item2.dict(exclude={"category": {"tier": {"name"}}},
include={'name', 'category'})
assert dict1.get('name') == 'M16'
assert "category" in dict1
assert dict1["category"]['name'] == 'Weapons'
assert "created_by" not in dict1
assert dict1["category"]["tier"].get('name') is None
dict2 = item1.dict(exclude={"id": ...},
include={'name': ..., 'category': {'name': ..., 'tier': {'id'}}})
assert dict2.get('name') == 'Teddy Bear'
assert dict2.get('id') is None # models not saved
assert dict2["category"]['name'] == 'Toys'
assert "created_by" not in dict1
assert dict1["category"]["tier"].get('name') is None
assert dict1["category"]["tier"]['id'] is None

View File

@ -0,0 +1,178 @@
import datetime
import string
import random
import databases
import pydantic
import pytest
import sqlalchemy
from fastapi import FastAPI
from starlette.testclient import TestClient
import ormar
from tests.settings import DATABASE_URL
app = FastAPI()
metadata = sqlalchemy.MetaData()
database = databases.Database(DATABASE_URL, force_rollback=True)
app.state.database = database
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
# note that you can set orm_mode here
# and in this case UserSchema become unnecessary
class UserBase(pydantic.BaseModel):
class Config:
orm_mode = True
email: str
first_name: str
last_name: str
class UserCreateSchema(UserBase):
password: str
category: str
class UserSchema(UserBase):
class Config:
orm_mode = True
def gen_pass():
choices = string.ascii_letters + string.digits + "!@#$%^&*()"
return "".join(random.choice(choices) for _ in range(20))
class RandomModel(ormar.Model):
class Meta:
tablename: str = "random_users"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
password: str = ormar.String(max_length=255, default=gen_pass)
first_name: str = ormar.String(max_length=255, default='John')
last_name: str = ormar.String(max_length=255)
created_date: datetime.datetime = ormar.DateTime(server_default=sqlalchemy.func.now())
class User(ormar.Model):
class Meta:
tablename: str = "users"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
email: str = ormar.String(max_length=255, nullable=False)
password: str = ormar.String(max_length=255, nullable=True)
first_name: str = ormar.String(max_length=255, nullable=False)
last_name: str = ormar.String(max_length=255, nullable=False)
category: str = ormar.String(max_length=255, nullable=True)
class User2(ormar.Model):
class Meta:
tablename: str = "users2"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
email: str = ormar.String(max_length=255, nullable=False)
password: str = ormar.String(max_length=255, nullable=False)
first_name: str = ormar.String(max_length=255, nullable=False)
last_name: str = ormar.String(max_length=255, nullable=False)
category: str = ormar.String(max_length=255, nullable=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@app.post("/users/", response_model=User, response_model_exclude={"password"})
async def create_user(user: User):
return await user.save()
@app.post("/users2/", response_model=User)
async def create_user2(user: User):
user = await user.save()
return user.dict(exclude={'password'})
@app.post("/users3/", response_model=UserBase)
async def create_user3(user: User2):
return await user.save()
@app.post("/users4/")
async def create_user4(user: User2):
user = await user.save()
return user.dict(exclude={'password'})
@app.post("/random/", response_model=RandomModel)
async def create_user5(user: RandomModel):
return await user.save()
def test_all_endpoints():
client = TestClient(app)
with client as client:
user = {
"email": "test@domain.com",
"password": "^*^%A*DA*IAAA",
"first_name": "John",
"last_name": "Doe",
}
response = client.post("/users/", json=user)
created_user = User(**response.json())
assert created_user.pk is not None
assert created_user.password is None
user2 = {
"email": "test@domain.com",
"first_name": "John",
"last_name": "Doe",
}
response = client.post("/users/", json=user2)
created_user = User(**response.json())
assert created_user.pk is not None
assert created_user.password is None
response = client.post("/users2/", json=user)
created_user2 = User(**response.json())
assert created_user2.pk is not None
assert created_user2.password is None
# response has only 3 fields from UserBase
response = client.post("/users3/", json=user)
assert list(response.json().keys()) == ['email', 'first_name', 'last_name']
response = client.post("/users4/", json=user)
assert list(response.json().keys()) == ['id', 'email', 'first_name', 'last_name', 'category']
user3 = {
'last_name': 'Test'
}
response = client.post("/random/", json=user3)
assert list(response.json().keys()) == ['id', 'password', 'first_name', 'last_name', 'created_date']

View File

@ -408,9 +408,11 @@ async def test_bulk_update_model_with_children():
album=best_seller, title="t4", position=1, play_count=500 album=best_seller, title="t4", position=1, play_count=500
) )
tracks = await Track.objects.select_related("album").filter( tracks = (
play_count__gt=10 await Track.objects.select_related("album")
).all() .filter(play_count__gt=10)
.all()
)
best_seller_albums = {} best_seller_albums = {}
for track in tracks: for track in tracks:
album = track.album album = track.album
@ -421,5 +423,7 @@ async def test_bulk_update_model_with_children():
await Album.objects.bulk_update( await Album.objects.bulk_update(
best_seller_albums.values(), columns=["is_best_seller"] best_seller_albums.values(), columns=["is_best_seller"]
) )
best_seller_albums_db = await Album.objects.filter(is_best_seller=True).all() best_seller_albums_db = await Album.objects.filter(
is_best_seller=True
).all()
assert len(best_seller_albums_db) == 2 assert len(best_seller_albums_db) == 2

View File

@ -17,7 +17,7 @@ class RandomSet(ormar.Model):
metadata = metadata metadata = metadata
database = database database = database
id: int = ormar.Integer(name='random_id', primary_key=True) id: int = ormar.Integer(name="random_id", primary_key=True)
name: str = ormar.String(max_length=100) name: str = ormar.String(max_length=100)
@ -28,7 +28,7 @@ class Tonation(ormar.Model):
database = database database = database
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(name='tonation_name', max_length=100) name: str = ormar.String(name="tonation_name", max_length=100)
rand_set: Optional[RandomSet] = ormar.ForeignKey(RandomSet) rand_set: Optional[RandomSet] = ormar.ForeignKey(RandomSet)
@ -38,7 +38,7 @@ class Division(ormar.Model):
metadata = metadata metadata = metadata
database = database database = database
id: int = ormar.Integer(name='division_id', primary_key=True) id: int = ormar.Integer(name="division_id", primary_key=True)
name: str = ormar.String(max_length=100, nullable=True) name: str = ormar.String(max_length=100, nullable=True)
@ -77,11 +77,11 @@ class Track(ormar.Model):
metadata = metadata metadata = metadata
database = database database = database
id: int = ormar.Integer(name='track_id', primary_key=True) id: int = ormar.Integer(name="track_id", primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album) album: Optional[Album] = ormar.ForeignKey(Album)
title: str = ormar.String(max_length=100) title: str = ormar.String(max_length=100)
position: int = ormar.Integer() position: int = ormar.Integer()
tonation: Optional[Tonation] = ormar.ForeignKey(Tonation, name='tonation_id') tonation: Optional[Tonation] = ormar.ForeignKey(Tonation, name="tonation_id")
class Cover(ormar.Model): class Cover(ormar.Model):
@ -91,7 +91,9 @@ class Cover(ormar.Model):
database = database database = database
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures", name='album_id') album: Optional[Album] = ormar.ForeignKey(
Album, related_name="cover_pictures", name="album_id"
)
title: str = ormar.String(max_length=100) title: str = ormar.String(max_length=100)
artist: str = ormar.String(max_length=200, nullable=True) artist: str = ormar.String(max_length=200, nullable=True)
@ -111,42 +113,71 @@ async def test_prefetch_related():
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
album = Album(name="Malibu") album = Album(name="Malibu")
await album.save() await album.save()
ton1 = await Tonation.objects.create(name='B-mol') ton1 = await Tonation.objects.create(name="B-mol")
await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) await Track.objects.create(
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) album=album, title="The Bird", position=1, tonation=ton1
await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) )
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') await Track.objects.create(
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') album=album,
title="Heart don't stand a chance",
position=2,
tonation=ton1,
)
await Track.objects.create(
album=album, title="The Waters", position=3, tonation=ton1
)
await Cover.objects.create(title="Cover1", album=album, artist="Artist 1")
await Cover.objects.create(title="Cover2", album=album, artist="Artist 2")
fantasies = Album(name="Fantasies") fantasies = Album(name="Fantasies")
await fantasies.save() await fantasies.save()
await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) await Track.objects.create(
album=fantasies, title="Help I'm Alive", position=1
)
await Track.objects.create(album=fantasies, title="Sick Muse", position=2) await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) await Track.objects.create(
await Cover.objects.create(title='Cover3', album=fantasies, artist='Artist 3') album=fantasies, title="Satellite Mind", position=3
await Cover.objects.create(title='Cover4', album=fantasies, artist='Artist 4') )
await Cover.objects.create(
title="Cover3", album=fantasies, artist="Artist 3"
)
await Cover.objects.create(
title="Cover4", album=fantasies, artist="Artist 4"
)
album = await Album.objects.filter(name='Malibu').prefetch_related( album = (
['tracks__tonation', 'cover_pictures']).get() await Album.objects.filter(name="Malibu")
.prefetch_related(["tracks__tonation", "cover_pictures"])
.get()
)
assert len(album.tracks) == 3 assert len(album.tracks) == 3
assert album.tracks[0].title == 'The Bird' assert album.tracks[0].title == "The Bird"
assert len(album.cover_pictures) == 2 assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].title == 'Cover1' assert album.cover_pictures[0].title == "Cover1"
assert album.tracks[0].tonation.name == album.tracks[2].tonation.name == 'B-mol' assert (
album.tracks[0].tonation.name
== album.tracks[2].tonation.name
== "B-mol"
)
albums = await Album.objects.prefetch_related('tracks').all() albums = await Album.objects.prefetch_related("tracks").all()
assert len(albums[0].tracks) == 3 assert len(albums[0].tracks) == 3
assert len(albums[1].tracks) == 3 assert len(albums[1].tracks) == 3
assert albums[0].tracks[0].title == "The Bird" assert albums[0].tracks[0].title == "The Bird"
assert albums[1].tracks[0].title == "Help I'm Alive" assert albums[1].tracks[0].title == "Help I'm Alive"
track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird") track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(
title="The Bird"
)
assert track.album.name == "Malibu" assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2 assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1' assert track.album.cover_pictures[0].artist == "Artist 1"
track = await Track.objects.prefetch_related(["album__cover_pictures"]).exclude_fields( track = (
'album__cover_pictures__artist').get(title="The Bird") await Track.objects.prefetch_related(["album__cover_pictures"])
.exclude_fields("album__cover_pictures__artist")
.get(title="The Bird")
)
assert track.album.name == "Malibu" assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2 assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist is None assert track.album.cover_pictures[0].artist is None
@ -159,29 +190,32 @@ async def test_prefetch_related():
async def test_prefetch_related_with_many_to_many(): async def test_prefetch_related_with_many_to_many():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
div = await Division.objects.create(name='Div 1') div = await Division.objects.create(name="Div 1")
shop1 = await Shop.objects.create(name='Shop 1', division=div) shop1 = await Shop.objects.create(name="Shop 1", division=div)
shop2 = await Shop.objects.create(name='Shop 2', division=div) shop2 = await Shop.objects.create(name="Shop 2", division=div)
album = Album(name="Malibu") album = Album(name="Malibu")
await album.save() await album.save()
await album.shops.add(shop1) await album.shops.add(shop1)
await album.shops.add(shop2) await album.shops.add(shop2)
await Track.objects.create(album=album, title="The Bird", position=1) await Track.objects.create(album=album, title="The Bird", position=1)
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2) await Track.objects.create(
album=album, title="Heart don't stand a chance", position=2
)
await Track.objects.create(album=album, title="The Waters", position=3) await Track.objects.create(album=album, title="The Waters", position=3)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') await Cover.objects.create(title="Cover1", album=album, artist="Artist 1")
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') await Cover.objects.create(title="Cover2", album=album, artist="Artist 2")
track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops__division"]).get( track = await Track.objects.prefetch_related(
title="The Bird") ["album__cover_pictures", "album__shops__division"]
).get(title="The Bird")
assert track.album.name == "Malibu" assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2 assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1' assert track.album.cover_pictures[0].artist == "Artist 1"
assert len(track.album.shops) == 2 assert len(track.album.shops) == 2
assert track.album.shops[0].name == 'Shop 1' assert track.album.shops[0].name == "Shop 1"
assert track.album.shops[0].division.name == 'Div 1' assert track.album.shops[0].division.name == "Div 1"
album2 = Album(name="Malibu 2") album2 = Album(name="Malibu 2")
await album2.save() await album2.save()
@ -190,14 +224,14 @@ async def test_prefetch_related_with_many_to_many():
await Track.objects.create(album=album2, title="The Bird 2", position=1) await Track.objects.create(album=album2, title="The Bird 2", position=1)
tracks = await Track.objects.prefetch_related(["album__shops"]).all() tracks = await Track.objects.prefetch_related(["album__shops"]).all()
assert tracks[0].album.name == 'Malibu' assert tracks[0].album.name == "Malibu"
assert tracks[0].album.shops[0].name == "Shop 1" assert tracks[0].album.shops[0].name == "Shop 1"
assert tracks[3].album.name == 'Malibu 2' assert tracks[3].album.name == "Malibu 2"
assert tracks[3].album.shops[0].name == "Shop 1" assert tracks[3].album.shops[0].name == "Shop 1"
assert tracks[0].album.shops[0] == tracks[3].album.shops[0] assert tracks[0].album.shops[0] == tracks[3].album.shops[0]
assert id(tracks[0].album.shops[0]) == id(tracks[3].album.shops[0]) assert id(tracks[0].album.shops[0]) == id(tracks[3].album.shops[0])
tracks[0].album.shops[0].name = 'Dummy' tracks[0].album.shops[0].name = "Dummy"
assert tracks[0].album.shops[0].name == tracks[3].album.shops[0].name assert tracks[0].album.shops[0].name == tracks[3].album.shops[0].name
@ -206,8 +240,10 @@ async def test_prefetch_related_empty():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
await Track.objects.create(title="The Bird", position=1) await Track.objects.create(title="The Bird", position=1)
track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird") track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(
assert track.title == 'The Bird' title="The Bird"
)
assert track.title == "The Bird"
assert track.album is None assert track.album is None
@ -215,91 +251,133 @@ async def test_prefetch_related_empty():
async def test_prefetch_related_with_select_related(): async def test_prefetch_related_with_select_related():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
div = await Division.objects.create(name='Div 1') div = await Division.objects.create(name="Div 1")
shop1 = await Shop.objects.create(name='Shop 1', division=div) shop1 = await Shop.objects.create(name="Shop 1", division=div)
shop2 = await Shop.objects.create(name='Shop 2', division=div) shop2 = await Shop.objects.create(name="Shop 2", division=div)
album = Album(name="Malibu") album = Album(name="Malibu")
await album.save() await album.save()
await album.shops.add(shop1) await album.shops.add(shop1)
await album.shops.add(shop2) await album.shops.add(shop2)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') await Cover.objects.create(title="Cover1", album=album, artist="Artist 1")
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') await Cover.objects.create(title="Cover2", album=album, artist="Artist 2")
album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related( album = (
['cover_pictures', 'shops__division']).get() await Album.objects.select_related(["tracks", "shops"])
.filter(name="Malibu")
.prefetch_related(["cover_pictures", "shops__division"])
.get()
)
assert len(album.tracks) == 0 assert len(album.tracks) == 0
assert len(album.cover_pictures) == 2 assert len(album.cover_pictures) == 2
assert album.shops[0].division.name == 'Div 1' assert album.shops[0].division.name == "Div 1"
rand_set = await RandomSet.objects.create(name='Rand 1') rand_set = await RandomSet.objects.create(name="Rand 1")
ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set) ton1 = await Tonation.objects.create(name="B-mol", rand_set=rand_set)
await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) await Track.objects.create(
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) album=album, title="The Bird", position=1, tonation=ton1
await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) )
await Track.objects.create(
album=album,
title="Heart don't stand a chance",
position=2,
tonation=ton1,
)
await Track.objects.create(
album=album, title="The Waters", position=3, tonation=ton1
)
album = await Album.objects.select_related('tracks__tonation__rand_set').filter( album = (
name='Malibu').prefetch_related( await Album.objects.select_related("tracks__tonation__rand_set")
['cover_pictures', 'shops__division']).order_by( .filter(name="Malibu")
['-shops__name', '-cover_pictures__artist', 'shops__division__name']).get() .prefetch_related(["cover_pictures", "shops__division"])
.order_by(
["-shops__name", "-cover_pictures__artist", "shops__division__name"]
)
.get()
)
assert len(album.tracks) == 3 assert len(album.tracks) == 3
assert album.tracks[0].tonation == album.tracks[2].tonation == ton1 assert album.tracks[0].tonation == album.tracks[2].tonation == ton1
assert len(album.cover_pictures) == 2 assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].artist == 'Artist 2' assert album.cover_pictures[0].artist == "Artist 2"
assert len(album.shops) == 2 assert len(album.shops) == 2
assert album.shops[0].name == 'Shop 2' assert album.shops[0].name == "Shop 2"
assert album.shops[0].division.name == 'Div 1' assert album.shops[0].division.name == "Div 1"
track = await Track.objects.select_related('album').prefetch_related( track = (
["album__cover_pictures", "album__shops__division"]).get( await Track.objects.select_related("album")
title="The Bird") .prefetch_related(["album__cover_pictures", "album__shops__division"])
.get(title="The Bird")
)
assert track.album.name == "Malibu" assert track.album.name == "Malibu"
assert len(track.album.cover_pictures) == 2 assert len(track.album.cover_pictures) == 2
assert track.album.cover_pictures[0].artist == 'Artist 1' assert track.album.cover_pictures[0].artist == "Artist 1"
assert len(track.album.shops) == 2 assert len(track.album.shops) == 2
assert track.album.shops[0].name == 'Shop 1' assert track.album.shops[0].name == "Shop 1"
assert track.album.shops[0].division.name == 'Div 1' assert track.album.shops[0].division.name == "Div 1"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_prefetch_related_with_select_related_and_fields(): async def test_prefetch_related_with_select_related_and_fields():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
div = await Division.objects.create(name='Div 1') div = await Division.objects.create(name="Div 1")
shop1 = await Shop.objects.create(name='Shop 1', division=div) shop1 = await Shop.objects.create(name="Shop 1", division=div)
shop2 = await Shop.objects.create(name='Shop 2', division=div) shop2 = await Shop.objects.create(name="Shop 2", division=div)
album = Album(name="Malibu") album = Album(name="Malibu")
await album.save() await album.save()
await album.shops.add(shop1) await album.shops.add(shop1)
await album.shops.add(shop2) await album.shops.add(shop2)
await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') await Cover.objects.create(title="Cover1", album=album, artist="Artist 1")
await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') await Cover.objects.create(title="Cover2", album=album, artist="Artist 2")
rand_set = await RandomSet.objects.create(name='Rand 1') rand_set = await RandomSet.objects.create(name="Rand 1")
ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set) ton1 = await Tonation.objects.create(name="B-mol", rand_set=rand_set)
await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) await Track.objects.create(
await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) album=album, title="The Bird", position=1, tonation=ton1
await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) )
await Track.objects.create(
album=album,
title="Heart don't stand a chance",
position=2,
tonation=ton1,
)
await Track.objects.create(
album=album, title="The Waters", position=3, tonation=ton1
)
album = await Album.objects.select_related('tracks__tonation__rand_set').filter( album = (
name='Malibu').prefetch_related( await Album.objects.select_related("tracks__tonation__rand_set")
['cover_pictures', 'shops__division']).exclude_fields({'shops': {'division': {'name'}}}).get() .filter(name="Malibu")
.prefetch_related(["cover_pictures", "shops__division"])
.exclude_fields({"shops": {"division": {"name"}}})
.get()
)
assert len(album.tracks) == 3 assert len(album.tracks) == 3
assert album.tracks[0].tonation == album.tracks[2].tonation == ton1 assert album.tracks[0].tonation == album.tracks[2].tonation == ton1
assert len(album.cover_pictures) == 2 assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].artist == 'Artist 1' assert album.cover_pictures[0].artist == "Artist 1"
assert len(album.shops) == 2 assert len(album.shops) == 2
assert album.shops[0].name == 'Shop 1' assert album.shops[0].name == "Shop 1"
assert album.shops[0].division.name is None assert album.shops[0].division.name is None
album = await Album.objects.select_related('tracks').filter( album = (
name='Malibu').prefetch_related( await Album.objects.select_related("tracks")
['cover_pictures', 'shops__division']).fields( .filter(name="Malibu")
{'name': ..., 'shops': {'division'}, 'cover_pictures': {'id': ..., 'title': ...}} .prefetch_related(["cover_pictures", "shops__division"])
).exclude_fields({'shops': {'division': {'name'}}}).get() .fields(
{
"name": ...,
"shops": {"division"},
"cover_pictures": {"id": ..., "title": ...},
}
)
.exclude_fields({"shops": {"division": {"name"}}})
.get()
)
assert len(album.tracks) == 3 assert len(album.tracks) == 3
assert len(album.cover_pictures) == 2 assert len(album.cover_pictures) == 2
assert album.cover_pictures[0].artist is None assert album.cover_pictures[0].artist is None

View File

@ -121,24 +121,24 @@ class SortModel(ormar.Model):
def test_sorting_models(): def test_sorting_models():
models = [ models = [
SortModel(id=1, name='Alice', sort_order=0), SortModel(id=1, name="Alice", sort_order=0),
SortModel(id=2, name='Al', sort_order=1), SortModel(id=2, name="Al", sort_order=1),
SortModel(id=3, name='Zake', sort_order=1), SortModel(id=3, name="Zake", sort_order=1),
SortModel(id=4, name='Will', sort_order=0), SortModel(id=4, name="Will", sort_order=0),
SortModel(id=5, name='Al', sort_order=2), SortModel(id=5, name="Al", sort_order=2),
SortModel(id=6, name='Alice', sort_order=2) SortModel(id=6, name="Alice", sort_order=2),
] ]
orders_by = {'name': 'asc', 'none': {}, 'sort_order': 'desc'} orders_by = {"name": "asc", "none": {}, "sort_order": "desc"}
models = sort_models(models, orders_by) models = sort_models(models, orders_by)
assert models[5].name == 'Zake' assert models[5].name == "Zake"
assert models[0].name == 'Al' assert models[0].name == "Al"
assert models[1].name == 'Al' assert models[1].name == "Al"
assert [model.id for model in models] == [5, 2, 6, 1, 4, 3] assert [model.id for model in models] == [5, 2, 6, 1, 4, 3]
orders_by = {'name': 'asc', 'none': set('aa'), 'id': 'asc'} orders_by = {"name": "asc", "none": set("aa"), "id": "asc"}
models = sort_models(models, orders_by) models = sort_models(models, orders_by)
assert [model.id for model in models] == [2, 5, 1, 6, 4, 3] assert [model.id for model in models] == [2, 5, 1, 6, 4, 3]
orders_by = {'sort_order': 'asc', 'none': ..., 'id': 'asc', 'uu': 2, 'aa': None} orders_by = {"sort_order": "asc", "none": ..., "id": "asc", "uu": 2, "aa": None}
models = sort_models(models, orders_by) models = sort_models(models, orders_by)
assert [model.id for model in models] == [1, 4, 2, 3, 5, 6] assert [model.id for model in models] == [1, 4, 2, 3, 5, 6]