add fastapi tests with inheritance and relations, more docstrings in queryset

This commit is contained in:
collerek
2021-01-04 12:43:00 +01:00
parent e4b4d9451d
commit a914be67e2
4 changed files with 404 additions and 4 deletions

View File

@ -1,3 +1,6 @@
"""
Contains QuerySet and different Query classes to allow for constructing of sql queries.
"""
from ormar.queryset.filter_query import FilterQuery
from ormar.queryset.limit_query import LimitQuery
from ormar.queryset.offset_query import OffsetQuery

View File

@ -29,6 +29,10 @@ ESCAPE_CHARACTERS = ["%", "_"]
class QueryClause:
"""
Constructs where clauses from strings passed as arguments
"""
def __init__(
self, model_cls: Type["Model"], filter_clauses: List, select_related: List,
) -> None:
@ -42,7 +46,16 @@ class QueryClause:
def filter( # noqa: A003
self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
"""
Main external access point that processes the clauses into sqlalchemy text
clauses and updates select_related list with implicit related tables
mentioned in select_related strings but not included in select_related.
:param kwargs: key, value pair with column names and values
:type kwargs: Any
:return: Tuple with list of where clauses and updated select_related list
:rtype: Tuple[List[sqlalchemy.sql.elements.TextClause], List[str]]
"""
if kwargs.get("pk"):
pk_name = self.model_cls.get_column_alias(self.model_cls.Meta.pkname)
kwargs[pk_name] = kwargs.pop("pk")
@ -54,6 +67,16 @@ class QueryClause:
def _populate_filter_clauses(
self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
"""
Iterates all clauses and extracts used operator and field from related
models if needed. Based on the chain of related names the target table
is determined and the final clause is escaped if needed and compiled.
:param kwargs: key, value pair with column names and values
:type kwargs: Any
:return: Tuple with list of where clauses and updated select_related list
:rtype: Tuple[List[sqlalchemy.sql.elements.TextClause], List[str]]
"""
filter_clauses = self.filter_clauses
select_related = list(self._select_related)
@ -100,6 +123,24 @@ class QueryClause:
table: sqlalchemy.Table,
table_prefix: str,
) -> sqlalchemy.sql.expression.TextClause:
"""
Escapes characters if it's required.
Substitutes values of the models if value is a ormar Model with its pk value.
Compiles the clause.
:param value: value of the filter
:type value: Any
:param op: filter operator
:type op: str
:param column: column on which filter should be applied
:type column: sqlalchemy.sql.schema.Column
:param table: table on which filter should be applied
:type table: sqlalchemy.sql.schema.Table
:param table_prefix: prefix from AliasManager
:type table_prefix: str
:return: complied and escaped clause
:rtype: sqlalchemy.sql.elements.TextClause
"""
value, has_escaped_character = self._escape_characters_in_clause(op, value)
if isinstance(value, ormar.Model):
@ -119,7 +160,21 @@ class QueryClause:
def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str]
) -> Tuple[List[str], str, Type["Model"]]:
"""
Adds related strings to select_related list otherwise the clause would fail as
the required columns would not be present. That means that select_related
list is filled with missing values present in filters.
Walks the relation to retrieve the actual model on which the clause should be
constructed, extracts alias based on last relation leading to target model.
:param related_parts: list of split parts of related string
:type related_parts: List[str]
:param select_related: list of related models
:type select_related: List[str]
:return: list of related models, table_prefix, final model class
:rtype: Tuple[List[str], str, Type[Model]]
"""
table_prefix = ""
model_cls = self.model_cls
select_related = [relation for relation in select_related]
@ -152,6 +207,23 @@ class QueryClause:
table_prefix: str,
modifiers: Dict,
) -> sqlalchemy.sql.expression.TextClause:
"""
Compiles the clause to str using appropriate database dialect, replace columns
names with aliased names and converts it back to TextClause.
:param clause: original not compiled clause
:type clause: sqlalchemy.sql.elements.BinaryExpression
:param column: column on which filter should be applied
:type column: sqlalchemy.sql.schema.Column
:param table: table on which filter should be applied
:type table: sqlalchemy.sql.schema.Table
:param table_prefix: prefix from AliasManager
:type table_prefix: str
:param modifiers: sqlalchemy modifiers - used only to escape chars here
:type modifiers: Dict[str, NoneType]
:return: compiled and escaped clause
:rtype: sqlalchemy.sql.elements.TextClause
"""
for modifier, modifier_value in modifiers.items():
clause.modifiers[modifier] = modifier_value
@ -169,6 +241,19 @@ class QueryClause:
@staticmethod
def _escape_characters_in_clause(op: str, value: Any) -> Tuple[Any, bool]:
"""
Escapes the special characters ["%", "_"] if needed.
Adds `%` for `like` queries.
:raises: QueryDefinitionError if contains or icontains is used with
ormar model instance
:param op: operator used in query
:type op: str
:param value: value of the filter
:type value: Any
:return: escaped value and flag if escaping is needed
:rtype: Tuple[Any, bool]
"""
has_escaped_character = False
if op not in [
@ -202,6 +287,14 @@ class QueryClause:
def _extract_operator_field_and_related(
parts: List[str],
) -> Tuple[str, str, Optional[List]]:
"""
Splits filter query key and extracts required parts.
:param parts: split filter query key
:type parts: List[str]
:return: operator, field_name, list of related parts
:rtype: Tuple[str, str, Optional[List]]
"""
if parts[-1] in FILTER_OPERATORS:
op = parts[-1]
field_name = parts[-2]

View File

@ -22,6 +22,10 @@ if TYPE_CHECKING: # pragma no cover
class JoinParameters(NamedTuple):
"""
Named tuple that holds set of parameters passed during join construction.
"""
prev_model: Type["Model"]
previous_alias: str
from_table: str
@ -48,13 +52,36 @@ class SqlJoin:
self.sorted_orders = sorted_orders
@staticmethod
def relation_manager(model_cls: Type["Model"]) -> AliasManager:
def alias_manager(model_cls: Type["Model"]) -> AliasManager:
"""
Shortcut for ormars model AliasManager stored on Meta.
:param model_cls: ormar Model class
:type model_cls: Type[Model]
:return: alias manager from model's Meta
:rtype: AliasManager
"""
return model_cls.Meta.alias_manager
@staticmethod
def on_clause(
previous_alias: str, alias: str, from_clause: str, to_clause: str,
) -> text:
"""
Receives aliases and names of both ends of the join and combines them
into one text clause used in joins.
:param previous_alias: alias of previous table
:type previous_alias: str
:param alias: alias of current table
:type alias: str
:param from_clause: from table name
:type from_clause: str
:param to_clause: to table name
:type to_clause: str
:return: clause combining all strings
:rtype: sqlalchemy.text
"""
left_part = f"{alias}_{to_clause}"
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
return text(f"{left_part}={right_part}")
@ -66,6 +93,20 @@ class SqlJoin:
exclude_fields: Optional[Union[Set, Dict]],
nested_name: str,
) -> Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]:
"""
Extract nested fields and exclude_fields if applicable.
:param model_cls: ormar model class
:type model_cls: Type["Model"]
:param fields: fields to include
:type fields: Optional[Union[Set, Dict]]
:param exclude_fields: fields to exclude
:type exclude_fields: Optional[Union[Set, Dict]]
:param nested_name: name of the nested field
:type nested_name: str
:return: updated exclude and include fields from nested objects
:rtype: Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]
"""
fields = model_cls.get_included(fields, nested_name)
exclude_fields = model_cls.get_excluded(exclude_fields, nested_name)
return fields, exclude_fields
@ -73,7 +114,19 @@ class SqlJoin:
def build_join( # noqa: CCR001
self, item: str, join_parameters: JoinParameters
) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]:
"""
Main external access point for building a join.
Splits the join definition, updates fields and exclude_fields if needed,
handles switching to through models for m2m relations, returns updated lists of
used_aliases and sort_orders.
:param item: string with join definition
:type item: str
:param join_parameters: parameters from previous/ current join
:type join_parameters: JoinParameters
:return: list of used aliases, select from, list of aliased columns, sort orders
:rtype: Tuple[List[str], Join, List[TextClause], collections.OrderedDict]
"""
fields = self.fields
exclude_fields = self.exclude_fields
@ -129,6 +182,23 @@ class SqlJoin:
exclude_fields: Optional[Union[Set, Dict]],
is_multi: bool = False,
) -> JoinParameters:
"""
Updates used_aliases to not join multiple times to the same table.
Updates join parameters with new values.
:param part: part of the join str definition
:type part: str
:param join_params: parameters from previous/ current join
:type join_params: JoinParameters
:param fields: fields to include
:type fields: Optional[Union[Set, Dict]]
:param exclude_fields: fields to exclude
:type exclude_fields: Optional[Union[Set, Dict]]
:param is_multi: flag if the relation is m2m
:type is_multi: bool
:return: updated join parameters
:rtype: ormar.queryset.join.JoinParameters
"""
if is_multi:
model_cls = join_params.model_cls.Meta.model_fields[part].through
else:
@ -164,6 +234,34 @@ class SqlJoin:
fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]],
) -> None:
"""
Resolves to and from column names and table names.
Produces on_clause.
Performs actual join updating select_from parameter.
Adds aliases of required column to list of columns to include in query.
Updates the used aliases list directly.
Process order_by causes for non m2m relations.
:param join_params: parameters from previous/ current join
:type join_params: JoinParameters
:param is_multi: flag if it's m2m relation
:type is_multi: bool
:param model_cls:
:type model_cls: ormar.models.metaclass.ModelMetaclass
:param part: name of the field used in join
:type part: str
:param alias: alias of the current join
:type alias: str
:param fields: fields to include
:type fields: Optional[Union[Set, Dict]]
:param exclude_fields: fields to exclude
:type exclude_fields: Optional[Union[Set, Dict]]
"""
to_table = model_cls.Meta.table.name
to_key, from_key = self.get_to_and_from_keys(
join_params, is_multi, model_cls, part
@ -175,7 +273,7 @@ class SqlJoin:
from_clause=f"{join_params.from_table}.{from_key}",
to_clause=f"{to_table}.{to_key}",
)
target_table = self.relation_manager(model_cls).prefixed_table_name(
target_table = self.alias_manager(model_cls).prefixed_table_name(
alias, to_table
)
self.select_from = sqlalchemy.sql.outerjoin(
@ -199,13 +297,21 @@ class SqlJoin:
use_alias=True,
)
self.columns.extend(
self.relation_manager(model_cls).prefixed_columns(
self.alias_manager(model_cls).prefixed_columns(
alias, model_cls.Meta.table, self_related_fields
)
)
self.used_aliases.append(alias)
def _switch_many_to_many_order_columns(self, part: str, new_part: str) -> None:
"""
Substitutes the name of the relation with actual model name in m2m order bys.
:param part: name of the field with relation
:type part: str
:param new_part: name of the target model
:type new_part: str
"""
if self.order_columns:
split_order_columns = [
x.split("__") for x in self.order_columns if "__" in x
@ -219,6 +325,16 @@ class SqlJoin:
@staticmethod
def _check_if_condition_apply(condition: List, part: str) -> bool:
"""
Checks filter conditions to find if they apply to current join.
:param condition: list of parts of condition split by '__'
:type condition: List[str]
:param part: name of the current relation join.
:type part: str
:return: result of the check
:rtype: bool
"""
return len(condition) >= 2 and (
condition[-2] == part or condition[-2][1:] == part
)
@ -226,6 +342,19 @@ class SqlJoin:
def set_aliased_order_by(
self, condition: List[str], alias: str, to_table: str, model_cls: Type["Model"],
) -> None:
"""
Substitute hyphens ('-') with descending order.
Construct actual sqlalchemy text clause using aliased table and column name.
:param condition: list of parts of a current condition split by '__'
:type condition: List[str]
:param alias: alias of the table in current join
:type alias: str
:param to_table: target table
:type to_table: sqlalchemy.sql.elements.quoted_name
:param model_cls: ormar model class
:type model_cls: ormar.models.metaclass.ModelMetaclass
"""
direction = f"{'desc' if condition[0][0] == '-' else ''}"
column_alias = model_cls.get_column_alias(condition[-1])
order = text(f"{alias}_{to_table}.{column_alias} {direction}")
@ -239,6 +368,21 @@ class SqlJoin:
part: str,
model_cls: Type["Model"],
) -> None:
"""
Triggers construction of order bys if they are given.
Otherwise by default each table is sorted by a primary key column asc.
:param alias: alias of current table in join
:type alias: str
:param to_table: target table
:type to_table: sqlalchemy.sql.elements.quoted_name
:param pkname_alias: alias of the primary key column
:type pkname_alias: str
:param part: name of the current relation join
:type part: str
:param model_cls: ormar model class
:type model_cls: Type[Model]
"""
if self.order_columns:
split_order_columns = [
x.split("__") for x in self.order_columns if "__" in x
@ -262,6 +406,22 @@ class SqlJoin:
model_cls: Type["Model"],
part: str,
) -> Tuple[str, str]:
"""
Based on the relation type, name of the relation and previous models and parts
stored in JoinParameters it resolves the current to and from keys, which are
different for ManyToMany relation, ForeignKey and reverse part of relations.
:param join_params: parameters from previous/ current join
:type join_params: JoinParameters
:param is_multi: flag if the relation is of m2m type
:type is_multi: bool
:param model_cls: ormar model class
:type model_cls: Type[Model]
:param part: name of the current relation join
:type part: str
:return: to key and from key
:rtype: Tuple[str, str]
"""
if is_multi:
to_field = join_params.prev_model.get_name()
to_key = model_cls.get_column_alias(to_field)

View File

@ -6,7 +6,17 @@ from fastapi import FastAPI
from starlette.testclient import TestClient
from tests.settings import DATABASE_URL
from tests.test_inheritance_concrete import Category, Subject, metadata, db as database # type: ignore
from tests.test_inheritance_concrete import ( # type: ignore
Category,
Subject,
Person,
Bus,
Truck,
Bus2,
Truck2,
db as database,
metadata,
)
app = FastAPI()
app.state.database = database
@ -37,6 +47,56 @@ async def create_category(category: Category):
return category
@app.post("/buses/", response_model=Bus)
async def create_bus(bus: Bus):
await bus.save()
return bus
@app.get("/buses/{item_id}", response_model=Bus)
async def get_bus(item_id: int):
bus = await Bus.objects.select_related(["owner", "co_owner"]).get(pk=item_id)
return bus
@app.post("/trucks/", response_model=Truck)
async def create_truck(truck: Truck):
await truck.save()
return truck
@app.post("/persons/", response_model=Person)
async def create_person(person: Person):
await person.save()
return person
@app.post("/buses2/", response_model=Bus2)
async def create_bus2(bus: Bus2):
await bus.save()
return bus
@app.post("/buses2/{item_id}/add_coowner/", response_model=Bus2)
async def add_bus_coowner(item_id: int, person: Person):
bus = await Bus2.objects.select_related(["owner", "co_owners"]).get(pk=item_id)
await bus.co_owners.add(person)
return bus
@app.post("/trucks2/", response_model=Truck2)
async def create_truck2(truck: Truck2):
await truck.save()
return truck
@app.post("/trucks2/{item_id}/add_coowner/", response_model=Truck2)
async def add_truck_coowner(item_id: int, person: Person):
truck = await Truck2.objects.select_related(["owner", "co_owners"]).get(pk=item_id)
await truck.co_owners.add(person)
return truck
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
@ -73,3 +133,87 @@ def test_read_main():
assert sub.name == "Bar"
assert sub.category.pk == cat.pk
assert isinstance(sub.updated_date, datetime.datetime)
def test_inheritance_with_relation():
client = TestClient(app)
with client as client:
sam = Person(**client.post("/persons/", json={"name": "Sam"}).json())
joe = Person(**client.post("/persons/", json={"name": "Joe"}).json())
truck_dict = dict(
name="Shelby wanna be",
max_capacity=1400,
owner=sam.dict(),
co_owner=joe.dict(),
)
bus_dict = dict(
name="Unicorn", max_persons=50, owner=sam.dict(), co_owner=joe.dict()
)
unicorn = Bus(**client.post("/buses/", json=bus_dict).json())
shelby = Truck(**client.post("/trucks/", json=truck_dict).json())
assert shelby.name == "Shelby wanna be"
assert shelby.owner.name == "Sam"
assert shelby.co_owner.name == "Joe"
assert shelby.co_owner == joe
assert shelby.max_capacity == 1400
assert unicorn.name == "Unicorn"
assert unicorn.owner == sam
assert unicorn.owner.name == "Sam"
assert unicorn.co_owner.name == "Joe"
assert unicorn.max_persons == 50
unicorn2 = Bus(**client.get(f"/buses/{unicorn.pk}").json())
assert unicorn2.name == "Unicorn"
assert unicorn2.owner == sam
assert unicorn2.owner.name == "Sam"
assert unicorn2.co_owner.name == "Joe"
assert unicorn2.max_persons == 50
def test_inheritance_with_m2m_relation():
client = TestClient(app)
with client as client:
sam = Person(**client.post("/persons/", json={"name": "Sam"}).json())
joe = Person(**client.post("/persons/", json={"name": "Joe"}).json())
alex = Person(**client.post("/persons/", json={"name": "Alex"}).json())
truck_dict = dict(name="Shelby wanna be", max_capacity=2000, owner=sam.dict())
bus_dict = dict(name="Unicorn", max_persons=80, owner=sam.dict())
unicorn = Bus2(**client.post("/buses2/", json=bus_dict).json())
shelby = Truck2(**client.post("/trucks2/", json=truck_dict).json())
unicorn = Bus2(
**client.post(f"/buses2/{unicorn.pk}/add_coowner/", json=joe.dict()).json()
)
unicorn = Bus2(
**client.post(f"/buses2/{unicorn.pk}/add_coowner/", json=alex.dict()).json()
)
assert shelby.name == "Shelby wanna be"
assert shelby.owner.name == "Sam"
assert len(shelby.co_owners) == 0
assert shelby.max_capacity == 2000
assert unicorn.name == "Unicorn"
assert unicorn.owner == sam
assert unicorn.owner.name == "Sam"
assert unicorn.co_owners[0].name == "Joe"
assert unicorn.co_owners[1] == alex
assert unicorn.max_persons == 80
client.post(f"/trucks2/{shelby.pk}/add_coowner/", json=alex.dict())
shelby = Truck2(
**client.post(f"/trucks2/{shelby.pk}/add_coowner/", json=joe.dict()).json()
)
assert shelby.name == "Shelby wanna be"
assert shelby.owner.name == "Sam"
assert len(shelby.co_owners) == 2
assert shelby.co_owners[0] == alex
assert shelby.co_owners[1] == joe
assert shelby.max_capacity == 2000