Skip to content

Commit

Permalink
Revert "[SPARK-48172][SQL] Fix escaping issues in JDBC Dialects"
Browse files Browse the repository at this point in the history
This reverts commit 47006a4.
  • Loading branch information
yaooqinn committed May 15, 2024
1 parent 3ae78c4 commit 4ff5ca8
Show file tree
Hide file tree
Showing 12 changed files with 12 additions and 291 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,6 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
connection.prepareStatement(
"CREATE TABLE employee (dept INTEGER, name VARCHAR(10), salary DECIMAL(20, 2), bonus DOUBLE)")
.executeUpdate()
connection.prepareStatement(
s"""CREATE TABLE pattern_testing_table (
|pattern_testing_col LONGTEXT
|)
""".stripMargin
).executeUpdate()
}

override def testUpdateColumnType(tbl: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,6 @@ abstract class DockerJDBCIntegrationV2Suite extends DockerJDBCIntegrationSuite {
.executeUpdate()
connection.prepareStatement("INSERT INTO employee VALUES (6, 'jen', 12000, 1200)")
.executeUpdate()

connection.prepareStatement(
s"""
|INSERT INTO pattern_testing_table VALUES
|('special_character_quote\\'_present'),
|('special_character_quote_not_present'),
|('special_character_percent%_present'),
|('special_character_percent_not_present'),
|('special_character_underscore_present'),
|('special_character_underscorenot_present')
""".stripMargin).executeUpdate()
}

def tablePreparation(connection: Connection): Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
connection.prepareStatement(
"CREATE TABLE employee (dept INT, name VARCHAR(32), salary NUMERIC(20, 2), bonus FLOAT)")
.executeUpdate()
connection.prepareStatement(
s"""CREATE TABLE pattern_testing_table (
|pattern_testing_col LONGTEXT
|)
""".stripMargin
).executeUpdate()
}

override def notSupportsTableComment: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
connection.prepareStatement(
"CREATE TABLE employee (dept INT, name VARCHAR(32), salary DECIMAL(20, 2)," +
" bonus DOUBLE)").executeUpdate()
connection.prepareStatement(
s"""CREATE TABLE pattern_testing_table (
|pattern_testing_col LONGTEXT
|)
""".stripMargin
).executeUpdate()
}

override def testUpdateColumnType(tbl: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,6 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
connection.prepareStatement(
"CREATE TABLE employee (dept NUMBER(32), name VARCHAR2(32), salary NUMBER(20, 2)," +
" bonus BINARY_DOUBLE)").executeUpdate()
connection.prepareStatement(
s"""CREATE TABLE pattern_testing_table (
|pattern_testing_col LONGTEXT
|)
""".stripMargin
).executeUpdate()
}

override def testUpdateColumnType(tbl: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,6 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
connection.prepareStatement(
"CREATE TABLE employee (dept INTEGER, name VARCHAR(32), salary NUMERIC(20, 2)," +
" bonus double precision)").executeUpdate()
connection.prepareStatement(
s"""CREATE TABLE pattern_testing_table (
|pattern_testing_col LONGTEXT
|)
""".stripMargin
).executeUpdate()
}

override def testUpdateColumnType(tbl: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,235 +359,6 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
assert(scan.schema.names.sameElements(Seq(col)))
}

test("SPARK-48172: Test CONTAINS") {
val df1 = spark.sql(
s"""
|SELECT * FROM $catalogName.pattern_testing_table
|WHERE contains(pattern_testing_col, 'quote\\'')""".stripMargin)
df1.explain("formatted")
val rows1 = df1.collect()
assert(rows1.length === 1)
assert(rows1(0).getString(0) === "special_character_quote'_present")

val df2 = spark.sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE contains(pattern_testing_col, 'percent%')""".stripMargin)
val rows2 = df2.collect()
assert(rows2.length === 1)
assert(rows2(0).getString(0) === "special_character_percent%_present")

val df3 = spark.
sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE contains(pattern_testing_col, 'underscore_')""".stripMargin)
val rows3 = df3.collect()
assert(rows3.length === 1)
assert(rows3(0).getString(0) === "special_character_underscore_present")

val df4 = spark.
sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE contains(pattern_testing_col, 'character')
|ORDER BY pattern_testing_col""".stripMargin)
val rows4 = df4.collect()
assert(rows4.length === 1)
assert(rows4(0).getString(0) === "special_character_percent%_present")
assert(rows4(1).getString(0) === "special_character_percent_not_present")
assert(rows4(2).getString(0) === "special_character_quote'_present")
assert(rows4(3).getString(0) === "special_character_quote_not_present")
assert(rows4(4).getString(0) === "special_character_underscore_present")
assert(rows4(5).getString(0) === "special_character_underscorenot_present")
}

test("SPARK-48172: Test ENDSWITH") {
val df1 = spark.sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE endswith(pattern_testing_col, 'quote\\'_present')""".stripMargin)
val rows1 = df1.collect()
assert(rows1.length === 1)
assert(rows1(0).getString(0) === "special_character_quote'_present")

