From 47d1bda0a1a9ecc7618d310de222060f1b7c35cc Mon Sep 17 00:00:00 2001 From: Bing Li Date: Thu, 1 Feb 2024 16:30:07 -0800 Subject: [PATCH 1/5] skip retry after application closed --- .../spark/snowflake/SnowflakeRelation.scala | 15 ++++++++++++--- .../spark/snowflake/SparkConnectorContext.scala | 7 +++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeRelation.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeRelation.scala index 833ed528..bb7d89fc 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeRelation.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeRelation.scala @@ -153,10 +153,19 @@ private[snowflake] case class SnowflakeRelation( // without first executing it. private def getRDD[T: ClassTag](statement: SnowflakeSQLStatement, resultSchema: StructType): RDD[T] = { - if (params.useCopyUnload) { - getSnowflakeRDD(statement, resultSchema) + val appId = sqlContext.sparkContext.applicationId + if (SparkConnectorContext.closedApplicationIDs.contains(appId)) { + // don't execute any snowflake queries if the Spark application was closed. + // Spark trigger `onApplicationEnd` listener early than stop tasks. + // Connector cancels all running sql queries in the `onApplicationEnd` listener. + // spark will re-run canceled Snowflake SQL queries in retries. + throw new IllegalStateException(s"Spark Application ($appId) was closed") } else { - getSnowflakeResultSetRDD(statement, resultSchema) + if (params.useCopyUnload) { + getSnowflakeRDD(statement, resultSchema) + } else { + getSnowflakeResultSetRDD(statement, resultSchema) + } } } diff --git a/src/main/scala/net/snowflake/spark/snowflake/SparkConnectorContext.scala b/src/main/scala/net/snowflake/spark/snowflake/SparkConnectorContext.scala index 79517c4a..25593a3d 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SparkConnectorContext.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SparkConnectorContext.scala @@ -32,6 +32,9 @@ object SparkConnectorContext { // The key is the application ID, the value is the set of running queries. private val runningQueries = mutable.Map[String, mutable.Set[RunningQuery]]() + // save all closed applications' ID, and skip Spark's retries after application closed. + private[snowflake] val closedApplicationIDs = mutable.HashSet.empty[String] + private[snowflake] def getRunningQueries = runningQueries // Register spark listener to cancel any running queries if application fails. @@ -44,6 +47,10 @@ object SparkConnectorContext { runningQueries.put(appId, mutable.Set.empty) sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { + // add application ID to the block list. + // when Spark retries these closed applications, + // Spark connector will skip those queries. + closedApplicationIDs.add(appId) try { cancelRunningQueries(appId) // Close all cached connections From e9386c73ecbeea06a85d1a4e4b8b11d131b761b9 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Fri, 2 Feb 2024 12:28:10 -0800 Subject: [PATCH 2/5] add test --- .../SparkConnectorContextSuite.scala | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala index a3094316..54d85f64 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala @@ -277,4 +277,28 @@ class SparkConnectorContextSuite extends IntegrationSuiteBase { } } + test("Disable retry after application closed") { + + val newSparkSession = createDefaultSparkSession + val sc = newSparkSession.sparkContext + val appId = sc.applicationId + + import scala.concurrent.ExecutionContext.Implicits.global + Future { + newSparkSession.read.format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("query", "select count(*) from table(generator(timelimit=>100))") + .load().show() + } + Thread.sleep(10000) + var queries = SparkConnectorContext.getRunningQueries.get(appId) + assert(queries.isDefined) + assert(queries.get.size == 1) + newSparkSession.stop() + Thread.sleep(10000) + queries = SparkConnectorContext.getRunningQueries.get(appId) + SparkConnectorContext.closedApplicationIDs.contains(appId) + assert(queries.isEmpty) + } + } From 589333840534e75c970c7c240bcc262136d2a0a0 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Fri, 2 Feb 2024 13:18:55 -0800 Subject: [PATCH 3/5] fix test --- .../snowflake/spark/snowflake/SparkConnectorContextSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala index 54d85f64..d0695431 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala @@ -287,6 +287,7 @@ class SparkConnectorContextSuite extends IntegrationSuiteBase { Future { newSparkSession.read.format(SNOWFLAKE_SOURCE_NAME) .options(connectorOptionsNoTable) + .option(Parameters.PARAM_SUPPORT_SHARE_CONNECTION, "false") .option("query", "select count(*) from table(generator(timelimit=>100))") .load().show() } From ee7719fd6a7f0823cf85d55c61182d0cf216e6ff Mon Sep 17 00:00:00 2001 From: Bing Li Date: Fri, 2 Feb 2024 14:11:05 -0800 Subject: [PATCH 4/5] fix test --- .../SparkConnectorContextSuite.scala | 35 +++++++++++++------ 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala index d0695431..c211b396 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala @@ -278,28 +278,41 @@ class SparkConnectorContextSuite extends IntegrationSuiteBase { } test("Disable retry after application closed") { - - val newSparkSession = createDefaultSparkSession - val sc = newSparkSession.sparkContext + val sc = sparkSession.sparkContext val appId = sc.applicationId + val df = sparkSession.read.format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option(Parameters.PARAM_SUPPORT_SHARE_CONNECTION, "false") + .option("query", "select count(*) from table(generator(timelimit=>100))") + .load() + import scala.concurrent.ExecutionContext.Implicits.global - Future { - newSparkSession.read.format(SNOWFLAKE_SOURCE_NAME) - .options(connectorOptionsNoTable) - .option(Parameters.PARAM_SUPPORT_SHARE_CONNECTION, "false") - .option("query", "select count(*) from table(generator(timelimit=>100))") - .load().show() + val f = Future { + df.collect() } Thread.sleep(10000) var queries = SparkConnectorContext.getRunningQueries.get(appId) assert(queries.isDefined) assert(queries.get.size == 1) - newSparkSession.stop() + sparkSession.stop() Thread.sleep(10000) queries = SparkConnectorContext.getRunningQueries.get(appId) SparkConnectorContext.closedApplicationIDs.contains(appId) assert(queries.isEmpty) - } + // Wait for child thread done to avoid affect other test cases. + Await.ready(f, Duration.Inf) + + // Recreate spark session to avoid affect following test cases + sparkSession = SparkSession.builder + .master("local") + .appName("SnowflakeSourceSuite") + .config("spark.sql.shuffle.partitions", "6") + // "spark.sql.legacy.timeParserPolicy = LEGACY" is added to allow + // spark 3.0 to support legacy conversion for unix_timestamp(). + // It may not be necessary for spark 2.X. + .config("spark.sql.legacy.timeParserPolicy", "LEGACY") + .getOrCreate() + } } From f15815afe49a248ca4b9a2d8a85876b44b54a733 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Fri, 2 Feb 2024 15:01:50 -0800 Subject: [PATCH 5/5] fix test --- .../SparkConnectorContextSuite.scala | 45 +++---------------- 1 file changed, 5 insertions(+), 40 deletions(-) diff --git a/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala index c211b396..8b75f366 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/SparkConnectorContextSuite.scala @@ -243,7 +243,11 @@ class SparkConnectorContextSuite extends IntegrationSuiteBase { // Stop the application, it will trigger the Application End event. sparkSession.stop() - Thread.sleep(5000) + Thread.sleep(10000) + + // no query can be retried after session closed + assert(SparkConnectorContext.closedApplicationIDs.contains(appId)) + assert(!SparkConnectorContext.getRunningQueries.contains(appId)) var (message, queryText) = getQueryMessage(conn, queryID, sessionID) var tryCount: Int = 0 @@ -276,43 +280,4 @@ class SparkConnectorContextSuite extends IntegrationSuiteBase { Await.ready(f2, Duration.Inf) } } - - test("Disable retry after application closed") { - val sc = sparkSession.sparkContext - val appId = sc.applicationId - - val df = sparkSession.read.format(SNOWFLAKE_SOURCE_NAME) - .options(connectorOptionsNoTable) - .option(Parameters.PARAM_SUPPORT_SHARE_CONNECTION, "false") - .option("query", "select count(*) from table(generator(timelimit=>100))") - .load() - - import scala.concurrent.ExecutionContext.Implicits.global - val f = Future { - df.collect() - } - Thread.sleep(10000) - var queries = SparkConnectorContext.getRunningQueries.get(appId) - assert(queries.isDefined) - assert(queries.get.size == 1) - sparkSession.stop() - Thread.sleep(10000) - queries = SparkConnectorContext.getRunningQueries.get(appId) - SparkConnectorContext.closedApplicationIDs.contains(appId) - assert(queries.isEmpty) - - // Wait for child thread done to avoid affect other test cases. - Await.ready(f, Duration.Inf) - - // Recreate spark session to avoid affect following test cases - sparkSession = SparkSession.builder - .master("local") - .appName("SnowflakeSourceSuite") - .config("spark.sql.shuffle.partitions", "6") - // "spark.sql.legacy.timeParserPolicy = LEGACY" is added to allow - // spark 3.0 to support legacy conversion for unix_timestamp(). - // It may not be necessary for spark 2.X. - .config("spark.sql.legacy.timeParserPolicy", "LEGACY") - .getOrCreate() - } }