diff --git a/docs/releases.md b/docs/releases.md index cdbd6b0..d27647f 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,7 @@ +# 0.5.3 + +* Fixed bug in `Model.dict()` method that was ignoring exclude parameter and not include dictionary argument. + # 0.5.2 * Added `prefetch_related` method to load subsequent models in separate queries. diff --git a/ormar/__init__.py b/ormar/__init__.py index cd3712a..e910731 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.5.2" +__version__ = "0.5.3" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/models/excludable.py b/ormar/models/excludable.py index c86e0c3..11c57ab 100644 --- a/ormar/models/excludable.py +++ b/ormar/models/excludable.py @@ -2,19 +2,25 @@ from typing import Dict, Set, Union class Excludable: + @staticmethod + def get_child( + items: Union[Set, Dict, None], key: str = None + ) -> Union[Set, Dict, None]: + if isinstance(items, dict): + return items.get(key, {}) + return items + @staticmethod def get_excluded( exclude: Union[Set, Dict, None], key: str = None ) -> Union[Set, Dict, None]: - if isinstance(exclude, dict): - return exclude.get(key, {}) - return exclude + return Excludable.get_child(items=exclude, key=key) @staticmethod def get_included( include: Union[Set, Dict, None], key: str = None ) -> Union[Set, Dict, None]: - return Excludable.get_excluded(exclude=include, key=key) + return Excludable.get_child(items=include, key=key) @staticmethod def is_excluded(exclude: Union[Set, Dict, None], key: str = None) -> bool: @@ -25,7 +31,7 @@ class Excludable: to_exclude = Excludable.get_excluded(exclude=exclude, key=key) if isinstance(to_exclude, Set): return key in to_exclude - elif to_exclude is ...: + if to_exclude is ...: return True return False @@ -38,6 +44,6 @@ class Excludable: to_include = Excludable.get_included(include=include, key=key) if isinstance(to_include, Set): return key in to_include - elif to_include is ...: + if to_include is ...: return True return False diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 06b5060..b36e495 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -1,10 +1,12 @@ import inspect from collections import OrderedDict from typing import ( + AbstractSet, Any, Callable, Dict, List, + Mapping, Optional, Sequence, Set, @@ -16,6 +18,7 @@ from typing import ( ) from ormar.exceptions import ModelPersistenceError, RelationshipInstanceError +from ormar.queryset.utils import translate_list_to_dict, update try: import orjson as json @@ -32,6 +35,9 @@ if TYPE_CHECKING: # pragma no cover from ormar.models import NewBaseModel T = TypeVar("T", bound=Model) + IntStr = Union[int, str] + AbstractSetIntStr = AbstractSet[IntStr] + MappingIntStrAny = Mapping[IntStr, Any] Field = TypeVar("Field", bound=BaseField) @@ -203,6 +209,21 @@ class ModelTableProxy: } return related_names + @classmethod + def _update_excluded_with_related_not_required( + cls, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny", None], + nested: bool = False, + ) -> Union[Set, Dict]: + exclude = exclude or {} + related_set = cls._exclude_related_names_not_required(nested=nested) + if isinstance(exclude, set): + exclude.union(related_set) + else: + related_dict = translate_list_to_dict(related_set) + exclude = update(related_dict, exclude) + return exclude + def _extract_model_db_fields(self) -> Dict: self_fields = self._extract_own_model_fields() self_fields = { diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 5acf4ca..8a9d518 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -27,6 +27,7 @@ from ormar.fields.foreign_key import ForeignKeyField from ormar.models.excludable import Excludable from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.modelproxy import ModelTableProxy +from ormar.queryset.utils import translate_list_to_dict from ormar.relations.alias_manager import AliasManager from ormar.relations.relation_manager import RelationsManager @@ -213,9 +214,7 @@ class NewBaseModel( @classmethod def get_properties( - cls, - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + cls, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None] ) -> List[str]: if isinstance(cls._props, list): props = cls._props @@ -234,11 +233,76 @@ class NewBaseModel( props = [prop for prop in props if prop not in exclude] return props - def dict( # noqa A003 + def _get_related_not_excluded_fields( + self, include: Optional[Dict], exclude: Optional[Dict], + ) -> List: + fields = [field for field in self.extract_related_names()] + if include: + fields = [field for field in fields if field in include] + if exclude: + fields = [ + field + for field in fields + if field not in exclude or exclude.get(field) is not Ellipsis + ] + return fields + + @staticmethod + def _extract_nested_models_from_list( + models: List, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None], + ) -> List: + result = [] + for model in models: + try: + result.append( + model.dict(nested=True, include=include, exclude=exclude,) + ) + except ReferenceError: # pragma no cover + continue + return result + + @staticmethod + def _skip_ellipsis( + items: Union[Set, Dict, None], key: str + ) -> Union[Set, Dict, None]: + result = Excludable.get_child(items, key) + return result if result is not Ellipsis else None + + def _extract_nested_models( # noqa: CCR001 + self, + nested: bool, + dict_instance: Dict, + include: Optional[Dict], + exclude: Optional[Dict], + ) -> Dict: + + fields = self._get_related_not_excluded_fields(include=include, exclude=exclude) + + for field in fields: + if self.Meta.model_fields[field].virtual and nested: + continue + nested_model = getattr(self, field) + if isinstance(nested_model, list): + dict_instance[field] = self._extract_nested_models_from_list( + models=nested_model, + include=self._skip_ellipsis(include, field), + exclude=self._skip_ellipsis(exclude, field), + ) + elif nested_model is not None: + dict_instance[field] = nested_model.dict( + nested=True, + include=self._skip_ellipsis(include, field), + exclude=self._skip_ellipsis(exclude, field), + ) + else: + dict_instance[field] = None + return dict_instance + + def dict( # type: ignore # noqa A003 self, *, - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + include: Union[Set, Dict] = None, + exclude: Union[Set, Dict] = None, by_alias: bool = False, skip_defaults: bool = None, exclude_unset: bool = False, @@ -248,30 +312,25 @@ class NewBaseModel( ) -> "DictStrAny": # noqa: A003' dict_instance = super().dict( include=include, - exclude=self._exclude_related_names_not_required(nested), + exclude=self._update_excluded_with_related_not_required(exclude, nested), by_alias=by_alias, skip_defaults=skip_defaults, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) - for field in self.extract_related_names(): - nested_model = getattr(self, field) - if self.Meta.model_fields[field].virtual and nested: - continue - if isinstance(nested_model, list): - result = [] - for model in nested_model: - try: - result.append(model.dict(nested=True)) - except ReferenceError: # pragma no cover - continue - dict_instance[field] = result - elif nested_model is not None: - dict_instance[field] = nested_model.dict(nested=True) - else: - dict_instance[field] = None + if include and isinstance(include, Set): + include = translate_list_to_dict(include) + if exclude and isinstance(exclude, Set): + exclude = translate_list_to_dict(exclude) + + dict_instance = self._extract_nested_models( + nested=nested, + dict_instance=dict_instance, + include=include, # type: ignore + exclude=exclude, # type: ignore + ) # include model properties as fields props = self.get_properties(include=include, exclude=exclude) diff --git a/ormar/queryset/utils.py b/ormar/queryset/utils.py index bed2e25..8cc7ec1 100644 --- a/ormar/queryset/utils.py +++ b/ormar/queryset/utils.py @@ -1,6 +1,15 @@ import collections.abc import copy -from typing import Any, Dict, List, Sequence, Set, TYPE_CHECKING, Type, Union +from typing import ( + Any, + Dict, + List, + Sequence, + Set, + TYPE_CHECKING, + Type, + Union, +) if TYPE_CHECKING: # pragma no cover from ormar import Model diff --git a/tests/test_dumping_model_to_dict.py b/tests/test_dumping_model_to_dict.py new file mode 100644 index 0000000..5fd9405 --- /dev/null +++ b/tests/test_dumping_model_to_dict.py @@ -0,0 +1,139 @@ +from typing import Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +metadata = sqlalchemy.MetaData() +database = databases.Database(DATABASE_URL, force_rollback=True) + + +class User(ormar.Model): + class Meta: + tablename: str = "users" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + email: str = ormar.String(max_length=255, nullable=False) + password: str = ormar.String(max_length=255, nullable=True) + first_name: str = ormar.String(max_length=255, nullable=False) + + +class Tier(ormar.Model): + class Meta: + tablename = "tiers" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Category(ormar.Model): + class Meta: + tablename = "categories" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + tier: Optional[Tier] = ormar.ForeignKey(Tier) + + +class Item(ormar.Model): + class Meta: + tablename = "items" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + category: Optional[Category] = ormar.ForeignKey(Category, nullable=True) + created_by: Optional[User] = ormar.ForeignKey(User) + + +@pytest.fixture(autouse=True, scope="module") +def sample_data(): + user = User(email="test@test.com", password="ijacids7^*&", first_name="Anna") + tier = Tier(name="Tier I") + category1 = Category(name="Toys", tier=tier) + category2 = Category(name="Weapons", tier=tier) + item1 = Item(name="Teddy Bear", category=category1, created_by=user) + item2 = Item(name="M16", category=category2, created_by=user) + return item1, item2 + + +def test_dumping_to_dict_no_exclusion(sample_data): + item1, item2 = sample_data + + dict1 = item1.dict() + assert dict1["name"] == "Teddy Bear" + assert dict1["category"]["name"] == "Toys" + assert dict1["category"]["tier"]['name'] == "Tier I" + assert dict1["created_by"]["email"] == "test@test.com" + + dict2 = item2.dict() + assert dict2["name"] == "M16" + assert dict2["category"]["name"] == "Weapons" + assert dict2["created_by"]["email"] == "test@test.com" + + +def test_dumping_to_dict_exclude_set(sample_data): + item1, item2 = sample_data + dict3 = item2.dict(exclude={"name"}) + assert "name" not in dict3 + assert dict3["category"]["name"] == "Weapons" + assert dict3["created_by"]["email"] == "test@test.com" + + dict4 = item2.dict(exclude={"category"}) + assert dict4["name"] == "M16" + assert "category" not in dict4 + assert dict4["created_by"]["email"] == "test@test.com" + + dict5 = item2.dict(exclude={"category", "name"}) + assert "name" not in dict5 + assert "category" not in dict5 + assert dict5["created_by"]["email"] == "test@test.com" + + +def test_dumping_to_dict_exclude_dict(sample_data): + item1, item2 = sample_data + dict6 = item2.dict(exclude={"category": {"name"}, "name": ...}) + assert "name" not in dict6 + assert "category" in dict6 + assert "name" not in dict6["category"] + assert dict6["created_by"]["email"] == "test@test.com" + + +def test_dumping_to_dict_exclude_nested_dict(sample_data): + item1, item2 = sample_data + dict1 = item2.dict(exclude={"category": {"tier": {"name"}}, "name": ...}) + assert "name" not in dict1 + assert "category" in dict1 + assert dict1["category"]['name'] == 'Weapons' + assert dict1["created_by"]["email"] == "test@test.com" + assert dict1["category"]["tier"].get('name') is None + + +def test_dumping_to_dict_exclude_and_include_nested_dict(sample_data): + item1, item2 = sample_data + dict1 = item2.dict(exclude={"category": {"tier": {"name"}}}, + include={'name', 'category'}) + assert dict1.get('name') == 'M16' + assert "category" in dict1 + assert dict1["category"]['name'] == 'Weapons' + assert "created_by" not in dict1 + assert dict1["category"]["tier"].get('name') is None + + dict2 = item1.dict(exclude={"id": ...}, + include={'name': ..., 'category': {'name': ..., 'tier': {'id'}}}) + assert dict2.get('name') == 'Teddy Bear' + assert dict2.get('id') is None # models not saved + assert dict2["category"]['name'] == 'Toys' + assert "created_by" not in dict1 + assert dict1["category"]["tier"].get('name') is None + assert dict1["category"]["tier"]['id'] is None diff --git a/tests/test_excluding_fields_in_fastapi.py b/tests/test_excluding_fields_in_fastapi.py new file mode 100644 index 0000000..56934f1 --- /dev/null +++ b/tests/test_excluding_fields_in_fastapi.py @@ -0,0 +1,178 @@ +import datetime +import string +import random + +import databases +import pydantic +import pytest +import sqlalchemy +from fastapi import FastAPI +from starlette.testclient import TestClient + +import ormar +from tests.settings import DATABASE_URL + +app = FastAPI() +metadata = sqlalchemy.MetaData() +database = databases.Database(DATABASE_URL, force_rollback=True) +app.state.database = database + + +@app.on_event("startup") +async def startup() -> None: + database_ = app.state.database + if not database_.is_connected: + await database_.connect() + + +@app.on_event("shutdown") +async def shutdown() -> None: + database_ = app.state.database + if database_.is_connected: + await database_.disconnect() + + +# note that you can set orm_mode here +# and in this case UserSchema become unnecessary +class UserBase(pydantic.BaseModel): + class Config: + orm_mode = True + + email: str + first_name: str + last_name: str + + +class UserCreateSchema(UserBase): + password: str + category: str + + +class UserSchema(UserBase): + class Config: + orm_mode = True + + +def gen_pass(): + choices = string.ascii_letters + string.digits + "!@#$%^&*()" + return "".join(random.choice(choices) for _ in range(20)) + + +class RandomModel(ormar.Model): + class Meta: + tablename: str = "random_users" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + password: str = ormar.String(max_length=255, default=gen_pass) + first_name: str = ormar.String(max_length=255, default='John') + last_name: str = ormar.String(max_length=255) + created_date: datetime.datetime = ormar.DateTime(server_default=sqlalchemy.func.now()) + + +class User(ormar.Model): + class Meta: + tablename: str = "users" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + email: str = ormar.String(max_length=255, nullable=False) + password: str = ormar.String(max_length=255, nullable=True) + first_name: str = ormar.String(max_length=255, nullable=False) + last_name: str = ormar.String(max_length=255, nullable=False) + category: str = ormar.String(max_length=255, nullable=True) + + +class User2(ormar.Model): + class Meta: + tablename: str = "users2" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + email: str = ormar.String(max_length=255, nullable=False) + password: str = ormar.String(max_length=255, nullable=False) + first_name: str = ormar.String(max_length=255, nullable=False) + last_name: str = ormar.String(max_length=255, nullable=False) + category: str = ormar.String(max_length=255, nullable=True) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@app.post("/users/", response_model=User, response_model_exclude={"password"}) +async def create_user(user: User): + return await user.save() + + +@app.post("/users2/", response_model=User) +async def create_user2(user: User): + user = await user.save() + return user.dict(exclude={'password'}) + + +@app.post("/users3/", response_model=UserBase) +async def create_user3(user: User2): + return await user.save() + + +@app.post("/users4/") +async def create_user4(user: User2): + user = await user.save() + return user.dict(exclude={'password'}) + + +@app.post("/random/", response_model=RandomModel) +async def create_user5(user: RandomModel): + return await user.save() + + +def test_all_endpoints(): + client = TestClient(app) + with client as client: + user = { + "email": "test@domain.com", + "password": "^*^%A*DA*IAAA", + "first_name": "John", + "last_name": "Doe", + } + response = client.post("/users/", json=user) + created_user = User(**response.json()) + assert created_user.pk is not None + assert created_user.password is None + + user2 = { + "email": "test@domain.com", + "first_name": "John", + "last_name": "Doe", + } + + response = client.post("/users/", json=user2) + created_user = User(**response.json()) + assert created_user.pk is not None + assert created_user.password is None + + response = client.post("/users2/", json=user) + created_user2 = User(**response.json()) + assert created_user2.pk is not None + assert created_user2.password is None + + # response has only 3 fields from UserBase + response = client.post("/users3/", json=user) + assert list(response.json().keys()) == ['email', 'first_name', 'last_name'] + + response = client.post("/users4/", json=user) + assert list(response.json().keys()) == ['id', 'email', 'first_name', 'last_name', 'category'] + + user3 = { + 'last_name': 'Test' + } + response = client.post("/random/", json=user3) + assert list(response.json().keys()) == ['id', 'password', 'first_name', 'last_name', 'created_date'] diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 34fe974..c4b9a4c 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -237,8 +237,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -246,8 +246,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -292,8 +292,8 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: @@ -325,8 +325,8 @@ async def test_pk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1 @@ -408,9 +408,11 @@ async def test_bulk_update_model_with_children(): album=best_seller, title="t4", position=1, play_count=500 ) - tracks = await Track.objects.select_related("album").filter( - play_count__gt=10 - ).all() + tracks = ( + await Track.objects.select_related("album") + .filter(play_count__gt=10) + .all() + ) best_seller_albums = {} for track in tracks: album = track.album @@ -421,5 +423,7 @@ async def test_bulk_update_model_with_children(): await Album.objects.bulk_update( best_seller_albums.values(), columns=["is_best_seller"] ) - best_seller_albums_db = await Album.objects.filter(is_best_seller=True).all() + best_seller_albums_db = await Album.objects.filter( + is_best_seller=True + ).all() assert len(best_seller_albums_db) == 2 diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index bfc09d7..d7ccea9 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -17,7 +17,7 @@ class RandomSet(ormar.Model): metadata = metadata database = database - id: int = ormar.Integer(name='random_id', primary_key=True) + id: int = ormar.Integer(name="random_id", primary_key=True) name: str = ormar.String(max_length=100) @@ -28,7 +28,7 @@ class Tonation(ormar.Model): database = database id: int = ormar.Integer(primary_key=True) - name: str = ormar.String(name='tonation_name', max_length=100) + name: str = ormar.String(name="tonation_name", max_length=100) rand_set: Optional[RandomSet] = ormar.ForeignKey(RandomSet) @@ -38,7 +38,7 @@ class Division(ormar.Model): metadata = metadata database = database - id: int = ormar.Integer(name='division_id', primary_key=True) + id: int = ormar.Integer(name="division_id", primary_key=True) name: str = ormar.String(max_length=100, nullable=True) @@ -77,11 +77,11 @@ class Track(ormar.Model): metadata = metadata database = database - id: int = ormar.Integer(name='track_id', primary_key=True) + id: int = ormar.Integer(name="track_id", primary_key=True) album: Optional[Album] = ormar.ForeignKey(Album) title: str = ormar.String(max_length=100) position: int = ormar.Integer() - tonation: Optional[Tonation] = ormar.ForeignKey(Tonation, name='tonation_id') + tonation: Optional[Tonation] = ormar.ForeignKey(Tonation, name="tonation_id") class Cover(ormar.Model): @@ -91,7 +91,9 @@ class Cover(ormar.Model): database = database id: int = ormar.Integer(primary_key=True) - album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures", name='album_id') + album: Optional[Album] = ormar.ForeignKey( + Album, related_name="cover_pictures", name="album_id" + ) title: str = ormar.String(max_length=100) artist: str = ormar.String(max_length=200, nullable=True) @@ -111,42 +113,71 @@ async def test_prefetch_related(): async with database.transaction(force_rollback=True): album = Album(name="Malibu") await album.save() - ton1 = await Tonation.objects.create(name='B-mol') - await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) - await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) - await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) - await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') - await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + ton1 = await Tonation.objects.create(name="B-mol") + await Track.objects.create( + album=album, title="The Bird", position=1, tonation=ton1 + ) + await Track.objects.create( + album=album, + title="Heart don't stand a chance", + position=2, + tonation=ton1, + ) + await Track.objects.create( + album=album, title="The Waters", position=3, tonation=ton1 + ) + await Cover.objects.create(title="Cover1", album=album, artist="Artist 1") + await Cover.objects.create(title="Cover2", album=album, artist="Artist 2") fantasies = Album(name="Fantasies") await fantasies.save() - await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.objects.create( + album=fantasies, title="Help I'm Alive", position=1 + ) await Track.objects.create(album=fantasies, title="Sick Muse", position=2) - await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) - await Cover.objects.create(title='Cover3', album=fantasies, artist='Artist 3') - await Cover.objects.create(title='Cover4', album=fantasies, artist='Artist 4') + await Track.objects.create( + album=fantasies, title="Satellite Mind", position=3 + ) + await Cover.objects.create( + title="Cover3", album=fantasies, artist="Artist 3" + ) + await Cover.objects.create( + title="Cover4", album=fantasies, artist="Artist 4" + ) - album = await Album.objects.filter(name='Malibu').prefetch_related( - ['tracks__tonation', 'cover_pictures']).get() + album = ( + await Album.objects.filter(name="Malibu") + .prefetch_related(["tracks__tonation", "cover_pictures"]) + .get() + ) assert len(album.tracks) == 3 - assert album.tracks[0].title == 'The Bird' + assert album.tracks[0].title == "The Bird" assert len(album.cover_pictures) == 2 - assert album.cover_pictures[0].title == 'Cover1' - assert album.tracks[0].tonation.name == album.tracks[2].tonation.name == 'B-mol' + assert album.cover_pictures[0].title == "Cover1" + assert ( + album.tracks[0].tonation.name + == album.tracks[2].tonation.name + == "B-mol" + ) - albums = await Album.objects.prefetch_related('tracks').all() + albums = await Album.objects.prefetch_related("tracks").all() assert len(albums[0].tracks) == 3 assert len(albums[1].tracks) == 3 assert albums[0].tracks[0].title == "The Bird" assert albums[1].tracks[0].title == "Help I'm Alive" - track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird") + track = await Track.objects.prefetch_related(["album__cover_pictures"]).get( + title="The Bird" + ) assert track.album.name == "Malibu" assert len(track.album.cover_pictures) == 2 - assert track.album.cover_pictures[0].artist == 'Artist 1' + assert track.album.cover_pictures[0].artist == "Artist 1" - track = await Track.objects.prefetch_related(["album__cover_pictures"]).exclude_fields( - 'album__cover_pictures__artist').get(title="The Bird") + track = ( + await Track.objects.prefetch_related(["album__cover_pictures"]) + .exclude_fields("album__cover_pictures__artist") + .get(title="The Bird") + ) assert track.album.name == "Malibu" assert len(track.album.cover_pictures) == 2 assert track.album.cover_pictures[0].artist is None @@ -159,29 +190,32 @@ async def test_prefetch_related(): async def test_prefetch_related_with_many_to_many(): async with database: async with database.transaction(force_rollback=True): - div = await Division.objects.create(name='Div 1') - shop1 = await Shop.objects.create(name='Shop 1', division=div) - shop2 = await Shop.objects.create(name='Shop 2', division=div) + div = await Division.objects.create(name="Div 1") + shop1 = await Shop.objects.create(name="Shop 1", division=div) + shop2 = await Shop.objects.create(name="Shop 2", division=div) album = Album(name="Malibu") await album.save() await album.shops.add(shop1) await album.shops.add(shop2) await Track.objects.create(album=album, title="The Bird", position=1) - await Track.objects.create(album=album, title="Heart don't stand a chance", position=2) + await Track.objects.create( + album=album, title="Heart don't stand a chance", position=2 + ) await Track.objects.create(album=album, title="The Waters", position=3) - await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') - await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + await Cover.objects.create(title="Cover1", album=album, artist="Artist 1") + await Cover.objects.create(title="Cover2", album=album, artist="Artist 2") - track = await Track.objects.prefetch_related(["album__cover_pictures", "album__shops__division"]).get( - title="The Bird") + track = await Track.objects.prefetch_related( + ["album__cover_pictures", "album__shops__division"] + ).get(title="The Bird") assert track.album.name == "Malibu" assert len(track.album.cover_pictures) == 2 - assert track.album.cover_pictures[0].artist == 'Artist 1' + assert track.album.cover_pictures[0].artist == "Artist 1" assert len(track.album.shops) == 2 - assert track.album.shops[0].name == 'Shop 1' - assert track.album.shops[0].division.name == 'Div 1' + assert track.album.shops[0].name == "Shop 1" + assert track.album.shops[0].division.name == "Div 1" album2 = Album(name="Malibu 2") await album2.save() @@ -190,14 +224,14 @@ async def test_prefetch_related_with_many_to_many(): await Track.objects.create(album=album2, title="The Bird 2", position=1) tracks = await Track.objects.prefetch_related(["album__shops"]).all() - assert tracks[0].album.name == 'Malibu' + assert tracks[0].album.name == "Malibu" assert tracks[0].album.shops[0].name == "Shop 1" - assert tracks[3].album.name == 'Malibu 2' + assert tracks[3].album.name == "Malibu 2" assert tracks[3].album.shops[0].name == "Shop 1" assert tracks[0].album.shops[0] == tracks[3].album.shops[0] assert id(tracks[0].album.shops[0]) == id(tracks[3].album.shops[0]) - tracks[0].album.shops[0].name = 'Dummy' + tracks[0].album.shops[0].name = "Dummy" assert tracks[0].album.shops[0].name == tracks[3].album.shops[0].name @@ -206,8 +240,10 @@ async def test_prefetch_related_empty(): async with database: async with database.transaction(force_rollback=True): await Track.objects.create(title="The Bird", position=1) - track = await Track.objects.prefetch_related(["album__cover_pictures"]).get(title="The Bird") - assert track.title == 'The Bird' + track = await Track.objects.prefetch_related(["album__cover_pictures"]).get( + title="The Bird" + ) + assert track.title == "The Bird" assert track.album is None @@ -215,91 +251,133 @@ async def test_prefetch_related_empty(): async def test_prefetch_related_with_select_related(): async with database: async with database.transaction(force_rollback=True): - div = await Division.objects.create(name='Div 1') - shop1 = await Shop.objects.create(name='Shop 1', division=div) - shop2 = await Shop.objects.create(name='Shop 2', division=div) + div = await Division.objects.create(name="Div 1") + shop1 = await Shop.objects.create(name="Shop 1", division=div) + shop2 = await Shop.objects.create(name="Shop 2", division=div) album = Album(name="Malibu") await album.save() await album.shops.add(shop1) await album.shops.add(shop2) - await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') - await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') + await Cover.objects.create(title="Cover1", album=album, artist="Artist 1") + await Cover.objects.create(title="Cover2", album=album, artist="Artist 2") - album = await Album.objects.select_related(['tracks', 'shops']).filter(name='Malibu').prefetch_related( - ['cover_pictures', 'shops__division']).get() + album = ( + await Album.objects.select_related(["tracks", "shops"]) + .filter(name="Malibu") + .prefetch_related(["cover_pictures", "shops__division"]) + .get() + ) assert len(album.tracks) == 0 assert len(album.cover_pictures) == 2 - assert album.shops[0].division.name == 'Div 1' + assert album.shops[0].division.name == "Div 1" - rand_set = await RandomSet.objects.create(name='Rand 1') - ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set) - await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) - await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) - await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) + rand_set = await RandomSet.objects.create(name="Rand 1") + ton1 = await Tonation.objects.create(name="B-mol", rand_set=rand_set) + await Track.objects.create( + album=album, title="The Bird", position=1, tonation=ton1 + ) + await Track.objects.create( + album=album, + title="Heart don't stand a chance", + position=2, + tonation=ton1, + ) + await Track.objects.create( + album=album, title="The Waters", position=3, tonation=ton1 + ) - album = await Album.objects.select_related('tracks__tonation__rand_set').filter( - name='Malibu').prefetch_related( - ['cover_pictures', 'shops__division']).order_by( - ['-shops__name', '-cover_pictures__artist', 'shops__division__name']).get() + album = ( + await Album.objects.select_related("tracks__tonation__rand_set") + .filter(name="Malibu") + .prefetch_related(["cover_pictures", "shops__division"]) + .order_by( + ["-shops__name", "-cover_pictures__artist", "shops__division__name"] + ) + .get() + ) assert len(album.tracks) == 3 assert album.tracks[0].tonation == album.tracks[2].tonation == ton1 assert len(album.cover_pictures) == 2 - assert album.cover_pictures[0].artist == 'Artist 2' + assert album.cover_pictures[0].artist == "Artist 2" assert len(album.shops) == 2 - assert album.shops[0].name == 'Shop 2' - assert album.shops[0].division.name == 'Div 1' + assert album.shops[0].name == "Shop 2" + assert album.shops[0].division.name == "Div 1" - track = await Track.objects.select_related('album').prefetch_related( - ["album__cover_pictures", "album__shops__division"]).get( - title="The Bird") + track = ( + await Track.objects.select_related("album") + .prefetch_related(["album__cover_pictures", "album__shops__division"]) + .get(title="The Bird") + ) assert track.album.name == "Malibu" assert len(track.album.cover_pictures) == 2 - assert track.album.cover_pictures[0].artist == 'Artist 1' + assert track.album.cover_pictures[0].artist == "Artist 1" assert len(track.album.shops) == 2 - assert track.album.shops[0].name == 'Shop 1' - assert track.album.shops[0].division.name == 'Div 1' + assert track.album.shops[0].name == "Shop 1" + assert track.album.shops[0].division.name == "Div 1" @pytest.mark.asyncio async def test_prefetch_related_with_select_related_and_fields(): async with database: async with database.transaction(force_rollback=True): - div = await Division.objects.create(name='Div 1') - shop1 = await Shop.objects.create(name='Shop 1', division=div) - shop2 = await Shop.objects.create(name='Shop 2', division=div) + div = await Division.objects.create(name="Div 1") + shop1 = await Shop.objects.create(name="Shop 1", division=div) + shop2 = await Shop.objects.create(name="Shop 2", division=div) album = Album(name="Malibu") await album.save() await album.shops.add(shop1) await album.shops.add(shop2) - await Cover.objects.create(title='Cover1', album=album, artist='Artist 1') - await Cover.objects.create(title='Cover2', album=album, artist='Artist 2') - rand_set = await RandomSet.objects.create(name='Rand 1') - ton1 = await Tonation.objects.create(name='B-mol', rand_set=rand_set) - await Track.objects.create(album=album, title="The Bird", position=1, tonation=ton1) - await Track.objects.create(album=album, title="Heart don't stand a chance", position=2, tonation=ton1) - await Track.objects.create(album=album, title="The Waters", position=3, tonation=ton1) + await Cover.objects.create(title="Cover1", album=album, artist="Artist 1") + await Cover.objects.create(title="Cover2", album=album, artist="Artist 2") + rand_set = await RandomSet.objects.create(name="Rand 1") + ton1 = await Tonation.objects.create(name="B-mol", rand_set=rand_set) + await Track.objects.create( + album=album, title="The Bird", position=1, tonation=ton1 + ) + await Track.objects.create( + album=album, + title="Heart don't stand a chance", + position=2, + tonation=ton1, + ) + await Track.objects.create( + album=album, title="The Waters", position=3, tonation=ton1 + ) - album = await Album.objects.select_related('tracks__tonation__rand_set').filter( - name='Malibu').prefetch_related( - ['cover_pictures', 'shops__division']).exclude_fields({'shops': {'division': {'name'}}}).get() + album = ( + await Album.objects.select_related("tracks__tonation__rand_set") + .filter(name="Malibu") + .prefetch_related(["cover_pictures", "shops__division"]) + .exclude_fields({"shops": {"division": {"name"}}}) + .get() + ) assert len(album.tracks) == 3 assert album.tracks[0].tonation == album.tracks[2].tonation == ton1 assert len(album.cover_pictures) == 2 - assert album.cover_pictures[0].artist == 'Artist 1' + assert album.cover_pictures[0].artist == "Artist 1" assert len(album.shops) == 2 - assert album.shops[0].name == 'Shop 1' + assert album.shops[0].name == "Shop 1" assert album.shops[0].division.name is None - album = await Album.objects.select_related('tracks').filter( - name='Malibu').prefetch_related( - ['cover_pictures', 'shops__division']).fields( - {'name': ..., 'shops': {'division'}, 'cover_pictures': {'id': ..., 'title': ...}} - ).exclude_fields({'shops': {'division': {'name'}}}).get() + album = ( + await Album.objects.select_related("tracks") + .filter(name="Malibu") + .prefetch_related(["cover_pictures", "shops__division"]) + .fields( + { + "name": ..., + "shops": {"division"}, + "cover_pictures": {"id": ..., "title": ...}, + } + ) + .exclude_fields({"shops": {"division": {"name"}}}) + .get() + ) assert len(album.tracks) == 3 assert len(album.cover_pictures) == 2 assert album.cover_pictures[0].artist is None diff --git a/tests/test_queryset_utils.py b/tests/test_queryset_utils.py index 97695b9..00492fd 100644 --- a/tests/test_queryset_utils.py +++ b/tests/test_queryset_utils.py @@ -121,24 +121,24 @@ class SortModel(ormar.Model): def test_sorting_models(): models = [ - SortModel(id=1, name='Alice', sort_order=0), - SortModel(id=2, name='Al', sort_order=1), - SortModel(id=3, name='Zake', sort_order=1), - SortModel(id=4, name='Will', sort_order=0), - SortModel(id=5, name='Al', sort_order=2), - SortModel(id=6, name='Alice', sort_order=2) + SortModel(id=1, name="Alice", sort_order=0), + SortModel(id=2, name="Al", sort_order=1), + SortModel(id=3, name="Zake", sort_order=1), + SortModel(id=4, name="Will", sort_order=0), + SortModel(id=5, name="Al", sort_order=2), + SortModel(id=6, name="Alice", sort_order=2), ] - orders_by = {'name': 'asc', 'none': {}, 'sort_order': 'desc'} + orders_by = {"name": "asc", "none": {}, "sort_order": "desc"} models = sort_models(models, orders_by) - assert models[5].name == 'Zake' - assert models[0].name == 'Al' - assert models[1].name == 'Al' + assert models[5].name == "Zake" + assert models[0].name == "Al" + assert models[1].name == "Al" assert [model.id for model in models] == [5, 2, 6, 1, 4, 3] - orders_by = {'name': 'asc', 'none': set('aa'), 'id': 'asc'} + orders_by = {"name": "asc", "none": set("aa"), "id": "asc"} models = sort_models(models, orders_by) assert [model.id for model in models] == [2, 5, 1, 6, 4, 3] - orders_by = {'sort_order': 'asc', 'none': ..., 'id': 'asc', 'uu': 2, 'aa': None} + orders_by = {"sort_order": "asc", "none": ..., "id": "asc", "uu": 2, "aa": None} models = sort_models(models, orders_by) assert [model.id for model in models] == [1, 4, 2, 3, 5, 6]