Skip to content

Commit

Permalink
SPARK 46840 added org.apache.spark.sql.execution.benchmark.Collation…
Browse files Browse the repository at this point in the history
…Benchmark [wip].
GideonPotok committed Mar 18, 2024
1 parent eeebaa6 commit 7aa6ce0
Showing 2 changed files with 184 additions and 240 deletions.
277 changes: 133 additions & 144 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@

package org.apache.spark.sql

// import scala.collection.immutable.Seq
import scala.collection.immutable.Seq
import scala.jdk.CollectionConverters.MapHasAsJava

import org.apache.spark.SparkException
@@ -39,27 +39,24 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
Seq("utf8_binary", "utf8_binary_lcase", "unicode", "unicode_ci").foreach { collationName =>
checkAnswer(sql(s"select 'aaa' collate $collationName"), Row("aaa"))
val collationId = CollationFactory.collationNameToId(collationName)
assert(
sql(s"select 'aaa' collate $collationName").schema(0).dataType
== StringType(collationId))
assert(sql(s"select 'aaa' collate $collationName").schema(0).dataType
== StringType(collationId))
}
}

test("collation name is case insensitive") {
Seq("uTf8_BiNaRy", "uTf8_BiNaRy_Lcase", "uNicOde", "UNICODE_ci").foreach { collationName =>
checkAnswer(sql(s"select 'aaa' collate $collationName"), Row("aaa"))
val collationId = CollationFactory.collationNameToId(collationName)
assert(
sql(s"select 'aaa' collate $collationName").schema(0).dataType
== StringType(collationId))
assert(sql(s"select 'aaa' collate $collationName").schema(0).dataType
== StringType(collationId))
}
}

test("collation expression returns name of collation") {
Seq("utf8_binary", "utf8_binary_lcase", "unicode", "unicode_ci").foreach { collationName =>
checkAnswer(
sql(s"select collation('aaa' collate $collationName)"),
Row(collationName.toUpperCase()))
sql(s"select collation('aaa' collate $collationName)"), Row(collationName.toUpperCase()))
}
}

@@ -82,8 +79,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
"expectedNum" -> "2",
"actualNum" -> paramCount.toString,
"docroot" -> "https://spark.apache.org/docs/latest"),
context =
ExpectedContext(fragment = s"collate($args)", start = 7, stop = 15 + args.length))
context = ExpectedContext(fragment = s"collate($args)", start = 7, stop = 15 + args.length)
)
})
}

@@ -98,26 +95,25 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
"inputSql" -> "\"123\"",
"inputType" -> "\"INT\"",
"requiredType" -> "\"STRING\""),
context = ExpectedContext(fragment = s"collate('abc', 123)", start = 7, stop = 25))
context = ExpectedContext(fragment = s"collate('abc', 123)", start = 7, stop = 25)
)
}

test("NULL as collation name") {
checkError(
exception = intercept[AnalysisException] {
sql("select collate('abc', cast(null as string))")
},
sql("select collate('abc', cast(null as string))") },
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL",
sqlState = "42K09",
Map("exprName" -> "`collation`", "sqlExpr" -> "\"CAST(NULL AS STRING)\""),
context =
ExpectedContext(fragment = s"collate('abc', cast(null as string))", start = 7, stop = 42))
context = ExpectedContext(
fragment = s"collate('abc', cast(null as string))", start = 7, stop = 42)
)
}

test("collate function invalid input data type") {
checkError(
exception = intercept[ExtendedAnalysisException] {
sql(s"select collate(1, 'UTF8_BINARY')")
},
exception = intercept[ExtendedAnalysisException] { sql(s"select collate(1, 'UTF8_BINARY')") },
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
sqlState = "42K09",
parameters = Map(
@@ -126,7 +122,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
"inputSql" -> "\"1\"",
"inputType" -> "\"INT\"",
"requiredType" -> "\"STRING\""),
context = ExpectedContext(fragment = s"collate(1, 'UTF8_BINARY')", start = 7, stop = 31))
context = ExpectedContext(
fragment = s"collate(1, 'UTF8_BINARY')", start = 7, stop = 31))
}