val df2 = spark.sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE endswith(pattern_testing_col, 'percent%_present')""".stripMargin)
val rows2 = df2.collect()
assert(rows2.length === 1)
assert(rows2(0).getString(0) === "special_character_percent%_present")

val df3 = spark.
sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE endswith(pattern_testing_col, 'underscore_present')""".stripMargin)
val rows3 = df3.collect()
assert(rows3.length === 1)
assert(rows3(0).getString(0) === "special_character_underscore_present")

val df4 = spark.
sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE endswith(pattern_testing_col, 'present')
|ORDER BY pattern_testing_col""".stripMargin)
val rows4 = df4.collect()
assert(rows4.length === 1)
assert(rows4(0).getString(0) === "special_character_percent%_present")
assert(rows4(1).getString(0) === "special_character_percent_not_present")
assert(rows4(2).getString(0) === "special_character_quote'_present")
assert(rows4(3).getString(0) === "special_character_quote_not_present")
assert(rows4(4).getString(0) === "special_character_underscore_present")
assert(rows4(5).getString(0) === "special_character_underscorenot_present")
}

test("SPARK-48172: Test STARTSWITH") {
val df1 = spark.sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE startswith(pattern_testing_col, 'special_character_quote\\'')""".stripMargin)
val rows1 = df1.collect()
assert(rows1.length === 1)
assert(rows1(0).getString(0) === "special_character_quote'_present")

val df2 = spark.sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE startswith(pattern_testing_col, 'special_character_percent%')""".stripMargin)
val rows2 = df2.collect()
assert(rows2.length === 1)
assert(rows2(0).getString(0) === "special_character_percent%_present")

val df3 = spark.
sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE startswith(pattern_testing_col, 'special_character_underscore_')""".stripMargin)
val rows3 = df3.collect()
assert(rows3.length === 1)
assert(rows3(0).getString(0) === "special_character_underscore_present")

val df4 = spark.
sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE startswith(pattern_testing_col, 'special_character')
|ORDER BY pattern_testing_col""".stripMargin)
val rows4 = df4.collect()
assert(rows4.length === 1)
assert(rows4(0).getString(0) === "special_character_percent%_present")
assert(rows4(1).getString(0) === "special_character_percent_not_present")
assert(rows4(2).getString(0) === "special_character_quote'_present")
assert(rows4(3).getString(0) === "special_character_quote_not_present")
assert(rows4(4).getString(0) === "special_character_underscore_present")
assert(rows4(5).getString(0) === "special_character_underscorenot_present")
}

test("SPARK-48172: Test LIKE") {
// this one should map to contains
val df1 = spark.sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE pattern_testing_col LIKE '%quote\\'%'""".stripMargin)
val rows1 = df1.collect()
assert(rows1.length === 1)
assert(rows1(0).getString(0) === "special_character_quote'_present")

val df2 = spark.sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE pattern_testing_col LIKE '%percent\\%%'""".stripMargin)
val rows2 = df2.collect()
assert(rows2.length === 1)
assert(rows2(0).getString(0) === "special_character_percent%_present")

val df3 = spark.
sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE pattern_testing_col LIKE '%underscore\\_%'""".stripMargin)
val rows3 = df3.collect()
assert(rows3.length === 1)
assert(rows3(0).getString(0) === "special_character_underscore_present")

val df4 = spark.
sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE pattern_testing_col LIKE '%character%'
|ORDER BY pattern_testing_col""".stripMargin)
val rows4 = df4.collect()
assert(rows4.length === 1)
assert(rows4(0).getString(0) === "special_character_percent%_present")
assert(rows4(1).getString(0) === "special_character_percent_not_present")
assert(rows4(2).getString(0) === "special_character_quote'_present")
assert(rows4(3).getString(0) === "special_character_quote_not_present")
assert(rows4(4).getString(0) === "special_character_underscore_present")
assert(rows4(5).getString(0) === "special_character_underscorenot_present")

// map to startsWith
// this one should map to contains
val df5 = spark.sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE pattern_testing_col LIKE 'special_character_quote\\'%'""".stripMargin)
val rows5 = df5.collect()
assert(rows5.length === 1)
assert(rows5(0).getString(0) === "special_character_quote'_present")

