add exclude method on QuerySet and fix missing default values on creation

This commit is contained in:
collerek
2020-09-17 18:03:29 +02:00
parent 48819f1023
commit 1a4be03131
9 changed files with 57 additions and 9 deletions

BIN
.coverage

Binary file not shown.

View File

@ -38,6 +38,11 @@ class BaseField:
return Field(default=default) return Field(default=default)
return None return None
@classmethod
def get_default(cls) -> Any:
if cls.has_default():
return cls.default if cls.default is not None else cls.server_default
@classmethod @classmethod
def has_default(cls) -> bool: def has_default(cls) -> bool:
return cls.default is not None or cls.server_default is not None return cls.default is not None or cls.server_default is not None

View File

@ -108,6 +108,7 @@ class Model(NewBaseModel):
if not self.pk and self.Meta.model_fields.get(self.Meta.pkname).autoincrement: if not self.pk and self.Meta.model_fields.get(self.Meta.pkname).autoincrement:
self_fields.pop(self.Meta.pkname, None) self_fields.pop(self.Meta.pkname, None)
self_fields = self.objects._populate_default_values(self_fields)
expr = self.Meta.table.insert() expr = self.Meta.table.insert()
expr = expr.values(**self_fields) expr = expr.values(**self_fields)
item_id = await self.Meta.database.execute(expr) item_id = await self.Meta.database.execute(expr)

View File

@ -29,8 +29,8 @@ class QueryClause:
self, model_cls: Type["Model"], filter_clauses: List, select_related: List, self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
) -> None: ) -> None:
self._select_related = select_related self._select_related = select_related[:]
self.filter_clauses = filter_clauses self.filter_clauses = filter_clauses[:]
self.model_cls = model_cls self.model_cls = model_cls
self.table = self.model_cls.Meta.table self.table = self.model_cls.Meta.table

View File

@ -4,7 +4,8 @@ import sqlalchemy
class FilterQuery: class FilterQuery:
def __init__(self, filter_clauses: List) -> None: def __init__(self, filter_clauses: List, exclude: bool = False) -> None:
self.exclude = exclude
self.filter_clauses = filter_clauses self.filter_clauses = filter_clauses
def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select:
@ -13,5 +14,6 @@ class FilterQuery:
clause = self.filter_clauses[0] clause = self.filter_clauses[0]
else: else:
clause = sqlalchemy.sql.and_(*self.filter_clauses) clause = sqlalchemy.sql.and_(*self.filter_clauses)
clause = sqlalchemy.sql.not_(clause) if self.exclude else clause
expr = expr.where(clause) expr = expr.where(clause)
return expr return expr

View File

@ -3,7 +3,7 @@ from typing import List, NamedTuple, TYPE_CHECKING, Tuple, Type
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
from ormar.fields import ManyToManyField # noqa I100 from ormar.fields import ManyToManyField # noqa I100
from ormar.relations import AliasManager from ormar.relations import AliasManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover

View File

@ -12,18 +12,20 @@ if TYPE_CHECKING: # pragma no cover
class Query: class Query:
def __init__( def __init__( # noqa CFQ002
self, self,
model_cls: Type["Model"], model_cls: Type["Model"],
filter_clauses: List, filter_clauses: List,
exclude_clauses: List,
select_related: List, select_related: List,
limit_count: int, limit_count: int,
offset: int, offset: int,
) -> None: ) -> None:
self.query_offset = offset self.query_offset = offset
self.limit_count = limit_count self.limit_count = limit_count
self._select_related = select_related self._select_related = select_related[:]
self.filter_clauses = filter_clauses self.filter_clauses = filter_clauses[:]
self.exclude_clauses = exclude_clauses[:]
self.model_cls = model_cls self.model_cls = model_cls
self.table = self.model_cls.Meta.table self.table = self.model_cls.Meta.table
@ -78,6 +80,9 @@ class Query:
self, expr: sqlalchemy.sql.select self, expr: sqlalchemy.sql.select
) -> sqlalchemy.sql.select: ) -> sqlalchemy.sql.select:
expr = FilterQuery(filter_clauses=self.filter_clauses).apply(expr) expr = FilterQuery(filter_clauses=self.filter_clauses).apply(expr)
expr = FilterQuery(filter_clauses=self.exclude_clauses, exclude=True).apply(
expr
)
expr = LimitQuery(limit_count=self.limit_count).apply(expr) expr = LimitQuery(limit_count=self.limit_count).apply(expr)
expr = OffsetQuery(query_offset=self.query_offset).apply(expr) expr = OffsetQuery(query_offset=self.query_offset).apply(expr)
expr = OrderQuery(order_bys=self.order_bys).apply(expr) expr = OrderQuery(order_bys=self.order_bys).apply(expr)

View File

