add update to queryset, add update_through_instance, start to update docs

This commit is contained in:
collerek
2021-02-26 11:28:44 +01:00
parent 503f589fa7
commit 7bf781098f
9 changed files with 339 additions and 14 deletions

View File

@ -1,3 +1,27 @@
# 0.9.5
## Features
* Add `update` method to `QuerysetProxy` so now it's possible to update related models directly from parent model
in `ManyToMany` relations and in reverse `ForeignKey` relations. Note that update like in `QuerySet` `update` returns number of
updated models and **does not update related models in place** on praent model. To get the refreshed data on parent model you need to refresh
the related models (i.e. `await model_instance.related.all()`)
* Added possibility to add more fields on `Through` model for `ManyToMany` relationships:
* name of the through model field is the lowercase name of the Through class
* you can pass additional fields when calling `add(child, **kwargs)` on relation (on `QuerysetProxy`)
* you can pass additional fields when calling `create(**kwargs)` on relation (on `QuerysetProxy`)
when one of the keyword arguments should be the through model name with a dict of values
* you can order by on through model fields
* you can filter on through model fields
* you can include and exclude fields on through models
* through models are attached only to related models (i.e. if you query from A to B -> only on B)
* check the updated docs for more information
# Other
* Updated docs and api docs
* Refactors and optimisations mainly related to filters and order bys
# 0.9.4
## Fixes

View File

@ -18,10 +18,10 @@ class PrefetchQueryMixin(RelationMixin):
@staticmethod
def get_clause_target_and_filter_column_name(
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool,
related: str,
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool,
related: str,
) -> Tuple[Type["Model"], str]:
"""
Returns Model on which query clause should be performed and name of the column.
@ -51,7 +51,7 @@ class PrefetchQueryMixin(RelationMixin):
@staticmethod
def get_column_name_for_id_extraction(
parent_model: Type["Model"], reverse: bool, related: str, use_raw: bool,
parent_model: Type["Model"], reverse: bool, related: str, use_raw: bool,
) -> str:
"""
Returns name of the column that should be used to extract ids from model.

View File

@ -52,6 +52,9 @@ class QuerySetProtocol(Protocol): # pragma: nocover
async def create(self, **kwargs: Any) -> "Model":
...
async def update(self, each: bool = False, **kwargs: Any) -> int:
...
async def get_or_create(self, **kwargs: Any) -> "Model":
...

View File

@ -142,6 +142,7 @@ class PrefetchQuery:
self.models: Dict = {}
self.select_dict = translate_list_to_dict(self._select_related)
self.orders_by = orders_by or []
# TODO: refactor OrderActions to use it instead of strings from it
self.order_dict = translate_list_to_dict(
[x.query_str for x in self.orders_by], is_order=True
)

View File