test("collation expression returns default collation") {
@@ -145,12 +142,14 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
def createTable(bucketColumns: String*): Unit = {
val tableName = "test_partition_tbl"
withTable(tableName) {
sql(s"""
sql(
s"""
|CREATE TABLE $tableName
|(id INT, c1 STRING COLLATE UNICODE, c2 string)
|USING parquet
|CLUSTERED BY (${bucketColumns.mkString(",")})
|INTO 4 BUCKETS""".stripMargin)
|INTO 4 BUCKETS""".stripMargin
)
}
}
// should work fine on default collated columns
@@ -164,7 +163,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
createTable(bucketColumns: _*)
},
errorClass = "INVALID_BUCKET_COLUMN_DATA_TYPE",
parameters = Map("type" -> "\"STRING COLLATE UNICODE\""));
parameters = Map("type" -> "\"STRING COLLATE UNICODE\"")
);
}
}

@@ -179,7 +179,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
("unicode", "aaa", "AAA", false),
("unicode_CI", "aaa", "aaa", true),
("unicode_CI", "aaa", "AAA", true),
("unicode_CI", "aaa", "bbb", false)).foreach {
("unicode_CI", "aaa", "bbb", false)
).foreach {
case (collationName, left, right, expected) =>
checkAnswer(
sql(s"select '$left' collate $collationName = '$right' collate $collationName"),
@@ -203,13 +204,15 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
("unicode", "aaa", "BBB", true),
("unicode_CI", "aaa", "aaa", false),
("unicode_CI", "aaa", "AAA", false),
("unicode_CI", "aaa", "bbb", true)).foreach { case (collationName, left, right, expected) =>
checkAnswer(
sql(s"select '$left' collate $collationName < '$right' collate $collationName"),
Row(expected))
checkAnswer(
sql(s"select collate('$left', '$collationName') < collate('$right', '$collationName')"),
Row(expected))
("unicode_CI", "aaa", "bbb", true)
).foreach {
case (collationName, left, right, expected) =>
checkAnswer(
sql(s"select '$left' collate $collationName < '$right' collate $collationName"),
Row(expected))
checkAnswer(
sql(s"select collate('$left', '$collationName') < collate('$right', '$collationName')"),
Row(expected))
}
}

@@ -222,63 +225,59 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
right = left.substring(1, 2);
checkError(
exception = intercept[ExtendedAnalysisException] {
spark.sql(
s"SELECT contains(collate('$left', '$leftCollationName')," +
s"collate('$right', '$rightCollationName'))")
spark.sql(s"SELECT contains(collate('$left', '$leftCollationName')," +
s"collate('$right', '$rightCollationName'))")
},
errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH",
sqlState = "42K09",
parameters = Map(
"collationNameLeft" -> s"$leftCollationName",
"collationNameRight" -> s"$rightCollationName",
"sqlExpr" -> "\"contains(collate(abc), collate(b))\""),
context = ExpectedContext(
fragment = s"contains(collate('abc', 'UNICODE_CI'),collate('b', 'UNICODE'))",
start = 7,
stop = 68))
"sqlExpr" -> "\"contains(collate(abc), collate(b))\""
),
context = ExpectedContext(fragment =
s"contains(collate('abc', 'UNICODE_CI'),collate('b', 'UNICODE'))",
start = 7, stop = 68)
)
// startsWith
right = left.substring(0, 1);
checkError(
exception = intercept[ExtendedAnalysisException] {
spark.sql(
s"SELECT startsWith(collate('$left', '$leftCollationName')," +
s"collate('$right', '$rightCollationName'))")
spark.sql(s"SELECT startsWith(collate('$left', '$leftCollationName')," +
s"collate('$right', '$rightCollationName'))")
},
errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH",
sqlState = "42K09",
parameters = Map(
"collationNameLeft" -> s"$leftCollationName",
"collationNameRight" -> s"$rightCollationName",
"sqlExpr" -> "\"startswith(collate(abc), collate(a))\""),
context = ExpectedContext(
fragment = s"startsWith(collate('abc', 'UNICODE_CI'),collate('a', 'UNICODE'))",
start = 7,
stop = 70))
"sqlExpr" -> "\"startswith(collate(abc), collate(a))\""
),
context = ExpectedContext(fragment =
s"startsWith(collate('abc', 'UNICODE_CI'),collate('a', 'UNICODE'))",
start = 7, stop = 70)
)
// endsWith
right = left.substring(2, 3);
checkError(
exception = intercept[ExtendedAnalysisException] {
spark.sql(
s"SELECT endsWith(collate('$left', '$leftCollationName')," +
s"collate('$right', '$rightCollationName'))")
spark.sql(s"SELECT endsWith(collate('$left', '$leftCollationName')," +
s"collate('$right', '$rightCollationName'))")
},
errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH",
sqlState = "42K09",
parameters = Map(
"collationNameLeft" -> s"$leftCollationName",
"collationNameRight" -> s"$rightCollationName",
"sqlExpr" -> "\"endswith(collate(abc), collate(c))\""),
context = ExpectedContext(
fragment = s"endsWith(collate('abc', 'UNICODE_CI'),collate('c', 'UNICODE'))",
start = 7,
stop = 68))
"sqlExpr" -> "\"endswith(collate(abc), collate(c))\""
),
context = ExpectedContext(fragment =
s"endsWith(collate('abc', 'UNICODE_CI'),collate('c', 'UNICODE'))",
start = 7, stop = 68)
)
}

