working basic many to many relationships

This commit is contained in:
collerek
2020-09-14 17:13:27 +02:00
parent 58c3627be7
commit 4674f625df
18 changed files with 791 additions and 244 deletions

View File

@ -5,6 +5,7 @@ from sqlalchemy import text
import ormar # noqa I100
from ormar.exceptions import QueryDefinitionError
from ormar.fields.many_to_many import ManyToManyField
if TYPE_CHECKING: # pragma no cover
from ormar import Model
@ -128,6 +129,10 @@ class QueryClause:
# against which the comparison is being made.
previous_table = model_cls.Meta.tablename
for part in related_parts:
if issubclass(model_cls.Meta.model_fields[part], ManyToManyField):
previous_table = model_cls.Meta.model_fields[
part
].through.Meta.tablename
current_table = model_cls.Meta.model_fields[part].to.Meta.tablename
manager = model_cls.Meta.alias_manager
table_prefix = manager.resolve_relation_join(previous_table, current_table)

View File

@ -4,8 +4,10 @@ import sqlalchemy
from sqlalchemy import text
import ormar # noqa I100
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.relations import AliasManager
from ormar.fields.many_to_many import ManyToManyField
from ormar.relations.alias_manager import AliasManager
if TYPE_CHECKING: # pragma no cover
from ormar import Model
@ -63,6 +65,15 @@ class Query:
)
for part in item.split("__"):
if issubclass(
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
):
_fields = join_parameters.model_cls.Meta.model_fields
new_part = _fields[part].to.get_name()
join_parameters = self._build_join_parameters(
part, join_parameters, is_multi=True
)
part = new_part
join_parameters = self._build_join_parameters(part, join_parameters)
expr = sqlalchemy.sql.select(self.columns)
@ -83,23 +94,30 @@ class Query:
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
return text(f"{left_part}={right_part}")
def _is_target_relation_key(
self, field: BaseField, target_model: Type["Model"]
) -> bool:
return issubclass(field, ForeignKeyField) and field.to.Meta == target_model.Meta
def _build_join_parameters(
self, part: str, join_params: JoinParameters
self, part: str, join_params: JoinParameters, is_multi: bool = False
) -> JoinParameters:
model_cls = join_params.model_cls.Meta.model_fields[part].to
if is_multi:
model_cls = join_params.model_cls.Meta.model_fields[part].through
else:
model_cls = join_params.model_cls.Meta.model_fields[part].to
to_table = model_cls.Meta.table.name
alias = model_cls.Meta.alias_manager.resolve_relation_join(
join_params.from_table, to_table
)
if alias not in self.used_aliases:
if join_params.prev_model.Meta.model_fields[part].virtual:
if join_params.prev_model.Meta.model_fields[part].virtual or is_multi:
to_key = next(
(
v
for k, v in model_cls.Meta.model_fields.items()
if issubclass(v, ForeignKeyField)
and v.to == join_params.prev_model
if self._is_target_relation_key(v, join_params.prev_model)
),
None,
).name
@ -129,16 +147,19 @@ class Query:
prev_model = model_cls
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
def _apply_expression_modifiers(
self, expr: sqlalchemy.sql.select
) -> sqlalchemy.sql.select:
def filter(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: # noqa A003
if self.filter_clauses:
if len(self.filter_clauses) == 1:
clause = self.filter_clauses[0]
else:
clause = sqlalchemy.sql.and_(*self.filter_clauses)
expr = expr.where(clause)
return expr
def _apply_expression_modifiers(
self, expr: sqlalchemy.sql.select
) -> sqlalchemy.sql.select:
expr = self.filter(expr)
if self.limit_count:
expr = expr.limit(self.limit_count)

View File

@ -48,6 +48,7 @@ class QuerySet:
limit_count=self.limit_count,
)
exp = qry.build_select_expression()
# print(exp.compile(compile_kwargs={"literal_binds": True}))
return exp
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
@ -70,7 +71,7 @@ class QuerySet:
if not isinstance(related, (list, tuple)):
related = [related]
related = list(self._select_related) + related
related = list(set(list(self._select_related) + related))
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
@ -82,13 +83,28 @@ class QuerySet:
async def exists(self) -> bool:
expr = self.build_select_expression()
expr = sqlalchemy.exists(expr).select()
# print(expr.compile(compile_kwargs={"literal_binds": True}))
return await self.database.fetch_val(expr)
async def count(self) -> int:
expr = self.build_select_expression().alias("subquery_for_count")
expr = sqlalchemy.func.count().select().select_from(expr)
# print(expr.compile(compile_kwargs={"literal_binds": True}))
return await self.database.fetch_val(expr)
async def delete(self, **kwargs: Any) -> int:
if kwargs:
return await self.filter(**kwargs).delete()
qry = Query(
model_cls=self.model_cls,
select_related=self._select_related,
filter_clauses=self.filter_clauses,
offset=self.query_offset,
limit_count=self.limit_count,
)
expr = qry.filter(self.table.delete())
return await self.database.execute(expr)
def limit(self, limit_count: int) -> "QuerySet":
return self.__class__(
model_cls=self.model_cls,
@ -143,6 +159,7 @@ class QuerySet:
return await self.filter(**kwargs).all()
expr = self.build_select_expression()
# breakpoint()
rows = await self.database.fetch_all(expr)
result_rows = [
self.model_cls.from_row(row, select_related=self._select_related)