@ -573,17 +573,19 @@ class QuerySet(Generic[T]):
:return: number of updated rows
:rtype: int
"""
if not each and not self.filter_clauses:
raise QueryDefinitionError(
"You cannot update without filtering the queryset first. "
"If you want to update all rows use update(each=True, **kwargs)"
)
self_fields = self.model.extract_db_own_fields().union(
self.model.extract_related_names()
)
updates = {k: v for k, v in kwargs.items() if k in self_fields}
updates = self.model.validate_choices(updates)
updates = self.model.translate_columns_to_aliases(updates)
if not each and not self.filter_clauses:
raise QueryDefinitionError(
"You cannot update without filtering the queryset first. "
"If you want to update all rows use update(each=True, **kwargs)"
)
expr = FilterQuery(filter_clauses=self.filter_clauses).apply(
self.table.update().values(**updates)
)

View File

@ -14,7 +14,7 @@ from typing import (
)
import ormar
from ormar.exceptions import ModelPersistenceError
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
if TYPE_CHECKING: # pragma no cover
from ormar.relations import Relation
@ -132,6 +132,22 @@ class QuerysetProxy(Generic[T]):
# print("\n", expr.compile(compile_kwargs={"literal_binds": True}))
await model_cls.Meta.database.execute(expr)
async def update_through_instance(self, child: "T", **kwargs: Any) -> None:
"""
Updates a through model instance in the database for m2m relations.
:param kwargs: dict of additional keyword arguments for through instance
:type kwargs: Any
:param child: child model instance
:type child: Model
"""
model_cls = self.relation.through
owner_column = self.related_field.default_target_field_name() # type: ignore
child_column = self.related_field.default_source_field_name() # type: ignore
rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk}
through_model = await model_cls.objects.get(**rel_kwargs)
await through_model.update(**kwargs)
async def delete_through_instance(self, child: "T") -> None:
"""
Removes through model instance from the database for m2m relations.
@ -290,6 +306,39 @@ class QuerysetProxy(Generic[T]):
await self.create_through_instance(created, **through_kwargs)
return created
async def update(self, each: bool = False, **kwargs: Any) -> int:
"""
Updates the model table after applying the filters from kwargs.
You have to either pass a filter to narrow down a query or explicitly pass
each=True flag to affect whole table.
:param each: flag if whole table should be affected if no filter is passed
:type each: bool
:param kwargs: fields names and proper value types
:type kwargs: Any
:return: number of updated rows
:rtype: int
"""
# queryset proxy always have one filter for pk of parent model
if not each and len(self.queryset.filter_clauses) == 1:
raise QueryDefinitionError(
"You cannot update without filtering the queryset first. "
"If you want to update all rows use update(each=True, **kwargs)"
)
through_kwargs = kwargs.pop(self.through_model_name, {})
children = await self.queryset.all()
for child in children:
if child:
await child.update(**kwargs)
if self.type_ == ormar.RelationType.MULTIPLE and through_kwargs:
await self.update_through_instance(
child=child, # type: ignore
**through_kwargs,
)
return len(children)
async def get_or_create(self, **kwargs: Any) -> "T":
"""
Combination of create and get methods.

View File

@ -235,6 +235,64 @@ async def test_ordering_by_through_model() -> Any:
assert post3.categories[2].postcategory.param_name == "volume"
@pytest.mark.asyncio
async def test_update_through_models_from_queryset_on_through() -> Any:
async with database:
post = await Post(title="Test post").save()
await post.categories.create(
name="Test category1",
postcategory={"sort_order": 2, "param_name": "volume"},
)
await post.categories.create(
name="Test category2", postcategory={"sort_order": 1, "param_name": "area"}
)
await post.categories.create(
name="Test category3",
postcategory={"sort_order": 3, "param_name": "velocity"},
)
await PostCategory.objects.filter(param_name="volume", post=post.id).update(
sort_order=4
)
post2 = (
await Post.objects.select_related("categories")
.order_by("-postcategory__sort_order")
.get()
)
assert len(post2.categories) == 3
assert post2.categories[0].postcategory.param_name == "volume"
assert post2.categories[2].postcategory.param_name == "area"
@pytest.mark.asyncio
async def test_update_through_from_related() -> Any:
async with database:
post = await Post(title="Test post").save()
await post.categories.create(
name="Test category1",
postcategory={"sort_order": 2, "param_name": "volume"},
)
await post.categories.create(
name="Test category2", postcategory={"sort_order": 1, "param_name": "area"}
)
await post.categories.create(
name="Test category3",
postcategory={"sort_order": 3, "param_name": "velocity"},
)
await post.categories.filter(name="Test category3").update(
postcategory={"sort_order": 4}
)
post2 = (
await Post.objects.select_related("categories")
.order_by("postcategory__sort_order")
.get()
)
assert len(post2.categories) == 3
assert post2.categories[2].postcategory.sort_order == 4
# TODO: check/ modify following
# add to fields with class lower name (V)
@ -245,11 +303,12 @@ async def test_ordering_by_through_model() -> Any:
# accessing from instance (V) <- no both sides only nested one is relevant, fix one side
# filtering in filter (through name normally) (V) < - table prefix from normal relation,
# check if is_through needed, resolved side of relation
# ordering by in order_by
# ordering by in order_by (V)
# updating in query (V)
# updating from querysetproxy (V)
# modifying from instance (both sides?) (X) <= no, the loaded one doesn't have relations
# updating in query
# modifying from instance (both sides?)
# including/excluding in fields?
# allowing to change fk fields names in through model?
# make through optional? auto-generated for cases other fields are missing?

View File

@ -6,6 +6,7 @@ import pytest
import sqlalchemy
import ormar
from ormar.exceptions import QueryDefinitionError
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
@ -180,3 +181,42 @@ async def test_queryset_methods():
assert len(categories) == 3 == len(post.categories)
for cat in post.categories:
assert cat.subject.name is not None
@pytest.mark.asyncio
async def test_queryset_update():
async with database:
async with database.transaction(force_rollback=True):
guido = await Author.objects.create(
first_name="Guido", last_name="Van Rossum"
)
subject = await Subject(name="Random").save()
post = await Post.objects.create(title="Hello, M2M", author=guido)
await post.categories.create(name="News", sort_order=1, subject=subject)
await post.categories.create(name="Breaking", sort_order=3, subject=subject)
await post.categories.order_by("sort_order").all()
assert len(post.categories) == 2
assert post.categories[0].sort_order == 1
assert post.categories[0].name == "News"
assert post.categories[1].sort_order == 3
assert post.categories[1].name == "Breaking"
updated = await post.categories.update(each=True, name="Test")
assert updated == 2
await post.categories.order_by("sort_order").all()
assert len(post.categories) == 2
assert post.categories[0].name == "Test"
assert post.categories[1].name == "Test"
updated = await post.categories.filter(sort_order=3).update(name="Test 2")
assert updated == 1
await post.categories.order_by("sort_order").all()
assert len(post.categories) == 2
assert post.categories[0].name == "Test"
assert post.categories[1].name == "Test 2"
with pytest.raises(QueryDefinitionError):
await post.categories.update(name="Test WRONG")

View File

@ -0,0 +1,147 @@
from typing import List, Optional
from uuid import UUID, uuid4
import databases
import pydantic
import pytest
import sqlalchemy
from fastapi import FastAPI
from starlette.testclient import TestClient
import ormar
from tests.settings import DATABASE_URL
app = FastAPI()
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
app.state.database = database
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
class BaseMeta(ormar.ModelMeta):
database = database
metadata = metadata
class OtherThing(ormar.Model):
class Meta(BaseMeta):
tablename = "other_things"
id: UUID = ormar.UUID(primary_key=True, default=uuid4)
name: str = ormar.Text(default="")
ot_contents: str = ormar.Text(default="")
class Thing(ormar.Model):
class Meta(BaseMeta):
tablename = "things"
id: UUID = ormar.UUID(primary_key=True, default=uuid4)
name: str = ormar.Text(default="")
js: pydantic.Json = ormar.JSON(nullable=True)
other_thing: Optional[OtherThing] = ormar.ForeignKey(OtherThing, nullable=True)
@app.post("/test/1")
async def post_test_1():
# don't split initialization and attribute assignment
ot = await OtherThing(ot_contents="otc").save()
await Thing(other_thing=ot, name="t1").save()
await Thing(other_thing=ot, name="t2").save()
await Thing(other_thing=ot, name="t3").save()
# if you do not care about returned object you can even go with bulk_create
# all of them are created in one transaction
# things = [Thing(other_thing=ot, name='t1'),
# Thing(other_thing=ot, name="t2"),
# Thing(other_thing=ot, name="t3")]
# await Thing.objects.bulk_create(things)
@app.get("/test/2", response_model=List[Thing])
async def get_test_2():
# if you only query for one use get or first
ot = await OtherThing.objects.get()
ts = await ot.things.all()
# specifically null out the relation on things before return
for t in ts:
t.remove(ot, name="other_thing")
return ts
@app.get("/test/3", response_model=List[Thing])
async def get_test_3():
ot = await OtherThing.objects.select_related("things").get()
# exclude unwanted field while ot is still in scope
# in order not to pass it to fastapi
return [t.dict(exclude={"other_thing"}) for t in ot.things]
@app.get("/test/4", response_model=List[Thing], response_model_exclude={"other_thing"})
async def get_test_4():
ot = await OtherThing.objects.get()
# query from the active side
return await Thing.objects.all(other_thing=ot)
@app.get("/get_ot/", response_model=OtherThing)
async def get_ot():
return await OtherThing.objects.get()
# more real life (usually) is not getting some random OT and get it's Things
# but query for a specific one by some kind of id
@app.get(
"/test/5/{thing_id}",
response_model=List[Thing],
response_model_exclude={"other_thing"},
)
async def get_test_5(thing_id: UUID):
return await Thing.objects.all(other_thing__id=thing_id)
def test_endpoints():
client = TestClient(app)
with client:
resp = client.post("/test/1")
assert resp.status_code == 200
resp2 = client.get("/test/2")
assert resp2.status_code == 200
assert len(resp2.json()) == 3
resp3 = client.get("/test/3")
assert resp3.status_code == 200
assert len(resp3.json()) == 3
resp4 = client.get("/test/4")
assert resp4.status_code == 200
assert len(resp4.json()) == 3
ot = OtherThing(**client.get("/get_ot/").json())
resp5 = client.get(f"/test/5/{ot.id}")
assert resp5.status_code == 200
assert len(resp5.json()) == 3