diff --git a/ormar/queryset/actions/order_action.py b/ormar/queryset/actions/order_action.py index 7330d72..05114bf 100644 --- a/ormar/queryset/actions/order_action.py +++ b/ormar/queryset/actions/order_action.py @@ -78,14 +78,17 @@ class OrderAction(QueryAction): :return: complied and escaped clause :rtype: sqlalchemy.sql.elements.TextClause """ + dialect = self.target_model.Meta.database._backend._dialect + quoter = dialect.identifier_preparer.quote prefix = f"{self.table_prefix}_" if self.table_prefix else "" table_name = self.table.name field_name = self.field_alias if not prefix: - dialect = self.target_model.Meta.database._backend._dialect - table_name = dialect.identifier_preparer.quote(table_name) - field_name = dialect.identifier_preparer.quote(field_name) - return text(f"{prefix}{table_name}" f".{field_name} {self.direction}") + table_name = quoter(table_name) + else: + table_name = quoter(f"{prefix}{table_name}") + field_name = quoter(field_name) + return text(f"{table_name}.{field_name} {self.direction}") def _split_value_into_parts(self, order_str: str) -> None: if order_str.startswith("-"): diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index 005eeb2..29d0614 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -92,28 +92,31 @@ class SqlJoin: """ return self.next_model.Meta.table - def _on_clause(self, previous_alias: str, from_clause: str, to_clause: str) -> text: + def _on_clause(self, previous_alias: str, from_table_name:str, from_column_name: str, to_table_name: str, to_column_name: str) -> text: """ Receives aliases and names of both ends of the join and combines them into one text clause used in joins. :param previous_alias: alias of previous table :type previous_alias: str - :param from_clause: from table name - :type from_clause: str - :param to_clause: to table name - :type to_clause: str + :param from_table_name: from table name + :type from_table_name: str + :param from_column_name: from column name + :type from_column_name: str + :param to_table_name: to table name + :type to_table_name: str + :param to_column_name: to column name + :type to_column_name: str :return: clause combining all strings :rtype: sqlalchemy.text """ - left_part = f"{self.next_alias}_{to_clause}" + dialect = self.main_model.Meta.database._backend._dialect + quoter = dialect.identifier_preparer.quote + left_part = f"{quoter(f'{self.next_alias}_{to_table_name}')}.{quoter(to_column_name)}" if not previous_alias: - dialect = self.main_model.Meta.database._backend._dialect - table, column = from_clause.split(".") - quotter = dialect.identifier_preparer.quote - right_part = f"{quotter(table)}.{quotter(column)}" + right_part = f"{quoter(from_table_name)}.{quoter(from_column_name)}" else: - right_part = f"{previous_alias}_{from_clause}" + right_part = f"{quoter(f'{previous_alias}_{from_table_name}')}.{from_column_name}" return text(f"{left_part}={right_part}") @@ -278,8 +281,10 @@ class SqlJoin: on_clause = self._on_clause( previous_alias=self.own_alias, - from_clause=f"{self.target_field.owner.Meta.tablename}.{from_key}", - to_clause=f"{self.to_table.name}.{to_key}", + from_table_name=self.target_field.owner.Meta.tablename, + from_column_name=from_key, + to_table_name=self.to_table.name, + to_column_name=to_key, ) target_table = self.alias_manager.prefixed_table_name( self.next_alias, self.to_table diff --git a/tests/test_model_definition/test_field_quoting.py b/tests/test_model_definition/test_field_quoting.py new file mode 100644 index 0000000..4cf6568 --- /dev/null +++ b/tests/test_model_definition/test_field_quoting.py @@ -0,0 +1,99 @@ +import asyncio +from typing import Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class SchoolClass(ormar.Model): + class Meta: + tablename = "app.schoolclasses" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Category(ormar.Model): + class Meta: + tablename = "app.categories" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Student(ormar.Model): + class Meta: + tablename = "app.students" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + gpa: float = ormar.Float() + schoolclass: Optional[SchoolClass] = ormar.ForeignKey(SchoolClass, related_name="students") + category: Optional[Category] = ormar.ForeignKey(Category, nullable=True, related_name="students") + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +async def create_data(): + class1 = await SchoolClass.objects.create(name="Math") + class2 = await SchoolClass.objects.create(name="Logic") + category = await Category.objects.create(name="Foreign") + category2 = await Category.objects.create(name="Domestic") + await Student.objects.create(name="Jane", category=category, schoolclass=class1, gpa=3.2) + await Student.objects.create(name="Judy", category=category2, schoolclass=class1, gpa=2.6) + await Student.objects.create(name="Jack", category=category2, schoolclass=class2, gpa=3.8) + + +@pytest.mark.asyncio +async def test_quotes_left_join(): + async with database: + async with database.transaction(force_rollback=True): + await create_data() + students = await Student.objects.filter( + (Student.schoolclass.name == "Math") | (Student.category.name == "Foreign") + ).all() + for student in students: + assert student.schoolclass.name == "Math" or student.category.name == "Foreign" + + +@pytest.mark.asyncio +async def test_quotes_reverse_join(): + async with database: + async with database.transaction(force_rollback=True): + await create_data() + schoolclasses = await SchoolClass.objects.filter(students__gpa__gt=3).all() + for schoolclass in schoolclasses: + for student in schoolclass.students: + assert student.gpa > 3 + + +@pytest.mark.asyncio +async def test_quotes_deep_join(): + async with database: + async with database.transaction(force_rollback=True): + await create_data() + schoolclasses = await SchoolClass.objects.filter(students__category__name="Domestic").all() + for schoolclass in schoolclasses: + for student in schoolclass.students: + assert student.category.name == "Domestic" +