case class CollationTestCase[R](
left: String,
right: String,
collation: String,
expectedResult: R)
case class CollationTestCase[R](left: String, right: String, collation: String, expectedResult: R)

test("Support contains string expression with Collation") {
// Supported collations
@@ -318,13 +317,11 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
CollationTestCase("abcde", "bcd", "UNICODE_CI", true),
CollationTestCase("abcde", "BCD", "UNICODE_CI", true),
CollationTestCase("abcde", "fgh", "UNICODE_CI", false),
CollationTestCase("abcde", "FGH", "UNICODE_CI", false))
CollationTestCase("abcde", "FGH", "UNICODE_CI", false)
)
checks.foreach(testCase => {
checkAnswer(
sql(
s"SELECT contains(collate('${testCase.left}', '${testCase.collation}')," +
s"collate('${testCase.right}', '${testCase.collation}'))"),
Row(testCase.expectedResult))
checkAnswer(sql(s"SELECT contains(collate('${testCase.left}', '${testCase.collation}')," +
s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult))
})
}

@@ -338,13 +335,11 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
CollationTestCase("abcde", "ABC", "UTF8_BINARY_LCASE", true),
CollationTestCase("abcde", "bcd", "UTF8_BINARY_LCASE", false),
CollationTestCase("abcde", "ABC", "UNICODE_CI", true),
CollationTestCase("abcde", "bcd", "UNICODE_CI", false))
CollationTestCase("abcde", "bcd", "UNICODE_CI", false)
)
checks.foreach(testCase => {
checkAnswer(
sql(
s"SELECT startswith(collate('${testCase.left}', '${testCase.collation}')," +
s"collate('${testCase.right}', '${testCase.collation}'))"),
Row(testCase.expectedResult))
checkAnswer(sql(s"SELECT startswith(collate('${testCase.left}', '${testCase.collation}')," +
s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult))
})
}

@@ -358,13 +353,11 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
CollationTestCase("abcde", "CDE", "UTF8_BINARY_LCASE", true),
CollationTestCase("abcde", "bcd", "UTF8_BINARY_LCASE", false),
CollationTestCase("abcde", "CDE", "UNICODE_CI", true),
CollationTestCase("abcde", "bcd", "UNICODE_CI", false))
CollationTestCase("abcde", "bcd", "UNICODE_CI", false)
)
checks.foreach(testCase => {
checkAnswer(
sql(
s"SELECT endswith(collate('${testCase.left}', '${testCase.collation}')," +
s"collate('${testCase.right}', '${testCase.collation}'))"),
Row(testCase.expectedResult))
checkAnswer(sql(s"SELECT endswith(collate('${testCase.left}', '${testCase.collation}')," +
s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult))
})
}

