add update to queryset, add update_through_instance, start to update docs
This commit is contained in:
@ -1,3 +1,27 @@
|
|||||||
|
# 0.9.5
|
||||||
|
|
||||||
|
## Features
|
||||||
|
* Add `update` method to `QuerysetProxy` so now it's possible to update related models directly from parent model
|
||||||
|
in `ManyToMany` relations and in reverse `ForeignKey` relations. Note that update like in `QuerySet` `update` returns number of
|
||||||
|
updated models and **does not update related models in place** on praent model. To get the refreshed data on parent model you need to refresh
|
||||||
|
the related models (i.e. `await model_instance.related.all()`)
|
||||||
|
* Added possibility to add more fields on `Through` model for `ManyToMany` relationships:
|
||||||
|
* name of the through model field is the lowercase name of the Through class
|
||||||
|
* you can pass additional fields when calling `add(child, **kwargs)` on relation (on `QuerysetProxy`)
|
||||||
|
* you can pass additional fields when calling `create(**kwargs)` on relation (on `QuerysetProxy`)
|
||||||
|
when one of the keyword arguments should be the through model name with a dict of values
|
||||||
|
* you can order by on through model fields
|
||||||
|
* you can filter on through model fields
|
||||||
|
* you can include and exclude fields on through models
|
||||||
|
* through models are attached only to related models (i.e. if you query from A to B -> only on B)
|
||||||
|
* check the updated docs for more information
|
||||||
|
|
||||||
|
# Other
|
||||||
|
* Updated docs and api docs
|
||||||
|
* Refactors and optimisations mainly related to filters and order bys
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 0.9.4
|
# 0.9.4
|
||||||
|
|
||||||
## Fixes
|
## Fixes
|
||||||
|
|||||||
@ -18,10 +18,10 @@ class PrefetchQueryMixin(RelationMixin):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_clause_target_and_filter_column_name(
|
def get_clause_target_and_filter_column_name(
|
||||||
parent_model: Type["Model"],
|
parent_model: Type["Model"],
|
||||||
target_model: Type["Model"],
|
target_model: Type["Model"],
|
||||||
reverse: bool,
|
reverse: bool,
|
||||||
related: str,
|
related: str,
|
||||||
) -> Tuple[Type["Model"], str]:
|
) -> Tuple[Type["Model"], str]:
|
||||||
"""
|
"""
|
||||||
Returns Model on which query clause should be performed and name of the column.
|
Returns Model on which query clause should be performed and name of the column.
|
||||||
@ -51,7 +51,7 @@ class PrefetchQueryMixin(RelationMixin):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_column_name_for_id_extraction(
|
def get_column_name_for_id_extraction(
|
||||||
parent_model: Type["Model"], reverse: bool, related: str, use_raw: bool,
|
parent_model: Type["Model"], reverse: bool, related: str, use_raw: bool,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Returns name of the column that should be used to extract ids from model.
|
Returns name of the column that should be used to extract ids from model.
|
||||||
|
|||||||
@ -52,6 +52,9 @@ class QuerySetProtocol(Protocol): # pragma: nocover
|
|||||||
async def create(self, **kwargs: Any) -> "Model":
|
async def create(self, **kwargs: Any) -> "Model":
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def update(self, each: bool = False, **kwargs: Any) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
async def get_or_create(self, **kwargs: Any) -> "Model":
|
async def get_or_create(self, **kwargs: Any) -> "Model":
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|||||||
@ -142,6 +142,7 @@ class PrefetchQuery:
|
|||||||
self.models: Dict = {}
|
self.models: Dict = {}
|
||||||
self.select_dict = translate_list_to_dict(self._select_related)
|
self.select_dict = translate_list_to_dict(self._select_related)
|
||||||
self.orders_by = orders_by or []
|
self.orders_by = orders_by or []
|
||||||
|
# TODO: refactor OrderActions to use it instead of strings from it
|
||||||
self.order_dict = translate_list_to_dict(
|
self.order_dict = translate_list_to_dict(
|
||||||
[x.query_str for x in self.orders_by], is_order=True
|
[x.query_str for x in self.orders_by], is_order=True
|
||||||
)
|
)
|
||||||
|
|||||||
@ -573,17 +573,19 @@ class QuerySet(Generic[T]):
|
|||||||
:return: number of updated rows
|
:return: number of updated rows
|
||||||
:rtype: int
|
:rtype: int
|
||||||
"""
|
"""
|
||||||
|
if not each and not self.filter_clauses:
|
||||||
|
raise QueryDefinitionError(
|
||||||
|
"You cannot update without filtering the queryset first. "
|
||||||
|
"If you want to update all rows use update(each=True, **kwargs)"
|
||||||
|
)
|
||||||
|
|
||||||
self_fields = self.model.extract_db_own_fields().union(
|
self_fields = self.model.extract_db_own_fields().union(
|
||||||
self.model.extract_related_names()
|
self.model.extract_related_names()
|
||||||
)
|
)
|
||||||
updates = {k: v for k, v in kwargs.items() if k in self_fields}
|
updates = {k: v for k, v in kwargs.items() if k in self_fields}
|
||||||
updates = self.model.validate_choices(updates)
|
updates = self.model.validate_choices(updates)
|
||||||
updates = self.model.translate_columns_to_aliases(updates)
|
updates = self.model.translate_columns_to_aliases(updates)
|
||||||
if not each and not self.filter_clauses:
|
|
||||||
raise QueryDefinitionError(
|
|
||||||
"You cannot update without filtering the queryset first. "
|
|
||||||
"If you want to update all rows use update(each=True, **kwargs)"
|
|
||||||
)
|
|
||||||
expr = FilterQuery(filter_clauses=self.filter_clauses).apply(
|
expr = FilterQuery(filter_clauses=self.filter_clauses).apply(
|
||||||
self.table.update().values(**updates)
|
self.table.update().values(**updates)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import ormar
|
import ormar
|
||||||
from ormar.exceptions import ModelPersistenceError
|
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar.relations import Relation
|
from ormar.relations import Relation
|
||||||
@ -132,6 +132,22 @@ class QuerysetProxy(Generic[T]):
|
|||||||
# print("\n", expr.compile(compile_kwargs={"literal_binds": True}))
|
# print("\n", expr.compile(compile_kwargs={"literal_binds": True}))
|
||||||
await model_cls.Meta.database.execute(expr)
|
await model_cls.Meta.database.execute(expr)
|
||||||
|
|
||||||
|
async def update_through_instance(self, child: "T", **kwargs: Any) -> None:
|
||||||
|
"""
|
||||||
|
Updates a through model instance in the database for m2m relations.
|
||||||
|
|
||||||
|
:param kwargs: dict of additional keyword arguments for through instance
|
||||||
|
:type kwargs: Any
|
||||||
|
:param child: child model instance
|
||||||
|
:type child: Model
|
||||||
|
"""
|
||||||
|
model_cls = self.relation.through
|
||||||
|
owner_column = self.related_field.default_target_field_name() # type: ignore
|
||||||
|
child_column = self.related_field.default_source_field_name() # type: ignore
|
||||||
|
rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk}
|
||||||
|
through_model = await model_cls.objects.get(**rel_kwargs)
|
||||||
|
await through_model.update(**kwargs)
|
||||||
|
|
||||||
async def delete_through_instance(self, child: "T") -> None:
|
async def delete_through_instance(self, child: "T") -> None:
|
||||||
"""
|
"""
|
||||||
Removes through model instance from the database for m2m relations.
|
Removes through model instance from the database for m2m relations.
|
||||||
@ -290,6 +306,39 @@ class QuerysetProxy(Generic[T]):
|
|||||||
await self.create_through_instance(created, **through_kwargs)
|
await self.create_through_instance(created, **through_kwargs)
|
||||||
return created
|
return created
|
||||||
|
|
||||||
|
async def update(self, each: bool = False, **kwargs: Any) -> int:
|
||||||
|
"""
|
||||||
|
Updates the model table after applying the filters from kwargs.
|
||||||
|
|
||||||
|
You have to either pass a filter to narrow down a query or explicitly pass
|
||||||
|
each=True flag to affect whole table.
|
||||||
|
|
||||||
|
:param each: flag if whole table should be affected if no filter is passed
|
||||||
|
:type each: bool
|
||||||
|
:param kwargs: fields names and proper value types
|
||||||
|
:type kwargs: Any
|
||||||
|
:return: number of updated rows
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
# queryset proxy always have one filter for pk of parent model
|
||||||
|
if not each and len(self.queryset.filter_clauses) == 1:
|
||||||
|
raise QueryDefinitionError(
|
||||||
|
"You cannot update without filtering the queryset first. "
|
||||||
|
"If you want to update all rows use update(each=True, **kwargs)"
|
||||||
|
)
|
||||||
|
|
||||||
|
through_kwargs = kwargs.pop(self.through_model_name, {})
|
||||||
|
children = await self.queryset.all()
|
||||||
|
for child in children:
|
||||||
|
if child:
|
||||||
|
await child.update(**kwargs)
|
||||||
|
if self.type_ == ormar.RelationType.MULTIPLE and through_kwargs:
|
||||||
|
await self.update_through_instance(
|
||||||
|
child=child, # type: ignore
|
||||||
|
**through_kwargs,
|
||||||
|
)
|
||||||
|
return len(children)
|
||||||
|
|
||||||
async def get_or_create(self, **kwargs: Any) -> "T":
|
async def get_or_create(self, **kwargs: Any) -> "T":
|
||||||
"""
|
"""
|
||||||
Combination of create and get methods.
|
Combination of create and get methods.
|
||||||
|
|||||||
@ -235,6 +235,64 @@ async def test_ordering_by_through_model() -> Any:
|
|||||||
assert post3.categories[2].postcategory.param_name == "volume"
|
assert post3.categories[2].postcategory.param_name == "volume"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_through_models_from_queryset_on_through() -> Any:
|
||||||
|
async with database:
|
||||||
|
post = await Post(title="Test post").save()
|
||||||
|
await post.categories.create(
|
||||||
|
name="Test category1",
|
||||||
|
postcategory={"sort_order": 2, "param_name": "volume"},
|
||||||
|
)
|
||||||
|
await post.categories.create(
|
||||||
|
name="Test category2", postcategory={"sort_order": 1, "param_name": "area"}
|
||||||
|
)
|
||||||
|
await post.categories.create(
|
||||||
|
name="Test category3",
|
||||||
|
postcategory={"sort_order": 3, "param_name": "velocity"},
|
||||||
|
)
|
||||||
|
|
||||||
|
await PostCategory.objects.filter(param_name="volume", post=post.id).update(
|
||||||
|
sort_order=4
|
||||||
|
)
|
||||||
|
post2 = (
|
||||||
|
await Post.objects.select_related("categories")
|
||||||
|
.order_by("-postcategory__sort_order")
|
||||||
|
.get()
|
||||||
|
)
|
||||||
|
assert len(post2.categories) == 3
|
||||||
|
assert post2.categories[0].postcategory.param_name == "volume"
|
||||||
|
assert post2.categories[2].postcategory.param_name == "area"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_through_from_related() -> Any:
|
||||||
|
async with database:
|
||||||
|
post = await Post(title="Test post").save()
|
||||||
|
await post.categories.create(
|
||||||
|
name="Test category1",
|
||||||
|
postcategory={"sort_order": 2, "param_name": "volume"},
|
||||||
|
)
|
||||||
|
await post.categories.create(
|
||||||
|
name="Test category2", postcategory={"sort_order": 1, "param_name": "area"}
|
||||||
|
)
|
||||||
|
await post.categories.create(
|
||||||
|
name="Test category3",
|
||||||
|
postcategory={"sort_order": 3, "param_name": "velocity"},
|
||||||
|
)
|
||||||
|
|
||||||
|
await post.categories.filter(name="Test category3").update(
|
||||||
|
postcategory={"sort_order": 4}
|
||||||
|
)
|
||||||
|
|
||||||
|
post2 = (
|
||||||
|
await Post.objects.select_related("categories")
|
||||||
|
.order_by("postcategory__sort_order")
|
||||||
|
.get()
|
||||||
|
)
|
||||||
|
assert len(post2.categories) == 3
|
||||||
|
assert post2.categories[2].postcategory.sort_order == 4
|
||||||
|
|
||||||
|
|
||||||
# TODO: check/ modify following
|
# TODO: check/ modify following
|
||||||
|
|
||||||
# add to fields with class lower name (V)
|
# add to fields with class lower name (V)
|
||||||
@ -245,11 +303,12 @@ async def test_ordering_by_through_model() -> Any:
|
|||||||
# accessing from instance (V) <- no both sides only nested one is relevant, fix one side
|
# accessing from instance (V) <- no both sides only nested one is relevant, fix one side
|
||||||
# filtering in filter (through name normally) (V) < - table prefix from normal relation,
|
# filtering in filter (through name normally) (V) < - table prefix from normal relation,
|
||||||
# check if is_through needed, resolved side of relation
|
# check if is_through needed, resolved side of relation
|
||||||
# ordering by in order_by
|
# ordering by in order_by (V)
|
||||||
|
# updating in query (V)
|
||||||
|
# updating from querysetproxy (V)
|
||||||
|
|
||||||
|
# modifying from instance (both sides?) (X) <= no, the loaded one doesn't have relations
|
||||||
|
|
||||||
# updating in query
|
|
||||||
# modifying from instance (both sides?)
|
|
||||||
# including/excluding in fields?
|
# including/excluding in fields?
|
||||||
# allowing to change fk fields names in through model?
|
# allowing to change fk fields names in through model?
|
||||||
# make through optional? auto-generated for cases other fields are missing?
|
# make through optional? auto-generated for cases other fields are missing?
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import pytest
|
|||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
import ormar
|
import ormar
|
||||||
|
from ormar.exceptions import QueryDefinitionError
|
||||||
from tests.settings import DATABASE_URL
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||||
@ -180,3 +181,42 @@ async def test_queryset_methods():
|
|||||||
assert len(categories) == 3 == len(post.categories)
|
assert len(categories) == 3 == len(post.categories)
|
||||||
for cat in post.categories:
|
for cat in post.categories:
|
||||||
assert cat.subject.name is not None
|
assert cat.subject.name is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_queryset_update():
|
||||||
|
async with database:
|
||||||
|
async with database.transaction(force_rollback=True):
|
||||||
|
guido = await Author.objects.create(
|
||||||
|
first_name="Guido", last_name="Van Rossum"
|
||||||
|
)
|
||||||
|
subject = await Subject(name="Random").save()
|
||||||
|
post = await Post.objects.create(title="Hello, M2M", author=guido)
|
||||||
|
await post.categories.create(name="News", sort_order=1, subject=subject)
|
||||||
|
await post.categories.create(name="Breaking", sort_order=3, subject=subject)
|
||||||
|
|
||||||
|
await post.categories.order_by("sort_order").all()
|
||||||
|
assert len(post.categories) == 2
|
||||||
|
assert post.categories[0].sort_order == 1
|
||||||
|
assert post.categories[0].name == "News"
|
||||||
|
assert post.categories[1].sort_order == 3
|
||||||
|
assert post.categories[1].name == "Breaking"
|
||||||
|
|
||||||
|
updated = await post.categories.update(each=True, name="Test")
|
||||||
|
assert updated == 2
|
||||||
|
|
||||||
|
await post.categories.order_by("sort_order").all()
|
||||||
|
assert len(post.categories) == 2
|
||||||
|
assert post.categories[0].name == "Test"
|
||||||
|
assert post.categories[1].name == "Test"
|
||||||
|
|
||||||
|
updated = await post.categories.filter(sort_order=3).update(name="Test 2")
|
||||||
|
assert updated == 1
|
||||||
|
|
||||||
|
await post.categories.order_by("sort_order").all()
|
||||||
|
assert len(post.categories) == 2
|
||||||
|
assert post.categories[0].name == "Test"
|
||||||
|
assert post.categories[1].name == "Test 2"
|
||||||
|
|
||||||
|
with pytest.raises(QueryDefinitionError):
|
||||||
|
await post.categories.update(name="Test WRONG")
|
||||||
|
|||||||
147
tests/test_wekref_exclusion.py
Normal file
147
tests/test_wekref_exclusion.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
import databases
|
||||||
|
import pydantic
|
||||||
|
import pytest
|
||||||
|
import sqlalchemy
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
import ormar
|
||||||
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||||
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
|
app.state.database = database
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup() -> None:
|
||||||
|
database_ = app.state.database
|
||||||
|
if not database_.is_connected:
|
||||||
|
await database_.connect()
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown() -> None:
|
||||||
|
database_ = app.state.database
|
||||||
|
if database_.is_connected:
|
||||||
|
await database_.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True, scope="module")
|
||||||
|
def create_test_database():
|
||||||
|
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||||
|
metadata.create_all(engine)
|
||||||
|
yield
|
||||||
|
metadata.drop_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMeta(ormar.ModelMeta):
|
||||||
|
database = database
|
||||||
|
metadata = metadata
|
||||||
|
|
||||||
|
|
||||||
|
class OtherThing(ormar.Model):
|
||||||
|
class Meta(BaseMeta):
|
||||||
|
tablename = "other_things"
|
||||||
|
|
||||||
|
id: UUID = ormar.UUID(primary_key=True, default=uuid4)
|
||||||
|
name: str = ormar.Text(default="")
|
||||||
|
ot_contents: str = ormar.Text(default="")
|
||||||
|
|
||||||
|
|
||||||
|
class Thing(ormar.Model):
|
||||||
|
class Meta(BaseMeta):
|
||||||
|
tablename = "things"
|
||||||
|
|
||||||
|
id: UUID = ormar.UUID(primary_key=True, default=uuid4)
|
||||||
|
name: str = ormar.Text(default="")
|
||||||
|
js: pydantic.Json = ormar.JSON(nullable=True)
|
||||||
|
other_thing: Optional[OtherThing] = ormar.ForeignKey(OtherThing, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/test/1")
|
||||||
|
async def post_test_1():
|
||||||
|
# don't split initialization and attribute assignment
|
||||||
|
ot = await OtherThing(ot_contents="otc").save()
|
||||||
|
await Thing(other_thing=ot, name="t1").save()
|
||||||
|
await Thing(other_thing=ot, name="t2").save()
|
||||||
|
await Thing(other_thing=ot, name="t3").save()
|
||||||
|
|
||||||
|
# if you do not care about returned object you can even go with bulk_create
|
||||||
|
# all of them are created in one transaction
|
||||||
|
# things = [Thing(other_thing=ot, name='t1'),
|
||||||
|
# Thing(other_thing=ot, name="t2"),
|
||||||
|
# Thing(other_thing=ot, name="t3")]
|
||||||
|
# await Thing.objects.bulk_create(things)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/test/2", response_model=List[Thing])
|
||||||
|
async def get_test_2():
|
||||||
|
# if you only query for one use get or first
|
||||||
|
ot = await OtherThing.objects.get()
|
||||||
|
ts = await ot.things.all()
|
||||||
|
# specifically null out the relation on things before return
|
||||||
|
for t in ts:
|
||||||
|
t.remove(ot, name="other_thing")
|
||||||
|
return ts
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/test/3", response_model=List[Thing])
|
||||||
|
async def get_test_3():
|
||||||
|
ot = await OtherThing.objects.select_related("things").get()
|
||||||
|
# exclude unwanted field while ot is still in scope
|
||||||
|
# in order not to pass it to fastapi
|
||||||
|
return [t.dict(exclude={"other_thing"}) for t in ot.things]
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/test/4", response_model=List[Thing], response_model_exclude={"other_thing"})
|
||||||
|
async def get_test_4():
|
||||||
|
ot = await OtherThing.objects.get()
|
||||||
|
# query from the active side
|
||||||
|
return await Thing.objects.all(other_thing=ot)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/get_ot/", response_model=OtherThing)
|
||||||
|
async def get_ot():
|
||||||
|
return await OtherThing.objects.get()
|
||||||
|
|
||||||
|
|
||||||
|
# more real life (usually) is not getting some random OT and get it's Things
|
||||||
|
# but query for a specific one by some kind of id
|
||||||
|
@app.get(
|
||||||
|
"/test/5/{thing_id}",
|
||||||
|
response_model=List[Thing],
|
||||||
|
response_model_exclude={"other_thing"},
|
||||||
|
)
|
||||||
|
async def get_test_5(thing_id: UUID):
|
||||||
|
return await Thing.objects.all(other_thing__id=thing_id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_endpoints():
|
||||||
|
client = TestClient(app)
|
||||||
|
with client:
|
||||||
|
resp = client.post("/test/1")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
resp2 = client.get("/test/2")
|
||||||
|
assert resp2.status_code == 200
|
||||||
|
assert len(resp2.json()) == 3
|
||||||
|
|
||||||
|
resp3 = client.get("/test/3")
|
||||||
|
assert resp3.status_code == 200
|
||||||
|
assert len(resp3.json()) == 3
|
||||||
|
|
||||||
|
resp4 = client.get("/test/4")
|
||||||
|
assert resp4.status_code == 200
|
||||||
|
assert len(resp4.json()) == 3
|
||||||
|
|
||||||
|
ot = OtherThing(**client.get("/get_ot/").json())
|
||||||
|
resp5 = client.get(f"/test/5/{ot.id}")
|
||||||
|
assert resp5.status_code == 200
|
||||||
|
assert len(resp5.json()) == 3
|
||||||
Reference in New Issue
Block a user