Skip to content

Commit

Permalink
Add TimestampNTZType support
Browse files Browse the repository at this point in the history
Resolves [#538].
  • Loading branch information
gtopper committed Jan 13, 2025
1 parent 961dd0c commit e3267f0
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 12 deletions.
16 changes: 12 additions & 4 deletions src/main/scala/net/snowflake/spark/snowflake/Conversions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ package net.snowflake.spark.snowflake

import java.sql.Timestamp
import java.text._
import java.time.ZonedDateTime
import java.time.{LocalDateTime, ZonedDateTime}
import java.time.format.DateTimeFormatter
import java.util.{Date, TimeZone}

import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -126,6 +125,15 @@ private[snowflake] object Conversions {
TimeZone.getDefault.toZoneId))
}

def formatTimestamp(t: LocalDateTime): String = {
// For writing to snowflake, time zone needs to be included
// in the timestamp string. The spark default timezone is used.
timestampWriteFormatter.format(
ZonedDateTime.of(
t,
TimeZone.getDefault.toZoneId))
}

// All strings are converted into double-quoted strings, with
// quote inside converted to double quotes
def formatString(s: String): String = {
Expand Down Expand Up @@ -176,7 +184,7 @@ private[snowflake] object Conversions {
case ShortType => data.toShort
case StringType =>
if (isIR) UTF8String.fromString(data) else data
case TimestampType => parseTimestamp(data, isIR)
case TimestampType | TimestampNTZType => parseTimestamp(data, isIR)
case _ => data
}
}
Expand Down Expand Up @@ -276,7 +284,7 @@ private[snowflake] object Conversions {
case ShortType => data.shortValue()
case StringType =>
if (isIR) UTF8String.fromString(data.asText()) else data.asText()
case TimestampType => parseTimestamp(data.asText(), isIR)
case TimestampType | TimestampNTZType => parseTimestamp(data.asText(), isIR)
case ArrayType(dt, _) =>
val result = new Array[Any](data.size())
(0 until data.size())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ private[snowflake] object FilterPushdown {
)) !
case DateType =>
StringVariable(Option(value).map(_.asInstanceOf[Date].toString)) + "::DATE"
case TimestampType =>
case TimestampType | TimestampNTZType =>
StringVariable(Option(value).map(_.asInstanceOf[Timestamp].toString)) + "::TIMESTAMP(3)"
case _ =>
value match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ private[snowflake] class JDBCWrapper {
"BINARY"
}
case TimestampType => "TIMESTAMP"
case TimestampNTZType => "TIMESTAMP_NTZ"
case DateType => "DATE"
case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})"
case _: StructType | _: ArrayType | _: MapType => "VARIANT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package net.snowflake.spark.snowflake

import java.sql.{Date, Timestamp}
import net.snowflake.client.jdbc.internal.apache.commons.codec.binary.Base64
import net.snowflake.spark.snowflake.Conversions.timestampWriteFormatter
import net.snowflake.spark.snowflake.DefaultJDBCWrapper.{snowflakeStyleSchema, snowflakeStyleString}
import net.snowflake.spark.snowflake.Parameters.MergedParameters
import net.snowflake.spark.snowflake.io.SupportedFormat
Expand All @@ -27,6 +28,9 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql._

import java.time.{LocalDateTime, ZonedDateTime}
import java.util.TimeZone

/**
* Functions to write data to Snowflake.
*
Expand Down Expand Up @@ -285,6 +289,11 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
if (v == null) ""
else Conversions.formatTimestamp(v.asInstanceOf[Timestamp])
}
case TimestampNTZType =>
(v: Any) => {
if (v == null) ""
else Conversions.formatTimestamp(v.asInstanceOf[LocalDateTime])
}
case StringType =>
(v: Any) =>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ object ParquetUtils {
builder.stringBuilder()
.prop("logicalType", "date")
.endString()
case TimestampType =>
case TimestampType | TimestampNTZType =>
builder.stringBuilder()
.prop("logicalType", " timestamp-micros")
.endString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ case class ResultIterator[T: ClassTag](
case FloatType => data.getFloat(index + 1)
case IntegerType => data.getInt(index + 1)
case LongType => data.getLong(index + 1)
case TimestampType =>
case TimestampType | TimestampNTZType =>
if (isIR) {
DateTimeUtils.fromJavaTimestamp(data.getTimestamp(index + 1))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import net.snowflake.spark.snowflake.io.SupportedFormat.SupportedFormat
import net.snowflake.spark.snowflake.DefaultJDBCWrapper.DataBaseOperations
import net.snowflake.spark.snowflake.test.{TestHook, TestHookFlag}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{BinaryType, StructType, TimestampType}
import org.apache.spark.sql.types.{BinaryType, StructType, TimestampNTZType, TimestampType}
import org.apache.spark.sql.{SQLContext, SaveMode}
import org.slf4j.LoggerFactory

Expand Down Expand Up @@ -925,7 +925,7 @@ private[io] object StageWriter {
val mappingFromString = getMappingFromString(mappingList, fromString)

val hasTimestampColumn: Boolean =
schema.exists(field => field.dataType == TimestampType)
schema.exists(field => field.dataType == TimestampType || field.dataType == TimestampNTZType)

val timestampFormat: String =
if (params.getStringTimestampFormat.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class ConversionsSuite extends FunSuite {
|"short":123,
|"string": "test string",
|"timestamp": "2015-07-01 00:00:00.001",
|"timestamp_ntz": "2015-07-01 00:00:00.001",
|"array":[1,2,3,4,5],
|"map":{"a":1,"b":2,"c":3},
|"structure":{"num":123,"str":"str1"}
Expand All @@ -163,6 +164,7 @@ class ConversionsSuite extends FunSuite {
StructField("short", ShortType, nullable = false),
StructField("string", StringType, nullable = false),
StructField("timestamp", TimestampType, nullable = false),
StructField("timestamp_ntz", TimestampNTZType, nullable = false),
StructField("array", ArrayType(IntegerType), nullable = false),
StructField("map", MapType(StringType, IntegerType), nullable = false),
StructField(
Expand Down Expand Up @@ -199,9 +201,10 @@ class ConversionsSuite extends FunSuite {
assert(result.getShort(8) == 123.toShort)
assert(result.getString(9) == "test string")
assert(result.getTimestamp(10) == Timestamp.valueOf("2015-07-01 00:00:00.001"))
assert(result.getSeq(11) sameElements Array(1, 2, 3, 4, 5))
assert(result.getMap(12) == Map("b" -> 2, "a" -> 1, "c" -> 3))
assert(result.getStruct(13) == Row(123, "str1"))
assert(result.getTimestamp(11) == Timestamp.valueOf("2015-07-01 00:00:00.001"))
assert(result.getSeq(12) sameElements Array(1, 2, 3, 4, 5))
assert(result.getMap(13) == Map("b" -> 2, "a" -> 1, "c" -> 3))
assert(result.getStruct(14) == Row(123, "str1"))

}
}

0 comments on commit e3267f0

Please sign in to comment.