@@ -381,18 +374,18 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
("unicode", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))),
("unicode_CI", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))),
("unicode_CI", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))),
("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb")))).foreach {
("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb")))
).foreach {
case (collationName: String, input: Seq[String], expected: Seq[Row]) =>
checkAnswer(
sql(s"""
checkAnswer(sql(
s"""
with t as (
select collate(col1, '$collationName') as c
from
values ${input.map(s => s"('$s')").mkString(", ")}
)
SELECT COUNT(*), c FROM t GROUP BY c
"""),
expected)
"""), expected)
}
}

@@ -401,8 +394,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
val tableNameBinary = "T_BINARY"
withTable(tableNameNonBinary) {
withTable(tableNameBinary) {
sql(
s"CREATE TABLE $tableNameNonBinary (c STRING COLLATE UTF8_BINARY_LCASE) USING PARQUET")
sql(s"CREATE TABLE $tableNameNonBinary (c STRING COLLATE UTF8_BINARY_LCASE) USING PARQUET")
sql(s"INSERT INTO $tableNameNonBinary VALUES ('aaa')")
sql(s"CREATE TABLE $tableNameBinary (c STRING COLLATE UTF8_BINARY) USING PARQUET")
sql(s"INSERT INTO $tableNameBinary VALUES ('aaa')")
@@ -421,10 +413,12 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
}

test("text writing to parquet with collation enclosed with backticks") {
withTempPath { path =>
withTempPath{ path =>
sql(s"select 'a' COLLATE `UNICODE`").write.parquet(path.getAbsolutePath)

checkAnswer(spark.read.parquet(path.getAbsolutePath), Row("a"))
checkAnswer(
spark.read.parquet(path.getAbsolutePath),
Row("a"))
}
}

@@ -434,7 +428,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
val collationId = CollationFactory.collationNameToId(collationName)

withTable(tableName) {
sql(s"""
sql(
s"""
|CREATE TABLE $tableName (c1 STRING COLLATE $collationName)
|USING PARQUET
|""".stripMargin)
@@ -453,7 +448,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
val collationId = CollationFactory.collationNameToId(collationName)

withTable(tableName) {
sql(s"""
sql(
s"""
|CREATE TABLE $tableName
|(c1 STRUCT<name: STRING COLLATE $collationName, age: INT>)
|USING PARQUET
@@ -462,11 +458,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
sql(s"INSERT INTO $tableName VALUES (named_struct('name', 'aaa', 'id', 1))")
sql(s"INSERT INTO $tableName VALUES (named_struct('name', 'AAA', 'id', 2))")

checkAnswer(
sql(s"SELECT DISTINCT collation(c1.name) FROM $tableName"),
checkAnswer(sql(s"SELECT DISTINCT collation(c1.name) FROM $tableName"),
Seq(Row(collationName)))
assert(
sql(s"SELECT c1.name FROM $tableName").schema.head.dataType == StringType(collationId))
assert(sql(s"SELECT c1.name FROM $tableName").schema.head.dataType == StringType(collationId))
}
}

@@ -477,27 +471,29 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
val collationId = CollationFactory.collationNameToId(collationName)

withTable(tableName) {
sql(s"""
sql(
s"""
|CREATE TABLE $tableName (c1 STRING)
|USING PARQUET
|""".stripMargin)

sql(s"INSERT INTO $tableName VALUES ('aaa')")
sql(s"INSERT INTO $tableName VALUES ('AAA')")

checkAnswer(
sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"),
checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"),
Seq(Row(defaultCollation)))

sql(s"""
sql(
s"""
|ALTER TABLE $tableName
|ADD COLUMN c2 STRING COLLATE $collationName
|""".stripMargin)

sql(s"INSERT INTO $tableName VALUES ('aaa', 'aaa')")
sql(s"INSERT INTO $tableName VALUES ('AAA', 'AAA')")

checkAnswer(sql(s"SELECT DISTINCT COLLATION(c2) FROM $tableName"), Seq(Row(collationName)))
checkAnswer(sql(s"SELECT DISTINCT COLLATION(c2) FROM $tableName"),
Seq(Row(collationName)))
assert(sql(s"select c2 FROM $tableName").schema.head.dataType == StringType(collationId))
}
}
@@ -508,7 +504,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
val collationId = CollationFactory.collationNameToId(collationName)

withTable(tableName) {
sql(s"""
sql(
s"""
|CREATE TABLE $tableName (c1 string COLLATE $collationName)
|USING $v2Source
|""".stripMargin)
@@ -526,7 +523,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {

sql(s"INSERT INTO $tableName VALUES ('a'), ('A')")

checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), Seq(Row(collationName)))
checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"),
Seq(Row(collationName)))
assert(sql(s"select c1 FROM $tableName").schema.head.dataType == StringType(collationId))
}
}
@@ -535,7 +533,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
def createTable(partitionColumns: String*): Unit = {
val tableName = "test_partition_tbl"
withTable(tableName) {
sql(s"""
sql(
s"""
|CREATE TABLE $tableName
|(id INT, c1 STRING COLLATE UNICODE, c2 string)
|USING parquet
@@ -555,78 +554,66 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
createTable(partitionColumns: _*)
},
errorClass = "INVALID_PARTITION_COLUMN_DATA_TYPE",
parameters = Map("type" -> "\"STRING COLLATE UNICODE\""));
parameters = Map("type" -> "\"STRING COLLATE UNICODE\"")
);
}
}

