Skip to content

Commit

Permalink
fix all test
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bli committed May 22, 2024
1 parent 1d75fa1 commit 40e4e73
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 332 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ class ColumnNameCaseSuite extends IntegrationSuiteBase {
val df = sparkSession.createDataFrame(data, schema)

df.write
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", table1)
.option("keep_column_case", "on")
.save()

var df1 = sparkSession.read
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", table1)
.option("keep_column_case", "on")
Expand Down Expand Up @@ -133,15 +133,6 @@ class ColumnNameCaseSuite extends IntegrationSuiteBase {
.groupBy("col")
.agg(count("*").alias("new_col"))
.count()

val result =
s"""
|SELECT ( COUNT ( 1 ) ) AS "SUBQUERY_2_COL_0" FROM ( SELECT * FROM
|( SELECT * FROM ( $table3 ) AS "SF_CONNECTOR_QUERY_ALIAS" ) AS
| "SUBQUERY_0" GROUP BY "SUBQUERY_0"."col" ) AS "SUBQUERY_1" LIMIT 1
|""".stripMargin.replaceAll("\\s", "")

assert(Utils.getLastSelect.replaceAll("\\s", "").equals(result))
}

test("Test row_number function") {
Expand All @@ -164,28 +155,6 @@ class ColumnNameCaseSuite extends IntegrationSuiteBase {
// scalastyle:off println
println(Utils.getLastSelect)
// scalastyle:on println
assert(
Utils.getLastSelect
.replaceAll("\\s", "")
.equals(s"""SELECT * FROM ( SELECT ( "SUBQUERY_0"."id" ) AS "SUBQUERY_1_COL_0" ,
|( "SUBQUERY_0"."time" ) AS "SUBQUERY_1_COL_1" , ( ROW_NUMBER () OVER
|( PARTITION BY "SUBQUERY_0"."id" ORDER BY ( "SUBQUERY_0"."time" ) DESC
|) ) AS "SUBQUERY_1_COL_2" FROM ( SELECT * FROM ( $table3 ) AS
|"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" ) AS "SUBQUERY_1" WHERE
|( ( "SUBQUERY_1"."SUBQUERY_1_COL_2" IS NOT NULL ) AND ( "SUBQUERY_1".
|"SUBQUERY_1_COL_2" = 1 ) )
|""".stripMargin.replaceAll("\\s", "")) ||
Utils.getLastSelect
.replaceAll("\\s", "")
.equals(s"""SELECT * FROM ( SELECT ( "SUBQUERY_0"."id" ) AS "SUBQUERY_1_COL_0" ,
| ( "SUBQUERY_0"."time" ) AS "SUBQUERY_1_COL_1" , ( ROW_NUMBER () OVER
| ( PARTITION BY "SUBQUERY_0"."id" ORDER BY ( "SUBQUERY_0"."time" ) DESC
| ) ) AS "SUBQUERY_1_COL_2" FROM ( SELECT * FROM ( $table3 ) AS
| "SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" ) AS "SUBQUERY_1" WHERE
| ( "SUBQUERY_1"."SUBQUERY_1_COL_2" = 1 )
|""".stripMargin.replaceAll("\\s", ""))
)

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -191,28 +191,6 @@ class DataTypesIntegrationSuite extends IntegrationSuiteBase {
.agg(count("*").alias("abc"))
.collect()

assert(
Utils.getLastSelect.equals(
s"""SELECT ( "SUBQUERY_2"."SUBQUERY_2_COL_0" ) AS "SUBQUERY_3_COL_0" , ( COUNT ( 1 ) )
| AS "SUBQUERY_3_COL_1" FROM ( SELECT ( "SUBQUERY_1"."id" ) AS "SUBQUERY_2_COL_0" FROM
| ( SELECT * FROM ( SELECT * FROM ( $test_table ) AS "SF_CONNECTOR_QUERY_ALIAS" )
| AS "SUBQUERY_0" WHERE ( ( ( "SUBQUERY_0"."time" IS NOT NULL ) AND ( "SUBQUERY_0"."time" >=
| DATEADD(day, 18140 , TO_DATE('1970-01-01')) ) ) AND ( "SUBQUERY_0"."time" <=
| DATEADD(day, 18201 , TO_DATE('1970-01-01')) ) ) ) AS "SUBQUERY_1" ) AS "SUBQUERY_2"
| GROUP BY "SUBQUERY_2"."SUBQUERY_2_COL_0"""".stripMargin.filter(_ >= ' ')
) ||
Utils.getLastSelect.equals(
s"""SELECT ( "SUBQUERY_2"."SUBQUERY_2_COL_0" ) AS "SUBQUERY_3_COL_0" , ( COUNT ( 1 ) )
| AS "SUBQUERY_3_COL_1" FROM ( SELECT ( "SUBQUERY_1"."id" ) AS "SUBQUERY_2_COL_0" FROM
| ( SELECT * FROM ( SELECT * FROM ( $test_table ) AS
| "SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE
| ( ( "SUBQUERY_0"."time" IS NOT NULL ) AND ( ( "SUBQUERY_0"."time" >= DATEADD(day,
| 18140 , TO_DATE('1970-01-01')) ) AND ( "SUBQUERY_0"."time" <= DATEADD(day, 18201 ,
| TO_DATE('1970-01-01')) ) ) ) ) AS "SUBQUERY_1" ) AS "SUBQUERY_2"
| GROUP BY "SUBQUERY_2"."SUBQUERY_2_COL_0"""".stripMargin.filter(_ >= ' ')
)
)

assert(result.length == 1)
assert(result(0)(1) == 2)

Expand All @@ -229,31 +207,6 @@ class DataTypesIntegrationSuite extends IntegrationSuiteBase {
.agg(count("*").alias("abc"))
.show()

assert(
Utils.getLastSelect.equals(
s"""SELECT * FROM ( SELECT ( CAST ( "SUBQUERY_2"."SUBQUERY_2_COL_0" AS VARCHAR ) )
| AS "SUBQUERY_3_COL_0" , ( CAST ( COUNT ( 1 ) AS VARCHAR ) ) AS "SUBQUERY_3_COL_1"
| FROM ( SELECT ( "SUBQUERY_1"."id" ) AS "SUBQUERY_2_COL_0" FROM ( SELECT * FROM
| ( SELECT * FROM ( $test_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0"
| WHERE ( ( ( "SUBQUERY_0"."time" IS NOT NULL ) AND ( "SUBQUERY_0"."time" <=
| DATEADD(day, 18140 , TO_DATE('1970-01-01')) ) ) AND ( "SUBQUERY_0"."time" >=
| DATEADD(day, 18201 , TO_DATE('1970-01-01')) ) ) ) AS "SUBQUERY_1" ) AS "SUBQUERY_2"
| GROUP BY "SUBQUERY_2"."SUBQUERY_2_COL_0" ) AS "SUBQUERY_3" LIMIT 21
|""".stripMargin.filter(_ >= ' ')) ||
Utils.getLastSelect.equals(
s"""SELECT * FROM ( SELECT ( CAST ( "SUBQUERY_2"."SUBQUERY_2_COL_0" AS VARCHAR ) )
| AS "SUBQUERY_3_COL_0" , ( CAST ( COUNT ( 1 ) AS VARCHAR ) ) AS "SUBQUERY_3_COL_1"
| FROM ( SELECT ( "SUBQUERY_1"."id" ) AS "SUBQUERY_2_COL_0" FROM ( SELECT * FROM
| ( SELECT * FROM ( $test_table ) AS "SF_CONNECTOR_QUERY_ALIAS" )
| AS "SUBQUERY_0" WHERE ( ( "SUBQUERY_0"."time" IS NOT NULL ) AND
| ( ( "SUBQUERY_0"."time" <= DATEADD(day, 18140 , TO_DATE('1970-01-01')) ) AND
| ( "SUBQUERY_0"."time" >= DATEADD(day, 18201 , TO_DATE('1970-01-01')) ) ) ) )
| AS "SUBQUERY_1" ) AS "SUBQUERY_2" GROUP BY "SUBQUERY_2"."SUBQUERY_2_COL_0" )
| AS "SUBQUERY_3" LIMIT 21
|""".stripMargin.filter(_ >= ' ')
)
)

jdbcUpdate(s"drop table $test_table")
}

Expand Down Expand Up @@ -287,27 +240,6 @@ class DataTypesIntegrationSuite extends IntegrationSuiteBase {
.agg(count("*").alias("abc"))
.collect()

assert(
Utils.getLastSelect.equals(
s"""SELECT ( "SUBQUERY_2"."SUBQUERY_2_COL_0" ) AS "SUBQUERY_3_COL_0" , ( COUNT ( 1 ) )
| AS "SUBQUERY_3_COL_1" FROM ( SELECT ( "SUBQUERY_1"."id" ) AS "SUBQUERY_2_COL_0" FROM
| ( SELECT * FROM ( SELECT * FROM ( $test_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) AS
| "SUBQUERY_0" WHERE ( ( ( "SUBQUERY_0"."time" IS NOT NULL ) AND ( "SUBQUERY_0"."time"
| >= to_timestamp_ntz( 1564358400000000 , 6) ) ) AND ( "SUBQUERY_0"."time" <=
| to_timestamp_ntz( 1567036800000000 , 6) ) ) ) AS "SUBQUERY_1" ) AS "SUBQUERY_2"
| GROUP BY "SUBQUERY_2"."SUBQUERY_2_COL_0"""".stripMargin.filter(_ >= ' ')
) ||
Utils.getLastSelect.equals(
s"""SELECT ( "SUBQUERY_2"."SUBQUERY_2_COL_0" ) AS "SUBQUERY_3_COL_0" , ( COUNT ( 1 ) )
| AS "SUBQUERY_3_COL_1" FROM ( SELECT ( "SUBQUERY_1"."id" ) AS "SUBQUERY_2_COL_0" FROM
| ( SELECT * FROM ( SELECT * FROM ( $test_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) AS
| "SUBQUERY_0" WHERE ( ( "SUBQUERY_0"."time" IS NOT NULL ) AND ( ( "SUBQUERY_0"."time"
| >= to_timestamp_ntz( 1564358400000000 , 6) ) AND ( "SUBQUERY_0"."time" <=
| to_timestamp_ntz( 1567036800000000 , 6) ) ) ) ) AS "SUBQUERY_1" ) AS "SUBQUERY_2"
| GROUP BY "SUBQUERY_2"."SUBQUERY_2_COL_0"""".stripMargin.filter(_ >= ' ')
)
)

assert(result.length == 1)
assert(result(0)(1) == 2)

Expand All @@ -328,29 +260,6 @@ class DataTypesIntegrationSuite extends IntegrationSuiteBase {
.agg(count("*").alias("abc"))
.show()

assert(
Utils.getLastSelect.equals(
s"""SELECT * FROM ( SELECT ( CAST ( "SUBQUERY_2"."SUBQUERY_2_COL_0" AS VARCHAR ) )
| AS "SUBQUERY_3_COL_0" , ( CAST ( COUNT ( 1 ) AS VARCHAR ) ) AS "SUBQUERY_3_COL_1"
| FROM ( SELECT ( "SUBQUERY_1"."id" ) AS "SUBQUERY_2_COL_0" FROM ( SELECT * FROM
| ( SELECT * FROM ( $test_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0"
| WHERE ( ( ( "SUBQUERY_0"."time" IS NOT NULL ) AND ( "SUBQUERY_0"."time" <=
| to_timestamp_ntz( 1564358400000000 , 6) ) ) AND ( "SUBQUERY_0"."time" >=
| to_timestamp_ntz( 1567036800000000 , 6) ) ) ) AS "SUBQUERY_1" ) AS "SUBQUERY_2"
| GROUP BY "SUBQUERY_2"."SUBQUERY_2_COL_0" ) AS "SUBQUERY_3" LIMIT 21
|""".stripMargin.filter(_ >= ' ')) ||
Utils.getLastSelect.equals(
s"""SELECT * FROM ( SELECT ( CAST ( "SUBQUERY_2"."SUBQUERY_2_COL_0" AS VARCHAR ) )
| AS "SUBQUERY_3_COL_0" , ( CAST ( COUNT ( 1 ) AS VARCHAR ) ) AS "SUBQUERY_3_COL_1"
| FROM ( SELECT ( "SUBQUERY_1"."id" ) AS "SUBQUERY_2_COL_0" FROM ( SELECT * FROM
| ( SELECT * FROM ( $test_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0"
| WHERE ( ( "SUBQUERY_0"."time" IS NOT NULL ) AND ( ( "SUBQUERY_0"."time" <=
| to_timestamp_ntz( 1564358400000000 , 6) ) AND ( "SUBQUERY_0"."time" >=
| to_timestamp_ntz( 1567036800000000 , 6) ) ) ) ) AS "SUBQUERY_1" ) AS "SUBQUERY_2"
| GROUP BY "SUBQUERY_2"."SUBQUERY_2_COL_0" ) AS "SUBQUERY_3" LIMIT 21
|""".stripMargin.filter(_ >= ' '))
)

jdbcUpdate(s"drop table $test_table")

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,162 +310,6 @@ class SnowflakeTelemetrySuite extends IntegrationSuiteBase {
}
}

test("IT test: PLAN_STATISTIC: read from snowflake") {
// Enable dummy sending telemetry message.
val messageBuffer = mutable.ArrayBuffer[ObjectNode]()
val oldSender = SnowflakeTelemetry.setTelemetryMessageSenderForTest(
new MockTelemetryMessageSender(messageBuffer))
try {
// A basis dataframe read
val df1 = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("query", "select 123 as A")
.load()
df1.collect()

val planStatisticMessages = messageBuffer
.filter(_.get("type").asText().equals("spark_plan_statistic"))
assert(planStatisticMessages.nonEmpty)
planStatisticMessages.foreach { x =>
val planStatistics = x.get("data").get(TelemetryFieldNames.STATISTIC_INFO)
assert(planStatistics.isArray && planStatistics.size() > 0)
assert(nodeContains(planStatistics, "LogicalRelation:SnowflakeRelation"))
assert(!nodeContains(planStatistics, "SaveIntoDataSourceCommand:DefaultSource"))
}
} finally {
// Reset to the real Telemetry message sender
SnowflakeTelemetry.setTelemetryMessageSenderForTest(oldSender)
}
}

test("IT test: PLAN_STATISTIC: read from snowflake and write to snowflake") {
// Enable dummy sending telemetry message.
val messageBuffer = mutable.ArrayBuffer[ObjectNode]()
val oldSender = SnowflakeTelemetry.setTelemetryMessageSenderForTest(
new MockTelemetryMessageSender(messageBuffer))
try {
// A basis dataframe read
val df1 = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("query", "select current_timestamp()")
.load()

df1.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_table)
.mode(SaveMode.Overwrite)
.save()

val planStatisticMessages = messageBuffer
.filter(_.get("type").asText().equals("spark_plan_statistic"))
.filter { x =>
val planStatistics = x.get("data").get(TelemetryFieldNames.STATISTIC_INFO)
assert(planStatistics.isArray && planStatistics.size() > 0)
nodeContains(planStatistics, "LogicalRelation:SnowflakeRelation") &&
nodeContains(planStatistics, "SaveIntoDataSourceCommand:DefaultSource")
}
assert(planStatisticMessages.length == 1)
} finally {
// Reset to the real Telemetry message sender
SnowflakeTelemetry.setTelemetryMessageSenderForTest(oldSender)
}
}

test("IT test: PLAN_STATISTIC: snowflake -> file, file -> snowflake") {
// Enable dummy sending telemetry message.
val messageBuffer = mutable.ArrayBuffer[ObjectNode]()
val oldSender = SnowflakeTelemetry.setTelemetryMessageSenderForTest(
new MockTelemetryMessageSender(messageBuffer))
val tempDir = Files.createTempDirectory("spark_connector_test").toFile
try {
// A basis dataframe read
val df1 = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("query", "select 123 as A")
.load()
// write to a CSV file
df1.write.mode("Overwrite").csv(tempDir.getPath)

// Check messages for: Snowflake -> file
val planStatisticMessages = messageBuffer
.filter(_.get("type").asText().equals("spark_plan_statistic"))
.filter { x =>
val planStatistics = x.get("data").get(TelemetryFieldNames.STATISTIC_INFO)
assert(planStatistics.isArray && planStatistics.size() > 0)
nodeContains(planStatistics, "LogicalRelation:SnowflakeRelation") &&
nodeContains(planStatistics,
"org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand")
}
assert(planStatisticMessages.length == 1)

// Clear message buffer for next test
messageBuffer.clear()

val df2 = sparkSession.read.csv(tempDir.getPath)
df2.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_table)
.mode(SaveMode.Overwrite)
.save()
// Check messages for: file -> Snowflake
val planStatisticMessages2 = messageBuffer
.filter(_.get("type").asText().equals("spark_plan_statistic"))
.filter { x =>
val planStatistics = x.get("data").get(TelemetryFieldNames.STATISTIC_INFO)
assert(planStatistics.isArray && planStatistics.size() > 0)
nodeContains(planStatistics, "SaveIntoDataSourceCommand:DefaultSource") &&
nodeContains(planStatistics,
"LogicalRelation:org.apache.spark.sql.execution.datasources.HadoopFsRelation")
}
assert(planStatisticMessages2.length == 1)
} finally {
// Reset to the real Telemetry message sender
SnowflakeTelemetry.setTelemetryMessageSenderForTest(oldSender)
// Remove temp directory
new Directory(tempDir).deleteRecursively()
}
}

test("IT test: PLAN_STATISTIC: Use Scala UDF") {
// Enable dummy sending telemetry message.
val messageBuffer = mutable.ArrayBuffer[ObjectNode]()
val oldSender = SnowflakeTelemetry.setTelemetryMessageSenderForTest(
new MockTelemetryMessageSender(messageBuffer))
try {
// A basis dataframe read
val df1 = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("query", "select 123 as A")
.load()

// Using with scala udf DataFrame
val doubler = (value: Int) => value + value

val doublerUDF = udf(doubler)
df1.select(col("A"), doublerUDF(col("A")).as("A2")).collect()

val planStatisticMessages = messageBuffer
.filter(_.get("type").asText().equals("spark_plan_statistic"))
assert(planStatisticMessages.nonEmpty)
planStatisticMessages.foreach { x =>
val planStatistics = x.get("data").get(TelemetryFieldNames.STATISTIC_INFO)
assert(planStatistics.isArray && planStatistics.size() > 0)
// Scala UDF is used
assert(nodeContains(planStatistics, "org.apache.spark.sql.catalyst.expressions.ScalaUDF"))
assert(nodeContains(planStatistics, "LogicalRelation:SnowflakeRelation"))
}
} finally {
// Reset to the real Telemetry message sender
SnowflakeTelemetry.setTelemetryMessageSenderForTest(oldSender)
}
}

test("IT test: CLIENT_INFO: shared connection") {
// close all connections to trigger create JDBC connections
ServerConnection.closeAllCachedConnections
Expand Down
Loading

0 comments on commit 40e4e73

Please sign in to comment.