fix: Table names and columns not being quoted properly (#789) (#1174)

* Split _on_clause() from_clause parameter into separate table and column strings

* Fix quoting for left side and order action, add test

* Improve join quoting, add more quoting tests
This commit is contained in:
Peter DeVita
2023-08-30 11:48:32 -04:00
committed by GitHub
parent a2dfac6c21
commit 242114ee91
3 changed files with 124 additions and 17 deletions

View File

@ -78,14 +78,17 @@ class OrderAction(QueryAction):
:return: complied and escaped clause :return: complied and escaped clause
:rtype: sqlalchemy.sql.elements.TextClause :rtype: sqlalchemy.sql.elements.TextClause
""" """
dialect = self.target_model.Meta.database._backend._dialect
quoter = dialect.identifier_preparer.quote
prefix = f"{self.table_prefix}_" if self.table_prefix else "" prefix = f"{self.table_prefix}_" if self.table_prefix else ""
table_name = self.table.name table_name = self.table.name
field_name = self.field_alias field_name = self.field_alias
if not prefix: if not prefix:
dialect = self.target_model.Meta.database._backend._dialect table_name = quoter(table_name)
table_name = dialect.identifier_preparer.quote(table_name) else:
field_name = dialect.identifier_preparer.quote(field_name) table_name = quoter(f"{prefix}{table_name}")
return text(f"{prefix}{table_name}" f".{field_name} {self.direction}") field_name = quoter(field_name)
return text(f"{table_name}.{field_name} {self.direction}")
def _split_value_into_parts(self, order_str: str) -> None: def _split_value_into_parts(self, order_str: str) -> None:
if order_str.startswith("-"): if order_str.startswith("-"):

View File

@ -92,28 +92,31 @@ class SqlJoin:
""" """
return self.next_model.Meta.table return self.next_model.Meta.table
def _on_clause(self, previous_alias: str, from_clause: str, to_clause: str) -> text: def _on_clause(self, previous_alias: str, from_table_name:str, from_column_name: str, to_table_name: str, to_column_name: str) -> text:
""" """
Receives aliases and names of both ends of the join and combines them Receives aliases and names of both ends of the join and combines them
into one text clause used in joins. into one text clause used in joins.
:param previous_alias: alias of previous table :param previous_alias: alias of previous table
:type previous_alias: str :type previous_alias: str
:param from_clause: from table name :param from_table_name: from table name
:type from_clause: str :type from_table_name: str
:param to_clause: to table name :param from_column_name: from column name
:type to_clause: str :type from_column_name: str
:param to_table_name: to table name
:type to_table_name: str
:param to_column_name: to column name
:type to_column_name: str
:return: clause combining all strings :return: clause combining all strings
:rtype: sqlalchemy.text :rtype: sqlalchemy.text
""" """
left_part = f"{self.next_alias}_{to_clause}" dialect = self.main_model.Meta.database._backend._dialect
quoter = dialect.identifier_preparer.quote
left_part = f"{quoter(f'{self.next_alias}_{to_table_name}')}.{quoter(to_column_name)}"
if not previous_alias: if not previous_alias:
dialect = self.main_model.Meta.database._backend._dialect right_part = f"{quoter(from_table_name)}.{quoter(from_column_name)}"
table, column = from_clause.split(".")
quotter = dialect.identifier_preparer.quote
right_part = f"{quotter(table)}.{quotter(column)}"
else: else:
right_part = f"{previous_alias}_{from_clause}" right_part = f"{quoter(f'{previous_alias}_{from_table_name}')}.{from_column_name}"
return text(f"{left_part}={right_part}") return text(f"{left_part}={right_part}")
@ -278,8 +281,10 @@ class SqlJoin:
on_clause = self._on_clause( on_clause = self._on_clause(
previous_alias=self.own_alias, previous_alias=self.own_alias,
from_clause=f"{self.target_field.owner.Meta.tablename}.{from_key}", from_table_name=self.target_field.owner.Meta.tablename,
to_clause=f"{self.to_table.name}.{to_key}", from_column_name=from_key,
to_table_name=self.to_table.name,
to_column_name=to_key,
) )
target_table = self.alias_manager.prefixed_table_name( target_table = self.alias_manager.prefixed_table_name(
self.next_alias, self.to_table self.next_alias, self.to_table

View File

@ -0,0 +1,99 @@
import asyncio
from typing import Optional
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class SchoolClass(ormar.Model):
class Meta:
tablename = "app.schoolclasses"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Category(ormar.Model):
class Meta:
tablename = "app.categories"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Student(ormar.Model):
class Meta:
tablename = "app.students"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
gpa: float = ormar.Float()
schoolclass: Optional[SchoolClass] = ormar.ForeignKey(SchoolClass, related_name="students")
category: Optional[Category] = ormar.ForeignKey(Category, nullable=True, related_name="students")
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
async def create_data():
class1 = await SchoolClass.objects.create(name="Math")
class2 = await SchoolClass.objects.create(name="Logic")
category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic")
await Student.objects.create(name="Jane", category=category, schoolclass=class1, gpa=3.2)
await Student.objects.create(name="Judy", category=category2, schoolclass=class1, gpa=2.6)
await Student.objects.create(name="Jack", category=category2, schoolclass=class2, gpa=3.8)
@pytest.mark.asyncio
async def test_quotes_left_join():
async with database:
async with database.transaction(force_rollback=True):
await create_data()
students = await Student.objects.filter(
(Student.schoolclass.name == "Math") | (Student.category.name == "Foreign")
).all()
for student in students:
assert student.schoolclass.name == "Math" or student.category.name == "Foreign"
@pytest.mark.asyncio
async def test_quotes_reverse_join():
async with database:
async with database.transaction(force_rollback=True):
await create_data()
schoolclasses = await SchoolClass.objects.filter(students__gpa__gt=3).all()
for schoolclass in schoolclasses:
for student in schoolclass.students:
assert student.gpa > 3
@pytest.mark.asyncio
async def test_quotes_deep_join():
async with database:
async with database.transaction(force_rollback=True):
await create_data()
schoolclasses = await SchoolClass.objects.filter(students__category__name="Domestic").all()
for schoolclass in schoolclasses:
for student in schoolclass.students:
assert student.category.name == "Domestic"