diff --git a/.coverage b/.coverage index d262878..acd8ef5 100644 Binary files a/.coverage and b/.coverage differ diff --git a/ormar/__init__.py b/ormar/__init__.py index 90db22e..351f6b1 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -18,6 +18,14 @@ from ormar.models import Model from ormar.queryset import QuerySet from ormar.relations import RelationType + +class UndefinedType: # pragma no cover + def __repr__(self) -> str: + return "OrmarUndefined" + + +Undefined = UndefinedType() + __version__ = "0.3.0" __all__ = [ "Integer", @@ -40,4 +48,5 @@ __all__ = [ "ForeignKey", "QuerySet", "RelationType", + "Undefined", ] diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 6f45032..5cb507c 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -112,41 +112,64 @@ class Query: join_params.from_table, to_table ) if alias not in self.used_aliases: - if join_params.prev_model.Meta.model_fields[part].virtual or is_multi: - to_key = next( - ( - v - for k, v in model_cls.Meta.model_fields.items() - if self._is_target_relation_key(v, join_params.prev_model) - ), - None, - ).name - from_key = model_cls.Meta.pkname - else: - to_key = model_cls.Meta.pkname - from_key = part - - on_clause = self.on_clause( - previous_alias=join_params.previous_alias, - alias=alias, - from_clause=f"{join_params.from_table}.{from_key}", - to_clause=f"{to_table}.{to_key}", - ) - target_table = self.relation_manager.prefixed_table_name(alias, to_table) - self.select_from = sqlalchemy.sql.outerjoin( - self.select_from, target_table, on_clause - ) - self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}")) - self.columns.extend( - self.relation_manager.prefixed_columns(alias, model_cls.Meta.table) - ) - self.used_aliases.append(alias) + self._process_join(join_params, is_multi, model_cls, part, alias) previous_alias = alias from_table = to_table prev_model = model_cls return JoinParameters(prev_model, previous_alias, from_table, model_cls) + def _process_join( + self, + join_params: JoinParameters, + is_multi: bool, + model_cls: Type["Model"], + part: str, + alias: str, + ) -> None: + to_table = model_cls.Meta.table.name + to_key, from_key = self._get_to_and_from_keys( + join_params, is_multi, model_cls, part + ) + + on_clause = self.on_clause( + previous_alias=join_params.previous_alias, + alias=alias, + from_clause=f"{join_params.from_table}.{from_key}", + to_clause=f"{to_table}.{to_key}", + ) + target_table = self.relation_manager.prefixed_table_name(alias, to_table) + self.select_from = sqlalchemy.sql.outerjoin( + self.select_from, target_table, on_clause + ) + self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}")) + self.columns.extend( + self.relation_manager.prefixed_columns(alias, model_cls.Meta.table) + ) + self.used_aliases.append(alias) + + def _get_to_and_from_keys( + self, + join_params: JoinParameters, + is_multi: bool, + model_cls: Type["Model"], + part: str, + ) -> Tuple[str, str]: + if join_params.prev_model.Meta.model_fields[part].virtual or is_multi: + to_key = next( + ( + v + for k, v in model_cls.Meta.model_fields.items() + if self._is_target_relation_key(v, join_params.prev_model) + ), + None, + ).name + from_key = model_cls.Meta.pkname + else: + to_key = model_cls.Meta.pkname + from_key = part + return to_key, from_key + def filter(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: # noqa A003 if self.filter_clauses: if len(self.filter_clauses) == 1: diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 4bc5e7f..a4c48d3 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1,4 +1,4 @@ -from typing import Any, List, TYPE_CHECKING, Tuple, Type, Union +from typing import Any, List, Mapping, TYPE_CHECKING, Tuple, Type, Union import databases import sqlalchemy @@ -31,6 +31,30 @@ class QuerySet: def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": return self.__class__(model_cls=owner) + def _process_query_result_rows(self, rows: List[Mapping]) -> List["Model"]: + result_rows = [ + self.model_cls.from_row(row, select_related=self._select_related) + for row in rows + ] + rows = self.model_cls.merge_instances_list(result_rows) + return rows + + def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict: + pkname = self.model_cls.Meta.pkname + pk = self.model_cls.Meta.model_fields[pkname] + if new_kwargs.get(pkname, ormar.Undefined) is None and ( + pk.nullable or pk.autoincrement + ): + del new_kwargs[pkname] + return new_kwargs + + @staticmethod + def check_single_result_rows_count(rows: List["Model"]) -> None: + if not rows: + raise NoMatch() + if len(rows) > 1: + raise MultipleMatches() + @property def database(self) -> databases.Database: return self.model_cls.Meta.database @@ -128,30 +152,20 @@ class QuerySet: return await self.filter(**kwargs).first() rows = await self.limit(1).all() - if rows: - return rows[0] + self.check_single_result_rows_count(rows) + return rows[0] async def get(self, **kwargs: Any) -> "Model": if kwargs: return await self.filter(**kwargs).get() + expr = self.build_select_expression() if not self.filter_clauses: - expr = self.build_select_expression().limit(2) - else: - expr = self.build_select_expression() + expr = expr.limit(2) rows = await self.database.fetch_all(expr) - - result_rows = [ - self.model_cls.from_row(row, select_related=self._select_related) - for row in rows - ] - rows = self.model_cls.merge_instances_list(result_rows) - - if not rows: - raise NoMatch() - if len(rows) > 1: - raise MultipleMatches() + rows = self._process_query_result_rows(rows) + self.check_single_result_rows_count(rows) return rows[0] async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 @@ -159,31 +173,17 @@ class QuerySet: return await self.filter(**kwargs).all() expr = self.build_select_expression() - # breakpoint() rows = await self.database.fetch_all(expr) - result_rows = [ - self.model_cls.from_row(row, select_related=self._select_related) - for row in rows - ] - result_rows = self.model_cls.merge_instances_list(result_rows) + result_rows = self._process_query_result_rows(rows) return result_rows async def create(self, **kwargs: Any) -> "Model": new_kwargs = dict(**kwargs) - - # Remove primary key when None to prevent not null constraint in postgresql. - pkname = self.model_cls.Meta.pkname - pk = self.model_cls.Meta.model_fields[pkname] - if ( - pkname in new_kwargs - and new_kwargs.get(pkname) is None - and (pk.nullable or pk.autoincrement) - ): - del new_kwargs[pkname] - + new_kwargs = self._remove_pk_from_kwargs(new_kwargs) new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs) + expr = self.table.insert() expr = expr.values(**new_kwargs) diff --git a/tests/test_models.py b/tests/test_models.py index 1c00ef3..758c1bd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,7 +4,7 @@ import pytest import sqlalchemy import ormar -from ormar.exceptions import QueryDefinitionError +from ormar.exceptions import QueryDefinitionError, NoMatch from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -230,4 +230,5 @@ async def test_model_first(): assert await User.objects.first() == tom assert await User.objects.first(name="Jane") == jane assert await User.objects.filter(name="Jane").first() == jane - assert await User.objects.filter(name="Lucy").first() is None + with pytest.raises(NoMatch): + await User.objects.filter(name="Lucy").first()