From e9a3ed857954c21ca639f07f2621ac8cebc30ad3 Mon Sep 17 00:00:00 2001 From: Nebojsa Savic Date: Tue, 28 May 2024 10:01:42 -0700 Subject: [PATCH] [SPARK-48159][SQL] Extending support for collated strings on datetime expressions ### What changes were proposed in this pull request? This PR introduces changes that will allow for collated strings to be passed to various datetime expressions or return value as collated string from those expressions. Impacted datetime expressions: - current_timezone - to_unix_timestamp - from_unixtime - next_day - from_utc_timestamp - to_utc_timestamp - to_date - to_timestamp - trunc - date_trunc - make_timestamp - date_part - convert_timezone ### Why are the changes needed? This PR is part of ongoing effort to support collated strings on SparkSQL. ### Does this PR introduce _any_ user-facing change? Yes, users will be able to use collated strings for datetime expressions. ### How was this patch tested? Added corresponding tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46618 from nebojsa-db/SPARK-48159. Authored-by: Nebojsa Savic Signed-off-by: Wenchen Fan --- .../expressions/datetimeExpressions.scala | 38 +-- .../sql/CollationSQLExpressionsSuite.scala | 234 ++++++++++++++++++ 2 files changed, 254 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 8caf8c5d48c2b..808ad54f8ecad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -105,7 +105,7 @@ trait TimestampFormatterHelper extends TimeZoneAwareExpression { since = "3.1.0") case class CurrentTimeZone() extends LeafExpression with Unevaluable { override def nullable: Boolean = false - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "current_timezone" final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) } @@ -924,7 +924,7 @@ case class DayName(child: Expression) extends GetDateField { override val funcName = "getDayName" override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override protected def withNewChildInternal(newChild: Expression): DayName = copy(child = newChild) } @@ -1262,7 +1262,8 @@ abstract class ToTimestamp override def forTimestampNTZ: Boolean = left.dataType == TimestampNTZType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringType, DateType, TimestampType, TimestampNTZType), StringType) + Seq(TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType), + StringTypeAnyCollation) override def dataType: DataType = LongType override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true @@ -1284,7 +1285,7 @@ abstract class ToTimestamp daysToMicros(t.asInstanceOf[Int], zoneId) / downScaleFactor case TimestampType | TimestampNTZType => t.asInstanceOf[Long] / downScaleFactor - case StringType => + case _: StringType => val fmt = right.eval(input) if (fmt == null) { null @@ -1327,7 +1328,7 @@ abstract class ToTimestamp } left.dataType match { - case StringType => formatterOption.map { fmt => + case _: StringType => formatterOption.map { fmt => val df = classOf[TimestampFormatter].getName val formatterName = ctx.addReferenceObj("formatter", fmt, df) nullSafeCodeGen(ctx, ev, (datetimeStr, _) => @@ -1430,10 +1431,10 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ this(unix, Literal(TimestampFormatter.defaultPattern())) } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringTypeAnyCollation) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) @@ -1541,7 +1542,7 @@ case class NextDay( def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation) override def dataType: DataType = DateType override def nullable: Boolean = true @@ -1752,7 +1753,7 @@ sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes w val func: (Long, String) => Long val funcName: String - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringTypeAnyCollation) override def dataType: DataType = TimestampType override def nullSafeEval(time: Any, timezone: Any): Any = { @@ -2092,8 +2093,8 @@ case class ParseToDate( override def inputTypes: Seq[AbstractDataType] = { // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. - TypeCollection(StringType, DateType, TimestampType, TimestampNTZType) +: - format.map(_ => StringType).toSeq + TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType) +: + format.map(_ => StringTypeAnyCollation).toSeq } override protected def withNewChildrenInternal( @@ -2164,10 +2165,10 @@ case class ParseToTimestamp( override def inputTypes: Seq[AbstractDataType] = { // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. - val types = Seq(StringType, DateType, TimestampType, TimestampNTZType) + val types = Seq(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType) TypeCollection( (if (dataType.isInstanceOf[TimestampType]) types :+ NumericType else types): _* - ) +: format.map(_ => StringType).toSeq + ) +: format.map(_ => StringTypeAnyCollation).toSeq } override protected def withNewChildrenInternal( @@ -2297,7 +2298,7 @@ case class TruncDate(date: Expression, format: Expression) override def left: Expression = date override def right: Expression = format - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation) override def dataType: DataType = DateType override def prettyName: String = "trunc" override val instant = date @@ -2366,7 +2367,7 @@ case class TruncTimestamp( override def left: Expression = format override def right: Expression = timestamp - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, TimestampType) override def dataType: TimestampType = TimestampType override def prettyName: String = "date_trunc" override val instant = timestamp @@ -2667,7 +2668,7 @@ case class MakeTimestamp( // casted into decimal safely, we use DecimalType(16, 6) which is wider than DecimalType(10, 0). override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType, IntegerType, IntegerType, DecimalType(16, 6)) ++ - timezone.map(_ => StringType) + timezone.map(_ => StringTypeAnyCollation) override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = @@ -2939,7 +2940,7 @@ case class Extract(field: Expression, source: Expression, replacement: Expressio object Extract { def createExpr(funcName: String, field: Expression, source: Expression): Expression = { // both string and null literals are allowed. - if ((field.dataType == StringType || field.dataType == NullType) && field.foldable) { + if ((field.dataType.isInstanceOf[StringType] || field.dataType == NullType) && field.foldable) { val fieldStr = field.eval().asInstanceOf[UTF8String] if (fieldStr == null) { Literal(null, DoubleType) @@ -3114,7 +3115,8 @@ case class ConvertTimezone( override def second: Expression = targetTz override def third: Expression = sourceTs - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, TimestampNTZType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, + StringTypeAnyCollation, TimestampNTZType) override def dataType: DataType = TimestampNTZType override def nullSafeEval(srcTz: Any, tgtTz: Any, micros: Any): Any = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index f3d07ba47b715..525ef02f949a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import scala.collection.immutable.Seq @@ -31,6 +32,8 @@ class CollationSQLExpressionsSuite extends QueryTest with SharedSparkSession { + private val testSuppCollations = Seq("UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI") + test("Support Md5 hash expression with collation") { case class Md5TestCase( input: String, @@ -1632,6 +1635,237 @@ class CollationSQLExpressionsSuite } } + test("CurrentTimeZone expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = "select current_timezone()" + // Data type check + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val testQuery = sql(query) + val dataType = StringType(collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("DayName expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = "select dayname(current_date())" + // Data type check + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val testQuery = sql(query) + val dataType = StringType(collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("ToUnixTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select to_unix_timestamp(collate('2021-01-01 00:00:00', '${collationName}'), + |collate('yyyy-MM-dd HH:mm:ss', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = LongType + val expectedResult = 1609488000L + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(expectedResult)) + }) + } + + test("FromUnixTime expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select from_unixtime(1609488000, collate('yyyy-MM-dd HH:mm:ss', '${collationName}')) + |""".stripMargin + // Result & data type check + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val testQuery = sql(query) + val dataType = StringType(collationName) + val expectedResult = "2021-01-01 00:00:00" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(expectedResult)) + } + }) + } + + test("NextDay expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select next_day('2015-01-14', collate('TU', '${collationName}')) + |""".stripMargin + // Result & data type check + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val testQuery = sql(query) + val dataType = DateType + val expectedResult = "2015-01-20" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Date.valueOf(expectedResult))) + } + }) + } + + test("FromUTCTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select from_utc_timestamp(collate('2016-08-31', '${collationName}'), + |collate('Asia/Seoul', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2016-08-31 09:00:00.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("ToUTCTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select to_utc_timestamp(collate('2016-08-31 09:00:00', '${collationName}'), + |collate('Asia/Seoul', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2016-08-31 00:00:00.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("ParseToDate expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select to_date(collate('2016-12-31', '${collationName}'), + |collate('yyyy-MM-dd', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = DateType + val expectedResult = "2016-12-31" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Date.valueOf(expectedResult))) + }) + } + + test("ParseToTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select to_timestamp(collate('2016-12-31 23:59:59', '${collationName}'), + |collate('yyyy-MM-dd HH:mm:ss', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2016-12-31 23:59:59.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("TruncDate expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select trunc(collate('2016-12-31 23:59:59', '${collationName}'), 'MM') + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = DateType + val expectedResult = "2016-12-01" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Date.valueOf(expectedResult))) + }) + } + + test("TruncTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select date_trunc(collate('HOUR', '${collationName}'), + |collate('2015-03-05T09:32:05.359', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2015-03-05 09:00:00.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("MakeTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select make_timestamp(2014, 12, 28, 6, 30, 45.887, collate('CET', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2014-12-27 21:30:45.887" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("DatePart expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select date_part(collate('Week', '${collationName}'), + |collate('2019-08-12 01:00:00.123456', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = IntegerType + val expectedResult = 33 + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(expectedResult)) + }) + } + + test("ConvertTimezone expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select date_format(convert_timezone(collate('America/Los_Angeles', '${collationName}'), + |collate('UTC', '${collationName}'), collate('2021-12-06 00:00:00', '${collationName}')), + |'yyyy-MM-dd HH:mm:ss.S') + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = StringType + val expectedResult = "2021-12-06 08:00:00.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(expectedResult)) + }) + } + // TODO: Add more tests for other SQL expressions }