diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 8caf8c5d48c2b..808ad54f8ecad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -105,7 +105,7 @@ trait TimestampFormatterHelper extends TimeZoneAwareExpression { since = "3.1.0") case class CurrentTimeZone() extends LeafExpression with Unevaluable { override def nullable: Boolean = false - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "current_timezone" final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) } @@ -924,7 +924,7 @@ case class DayName(child: Expression) extends GetDateField { override val funcName = "getDayName" override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override protected def withNewChildInternal(newChild: Expression): DayName = copy(child = newChild) } @@ -1262,7 +1262,8 @@ abstract class ToTimestamp override def forTimestampNTZ: Boolean = left.dataType == TimestampNTZType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringType, DateType, TimestampType, TimestampNTZType), StringType) + Seq(TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType), + StringTypeAnyCollation) override def dataType: DataType = LongType override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true @@ -1284,7 +1285,7 @@ abstract class ToTimestamp daysToMicros(t.asInstanceOf[Int], zoneId) / downScaleFactor case TimestampType | TimestampNTZType => t.asInstanceOf[Long] / downScaleFactor - case StringType => + case _: StringType => val fmt = right.eval(input) if (fmt == null) { null @@ -1327,7 +1328,7 @@ abstract class ToTimestamp } left.dataType match { - case StringType => formatterOption.map { fmt => + case _: StringType => formatterOption.map { fmt => val df = classOf[TimestampFormatter].getName val formatterName = ctx.addReferenceObj("formatter", fmt, df) nullSafeCodeGen(ctx, ev, (datetimeStr, _) => @@ -1430,10 +1431,10 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ this(unix, Literal(TimestampFormatter.defaultPattern())) } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringTypeAnyCollation) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) @@ -1541,7 +1542,7 @@ case class NextDay( def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation) override def dataType: DataType = DateType override def nullable: Boolean = true @@ -1752,7 +1753,7 @@ sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes w val func: (Long, String) => Long val funcName: String - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringTypeAnyCollation) override def dataType: DataType = TimestampType override def nullSafeEval(time: Any, timezone: Any): Any = { @@ -2092,8 +2093,8 @@ case class ParseToDate( override def inputTypes: Seq[AbstractDataType] = { // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. - TypeCollection(StringType, DateType, TimestampType, TimestampNTZType) +: - format.map(_ => StringType).toSeq + TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType) +: + format.map(_ => StringTypeAnyCollation).toSeq } override protected def withNewChildrenInternal( @@ -2164,10 +2165,10 @@ case class ParseToTimestamp( override def inputTypes: Seq[AbstractDataType] = { // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. - val types = Seq(StringType, DateType, TimestampType, TimestampNTZType) + val types = Seq(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType) TypeCollection( (if (dataType.isInstanceOf[TimestampType]) types :+ NumericType else types): _* - ) +: format.map(_ => StringType).toSeq + ) +: format.map(_ => StringTypeAnyCollation).toSeq } override protected def withNewChildrenInternal( @@ -2297,7 +2298,7 @@ case class TruncDate(date: Expression, format: Expression) override def left: Expression = date override def right: Expression = format - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation) override def dataType: DataType = DateType override def prettyName: String = "trunc" override val instant = date @@ -2366,7 +2367,7 @@ case class TruncTimestamp( override def left: Expression = format override def right: Expression = timestamp - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, TimestampType) override def dataType: TimestampType = TimestampType override def prettyName: String = "date_trunc" override val instant = timestamp @@ -2667,7 +2668,7 @@ case class MakeTimestamp( // casted into decimal safely, we use DecimalType(16, 6) which is wider than DecimalType(10, 0). override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType, IntegerType, IntegerType, DecimalType(16, 6)) ++ - timezone.map(_ => StringType) + timezone.map(_ => StringTypeAnyCollation) override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = @@ -2939,7 +2940,7 @@ case class Extract(field: Expression, source: Expression, replacement: Expressio object Extract { def createExpr(funcName: String, field: Expression, source: Expression): Expression = { // both string and null literals are allowed. - if ((field.dataType == StringType || field.dataType == NullType) && field.foldable) { + if ((field.dataType.isInstanceOf[StringType] || field.dataType == NullType) && field.foldable) { val fieldStr = field.eval().asInstanceOf[UTF8String] if (fieldStr == null) { Literal(null, DoubleType) @@ -3114,7 +3115,8 @@ case class ConvertTimezone( override def second: Expression = targetTz override def third: Expression = sourceTs - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, TimestampNTZType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, + StringTypeAnyCollation, TimestampNTZType) override def dataType: DataType = TimestampNTZType override def nullSafeEval(srcTz: Any, tgtTz: Any, micros: Any): Any = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index f3d07ba47b715..525ef02f949a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import scala.collection.immutable.Seq @@ -31,6 +32,8 @@ class CollationSQLExpressionsSuite extends QueryTest with SharedSparkSession { + private val testSuppCollations = Seq("UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI") + test("Support Md5 hash expression with collation") { case class Md5TestCase( input: String, @@ -1632,6 +1635,237 @@ class CollationSQLExpressionsSuite } } + test("CurrentTimeZone expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = "select current_timezone()" + // Data type check + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val testQuery = sql(query) + val dataType = StringType(collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("DayName expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = "select dayname(current_date())" + // Data type check + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val testQuery = sql(query) + val dataType = StringType(collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("ToUnixTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select to_unix_timestamp(collate('2021-01-01 00:00:00', '${collationName}'), + |collate('yyyy-MM-dd HH:mm:ss', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = LongType + val expectedResult = 1609488000L + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(expectedResult)) + }) + } + + test("FromUnixTime expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select from_unixtime(1609488000, collate('yyyy-MM-dd HH:mm:ss', '${collationName}')) + |""".stripMargin + // Result & data type check + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val testQuery = sql(query) + val dataType = StringType(collationName) + val expectedResult = "2021-01-01 00:00:00" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(expectedResult)) + } + }) + } + + test("NextDay expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select next_day('2015-01-14', collate('TU', '${collationName}')) + |""".stripMargin + // Result & data type check + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val testQuery = sql(query) + val dataType = DateType + val expectedResult = "2015-01-20" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Date.valueOf(expectedResult))) + } + }) + } + + test("FromUTCTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select from_utc_timestamp(collate('2016-08-31', '${collationName}'), + |collate('Asia/Seoul', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2016-08-31 09:00:00.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("ToUTCTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select to_utc_timestamp(collate('2016-08-31 09:00:00', '${collationName}'), + |collate('Asia/Seoul', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2016-08-31 00:00:00.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("ParseToDate expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select to_date(collate('2016-12-31', '${collationName}'), + |collate('yyyy-MM-dd', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = DateType + val expectedResult = "2016-12-31" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Date.valueOf(expectedResult))) + }) + } + + test("ParseToTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select to_timestamp(collate('2016-12-31 23:59:59', '${collationName}'), + |collate('yyyy-MM-dd HH:mm:ss', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2016-12-31 23:59:59.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("TruncDate expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select trunc(collate('2016-12-31 23:59:59', '${collationName}'), 'MM') + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = DateType + val expectedResult = "2016-12-01" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Date.valueOf(expectedResult))) + }) + } + + test("TruncTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select date_trunc(collate('HOUR', '${collationName}'), + |collate('2015-03-05T09:32:05.359', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2015-03-05 09:00:00.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("MakeTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select make_timestamp(2014, 12, 28, 6, 30, 45.887, collate('CET', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2014-12-27 21:30:45.887" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("DatePart expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select date_part(collate('Week', '${collationName}'), + |collate('2019-08-12 01:00:00.123456', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = IntegerType + val expectedResult = 33 + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(expectedResult)) + }) + } + + test("ConvertTimezone expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select date_format(convert_timezone(collate('America/Los_Angeles', '${collationName}'), + |collate('UTC', '${collationName}'), collate('2021-12-06 00:00:00', '${collationName}')), + |'yyyy-MM-dd HH:mm:ss.S') + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = StringType + val expectedResult = "2021-12-06 08:00:00.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(expectedResult)) + }) + } + // TODO: Add more tests for other SQL expressions }