@ -14,16 +14,18 @@ if TYPE_CHECKING: # pragma no cover
class QuerySet: class QuerySet:
def __init__( def __init__( # noqa CFQ002
self, self,
model_cls: Type["Model"] = None, model_cls: Type["Model"] = None,
filter_clauses: List = None, filter_clauses: List = None,
exclude_clauses: List = None,
select_related: List = None, select_related: List = None,
limit_count: int = None, limit_count: int = None,
offset: int = None, offset: int = None,
) -> None: ) -> None:
self.model_cls = model_cls self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses self.filter_clauses = [] if filter_clauses is None else filter_clauses
self.exclude_clauses = [] if exclude_clauses is None else exclude_clauses
self._select_related = [] if select_related is None else select_related self._select_related = [] if select_related is None else select_related
self.limit_count = limit_count self.limit_count = limit_count
self.query_offset = offset self.query_offset = offset
@ -40,6 +42,12 @@ class QuerySet:
rows = self.model_cls.merge_instances_list(result_rows) rows = self.model_cls.merge_instances_list(result_rows)
return rows return rows
def _populate_default_values(self, new_kwargs: dict) -> dict:
for field_name, field in self.model_cls.Meta.model_fields.items():
if field_name not in new_kwargs and field.has_default():
new_kwargs[field_name] = field.get_default()
return new_kwargs
def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict: def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict:
pkname = self.model_cls.Meta.pkname pkname = self.model_cls.Meta.pkname
pk = self.model_cls.Meta.model_fields[pkname] pk = self.model_cls.Meta.model_fields[pkname]
@ -69,6 +77,7 @@ class QuerySet:
model_cls=self.model_cls, model_cls=self.model_cls,
select_related=self._select_related, select_related=self._select_related,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
offset=self.query_offset, offset=self.query_offset,
limit_count=self.limit_count, limit_count=self.limit_count,
) )
@ -76,22 +85,32 @@ class QuerySet:
# print(exp.compile(compile_kwargs={"literal_binds": True})) # print(exp.compile(compile_kwargs={"literal_binds": True}))
return exp return exp
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003
qryclause = QueryClause( qryclause = QueryClause(
model_cls=self.model_cls, model_cls=self.model_cls,
select_related=self._select_related, select_related=self._select_related,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
) )
filter_clauses, select_related = qryclause.filter(**kwargs) filter_clauses, select_related = qryclause.filter(**kwargs)
if _exclude:
exclude_clauses = filter_clauses
filter_clauses = self.filter_clauses
else:
exclude_clauses = self.exclude_clauses
filter_clauses = filter_clauses
return self.__class__( return self.__class__(
model_cls=self.model_cls, model_cls=self.model_cls,
filter_clauses=filter_clauses, filter_clauses=filter_clauses,
exclude_clauses=exclude_clauses,
select_related=select_related, select_related=select_related,
limit_count=self.limit_count, limit_count=self.limit_count,
offset=self.query_offset, offset=self.query_offset,
) )
def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003
return self.filter(_exclude=True, **kwargs)
def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet": def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet":
if not isinstance(related, (list, tuple)): if not isinstance(related, (list, tuple)):
related = [related] related = [related]
@ -100,6 +119,7 @@ class QuerySet:
return self.__class__( return self.__class__(
model_cls=self.model_cls, model_cls=self.model_cls,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
select_related=related, select_related=related,
limit_count=self.limit_count, limit_count=self.limit_count,
offset=self.query_offset, offset=self.query_offset,
@ -127,6 +147,7 @@ class QuerySet:
return self.__class__( return self.__class__(
model_cls=self.model_cls, model_cls=self.model_cls,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
select_related=self._select_related, select_related=self._select_related,
limit_count=limit_count, limit_count=limit_count,
offset=self.query_offset, offset=self.query_offset,
@ -136,6 +157,7 @@ class QuerySet:
return self.__class__( return self.__class__(
model_cls=self.model_cls, model_cls=self.model_cls,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses,
select_related=self._select_related, select_related=self._select_related,
limit_count=self.limit_count, limit_count=self.limit_count,
offset=offset, offset=offset,
@ -177,6 +199,7 @@ class QuerySet:
new_kwargs = dict(**kwargs) new_kwargs = dict(**kwargs)
new_kwargs = self._remove_pk_from_kwargs(new_kwargs) new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs) new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs)
expr = self.table.insert() expr = self.table.insert()
expr = expr.values(**new_kwargs) expr = expr.values(**new_kwargs)

View File

@ -165,6 +165,18 @@ async def test_model_filter():
products = await Product.objects.all(name__icontains="T") products = await Product.objects.all(name__icontains="T")
assert len(products) == 2 assert len(products) == 2
products = await Product.objects.exclude(rating__gte=4).all()
assert len(products) == 1
products = await Product.objects.exclude(rating__gte=4, in_stock=True).all()
assert len(products) == 2
products = await Product.objects.exclude(in_stock=True).all()
assert len(products) == 1
products = await Product.objects.exclude(name__icontains="T").all()
assert len(products) == 1
# Test escaping % character from icontains, contains, and iexact # Test escaping % character from icontains, contains, and iexact
await Product.objects.create(name="100%-Cotton", rating=3) await Product.objects.create(name="100%-Cotton", rating=3)
await Product.objects.create(name="Cotton-100%-Egyptian", rating=3) await Product.objects.create(name="Cotton-100%-Egyptian", rating=3)