Skip to content

Commit

Permalink
Fix Parquet Issues (#591)
Browse files Browse the repository at this point in the history
* fix parameter

* remove timestamp NTZ to be compatible with Spark 3.3-

* refactor cloud operations

* support null value in array

* refactor cloud operations

* move parquet writer

* fix file size

* fix gcs
  • Loading branch information
sfc-gh-bli authored Nov 4, 2024
1 parent 3e45067 commit 452d89b
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 347 deletions.
43 changes: 39 additions & 4 deletions src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ class ParquetSuite extends IntegrationSuiteBase {
// throw exception because only support SaveMode.Append
assertThrows[UnsupportedOperationException] {
df.write
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_column_map_parquet)
Expand All @@ -514,7 +514,7 @@ class ParquetSuite extends IntegrationSuiteBase {
// throw exception because "aaa" is not a column name of DF
assertThrows[IllegalArgumentException] {
df.write
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_column_map_parquet)
Expand All @@ -526,7 +526,7 @@ class ParquetSuite extends IntegrationSuiteBase {
// throw exception because "AAA" is not a column name of table in snowflake database
assertThrows[IllegalArgumentException] {
df.write
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_column_map_parquet)
Expand All @@ -536,6 +536,41 @@ class ParquetSuite extends IntegrationSuiteBase {
}
}

test("null value in array") {
val data: RDD[Row] = sc.makeRDD(
List(
Row(
Array(null, "one", "two", "three"),
),
Row(
Array("one", null, "two", "three"),
)
)
)

val schema = StructType(List(
StructField("ARRAY_STRING_FIELD",
ArrayType(StringType, containsNull = true), nullable = true)))
val df = sparkSession.createDataFrame(data, schema)
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_array_map)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.mode(SaveMode.Overwrite)
.save()


val res = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_array_map)
.schema(schema)
.load().collect()
assert(res.head.getSeq(0) == Seq("null", "one", "two", "three"))
assert(res(1).getSeq(0) == Seq("one", "null", "two", "three"))
}

test("test error when column map does not match") {
jdbcUpdate(s"create or replace table $test_column_map_not_match (num int, str string)")
// auto map
Expand All @@ -547,7 +582,7 @@ class ParquetSuite extends IntegrationSuiteBase {

assertThrows[SQLException]{
df1.write
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_column_map_not_match)
Expand Down
15 changes: 3 additions & 12 deletions src/it/scala/net/snowflake/spark/snowflake/io/StageSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -311,16 +311,13 @@ class StageSuite extends IntegrationSuiteBase {
try {
// The credential for the external stage is fake.
val azureExternalStage = ExternalAzureStorage(
param,
containerName = "test_fake_container",
azureAccount = "test_fake_account",
azureEndpoint = "blob.core.windows.net",
azureSAS =
"?sig=test_test_test_test_test_test_test_test_test_test_test_test" +
"_test_test_test_test_test_fak&spr=https&sp=rwdl&sr=c",
param.proxyInfo,
param.maxRetryCount,
param.sfURL,
param.useExponentialBackoff,
param.expectedPartitionCount,
pref = "test_dir",
connection = connection
Expand Down Expand Up @@ -367,13 +364,10 @@ class StageSuite extends IntegrationSuiteBase {
try {
// The credential for the external stage is fake.
val s3ExternalStage = ExternalS3Storage(
param,
bucketName = "test_fake_bucket",
awsId = "TEST_TEST_TEST_TEST1",
awsKey = "TEST_TEST_TEST_TEST_TEST_TEST_TEST_TEST2",
param.proxyInfo,
param.maxRetryCount,
param.sfURL,
param.useExponentialBackoff,
param.expectedPartitionCount,
pref = "test_dir",
connection = connection,
Expand Down Expand Up @@ -487,13 +481,10 @@ class StageSuite extends IntegrationSuiteBase {
try {
// The credential for the external stage is fake.
val s3ExternalStage = ExternalS3Storage(
param,
bucketName = "test_fake_bucket",
awsId = "TEST_TEST_TEST_TEST1",
awsKey = "TEST_TEST_TEST_TEST_TEST_TEST_TEST_TEST2",
param.proxyInfo,
param.maxRetryCount,
param.sfURL,
param.useExponentialBackoff,
param.expectedPartitionCount,
pref = "test_dir",
connection = connection,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ object Parameters {
Set("off", "no", "false", "0", "disabled")

// enable parquet format
val PARAM_USE_PARQUET_IN_WRITE: String = knownParam("use_parquet_in_write ")
val PARAM_USE_PARQUET_IN_WRITE: String = knownParam("use_parquet_in_write")

/**
* Helper method to check if a given string represents some form
Expand Down
51 changes: 2 additions & 49 deletions src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,15 @@

package net.snowflake.spark.snowflake

import scala.collection.JavaConverters._
import java.sql.{Date, Timestamp}
import net.snowflake.client.jdbc.internal.apache.commons.codec.binary.Base64
import net.snowflake.spark.snowflake.DefaultJDBCWrapper.{snowflakeStyleSchema, snowflakeStyleString}
import net.snowflake.spark.snowflake.Parameters.{MergedParameters, mergeParameters}
import net.snowflake.spark.snowflake.SparkConnectorContext.getClass
import net.snowflake.spark.snowflake.Utils.ensureUnquoted
import net.snowflake.spark.snowflake.Parameters.MergedParameters
import net.snowflake.spark.snowflake.io.SupportedFormat
import net.snowflake.spark.snowflake.io.SupportedFormat.SupportedFormat
import org.apache.avro.Schema
import org.apache.avro.generic.GenericData
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.util.RebaseDateTime
import org.slf4j.LoggerFactory

import java.nio.ByteBuffer
import java.time.{LocalDate, ZoneId, ZoneOffset}
import java.util.concurrent.TimeUnit
import scala.collection.mutable

/**
* Functions to write data to Snowflake.
Expand Down Expand Up @@ -198,42 +186,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
format match {
case SupportedFormat.PARQUET =>
val snowflakeStyleSchema = mapColumn(data.schema, params, snowflakeStyle = true)
val schema = io.ParquetUtils.convertStructToAvro(snowflakeStyleSchema)
(data.rdd.map (row => {
def rowToAvroRecord(row: Row,
schema: Schema,
snowflakeStyleSchema: StructType,
params: MergedParameters): GenericData.Record = {
val record = new GenericData.Record(schema)
row.toSeq.zip(snowflakeStyleSchema.names).foreach {
case (row: Row, name) =>
record.put(name,
rowToAvroRecord(
row,
schema.getField(name).schema().getTypes.get(0),
snowflakeStyleSchema(name).dataType.asInstanceOf[StructType],
params
))
case (map: scala.collection.immutable.Map[Any, Any], name) =>
record.put(name, map.asJava)
case (str: String, name) =>
record.put(name, if (params.trimSpace) str.trim else str)
case (arr: mutable.WrappedArray[Any], name) =>
record.put(name, arr.toArray)
case (decimal: java.math.BigDecimal, name) =>
record.put(name, ByteBuffer.wrap(decimal.unscaledValue().toByteArray))
case (timestamp: java.sql.Timestamp, name) =>
record.put(name, timestamp.toString)
case (date: java.sql.Date, name) =>
record.put(name, date.toString)
case (date: java.time.LocalDateTime, name) =>
record.put(name, date.toEpochSecond(ZoneOffset.UTC))
case (value, name) => record.put(name, value)
}
record
}
rowToAvroRecord(row, schema, snowflakeStyleSchema, params)
}), snowflakeStyleSchema)
(data.rdd.asInstanceOf[RDD[Any]], snowflakeStyleSchema)
case SupportedFormat.CSV =>
val conversionFunction = genConversionFunctions(data.schema, params)
(data.rdd.map(row => {
Expand Down
Loading

0 comments on commit 452d89b

Please sign in to comment.