From fae39b659e8e4a71f244490366753d72d2d5b720 Mon Sep 17 00:00:00 2001 From: Jalpan Randeri Date: Thu, 30 Nov 2023 23:00:54 -0800 Subject: [PATCH] Support for AQE mode for delayed query pushdown + Enhanced debugging This commit handles the scenario where Apache Spark is running under AQE mode, snowflake connector to delayed pushdown. This allows spark to generate more optimium query plan. This results in improved performance. Furthermore, it logs the pushdown query into spark plan. This allow easy debugging from Spark History Server and UIs. --- .../snowflake/SnowflakeJDBCWrapper.scala | 10 ++- .../net/snowflake/spark/snowflake/Utils.scala | 6 +- .../pushdowns/SnowflakeScanExec.scala | 75 +++++++++++++++++++ .../pushdowns/SnowflakeStrategy.scala | 19 +++-- .../querygeneration/QueryBuilder.scala | 11 ++- .../spark/snowflake/SparkQuerySuite.scala | 62 +++++++++++++++ 6 files changed, 172 insertions(+), 11 deletions(-) create mode 100644 src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeScanExec.scala create mode 100644 src/test/scala/net/snowflake/spark/snowflake/SparkQuerySuite.scala diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala index bdeaeedc..a2fe045d 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala @@ -39,11 +39,13 @@ import scala.util.Try * Shim which exposes some JDBC helper functions. Most of this code is copied from Spark SQL, with * minor modifications for Snowflake-specific features and limitations. */ -private[snowflake] class JDBCWrapper { +private[snowflake] class JDBCWrapper extends Serializable { private val log = LoggerFactory.getLogger(getClass) - private val ec: ExecutionContext = { + // Note: marking field `implicit transient lazy` this allows spark to + // recreate upon de-serialization + @transient implicit private lazy val ec: ExecutionContext = { log.debug("Creating a new ExecutionContext") val threadFactory: ThreadFactory = new ThreadFactory { private[this] val count = new AtomicInteger() @@ -353,7 +355,7 @@ private[snowflake] class JDBCWrapper { TelemetryClient.createTelemetry(conn.jdbcConnection) } -private[snowflake] object DefaultJDBCWrapper extends JDBCWrapper { +private[snowflake] object DefaultJDBCWrapper extends JDBCWrapper with Serializable { private val LOGGER = LoggerFactory.getLogger(getClass.getName) @@ -588,7 +590,7 @@ private[snowflake] object DefaultJDBCWrapper extends JDBCWrapper { private[snowflake] class SnowflakeSQLStatement( val numOfVar: Int = 0, val list: List[StatementElement] = Nil -) { +) extends Serializable { private val log = LoggerFactory.getLogger(getClass) diff --git a/src/main/scala/net/snowflake/spark/snowflake/Utils.scala b/src/main/scala/net/snowflake/spark/snowflake/Utils.scala index 7eaed10a..00a29b5c 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Utils.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Utils.scala @@ -20,7 +20,6 @@ package net.snowflake.spark.snowflake import java.net.URI import java.sql.{Connection, ResultSet} import java.util.{Properties, UUID} - import net.snowflake.client.jdbc.{SnowflakeDriver, SnowflakeResultSet, SnowflakeResultSetSerializable} import net.snowflake.spark.snowflake.Parameters.MergedParameters import org.apache.spark.{SPARK_VERSION, SparkContext, SparkEnv} @@ -37,6 +36,7 @@ import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.node.Object import net.snowflake.spark.snowflake.FSType.FSType import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.{StructField, StructType} import org.slf4j.LoggerFactory @@ -77,6 +77,10 @@ object Utils { } else { "" } + private[snowflake] lazy val lazyMode = SparkSession.active + .conf + .get("spark.snowflakedb.lazyModeForAQE", "true") + .toBoolean private[snowflake] lazy val scalaVersion = util.Properties.versionNumberString private[snowflake] lazy val javaVersion = diff --git a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeScanExec.scala b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeScanExec.scala new file mode 100644 index 00000000..afee8109 --- /dev/null +++ b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeScanExec.scala @@ -0,0 +1,75 @@ +package net.snowflake.spark.snowflake.pushdowns + +import net.snowflake.spark.snowflake.{SnowflakeRelation, SnowflakeSQLStatement} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.execution.LeafExecNode + +import java.util.concurrent.{Callable, ExecutorService, Executors, Future} + +/** + * Snowflake Scan Plan for pushing query fragment to snowflake endpoint + * + * @param projection projected columns + * @param snowflakeSQL SQL query that is pushed to snowflake for evaluation + * @param relation Snowflake datasource + */ +case class SnowflakeScanExec(projection: Seq[Attribute], + snowflakeSQL: SnowflakeSQLStatement, + relation: SnowflakeRelation) extends LeafExecNode { + // result holder + @transient implicit private var data: Future[PushDownResult] = _ + @transient implicit private val service: ExecutorService = Executors.newCachedThreadPool() + + override protected def doPrepare(): Unit = { + logInfo(s"Preparing query to push down - $snowflakeSQL") + + val work = new Callable[PushDownResult]() { + override def call(): PushDownResult = { + val result = { + try { + val data = relation.buildScanFromSQL[InternalRow](snowflakeSQL, Some(schema)) + PushDownResult(data = Some(data)) + } catch { + case e: Exception => + logError("Failure in query execution", e) + PushDownResult(failure = Some(e)) + } + } + result + } + } + data = service.submit(work) + logInfo("submitted query asynchronously") + } + + override protected def doExecute(): RDD[InternalRow] = { + if (data.get().failure.nonEmpty) { + // raise original exception + throw data.get().failure.get + } + + data.get().data.get.mapPartitions { iter => + val project = UnsafeProjection.create(schema) + iter.map(project) + } + } + + override def cleanupResources(): Unit = { + logDebug(s"shutting down service to clean up resources") + service.shutdown() + } + + override def output: Seq[Attribute] = projection +} + +/** + * Result holder + * + * @param data RDD that holds the data + * @param failure failure information if we unable to push down + */ +private case class PushDownResult(data: Option[RDD[InternalRow]] = None, + failure: Option[Exception] = None) + extends Serializable \ No newline at end of file diff --git a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeStrategy.scala b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeStrategy.scala index ddb0b577..bc2ce849 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeStrategy.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeStrategy.scala @@ -1,6 +1,6 @@ package net.snowflake.spark.snowflake.pushdowns -import net.snowflake.spark.snowflake.SnowflakeConnectorFeatureNotSupportException +import net.snowflake.spark.snowflake.{SnowflakeConnectorFeatureNotSupportException, Utils} import net.snowflake.spark.snowflake.pushdowns.querygeneration.QueryBuilder import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.plans.logical._ @@ -38,9 +38,18 @@ class SnowflakeStrategy extends Strategy { * @return An Option of Seq[SnowflakePlan] that contains the PhysicalPlan if * query generation was successful, None if not. */ - private def buildQueryRDD(plan: LogicalPlan): Option[Seq[SnowflakePlan]] = - QueryBuilder.getRDDFromPlan(plan).map { - case (output: Seq[Attribute], rdd: RDD[InternalRow]) => - Seq(SnowflakePlan(output, rdd)) + private def buildQueryRDD(plan: LogicalPlan): Option[Seq[SparkPlan]] = { + if (Utils.lazyMode) { + logInfo("Using lazy mode for push down") + QueryBuilder.getSnowflakeScanExecPlan(plan).map { + case (projection, snowflakeSQL, relation) => + Seq(SnowflakeScanExec(projection, snowflakeSQL, relation)) + } + } else { + QueryBuilder.getRDDFromPlan(plan).map { + case (output: Seq[Attribute], rdd: RDD[InternalRow]) => + Seq(SnowflakePlan(output, rdd)) + } } + } } diff --git a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/QueryBuilder.scala b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/QueryBuilder.scala index 1a1b847a..86c473c5 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/QueryBuilder.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/QueryBuilder.scala @@ -1,7 +1,7 @@ package net.snowflake.spark.snowflake.pushdowns.querygeneration + import java.io.{PrintWriter, StringWriter} -import java.util.NoSuchElementException import net.snowflake.spark.snowflake.{ ConnectionCacheKey, @@ -307,4 +307,13 @@ private[snowflake] object QueryBuilder { (executedBuilder.getOutput, executedBuilder.rdd) } } + + def getSnowflakeScanExecPlan(plan: LogicalPlan): + Option[(Seq[Attribute], SnowflakeSQLStatement, SnowflakeRelation)] = { + val qb = new QueryBuilder(plan) + + qb.tryBuild.map { executedBuilder => + (executedBuilder.getOutput, executedBuilder.statement, executedBuilder.source.relation) + } + } } diff --git a/src/test/scala/net/snowflake/spark/snowflake/SparkQuerySuite.scala b/src/test/scala/net/snowflake/spark/snowflake/SparkQuerySuite.scala new file mode 100644 index 00000000..be34c46c --- /dev/null +++ b/src/test/scala/net/snowflake/spark/snowflake/SparkQuerySuite.scala @@ -0,0 +1,62 @@ +package net.snowflake.spark.snowflake + +import net.snowflake.spark.snowflake.pushdowns.SnowflakeScanExec +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.{ExplainMode, FormattedMode} +import org.scalatest.{BeforeAndAfter, FunSuite} + +class SparkQuerySuite extends FunSuite with BeforeAndAfter { + private var spark: SparkSession = _ + + before { + spark = SparkSession + .builder() + .master("local[2]") + .getOrCreate() + } + + after { + spark.stop() + } + + test("pushdown scan to snowflake") { + spark.sql( + """ + CREATE TABLE student(name string) + USING net.snowflake.spark.snowflake + OPTIONS (dbtable 'default.student', + sfdatabase 'sf-db', + tempdir '/tmp/dir', + sfurl 'accountname.snowflakecomputing.com:443', + sfuser 'alice', + sfpassword 'hello-snowflake') + """).show() + + val df = spark.sql( + """ + |SELECT * + | FROM student + |""".stripMargin) + val plan = df.queryExecution.executedPlan + + assert(plan.isInstanceOf[SnowflakeScanExec]) + val sfPlan = plan.asInstanceOf[SnowflakeScanExec] + assert(sfPlan.snowflakeSQL.toString == + """SELECT * FROM ( default.student ) AS "SF_CONNECTOR_QUERY_ALIAS"""" + .stripMargin) + + // explain plan + val planString = df.queryExecution.explainString(FormattedMode) + val expectedString = + """== Physical Plan == + |SnowflakeScan (1) + | + | + |(1) SnowflakeScan + |Output [1]: [name#1] + |Arguments: [name#1], SELECT * FROM ( default.student ) AS "SF_CONNECTOR_QUERY_ALIAS", SnowflakeRelation + """.stripMargin + assert(planString.trim == expectedString.trim) + } + +}