diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 3835b41..1ae257e 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -71,7 +71,12 @@ class ModelTableProxy: ) model_dict[field] = pk_value elif field_value: # nested dict - model_dict[field] = field_value.get(target_pkname) + if isinstance(field_value, list): + model_dict[field] = [ + target.get(target_pkname) for target in field_value + ] + else: + model_dict[field] = field_value.get(target_pkname) else: model_dict.pop(field, None) return model_dict @@ -231,7 +236,9 @@ class ModelTableProxy: @staticmethod def _populate_pk_column( - model: Type["Model"], columns: List[str], use_alias: bool = False, + model: Type["Model"], + columns: List[str], + use_alias: bool = False, ) -> List[str]: pk_alias = ( model.get_column_alias(model.Meta.pkname) diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 7015f7a..34fe974 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -20,6 +20,7 @@ class Album(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) + is_best_seller: bool = ormar.Boolean(default=False) class Track(ormar.Model): @@ -32,6 +33,7 @@ class Track(ormar.Model): album: Optional[Album] = ormar.ForeignKey(Album) title: str = ormar.String(max_length=100) position: int = ormar.Integer() + play_count: int = ormar.Integer(nullable=True, default=0) class Cover(ormar.Model): @@ -372,3 +374,52 @@ async def test_wrong_model_passed_as_fk(): with pytest.raises(RelationshipInstanceError): org = await Organisation.objects.create(ident="ACME Ltd") await Track.objects.create(album=org, title="Test1", position=1) + + +@pytest.mark.asyncio +async def test_bulk_update_model_with_no_children(): + async with database: + async with database.transaction(force_rollback=True): + album = await Album.objects.create(name="Test") + album.name = "Test2" + await Album.objects.bulk_update([album], columns=["name"]) + + updated_album = await Album.objects.get(id=album.id) + assert updated_album.name == "Test2" + + +@pytest.mark.asyncio +async def test_bulk_update_model_with_children(): + async with database: + async with database.transaction(force_rollback=True): + best_seller = await Album.objects.create(name="to_be_best_seller") + best_seller2 = await Album.objects.create(name="to_be_best_seller2") + not_best_seller = await Album.objects.create(name="unpopular") + await Track.objects.create( + album=best_seller, title="t1", position=1, play_count=100 + ) + await Track.objects.create( + album=best_seller2, title="t2", position=1, play_count=100 + ) + await Track.objects.create( + album=not_best_seller, title="t3", position=1, play_count=3 + ) + await Track.objects.create( + album=best_seller, title="t4", position=1, play_count=500 + ) + + tracks = await Track.objects.select_related("album").filter( + play_count__gt=10 + ).all() + best_seller_albums = {} + for track in tracks: + album = track.album + if album.id in best_seller_albums: + continue + album.is_best_seller = True + best_seller_albums[album.id] = album + await Album.objects.bulk_update( + best_seller_albums.values(), columns=["is_best_seller"] + ) + best_seller_albums_db = await Album.objects.filter(is_best_seller=True).all() + assert len(best_seller_albums_db) == 2