test("shuffle respects collation") {
val in = (('a' to 'z') ++ ('A' to 'Z')).map(_.toString * 3).map(Row.apply(_))

val schema = StructType(
StructField(
"col",
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) :: Nil)
val schema = StructType(StructField(
"col",
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) :: Nil)
val df = spark.createDataFrame(sparkContext.parallelize(in), schema)

df.repartition(10, df.col("col"))
.foreachPartition((rowIterator: Iterator[Row]) => {
df.repartition(10, df.col("col")).foreachPartition(
(rowIterator: Iterator[Row]) => {
val partitionData = rowIterator.map(r => r.getString(0)).toArray
partitionData.foreach(s => {
// assert that both lower and upper case of the string are present in the same partition.
assert(partitionData.contains(s.toLowerCase()))
assert(partitionData.contains(s.toUpperCase()))
})
})
})
}

test("hash based joins not allowed for non-binary collated strings") {
val in = (('a' to 'z') ++ ('A' to 'Z')).map(_.toString * 3).map(e => Row.apply(e, e))

val schema = StructType(
StructField(
"col_non_binary",
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) ::
StructField("col_binary", StringType) :: Nil)
val schema = StructType(StructField(
"col_non_binary",
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) ::
StructField("col_binary", StringType) :: Nil)
val df1 = spark.createDataFrame(sparkContext.parallelize(in), schema)

// Binary collations are allowed to use hash join.
assert(
collectFirst(
df1
.hint("broadcast")
.join(df1, df1("col_binary") === df1("col_binary"))
.queryExecution
.executedPlan) { case _: BroadcastHashJoinExec =>
()
}.nonEmpty)
assert(collectFirst(
df1.hint("broadcast").join(df1, df1("col_binary") === df1("col_binary"))
.queryExecution.executedPlan) {
case _: BroadcastHashJoinExec => ()
}.nonEmpty)

// Even with hint broadcast, hash join is not used for non-binary collated strings.
assert(
collectFirst(
df1
.hint("broadcast")
.join(df1, df1("col_non_binary") === df1("col_non_binary"))
.queryExecution
.executedPlan) { case _: BroadcastHashJoinExec =>
()
}.isEmpty)
assert(collectFirst(
df1.hint("broadcast").join(df1, df1("col_non_binary") === df1("col_non_binary"))
.queryExecution.executedPlan) {
case _: BroadcastHashJoinExec => ()
}.isEmpty)

// Instead they will default to sort merge join.
assert(
collectFirst(
df1
.hint("broadcast")
.join(df1, df1("col_non_binary") === df1("col_non_binary"))
.queryExecution
.executedPlan) { case _: SortMergeJoinExec =>
()
}.nonEmpty)
assert(collectFirst(
df1.hint("broadcast").join(df1, df1("col_non_binary") === df1("col_non_binary"))
.queryExecution.executedPlan) {
case _: SortMergeJoinExec => ()
}.nonEmpty)
}