val df6 = spark.sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE pattern_testing_col LIKE 'special_character_percent\\%%'""".stripMargin)
val rows6 = df6.collect()
assert(rows6.length === 1)
assert(rows6(0).getString(0) === "special_character_percent%_present")

val df7 = spark.
sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE pattern_testing_col LIKE 'special_character_underscore\\_%'""".stripMargin)
val rows7 = df7.collect()
assert(rows7.length === 1)
assert(rows7(0).getString(0) === "special_character_underscore_present")

val df8 = spark.
sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE pattern_testing_col LIKE 'special_character%'
|ORDER BY pattern_testing_col""".stripMargin)
val rows8 = df8.collect()
assert(rows8.length === 1)
assert(rows8(0).getString(0) === "special_character_percent%_present")
assert(rows8(1).getString(0) === "special_character_percent_not_present")
assert(rows8(2).getString(0) === "special_character_quote'_present")
assert(rows8(3).getString(0) === "special_character_quote_not_present")
assert(rows8(4).getString(0) === "special_character_underscore_present")
assert(rows8(5).getString(0) === "special_character_underscorenot_present")
// map to endsWith
// this one should map to contains
val df9 = spark.sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE pattern_testing_col LIKE '%quote\\'_present'""".stripMargin)
val rows9 = df9.collect()
assert(rows9.length === 1)
assert(rows9(0).getString(0) === "special_character_quote'_present")

val df10 = spark.sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE pattern_testing_col LIKE '%percent\\%_present'""".stripMargin)
val rows10 = df10.collect()
assert(rows10.length === 1)
assert(rows10(0).getString(0) === "special_character_percent%_present")

val df11 = spark.
sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE pattern_testing_col LIKE '%underscore\\_present'""".stripMargin)
val rows11 = df11.collect()
assert(rows11.length === 1)
assert(rows11(0).getString(0) === "special_character_underscore_present")

val df12 = spark.
sql(
s"""SELECT * FROM $catalogName.pattern_testing_table
|WHERE pattern_testing_col LIKE '%present' ORDER BY pattern_testing_col""".stripMargin)
val rows12 = df12.collect()
assert(rows12.length === 1)
assert(rows12(0).getString(0) === "special_character_percent%_present")
assert(rows12(1).getString(0) === "special_character_percent_not_present")
assert(rows12(2).getString(0) === "special_character_quote'_present")
assert(rows12(3).getString(0) === "special_character_quote_not_present")
assert(rows12(4).getString(0) === "special_character_underscore_present")
assert(rows12(5).getString(0) === "special_character_underscorenot_present")
}

test("SPARK-37038: Test TABLESAMPLE") {
if (supportsTableSample) {
withTable(s"$catalogName.new_table") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ protected String escapeSpecialCharsForLikePattern(String str) {
switch (c) {
case '_' -> builder.append("\\_");
case '%' -> builder.append("\\%");
case '\'' -> builder.append("\\\'");
default -> builder.append(c);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.connector.expressions

import org.apache.commons.lang3.StringUtils

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
Expand Down Expand Up @@ -390,7 +388,7 @@ private[sql] object HoursTransform {
private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] {
override def toString: String = {
if (dataType.isInstanceOf[StringType]) {
s"'${StringUtils.replace(s"$value", "'", "''")}'"
s"'$value'"
} else {
s"$value"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,13 @@ private[sql] case class H2Dialect() extends JdbcDialect {
}

class H2SQLBuilder extends JDBCSQLBuilder {
override def escapeSpecialCharsForLikePattern(str: String): String = {
str.map {
case '_' => "\\_"
case '%' => "\\%"
case c => c.toString
}.mkString
}

override def visitAggregateFunction(
funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,6 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper {
}
}

override def visitStartsWith(l: String, r: String): String = {
val value = r.substring(1, r.length() - 1)
s"$l LIKE '${escapeSpecialCharsForLikePattern(value)}%' ESCAPE '\\\\'"
}

override def visitEndsWith(l: String, r: String): String = {
val value = r.substring(1, r.length() - 1)
s"$l LIKE '%${escapeSpecialCharsForLikePattern(value)}' ESCAPE '\\\\'"
}

override def visitContains(l: String, r: String): String = {
val value = r.substring(1, r.length() - 1)
s"$l LIKE '%${escapeSpecialCharsForLikePattern(value)}%' ESCAPE '\\\\'"
}

override def visitAggregateFunction(
funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) {
Expand Down
Loading

0 comments on commit 4ff5ca8

Please sign in to comment.