diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 551b3bb72c6ff..2481b9b1d43ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -39,8 +39,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq("utf8_binary", "utf8_binary_lcase", "unicode", "unicode_ci").foreach { collationName => checkAnswer(sql(s"select 'aaa' collate $collationName"), Row("aaa")) val collationId = CollationFactory.collationNameToId(collationName) - assert(sql(s"select 'aaa' collate $collationName").schema(0).dataType - == StringType(collationId)) + assert( + sql(s"select 'aaa' collate $collationName").schema(0).dataType + == StringType(collationId)) } } @@ -48,15 +49,17 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq("uTf8_BiNaRy", "uTf8_BiNaRy_Lcase", "uNicOde", "UNICODE_ci").foreach { collationName => checkAnswer(sql(s"select 'aaa' collate $collationName"), Row("aaa")) val collationId = CollationFactory.collationNameToId(collationName) - assert(sql(s"select 'aaa' collate $collationName").schema(0).dataType - == StringType(collationId)) + assert( + sql(s"select 'aaa' collate $collationName").schema(0).dataType + == StringType(collationId)) } } test("collation expression returns name of collation") { Seq("utf8_binary", "utf8_binary_lcase", "unicode", "unicode_ci").foreach { collationName => checkAnswer( - sql(s"select collation('aaa' collate $collationName)"), Row(collationName.toUpperCase())) + sql(s"select collation('aaa' collate $collationName)"), + Row(collationName.toUpperCase())) } } @@ -79,8 +82,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "expectedNum" -> "2", "actualNum" -> paramCount.toString, "docroot" -> "https://spark.apache.org/docs/latest"), - context = ExpectedContext(fragment = s"collate($args)", start = 7, stop = 15 + args.length) - ) + context = + ExpectedContext(fragment = s"collate($args)", start = 7, stop = 15 + args.length)) }) } @@ -95,25 +98,26 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "inputSql" -> "\"123\"", "inputType" -> "\"INT\"", "requiredType" -> "\"STRING\""), - context = ExpectedContext(fragment = s"collate('abc', 123)", start = 7, stop = 25) - ) + context = ExpectedContext(fragment = s"collate('abc', 123)", start = 7, stop = 25)) } test("NULL as collation name") { checkError( exception = intercept[AnalysisException] { - sql("select collate('abc', cast(null as string))") }, + sql("select collate('abc', cast(null as string))") + }, errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL", sqlState = "42K09", Map("exprName" -> "`collation`", "sqlExpr" -> "\"CAST(NULL AS STRING)\""), - context = ExpectedContext( - fragment = s"collate('abc', cast(null as string))", start = 7, stop = 42) - ) + context = + ExpectedContext(fragment = s"collate('abc', cast(null as string))", start = 7, stop = 42)) } test("collate function invalid input data type") { checkError( - exception = intercept[ExtendedAnalysisException] { sql(s"select collate(1, 'UTF8_BINARY')") }, + exception = intercept[ExtendedAnalysisException] { + sql(s"select collate(1, 'UTF8_BINARY')") + }, errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = "42K09", parameters = Map( @@ -122,8 +126,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "inputSql" -> "\"1\"", "inputType" -> "\"INT\"", "requiredType" -> "\"STRING\""), - context = ExpectedContext( - fragment = s"collate(1, 'UTF8_BINARY')", start = 7, stop = 31)) + context = ExpectedContext(fragment = s"collate(1, 'UTF8_BINARY')", start = 7, stop = 31)) } test("collation expression returns default collation") { @@ -142,14 +145,12 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { def createTable(bucketColumns: String*): Unit = { val tableName = "test_partition_tbl" withTable(tableName) { - sql( - s""" + sql(s""" |CREATE TABLE $tableName |(id INT, c1 STRING COLLATE UNICODE, c2 string) |USING parquet |CLUSTERED BY (${bucketColumns.mkString(",")}) - |INTO 4 BUCKETS""".stripMargin - ) + |INTO 4 BUCKETS""".stripMargin) } } // should work fine on default collated columns @@ -163,8 +164,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { createTable(bucketColumns: _*) }, errorClass = "INVALID_BUCKET_COLUMN_DATA_TYPE", - parameters = Map("type" -> "\"STRING COLLATE UNICODE\"") - ); + parameters = Map("type" -> "\"STRING COLLATE UNICODE\"")); } } @@ -179,8 +179,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ("unicode", "aaa", "AAA", false), ("unicode_CI", "aaa", "aaa", true), ("unicode_CI", "aaa", "AAA", true), - ("unicode_CI", "aaa", "bbb", false) - ).foreach { + ("unicode_CI", "aaa", "bbb", false)).foreach { case (collationName, left, right, expected) => checkAnswer( sql(s"select '$left' collate $collationName = '$right' collate $collationName"), @@ -204,15 +203,13 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ("unicode", "aaa", "BBB", true), ("unicode_CI", "aaa", "aaa", false), ("unicode_CI", "aaa", "AAA", false), - ("unicode_CI", "aaa", "bbb", true) - ).foreach { - case (collationName, left, right, expected) => - checkAnswer( - sql(s"select '$left' collate $collationName < '$right' collate $collationName"), - Row(expected)) - checkAnswer( - sql(s"select collate('$left', '$collationName') < collate('$right', '$collationName')"), - Row(expected)) + ("unicode_CI", "aaa", "bbb", true)).foreach { case (collationName, left, right, expected) => + checkAnswer( + sql(s"select '$left' collate $collationName < '$right' collate $collationName"), + Row(expected)) + checkAnswer( + sql(s"select collate('$left', '$collationName') < collate('$right', '$collationName')"), + Row(expected)) } } @@ -225,59 +222,63 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { right = left.substring(1, 2); checkError( exception = intercept[ExtendedAnalysisException] { - spark.sql(s"SELECT contains(collate('$left', '$leftCollationName')," + - s"collate('$right', '$rightCollationName'))") + spark.sql( + s"SELECT contains(collate('$left', '$leftCollationName')," + + s"collate('$right', '$rightCollationName'))") }, errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", sqlState = "42K09", parameters = Map( "collationNameLeft" -> s"$leftCollationName", "collationNameRight" -> s"$rightCollationName", - "sqlExpr" -> "\"contains(collate(abc), collate(b))\"" - ), - context = ExpectedContext(fragment = - s"contains(collate('abc', 'UNICODE_CI'),collate('b', 'UNICODE'))", - start = 7, stop = 68) - ) + "sqlExpr" -> "\"contains(collate(abc), collate(b))\""), + context = ExpectedContext( + fragment = s"contains(collate('abc', 'UNICODE_CI'),collate('b', 'UNICODE'))", + start = 7, + stop = 68)) // startsWith right = left.substring(0, 1); checkError( exception = intercept[ExtendedAnalysisException] { - spark.sql(s"SELECT startsWith(collate('$left', '$leftCollationName')," + - s"collate('$right', '$rightCollationName'))") + spark.sql( + s"SELECT startsWith(collate('$left', '$leftCollationName')," + + s"collate('$right', '$rightCollationName'))") }, errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", sqlState = "42K09", parameters = Map( "collationNameLeft" -> s"$leftCollationName", "collationNameRight" -> s"$rightCollationName", - "sqlExpr" -> "\"startswith(collate(abc), collate(a))\"" - ), - context = ExpectedContext(fragment = - s"startsWith(collate('abc', 'UNICODE_CI'),collate('a', 'UNICODE'))", - start = 7, stop = 70) - ) + "sqlExpr" -> "\"startswith(collate(abc), collate(a))\""), + context = ExpectedContext( + fragment = s"startsWith(collate('abc', 'UNICODE_CI'),collate('a', 'UNICODE'))", + start = 7, + stop = 70)) // endsWith right = left.substring(2, 3); checkError( exception = intercept[ExtendedAnalysisException] { - spark.sql(s"SELECT endsWith(collate('$left', '$leftCollationName')," + - s"collate('$right', '$rightCollationName'))") + spark.sql( + s"SELECT endsWith(collate('$left', '$leftCollationName')," + + s"collate('$right', '$rightCollationName'))") }, errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", sqlState = "42K09", parameters = Map( "collationNameLeft" -> s"$leftCollationName", "collationNameRight" -> s"$rightCollationName", - "sqlExpr" -> "\"endswith(collate(abc), collate(c))\"" - ), - context = ExpectedContext(fragment = - s"endsWith(collate('abc', 'UNICODE_CI'),collate('c', 'UNICODE'))", - start = 7, stop = 68) - ) + "sqlExpr" -> "\"endswith(collate(abc), collate(c))\""), + context = ExpectedContext( + fragment = s"endsWith(collate('abc', 'UNICODE_CI'),collate('c', 'UNICODE'))", + start = 7, + stop = 68)) } - case class CollationTestCase[R](left: String, right: String, collation: String, expectedResult: R) + case class CollationTestCase[R]( + left: String, + right: String, + collation: String, + expectedResult: R) test("Support contains string expression with Collation") { // Supported collations @@ -317,11 +318,13 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { CollationTestCase("abcde", "bcd", "UNICODE_CI", true), CollationTestCase("abcde", "BCD", "UNICODE_CI", true), CollationTestCase("abcde", "fgh", "UNICODE_CI", false), - CollationTestCase("abcde", "FGH", "UNICODE_CI", false) - ) + CollationTestCase("abcde", "FGH", "UNICODE_CI", false)) checks.foreach(testCase => { - checkAnswer(sql(s"SELECT contains(collate('${testCase.left}', '${testCase.collation}')," + - s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + checkAnswer( + sql( + s"SELECT contains(collate('${testCase.left}', '${testCase.collation}')," + + s"collate('${testCase.right}', '${testCase.collation}'))"), + Row(testCase.expectedResult)) }) } @@ -335,11 +338,13 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { CollationTestCase("abcde", "ABC", "UTF8_BINARY_LCASE", true), CollationTestCase("abcde", "bcd", "UTF8_BINARY_LCASE", false), CollationTestCase("abcde", "ABC", "UNICODE_CI", true), - CollationTestCase("abcde", "bcd", "UNICODE_CI", false) - ) + CollationTestCase("abcde", "bcd", "UNICODE_CI", false)) checks.foreach(testCase => { - checkAnswer(sql(s"SELECT startswith(collate('${testCase.left}', '${testCase.collation}')," + - s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + checkAnswer( + sql( + s"SELECT startswith(collate('${testCase.left}', '${testCase.collation}')," + + s"collate('${testCase.right}', '${testCase.collation}'))"), + Row(testCase.expectedResult)) }) } @@ -353,11 +358,13 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { CollationTestCase("abcde", "CDE", "UTF8_BINARY_LCASE", true), CollationTestCase("abcde", "bcd", "UTF8_BINARY_LCASE", false), CollationTestCase("abcde", "CDE", "UNICODE_CI", true), - CollationTestCase("abcde", "bcd", "UNICODE_CI", false) - ) + CollationTestCase("abcde", "bcd", "UNICODE_CI", false)) checks.foreach(testCase => { - checkAnswer(sql(s"SELECT endswith(collate('${testCase.left}', '${testCase.collation}')," + - s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + checkAnswer( + sql( + s"SELECT endswith(collate('${testCase.left}', '${testCase.collation}')," + + s"collate('${testCase.right}', '${testCase.collation}'))"), + Row(testCase.expectedResult)) }) } @@ -374,18 +381,18 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ("unicode", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), ("unicode_CI", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("unicode_CI", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))), - ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))) - ).foreach { + ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb")))).foreach { case (collationName: String, input: Seq[String], expected: Seq[Row]) => - checkAnswer(sql( - s""" + checkAnswer( + sql(s""" with t as ( select collate(col1, '$collationName') as c from values ${input.map(s => s"('$s')").mkString(", ")} ) SELECT COUNT(*), c FROM t GROUP BY c - """), expected) + """), + expected) } } @@ -394,7 +401,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val tableNameBinary = "T_BINARY" withTable(tableNameNonBinary) { withTable(tableNameBinary) { - sql(s"CREATE TABLE $tableNameNonBinary (c STRING COLLATE UTF8_BINARY_LCASE) USING PARQUET") + sql( + s"CREATE TABLE $tableNameNonBinary (c STRING COLLATE UTF8_BINARY_LCASE) USING PARQUET") sql(s"INSERT INTO $tableNameNonBinary VALUES ('aaa')") sql(s"CREATE TABLE $tableNameBinary (c STRING COLLATE UTF8_BINARY) USING PARQUET") sql(s"INSERT INTO $tableNameBinary VALUES ('aaa')") @@ -413,12 +421,10 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } test("text writing to parquet with collation enclosed with backticks") { - withTempPath{ path => + withTempPath { path => sql(s"select 'a' COLLATE `UNICODE`").write.parquet(path.getAbsolutePath) - checkAnswer( - spark.read.parquet(path.getAbsolutePath), - Row("a")) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row("a")) } } @@ -428,8 +434,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val collationId = CollationFactory.collationNameToId(collationName) withTable(tableName) { - sql( - s""" + sql(s""" |CREATE TABLE $tableName (c1 STRING COLLATE $collationName) |USING PARQUET |""".stripMargin) @@ -448,8 +453,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val collationId = CollationFactory.collationNameToId(collationName) withTable(tableName) { - sql( - s""" + sql(s""" |CREATE TABLE $tableName |(c1 STRUCT) |USING PARQUET @@ -458,9 +462,11 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"INSERT INTO $tableName VALUES (named_struct('name', 'aaa', 'id', 1))") sql(s"INSERT INTO $tableName VALUES (named_struct('name', 'AAA', 'id', 2))") - checkAnswer(sql(s"SELECT DISTINCT collation(c1.name) FROM $tableName"), + checkAnswer( + sql(s"SELECT DISTINCT collation(c1.name) FROM $tableName"), Seq(Row(collationName))) - assert(sql(s"SELECT c1.name FROM $tableName").schema.head.dataType == StringType(collationId)) + assert( + sql(s"SELECT c1.name FROM $tableName").schema.head.dataType == StringType(collationId)) } } @@ -471,8 +477,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val collationId = CollationFactory.collationNameToId(collationName) withTable(tableName) { - sql( - s""" + sql(s""" |CREATE TABLE $tableName (c1 STRING) |USING PARQUET |""".stripMargin) @@ -480,11 +485,11 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"INSERT INTO $tableName VALUES ('aaa')") sql(s"INSERT INTO $tableName VALUES ('AAA')") - checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), + checkAnswer( + sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), Seq(Row(defaultCollation))) - sql( - s""" + sql(s""" |ALTER TABLE $tableName |ADD COLUMN c2 STRING COLLATE $collationName |""".stripMargin) @@ -492,8 +497,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"INSERT INTO $tableName VALUES ('aaa', 'aaa')") sql(s"INSERT INTO $tableName VALUES ('AAA', 'AAA')") - checkAnswer(sql(s"SELECT DISTINCT COLLATION(c2) FROM $tableName"), - Seq(Row(collationName))) + checkAnswer(sql(s"SELECT DISTINCT COLLATION(c2) FROM $tableName"), Seq(Row(collationName))) assert(sql(s"select c2 FROM $tableName").schema.head.dataType == StringType(collationId)) } } @@ -504,8 +508,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val collationId = CollationFactory.collationNameToId(collationName) withTable(tableName) { - sql( - s""" + sql(s""" |CREATE TABLE $tableName (c1 string COLLATE $collationName) |USING $v2Source |""".stripMargin) @@ -523,8 +526,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"INSERT INTO $tableName VALUES ('a'), ('A')") - checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), - Seq(Row(collationName))) + checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), Seq(Row(collationName))) assert(sql(s"select c1 FROM $tableName").schema.head.dataType == StringType(collationId)) } } @@ -533,8 +535,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { def createTable(partitionColumns: String*): Unit = { val tableName = "test_partition_tbl" withTable(tableName) { - sql( - s""" + sql(s""" |CREATE TABLE $tableName |(id INT, c1 STRING COLLATE UNICODE, c2 string) |USING parquet @@ -554,66 +555,78 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { createTable(partitionColumns: _*) }, errorClass = "INVALID_PARTITION_COLUMN_DATA_TYPE", - parameters = Map("type" -> "\"STRING COLLATE UNICODE\"") - ); + parameters = Map("type" -> "\"STRING COLLATE UNICODE\"")); } } test("shuffle respects collation") { val in = (('a' to 'z') ++ ('A' to 'Z')).map(_.toString * 3).map(Row.apply(_)) - val schema = StructType(StructField( - "col", - StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) :: Nil) + val schema = StructType( + StructField( + "col", + StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) :: Nil) val df = spark.createDataFrame(sparkContext.parallelize(in), schema) - df.repartition(10, df.col("col")).foreachPartition( - (rowIterator: Iterator[Row]) => { + df.repartition(10, df.col("col")) + .foreachPartition((rowIterator: Iterator[Row]) => { val partitionData = rowIterator.map(r => r.getString(0)).toArray partitionData.foreach(s => { // assert that both lower and upper case of the string are present in the same partition. assert(partitionData.contains(s.toLowerCase())) assert(partitionData.contains(s.toUpperCase())) }) - }) + }) } test("hash based joins not allowed for non-binary collated strings") { val in = (('a' to 'z') ++ ('A' to 'Z')).map(_.toString * 3).map(e => Row.apply(e, e)) - val schema = StructType(StructField( - "col_non_binary", - StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) :: - StructField("col_binary", StringType) :: Nil) + val schema = StructType( + StructField( + "col_non_binary", + StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) :: + StructField("col_binary", StringType) :: Nil) val df1 = spark.createDataFrame(sparkContext.parallelize(in), schema) // Binary collations are allowed to use hash join. - assert(collectFirst( - df1.hint("broadcast").join(df1, df1("col_binary") === df1("col_binary")) - .queryExecution.executedPlan) { - case _: BroadcastHashJoinExec => () - }.nonEmpty) + assert( + collectFirst( + df1 + .hint("broadcast") + .join(df1, df1("col_binary") === df1("col_binary")) + .queryExecution + .executedPlan) { case _: BroadcastHashJoinExec => + () + }.nonEmpty) // Even with hint broadcast, hash join is not used for non-binary collated strings. - assert(collectFirst( - df1.hint("broadcast").join(df1, df1("col_non_binary") === df1("col_non_binary")) - .queryExecution.executedPlan) { - case _: BroadcastHashJoinExec => () - }.isEmpty) + assert( + collectFirst( + df1 + .hint("broadcast") + .join(df1, df1("col_non_binary") === df1("col_non_binary")) + .queryExecution + .executedPlan) { case _: BroadcastHashJoinExec => + () + }.isEmpty) // Instead they will default to sort merge join. - assert(collectFirst( - df1.hint("broadcast").join(df1, df1("col_non_binary") === df1("col_non_binary")) - .queryExecution.executedPlan) { - case _: SortMergeJoinExec => () - }.nonEmpty) + assert( + collectFirst( + df1 + .hint("broadcast") + .join(df1, df1("col_non_binary") === df1("col_non_binary")) + .queryExecution + .executedPlan) { case _: SortMergeJoinExec => + () + }.nonEmpty) } test("Generated column expressions using collations - errors out") { checkError( exception = intercept[AnalysisException] { - sql( - s""" + sql(s""" |CREATE TABLE testcat.test_table( | c1 STRING COLLATE UNICODE, | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (SUBSTRING(c1, 0, 1)) @@ -629,8 +642,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkError( exception = intercept[AnalysisException] { - sql( - s""" + sql(s""" |CREATE TABLE testcat.test_table( | c1 STRING COLLATE UNICODE, | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (c1 || 'a' COLLATE UNICODE) @@ -646,8 +658,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkError( exception = intercept[AnalysisException] { - sql( - s""" + sql(s""" |CREATE TABLE testcat.test_table( | struct1 STRUCT, | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (SUBSTRING(struct1.a, 0, 1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala index 1815398703182..e75d0d3616861 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala @@ -16,17 +16,15 @@ */ package org.apache.spark.sql.execution.benchmark -// scalastyle:off import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.CollationFactory -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** - * Benchmark to measure performance for joins. - * To run this benchmark: + * Benchmark to measure performance for joins. To run this benchmark: * {{{ * 1. without sbt: * bin/spark-submit --class @@ -47,54 +45,77 @@ object CollationBenchmark extends SqlBasedBenchmark { } def benchmarkFilter(collationTypes: Seq[String], utf8Strings: Seq[UTF8String]): Unit = { - val benchmark = collationTypes.foldLeft(new Benchmark(s"filter collation types", - utf8Strings.size, output = output)) { (b, collationType) => - b.addCase(s"filter - $collationType") { _ => - val collation = CollationFactory.fetchCollation(collationType) - utf8Strings.filter(s => collation.equalsFunction(s, UTF8String.fromString("500")).booleanValue()) - } - b + val benchmark = collationTypes.foldLeft( + new Benchmark(s"filter collation types", utf8Strings.size, output = output)) { + (b, collationType) => + b.addCase(s"filter - $collationType") { _ => + val collation = CollationFactory.fetchCollation(collationType) + utf8Strings.filter(s => + collation.equalsFunction(s, UTF8String.fromString("500")).booleanValue()) + } + b } benchmark.run() } def benchmarkHashFunction(collationTypes: Seq[String], utf8Strings: Seq[UTF8String]): Unit = { - val benchmark = collationTypes.foldLeft(new Benchmark(s"hashFunction collation types", - utf8Strings.size, output = output)) { (b, collationType) => - b.addCase(s"hashFunction - $collationType") { _ => - val collation = CollationFactory.fetchCollation(collationType) - utf8Strings.map(s => collation.hashFunction.applyAsLong(s)) - } - b + val benchmark = collationTypes.foldLeft( + new Benchmark(s"hashFunction collation types", utf8Strings.size, output = output)) { + (b, collationType) => + b.addCase(s"hashFunction - $collationType") { _ => + val collation = CollationFactory.fetchCollation(collationType) + utf8Strings.map(s => collation.hashFunction.applyAsLong(s)) + } + b } benchmark.run() } - def collationBenchmarkFilterEqual(collationTypes: Seq[String], utf8Strings: Seq[UTF8String]): Unit = { + def collationBenchmarkFilterEqual( + collationTypes: Seq[String], + utf8Strings: Seq[UTF8String]): Unit = { val N = 5 << 20 - val benchmark = collationTypes.foldLeft(new Benchmark(s"filter df column with collation", - utf8Strings.size, output = output)) { (b, collationType) => - b.addCase(s"filter df column with collation - $collationType") { _ => - val df = spark.range(N).withColumn("id_s", expr("cast(id as string)")) - .selectExpr((Seq("id_s") ++ collationTypes.map(t => s"collate(id_s, '$collationType') as k_$t")): _*) - .withColumn("k_lower", expr("lower(id_s)")) - .withColumn("k_upper", expr("upper(id_s)")) - import scala.jdk.CollectionConverters._ - val schema = StructType(Seq(StructField("my_strings", StringType(CollationFactory.collationNameToId(collationType))))) - val odf = spark.createDataFrame(List( - Row(utf8Strings.head), Row("BBB"), Row("CCC"), Row("DDD"), Row("EEE"), Row("FFF"), Row("ggG"), Row("hhhhhhh"), Row("III"), Row("JJJ") - ).asJava, schema) - val dff = df.crossJoin(odf) + val benchmark = collationTypes.foldLeft( + new Benchmark(s"filter df column with collation", utf8Strings.size, output = output)) { + (b, collationType) => + b.addCase(s"filter df column with collation - $collationType") { _ => + val df = spark + .range(N) + .withColumn("id_s", expr("cast(id as string)")) + .selectExpr((Seq("id_s") ++ collationTypes.map(t => + s"collate(id_s, '$collationType') as k_$t")): _*) + .withColumn("k_lower", expr("lower(id_s)")) + .withColumn("k_upper", expr("upper(id_s)")) + import scala.jdk.CollectionConverters._ + val schema = StructType( + Seq( + StructField( + "my_strings", + StringType(CollationFactory.collationNameToId(collationType))))) + val odf = spark.createDataFrame( + List( + Row(utf8Strings.head), + Row("BBB"), + Row("CCC"), + Row("DDD"), + Row("EEE"), + Row("FFF"), + Row("ggG"), + Row("hhhhhhh"), + Row("III"), + Row("JJJ")).asJava, + schema) + val dff = df.crossJoin(odf) - val dfff = dff.where(col(s"k_$collationType") === col("my_strings") ) //col(s"k_$collationType") === expr(s"collate('AAA', '$collationType')") || - dfff.queryExecution.executedPlan.executeCollect() - } - b + val dfff = dff.where( + col(s"k_$collationType") === col("my_strings") + ) // col(s"k_$collationType") === expr(s"collate('AAA', '$collationType')") || + dfff.queryExecution.executedPlan.executeCollect() + } + b } benchmark.run() } - def codeGenNonCollationBenchmarkFilterILike(collationType: String, utf8Strings: Seq[UTF8String]): Unit = { - } override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { val utf8Strings = generateUTF8Strings(1000) // Adjust the size as needed diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index db8bdc361c7a5..a6bda01fa2500 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType} /** - * Benchmark to measure read performance with Filter pushdown. - * To run this benchmark: + * Benchmark to measure read performance with Filter pushdown. To run this benchmark: * {{{ * 1. without sbt: bin/spark-submit --class * --jars , @@ -50,7 +49,8 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { .set("spark.master", "local[1]") .setIfMissing("spark.driver.memory", "3g") .setIfMissing("spark.executor.memory", "3g") - .setIfMissing("spark.sql.parquet.compression.codec", + .setIfMissing( + "spark.sql.parquet.compression.codec", ParquetCompressionCodec.SNAPPY.lowerCaseName()) SparkSession.builder().config(conf).getOrCreate() @@ -63,11 +63,15 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { private val blockSize = org.apache.parquet.hadoop.ParquetWriter.DEFAULT_PAGE_SIZE def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(spark.catalog.dropTempView) + try f + finally tableNames.foreach(spark.catalog.dropTempView) } private def prepareTable( - dir: File, numRows: Int, width: Int, useStringForValue: Boolean): Unit = { + dir: File, + numRows: Int, + width: Int, + useStringForValue: Boolean): Unit = { import spark.implicits._ val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") val valueCol = if (useStringForValue) { @@ -75,7 +79,10 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { } else { monotonically_increasing_id() } - val df = spark.range(numRows).map(_ => Random.nextLong()).selectExpr(selectExpr: _*) + val df = spark + .range(numRows) + .map(_ => Random.nextLong()) + .selectExpr(selectExpr: _*) .withColumn("value", valueCol) .sort("value") @@ -83,7 +90,10 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { } private def prepareStringDictTable( - dir: File, numRows: Int, numDistinctValues: Int, width: Int): Unit = { + dir: File, + numRows: Int, + numDistinctValues: Int, + width: Int): Unit = { val selectExpr = (0 to width).map { case 0 => s"CAST(id % $numDistinctValues AS STRING) AS value" case i => s"CAST(rand() AS STRING) c$i" @@ -97,14 +107,18 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { val orcPath = dir.getCanonicalPath + "/orc" val parquetPath = dir.getCanonicalPath + "/parquet" - df.write.mode("overwrite") + df.write + .mode("overwrite") .option("orc.dictionary.key.threshold", if (useDictionary) 1.0 else 0.8) .option("orc.compress.size", blockSize) - .option("orc.stripe.size", blockSize).orc(orcPath) + .option("orc.stripe.size", blockSize) + .orc(orcPath) spark.read.orc(orcPath).createOrReplaceTempView("orcTable") - df.write.mode("overwrite") - .option("parquet.block.size", blockSize).parquet(parquetPath) + df.write + .mode("overwrite") + .option("parquet.block.size", blockSize) + .parquet(parquetPath) spark.read.parquet(parquetPath).createOrReplaceTempView("parquetTable") } @@ -146,8 +160,7 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { s"value = $mid", s"value <=> $mid", s"$mid <= value AND value <= $mid", - s"${mid - 1} < value AND value < ${mid + 1}" - ).foreach { whereExpr => + s"${mid - 1} < value AND value < ${mid + 1}").foreach { whereExpr => val title = s"Select 1 int row ($whereExpr)".replace("value AND value", "value") filterPushDownBenchmark(numRows, title, whereExpr) } @@ -159,32 +172,29 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { numRows, s"Select $percent% int rows (value < ${numRows * percent / 100})", s"value < ${numRows * percent / 100}", - selectExpr - ) + selectExpr) } Seq("value IS NOT NULL", "value > -1", "value != -1").foreach { whereExpr => - filterPushDownBenchmark( - numRows, - s"Select all int rows ($whereExpr)", - whereExpr, - selectExpr) + filterPushDownBenchmark(numRows, s"Select all int rows ($whereExpr)", whereExpr, selectExpr) } } private def runStringBenchmark( - numRows: Int, width: Int, searchValue: Int, colType: String): Unit = { + numRows: Int, + width: Int, + searchValue: Int, + colType: String): Unit = { Seq("value IS NULL", s"'$searchValue' < value AND value < '$searchValue'") - .foreach { whereExpr => - val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value") - filterPushDownBenchmark(numRows, title, whereExpr) - } + .foreach { whereExpr => + val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } Seq( s"value = '$searchValue'", s"value <=> '$searchValue'", - s"'$searchValue' <= value AND value <= '$searchValue'" - ).foreach { whereExpr => + s"'$searchValue' <= value AND value <= '$searchValue'").foreach { whereExpr => val title = s"Select 1 $colType row ($whereExpr)".replace("value AND value", "value") filterPushDownBenchmark(numRows, title, whereExpr) } @@ -200,8 +210,11 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { } } - private def runStringBenchmarkCollation(numRows: Int, width: Int, - searchValue: Int, colType: String): Unit = { + private def runStringBenchmarkCollation( + numRows: Int, + width: Int, + searchValue: Int, + colType: String): Unit = { Seq("value IS NULL", s"'$searchValue' < value AND value < '$searchValue'") .foreach { whereExpr => val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value") @@ -211,8 +224,7 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { Seq( s"value = '$searchValue'", s"value <=> '$searchValue'", - s"'$searchValue' <= value AND value <= '$searchValue'" - ).foreach { whereExpr => + s"'$searchValue' <= value AND value <= '$searchValue'").foreach { whereExpr => val title = s"Select 1 $colType row ($whereExpr)".replace("value AND value", "value") filterPushDownBenchmark(numRows, title, whereExpr) } @@ -263,10 +275,10 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { Seq( "value like '10%'", "value like '1000%'", - s"value like '${mid.toString.substring(0, mid.toString.length - 1)}%'" - ).foreach { whereExpr => - val title = s"StringStartsWith filter: ($whereExpr)" - filterPushDownBenchmark(numRows, title, whereExpr) + s"value like '${mid.toString.substring(0, mid.toString.length - 1)}%'").foreach { + whereExpr => + val title = s"StringStartsWith filter: ($whereExpr)" + filterPushDownBenchmark(numRows, title, whereExpr) } } } @@ -279,10 +291,10 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { Seq( "value like '%10'", "value like '%1000'", - s"value like '%${mid.toString.substring(0, mid.toString.length - 1)}'" - ).foreach { whereExpr => - val title = s"StringEndsWith filter: ($whereExpr)" - filterPushDownBenchmark(numRows, title, whereExpr) + s"value like '%${mid.toString.substring(0, mid.toString.length - 1)}'").foreach { + whereExpr => + val title = s"StringEndsWith filter: ($whereExpr)" + filterPushDownBenchmark(numRows, title, whereExpr) } } } @@ -295,10 +307,10 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { Seq( "value like '%10%'", "value like '%1000%'", - s"value like '%${mid.toString.substring(0, mid.toString.length - 1)}%'" - ).foreach { whereExpr => - val title = s"StringContains filter: ($whereExpr)" - filterPushDownBenchmark(numRows, title, whereExpr) + s"value like '%${mid.toString.substring(0, mid.toString.length - 1)}%'").foreach { + whereExpr => + val title = s"StringContains filter: ($whereExpr)" + filterPushDownBenchmark(numRows, title, whereExpr) } } } @@ -309,16 +321,17 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { Seq( s"decimal(${Decimal.MAX_INT_DIGITS}, 2)", s"decimal(${Decimal.MAX_LONG_DIGITS}, 2)", - s"decimal(${DecimalType.MAX_PRECISION}, 2)" - ).foreach { dt => + s"decimal(${DecimalType.MAX_PRECISION}, 2)").foreach { dt => val columns = (1 to width).map(i => s"CAST(id AS string) c$i") val valueCol = if (dt.equalsIgnoreCase(s"decimal(${Decimal.MAX_INT_DIGITS}, 2)")) { monotonically_increasing_id() % 9999999 } else { monotonically_increasing_id() } - val df = spark.range(numRows) - .selectExpr(columns: _*).withColumn("value", valueCol.cast(dt)) + val df = spark + .range(numRows) + .selectExpr(columns: _*) + .withColumn("value", valueCol.cast(dt)) withTempTable("orcTable", "parquetTable") { saveAsTable(df, dir) @@ -333,8 +346,7 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { numRows, s"Select $percent% $dt rows (value < ${numRows * percent / 100})", s"value < ${numRows * percent / 100}", - selectExpr - ) + selectExpr) } } } @@ -350,7 +362,8 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { val filter = Range(0, count).map(r => scala.util.Random.nextInt(numRows * distribution / 100)) val whereExpr = s"value in(${filter.mkString(",")})" - val title = s"InSet -> InFilters (values count: $count, distribution: $distribution)" + val title = + s"InSet -> InFilters (values count: $count, distribution: $distribution)" filterPushDownBenchmark(numRows, title, whereExpr) } } @@ -361,7 +374,9 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { runBenchmark(s"Pushdown benchmark for ${ByteType.simpleString}") { withTempPath { dir => val columns = (1 to width).map(i => s"CAST(id AS string) c$i") - val df = spark.range(numRows).selectExpr(columns: _*) + val df = spark + .range(numRows) + .selectExpr(columns: _*) .withColumn("value", (monotonically_increasing_id() % Byte.MaxValue).cast(ByteType)) .orderBy("value") withTempTable("orcTable", "parquetTable") { @@ -381,8 +396,7 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { s"Select $percent% ${ByteType.simpleString} rows " + s"(value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString}))", s"value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString})", - selectExpr - ) + selectExpr) } } } @@ -394,7 +408,9 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { ParquetOutputTimestampType.values.toSeq.map(_.toString).foreach { fileType => withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> fileType) { val columns = (1 to width).map(i => s"CAST(id AS string) c$i") - val df = spark.range(numRows).selectExpr(columns: _*) + val df = spark + .range(numRows) + .selectExpr(columns: _*) .withColumn("value", timestamp_seconds(monotonically_increasing_id())) withTempTable("orcTable", "parquetTable") { saveAsTable(df, dir) @@ -406,15 +422,15 @@ object FilterPushdownBenchmark extends SqlBasedBenchmark { } val selectExpr = (1 to width) - .map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + .map(i => s"MAX(c$i)") + .mkString("", ",", ", MAX(value)") Seq(10, 50, 90).foreach { percent => filterPushDownBenchmark( numRows, s"Select $percent% timestamp stored as $fileType rows " + s"(value < timestamp_seconds(${numRows * percent / 100}))", s"value < timestamp_seconds(${numRows * percent / 100})", - selectExpr - ) + selectExpr) } } }