From 71d22e9c9331ee159edb697f4146a05e6699b137 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 26 Sep 2024 15:48:47 -0700 Subject: [PATCH] fix test --- .../spark/snowflake/ParquetSuite.scala | 46 +++++++++++++++++++ .../spark/snowflake/VariantTypeSuite.scala | 4 +- .../snowflake/SnowflakeJDBCWrapper.scala | 5 +- .../spark/snowflake/SnowflakeWriter.scala | 16 +++++-- .../spark/snowflake/io/StageWriter.scala | 7 ++- 5 files changed, 67 insertions(+), 11 deletions(-) diff --git a/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala index 066964ed..8dbad515 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala @@ -12,6 +12,7 @@ import scala.util.Random class ParquetSuite extends IntegrationSuiteBase { val test_parquet_table: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString val test_parquet_column_map: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString + val test_special_character: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString override def afterAll(): Unit = { runSql(s"drop table if exists $test_parquet_table") @@ -380,6 +381,51 @@ class ParquetSuite extends IntegrationSuiteBase { assert(newDf.schema.fieldNames.contains("\"timestamp.()col\"")) } + test("test parquet with special character to existing table"){ + jdbcUpdate( + s"""create or replace table $test_special_character + |("timestamp1.()col" timestamp, "date1.()col" date)""".stripMargin + ) + + val data: RDD[Row] = sc.makeRDD( + List( + Row( + Timestamp.valueOf("0001-12-30 10:15:30"), + Date.valueOf("0001-03-01") + ) + ) + ) + + val schema = StructType(List( + StructField("\"timestamp1.()col\"", TimestampType, true), + StructField("date1.()col", DateType, true) + )) + + val df = sparkSession.createDataFrame(data, schema) + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true") + .option("dbtable", test_special_character) + .mode(SaveMode.Append) + .save() + + val newDf = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_special_character) + .load() + newDf.show() + + checkAnswer(newDf, List( + Row( + Timestamp.valueOf("0001-12-30 10:15:30"), + Date.valueOf("0001-03-01") + ) + )) + assert(newDf.schema.fieldNames.contains("\"timestamp1.()col\"")) + } + test("Test columnMap with parquet") { jdbcUpdate( s"create or replace table $test_parquet_column_map (ONE int, TWO int, THREE int, Four int)" diff --git a/src/it/scala/net/snowflake/spark/snowflake/VariantTypeSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/VariantTypeSuite.scala index 3ca5af58..f2405992 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/VariantTypeSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/VariantTypeSuite.scala @@ -255,8 +255,8 @@ class VariantTypeSuite extends IntegrationSuiteBase { val result = out.collect() assert(result.length == 3) - val bin = result(0).get(0).asInstanceOf[Array[Byte]] - assert(new String(bin).equals("binary1")) + val bin = new String(result(0).get(0).asInstanceOf[Array[Byte]]) + assert(bin.equals("binary1")) assert(result(0).getList[Int](1).get(0) == 1) assert(result(1).getList[Int](1).get(1) == 5) assert(result(2).getList[Int](1).get(2) == 9) diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala index 1c0412f0..2a7f02ae 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala @@ -413,10 +413,7 @@ private[snowflake] object DefaultJDBCWrapper extends JDBCWrapper { temporary: Boolean, bindVariableEnabled: Boolean = true): Unit = { val columnNames = snowflakeStyleSchema(stagingTableSchema, params).fields - .map(field => { - val name: String = field.name - s"""$name""" - }) + .map(_.name) .mkString(",") (ConstantString("create") + (if (overwrite) "or replace" else "") + diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala index c1b92c5c..57abf8aa 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala @@ -92,7 +92,6 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { val toSchema = Utils.removeQuote( jdbcWrapper.resolveTable(conn, params.table.get.name, params) ) - params.setSnowflakeTableSchema(toSchema) params.columnMap match { case Some(map) => map.values.foreach{ @@ -107,9 +106,18 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { } finally conn.close() } + if (saveMode != SaveMode.Overwrite){ + val conn = jdbcWrapper.getConnector(params) + try{ + val toSchema = jdbcWrapper.resolveTable(conn, params.table.get.name, params) + params.setSnowflakeTableSchema(toSchema) + } finally conn.close() + } + + val output: DataFrame = removeUselessColumns(data, params) - val strRDDAndSchema = dataFrameToRDD(sqlContext, output, params, format) - io.writeRDD(sqlContext, params, strRDDAndSchema._1, strRDDAndSchema._2, saveMode, format) + val (strRDD, schema) = dataFrameToRDD(sqlContext, output, params, format) + io.writeRDD(sqlContext, params, strRDD, schema, saveMode, format) } /** @@ -229,7 +237,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { ) }) (spark.createDataFrame(newData, newSchema) - .toJSON.map(_.toString).rdd.asInstanceOf[RDD[Any]], newSchema) + .toJSON.map(_.toString).rdd.asInstanceOf[RDD[Any]], data.schema) } } diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala index 13503c59..b1e5fe02 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala @@ -426,7 +426,7 @@ private[io] object StageWriter { overwrite = false, temporary = false) } else if (tableExists){ conn.createTableSelectFrom( - targetTable.name, + tempTable.name, params.toFiltered(params.getSnowflakeTableSchema()), table.name, params.getSnowflakeTableSchema(), @@ -434,7 +434,12 @@ private[io] object StageWriter { overwrite = true, temporary = false ) + } else if (!tableExists){ + conn.createTable(targetTable.name, + params.toFiltered(params.getSnowflakeTableSchema()), params, + overwrite = false, temporary = false) } + } else { // purge tables when overwriting if (saveMode == SaveMode.Overwrite && tableExists) {