Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
sfc-gh-yuwang committed Sep 26, 2024
1 parent 3db1b37 commit 71d22e9
Showing 5 changed files with 67 additions and 11 deletions.
46 changes: 46 additions & 0 deletions src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@ import scala.util.Random
class ParquetSuite extends IntegrationSuiteBase {
val test_parquet_table: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_parquet_column_map: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_special_character: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString

override def afterAll(): Unit = {
runSql(s"drop table if exists $test_parquet_table")
@@ -380,6 +381,51 @@ class ParquetSuite extends IntegrationSuiteBase {
assert(newDf.schema.fieldNames.contains("\"timestamp.()col\""))
}

test("test parquet with special character to existing table"){
jdbcUpdate(
s"""create or replace table $test_special_character
|("timestamp1.()col" timestamp, "date1.()col" date)""".stripMargin
)

val data: RDD[Row] = sc.makeRDD(
List(
Row(
Timestamp.valueOf("0001-12-30 10:15:30"),
Date.valueOf("0001-03-01")
)
)
)

val schema = StructType(List(
StructField("\"timestamp1.()col\"", TimestampType, true),
StructField("date1.()col", DateType, true)
))

val df = sparkSession.createDataFrame(data, schema)
df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_special_character)
.mode(SaveMode.Append)
.save()

val newDf = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_special_character)
.load()
newDf.show()

checkAnswer(newDf, List(
Row(
Timestamp.valueOf("0001-12-30 10:15:30"),
Date.valueOf("0001-03-01")
)
))
assert(newDf.schema.fieldNames.contains("\"timestamp1.()col\""))
}

test("Test columnMap with parquet") {
jdbcUpdate(
s"create or replace table $test_parquet_column_map (ONE int, TWO int, THREE int, Four int)"
Original file line number Diff line number Diff line change
@@ -255,8 +255,8 @@ class VariantTypeSuite extends IntegrationSuiteBase {
val result = out.collect()
assert(result.length == 3)

val bin = result(0).get(0).asInstanceOf[Array[Byte]]
assert(new String(bin).equals("binary1"))
val bin = new String(result(0).get(0).asInstanceOf[Array[Byte]])
assert(bin.equals("binary1"))
assert(result(0).getList[Int](1).get(0) == 1)
assert(result(1).getList[Int](1).get(1) == 5)
assert(result(2).getList[Int](1).get(2) == 9)
Original file line number Diff line number Diff line change
@@ -413,10 +413,7 @@ private[snowflake] object DefaultJDBCWrapper extends JDBCWrapper {
temporary: Boolean,
bindVariableEnabled: Boolean = true): Unit = {
val columnNames = snowflakeStyleSchema(stagingTableSchema, params).fields
.map(field => {
val name: String = field.name
s"""$name"""
})
.map(_.name)
.mkString(",")
(ConstantString("create") +
(if (overwrite) "or replace" else "") +
16 changes: 12 additions & 4 deletions src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala
Original file line number Diff line number Diff line change
@@ -92,7 +92,6 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
val toSchema = Utils.removeQuote(
jdbcWrapper.resolveTable(conn, params.table.get.name, params)
)
params.setSnowflakeTableSchema(toSchema)
params.columnMap match {
case Some(map) =>
map.values.foreach{
@@ -107,9 +106,18 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
} finally conn.close()
}

if (saveMode != SaveMode.Overwrite){
val conn = jdbcWrapper.getConnector(params)
try{
val toSchema = jdbcWrapper.resolveTable(conn, params.table.get.name, params)
params.setSnowflakeTableSchema(toSchema)
} finally conn.close()
}


val output: DataFrame = removeUselessColumns(data, params)
val strRDDAndSchema = dataFrameToRDD(sqlContext, output, params, format)
io.writeRDD(sqlContext, params, strRDDAndSchema._1, strRDDAndSchema._2, saveMode, format)
val (strRDD, schema) = dataFrameToRDD(sqlContext, output, params, format)
io.writeRDD(sqlContext, params, strRDD, schema, saveMode, format)
}

/**
@@ -229,7 +237,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
)
})
(spark.createDataFrame(newData, newSchema)
.toJSON.map(_.toString).rdd.asInstanceOf[RDD[Any]], newSchema)
.toJSON.map(_.toString).rdd.asInstanceOf[RDD[Any]], data.schema)
}
}

Original file line number Diff line number Diff line change
@@ -426,15 +426,20 @@ private[io] object StageWriter {
overwrite = false, temporary = false)
} else if (tableExists){
conn.createTableSelectFrom(
targetTable.name,
tempTable.name,
params.toFiltered(params.getSnowflakeTableSchema()),
table.name,
params.getSnowflakeTableSchema(),
params,
overwrite = true,
temporary = false
)
} else if (!tableExists){
conn.createTable(targetTable.name,
params.toFiltered(params.getSnowflakeTableSchema()), params,
overwrite = false, temporary = false)
}

} else {
// purge tables when overwriting
if (saveMode == SaveMode.Overwrite && tableExists) {

0 comments on commit 71d22e9

Please sign in to comment.