diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 2481b9b1d43ec..aaf3e88c9bdb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -// import scala.collection.immutable.Seq +import scala.collection.immutable.Seq import scala.jdk.CollectionConverters.MapHasAsJava import org.apache.spark.SparkException @@ -39,9 +39,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq("utf8_binary", "utf8_binary_lcase", "unicode", "unicode_ci").foreach { collationName => checkAnswer(sql(s"select 'aaa' collate $collationName"), Row("aaa")) val collationId = CollationFactory.collationNameToId(collationName) - assert( - sql(s"select 'aaa' collate $collationName").schema(0).dataType - == StringType(collationId)) + assert(sql(s"select 'aaa' collate $collationName").schema(0).dataType + == StringType(collationId)) } } @@ -49,17 +48,15 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq("uTf8_BiNaRy", "uTf8_BiNaRy_Lcase", "uNicOde", "UNICODE_ci").foreach { collationName => checkAnswer(sql(s"select 'aaa' collate $collationName"), Row("aaa")) val collationId = CollationFactory.collationNameToId(collationName) - assert( - sql(s"select 'aaa' collate $collationName").schema(0).dataType - == StringType(collationId)) + assert(sql(s"select 'aaa' collate $collationName").schema(0).dataType + == StringType(collationId)) } } test("collation expression returns name of collation") { Seq("utf8_binary", "utf8_binary_lcase", "unicode", "unicode_ci").foreach { collationName => checkAnswer( - sql(s"select collation('aaa' collate $collationName)"), - Row(collationName.toUpperCase())) + sql(s"select collation('aaa' collate $collationName)"), Row(collationName.toUpperCase())) } } @@ -82,8 +79,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "expectedNum" -> "2", "actualNum" -> paramCount.toString, "docroot" -> "https://spark.apache.org/docs/latest"), - context = - ExpectedContext(fragment = s"collate($args)", start = 7, stop = 15 + args.length)) + context = ExpectedContext(fragment = s"collate($args)", start = 7, stop = 15 + args.length) + ) }) } @@ -98,26 +95,25 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "inputSql" -> "\"123\"", "inputType" -> "\"INT\"", "requiredType" -> "\"STRING\""), - context = ExpectedContext(fragment = s"collate('abc', 123)", start = 7, stop = 25)) + context = ExpectedContext(fragment = s"collate('abc', 123)", start = 7, stop = 25) + ) } test("NULL as collation name") { checkError( exception = intercept[AnalysisException] { - sql("select collate('abc', cast(null as string))") - }, + sql("select collate('abc', cast(null as string))") }, errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL", sqlState = "42K09", Map("exprName" -> "`collation`", "sqlExpr" -> "\"CAST(NULL AS STRING)\""), - context = - ExpectedContext(fragment = s"collate('abc', cast(null as string))", start = 7, stop = 42)) + context = ExpectedContext( + fragment = s"collate('abc', cast(null as string))", start = 7, stop = 42) + ) } test("collate function invalid input data type") { checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"select collate(1, 'UTF8_BINARY')") - }, + exception = intercept[ExtendedAnalysisException] { sql(s"select collate(1, 'UTF8_BINARY')") }, errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = "42K09", parameters = Map( @@ -126,7 +122,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "inputSql" -> "\"1\"", "inputType" -> "\"INT\"", "requiredType" -> "\"STRING\""), - context = ExpectedContext(fragment = s"collate(1, 'UTF8_BINARY')", start = 7, stop = 31)) + context = ExpectedContext( + fragment = s"collate(1, 'UTF8_BINARY')", start = 7, stop = 31)) } test("collation expression returns default collation") { @@ -145,12 +142,14 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { def createTable(bucketColumns: String*): Unit = { val tableName = "test_partition_tbl" withTable(tableName) { - sql(s""" + sql( + s""" |CREATE TABLE $tableName |(id INT, c1 STRING COLLATE UNICODE, c2 string) |USING parquet |CLUSTERED BY (${bucketColumns.mkString(",")}) - |INTO 4 BUCKETS""".stripMargin) + |INTO 4 BUCKETS""".stripMargin + ) } } // should work fine on default collated columns @@ -164,7 +163,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { createTable(bucketColumns: _*) }, errorClass = "INVALID_BUCKET_COLUMN_DATA_TYPE", - parameters = Map("type" -> "\"STRING COLLATE UNICODE\"")); + parameters = Map("type" -> "\"STRING COLLATE UNICODE\"") + ); } } @@ -179,7 +179,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ("unicode", "aaa", "AAA", false), ("unicode_CI", "aaa", "aaa", true), ("unicode_CI", "aaa", "AAA", true), - ("unicode_CI", "aaa", "bbb", false)).foreach { + ("unicode_CI", "aaa", "bbb", false) + ).foreach { case (collationName, left, right, expected) => checkAnswer( sql(s"select '$left' collate $collationName = '$right' collate $collationName"), @@ -203,13 +204,15 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ("unicode", "aaa", "BBB", true), ("unicode_CI", "aaa", "aaa", false), ("unicode_CI", "aaa", "AAA", false), - ("unicode_CI", "aaa", "bbb", true)).foreach { case (collationName, left, right, expected) => - checkAnswer( - sql(s"select '$left' collate $collationName < '$right' collate $collationName"), - Row(expected)) - checkAnswer( - sql(s"select collate('$left', '$collationName') < collate('$right', '$collationName')"), - Row(expected)) + ("unicode_CI", "aaa", "bbb", true) + ).foreach { + case (collationName, left, right, expected) => + checkAnswer( + sql(s"select '$left' collate $collationName < '$right' collate $collationName"), + Row(expected)) + checkAnswer( + sql(s"select collate('$left', '$collationName') < collate('$right', '$collationName')"), + Row(expected)) } } @@ -222,63 +225,59 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { right = left.substring(1, 2); checkError( exception = intercept[ExtendedAnalysisException] { - spark.sql( - s"SELECT contains(collate('$left', '$leftCollationName')," + - s"collate('$right', '$rightCollationName'))") + spark.sql(s"SELECT contains(collate('$left', '$leftCollationName')," + + s"collate('$right', '$rightCollationName'))") }, errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", sqlState = "42K09", parameters = Map( "collationNameLeft" -> s"$leftCollationName", "collationNameRight" -> s"$rightCollationName", - "sqlExpr" -> "\"contains(collate(abc), collate(b))\""), - context = ExpectedContext( - fragment = s"contains(collate('abc', 'UNICODE_CI'),collate('b', 'UNICODE'))", - start = 7, - stop = 68)) + "sqlExpr" -> "\"contains(collate(abc), collate(b))\"" + ), + context = ExpectedContext(fragment = + s"contains(collate('abc', 'UNICODE_CI'),collate('b', 'UNICODE'))", + start = 7, stop = 68) + ) // startsWith right = left.substring(0, 1); checkError( exception = intercept[ExtendedAnalysisException] { - spark.sql( - s"SELECT startsWith(collate('$left', '$leftCollationName')," + - s"collate('$right', '$rightCollationName'))") + spark.sql(s"SELECT startsWith(collate('$left', '$leftCollationName')," + + s"collate('$right', '$rightCollationName'))") }, errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", sqlState = "42K09", parameters = Map( "collationNameLeft" -> s"$leftCollationName", "collationNameRight" -> s"$rightCollationName", - "sqlExpr" -> "\"startswith(collate(abc), collate(a))\""), - context = ExpectedContext( - fragment = s"startsWith(collate('abc', 'UNICODE_CI'),collate('a', 'UNICODE'))", - start = 7, - stop = 70)) + "sqlExpr" -> "\"startswith(collate(abc), collate(a))\"" + ), + context = ExpectedContext(fragment = + s"startsWith(collate('abc', 'UNICODE_CI'),collate('a', 'UNICODE'))", + start = 7, stop = 70) + ) // endsWith right = left.substring(2, 3); checkError( exception = intercept[ExtendedAnalysisException] { - spark.sql( - s"SELECT endsWith(collate('$left', '$leftCollationName')," + - s"collate('$right', '$rightCollationName'))") + spark.sql(s"SELECT endsWith(collate('$left', '$leftCollationName')," + + s"collate('$right', '$rightCollationName'))") }, errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", sqlState = "42K09", parameters = Map( "collationNameLeft" -> s"$leftCollationName", "collationNameRight" -> s"$rightCollationName", - "sqlExpr" -> "\"endswith(collate(abc), collate(c))\""), - context = ExpectedContext( - fragment = s"endsWith(collate('abc', 'UNICODE_CI'),collate('c', 'UNICODE'))", - start = 7, - stop = 68)) + "sqlExpr" -> "\"endswith(collate(abc), collate(c))\"" + ), + context = ExpectedContext(fragment = + s"endsWith(collate('abc', 'UNICODE_CI'),collate('c', 'UNICODE'))", + start = 7, stop = 68) + ) } - case class CollationTestCase[R]( - left: String, - right: String, - collation: String, - expectedResult: R) + case class CollationTestCase[R](left: String, right: String, collation: String, expectedResult: R) test("Support contains string expression with Collation") { // Supported collations @@ -318,13 +317,11 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { CollationTestCase("abcde", "bcd", "UNICODE_CI", true), CollationTestCase("abcde", "BCD", "UNICODE_CI", true), CollationTestCase("abcde", "fgh", "UNICODE_CI", false), - CollationTestCase("abcde", "FGH", "UNICODE_CI", false)) + CollationTestCase("abcde", "FGH", "UNICODE_CI", false) + ) checks.foreach(testCase => { - checkAnswer( - sql( - s"SELECT contains(collate('${testCase.left}', '${testCase.collation}')," + - s"collate('${testCase.right}', '${testCase.collation}'))"), - Row(testCase.expectedResult)) + checkAnswer(sql(s"SELECT contains(collate('${testCase.left}', '${testCase.collation}')," + + s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) }) } @@ -338,13 +335,11 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { CollationTestCase("abcde", "ABC", "UTF8_BINARY_LCASE", true), CollationTestCase("abcde", "bcd", "UTF8_BINARY_LCASE", false), CollationTestCase("abcde", "ABC", "UNICODE_CI", true), - CollationTestCase("abcde", "bcd", "UNICODE_CI", false)) + CollationTestCase("abcde", "bcd", "UNICODE_CI", false) + ) checks.foreach(testCase => { - checkAnswer( - sql( - s"SELECT startswith(collate('${testCase.left}', '${testCase.collation}')," + - s"collate('${testCase.right}', '${testCase.collation}'))"), - Row(testCase.expectedResult)) + checkAnswer(sql(s"SELECT startswith(collate('${testCase.left}', '${testCase.collation}')," + + s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) }) } @@ -358,13 +353,11 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { CollationTestCase("abcde", "CDE", "UTF8_BINARY_LCASE", true), CollationTestCase("abcde", "bcd", "UTF8_BINARY_LCASE", false), CollationTestCase("abcde", "CDE", "UNICODE_CI", true), - CollationTestCase("abcde", "bcd", "UNICODE_CI", false)) + CollationTestCase("abcde", "bcd", "UNICODE_CI", false) + ) checks.foreach(testCase => { - checkAnswer( - sql( - s"SELECT endswith(collate('${testCase.left}', '${testCase.collation}')," + - s"collate('${testCase.right}', '${testCase.collation}'))"), - Row(testCase.expectedResult)) + checkAnswer(sql(s"SELECT endswith(collate('${testCase.left}', '${testCase.collation}')," + + s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) }) } @@ -381,18 +374,18 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ("unicode", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), ("unicode_CI", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("unicode_CI", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))), - ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb")))).foreach { + ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))) + ).foreach { case (collationName: String, input: Seq[String], expected: Seq[Row]) => - checkAnswer( - sql(s""" + checkAnswer(sql( + s""" with t as ( select collate(col1, '$collationName') as c from values ${input.map(s => s"('$s')").mkString(", ")} ) SELECT COUNT(*), c FROM t GROUP BY c - """), - expected) + """), expected) } } @@ -401,8 +394,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val tableNameBinary = "T_BINARY" withTable(tableNameNonBinary) { withTable(tableNameBinary) { - sql( - s"CREATE TABLE $tableNameNonBinary (c STRING COLLATE UTF8_BINARY_LCASE) USING PARQUET") + sql(s"CREATE TABLE $tableNameNonBinary (c STRING COLLATE UTF8_BINARY_LCASE) USING PARQUET") sql(s"INSERT INTO $tableNameNonBinary VALUES ('aaa')") sql(s"CREATE TABLE $tableNameBinary (c STRING COLLATE UTF8_BINARY) USING PARQUET") sql(s"INSERT INTO $tableNameBinary VALUES ('aaa')") @@ -421,10 +413,12 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } test("text writing to parquet with collation enclosed with backticks") { - withTempPath { path => + withTempPath{ path => sql(s"select 'a' COLLATE `UNICODE`").write.parquet(path.getAbsolutePath) - checkAnswer(spark.read.parquet(path.getAbsolutePath), Row("a")) + checkAnswer( + spark.read.parquet(path.getAbsolutePath), + Row("a")) } } @@ -434,7 +428,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val collationId = CollationFactory.collationNameToId(collationName) withTable(tableName) { - sql(s""" + sql( + s""" |CREATE TABLE $tableName (c1 STRING COLLATE $collationName) |USING PARQUET |""".stripMargin) @@ -453,7 +448,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val collationId = CollationFactory.collationNameToId(collationName) withTable(tableName) { - sql(s""" + sql( + s""" |CREATE TABLE $tableName |(c1 STRUCT) |USING PARQUET @@ -462,11 +458,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"INSERT INTO $tableName VALUES (named_struct('name', 'aaa', 'id', 1))") sql(s"INSERT INTO $tableName VALUES (named_struct('name', 'AAA', 'id', 2))") - checkAnswer( - sql(s"SELECT DISTINCT collation(c1.name) FROM $tableName"), + checkAnswer(sql(s"SELECT DISTINCT collation(c1.name) FROM $tableName"), Seq(Row(collationName))) - assert( - sql(s"SELECT c1.name FROM $tableName").schema.head.dataType == StringType(collationId)) + assert(sql(s"SELECT c1.name FROM $tableName").schema.head.dataType == StringType(collationId)) } } @@ -477,7 +471,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val collationId = CollationFactory.collationNameToId(collationName) withTable(tableName) { - sql(s""" + sql( + s""" |CREATE TABLE $tableName (c1 STRING) |USING PARQUET |""".stripMargin) @@ -485,11 +480,11 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"INSERT INTO $tableName VALUES ('aaa')") sql(s"INSERT INTO $tableName VALUES ('AAA')") - checkAnswer( - sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), + checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), Seq(Row(defaultCollation))) - sql(s""" + sql( + s""" |ALTER TABLE $tableName |ADD COLUMN c2 STRING COLLATE $collationName |""".stripMargin) @@ -497,7 +492,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"INSERT INTO $tableName VALUES ('aaa', 'aaa')") sql(s"INSERT INTO $tableName VALUES ('AAA', 'AAA')") - checkAnswer(sql(s"SELECT DISTINCT COLLATION(c2) FROM $tableName"), Seq(Row(collationName))) + checkAnswer(sql(s"SELECT DISTINCT COLLATION(c2) FROM $tableName"), + Seq(Row(collationName))) assert(sql(s"select c2 FROM $tableName").schema.head.dataType == StringType(collationId)) } } @@ -508,7 +504,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val collationId = CollationFactory.collationNameToId(collationName) withTable(tableName) { - sql(s""" + sql( + s""" |CREATE TABLE $tableName (c1 string COLLATE $collationName) |USING $v2Source |""".stripMargin) @@ -526,7 +523,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"INSERT INTO $tableName VALUES ('a'), ('A')") - checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), Seq(Row(collationName))) + checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), + Seq(Row(collationName))) assert(sql(s"select c1 FROM $tableName").schema.head.dataType == StringType(collationId)) } } @@ -535,7 +533,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { def createTable(partitionColumns: String*): Unit = { val tableName = "test_partition_tbl" withTable(tableName) { - sql(s""" + sql( + s""" |CREATE TABLE $tableName |(id INT, c1 STRING COLLATE UNICODE, c2 string) |USING parquet @@ -555,78 +554,66 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { createTable(partitionColumns: _*) }, errorClass = "INVALID_PARTITION_COLUMN_DATA_TYPE", - parameters = Map("type" -> "\"STRING COLLATE UNICODE\"")); + parameters = Map("type" -> "\"STRING COLLATE UNICODE\"") + ); } } test("shuffle respects collation") { val in = (('a' to 'z') ++ ('A' to 'Z')).map(_.toString * 3).map(Row.apply(_)) - val schema = StructType( - StructField( - "col", - StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) :: Nil) + val schema = StructType(StructField( + "col", + StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) :: Nil) val df = spark.createDataFrame(sparkContext.parallelize(in), schema) - df.repartition(10, df.col("col")) - .foreachPartition((rowIterator: Iterator[Row]) => { + df.repartition(10, df.col("col")).foreachPartition( + (rowIterator: Iterator[Row]) => { val partitionData = rowIterator.map(r => r.getString(0)).toArray partitionData.foreach(s => { // assert that both lower and upper case of the string are present in the same partition. assert(partitionData.contains(s.toLowerCase())) assert(partitionData.contains(s.toUpperCase())) }) - }) + }) } test("hash based joins not allowed for non-binary collated strings") { val in = (('a' to 'z') ++ ('A' to 'Z')).map(_.toString * 3).map(e => Row.apply(e, e)) - val schema = StructType( - StructField( - "col_non_binary", - StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) :: - StructField("col_binary", StringType) :: Nil) + val schema = StructType(StructField( + "col_non_binary", + StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) :: + StructField("col_binary", StringType) :: Nil) val df1 = spark.createDataFrame(sparkContext.parallelize(in), schema) // Binary collations are allowed to use hash join. - assert( - collectFirst( - df1 - .hint("broadcast") - .join(df1, df1("col_binary") === df1("col_binary")) - .queryExecution - .executedPlan) { case _: BroadcastHashJoinExec => - () - }.nonEmpty) + assert(collectFirst( + df1.hint("broadcast").join(df1, df1("col_binary") === df1("col_binary")) + .queryExecution.executedPlan) { + case _: BroadcastHashJoinExec => () + }.nonEmpty) // Even with hint broadcast, hash join is not used for non-binary collated strings. - assert( - collectFirst( - df1 - .hint("broadcast") - .join(df1, df1("col_non_binary") === df1("col_non_binary")) - .queryExecution - .executedPlan) { case _: BroadcastHashJoinExec => - () - }.isEmpty) + assert(collectFirst( + df1.hint("broadcast").join(df1, df1("col_non_binary") === df1("col_non_binary")) + .queryExecution.executedPlan) { + case _: BroadcastHashJoinExec => () + }.isEmpty) // Instead they will default to sort merge join. - assert( - collectFirst( - df1 - .hint("broadcast") - .join(df1, df1("col_non_binary") === df1("col_non_binary")) - .queryExecution - .executedPlan) { case _: SortMergeJoinExec => - () - }.nonEmpty) + assert(collectFirst( + df1.hint("broadcast").join(df1, df1("col_non_binary") === df1("col_non_binary")) + .queryExecution.executedPlan) { + case _: SortMergeJoinExec => () + }.nonEmpty) } test("Generated column expressions using collations - errors out") { checkError( exception = intercept[AnalysisException] { - sql(s""" + sql( + s""" |CREATE TABLE testcat.test_table( | c1 STRING COLLATE UNICODE, | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (SUBSTRING(c1, 0, 1)) @@ -642,7 +629,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkError( exception = intercept[AnalysisException] { - sql(s""" + sql( + s""" |CREATE TABLE testcat.test_table( | c1 STRING COLLATE UNICODE, | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (c1 || 'a' COLLATE UNICODE) @@ -658,7 +646,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkError( exception = intercept[AnalysisException] { - sql(s""" + sql( + s""" |CREATE TABLE testcat.test_table( | struct1 STRUCT, | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (SUBSTRING(struct1.a, 0, 1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index a6bda01fa2500..0b6b5db556af7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType} /** - * Benchmark to measure read performance with Filter pushdown. To run this benchmark: + * Benchmark to measure read performance with Filter pushdown. + * To run this benchmark: * {{{ * 1. without sbt: bin/spark-submit --class * --jars , @@ -49,8 +50,7 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { .set("spark.master", "local[1]") .setIfMissing("spark.driver.memory", "3g") .setIfMissing("spark.executor.memory", "3g") - .setIfMissing( - "spark.sql.parquet.compression.codec", + .setIfMissing("spark.sql.parquet.compression.codec", ParquetCompressionCodec.SNAPPY.lowerCaseName()) SparkSession.builder().config(conf).getOrCreate() @@ -63,15 +63,11 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { private val blockSize = org.apache.parquet.hadoop.ParquetWriter.DEFAULT_PAGE_SIZE def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f - finally tableNames.foreach(spark.catalog.dropTempView) + try f finally tableNames.foreach(spark.catalog.dropTempView) } private def prepareTable( - dir: File, - numRows: Int, - width: Int, - useStringForValue: Boolean): Unit = { + dir: File, numRows: Int, width: Int, useStringForValue: Boolean): Unit = { import spark.implicits._ val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") val valueCol = if (useStringForValue) { @@ -79,10 +75,7 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { } else { monotonically_increasing_id() } - val df = spark - .range(numRows) - .map(_ => Random.nextLong()) - .selectExpr(selectExpr: _*) + val df = spark.range(numRows).map(_ => Random.nextLong()).selectExpr(selectExpr: _*) .withColumn("value", valueCol) .sort("value") @@ -90,10 +83,7 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { } private def prepareStringDictTable( - dir: File, - numRows: Int, - numDistinctValues: Int, - width: Int): Unit = { + dir: File, numRows: Int, numDistinctValues: Int, width: Int): Unit = { val selectExpr = (0 to width).map { case 0 => s"CAST(id % $numDistinctValues AS STRING) AS value" case i => s"CAST(rand() AS STRING) c$i" @@ -107,18 +97,14 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { val orcPath = dir.getCanonicalPath + "/orc" val parquetPath = dir.getCanonicalPath + "/parquet" - df.write - .mode("overwrite") + df.write.mode("overwrite") .option("orc.dictionary.key.threshold", if (useDictionary) 1.0 else 0.8) .option("orc.compress.size", blockSize) - .option("orc.stripe.size", blockSize) - .orc(orcPath) + .option("orc.stripe.size", blockSize).orc(orcPath) spark.read.orc(orcPath).createOrReplaceTempView("orcTable") - df.write - .mode("overwrite") - .option("parquet.block.size", blockSize) - .parquet(parquetPath) + df.write.mode("overwrite") + .option("parquet.block.size", blockSize).parquet(parquetPath) spark.read.parquet(parquetPath).createOrReplaceTempView("parquetTable") } @@ -160,7 +146,8 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { s"value = $mid", s"value <=> $mid", s"$mid <= value AND value <= $mid", - s"${mid - 1} < value AND value < ${mid + 1}").foreach { whereExpr => + s"${mid - 1} < value AND value < ${mid + 1}" + ).foreach { whereExpr => val title = s"Select 1 int row ($whereExpr)".replace("value AND value", "value") filterPushDownBenchmark(numRows, title, whereExpr) } @@ -172,66 +159,39 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { numRows, s"Select $percent% int rows (value < ${numRows * percent / 100})", s"value < ${numRows * percent / 100}", - selectExpr) + selectExpr + ) } Seq("value IS NOT NULL", "value > -1", "value != -1").foreach { whereExpr => - filterPushDownBenchmark(numRows, s"Select all int rows ($whereExpr)", whereExpr, selectExpr) - } - } - - private def runStringBenchmark( - numRows: Int, - width: Int, - searchValue: Int, - colType: String): Unit = { - Seq("value IS NULL", s"'$searchValue' < value AND value < '$searchValue'") - .foreach { whereExpr => - val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - Seq( - s"value = '$searchValue'", - s"value <=> '$searchValue'", - s"'$searchValue' <= value AND value <= '$searchValue'").foreach { whereExpr => - val title = s"Select 1 $colType row ($whereExpr)".replace("value AND value", "value") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") - - Seq("value IS NOT NULL").foreach { whereExpr => filterPushDownBenchmark( numRows, - s"Select all $colType rows ($whereExpr)", + s"Select all int rows ($whereExpr)", whereExpr, selectExpr) } } - private def runStringBenchmarkCollation( - numRows: Int, - width: Int, - searchValue: Int, - colType: String): Unit = { + private def runStringBenchmark( + numRows: Int, width: Int, searchValue: Int, colType: String): Unit = { Seq("value IS NULL", s"'$searchValue' < value AND value < '$searchValue'") - .foreach { whereExpr => - val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value") - filterPushDownBenchmark(numRows, title, whereExpr) - } + .foreach { whereExpr => + val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } Seq( s"value = '$searchValue'", s"value <=> '$searchValue'", - s"'$searchValue' <= value AND value <= '$searchValue'").foreach { whereExpr => + s"'$searchValue' <= value AND value <= '$searchValue'" + ).foreach { whereExpr => val title = s"Select 1 $colType row ($whereExpr)".replace("value AND value", "value") filterPushDownBenchmark(numRows, title, whereExpr) } val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") - Seq("collate(value, 'UNICODE_CI') = collate(value, 'UNICODE_CI')").foreach { whereExpr => + Seq("value IS NOT NULL").foreach { whereExpr => filterPushDownBenchmark( numRows, s"Select all $colType rows ($whereExpr)", @@ -248,7 +208,6 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { prepareTable(dir, numRows, width, useStringForValue) if (useStringForValue) { runStringBenchmark(numRows, width, mid, "string") - runStringBenchmarkCollation(numRows, width, mid, "string") } else { runIntBenchmark(numRows, width, mid) } @@ -275,10 +234,10 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { Seq( "value like '10%'", "value like '1000%'", - s"value like '${mid.toString.substring(0, mid.toString.length - 1)}%'").foreach { - whereExpr => - val title = s"StringStartsWith filter: ($whereExpr)" - filterPushDownBenchmark(numRows, title, whereExpr) + s"value like '${mid.toString.substring(0, mid.toString.length - 1)}%'" + ).foreach { whereExpr => + val title = s"StringStartsWith filter: ($whereExpr)" + filterPushDownBenchmark(numRows, title, whereExpr) } } } @@ -291,10 +250,10 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { Seq( "value like '%10'", "value like '%1000'", - s"value like '%${mid.toString.substring(0, mid.toString.length - 1)}'").foreach { - whereExpr => - val title = s"StringEndsWith filter: ($whereExpr)" - filterPushDownBenchmark(numRows, title, whereExpr) + s"value like '%${mid.toString.substring(0, mid.toString.length - 1)}'" + ).foreach { whereExpr => + val title = s"StringEndsWith filter: ($whereExpr)" + filterPushDownBenchmark(numRows, title, whereExpr) } } } @@ -307,10 +266,10 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { Seq( "value like '%10%'", "value like '%1000%'", - s"value like '%${mid.toString.substring(0, mid.toString.length - 1)}%'").foreach { - whereExpr => - val title = s"StringContains filter: ($whereExpr)" - filterPushDownBenchmark(numRows, title, whereExpr) + s"value like '%${mid.toString.substring(0, mid.toString.length - 1)}%'" + ).foreach { whereExpr => + val title = s"StringContains filter: ($whereExpr)" + filterPushDownBenchmark(numRows, title, whereExpr) } } } @@ -321,17 +280,16 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { Seq( s"decimal(${Decimal.MAX_INT_DIGITS}, 2)", s"decimal(${Decimal.MAX_LONG_DIGITS}, 2)", - s"decimal(${DecimalType.MAX_PRECISION}, 2)").foreach { dt => + s"decimal(${DecimalType.MAX_PRECISION}, 2)" + ).foreach { dt => val columns = (1 to width).map(i => s"CAST(id AS string) c$i") val valueCol = if (dt.equalsIgnoreCase(s"decimal(${Decimal.MAX_INT_DIGITS}, 2)")) { monotonically_increasing_id() % 9999999 } else { monotonically_increasing_id() } - val df = spark - .range(numRows) - .selectExpr(columns: _*) - .withColumn("value", valueCol.cast(dt)) + val df = spark.range(numRows) + .selectExpr(columns: _*).withColumn("value", valueCol.cast(dt)) withTempTable("orcTable", "parquetTable") { saveAsTable(df, dir) @@ -346,7 +304,8 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { numRows, s"Select $percent% $dt rows (value < ${numRows * percent / 100})", s"value < ${numRows * percent / 100}", - selectExpr) + selectExpr + ) } } } @@ -362,8 +321,7 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { val filter = Range(0, count).map(r => scala.util.Random.nextInt(numRows * distribution / 100)) val whereExpr = s"value in(${filter.mkString(",")})" - val title = - s"InSet -> InFilters (values count: $count, distribution: $distribution)" + val title = s"InSet -> InFilters (values count: $count, distribution: $distribution)" filterPushDownBenchmark(numRows, title, whereExpr) } } @@ -374,9 +332,7 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { runBenchmark(s"Pushdown benchmark for ${ByteType.simpleString}") { withTempPath { dir => val columns = (1 to width).map(i => s"CAST(id AS string) c$i") - val df = spark - .range(numRows) - .selectExpr(columns: _*) + val df = spark.range(numRows).selectExpr(columns: _*) .withColumn("value", (monotonically_increasing_id() % Byte.MaxValue).cast(ByteType)) .orderBy("value") withTempTable("orcTable", "parquetTable") { @@ -396,7 +352,8 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { s"Select $percent% ${ByteType.simpleString} rows " + s"(value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString}))", s"value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString})", - selectExpr) + selectExpr + ) } } } @@ -408,9 +365,7 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { ParquetOutputTimestampType.values.toSeq.map(_.toString).foreach { fileType => withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> fileType) { val columns = (1 to width).map(i => s"CAST(id AS string) c$i") - val df = spark - .range(numRows) - .selectExpr(columns: _*) + val df = spark.range(numRows).selectExpr(columns: _*) .withColumn("value", timestamp_seconds(monotonically_increasing_id())) withTempTable("orcTable", "parquetTable") { saveAsTable(df, dir) @@ -422,15 +377,15 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { } val selectExpr = (1 to width) - .map(i => s"MAX(c$i)") - .mkString("", ",", ", MAX(value)") + .map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") Seq(10, 50, 90).foreach { percent => filterPushDownBenchmark( numRows, s"Select $percent% timestamp stored as $fileType rows " + s"(value < timestamp_seconds(${numRows * percent / 100}))", s"value < timestamp_seconds(${numRows * percent / 100})", - selectExpr) + selectExpr + ) } } }