Skip to content

Commit

Permalink
[SPARK-48159][SQL] Extending support for collated strings on datetime…
Browse files Browse the repository at this point in the history
… expressions

### What changes were proposed in this pull request?
This PR introduces changes that will allow for collated strings to be passed to various datetime expressions or return value as collated string from those expressions.
Impacted datetime expressions:

- current_timezone
- to_unix_timestamp
- from_unixtime
- next_day
- from_utc_timestamp
- to_utc_timestamp
- to_date
- to_timestamp
- trunc
- date_trunc
- make_timestamp
- date_part
- convert_timezone

### Why are the changes needed?
This PR is part of ongoing effort to support collated strings on SparkSQL.

### Does this PR introduce _any_ user-facing change?
Yes, users will be able to use collated strings for datetime expressions.

### How was this patch tested?
Added corresponding tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#46618 from nebojsa-db/SPARK-48159.

Authored-by: Nebojsa Savic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
nebojsa-db authored and cloud-fan committed May 28, 2024
1 parent 731a2cf commit e9a3ed8
Show file tree
Hide file tree
Showing 2 changed files with 254 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ trait TimestampFormatterHelper extends TimeZoneAwareExpression {
since = "3.1.0")
case class CurrentTimeZone() extends LeafExpression with Unevaluable {
override def nullable: Boolean = false
override def dataType: DataType = StringType
override def dataType: DataType = SQLConf.get.defaultStringType
override def prettyName: String = "current_timezone"
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
}
Expand Down Expand Up @@ -924,7 +924,7 @@ case class DayName(child: Expression) extends GetDateField {
override val funcName = "getDayName"

override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
override def dataType: DataType = StringType
override def dataType: DataType = SQLConf.get.defaultStringType
override protected def withNewChildInternal(newChild: Expression): DayName =
copy(child = newChild)
}
Expand Down Expand Up @@ -1262,7 +1262,8 @@ abstract class ToTimestamp
override def forTimestampNTZ: Boolean = left.dataType == TimestampNTZType

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(StringType, DateType, TimestampType, TimestampNTZType), StringType)
Seq(TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType),
StringTypeAnyCollation)

override def dataType: DataType = LongType
override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true
Expand All @@ -1284,7 +1285,7 @@ abstract class ToTimestamp
daysToMicros(t.asInstanceOf[Int], zoneId) / downScaleFactor
case TimestampType | TimestampNTZType =>
t.asInstanceOf[Long] / downScaleFactor
case StringType =>
case _: StringType =>
val fmt = right.eval(input)
if (fmt == null) {
null
Expand Down Expand Up @@ -1327,7 +1328,7 @@ abstract class ToTimestamp
}

left.dataType match {
case StringType => formatterOption.map { fmt =>
case _: StringType => formatterOption.map { fmt =>
val df = classOf[TimestampFormatter].getName
val formatterName = ctx.addReferenceObj("formatter", fmt, df)
nullSafeCodeGen(ctx, ev, (datetimeStr, _) =>
Expand Down Expand Up @@ -1430,10 +1431,10 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
this(unix, Literal(TimestampFormatter.defaultPattern()))
}

override def dataType: DataType = StringType
override def dataType: DataType = SQLConf.get.defaultStringType
override def nullable: Boolean = true

override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType)
override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringTypeAnyCollation)

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
Expand Down Expand Up @@ -1541,7 +1542,7 @@ case class NextDay(

def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled)

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation)

override def dataType: DataType = DateType
override def nullable: Boolean = true
Expand Down Expand Up @@ -1752,7 +1753,7 @@ sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes w
val func: (Long, String) => Long
val funcName: String

override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType)
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringTypeAnyCollation)
override def dataType: DataType = TimestampType

override def nullSafeEval(time: Any, timezone: Any): Any = {
Expand Down Expand Up @@ -2092,8 +2093,8 @@ case class ParseToDate(
override def inputTypes: Seq[AbstractDataType] = {
// Note: ideally this function should only take string input, but we allow more types here to
// be backward compatible.
TypeCollection(StringType, DateType, TimestampType, TimestampNTZType) +:
format.map(_ => StringType).toSeq
TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType) +:
format.map(_ => StringTypeAnyCollation).toSeq
}

override protected def withNewChildrenInternal(
Expand Down Expand Up @@ -2164,10 +2165,10 @@ case class ParseToTimestamp(
override def inputTypes: Seq[AbstractDataType] = {
// Note: ideally this function should only take string input, but we allow more types here to
// be backward compatible.
val types = Seq(StringType, DateType, TimestampType, TimestampNTZType)
val types = Seq(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType)
TypeCollection(
(if (dataType.isInstanceOf[TimestampType]) types :+ NumericType else types): _*
) +: format.map(_ => StringType).toSeq
) +: format.map(_ => StringTypeAnyCollation).toSeq
}

override protected def withNewChildrenInternal(
Expand Down Expand Up @@ -2297,7 +2298,7 @@ case class TruncDate(date: Expression, format: Expression)
override def left: Expression = date
override def right: Expression = format

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation)
override def dataType: DataType = DateType
override def prettyName: String = "trunc"
override val instant = date
Expand Down Expand Up @@ -2366,7 +2367,7 @@ case class TruncTimestamp(
override def left: Expression = format
override def right: Expression = timestamp

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType)
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, TimestampType)
override def dataType: TimestampType = TimestampType
override def prettyName: String = "date_trunc"
override val instant = timestamp
Expand Down Expand Up @@ -2667,7 +2668,7 @@ case class MakeTimestamp(
// casted into decimal safely, we use DecimalType(16, 6) which is wider than DecimalType(10, 0).
override def inputTypes: Seq[AbstractDataType] =
Seq(IntegerType, IntegerType, IntegerType, IntegerType, IntegerType, DecimalType(16, 6)) ++
timezone.map(_ => StringType)
timezone.map(_ => StringTypeAnyCollation)
override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
Expand Down Expand Up @@ -2939,7 +2940,7 @@ case class Extract(field: Expression, source: Expression, replacement: Expressio
object Extract {
def createExpr(funcName: String, field: Expression, source: Expression): Expression = {
// both string and null literals are allowed.
if ((field.dataType == StringType || field.dataType == NullType) && field.foldable) {
if ((field.dataType.isInstanceOf[StringType] || field.dataType == NullType) && field.foldable) {
val fieldStr = field.eval().asInstanceOf[UTF8String]
if (fieldStr == null) {
Literal(null, DoubleType)
Expand Down Expand Up @@ -3114,7 +3115,8 @@ case class ConvertTimezone(
override def second: Expression = targetTz
override def third: Expression = sourceTs

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, TimestampNTZType)
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation,
StringTypeAnyCollation, TimestampNTZType)
override def dataType: DataType = TimestampNTZType

override def nullSafeEval(srcTz: Any, tgtTz: Any, micros: Any): Any = {
Expand Down
Loading

0 comments on commit e9a3ed8

Please sign in to comment.