add update to queryset, add update_through_instance, start to update docs
This commit is contained in:
@ -1,3 +1,27 @@
|
||||
# 0.9.5
|
||||
|
||||
## Features
|
||||
* Add `update` method to `QuerysetProxy` so now it's possible to update related models directly from parent model
|
||||
in `ManyToMany` relations and in reverse `ForeignKey` relations. Note that update like in `QuerySet` `update` returns number of
|
||||
updated models and **does not update related models in place** on praent model. To get the refreshed data on parent model you need to refresh
|
||||
the related models (i.e. `await model_instance.related.all()`)
|
||||
* Added possibility to add more fields on `Through` model for `ManyToMany` relationships:
|
||||
* name of the through model field is the lowercase name of the Through class
|
||||
* you can pass additional fields when calling `add(child, **kwargs)` on relation (on `QuerysetProxy`)
|
||||
* you can pass additional fields when calling `create(**kwargs)` on relation (on `QuerysetProxy`)
|
||||
when one of the keyword arguments should be the through model name with a dict of values
|
||||
* you can order by on through model fields
|
||||
* you can filter on through model fields
|
||||
* you can include and exclude fields on through models
|
||||
* through models are attached only to related models (i.e. if you query from A to B -> only on B)
|
||||
* check the updated docs for more information
|
||||
|
||||
# Other
|
||||
* Updated docs and api docs
|
||||
* Refactors and optimisations mainly related to filters and order bys
|
||||
|
||||
|
||||
|
||||
# 0.9.4
|
||||
|
||||
## Fixes
|
||||
|
||||
@ -52,6 +52,9 @@ class QuerySetProtocol(Protocol): # pragma: nocover
|
||||
async def create(self, **kwargs: Any) -> "Model":
|
||||
...
|
||||
|
||||
async def update(self, each: bool = False, **kwargs: Any) -> int:
|
||||
...
|
||||
|
||||
async def get_or_create(self, **kwargs: Any) -> "Model":
|
||||
...
|
||||
|
||||
|
||||
@ -142,6 +142,7 @@ class PrefetchQuery:
|
||||
self.models: Dict = {}
|
||||
self.select_dict = translate_list_to_dict(self._select_related)
|
||||
self.orders_by = orders_by or []
|
||||
# TODO: refactor OrderActions to use it instead of strings from it
|
||||
self.order_dict = translate_list_to_dict(
|
||||
[x.query_str for x in self.orders_by], is_order=True
|
||||
)
|
||||
|
||||
@ -573,17 +573,19 @@ class QuerySet(Generic[T]):
|
||||
:return: number of updated rows
|
||||
:rtype: int
|
||||
"""
|
||||
if not each and not self.filter_clauses:
|
||||
raise QueryDefinitionError(
|
||||
"You cannot update without filtering the queryset first. "
|
||||
"If you want to update all rows use update(each=True, **kwargs)"
|
||||
)
|
||||
|
||||
self_fields = self.model.extract_db_own_fields().union(
|
||||
self.model.extract_related_names()
|
||||
)
|
||||
updates = {k: v for k, v in kwargs.items() if k in self_fields}
|
||||
updates = self.model.validate_choices(updates)
|
||||
updates = self.model.translate_columns_to_aliases(updates)
|
||||
if not each and not self.filter_clauses:
|
||||
raise QueryDefinitionError(
|
||||
"You cannot update without filtering the queryset first. "
|
||||
"If you want to update all rows use update(each=True, **kwargs)"
|
||||
)
|
||||
|
||||
expr = FilterQuery(filter_clauses=self.filter_clauses).apply(
|
||||
self.table.update().values(**updates)
|
||||
)
|
||||
|
||||
@ -14,7 +14,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import ModelPersistenceError
|
||||
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar.relations import Relation
|
||||
@ -132,6 +132,22 @@ class QuerysetProxy(Generic[T]):
|
||||
# print("\n", expr.compile(compile_kwargs={"literal_binds": True}))
|
||||
await model_cls.Meta.database.execute(expr)
|
||||
|
||||
async def update_through_instance(self, child: "T", **kwargs: Any) -> None:
|
||||
"""
|
||||
Updates a through model instance in the database for m2m relations.
|
||||
|
||||
:param kwargs: dict of additional keyword arguments for through instance
|
||||
:type kwargs: Any
|
||||
:param child: child model instance
|
||||
:type child: Model
|
||||
"""
|
||||
model_cls = self.relation.through
|
||||
owner_column = self.related_field.default_target_field_name() # type: ignore
|
||||
child_column = self.related_field.default_source_field_name() # type: ignore
|
||||
rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk}
|
||||
through_model = await model_cls.objects.get(**rel_kwargs)
|
||||
await through_model.update(**kwargs)
|
||||
|
||||
async def delete_through_instance(self, child: "T") -> None:
|
||||
"""
|
||||
Removes through model instance from the database for m2m relations.
|
||||
@ -290,6 +306,39 @@ class QuerysetProxy(Generic[T]):
|
||||
await self.create_through_instance(created, **through_kwargs)
|
||||
return created
|
||||
|
||||
async def update(self, each: bool = False, **kwargs: Any) -> int:
|
||||
"""
|
||||
Updates the model table after applying the filters from kwargs.
|
||||
|
||||
You have to either pass a filter to narrow down a query or explicitly pass
|
||||
each=True flag to affect whole table.
|
||||
|
||||
:param each: flag if whole table should be affected if no filter is passed
|
||||
:type each: bool
|
||||
:param kwargs: fields names and proper value types
|
||||
:type kwargs: Any
|
||||
:return: number of updated rows
|
||||
:rtype: int
|
||||
"""
|
||||
# queryset proxy always have one filter for pk of parent model
|
||||
if not each and len(self.queryset.filter_clauses) == 1:
|
||||
raise QueryDefinitionError(
|
||||
"You cannot update without filtering the queryset first. "
|
||||
"If you want to update all rows use update(each=True, **kwargs)"
|
||||
)
|
||||
|
||||
through_kwargs = kwargs.pop(self.through_model_name, {})
|
||||
children = await self.queryset.all()
|
||||
for child in children:
|
||||
if child:
|
||||
await child.update(**kwargs)
|
||||
if self.type_ == ormar.RelationType.MULTIPLE and through_kwargs:
|
||||
await self.update_through_instance(
|
||||
child=child, # type: ignore
|
||||
**through_kwargs,
|
||||
)
|
||||
return len(children)
|
||||
|
||||
async def get_or_create(self, **kwargs: Any) -> "T":
|
||||
"""
|
||||
Combination of create and get methods.
|
||||
|
||||
@ -235,6 +235,64 @@ async def test_ordering_by_through_model() -> Any:
|
||||
assert post3.categories[2].postcategory.param_name == "volume"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_through_models_from_queryset_on_through() -> Any:
|
||||
async with database:
|
||||
post = await Post(title="Test post").save()
|
||||
await post.categories.create(
|
||||
name="Test category1",
|
||||
postcategory={"sort_order": 2, "param_name": "volume"},
|
||||
)
|
||||
await post.categories.create(
|
||||
name="Test category2", postcategory={"sort_order": 1, "param_name": "area"}
|
||||
)
|
||||
await post.categories.create(
|
||||
name="Test category3",
|
||||
postcategory={"sort_order": 3, "param_name": "velocity"},
|
||||
)
|
||||
|
||||
await PostCategory.objects.filter(param_name="volume", post=post.id).update(
|
||||
sort_order=4
|
||||
)
|
||||
post2 = (
|
||||
await Post.objects.select_related("categories")
|
||||
.order_by("-postcategory__sort_order")
|
||||
.get()
|
||||
)
|
||||
assert len(post2.categories) == 3
|
||||
assert post2.categories[0].postcategory.param_name == "volume"
|
||||
assert post2.categories[2].postcategory.param_name == "area"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_through_from_related() -> Any:
|
||||
async with database:
|
||||
post = await Post(title="Test post").save()
|
||||
await post.categories.create(
|
||||
name="Test category1",
|
||||
postcategory={"sort_order": 2, "param_name": "volume"},
|
||||
)
|
||||
await post.categories.create(
|
||||
name="Test category2", postcategory={"sort_order": 1, "param_name": "area"}
|
||||
)
|
||||
await post.categories.create(
|
||||
name="Test category3",
|
||||
postcategory={"sort_order": 3, "param_name": "velocity"},
|
||||
)
|
||||
|
||||
await post.categories.filter(name="Test category3").update(
|
||||
postcategory={"sort_order": 4}
|
||||
)
|
||||
|
||||
post2 = (
|
||||
await Post.objects.select_related("categories")
|
||||
.order_by("postcategory__sort_order")
|
||||
.get()
|
||||
)
|
||||
assert len(post2.categories) == 3
|
||||
assert post2.categories[2].postcategory.sort_order == 4
|
||||
|
||||
|
||||
# TODO: check/ modify following
|
||||
|
||||
# add to fields with class lower name (V)
|
||||
@ -245,11 +303,12 @@ async def test_ordering_by_through_model() -> Any:
|
||||
# accessing from instance (V) <- no both sides only nested one is relevant, fix one side
|
||||
# filtering in filter (through name normally) (V) < - table prefix from normal relation,
|
||||
# check if is_through needed, resolved side of relation
|
||||
# ordering by in order_by
|
||||
# ordering by in order_by (V)
|
||||
# updating in query (V)
|
||||
# updating from querysetproxy (V)
|
||||
|
||||
# modifying from instance (both sides?) (X) <= no, the loaded one doesn't have relations
|
||||
|
||||
# updating in query
|
||||
# modifying from instance (both sides?)
|
||||
# including/excluding in fields?
|
||||
# allowing to change fk fields names in through model?
|
||||
# make through optional? auto-generated for cases other fields are missing?
|
||||
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import QueryDefinitionError
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
@ -180,3 +181,42 @@ async def test_queryset_methods():
|
||||
assert len(categories) == 3 == len(post.categories)
|
||||
for cat in post.categories:
|
||||
assert cat.subject.name is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queryset_update():
|
||||
async with database:
|
||||
async with database.transaction(force_rollback=True):
|
||||
guido = await Author.objects.create(
|
||||
first_name="Guido", last_name="Van Rossum"
|
||||
)
|
||||
subject = await Subject(name="Random").save()
|
||||
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||
await post.categories.create(name="News", sort_order=1, subject=subject)
|
||||
await post.categories.create(name="Breaking", sort_order=3, subject=subject)
|
||||
|
||||
await post.categories.order_by("sort_order").all()
|
||||
assert len(post.categories) == 2
|
||||
assert post.categories[0].sort_order == 1
|
||||
assert post.categories[0].name == "News"
|
||||
assert post.categories[1].sort_order == 3
|
||||
assert post.categories[1].name == "Breaking"
|
||||
|
||||
updated = await post.categories.update(each=True, name="Test")
|
||||
assert updated == 2
|
||||
|
||||
await post.categories.order_by("sort_order").all()
|
||||
assert len(post.categories) == 2
|
||||
assert post.categories[0].name == "Test"
|
||||
assert post.categories[1].name == "Test"
|
||||
|
||||
updated = await post.categories.filter(sort_order=3).update(name="Test 2")
|
||||
assert updated == 1
|
||||
|
||||
await post.categories.order_by("sort_order").all()
|
||||
assert len(post.categories) == 2
|
||||
assert post.categories[0].name == "Test"
|
||||
assert post.categories[1].name == "Test 2"
|
||||
|
||||
with pytest.raises(QueryDefinitionError):
|
||||
await post.categories.update(name="Test WRONG")
|
||||
|
||||
147
tests/test_wekref_exclusion.py
Normal file
147
tests/test_wekref_exclusion.py
Normal file
@ -0,0 +1,147 @@
|
||||
from typing import List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import databases
|
||||
import pydantic
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
from fastapi import FastAPI
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
import ormar
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
app.state.database = database
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup() -> None:
|
||||
database_ = app.state.database
|
||||
if not database_.is_connected:
|
||||
await database_.connect()
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown() -> None:
|
||||
database_ = app.state.database
|
||||
if database_.is_connected:
|
||||
await database_.disconnect()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def create_test_database():
|
||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||
metadata.create_all(engine)
|
||||
yield
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
class BaseMeta(ormar.ModelMeta):
|
||||
database = database
|
||||
metadata = metadata
|
||||
|
||||
|
||||
class OtherThing(ormar.Model):
|
||||
class Meta(BaseMeta):
|
||||
tablename = "other_things"
|
||||
|
||||
id: UUID = ormar.UUID(primary_key=True, default=uuid4)
|
||||
name: str = ormar.Text(default="")
|
||||
ot_contents: str = ormar.Text(default="")
|
||||
|
||||
|
||||
class Thing(ormar.Model):
|
||||
class Meta(BaseMeta):
|
||||
tablename = "things"
|
||||
|
||||
id: UUID = ormar.UUID(primary_key=True, default=uuid4)
|
||||
name: str = ormar.Text(default="")
|
||||
js: pydantic.Json = ormar.JSON(nullable=True)
|
||||
other_thing: Optional[OtherThing] = ormar.ForeignKey(OtherThing, nullable=True)
|
||||
|
||||
|
||||
@app.post("/test/1")
|
||||
async def post_test_1():
|
||||
# don't split initialization and attribute assignment
|
||||
ot = await OtherThing(ot_contents="otc").save()
|
||||
await Thing(other_thing=ot, name="t1").save()
|
||||
await Thing(other_thing=ot, name="t2").save()
|
||||
await Thing(other_thing=ot, name="t3").save()
|
||||
|
||||
# if you do not care about returned object you can even go with bulk_create
|
||||
# all of them are created in one transaction
|
||||
# things = [Thing(other_thing=ot, name='t1'),
|
||||
# Thing(other_thing=ot, name="t2"),
|
||||
# Thing(other_thing=ot, name="t3")]
|
||||
# await Thing.objects.bulk_create(things)
|
||||
|
||||
|
||||
@app.get("/test/2", response_model=List[Thing])
|
||||
async def get_test_2():
|
||||
# if you only query for one use get or first
|
||||
ot = await OtherThing.objects.get()
|
||||
ts = await ot.things.all()
|
||||
# specifically null out the relation on things before return
|
||||
for t in ts:
|
||||
t.remove(ot, name="other_thing")
|
||||
return ts
|
||||
|
||||
|
||||
@app.get("/test/3", response_model=List[Thing])
|
||||
async def get_test_3():
|
||||
ot = await OtherThing.objects.select_related("things").get()
|
||||
# exclude unwanted field while ot is still in scope
|
||||
# in order not to pass it to fastapi
|
||||
return [t.dict(exclude={"other_thing"}) for t in ot.things]
|
||||
|
||||
|
||||
@app.get("/test/4", response_model=List[Thing], response_model_exclude={"other_thing"})
|
||||
async def get_test_4():
|
||||
ot = await OtherThing.objects.get()
|
||||
# query from the active side
|
||||
return await Thing.objects.all(other_thing=ot)
|
||||
|
||||
|
||||
@app.get("/get_ot/", response_model=OtherThing)
|
||||
async def get_ot():
|
||||
return await OtherThing.objects.get()
|
||||
|
||||
|
||||
# more real life (usually) is not getting some random OT and get it's Things
|
||||
# but query for a specific one by some kind of id
|
||||
@app.get(
|
||||
"/test/5/{thing_id}",
|
||||
response_model=List[Thing],
|
||||
response_model_exclude={"other_thing"},
|
||||
)
|
||||
async def get_test_5(thing_id: UUID):
|
||||
return await Thing.objects.all(other_thing__id=thing_id)
|
||||
|
||||
|
||||
def test_endpoints():
|
||||
client = TestClient(app)
|
||||
with client:
|
||||
resp = client.post("/test/1")
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp2 = client.get("/test/2")
|
||||
assert resp2.status_code == 200
|
||||
assert len(resp2.json()) == 3
|
||||
|
||||
resp3 = client.get("/test/3")
|
||||
assert resp3.status_code == 200
|
||||
assert len(resp3.json()) == 3
|
||||
|
||||
resp4 = client.get("/test/4")
|
||||
assert resp4.status_code == 200
|
||||
assert len(resp4.json()) == 3
|
||||
|
||||
ot = OtherThing(**client.get("/get_ot/").json())
|
||||
resp5 = client.get(f"/test/5/{ot.id}")
|
||||
assert resp5.status_code == 200
|
||||
assert len(resp5.json()) == 3
|
||||
Reference in New Issue
Block a user