query level delete and update

This commit is contained in:
collerek
2020-09-24 13:56:13 +02:00
parent ae34d21767
commit da05063e8d
5 changed files with 121 additions and 24 deletions

BIN
.coverage

Binary file not shown.

View File

@ -25,6 +25,12 @@ class ModelTableProxy:
self_fields = {k: v for k, v in self.dict().items() if k not in related_names} self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
return self_fields return self_fields
@classmethod
def extract_db_own_fields(cls) -> set:
related_names = cls._extract_related_names()
self_fields = {name for name in cls.Meta.model_fields.keys() if name not in related_names}
return self_fields
@classmethod @classmethod
def substitute_models_with_pks(cls, model_dict: dict) -> dict: def substitute_models_with_pks(cls, model_dict: dict) -> dict:
for field in cls._extract_related_names(): for field in cls._extract_related_names():
@ -51,9 +57,9 @@ class ModelTableProxy:
related_names = set() related_names = set()
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if ( if (
inspect.isclass(field) inspect.isclass(field)
and issubclass(field, ForeignKeyField) and issubclass(field, ForeignKeyField)
and not field.virtual and not field.virtual
): ):
related_names.add(name) related_names.add(name)
return related_names return related_names
@ -65,9 +71,9 @@ class ModelTableProxy:
related_names = set() related_names = set()
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if ( if (
inspect.isclass(field) inspect.isclass(field)
and issubclass(field, ForeignKeyField) and issubclass(field, ForeignKeyField)
and field.nullable and field.nullable
): ):
related_names.add(name) related_names.add(name)
return related_names return related_names
@ -95,7 +101,7 @@ class ModelTableProxy:
@staticmethod @staticmethod
def resolve_relation_field( def resolve_relation_field(
item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]] item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]]
) -> Type[Field]: ) -> Type[Field]:
name = ModelTableProxy.resolve_relation_name(item, related) name = ModelTableProxy.resolve_relation_name(item, related)
to_field = item.Meta.model_fields.get(name) to_field = item.Meta.model_fields.get(name)
@ -121,12 +127,12 @@ class ModelTableProxy:
for field in one.Meta.model_fields.keys(): for field in one.Meta.model_fields.keys():
current_field = getattr(one, field) current_field = getattr(one, field)
if isinstance(current_field, list) and not isinstance( if isinstance(current_field, list) and not isinstance(
current_field, ormar.Model current_field, ormar.Model
): ):
setattr(other, field, current_field + getattr(other, field)) setattr(other, field, current_field + getattr(other, field))
elif ( elif (
isinstance(current_field, ormar.Model) isinstance(current_field, ormar.Model)
and current_field.pk == getattr(other, field).pk and current_field.pk == getattr(other, field).pk
): ):
setattr( setattr(
other, other,

View File

@ -5,6 +5,7 @@ import sqlalchemy
import ormar # noqa I100 import ormar # noqa I100
from ormar import MultipleMatches, NoMatch from ormar import MultipleMatches, NoMatch
from ormar.exceptions import QueryDefinitionError
from ormar.queryset import FilterQuery from ormar.queryset import FilterQuery
from ormar.queryset.clause import QueryClause from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query from ormar.queryset.query import Query
@ -15,13 +16,13 @@ if TYPE_CHECKING: # pragma no cover
class QuerySet: class QuerySet:
def __init__( # noqa CFQ002 def __init__( # noqa CFQ002
self, self,
model_cls: Type["Model"] = None, model_cls: Type["Model"] = None,
filter_clauses: List = None, filter_clauses: List = None,
exclude_clauses: List = None, exclude_clauses: List = None,
select_related: List = None, select_related: List = None,
limit_count: int = None, limit_count: int = None,
offset: int = None, offset: int = None,
) -> None: ) -> None:
self.model_cls = model_cls self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses self.filter_clauses = [] if filter_clauses is None else filter_clauses
@ -52,7 +53,7 @@ class QuerySet:
pkname = self.model_cls.Meta.pkname pkname = self.model_cls.Meta.pkname
pk = self.model_cls.Meta.model_fields[pkname] pk = self.model_cls.Meta.model_fields[pkname]
if new_kwargs.get(pkname, ormar.Undefined) is None and ( if new_kwargs.get(pkname, ormar.Undefined) is None and (
pk.nullable or pk.autoincrement pk.nullable or pk.autoincrement
): ):
del new_kwargs[pkname] del new_kwargs[pkname]
return new_kwargs return new_kwargs
@ -135,10 +136,25 @@ class QuerySet:
expr = sqlalchemy.func.count().select().select_from(expr) expr = sqlalchemy.func.count().select().select_from(expr)
return await self.database.fetch_val(expr) return await self.database.fetch_val(expr)
async def delete(self, **kwargs: Any) -> int: async def update(self, each: bool = False, **kwargs: Any) -> int:
self_fields = self.model_cls.extract_db_own_fields()
updates = {k: v for k, v in kwargs.items() if k in self_fields}
if not each and not self.filter_clauses:
raise QueryDefinitionError('You cannot update without filtering the queryset first. '
'If you want to update all rows use update(each=True, **kwargs)')
expr = FilterQuery(filter_clauses=self.filter_clauses).apply(
self.table.update().values(**updates)
)
# print(expr.compile(compile_kwargs={"literal_binds": True}))
return await self.database.execute(expr)
async def delete(self, each: bool = False, **kwargs: Any) -> int:
if kwargs: if kwargs:
return await self.filter(**kwargs).delete() return await self.filter(**kwargs).delete()
expr = FilterQuery(filter_clauses=self.filter_clauses,).apply( if not each and not self.filter_clauses:
raise QueryDefinitionError('You cannot delete without filtering the queryset first. '
'If you want to delete all rows use delete(each=True)')
expr = FilterQuery(filter_clauses=self.filter_clauses).apply(
self.table.delete() self.table.delete()
) )
return await self.database.execute(expr) return await self.database.execute(expr)

View File

@ -71,10 +71,10 @@ async def create_test_database():
async def cleanup(): async def cleanup():
yield yield
async with database: async with database:
await PostCategory.objects.delete() await PostCategory.objects.delete(each=True)
await Post.objects.delete() await Post.objects.delete(each=True)
await Category.objects.delete() await Category.objects.delete(each=True)
await Author.objects.delete() await Author.objects.delete(each=True)
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -0,0 +1,75 @@
import databases
import pytest
import sqlalchemy
import ormar
from ormar.exceptions import QueryDefinitionError
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Book(ormar.Model):
class Meta:
tablename = "books"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
title: ormar.String(max_length=200)
author: ormar.String(max_length=100)
genre: ormar.String(max_length=100, default='Fiction', choices=['Fiction', 'Adventure', 'Historic', 'Fantasy'])
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_delete_and_update():
async with database:
async with database.transaction(force_rollback=True):
await Book.objects.create(title='Tom Sawyer', author="Twain, Mark", genre='Adventure')
await Book.objects.create(title='War and Peace', author="Tolstoy, Leo", genre='Fiction')
await Book.objects.create(title='Anna Karenina', author="Tolstoy, Leo", genre='Fiction')
await Book.objects.create(title='Harry Potter', author="Rowling, J.K.", genre='Fantasy')
await Book.objects.create(title='Lord of the Rings', author="Tolkien, J.R.", genre='Fantasy')
all_books = await Book.objects.all()
assert len(all_books) == 5
await Book.objects.filter(author="Tolstoy, Leo").update(author="Lenin, Vladimir")
all_books = await Book.objects.filter(author="Lenin, Vladimir").all()
assert len(all_books) == 2
historic_books = await Book.objects.filter(genre='Historic').all()
assert len(historic_books) == 0
with pytest.raises(QueryDefinitionError):
await Book.objects.update(genre='Historic')
await Book.objects.filter(author="Lenin, Vladimir").update(genre='Historic')
historic_books = await Book.objects.filter(genre='Historic').all()
assert len(historic_books) == 2
await Book.objects.delete(genre='Fantasy')
all_books = await Book.objects.all()
assert len(all_books) == 3
await Book.objects.update(each=True, genre='Fiction')
all_books = await Book.objects.filter(genre='Fiction').all()
assert len(all_books) == 3
with pytest.raises(QueryDefinitionError):
await Book.objects.delete()
await Book.objects.delete(each=True)
all_books = await Book.objects.all()
assert len(all_books) == 0