diff --git a/docs/releases.md b/docs/releases.md index 30e6fcb..9d98f9f 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,27 @@ +# 0.9.5 + +## Features +* Add `update` method to `QuerysetProxy` so now it's possible to update related models directly from parent model + in `ManyToMany` relations and in reverse `ForeignKey` relations. Note that update like in `QuerySet` `update` returns number of + updated models and **does not update related models in place** on praent model. To get the refreshed data on parent model you need to refresh + the related models (i.e. `await model_instance.related.all()`) +* Added possibility to add more fields on `Through` model for `ManyToMany` relationships: + * name of the through model field is the lowercase name of the Through class + * you can pass additional fields when calling `add(child, **kwargs)` on relation (on `QuerysetProxy`) + * you can pass additional fields when calling `create(**kwargs)` on relation (on `QuerysetProxy`) + when one of the keyword arguments should be the through model name with a dict of values + * you can order by on through model fields + * you can filter on through model fields + * you can include and exclude fields on through models + * through models are attached only to related models (i.e. if you query from A to B -> only on B) + * check the updated docs for more information + +# Other +* Updated docs and api docs +* Refactors and optimisations mainly related to filters and order bys + + + # 0.9.4 ## Fixes diff --git a/ormar/models/mixins/prefetch_mixin.py b/ormar/models/mixins/prefetch_mixin.py index 440052a..85faec2 100644 --- a/ormar/models/mixins/prefetch_mixin.py +++ b/ormar/models/mixins/prefetch_mixin.py @@ -18,10 +18,10 @@ class PrefetchQueryMixin(RelationMixin): @staticmethod def get_clause_target_and_filter_column_name( - parent_model: Type["Model"], - target_model: Type["Model"], - reverse: bool, - related: str, + parent_model: Type["Model"], + target_model: Type["Model"], + reverse: bool, + related: str, ) -> Tuple[Type["Model"], str]: """ Returns Model on which query clause should be performed and name of the column. @@ -51,7 +51,7 @@ class PrefetchQueryMixin(RelationMixin): @staticmethod def get_column_name_for_id_extraction( - parent_model: Type["Model"], reverse: bool, related: str, use_raw: bool, + parent_model: Type["Model"], reverse: bool, related: str, use_raw: bool, ) -> str: """ Returns name of the column that should be used to extract ids from model. diff --git a/ormar/protocols/queryset_protocol.py b/ormar/protocols/queryset_protocol.py index 7eb7092..397f58b 100644 --- a/ormar/protocols/queryset_protocol.py +++ b/ormar/protocols/queryset_protocol.py @@ -52,6 +52,9 @@ class QuerySetProtocol(Protocol): # pragma: nocover async def create(self, **kwargs: Any) -> "Model": ... + async def update(self, each: bool = False, **kwargs: Any) -> int: + ... + async def get_or_create(self, **kwargs: Any) -> "Model": ... diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 08d2675..533f92c 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -142,6 +142,7 @@ class PrefetchQuery: self.models: Dict = {} self.select_dict = translate_list_to_dict(self._select_related) self.orders_by = orders_by or [] + # TODO: refactor OrderActions to use it instead of strings from it self.order_dict = translate_list_to_dict( [x.query_str for x in self.orders_by], is_order=True ) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 2202abd..46d679a 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -573,17 +573,19 @@ class QuerySet(Generic[T]): :return: number of updated rows :rtype: int """ + if not each and not self.filter_clauses: + raise QueryDefinitionError( + "You cannot update without filtering the queryset first. " + "If you want to update all rows use update(each=True, **kwargs)" + ) + self_fields = self.model.extract_db_own_fields().union( self.model.extract_related_names() ) updates = {k: v for k, v in kwargs.items() if k in self_fields} updates = self.model.validate_choices(updates) updates = self.model.translate_columns_to_aliases(updates) - if not each and not self.filter_clauses: - raise QueryDefinitionError( - "You cannot update without filtering the queryset first. " - "If you want to update all rows use update(each=True, **kwargs)" - ) + expr = FilterQuery(filter_clauses=self.filter_clauses).apply( self.table.update().values(**updates) ) diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index 85b7832..157e72c 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -14,7 +14,7 @@ from typing import ( ) import ormar -from ormar.exceptions import ModelPersistenceError +from ormar.exceptions import ModelPersistenceError, QueryDefinitionError if TYPE_CHECKING: # pragma no cover from ormar.relations import Relation @@ -132,6 +132,22 @@ class QuerysetProxy(Generic[T]): # print("\n", expr.compile(compile_kwargs={"literal_binds": True})) await model_cls.Meta.database.execute(expr) + async def update_through_instance(self, child: "T", **kwargs: Any) -> None: + """ + Updates a through model instance in the database for m2m relations. + + :param kwargs: dict of additional keyword arguments for through instance + :type kwargs: Any + :param child: child model instance + :type child: Model + """ + model_cls = self.relation.through + owner_column = self.related_field.default_target_field_name() # type: ignore + child_column = self.related_field.default_source_field_name() # type: ignore + rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk} + through_model = await model_cls.objects.get(**rel_kwargs) + await through_model.update(**kwargs) + async def delete_through_instance(self, child: "T") -> None: """ Removes through model instance from the database for m2m relations. @@ -290,6 +306,39 @@ class QuerysetProxy(Generic[T]): await self.create_through_instance(created, **through_kwargs) return created + async def update(self, each: bool = False, **kwargs: Any) -> int: + """ + Updates the model table after applying the filters from kwargs. + + You have to either pass a filter to narrow down a query or explicitly pass + each=True flag to affect whole table. + + :param each: flag if whole table should be affected if no filter is passed + :type each: bool + :param kwargs: fields names and proper value types + :type kwargs: Any + :return: number of updated rows + :rtype: int + """ + # queryset proxy always have one filter for pk of parent model + if not each and len(self.queryset.filter_clauses) == 1: + raise QueryDefinitionError( + "You cannot update without filtering the queryset first. " + "If you want to update all rows use update(each=True, **kwargs)" + ) + + through_kwargs = kwargs.pop(self.through_model_name, {}) + children = await self.queryset.all() + for child in children: + if child: + await child.update(**kwargs) + if self.type_ == ormar.RelationType.MULTIPLE and through_kwargs: + await self.update_through_instance( + child=child, # type: ignore + **through_kwargs, + ) + return len(children) + async def get_or_create(self, **kwargs: Any) -> "T": """ Combination of create and get methods. diff --git a/tests/test_m2m_through_fields.py b/tests/test_m2m_through_fields.py index 5c1b44f..898f103 100644 --- a/tests/test_m2m_through_fields.py +++ b/tests/test_m2m_through_fields.py @@ -235,6 +235,64 @@ async def test_ordering_by_through_model() -> Any: assert post3.categories[2].postcategory.param_name == "volume" +@pytest.mark.asyncio +async def test_update_through_models_from_queryset_on_through() -> Any: + async with database: + post = await Post(title="Test post").save() + await post.categories.create( + name="Test category1", + postcategory={"sort_order": 2, "param_name": "volume"}, + ) + await post.categories.create( + name="Test category2", postcategory={"sort_order": 1, "param_name": "area"} + ) + await post.categories.create( + name="Test category3", + postcategory={"sort_order": 3, "param_name": "velocity"}, + ) + + await PostCategory.objects.filter(param_name="volume", post=post.id).update( + sort_order=4 + ) + post2 = ( + await Post.objects.select_related("categories") + .order_by("-postcategory__sort_order") + .get() + ) + assert len(post2.categories) == 3 + assert post2.categories[0].postcategory.param_name == "volume" + assert post2.categories[2].postcategory.param_name == "area" + + +@pytest.mark.asyncio +async def test_update_through_from_related() -> Any: + async with database: + post = await Post(title="Test post").save() + await post.categories.create( + name="Test category1", + postcategory={"sort_order": 2, "param_name": "volume"}, + ) + await post.categories.create( + name="Test category2", postcategory={"sort_order": 1, "param_name": "area"} + ) + await post.categories.create( + name="Test category3", + postcategory={"sort_order": 3, "param_name": "velocity"}, + ) + + await post.categories.filter(name="Test category3").update( + postcategory={"sort_order": 4} + ) + + post2 = ( + await Post.objects.select_related("categories") + .order_by("postcategory__sort_order") + .get() + ) + assert len(post2.categories) == 3 + assert post2.categories[2].postcategory.sort_order == 4 + + # TODO: check/ modify following # add to fields with class lower name (V) @@ -245,11 +303,12 @@ async def test_ordering_by_through_model() -> Any: # accessing from instance (V) <- no both sides only nested one is relevant, fix one side # filtering in filter (through name normally) (V) < - table prefix from normal relation, # check if is_through needed, resolved side of relation -# ordering by in order_by +# ordering by in order_by (V) +# updating in query (V) +# updating from querysetproxy (V) +# modifying from instance (both sides?) (X) <= no, the loaded one doesn't have relations -# updating in query -# modifying from instance (both sides?) # including/excluding in fields? # allowing to change fk fields names in through model? # make through optional? auto-generated for cases other fields are missing? diff --git a/tests/test_queryproxy_on_m2m_models.py b/tests/test_queryproxy_on_m2m_models.py index d33aa5d..a91c4f8 100644 --- a/tests/test_queryproxy_on_m2m_models.py +++ b/tests/test_queryproxy_on_m2m_models.py @@ -6,6 +6,7 @@ import pytest import sqlalchemy import ormar +from ormar.exceptions import QueryDefinitionError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -180,3 +181,42 @@ async def test_queryset_methods(): assert len(categories) == 3 == len(post.categories) for cat in post.categories: assert cat.subject.name is not None + + +@pytest.mark.asyncio +async def test_queryset_update(): + async with database: + async with database.transaction(force_rollback=True): + guido = await Author.objects.create( + first_name="Guido", last_name="Van Rossum" + ) + subject = await Subject(name="Random").save() + post = await Post.objects.create(title="Hello, M2M", author=guido) + await post.categories.create(name="News", sort_order=1, subject=subject) + await post.categories.create(name="Breaking", sort_order=3, subject=subject) + + await post.categories.order_by("sort_order").all() + assert len(post.categories) == 2 + assert post.categories[0].sort_order == 1 + assert post.categories[0].name == "News" + assert post.categories[1].sort_order == 3 + assert post.categories[1].name == "Breaking" + + updated = await post.categories.update(each=True, name="Test") + assert updated == 2 + + await post.categories.order_by("sort_order").all() + assert len(post.categories) == 2 + assert post.categories[0].name == "Test" + assert post.categories[1].name == "Test" + + updated = await post.categories.filter(sort_order=3).update(name="Test 2") + assert updated == 1 + + await post.categories.order_by("sort_order").all() + assert len(post.categories) == 2 + assert post.categories[0].name == "Test" + assert post.categories[1].name == "Test 2" + + with pytest.raises(QueryDefinitionError): + await post.categories.update(name="Test WRONG") diff --git a/tests/test_wekref_exclusion.py b/tests/test_wekref_exclusion.py new file mode 100644 index 0000000..a1140f7 --- /dev/null +++ b/tests/test_wekref_exclusion.py @@ -0,0 +1,147 @@ +from typing import List, Optional +from uuid import UUID, uuid4 + +import databases +import pydantic +import pytest +import sqlalchemy +from fastapi import FastAPI +from starlette.testclient import TestClient + +import ormar +from tests.settings import DATABASE_URL + +app = FastAPI() + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + +app.state.database = database + + +@app.on_event("startup") +async def startup() -> None: + database_ = app.state.database + if not database_.is_connected: + await database_.connect() + + +@app.on_event("shutdown") +async def shutdown() -> None: + database_ = app.state.database + if database_.is_connected: + await database_.disconnect() + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +class BaseMeta(ormar.ModelMeta): + database = database + metadata = metadata + + +class OtherThing(ormar.Model): + class Meta(BaseMeta): + tablename = "other_things" + + id: UUID = ormar.UUID(primary_key=True, default=uuid4) + name: str = ormar.Text(default="") + ot_contents: str = ormar.Text(default="") + + +class Thing(ormar.Model): + class Meta(BaseMeta): + tablename = "things" + + id: UUID = ormar.UUID(primary_key=True, default=uuid4) + name: str = ormar.Text(default="") + js: pydantic.Json = ormar.JSON(nullable=True) + other_thing: Optional[OtherThing] = ormar.ForeignKey(OtherThing, nullable=True) + + +@app.post("/test/1") +async def post_test_1(): + # don't split initialization and attribute assignment + ot = await OtherThing(ot_contents="otc").save() + await Thing(other_thing=ot, name="t1").save() + await Thing(other_thing=ot, name="t2").save() + await Thing(other_thing=ot, name="t3").save() + + # if you do not care about returned object you can even go with bulk_create + # all of them are created in one transaction + # things = [Thing(other_thing=ot, name='t1'), + # Thing(other_thing=ot, name="t2"), + # Thing(other_thing=ot, name="t3")] + # await Thing.objects.bulk_create(things) + + +@app.get("/test/2", response_model=List[Thing]) +async def get_test_2(): + # if you only query for one use get or first + ot = await OtherThing.objects.get() + ts = await ot.things.all() + # specifically null out the relation on things before return + for t in ts: + t.remove(ot, name="other_thing") + return ts + + +@app.get("/test/3", response_model=List[Thing]) +async def get_test_3(): + ot = await OtherThing.objects.select_related("things").get() + # exclude unwanted field while ot is still in scope + # in order not to pass it to fastapi + return [t.dict(exclude={"other_thing"}) for t in ot.things] + + +@app.get("/test/4", response_model=List[Thing], response_model_exclude={"other_thing"}) +async def get_test_4(): + ot = await OtherThing.objects.get() + # query from the active side + return await Thing.objects.all(other_thing=ot) + + +@app.get("/get_ot/", response_model=OtherThing) +async def get_ot(): + return await OtherThing.objects.get() + + +# more real life (usually) is not getting some random OT and get it's Things +# but query for a specific one by some kind of id +@app.get( + "/test/5/{thing_id}", + response_model=List[Thing], + response_model_exclude={"other_thing"}, +) +async def get_test_5(thing_id: UUID): + return await Thing.objects.all(other_thing__id=thing_id) + + +def test_endpoints(): + client = TestClient(app) + with client: + resp = client.post("/test/1") + assert resp.status_code == 200 + + resp2 = client.get("/test/2") + assert resp2.status_code == 200 + assert len(resp2.json()) == 3 + + resp3 = client.get("/test/3") + assert resp3.status_code == 200 + assert len(resp3.json()) == 3 + + resp4 = client.get("/test/4") + assert resp4.status_code == 200 + assert len(resp4.json()) == 3 + + ot = OtherThing(**client.get("/get_ot/").json()) + resp5 = client.get(f"/test/5/{ot.id}") + assert resp5.status_code == 200 + assert len(resp5.json()) == 3