add fastapi tests with inheritance and relations, more docstrings in queryset
This commit is contained in:
@ -1,3 +1,6 @@
|
|||||||
|
"""
|
||||||
|
Contains QuerySet and different Query classes to allow for constructing of sql queries.
|
||||||
|
"""
|
||||||
from ormar.queryset.filter_query import FilterQuery
|
from ormar.queryset.filter_query import FilterQuery
|
||||||
from ormar.queryset.limit_query import LimitQuery
|
from ormar.queryset.limit_query import LimitQuery
|
||||||
from ormar.queryset.offset_query import OffsetQuery
|
from ormar.queryset.offset_query import OffsetQuery
|
||||||
|
|||||||
@ -29,6 +29,10 @@ ESCAPE_CHARACTERS = ["%", "_"]
|
|||||||
|
|
||||||
|
|
||||||
class QueryClause:
|
class QueryClause:
|
||||||
|
"""
|
||||||
|
Constructs where clauses from strings passed as arguments
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
|
self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -42,7 +46,16 @@ class QueryClause:
|
|||||||
def filter( # noqa: A003
|
def filter( # noqa: A003
|
||||||
self, **kwargs: Any
|
self, **kwargs: Any
|
||||||
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
|
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
|
||||||
|
"""
|
||||||
|
Main external access point that processes the clauses into sqlalchemy text
|
||||||
|
clauses and updates select_related list with implicit related tables
|
||||||
|
mentioned in select_related strings but not included in select_related.
|
||||||
|
|
||||||
|
:param kwargs: key, value pair with column names and values
|
||||||
|
:type kwargs: Any
|
||||||
|
:return: Tuple with list of where clauses and updated select_related list
|
||||||
|
:rtype: Tuple[List[sqlalchemy.sql.elements.TextClause], List[str]]
|
||||||
|
"""
|
||||||
if kwargs.get("pk"):
|
if kwargs.get("pk"):
|
||||||
pk_name = self.model_cls.get_column_alias(self.model_cls.Meta.pkname)
|
pk_name = self.model_cls.get_column_alias(self.model_cls.Meta.pkname)
|
||||||
kwargs[pk_name] = kwargs.pop("pk")
|
kwargs[pk_name] = kwargs.pop("pk")
|
||||||
@ -54,6 +67,16 @@ class QueryClause:
|
|||||||
def _populate_filter_clauses(
|
def _populate_filter_clauses(
|
||||||
self, **kwargs: Any
|
self, **kwargs: Any
|
||||||
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
|
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
|
||||||
|
"""
|
||||||
|
Iterates all clauses and extracts used operator and field from related
|
||||||
|
models if needed. Based on the chain of related names the target table
|
||||||
|
is determined and the final clause is escaped if needed and compiled.
|
||||||
|
|
||||||
|
:param kwargs: key, value pair with column names and values
|
||||||
|
:type kwargs: Any
|
||||||
|
:return: Tuple with list of where clauses and updated select_related list
|
||||||
|
:rtype: Tuple[List[sqlalchemy.sql.elements.TextClause], List[str]]
|
||||||
|
"""
|
||||||
filter_clauses = self.filter_clauses
|
filter_clauses = self.filter_clauses
|
||||||
select_related = list(self._select_related)
|
select_related = list(self._select_related)
|
||||||
|
|
||||||
@ -100,6 +123,24 @@ class QueryClause:
|
|||||||
table: sqlalchemy.Table,
|
table: sqlalchemy.Table,
|
||||||
table_prefix: str,
|
table_prefix: str,
|
||||||
) -> sqlalchemy.sql.expression.TextClause:
|
) -> sqlalchemy.sql.expression.TextClause:
|
||||||
|
"""
|
||||||
|
Escapes characters if it's required.
|
||||||
|
Substitutes values of the models if value is a ormar Model with its pk value.
|
||||||
|
Compiles the clause.
|
||||||
|
|
||||||
|
:param value: value of the filter
|
||||||
|
:type value: Any
|
||||||
|
:param op: filter operator
|
||||||
|
:type op: str
|
||||||
|
:param column: column on which filter should be applied
|
||||||
|
:type column: sqlalchemy.sql.schema.Column
|
||||||
|
:param table: table on which filter should be applied
|
||||||
|
:type table: sqlalchemy.sql.schema.Table
|
||||||
|
:param table_prefix: prefix from AliasManager
|
||||||
|
:type table_prefix: str
|
||||||
|
:return: complied and escaped clause
|
||||||
|
:rtype: sqlalchemy.sql.elements.TextClause
|
||||||
|
"""
|
||||||
value, has_escaped_character = self._escape_characters_in_clause(op, value)
|
value, has_escaped_character = self._escape_characters_in_clause(op, value)
|
||||||
|
|
||||||
if isinstance(value, ormar.Model):
|
if isinstance(value, ormar.Model):
|
||||||
@ -119,7 +160,21 @@ class QueryClause:
|
|||||||
def _determine_filter_target_table(
|
def _determine_filter_target_table(
|
||||||
self, related_parts: List[str], select_related: List[str]
|
self, related_parts: List[str], select_related: List[str]
|
||||||
) -> Tuple[List[str], str, Type["Model"]]:
|
) -> Tuple[List[str], str, Type["Model"]]:
|
||||||
|
"""
|
||||||
|
Adds related strings to select_related list otherwise the clause would fail as
|
||||||
|
the required columns would not be present. That means that select_related
|
||||||
|
list is filled with missing values present in filters.
|
||||||
|
|
||||||
|
Walks the relation to retrieve the actual model on which the clause should be
|
||||||
|
constructed, extracts alias based on last relation leading to target model.
|
||||||
|
|
||||||
|
:param related_parts: list of split parts of related string
|
||||||
|
:type related_parts: List[str]
|
||||||
|
:param select_related: list of related models
|
||||||
|
:type select_related: List[str]
|
||||||
|
:return: list of related models, table_prefix, final model class
|
||||||
|
:rtype: Tuple[List[str], str, Type[Model]]
|
||||||
|
"""
|
||||||
table_prefix = ""
|
table_prefix = ""
|
||||||
model_cls = self.model_cls
|
model_cls = self.model_cls
|
||||||
select_related = [relation for relation in select_related]
|
select_related = [relation for relation in select_related]
|
||||||
@ -152,6 +207,23 @@ class QueryClause:
|
|||||||
table_prefix: str,
|
table_prefix: str,
|
||||||
modifiers: Dict,
|
modifiers: Dict,
|
||||||
) -> sqlalchemy.sql.expression.TextClause:
|
) -> sqlalchemy.sql.expression.TextClause:
|
||||||
|
"""
|
||||||
|
Compiles the clause to str using appropriate database dialect, replace columns
|
||||||
|
names with aliased names and converts it back to TextClause.
|
||||||
|
|
||||||
|
:param clause: original not compiled clause
|
||||||
|
:type clause: sqlalchemy.sql.elements.BinaryExpression
|
||||||
|
:param column: column on which filter should be applied
|
||||||
|
:type column: sqlalchemy.sql.schema.Column
|
||||||
|
:param table: table on which filter should be applied
|
||||||
|
:type table: sqlalchemy.sql.schema.Table
|
||||||
|
:param table_prefix: prefix from AliasManager
|
||||||
|
:type table_prefix: str
|
||||||
|
:param modifiers: sqlalchemy modifiers - used only to escape chars here
|
||||||
|
:type modifiers: Dict[str, NoneType]
|
||||||
|
:return: compiled and escaped clause
|
||||||
|
:rtype: sqlalchemy.sql.elements.TextClause
|
||||||
|
"""
|
||||||
for modifier, modifier_value in modifiers.items():
|
for modifier, modifier_value in modifiers.items():
|
||||||
clause.modifiers[modifier] = modifier_value
|
clause.modifiers[modifier] = modifier_value
|
||||||
|
|
||||||
@ -169,6 +241,19 @@ class QueryClause:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _escape_characters_in_clause(op: str, value: Any) -> Tuple[Any, bool]:
|
def _escape_characters_in_clause(op: str, value: Any) -> Tuple[Any, bool]:
|
||||||
|
"""
|
||||||
|
Escapes the special characters ["%", "_"] if needed.
|
||||||
|
Adds `%` for `like` queries.
|
||||||
|
|
||||||
|
:raises: QueryDefinitionError if contains or icontains is used with
|
||||||
|
ormar model instance
|
||||||
|
:param op: operator used in query
|
||||||
|
:type op: str
|
||||||
|
:param value: value of the filter
|
||||||
|
:type value: Any
|
||||||
|
:return: escaped value and flag if escaping is needed
|
||||||
|
:rtype: Tuple[Any, bool]
|
||||||
|
"""
|
||||||
has_escaped_character = False
|
has_escaped_character = False
|
||||||
|
|
||||||
if op not in [
|
if op not in [
|
||||||
@ -202,6 +287,14 @@ class QueryClause:
|
|||||||
def _extract_operator_field_and_related(
|
def _extract_operator_field_and_related(
|
||||||
parts: List[str],
|
parts: List[str],
|
||||||
) -> Tuple[str, str, Optional[List]]:
|
) -> Tuple[str, str, Optional[List]]:
|
||||||
|
"""
|
||||||
|
Splits filter query key and extracts required parts.
|
||||||
|
|
||||||
|
:param parts: split filter query key
|
||||||
|
:type parts: List[str]
|
||||||
|
:return: operator, field_name, list of related parts
|
||||||
|
:rtype: Tuple[str, str, Optional[List]]
|
||||||
|
"""
|
||||||
if parts[-1] in FILTER_OPERATORS:
|
if parts[-1] in FILTER_OPERATORS:
|
||||||
op = parts[-1]
|
op = parts[-1]
|
||||||
field_name = parts[-2]
|
field_name = parts[-2]
|
||||||
|
|||||||
@ -22,6 +22,10 @@ if TYPE_CHECKING: # pragma no cover
|
|||||||
|
|
||||||
|
|
||||||
class JoinParameters(NamedTuple):
|
class JoinParameters(NamedTuple):
|
||||||
|
"""
|
||||||
|
Named tuple that holds set of parameters passed during join construction.
|
||||||
|
"""
|
||||||
|
|
||||||
prev_model: Type["Model"]
|
prev_model: Type["Model"]
|
||||||
previous_alias: str
|
previous_alias: str
|
||||||
from_table: str
|
from_table: str
|
||||||
@ -48,13 +52,36 @@ class SqlJoin:
|
|||||||
self.sorted_orders = sorted_orders
|
self.sorted_orders = sorted_orders
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def relation_manager(model_cls: Type["Model"]) -> AliasManager:
|
def alias_manager(model_cls: Type["Model"]) -> AliasManager:
|
||||||
|
"""
|
||||||
|
Shortcut for ormars model AliasManager stored on Meta.
|
||||||
|
|
||||||
|
:param model_cls: ormar Model class
|
||||||
|
:type model_cls: Type[Model]
|
||||||
|
:return: alias manager from model's Meta
|
||||||
|
:rtype: AliasManager
|
||||||
|
"""
|
||||||
return model_cls.Meta.alias_manager
|
return model_cls.Meta.alias_manager
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def on_clause(
|
def on_clause(
|
||||||
previous_alias: str, alias: str, from_clause: str, to_clause: str,
|
previous_alias: str, alias: str, from_clause: str, to_clause: str,
|
||||||
) -> text:
|
) -> text:
|
||||||
|
"""
|
||||||
|
Receives aliases and names of both ends of the join and combines them
|
||||||
|
into one text clause used in joins.
|
||||||
|
|
||||||
|
:param previous_alias: alias of previous table
|
||||||
|
:type previous_alias: str
|
||||||
|
:param alias: alias of current table
|
||||||
|
:type alias: str
|
||||||
|
:param from_clause: from table name
|
||||||
|
:type from_clause: str
|
||||||
|
:param to_clause: to table name
|
||||||
|
:type to_clause: str
|
||||||
|
:return: clause combining all strings
|
||||||
|
:rtype: sqlalchemy.text
|
||||||
|
"""
|
||||||
left_part = f"{alias}_{to_clause}"
|
left_part = f"{alias}_{to_clause}"
|
||||||
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
|
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
|
||||||
return text(f"{left_part}={right_part}")
|
return text(f"{left_part}={right_part}")
|
||||||
@ -66,6 +93,20 @@ class SqlJoin:
|
|||||||
exclude_fields: Optional[Union[Set, Dict]],
|
exclude_fields: Optional[Union[Set, Dict]],
|
||||||
nested_name: str,
|
nested_name: str,
|
||||||
) -> Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]:
|
) -> Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]:
|
||||||
|
"""
|
||||||
|
Extract nested fields and exclude_fields if applicable.
|
||||||
|
|
||||||
|
:param model_cls: ormar model class
|
||||||
|
:type model_cls: Type["Model"]
|
||||||
|
:param fields: fields to include
|
||||||
|
:type fields: Optional[Union[Set, Dict]]
|
||||||
|
:param exclude_fields: fields to exclude
|
||||||
|
:type exclude_fields: Optional[Union[Set, Dict]]
|
||||||
|
:param nested_name: name of the nested field
|
||||||
|
:type nested_name: str
|
||||||
|
:return: updated exclude and include fields from nested objects
|
||||||
|
:rtype: Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]
|
||||||
|
"""
|
||||||
fields = model_cls.get_included(fields, nested_name)
|
fields = model_cls.get_included(fields, nested_name)
|
||||||
exclude_fields = model_cls.get_excluded(exclude_fields, nested_name)
|
exclude_fields = model_cls.get_excluded(exclude_fields, nested_name)
|
||||||
return fields, exclude_fields
|
return fields, exclude_fields
|
||||||
@ -73,7 +114,19 @@ class SqlJoin:
|
|||||||
def build_join( # noqa: CCR001
|
def build_join( # noqa: CCR001
|
||||||
self, item: str, join_parameters: JoinParameters
|
self, item: str, join_parameters: JoinParameters
|
||||||
) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]:
|
) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]:
|
||||||
|
"""
|
||||||
|
Main external access point for building a join.
|
||||||
|
Splits the join definition, updates fields and exclude_fields if needed,
|
||||||
|
handles switching to through models for m2m relations, returns updated lists of
|
||||||
|
used_aliases and sort_orders.
|
||||||
|
|
||||||
|
:param item: string with join definition
|
||||||
|
:type item: str
|
||||||
|
:param join_parameters: parameters from previous/ current join
|
||||||
|
:type join_parameters: JoinParameters
|
||||||
|
:return: list of used aliases, select from, list of aliased columns, sort orders
|
||||||
|
:rtype: Tuple[List[str], Join, List[TextClause], collections.OrderedDict]
|
||||||
|
"""
|
||||||
fields = self.fields
|
fields = self.fields
|
||||||
exclude_fields = self.exclude_fields
|
exclude_fields = self.exclude_fields
|
||||||
|
|
||||||
@ -129,6 +182,23 @@ class SqlJoin:
|
|||||||
exclude_fields: Optional[Union[Set, Dict]],
|
exclude_fields: Optional[Union[Set, Dict]],
|
||||||
is_multi: bool = False,
|
is_multi: bool = False,
|
||||||
) -> JoinParameters:
|
) -> JoinParameters:
|
||||||
|
"""
|
||||||
|
Updates used_aliases to not join multiple times to the same table.
|
||||||
|
Updates join parameters with new values.
|
||||||
|
|
||||||
|
:param part: part of the join str definition
|
||||||
|
:type part: str
|
||||||
|
:param join_params: parameters from previous/ current join
|
||||||
|
:type join_params: JoinParameters
|
||||||
|
:param fields: fields to include
|
||||||
|
:type fields: Optional[Union[Set, Dict]]
|
||||||
|
:param exclude_fields: fields to exclude
|
||||||
|
:type exclude_fields: Optional[Union[Set, Dict]]
|
||||||
|
:param is_multi: flag if the relation is m2m
|
||||||
|
:type is_multi: bool
|
||||||
|
:return: updated join parameters
|
||||||
|
:rtype: ormar.queryset.join.JoinParameters
|
||||||
|
"""
|
||||||
if is_multi:
|
if is_multi:
|
||||||
model_cls = join_params.model_cls.Meta.model_fields[part].through
|
model_cls = join_params.model_cls.Meta.model_fields[part].through
|
||||||
else:
|
else:
|
||||||
@ -164,6 +234,34 @@ class SqlJoin:
|
|||||||
fields: Optional[Union[Set, Dict]],
|
fields: Optional[Union[Set, Dict]],
|
||||||
exclude_fields: Optional[Union[Set, Dict]],
|
exclude_fields: Optional[Union[Set, Dict]],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Resolves to and from column names and table names.
|
||||||
|
|
||||||
|
Produces on_clause.
|
||||||
|
|
||||||
|
Performs actual join updating select_from parameter.
|
||||||
|
|
||||||
|
Adds aliases of required column to list of columns to include in query.
|
||||||
|
|
||||||
|
Updates the used aliases list directly.
|
||||||
|
|
||||||
|
Process order_by causes for non m2m relations.
|
||||||
|
|
||||||
|
:param join_params: parameters from previous/ current join
|
||||||
|
:type join_params: JoinParameters
|
||||||
|
:param is_multi: flag if it's m2m relation
|
||||||
|
:type is_multi: bool
|
||||||
|
:param model_cls:
|
||||||
|
:type model_cls: ormar.models.metaclass.ModelMetaclass
|
||||||
|
:param part: name of the field used in join
|
||||||
|
:type part: str
|
||||||
|
:param alias: alias of the current join
|
||||||
|
:type alias: str
|
||||||
|
:param fields: fields to include
|
||||||
|
:type fields: Optional[Union[Set, Dict]]
|
||||||
|
:param exclude_fields: fields to exclude
|
||||||
|
:type exclude_fields: Optional[Union[Set, Dict]]
|
||||||
|
"""
|
||||||
to_table = model_cls.Meta.table.name
|
to_table = model_cls.Meta.table.name
|
||||||
to_key, from_key = self.get_to_and_from_keys(
|
to_key, from_key = self.get_to_and_from_keys(
|
||||||
join_params, is_multi, model_cls, part
|
join_params, is_multi, model_cls, part
|
||||||
@ -175,7 +273,7 @@ class SqlJoin:
|
|||||||
from_clause=f"{join_params.from_table}.{from_key}",
|
from_clause=f"{join_params.from_table}.{from_key}",
|
||||||
to_clause=f"{to_table}.{to_key}",
|
to_clause=f"{to_table}.{to_key}",
|
||||||
)
|
)
|
||||||
target_table = self.relation_manager(model_cls).prefixed_table_name(
|
target_table = self.alias_manager(model_cls).prefixed_table_name(
|
||||||
alias, to_table
|
alias, to_table
|
||||||
)
|
)
|
||||||
self.select_from = sqlalchemy.sql.outerjoin(
|
self.select_from = sqlalchemy.sql.outerjoin(
|
||||||
@ -199,13 +297,21 @@ class SqlJoin:
|
|||||||
use_alias=True,
|
use_alias=True,
|
||||||
)
|
)
|
||||||
self.columns.extend(
|
self.columns.extend(
|
||||||
self.relation_manager(model_cls).prefixed_columns(
|
self.alias_manager(model_cls).prefixed_columns(
|
||||||
alias, model_cls.Meta.table, self_related_fields
|
alias, model_cls.Meta.table, self_related_fields
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.used_aliases.append(alias)
|
self.used_aliases.append(alias)
|
||||||
|
|
||||||
def _switch_many_to_many_order_columns(self, part: str, new_part: str) -> None:
|
def _switch_many_to_many_order_columns(self, part: str, new_part: str) -> None:
|
||||||
|
"""
|
||||||
|
Substitutes the name of the relation with actual model name in m2m order bys.
|
||||||
|
|
||||||
|
:param part: name of the field with relation
|
||||||
|
:type part: str
|
||||||
|
:param new_part: name of the target model
|
||||||
|
:type new_part: str
|
||||||
|
"""
|
||||||
if self.order_columns:
|
if self.order_columns:
|
||||||
split_order_columns = [
|
split_order_columns = [
|
||||||
x.split("__") for x in self.order_columns if "__" in x
|
x.split("__") for x in self.order_columns if "__" in x
|
||||||
@ -219,6 +325,16 @@ class SqlJoin:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_if_condition_apply(condition: List, part: str) -> bool:
|
def _check_if_condition_apply(condition: List, part: str) -> bool:
|
||||||
|
"""
|
||||||
|
Checks filter conditions to find if they apply to current join.
|
||||||
|
|
||||||
|
:param condition: list of parts of condition split by '__'
|
||||||
|
:type condition: List[str]
|
||||||
|
:param part: name of the current relation join.
|
||||||
|
:type part: str
|
||||||
|
:return: result of the check
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
return len(condition) >= 2 and (
|
return len(condition) >= 2 and (
|
||||||
condition[-2] == part or condition[-2][1:] == part
|
condition[-2] == part or condition[-2][1:] == part
|
||||||
)
|
)
|
||||||
@ -226,6 +342,19 @@ class SqlJoin:
|
|||||||
def set_aliased_order_by(
|
def set_aliased_order_by(
|
||||||
self, condition: List[str], alias: str, to_table: str, model_cls: Type["Model"],
|
self, condition: List[str], alias: str, to_table: str, model_cls: Type["Model"],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Substitute hyphens ('-') with descending order.
|
||||||
|
Construct actual sqlalchemy text clause using aliased table and column name.
|
||||||
|
|
||||||
|
:param condition: list of parts of a current condition split by '__'
|
||||||
|
:type condition: List[str]
|
||||||
|
:param alias: alias of the table in current join
|
||||||
|
:type alias: str
|
||||||
|
:param to_table: target table
|
||||||
|
:type to_table: sqlalchemy.sql.elements.quoted_name
|
||||||
|
:param model_cls: ormar model class
|
||||||
|
:type model_cls: ormar.models.metaclass.ModelMetaclass
|
||||||
|
"""
|
||||||
direction = f"{'desc' if condition[0][0] == '-' else ''}"
|
direction = f"{'desc' if condition[0][0] == '-' else ''}"
|
||||||
column_alias = model_cls.get_column_alias(condition[-1])
|
column_alias = model_cls.get_column_alias(condition[-1])
|
||||||
order = text(f"{alias}_{to_table}.{column_alias} {direction}")
|
order = text(f"{alias}_{to_table}.{column_alias} {direction}")
|
||||||
@ -239,6 +368,21 @@ class SqlJoin:
|
|||||||
part: str,
|
part: str,
|
||||||
model_cls: Type["Model"],
|
model_cls: Type["Model"],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Triggers construction of order bys if they are given.
|
||||||
|
Otherwise by default each table is sorted by a primary key column asc.
|
||||||
|
|
||||||
|
:param alias: alias of current table in join
|
||||||
|
:type alias: str
|
||||||
|
:param to_table: target table
|
||||||
|
:type to_table: sqlalchemy.sql.elements.quoted_name
|
||||||
|
:param pkname_alias: alias of the primary key column
|
||||||
|
:type pkname_alias: str
|
||||||
|
:param part: name of the current relation join
|
||||||
|
:type part: str
|
||||||
|
:param model_cls: ormar model class
|
||||||
|
:type model_cls: Type[Model]
|
||||||
|
"""
|
||||||
if self.order_columns:
|
if self.order_columns:
|
||||||
split_order_columns = [
|
split_order_columns = [
|
||||||
x.split("__") for x in self.order_columns if "__" in x
|
x.split("__") for x in self.order_columns if "__" in x
|
||||||
@ -262,6 +406,22 @@ class SqlJoin:
|
|||||||
model_cls: Type["Model"],
|
model_cls: Type["Model"],
|
||||||
part: str,
|
part: str,
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Based on the relation type, name of the relation and previous models and parts
|
||||||
|
stored in JoinParameters it resolves the current to and from keys, which are
|
||||||
|
different for ManyToMany relation, ForeignKey and reverse part of relations.
|
||||||
|
|
||||||
|
:param join_params: parameters from previous/ current join
|
||||||
|
:type join_params: JoinParameters
|
||||||
|
:param is_multi: flag if the relation is of m2m type
|
||||||
|
:type is_multi: bool
|
||||||
|
:param model_cls: ormar model class
|
||||||
|
:type model_cls: Type[Model]
|
||||||
|
:param part: name of the current relation join
|
||||||
|
:type part: str
|
||||||
|
:return: to key and from key
|
||||||
|
:rtype: Tuple[str, str]
|
||||||
|
"""
|
||||||
if is_multi:
|
if is_multi:
|
||||||
to_field = join_params.prev_model.get_name()
|
to_field = join_params.prev_model.get_name()
|
||||||
to_key = model_cls.get_column_alias(to_field)
|
to_key = model_cls.get_column_alias(to_field)
|
||||||
|
|||||||
@ -6,7 +6,17 @@ from fastapi import FastAPI
|
|||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
from tests.settings import DATABASE_URL
|
from tests.settings import DATABASE_URL
|
||||||
from tests.test_inheritance_concrete import Category, Subject, metadata, db as database # type: ignore
|
from tests.test_inheritance_concrete import ( # type: ignore
|
||||||
|
Category,
|
||||||
|
Subject,
|
||||||
|
Person,
|
||||||
|
Bus,
|
||||||
|
Truck,
|
||||||
|
Bus2,
|
||||||
|
Truck2,
|
||||||
|
db as database,
|
||||||
|
metadata,
|
||||||
|
)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.state.database = database
|
app.state.database = database
|
||||||
@ -37,6 +47,56 @@ async def create_category(category: Category):
|
|||||||
return category
|
return category
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/buses/", response_model=Bus)
|
||||||
|
async def create_bus(bus: Bus):
|
||||||
|
await bus.save()
|
||||||
|
return bus
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/buses/{item_id}", response_model=Bus)
|
||||||
|
async def get_bus(item_id: int):
|
||||||
|
bus = await Bus.objects.select_related(["owner", "co_owner"]).get(pk=item_id)
|
||||||
|
return bus
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/trucks/", response_model=Truck)
|
||||||
|
async def create_truck(truck: Truck):
|
||||||
|
await truck.save()
|
||||||
|
return truck
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/persons/", response_model=Person)
|
||||||
|
async def create_person(person: Person):
|
||||||
|
await person.save()
|
||||||
|
return person
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/buses2/", response_model=Bus2)
|
||||||
|
async def create_bus2(bus: Bus2):
|
||||||
|
await bus.save()
|
||||||
|
return bus
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/buses2/{item_id}/add_coowner/", response_model=Bus2)
|
||||||
|
async def add_bus_coowner(item_id: int, person: Person):
|
||||||
|
bus = await Bus2.objects.select_related(["owner", "co_owners"]).get(pk=item_id)
|
||||||
|
await bus.co_owners.add(person)
|
||||||
|
return bus
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/trucks2/", response_model=Truck2)
|
||||||
|
async def create_truck2(truck: Truck2):
|
||||||
|
await truck.save()
|
||||||
|
return truck
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/trucks2/{item_id}/add_coowner/", response_model=Truck2)
|
||||||
|
async def add_truck_coowner(item_id: int, person: Person):
|
||||||
|
truck = await Truck2.objects.select_related(["owner", "co_owners"]).get(pk=item_id)
|
||||||
|
await truck.co_owners.add(person)
|
||||||
|
return truck
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True, scope="module")
|
@pytest.fixture(autouse=True, scope="module")
|
||||||
def create_test_database():
|
def create_test_database():
|
||||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||||
@ -73,3 +133,87 @@ def test_read_main():
|
|||||||
assert sub.name == "Bar"
|
assert sub.name == "Bar"
|
||||||
assert sub.category.pk == cat.pk
|
assert sub.category.pk == cat.pk
|
||||||
assert isinstance(sub.updated_date, datetime.datetime)
|
assert isinstance(sub.updated_date, datetime.datetime)
|
||||||
|
|
||||||
|
|
||||||
|
def test_inheritance_with_relation():
|
||||||
|
client = TestClient(app)
|
||||||
|
with client as client:
|
||||||
|
sam = Person(**client.post("/persons/", json={"name": "Sam"}).json())
|
||||||
|
joe = Person(**client.post("/persons/", json={"name": "Joe"}).json())
|
||||||
|
|
||||||
|
truck_dict = dict(
|
||||||
|
name="Shelby wanna be",
|
||||||
|
max_capacity=1400,
|
||||||
|
owner=sam.dict(),
|
||||||
|
co_owner=joe.dict(),
|
||||||
|
)
|
||||||
|
bus_dict = dict(
|
||||||
|
name="Unicorn", max_persons=50, owner=sam.dict(), co_owner=joe.dict()
|
||||||
|
)
|
||||||
|
unicorn = Bus(**client.post("/buses/", json=bus_dict).json())
|
||||||
|
shelby = Truck(**client.post("/trucks/", json=truck_dict).json())
|
||||||
|
|
||||||
|
assert shelby.name == "Shelby wanna be"
|
||||||
|
assert shelby.owner.name == "Sam"
|
||||||
|
assert shelby.co_owner.name == "Joe"
|
||||||
|
assert shelby.co_owner == joe
|
||||||
|
assert shelby.max_capacity == 1400
|
||||||
|
|
||||||
|
assert unicorn.name == "Unicorn"
|
||||||
|
assert unicorn.owner == sam
|
||||||
|
assert unicorn.owner.name == "Sam"
|
||||||
|
assert unicorn.co_owner.name == "Joe"
|
||||||
|
assert unicorn.max_persons == 50
|
||||||
|
|
||||||
|
unicorn2 = Bus(**client.get(f"/buses/{unicorn.pk}").json())
|
||||||
|
assert unicorn2.name == "Unicorn"
|
||||||
|
assert unicorn2.owner == sam
|
||||||
|
assert unicorn2.owner.name == "Sam"
|
||||||
|
assert unicorn2.co_owner.name == "Joe"
|
||||||
|
assert unicorn2.max_persons == 50
|
||||||
|
|
||||||
|
|
||||||
|
def test_inheritance_with_m2m_relation():
|
||||||
|
client = TestClient(app)
|
||||||
|
with client as client:
|
||||||
|
sam = Person(**client.post("/persons/", json={"name": "Sam"}).json())
|
||||||
|
joe = Person(**client.post("/persons/", json={"name": "Joe"}).json())
|
||||||
|
alex = Person(**client.post("/persons/", json={"name": "Alex"}).json())
|
||||||
|
|
||||||
|
truck_dict = dict(name="Shelby wanna be", max_capacity=2000, owner=sam.dict())
|
||||||
|
bus_dict = dict(name="Unicorn", max_persons=80, owner=sam.dict())
|
||||||
|
|
||||||
|
unicorn = Bus2(**client.post("/buses2/", json=bus_dict).json())
|
||||||
|
shelby = Truck2(**client.post("/trucks2/", json=truck_dict).json())
|
||||||
|
|
||||||
|
unicorn = Bus2(
|
||||||
|
**client.post(f"/buses2/{unicorn.pk}/add_coowner/", json=joe.dict()).json()
|
||||||
|
)
|
||||||
|
unicorn = Bus2(
|
||||||
|
**client.post(f"/buses2/{unicorn.pk}/add_coowner/", json=alex.dict()).json()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert shelby.name == "Shelby wanna be"
|
||||||
|
assert shelby.owner.name == "Sam"
|
||||||
|
assert len(shelby.co_owners) == 0
|
||||||
|
assert shelby.max_capacity == 2000
|
||||||
|
|
||||||
|
assert unicorn.name == "Unicorn"
|
||||||
|
assert unicorn.owner == sam
|
||||||
|
assert unicorn.owner.name == "Sam"
|
||||||
|
assert unicorn.co_owners[0].name == "Joe"
|
||||||
|
assert unicorn.co_owners[1] == alex
|
||||||
|
assert unicorn.max_persons == 80
|
||||||
|
|
||||||
|
client.post(f"/trucks2/{shelby.pk}/add_coowner/", json=alex.dict())
|
||||||
|
|
||||||
|
shelby = Truck2(
|
||||||
|
**client.post(f"/trucks2/{shelby.pk}/add_coowner/", json=joe.dict()).json()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert shelby.name == "Shelby wanna be"
|
||||||
|
assert shelby.owner.name == "Sam"
|
||||||
|
assert len(shelby.co_owners) == 2
|
||||||
|
assert shelby.co_owners[0] == alex
|
||||||
|
assert shelby.co_owners[1] == joe
|
||||||
|
assert shelby.max_capacity == 2000
|
||||||
|
|||||||
Reference in New Issue
Block a user