diff --git a/.coverage b/.coverage index 49a5456..b115947 100644 Binary files a/.coverage and b/.coverage differ diff --git a/ormar/fields/base.py b/ormar/fields/base.py index f3e83a4..18cdbfa 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -38,6 +38,11 @@ class BaseField: return Field(default=default) return None + @classmethod + def get_default(cls) -> Any: + if cls.has_default(): + return cls.default if cls.default is not None else cls.server_default + @classmethod def has_default(cls) -> bool: return cls.default is not None or cls.server_default is not None diff --git a/ormar/models/model.py b/ormar/models/model.py index fd61ba6..ae2a590 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -108,6 +108,7 @@ class Model(NewBaseModel): if not self.pk and self.Meta.model_fields.get(self.Meta.pkname).autoincrement: self_fields.pop(self.Meta.pkname, None) + self_fields = self.objects._populate_default_values(self_fields) expr = self.Meta.table.insert() expr = expr.values(**self_fields) item_id = await self.Meta.database.execute(expr) diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index 336db05..fd70258 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -29,8 +29,8 @@ class QueryClause: self, model_cls: Type["Model"], filter_clauses: List, select_related: List, ) -> None: - self._select_related = select_related - self.filter_clauses = filter_clauses + self._select_related = select_related[:] + self.filter_clauses = filter_clauses[:] self.model_cls = model_cls self.table = self.model_cls.Meta.table diff --git a/ormar/queryset/filter_query.py b/ormar/queryset/filter_query.py index 8db8185..f55d4e0 100644 --- a/ormar/queryset/filter_query.py +++ b/ormar/queryset/filter_query.py @@ -4,7 +4,8 @@ import sqlalchemy class FilterQuery: - def __init__(self, filter_clauses: List) -> None: + def __init__(self, filter_clauses: List, exclude: bool = False) -> None: + self.exclude = exclude self.filter_clauses = filter_clauses def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: @@ -13,5 +14,6 @@ class FilterQuery: clause = self.filter_clauses[0] else: clause = sqlalchemy.sql.and_(*self.filter_clauses) + clause = sqlalchemy.sql.not_(clause) if self.exclude else clause expr = expr.where(clause) return expr diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index 2d59f58..fa6ed74 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -3,7 +3,7 @@ from typing import List, NamedTuple, TYPE_CHECKING, Tuple, Type import sqlalchemy from sqlalchemy import text -from ormar.fields import ManyToManyField # noqa I100 +from ormar.fields import ManyToManyField # noqa I100 from ormar.relations import AliasManager if TYPE_CHECKING: # pragma no cover diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 6fdbdd4..b07612a 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -12,18 +12,20 @@ if TYPE_CHECKING: # pragma no cover class Query: - def __init__( + def __init__( # noqa CFQ002 self, model_cls: Type["Model"], filter_clauses: List, + exclude_clauses: List, select_related: List, limit_count: int, offset: int, ) -> None: self.query_offset = offset self.limit_count = limit_count - self._select_related = select_related - self.filter_clauses = filter_clauses + self._select_related = select_related[:] + self.filter_clauses = filter_clauses[:] + self.exclude_clauses = exclude_clauses[:] self.model_cls = model_cls self.table = self.model_cls.Meta.table @@ -78,6 +80,9 @@ class Query: self, expr: sqlalchemy.sql.select ) -> sqlalchemy.sql.select: expr = FilterQuery(filter_clauses=self.filter_clauses).apply(expr) + expr = FilterQuery(filter_clauses=self.exclude_clauses, exclude=True).apply( + expr + ) expr = LimitQuery(limit_count=self.limit_count).apply(expr) expr = OffsetQuery(query_offset=self.query_offset).apply(expr) expr = OrderQuery(order_bys=self.order_bys).apply(expr) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index ef3d636..46edf65 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -14,16 +14,18 @@ if TYPE_CHECKING: # pragma no cover class QuerySet: - def __init__( + def __init__( # noqa CFQ002 self, model_cls: Type["Model"] = None, filter_clauses: List = None, + exclude_clauses: List = None, select_related: List = None, limit_count: int = None, offset: int = None, ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses + self.exclude_clauses = [] if exclude_clauses is None else exclude_clauses self._select_related = [] if select_related is None else select_related self.limit_count = limit_count self.query_offset = offset @@ -40,6 +42,12 @@ class QuerySet: rows = self.model_cls.merge_instances_list(result_rows) return rows + def _populate_default_values(self, new_kwargs: dict) -> dict: + for field_name, field in self.model_cls.Meta.model_fields.items(): + if field_name not in new_kwargs and field.has_default(): + new_kwargs[field_name] = field.get_default() + return new_kwargs + def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict: pkname = self.model_cls.Meta.pkname pk = self.model_cls.Meta.model_fields[pkname] @@ -69,6 +77,7 @@ class QuerySet: model_cls=self.model_cls, select_related=self._select_related, filter_clauses=self.filter_clauses, + exclude_clauses=self.exclude_clauses, offset=self.query_offset, limit_count=self.limit_count, ) @@ -76,22 +85,32 @@ class QuerySet: # print(exp.compile(compile_kwargs={"literal_binds": True})) return exp - def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 + def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003 qryclause = QueryClause( model_cls=self.model_cls, select_related=self._select_related, filter_clauses=self.filter_clauses, ) filter_clauses, select_related = qryclause.filter(**kwargs) + if _exclude: + exclude_clauses = filter_clauses + filter_clauses = self.filter_clauses + else: + exclude_clauses = self.exclude_clauses + filter_clauses = filter_clauses return self.__class__( model_cls=self.model_cls, filter_clauses=filter_clauses, + exclude_clauses=exclude_clauses, select_related=select_related, limit_count=self.limit_count, offset=self.query_offset, ) + def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003 + return self.filter(_exclude=True, **kwargs) + def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet": if not isinstance(related, (list, tuple)): related = [related] @@ -100,6 +119,7 @@ class QuerySet: return self.__class__( model_cls=self.model_cls, filter_clauses=self.filter_clauses, + exclude_clauses=self.exclude_clauses, select_related=related, limit_count=self.limit_count, offset=self.query_offset, @@ -127,6 +147,7 @@ class QuerySet: return self.__class__( model_cls=self.model_cls, filter_clauses=self.filter_clauses, + exclude_clauses=self.exclude_clauses, select_related=self._select_related, limit_count=limit_count, offset=self.query_offset, @@ -136,6 +157,7 @@ class QuerySet: return self.__class__( model_cls=self.model_cls, filter_clauses=self.filter_clauses, + exclude_clauses=self.exclude_clauses, select_related=self._select_related, limit_count=self.limit_count, offset=offset, @@ -177,6 +199,7 @@ class QuerySet: new_kwargs = dict(**kwargs) new_kwargs = self._remove_pk_from_kwargs(new_kwargs) new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs) + new_kwargs = self._populate_default_values(new_kwargs) expr = self.table.insert() expr = expr.values(**new_kwargs) diff --git a/tests/test_models.py b/tests/test_models.py index 79381fe..49f8ed3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -165,6 +165,18 @@ async def test_model_filter(): products = await Product.objects.all(name__icontains="T") assert len(products) == 2 + products = await Product.objects.exclude(rating__gte=4).all() + assert len(products) == 1 + + products = await Product.objects.exclude(rating__gte=4, in_stock=True).all() + assert len(products) == 2 + + products = await Product.objects.exclude(in_stock=True).all() + assert len(products) == 1 + + products = await Product.objects.exclude(name__icontains="T").all() + assert len(products) == 1 + # Test escaping % character from icontains, contains, and iexact await Product.objects.create(name="100%-Cotton", rating=3) await Product.objects.create(name="Cotton-100%-Egyptian", rating=3)