some refactors and cleanup
This commit is contained in:
@ -18,6 +18,14 @@ from ormar.models import Model
|
|||||||
from ormar.queryset import QuerySet
|
from ormar.queryset import QuerySet
|
||||||
from ormar.relations import RelationType
|
from ormar.relations import RelationType
|
||||||
|
|
||||||
|
|
||||||
|
class UndefinedType: # pragma no cover
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "OrmarUndefined"
|
||||||
|
|
||||||
|
|
||||||
|
Undefined = UndefinedType()
|
||||||
|
|
||||||
__version__ = "0.3.0"
|
__version__ = "0.3.0"
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Integer",
|
"Integer",
|
||||||
@ -40,4 +48,5 @@ __all__ = [
|
|||||||
"ForeignKey",
|
"ForeignKey",
|
||||||
"QuerySet",
|
"QuerySet",
|
||||||
"RelationType",
|
"RelationType",
|
||||||
|
"Undefined",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -112,19 +112,25 @@ class Query:
|
|||||||
join_params.from_table, to_table
|
join_params.from_table, to_table
|
||||||
)
|
)
|
||||||
if alias not in self.used_aliases:
|
if alias not in self.used_aliases:
|
||||||
if join_params.prev_model.Meta.model_fields[part].virtual or is_multi:
|
self._process_join(join_params, is_multi, model_cls, part, alias)
|
||||||
to_key = next(
|
|
||||||
(
|
previous_alias = alias
|
||||||
v
|
from_table = to_table
|
||||||
for k, v in model_cls.Meta.model_fields.items()
|
prev_model = model_cls
|
||||||
if self._is_target_relation_key(v, join_params.prev_model)
|
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
||||||
),
|
|
||||||
None,
|
def _process_join(
|
||||||
).name
|
self,
|
||||||
from_key = model_cls.Meta.pkname
|
join_params: JoinParameters,
|
||||||
else:
|
is_multi: bool,
|
||||||
to_key = model_cls.Meta.pkname
|
model_cls: Type["Model"],
|
||||||
from_key = part
|
part: str,
|
||||||
|
alias: str,
|
||||||
|
) -> None:
|
||||||
|
to_table = model_cls.Meta.table.name
|
||||||
|
to_key, from_key = self._get_to_and_from_keys(
|
||||||
|
join_params, is_multi, model_cls, part
|
||||||
|
)
|
||||||
|
|
||||||
on_clause = self.on_clause(
|
on_clause = self.on_clause(
|
||||||
previous_alias=join_params.previous_alias,
|
previous_alias=join_params.previous_alias,
|
||||||
@ -142,10 +148,27 @@ class Query:
|
|||||||
)
|
)
|
||||||
self.used_aliases.append(alias)
|
self.used_aliases.append(alias)
|
||||||
|
|
||||||
previous_alias = alias
|
def _get_to_and_from_keys(
|
||||||
from_table = to_table
|
self,
|
||||||
prev_model = model_cls
|
join_params: JoinParameters,
|
||||||
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
|
is_multi: bool,
|
||||||
|
model_cls: Type["Model"],
|
||||||
|
part: str,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
if join_params.prev_model.Meta.model_fields[part].virtual or is_multi:
|
||||||
|
to_key = next(
|
||||||
|
(
|
||||||
|
v
|
||||||
|
for k, v in model_cls.Meta.model_fields.items()
|
||||||
|
if self._is_target_relation_key(v, join_params.prev_model)
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
).name
|
||||||
|
from_key = model_cls.Meta.pkname
|
||||||
|
else:
|
||||||
|
to_key = model_cls.Meta.pkname
|
||||||
|
from_key = part
|
||||||
|
return to_key, from_key
|
||||||
|
|
||||||
def filter(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: # noqa A003
|
def filter(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: # noqa A003
|
||||||
if self.filter_clauses:
|
if self.filter_clauses:
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, TYPE_CHECKING, Tuple, Type, Union
|
from typing import Any, List, Mapping, TYPE_CHECKING, Tuple, Type, Union
|
||||||
|
|
||||||
import databases
|
import databases
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
@ -31,6 +31,30 @@ class QuerySet:
|
|||||||
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
|
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
|
||||||
return self.__class__(model_cls=owner)
|
return self.__class__(model_cls=owner)
|
||||||
|
|
||||||
|
def _process_query_result_rows(self, rows: List[Mapping]) -> List["Model"]:
|
||||||
|
result_rows = [
|
||||||
|
self.model_cls.from_row(row, select_related=self._select_related)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
rows = self.model_cls.merge_instances_list(result_rows)
|
||||||
|
return rows
|
||||||
|
|
||||||
|
def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict:
|
||||||
|
pkname = self.model_cls.Meta.pkname
|
||||||
|
pk = self.model_cls.Meta.model_fields[pkname]
|
||||||
|
if new_kwargs.get(pkname, ormar.Undefined) is None and (
|
||||||
|
pk.nullable or pk.autoincrement
|
||||||
|
):
|
||||||
|
del new_kwargs[pkname]
|
||||||
|
return new_kwargs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_single_result_rows_count(rows: List["Model"]) -> None:
|
||||||
|
if not rows:
|
||||||
|
raise NoMatch()
|
||||||
|
if len(rows) > 1:
|
||||||
|
raise MultipleMatches()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def database(self) -> databases.Database:
|
def database(self) -> databases.Database:
|
||||||
return self.model_cls.Meta.database
|
return self.model_cls.Meta.database
|
||||||
@ -128,30 +152,20 @@ class QuerySet:
|
|||||||
return await self.filter(**kwargs).first()
|
return await self.filter(**kwargs).first()
|
||||||
|
|
||||||
rows = await self.limit(1).all()
|
rows = await self.limit(1).all()
|
||||||
if rows:
|
self.check_single_result_rows_count(rows)
|
||||||
return rows[0]
|
return rows[0]
|
||||||
|
|
||||||
async def get(self, **kwargs: Any) -> "Model":
|
async def get(self, **kwargs: Any) -> "Model":
|
||||||
if kwargs:
|
if kwargs:
|
||||||
return await self.filter(**kwargs).get()
|
return await self.filter(**kwargs).get()
|
||||||
|
|
||||||
if not self.filter_clauses:
|
|
||||||
expr = self.build_select_expression().limit(2)
|
|
||||||
else:
|
|
||||||
expr = self.build_select_expression()
|
expr = self.build_select_expression()
|
||||||
|
if not self.filter_clauses:
|
||||||
|
expr = expr.limit(2)
|
||||||
|
|
||||||
rows = await self.database.fetch_all(expr)
|
rows = await self.database.fetch_all(expr)
|
||||||
|
rows = self._process_query_result_rows(rows)
|
||||||
result_rows = [
|
self.check_single_result_rows_count(rows)
|
||||||
self.model_cls.from_row(row, select_related=self._select_related)
|
|
||||||
for row in rows
|
|
||||||
]
|
|
||||||
rows = self.model_cls.merge_instances_list(result_rows)
|
|
||||||
|
|
||||||
if not rows:
|
|
||||||
raise NoMatch()
|
|
||||||
if len(rows) > 1:
|
|
||||||
raise MultipleMatches()
|
|
||||||
return rows[0]
|
return rows[0]
|
||||||
|
|
||||||
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
|
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
|
||||||
@ -159,31 +173,17 @@ class QuerySet:
|
|||||||
return await self.filter(**kwargs).all()
|
return await self.filter(**kwargs).all()
|
||||||
|
|
||||||
expr = self.build_select_expression()
|
expr = self.build_select_expression()
|
||||||
# breakpoint()
|
|
||||||
rows = await self.database.fetch_all(expr)
|
rows = await self.database.fetch_all(expr)
|
||||||
result_rows = [
|
result_rows = self._process_query_result_rows(rows)
|
||||||
self.model_cls.from_row(row, select_related=self._select_related)
|
|
||||||
for row in rows
|
|
||||||
]
|
|
||||||
result_rows = self.model_cls.merge_instances_list(result_rows)
|
|
||||||
|
|
||||||
return result_rows
|
return result_rows
|
||||||
|
|
||||||
async def create(self, **kwargs: Any) -> "Model":
|
async def create(self, **kwargs: Any) -> "Model":
|
||||||
|
|
||||||
new_kwargs = dict(**kwargs)
|
new_kwargs = dict(**kwargs)
|
||||||
|
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
|
||||||
# Remove primary key when None to prevent not null constraint in postgresql.
|
|
||||||
pkname = self.model_cls.Meta.pkname
|
|
||||||
pk = self.model_cls.Meta.model_fields[pkname]
|
|
||||||
if (
|
|
||||||
pkname in new_kwargs
|
|
||||||
and new_kwargs.get(pkname) is None
|
|
||||||
and (pk.nullable or pk.autoincrement)
|
|
||||||
):
|
|
||||||
del new_kwargs[pkname]
|
|
||||||
|
|
||||||
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
|
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
|
||||||
|
|
||||||
expr = self.table.insert()
|
expr = self.table.insert()
|
||||||
expr = expr.values(**new_kwargs)
|
expr = expr.values(**new_kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import pytest
|
|||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
import ormar
|
import ormar
|
||||||
from ormar.exceptions import QueryDefinitionError
|
from ormar.exceptions import QueryDefinitionError, NoMatch
|
||||||
from tests.settings import DATABASE_URL
|
from tests.settings import DATABASE_URL
|
||||||
|
|
||||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||||
@ -230,4 +230,5 @@ async def test_model_first():
|
|||||||
assert await User.objects.first() == tom
|
assert await User.objects.first() == tom
|
||||||
assert await User.objects.first(name="Jane") == jane
|
assert await User.objects.first(name="Jane") == jane
|
||||||
assert await User.objects.filter(name="Jane").first() == jane
|
assert await User.objects.filter(name="Jane").first() == jane
|
||||||
assert await User.objects.filter(name="Lucy").first() is None
|
with pytest.raises(NoMatch):
|
||||||
|
await User.objects.filter(name="Lucy").first()
|
||||||
|
|||||||
Reference in New Issue
Block a user