test("Generated column expressions using collations - errors out") {
checkError(
exception = intercept[AnalysisException] {
sql(s"""
sql(
s"""
|CREATE TABLE testcat.test_table(
| c1 STRING COLLATE UNICODE,
| c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (SUBSTRING(c1, 0, 1))
@@ -642,7 +629,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {

checkError(
exception = intercept[AnalysisException] {
sql(s"""
sql(
s"""
|CREATE TABLE testcat.test_table(
| c1 STRING COLLATE UNICODE,
| c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (c1 || 'a' COLLATE UNICODE)
@@ -658,7 +646,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {

checkError(
exception = intercept[AnalysisException] {
sql(s"""
sql(
s"""
|CREATE TABLE testcat.test_table(
| struct1 STRUCT<a: STRING COLLATE UNICODE>,
| c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (SUBSTRING(struct1.a, 0, 1))
Original file line number Diff line number Diff line change
@@ -31,7 +31,8 @@ import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType}

/**
* Benchmark to measure read performance with Filter pushdown. To run this benchmark:
* Benchmark to measure read performance with Filter pushdown.
* To run this benchmark:
* {{{
* 1. without sbt: bin/spark-submit --class <this class>
* --jars <spark core test jar>,<spark catalyst test jar> <spark sql test jar>
@@ -49,8 +50,7 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
.set("spark.master", "local[1]")
.setIfMissing("spark.driver.memory", "3g")
.setIfMissing("spark.executor.memory", "3g")
.setIfMissing(
"spark.sql.parquet.compression.codec",
.setIfMissing("spark.sql.parquet.compression.codec",
ParquetCompressionCodec.SNAPPY.lowerCaseName())

SparkSession.builder().config(conf).getOrCreate()
@@ -63,37 +63,27 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
private val blockSize = org.apache.parquet.hadoop.ParquetWriter.DEFAULT_PAGE_SIZE

def withTempTable(tableNames: String*)(f: => Unit): Unit = {
try f
finally tableNames.foreach(spark.catalog.dropTempView)
try f finally tableNames.foreach(spark.catalog.dropTempView)
}

private def prepareTable(
dir: File,
numRows: Int,
width: Int,
useStringForValue: Boolean): Unit = {
dir: File, numRows: Int, width: Int, useStringForValue: Boolean): Unit = {
import spark.implicits._
val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i")
val valueCol = if (useStringForValue) {
monotonically_increasing_id().cast("string")
} else {
monotonically_increasing_id()
}
val df = spark
.range(numRows)
.map(_ => Random.nextLong())
.selectExpr(selectExpr: _*)
val df = spark.range(numRows).map(_ => Random.nextLong()).selectExpr(selectExpr: _*)
.withColumn("value", valueCol)
.sort("value")

saveAsTable(df, dir)
}

private def prepareStringDictTable(
dir: File,
numRows: Int,
numDistinctValues: Int,
width: Int): Unit = {
dir: File, numRows: Int, numDistinctValues: Int, width: Int): Unit = {
val selectExpr = (0 to width).map {
case 0 => s"CAST(id % $numDistinctValues AS STRING) AS value"
case i => s"CAST(rand() AS STRING) c$i"
@@ -107,18 +97,14 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
val orcPath = dir.getCanonicalPath + "/orc"
val parquetPath = dir.getCanonicalPath + "/parquet"

df.write
.mode("overwrite")
df.write.mode("overwrite")
.option("orc.dictionary.key.threshold", if (useDictionary) 1.0 else 0.8)
.option("orc.compress.size", blockSize)
.option("orc.stripe.size", blockSize)
.orc(orcPath)
.option("orc.stripe.size", blockSize).orc(orcPath)
spark.read.orc(orcPath).createOrReplaceTempView("orcTable")

df.write
.mode("overwrite")
.option("parquet.block.size", blockSize)
.parquet(parquetPath)
df.write.mode("overwrite")
.option("parquet.block.size", blockSize).parquet(parquetPath)
spark.read.parquet(parquetPath).createOrReplaceTempView("parquetTable")
}

@@ -160,7 +146,8 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
s"value = $mid",
s"value <=> $mid",
s"$mid <= value AND value <= $mid",
s"${mid - 1} < value AND value < ${mid + 1}").foreach { whereExpr =>
s"${mid - 1} < value AND value < ${mid + 1}"
).foreach { whereExpr =>
val title = s"Select 1 int row ($whereExpr)".replace("value AND value", "value")
filterPushDownBenchmark(numRows, title, whereExpr)
}
@@ -172,66 +159,39 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
numRows,
s"Select $percent% int rows (value < ${numRows * percent / 100})",
s"value < ${numRows * percent / 100}",
selectExpr)
selectExpr
)
}

Seq("value IS NOT NULL", "value > -1", "value != -1").foreach { whereExpr =>
filterPushDownBenchmark(numRows, s"Select all int rows ($whereExpr)", whereExpr, selectExpr)
}
}

private def runStringBenchmark(
numRows: Int,
width: Int,
searchValue: Int,
colType: String): Unit = {
Seq("value IS NULL", s"'$searchValue' < value AND value < '$searchValue'")
.foreach { whereExpr =>
val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value")
filterPushDownBenchmark(numRows, title, whereExpr)
}

Seq(
s"value = '$searchValue'",
s"value <=> '$searchValue'",
s"'$searchValue' <= value AND value <= '$searchValue'").foreach { whereExpr =>
val title = s"Select 1 $colType row ($whereExpr)".replace("value AND value", "value")
filterPushDownBenchmark(numRows, title, whereExpr)
}

val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)")

Seq("value IS NOT NULL").foreach { whereExpr =>
filterPushDownBenchmark(
numRows,
s"Select all $colType rows ($whereExpr)",
s"Select all int rows ($whereExpr)",
whereExpr,
selectExpr)
}
}

private def runStringBenchmarkCollation(
numRows: Int,
width: Int,
searchValue: Int,
colType: String): Unit = {
private def runStringBenchmark(
numRows: Int, width: Int, searchValue: Int, colType: String): Unit = {
Seq("value IS NULL", s"'$searchValue' < value AND value < '$searchValue'")
.foreach { whereExpr =>
val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value")
filterPushDownBenchmark(numRows, title, whereExpr)
}
.foreach { whereExpr =>
val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value")
filterPushDownBenchmark(numRows, title, whereExpr)
}

Seq(
s"value = '$searchValue'",
s"value <=> '$searchValue'",
s"'$searchValue' <= value AND value <= '$searchValue'").foreach { whereExpr =>
s"'$searchValue' <= value AND value <= '$searchValue'"
).foreach { whereExpr =>
val title = s"Select 1 $colType row ($whereExpr)".replace("value AND value", "value")
filterPushDownBenchmark(numRows, title, whereExpr)
}

val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)")

Seq("collate(value, 'UNICODE_CI') = collate(value, 'UNICODE_CI')").foreach { whereExpr =>
Seq("value IS NOT NULL").foreach { whereExpr =>
filterPushDownBenchmark(
numRows,
s"Select all $colType rows ($whereExpr)",
@@ -248,7 +208,6 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
prepareTable(dir, numRows, width, useStringForValue)
if (useStringForValue) {
runStringBenchmark(numRows, width, mid, "string")
runStringBenchmarkCollation(numRows, width, mid, "string")
} else {
runIntBenchmark(numRows, width, mid)
}
@@ -275,10 +234,10 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
Seq(
"value like '10%'",
"value like '1000%'",
s"value like '${mid.toString.substring(0, mid.toString.length - 1)}%'").foreach {
whereExpr =>
val title = s"StringStartsWith filter: ($whereExpr)"
filterPushDownBenchmark(numRows, title, whereExpr)
s"value like '${mid.toString.substring(0, mid.toString.length - 1)}%'"
).foreach { whereExpr =>
val title = s"StringStartsWith filter: ($whereExpr)"
filterPushDownBenchmark(numRows, title, whereExpr)
}
}
}
@@ -291,10 +250,10 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
Seq(
"value like '%10'",
"value like '%1000'",
s"value like '%${mid.toString.substring(0, mid.toString.length - 1)}'").foreach {
whereExpr =>
val title = s"StringEndsWith filter: ($whereExpr)"
filterPushDownBenchmark(numRows, title, whereExpr)
s"value like '%${mid.toString.substring(0, mid.toString.length - 1)}'"
).foreach { whereExpr =>
val title = s"StringEndsWith filter: ($whereExpr)"
filterPushDownBenchmark(numRows, title, whereExpr)
}
}
}
@@ -307,10 +266,10 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
Seq(
"value like '%10%'",
"value like '%1000%'",
s"value like '%${mid.toString.substring(0, mid.toString.length - 1)}%'").foreach {
whereExpr =>
val title = s"StringContains filter: ($whereExpr)"
filterPushDownBenchmark(numRows, title, whereExpr)
s"value like '%${mid.toString.substring(0, mid.toString.length - 1)}%'"
).foreach { whereExpr =>
val title = s"StringContains filter: ($whereExpr)"
filterPushDownBenchmark(numRows, title, whereExpr)
}
}
}
@@ -321,17 +280,16 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
Seq(
s"decimal(${Decimal.MAX_INT_DIGITS}, 2)",
s"decimal(${Decimal.MAX_LONG_DIGITS}, 2)",
s"decimal(${DecimalType.MAX_PRECISION}, 2)").foreach { dt =>
s"decimal(${DecimalType.MAX_PRECISION}, 2)"
).foreach { dt =>
val columns = (1 to width).map(i => s"CAST(id AS string) c$i")
val valueCol = if (dt.equalsIgnoreCase(s"decimal(${Decimal.MAX_INT_DIGITS}, 2)")) {
monotonically_increasing_id() % 9999999
} else {
monotonically_increasing_id()
}
val df = spark
.range(numRows)
.selectExpr(columns: _*)
.withColumn("value", valueCol.cast(dt))
val df = spark.range(numRows)
.selectExpr(columns: _*).withColumn("value", valueCol.cast(dt))
withTempTable("orcTable", "parquetTable") {
saveAsTable(df, dir)

@@ -346,7 +304,8 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
numRows,
s"Select $percent% $dt rows (value < ${numRows * percent / 100})",
s"value < ${numRows * percent / 100}",
selectExpr)
selectExpr
)
}
}
}
@@ -362,8 +321,7 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
val filter =
Range(0, count).map(r => scala.util.Random.nextInt(numRows * distribution / 100))
val whereExpr = s"value in(${filter.mkString(",")})"
val title =
s"InSet -> InFilters (values count: $count, distribution: $distribution)"
val title = s"InSet -> InFilters (values count: $count, distribution: $distribution)"
filterPushDownBenchmark(numRows, title, whereExpr)
}
}
@@ -374,9 +332,7 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
runBenchmark(s"Pushdown benchmark for ${ByteType.simpleString}") {
withTempPath { dir =>
val columns = (1 to width).map(i => s"CAST(id AS string) c$i")
val df = spark
.range(numRows)
.selectExpr(columns: _*)
val df = spark.range(numRows).selectExpr(columns: _*)
.withColumn("value", (monotonically_increasing_id() % Byte.MaxValue).cast(ByteType))
.orderBy("value")
withTempTable("orcTable", "parquetTable") {
@@ -396,7 +352,8 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
s"Select $percent% ${ByteType.simpleString} rows " +
s"(value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString}))",
s"value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString})",
selectExpr)
selectExpr
)
}
}
}
@@ -408,9 +365,7 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
ParquetOutputTimestampType.values.toSeq.map(_.toString).foreach { fileType =>
withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> fileType) {
val columns = (1 to width).map(i => s"CAST(id AS string) c$i")
val df = spark
.range(numRows)
.selectExpr(columns: _*)
val df = spark.range(numRows).selectExpr(columns: _*)
.withColumn("value", timestamp_seconds(monotonically_increasing_id()))
withTempTable("orcTable", "parquetTable") {
saveAsTable(df, dir)
@@ -422,15 +377,15 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark {
}

val selectExpr = (1 to width)
.map(i => s"MAX(c$i)")
.mkString("", ",", ", MAX(value)")
.map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)")
Seq(10, 50, 90).foreach { percent =>
filterPushDownBenchmark(
numRows,
s"Select $percent% timestamp stored as $fileType rows " +
s"(value < timestamp_seconds(${numRows * percent / 100}))",
s"value < timestamp_seconds(${numRows * percent / 100})",
selectExpr)
selectExpr
)
}
}
}

0 comments on commit 7aa6ce0

Please sign in to comment.