From a914be67e258c1ccfc4467a5a465e795ccbca154 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 4 Jan 2021 12:43:00 +0100 Subject: [PATCH] add fastapi tests with inheritance and relations, more docstrings in queryset --- ormar/queryset/__init__.py | 3 + ormar/queryset/clause.py | 93 ++++++++++++ ormar/queryset/join.py | 166 ++++++++++++++++++++- tests/test_inheritance_concrete_fastapi.py | 146 +++++++++++++++++- 4 files changed, 404 insertions(+), 4 deletions(-) diff --git a/ormar/queryset/__init__.py b/ormar/queryset/__init__.py index 2bc0a6d..8528b05 100644 --- a/ormar/queryset/__init__.py +++ b/ormar/queryset/__init__.py @@ -1,3 +1,6 @@ +""" +Contains QuerySet and different Query classes to allow for constructing of sql queries. +""" from ormar.queryset.filter_query import FilterQuery from ormar.queryset.limit_query import LimitQuery from ormar.queryset.offset_query import OffsetQuery diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index c55eed8..d0963cf 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -29,6 +29,10 @@ ESCAPE_CHARACTERS = ["%", "_"] class QueryClause: + """ + Constructs where clauses from strings passed as arguments + """ + def __init__( self, model_cls: Type["Model"], filter_clauses: List, select_related: List, ) -> None: @@ -42,7 +46,16 @@ class QueryClause: def filter( # noqa: A003 self, **kwargs: Any ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: + """ + Main external access point that processes the clauses into sqlalchemy text + clauses and updates select_related list with implicit related tables + mentioned in select_related strings but not included in select_related. + :param kwargs: key, value pair with column names and values + :type kwargs: Any + :return: Tuple with list of where clauses and updated select_related list + :rtype: Tuple[List[sqlalchemy.sql.elements.TextClause], List[str]] + """ if kwargs.get("pk"): pk_name = self.model_cls.get_column_alias(self.model_cls.Meta.pkname) kwargs[pk_name] = kwargs.pop("pk") @@ -54,6 +67,16 @@ class QueryClause: def _populate_filter_clauses( self, **kwargs: Any ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: + """ + Iterates all clauses and extracts used operator and field from related + models if needed. Based on the chain of related names the target table + is determined and the final clause is escaped if needed and compiled. + + :param kwargs: key, value pair with column names and values + :type kwargs: Any + :return: Tuple with list of where clauses and updated select_related list + :rtype: Tuple[List[sqlalchemy.sql.elements.TextClause], List[str]] + """ filter_clauses = self.filter_clauses select_related = list(self._select_related) @@ -100,6 +123,24 @@ class QueryClause: table: sqlalchemy.Table, table_prefix: str, ) -> sqlalchemy.sql.expression.TextClause: + """ + Escapes characters if it's required. + Substitutes values of the models if value is a ormar Model with its pk value. + Compiles the clause. + + :param value: value of the filter + :type value: Any + :param op: filter operator + :type op: str + :param column: column on which filter should be applied + :type column: sqlalchemy.sql.schema.Column + :param table: table on which filter should be applied + :type table: sqlalchemy.sql.schema.Table + :param table_prefix: prefix from AliasManager + :type table_prefix: str + :return: complied and escaped clause + :rtype: sqlalchemy.sql.elements.TextClause + """ value, has_escaped_character = self._escape_characters_in_clause(op, value) if isinstance(value, ormar.Model): @@ -119,7 +160,21 @@ class QueryClause: def _determine_filter_target_table( self, related_parts: List[str], select_related: List[str] ) -> Tuple[List[str], str, Type["Model"]]: + """ + Adds related strings to select_related list otherwise the clause would fail as + the required columns would not be present. That means that select_related + list is filled with missing values present in filters. + Walks the relation to retrieve the actual model on which the clause should be + constructed, extracts alias based on last relation leading to target model. + + :param related_parts: list of split parts of related string + :type related_parts: List[str] + :param select_related: list of related models + :type select_related: List[str] + :return: list of related models, table_prefix, final model class + :rtype: Tuple[List[str], str, Type[Model]] + """ table_prefix = "" model_cls = self.model_cls select_related = [relation for relation in select_related] @@ -152,6 +207,23 @@ class QueryClause: table_prefix: str, modifiers: Dict, ) -> sqlalchemy.sql.expression.TextClause: + """ + Compiles the clause to str using appropriate database dialect, replace columns + names with aliased names and converts it back to TextClause. + + :param clause: original not compiled clause + :type clause: sqlalchemy.sql.elements.BinaryExpression + :param column: column on which filter should be applied + :type column: sqlalchemy.sql.schema.Column + :param table: table on which filter should be applied + :type table: sqlalchemy.sql.schema.Table + :param table_prefix: prefix from AliasManager + :type table_prefix: str + :param modifiers: sqlalchemy modifiers - used only to escape chars here + :type modifiers: Dict[str, NoneType] + :return: compiled and escaped clause + :rtype: sqlalchemy.sql.elements.TextClause + """ for modifier, modifier_value in modifiers.items(): clause.modifiers[modifier] = modifier_value @@ -169,6 +241,19 @@ class QueryClause: @staticmethod def _escape_characters_in_clause(op: str, value: Any) -> Tuple[Any, bool]: + """ + Escapes the special characters ["%", "_"] if needed. + Adds `%` for `like` queries. + + :raises: QueryDefinitionError if contains or icontains is used with + ormar model instance + :param op: operator used in query + :type op: str + :param value: value of the filter + :type value: Any + :return: escaped value and flag if escaping is needed + :rtype: Tuple[Any, bool] + """ has_escaped_character = False if op not in [ @@ -202,6 +287,14 @@ class QueryClause: def _extract_operator_field_and_related( parts: List[str], ) -> Tuple[str, str, Optional[List]]: + """ + Splits filter query key and extracts required parts. + + :param parts: split filter query key + :type parts: List[str] + :return: operator, field_name, list of related parts + :rtype: Tuple[str, str, Optional[List]] + """ if parts[-1] in FILTER_OPERATORS: op = parts[-1] field_name = parts[-2] diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index 7726a8c..d255f1c 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -22,6 +22,10 @@ if TYPE_CHECKING: # pragma no cover class JoinParameters(NamedTuple): + """ + Named tuple that holds set of parameters passed during join construction. + """ + prev_model: Type["Model"] previous_alias: str from_table: str @@ -48,13 +52,36 @@ class SqlJoin: self.sorted_orders = sorted_orders @staticmethod - def relation_manager(model_cls: Type["Model"]) -> AliasManager: + def alias_manager(model_cls: Type["Model"]) -> AliasManager: + """ + Shortcut for ormars model AliasManager stored on Meta. + + :param model_cls: ormar Model class + :type model_cls: Type[Model] + :return: alias manager from model's Meta + :rtype: AliasManager + """ return model_cls.Meta.alias_manager @staticmethod def on_clause( previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: + """ + Receives aliases and names of both ends of the join and combines them + into one text clause used in joins. + + :param previous_alias: alias of previous table + :type previous_alias: str + :param alias: alias of current table + :type alias: str + :param from_clause: from table name + :type from_clause: str + :param to_clause: to table name + :type to_clause: str + :return: clause combining all strings + :rtype: sqlalchemy.text + """ left_part = f"{alias}_{to_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" return text(f"{left_part}={right_part}") @@ -66,6 +93,20 @@ class SqlJoin: exclude_fields: Optional[Union[Set, Dict]], nested_name: str, ) -> Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]]: + """ + Extract nested fields and exclude_fields if applicable. + + :param model_cls: ormar model class + :type model_cls: Type["Model"] + :param fields: fields to include + :type fields: Optional[Union[Set, Dict]] + :param exclude_fields: fields to exclude + :type exclude_fields: Optional[Union[Set, Dict]] + :param nested_name: name of the nested field + :type nested_name: str + :return: updated exclude and include fields from nested objects + :rtype: Tuple[Optional[Union[Dict, Set]], Optional[Union[Dict, Set]]] + """ fields = model_cls.get_included(fields, nested_name) exclude_fields = model_cls.get_excluded(exclude_fields, nested_name) return fields, exclude_fields @@ -73,7 +114,19 @@ class SqlJoin: def build_join( # noqa: CCR001 self, item: str, join_parameters: JoinParameters ) -> Tuple[List, sqlalchemy.sql.select, List, OrderedDict]: + """ + Main external access point for building a join. + Splits the join definition, updates fields and exclude_fields if needed, + handles switching to through models for m2m relations, returns updated lists of + used_aliases and sort_orders. + :param item: string with join definition + :type item: str + :param join_parameters: parameters from previous/ current join + :type join_parameters: JoinParameters + :return: list of used aliases, select from, list of aliased columns, sort orders + :rtype: Tuple[List[str], Join, List[TextClause], collections.OrderedDict] + """ fields = self.fields exclude_fields = self.exclude_fields @@ -129,6 +182,23 @@ class SqlJoin: exclude_fields: Optional[Union[Set, Dict]], is_multi: bool = False, ) -> JoinParameters: + """ + Updates used_aliases to not join multiple times to the same table. + Updates join parameters with new values. + + :param part: part of the join str definition + :type part: str + :param join_params: parameters from previous/ current join + :type join_params: JoinParameters + :param fields: fields to include + :type fields: Optional[Union[Set, Dict]] + :param exclude_fields: fields to exclude + :type exclude_fields: Optional[Union[Set, Dict]] + :param is_multi: flag if the relation is m2m + :type is_multi: bool + :return: updated join parameters + :rtype: ormar.queryset.join.JoinParameters + """ if is_multi: model_cls = join_params.model_cls.Meta.model_fields[part].through else: @@ -164,6 +234,34 @@ class SqlJoin: fields: Optional[Union[Set, Dict]], exclude_fields: Optional[Union[Set, Dict]], ) -> None: + """ + Resolves to and from column names and table names. + + Produces on_clause. + + Performs actual join updating select_from parameter. + + Adds aliases of required column to list of columns to include in query. + + Updates the used aliases list directly. + + Process order_by causes for non m2m relations. + + :param join_params: parameters from previous/ current join + :type join_params: JoinParameters + :param is_multi: flag if it's m2m relation + :type is_multi: bool + :param model_cls: + :type model_cls: ormar.models.metaclass.ModelMetaclass + :param part: name of the field used in join + :type part: str + :param alias: alias of the current join + :type alias: str + :param fields: fields to include + :type fields: Optional[Union[Set, Dict]] + :param exclude_fields: fields to exclude + :type exclude_fields: Optional[Union[Set, Dict]] + """ to_table = model_cls.Meta.table.name to_key, from_key = self.get_to_and_from_keys( join_params, is_multi, model_cls, part @@ -175,7 +273,7 @@ class SqlJoin: from_clause=f"{join_params.from_table}.{from_key}", to_clause=f"{to_table}.{to_key}", ) - target_table = self.relation_manager(model_cls).prefixed_table_name( + target_table = self.alias_manager(model_cls).prefixed_table_name( alias, to_table ) self.select_from = sqlalchemy.sql.outerjoin( @@ -199,13 +297,21 @@ class SqlJoin: use_alias=True, ) self.columns.extend( - self.relation_manager(model_cls).prefixed_columns( + self.alias_manager(model_cls).prefixed_columns( alias, model_cls.Meta.table, self_related_fields ) ) self.used_aliases.append(alias) def _switch_many_to_many_order_columns(self, part: str, new_part: str) -> None: + """ + Substitutes the name of the relation with actual model name in m2m order bys. + + :param part: name of the field with relation + :type part: str + :param new_part: name of the target model + :type new_part: str + """ if self.order_columns: split_order_columns = [ x.split("__") for x in self.order_columns if "__" in x @@ -219,6 +325,16 @@ class SqlJoin: @staticmethod def _check_if_condition_apply(condition: List, part: str) -> bool: + """ + Checks filter conditions to find if they apply to current join. + + :param condition: list of parts of condition split by '__' + :type condition: List[str] + :param part: name of the current relation join. + :type part: str + :return: result of the check + :rtype: bool + """ return len(condition) >= 2 and ( condition[-2] == part or condition[-2][1:] == part ) @@ -226,6 +342,19 @@ class SqlJoin: def set_aliased_order_by( self, condition: List[str], alias: str, to_table: str, model_cls: Type["Model"], ) -> None: + """ + Substitute hyphens ('-') with descending order. + Construct actual sqlalchemy text clause using aliased table and column name. + + :param condition: list of parts of a current condition split by '__' + :type condition: List[str] + :param alias: alias of the table in current join + :type alias: str + :param to_table: target table + :type to_table: sqlalchemy.sql.elements.quoted_name + :param model_cls: ormar model class + :type model_cls: ormar.models.metaclass.ModelMetaclass + """ direction = f"{'desc' if condition[0][0] == '-' else ''}" column_alias = model_cls.get_column_alias(condition[-1]) order = text(f"{alias}_{to_table}.{column_alias} {direction}") @@ -239,6 +368,21 @@ class SqlJoin: part: str, model_cls: Type["Model"], ) -> None: + """ + Triggers construction of order bys if they are given. + Otherwise by default each table is sorted by a primary key column asc. + + :param alias: alias of current table in join + :type alias: str + :param to_table: target table + :type to_table: sqlalchemy.sql.elements.quoted_name + :param pkname_alias: alias of the primary key column + :type pkname_alias: str + :param part: name of the current relation join + :type part: str + :param model_cls: ormar model class + :type model_cls: Type[Model] + """ if self.order_columns: split_order_columns = [ x.split("__") for x in self.order_columns if "__" in x @@ -262,6 +406,22 @@ class SqlJoin: model_cls: Type["Model"], part: str, ) -> Tuple[str, str]: + """ + Based on the relation type, name of the relation and previous models and parts + stored in JoinParameters it resolves the current to and from keys, which are + different for ManyToMany relation, ForeignKey and reverse part of relations. + + :param join_params: parameters from previous/ current join + :type join_params: JoinParameters + :param is_multi: flag if the relation is of m2m type + :type is_multi: bool + :param model_cls: ormar model class + :type model_cls: Type[Model] + :param part: name of the current relation join + :type part: str + :return: to key and from key + :rtype: Tuple[str, str] + """ if is_multi: to_field = join_params.prev_model.get_name() to_key = model_cls.get_column_alias(to_field) diff --git a/tests/test_inheritance_concrete_fastapi.py b/tests/test_inheritance_concrete_fastapi.py index f3e03f6..217fe3c 100644 --- a/tests/test_inheritance_concrete_fastapi.py +++ b/tests/test_inheritance_concrete_fastapi.py @@ -6,7 +6,17 @@ from fastapi import FastAPI from starlette.testclient import TestClient from tests.settings import DATABASE_URL -from tests.test_inheritance_concrete import Category, Subject, metadata, db as database # type: ignore +from tests.test_inheritance_concrete import ( # type: ignore + Category, + Subject, + Person, + Bus, + Truck, + Bus2, + Truck2, + db as database, + metadata, +) app = FastAPI() app.state.database = database @@ -37,6 +47,56 @@ async def create_category(category: Category): return category +@app.post("/buses/", response_model=Bus) +async def create_bus(bus: Bus): + await bus.save() + return bus + + +@app.get("/buses/{item_id}", response_model=Bus) +async def get_bus(item_id: int): + bus = await Bus.objects.select_related(["owner", "co_owner"]).get(pk=item_id) + return bus + + +@app.post("/trucks/", response_model=Truck) +async def create_truck(truck: Truck): + await truck.save() + return truck + + +@app.post("/persons/", response_model=Person) +async def create_person(person: Person): + await person.save() + return person + + +@app.post("/buses2/", response_model=Bus2) +async def create_bus2(bus: Bus2): + await bus.save() + return bus + + +@app.post("/buses2/{item_id}/add_coowner/", response_model=Bus2) +async def add_bus_coowner(item_id: int, person: Person): + bus = await Bus2.objects.select_related(["owner", "co_owners"]).get(pk=item_id) + await bus.co_owners.add(person) + return bus + + +@app.post("/trucks2/", response_model=Truck2) +async def create_truck2(truck: Truck2): + await truck.save() + return truck + + +@app.post("/trucks2/{item_id}/add_coowner/", response_model=Truck2) +async def add_truck_coowner(item_id: int, person: Person): + truck = await Truck2.objects.select_related(["owner", "co_owners"]).get(pk=item_id) + await truck.co_owners.add(person) + return truck + + @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) @@ -73,3 +133,87 @@ def test_read_main(): assert sub.name == "Bar" assert sub.category.pk == cat.pk assert isinstance(sub.updated_date, datetime.datetime) + + +def test_inheritance_with_relation(): + client = TestClient(app) + with client as client: + sam = Person(**client.post("/persons/", json={"name": "Sam"}).json()) + joe = Person(**client.post("/persons/", json={"name": "Joe"}).json()) + + truck_dict = dict( + name="Shelby wanna be", + max_capacity=1400, + owner=sam.dict(), + co_owner=joe.dict(), + ) + bus_dict = dict( + name="Unicorn", max_persons=50, owner=sam.dict(), co_owner=joe.dict() + ) + unicorn = Bus(**client.post("/buses/", json=bus_dict).json()) + shelby = Truck(**client.post("/trucks/", json=truck_dict).json()) + + assert shelby.name == "Shelby wanna be" + assert shelby.owner.name == "Sam" + assert shelby.co_owner.name == "Joe" + assert shelby.co_owner == joe + assert shelby.max_capacity == 1400 + + assert unicorn.name == "Unicorn" + assert unicorn.owner == sam + assert unicorn.owner.name == "Sam" + assert unicorn.co_owner.name == "Joe" + assert unicorn.max_persons == 50 + + unicorn2 = Bus(**client.get(f"/buses/{unicorn.pk}").json()) + assert unicorn2.name == "Unicorn" + assert unicorn2.owner == sam + assert unicorn2.owner.name == "Sam" + assert unicorn2.co_owner.name == "Joe" + assert unicorn2.max_persons == 50 + + +def test_inheritance_with_m2m_relation(): + client = TestClient(app) + with client as client: + sam = Person(**client.post("/persons/", json={"name": "Sam"}).json()) + joe = Person(**client.post("/persons/", json={"name": "Joe"}).json()) + alex = Person(**client.post("/persons/", json={"name": "Alex"}).json()) + + truck_dict = dict(name="Shelby wanna be", max_capacity=2000, owner=sam.dict()) + bus_dict = dict(name="Unicorn", max_persons=80, owner=sam.dict()) + + unicorn = Bus2(**client.post("/buses2/", json=bus_dict).json()) + shelby = Truck2(**client.post("/trucks2/", json=truck_dict).json()) + + unicorn = Bus2( + **client.post(f"/buses2/{unicorn.pk}/add_coowner/", json=joe.dict()).json() + ) + unicorn = Bus2( + **client.post(f"/buses2/{unicorn.pk}/add_coowner/", json=alex.dict()).json() + ) + + assert shelby.name == "Shelby wanna be" + assert shelby.owner.name == "Sam" + assert len(shelby.co_owners) == 0 + assert shelby.max_capacity == 2000 + + assert unicorn.name == "Unicorn" + assert unicorn.owner == sam + assert unicorn.owner.name == "Sam" + assert unicorn.co_owners[0].name == "Joe" + assert unicorn.co_owners[1] == alex + assert unicorn.max_persons == 80 + + client.post(f"/trucks2/{shelby.pk}/add_coowner/", json=alex.dict()) + + shelby = Truck2( + **client.post(f"/trucks2/{shelby.pk}/add_coowner/", json=joe.dict()).json() + ) + + assert shelby.name == "Shelby wanna be" + assert shelby.owner.name == "Sam" + assert len(shelby.co_owners) == 2 + assert shelby.co_owners[0] == alex + assert shelby.co_owners[1] == joe + assert shelby.max_capacity == 2000