From 957bd6796ceda54eaaf96fd6a4ce2f8211722eda Mon Sep 17 00:00:00 2001 From: Bing Li Date: Wed, 14 Feb 2024 14:48:42 -0800 Subject: [PATCH 1/4] add new trim_space parameter --- .../spark/snowflake/IssueSuite.scala | 79 +++++++++++++++++++ .../spark/snowflake/Parameters.scala | 11 ++- .../spark/snowflake/SnowflakeWriter.scala | 22 ++++-- .../spark/snowflake/io/StageWriter.scala | 1 + 4 files changed, 106 insertions(+), 7 deletions(-) diff --git a/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala index 1be7e111..b0399332 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala @@ -20,6 +20,85 @@ class IssueSuite extends IntegrationSuiteBase { super.beforeEach() } + test("trim space - csv") { + val st1 = new StructType( + Array(StructField("str", StringType, nullable = false)) + ) + val tt: String = s"tt_$randomSuffix" + try { + sparkSession + .createDataFrame( + sparkSession.sparkContext.parallelize( + Seq( + Row("ab c"), + Row(" a bc"), + Row("abdc ") + ) + ), + st1 + ) + .write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .option(Parameters.PARAM_TRIM_SPACE, "true") + .mode(SaveMode.Overwrite) + .save() + + val loadDf = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .load() + + assert(loadDf.collect().forall(row => row.toSeq.head.toString.length == 4)) + + } finally { + jdbcUpdate(s"drop table if exists $tt") + } + } + + test("trim space - json") { + val st1 = new StructType( + Array( + StructField("str", StringType, nullable = false), + StructField("arr", ArrayType(IntegerType), nullable = false) + ) + ) + val tt: String = s"tt_$randomSuffix" + try { + sparkSession + .createDataFrame( + sparkSession.sparkContext.parallelize( + Seq( + Row("ab c", Array(1, 2, 3)), + Row(" a bc", Array(2, 2, 3)), + Row("abdc ", Array(3, 2, 3)) + ) + ), + st1 + ) + .write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .option(Parameters.PARAM_TRIM_SPACE, "true") + .mode(SaveMode.Overwrite) + .save() + + val loadDf = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .load() + + assert(loadDf.select("str").collect().forall(row => row.toSeq.head.toString.length == 4)) + + } finally { + jdbcUpdate(s"drop table if exists $tt") + } + } + test("csv delimiter character should not break rows") { val st1 = new StructType( Array( diff --git a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala index fcf2e19e..3d7f1fd9 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala @@ -85,6 +85,7 @@ object Parameters { val PARAM_COLUMN_MAP: String = knownParam("columnmap") val PARAM_TRUNCATE_COLUMNS: String = knownParam("truncate_columns") val PARAM_PURGE: String = knownParam("purge") + val PARAM_TRIM_SPACE: String = knownParam("trim_space") val PARAM_TRUNCATE_TABLE: String = knownParam("truncate_table") val PARAM_CONTINUE_ON_ERROR: String = knownParam("continue_on_error") @@ -288,7 +289,8 @@ object Parameters { PARAM_USE_AWS_MULTIPLE_PARTS_UPLOAD -> "true", PARAM_TIMESTAMP_NTZ_OUTPUT_FORMAT -> "YYYY-MM-DD HH24:MI:SS.FF3", PARAM_TIMESTAMP_LTZ_OUTPUT_FORMAT -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF3", - PARAM_TIMESTAMP_TZ_OUTPUT_FORMAT -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF3" + PARAM_TIMESTAMP_TZ_OUTPUT_FORMAT -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF3", + PARAM_TRIM_SPACE -> "false" ) /** @@ -837,6 +839,13 @@ object Parameters { def useStagingTable: Boolean = isTrue(parameters(PARAM_USE_STAGING_TABLE)) + /** + * Boolean that specifies whether to remove white space from String fields + * Defaults to false + */ + def trimSpace: Boolean = + isTrue(parameters(PARAM_TRIM_SPACE)) + /** * Extra options to append to the Snowflake COPY command (e.g. "MAXERROR 100"). */ diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala index e35c8bf2..2864858e 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala @@ -18,9 +18,8 @@ package net.snowflake.spark.snowflake import java.sql.{Date, Timestamp} - import net.snowflake.client.jdbc.internal.apache.commons.codec.binary.Base64 -import net.snowflake.spark.snowflake.Parameters.MergedParameters +import net.snowflake.spark.snowflake.Parameters.{MergedParameters, mergeParameters} import net.snowflake.spark.snowflake.io.SupportedFormat import net.snowflake.spark.snowflake.io.SupportedFormat.SupportedFormat import org.apache.spark.rdd.RDD @@ -83,7 +82,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { format match { case SupportedFormat.CSV => - val conversionFunction = genConversionFunctions(data.schema) + val conversionFunction = genConversionFunctions(data.schema, params) data.rdd.map(row => { row.toSeq .zip(conversionFunction) @@ -95,7 +94,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { case SupportedFormat.JSON => // convert binary (Array of Byte) to encoded base64 String before COPY val newSchema: StructType = prepareSchemaForJson(data.schema) - val conversionsFunction = genConversionFunctionsForJson(data.schema) + val conversionsFunction = genConversionFunctionsForJson(data.schema, params) val newData: RDD[Row] = data.rdd.map(row => { Row.fromSeq( row.toSeq @@ -118,9 +117,15 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { }) - private def genConversionFunctionsForJson(schema: StructType): Array[Any => Any] = + private def genConversionFunctionsForJson(schema: StructType, + params: MergedParameters): Array[Any => Any] = schema.fields.map(field => field.dataType match { + case StringType => + (v: Any) => + if (params.trimSpace) { + v.toString.trim + } else v case BinaryType => (v: Any) => v match { @@ -157,9 +162,14 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { } // Prepare a set of conversion functions, based on the schema - def genConversionFunctions(schema: StructType): Array[Any => Any] = + def genConversionFunctions(schema: StructType, params: MergedParameters): Array[Any => Any] = schema.fields.map { field => field.dataType match { + case StringType => + (v: Any) => + if (params.trimSpace) { + v.toString.trim + } else v case DateType => (v: Any) => v match { diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala index fbd19b22..1ebec63c 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala @@ -908,6 +908,7 @@ private[io] object StageWriter { params.getStringTimestampFormat.get } + val trimSpace: String = if (params.trimSpace) "TRUE" else "FALSE" val formatString = format match { From 351b077c203531866de39f4ea4238d8cfbb03122 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Wed, 14 Feb 2024 15:06:41 -0800 Subject: [PATCH 2/4] fix --- .../scala/net/snowflake/spark/snowflake/io/StageWriter.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala index 1ebec63c..aa5a4795 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala @@ -908,8 +908,6 @@ private[io] object StageWriter { params.getStringTimestampFormat.get } - val trimSpace: String = if (params.trimSpace) "TRUE" else "FALSE" - val formatString = format match { case SupportedFormat.CSV => From 48eea60d5f852d742ebddbc931a06e47906a3aab Mon Sep 17 00:00:00 2001 From: Bing Li Date: Thu, 15 Feb 2024 10:48:10 -0800 Subject: [PATCH 3/4] add test --- .../spark/snowflake/IssueSuite.scala | 49 ++++++++++++++++--- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala index b0399332..97edd56c 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala @@ -26,7 +26,7 @@ class IssueSuite extends IntegrationSuiteBase { ) val tt: String = s"tt_$randomSuffix" try { - sparkSession + val df = sparkSession .createDataFrame( sparkSession.sparkContext.parallelize( Seq( @@ -37,7 +37,7 @@ class IssueSuite extends IntegrationSuiteBase { ), st1 ) - .write + df.write .format(SNOWFLAKE_SOURCE_NAME) .options(connectorOptions) .option("dbtable", tt) @@ -45,7 +45,7 @@ class IssueSuite extends IntegrationSuiteBase { .mode(SaveMode.Overwrite) .save() - val loadDf = sparkSession.read + var loadDf = sparkSession.read .format(SNOWFLAKE_SOURCE_NAME) .options(connectorOptions) .option("dbtable", tt) @@ -53,6 +53,25 @@ class IssueSuite extends IntegrationSuiteBase { assert(loadDf.collect().forall(row => row.toSeq.head.toString.length == 4)) + // disabled by default + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .mode(SaveMode.Overwrite) + .save() + + loadDf = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .load() + val result = loadDf.collect() + assert(result.head.toSeq.head.toString.length == 4) + assert(result(1).toSeq.head.toString.length == 5) + assert(result(2).toSeq.head.toString.length == 6) + + } finally { jdbcUpdate(s"drop table if exists $tt") } @@ -67,7 +86,7 @@ class IssueSuite extends IntegrationSuiteBase { ) val tt: String = s"tt_$randomSuffix" try { - sparkSession + val df = sparkSession .createDataFrame( sparkSession.sparkContext.parallelize( Seq( @@ -78,7 +97,7 @@ class IssueSuite extends IntegrationSuiteBase { ), st1 ) - .write + df.write .format(SNOWFLAKE_SOURCE_NAME) .options(connectorOptions) .option("dbtable", tt) @@ -86,7 +105,7 @@ class IssueSuite extends IntegrationSuiteBase { .mode(SaveMode.Overwrite) .save() - val loadDf = sparkSession.read + var loadDf = sparkSession.read .format(SNOWFLAKE_SOURCE_NAME) .options(connectorOptions) .option("dbtable", tt) @@ -94,6 +113,24 @@ class IssueSuite extends IntegrationSuiteBase { assert(loadDf.select("str").collect().forall(row => row.toSeq.head.toString.length == 4)) + // disabled by default + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .mode(SaveMode.Overwrite) + .save() + + loadDf = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptions) + .option("dbtable", tt) + .load() + val result = loadDf.select("str").collect() + assert(result.head.toSeq.head.toString.length == 4) + assert(result(1).toSeq.head.toString.length == 5) + assert(result(2).toSeq.head.toString.length == 6) + } finally { jdbcUpdate(s"drop table if exists $tt") } From bb5e653f04a72ac6bc997dc011a0d7b1a0867cee Mon Sep 17 00:00:00 2001 From: Bing Li Date: Thu, 15 Feb 2024 11:12:18 -0800 Subject: [PATCH 4/4] fix test --- .../snowflake/spark/snowflake/SnowflakeWriter.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala index 2864858e..52c47472 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala @@ -165,11 +165,6 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { def genConversionFunctions(schema: StructType, params: MergedParameters): Array[Any => Any] = schema.fields.map { field => field.dataType match { - case StringType => - (v: Any) => - if (params.trimSpace) { - v.toString.trim - } else v case DateType => (v: Any) => v match { @@ -187,7 +182,12 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { (v: Any) => { if (v == null) "" - else Conversions.formatString(v.asInstanceOf[String]) + else { + val trimmed = if (params.trimSpace) { + v.toString.trim + } else v + Conversions.formatString(trimmed.asInstanceOf[String]) + } } case BinaryType => (v: Any) =>