Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TimestampNTZType support #599

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/main/scala/net/snowflake/spark/snowflake/Conversions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ package net.snowflake.spark.snowflake

import java.sql.Timestamp
import java.text._
import java.time.ZonedDateTime
import java.time.{LocalDateTime, ZonedDateTime}
import java.time.format.DateTimeFormatter
import java.util.{Date, TimeZone}

import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -126,6 +125,15 @@ private[snowflake] object Conversions {
TimeZone.getDefault.toZoneId))
}

def formatTimestamp(t: LocalDateTime): String = {
// For writing to snowflake, time zone needs to be included
// in the timestamp string. The spark default timezone is used.
timestampWriteFormatter.format(
ZonedDateTime.of(
t,
TimeZone.getDefault.toZoneId))
}

// All strings are converted into double-quoted strings, with
// quote inside converted to double quotes
def formatString(s: String): String = {
Expand Down Expand Up @@ -176,7 +184,7 @@ private[snowflake] object Conversions {
case ShortType => data.toShort
case StringType =>
if (isIR) UTF8String.fromString(data) else data
case TimestampType => parseTimestamp(data, isIR)
case TimestampType | TimestampNTZType => parseTimestamp(data, isIR)
case _ => data
}
}
Expand Down Expand Up @@ -276,7 +284,7 @@ private[snowflake] object Conversions {
case ShortType => data.shortValue()
case StringType =>
if (isIR) UTF8String.fromString(data.asText()) else data.asText()
case TimestampType => parseTimestamp(data.asText(), isIR)
case TimestampType | TimestampNTZType => parseTimestamp(data.asText(), isIR)
case ArrayType(dt, _) =>
val result = new Array[Any](data.size())
(0 until data.size())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ private[snowflake] object FilterPushdown {
)) !
case DateType =>
StringVariable(Option(value).map(_.asInstanceOf[Date].toString)) + "::DATE"
case TimestampType =>
case TimestampType | TimestampNTZType =>
StringVariable(Option(value).map(_.asInstanceOf[Timestamp].toString)) + "::TIMESTAMP(3)"
case _ =>
value match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ private[snowflake] class JDBCWrapper {
"BINARY"
}
case TimestampType => "TIMESTAMP"
case TimestampNTZType => "TIMESTAMP_NTZ"
case DateType => "DATE"
case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})"
case _: StructType | _: ArrayType | _: MapType => "VARIANT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package net.snowflake.spark.snowflake

import java.sql.{Date, Timestamp}
import net.snowflake.client.jdbc.internal.apache.commons.codec.binary.Base64
import net.snowflake.spark.snowflake.Conversions.timestampWriteFormatter
import net.snowflake.spark.snowflake.DefaultJDBCWrapper.{snowflakeStyleSchema, snowflakeStyleString}
import net.snowflake.spark.snowflake.Parameters.MergedParameters
import net.snowflake.spark.snowflake.io.SupportedFormat
Expand All @@ -27,6 +28,9 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql._

import java.time.{LocalDateTime, ZonedDateTime}
import java.util.TimeZone

/**
* Functions to write data to Snowflake.
*
Expand Down Expand Up @@ -285,6 +289,11 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
if (v == null) ""
else Conversions.formatTimestamp(v.asInstanceOf[Timestamp])
}
case TimestampNTZType =>
(v: Any) => {
if (v == null) ""
else Conversions.formatTimestamp(v.asInstanceOf[LocalDateTime])
}
case StringType =>
(v: Any) =>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ object ParquetUtils {
builder.stringBuilder()
.prop("logicalType", "date")
.endString()
case TimestampType =>
case TimestampType | TimestampNTZType =>
builder.stringBuilder()
.prop("logicalType", " timestamp-micros")
.endString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ case class ResultIterator[T: ClassTag](
case FloatType => data.getFloat(index + 1)
case IntegerType => data.getInt(index + 1)
case LongType => data.getLong(index + 1)
case TimestampType =>
case TimestampType | TimestampNTZType =>
if (isIR) {
DateTimeUtils.fromJavaTimestamp(data.getTimestamp(index + 1))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import net.snowflake.spark.snowflake.io.SupportedFormat.SupportedFormat
import net.snowflake.spark.snowflake.DefaultJDBCWrapper.DataBaseOperations
import net.snowflake.spark.snowflake.test.{TestHook, TestHookFlag}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{BinaryType, StructType, TimestampType}
import org.apache.spark.sql.types.{BinaryType, StructType, TimestampNTZType, TimestampType}
import org.apache.spark.sql.{SQLContext, SaveMode}
import org.slf4j.LoggerFactory

Expand Down Expand Up @@ -925,7 +925,7 @@ private[io] object StageWriter {
val mappingFromString = getMappingFromString(mappingList, fromString)

val hasTimestampColumn: Boolean =
schema.exists(field => field.dataType == TimestampType)
schema.exists(field => field.dataType == TimestampType || field.dataType == TimestampNTZType)

val timestampFormat: String =
if (params.getStringTimestampFormat.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class ConversionsSuite extends FunSuite {
|"short":123,
|"string": "test string",
|"timestamp": "2015-07-01 00:00:00.001",
|"timestamp_ntz": "2015-07-01 00:00:00.001",
|"array":[1,2,3,4,5],
|"map":{"a":1,"b":2,"c":3},
|"structure":{"num":123,"str":"str1"}
Expand All @@ -163,6 +164,7 @@ class ConversionsSuite extends FunSuite {
StructField("short", ShortType, nullable = false),
StructField("string", StringType, nullable = false),
StructField("timestamp", TimestampType, nullable = false),
StructField("timestamp_ntz", TimestampNTZType, nullable = false),
StructField("array", ArrayType(IntegerType), nullable = false),
StructField("map", MapType(StringType, IntegerType), nullable = false),
StructField(
Expand Down Expand Up @@ -199,9 +201,10 @@ class ConversionsSuite extends FunSuite {
assert(result.getShort(8) == 123.toShort)
assert(result.getString(9) == "test string")
assert(result.getTimestamp(10) == Timestamp.valueOf("2015-07-01 00:00:00.001"))
assert(result.getSeq(11) sameElements Array(1, 2, 3, 4, 5))
assert(result.getMap(12) == Map("b" -> 2, "a" -> 1, "c" -> 3))
assert(result.getStruct(13) == Row(123, "str1"))
assert(result.getTimestamp(11) == Timestamp.valueOf("2015-07-01 00:00:00.001"))
assert(result.getSeq(12) sameElements Array(1, 2, 3, 4, 5))
assert(result.getMap(13) == Map("b" -> 2, "a" -> 1, "c" -> 3))
assert(result.getStruct(14) == Row(123, "str1"))

}
}