diff --git a/ClusterTest/build.sbt b/ClusterTest/build.sbt index de1b2859..f027b77b 100644 --- a/ClusterTest/build.sbt +++ b/ClusterTest/build.sbt @@ -36,7 +36,6 @@ lazy val root = project.withId("spark-snowflake").in(file(".")) resolvers += "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", libraryDependencies ++= Seq( - "net.snowflake" % "snowflake-ingest-sdk" % "0.10.8", "net.snowflake" % "snowflake-jdbc" % "3.16.0", "org.apache.commons" % "commons-lang3" % "3.5" % "provided, runtime", "org.apache.spark" %% "spark-core" % testSparkVersion % "provided, runtime", diff --git a/build.sbt b/build.sbt index e3e7ce1d..1260f10c 100644 --- a/build.sbt +++ b/build.sbt @@ -59,7 +59,6 @@ lazy val root = project.withId("spark-snowflake").in(file(".")) resolvers += "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", libraryDependencies ++= Seq( - "net.snowflake" % "snowflake-ingest-sdk" % "0.10.8", "net.snowflake" % "snowflake-jdbc" % "3.16.0", "org.scalatest" %% "scalatest" % "3.1.1" % Test, "org.mockito" % "mockito-core" % "1.10.19" % Test, diff --git a/src/it/scala/net/snowflake/spark/snowflake/streaming/SinkUtilsSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/streaming/SinkUtilsSuite.scala deleted file mode 100644 index 66e5e4bf..00000000 --- a/src/it/scala/net/snowflake/spark/snowflake/streaming/SinkUtilsSuite.scala +++ /dev/null @@ -1,37 +0,0 @@ -package net.snowflake.spark.snowflake.streaming - -import net.snowflake.spark.snowflake.{ConstantString, IntegrationSuiteBase} -import net.snowflake.spark.snowflake.DefaultJDBCWrapper.DataBaseOperations -import scala.language.postfixOps - -class SinkUtilsSuite extends IntegrationSuiteBase { - - val pipeName = s"test_pipe_$randomSuffix" - val stageName = s"test_stage_$randomSuffix" - val tableName = s"test_table_$randomSuffix" - - override def beforeAll(): Unit = { - super.beforeAll() - jdbcUpdate(s"create table $tableName(num int)") - jdbcUpdate(s"create stage $stageName") - } - - override def afterAll(): Unit = { - jdbcUpdate(s"drop pipe if exists $pipeName") - jdbcUpdate(s"drop stage if exists $stageName") - jdbcUpdate(s"drop table if exists $tableName") - super.afterAll() - } - - test("verifyPipe") { - - val copy = s"copy into $tableName from @$stageName" - val overwrite = true - conn.createPipe(pipeName, ConstantString(copy) !, overwrite) - - assert(verifyPipe(conn, pipeName, copy)) - - assert(!verifyPipe(conn, pipeName, copy + "something")) - } - -} diff --git a/src/it/scala/net/snowflake/spark/snowflake/streaming/StreamingSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/streaming/StreamingSuite.scala deleted file mode 100644 index a7c3520c..00000000 --- a/src/it/scala/net/snowflake/spark/snowflake/streaming/StreamingSuite.scala +++ /dev/null @@ -1,514 +0,0 @@ -package net.snowflake.spark.snowflake.streaming - -import java.io.{DataOutputStream, File} -import java.net.ServerSocket -import java.nio.charset.Charset -import java.util.concurrent.TimeUnit - -import net.snowflake.spark.snowflake.DefaultJDBCWrapper.DataBaseOperations -import net.snowflake.spark.snowflake.Utils.SNOWFLAKE_SOURCE_NAME -import net.snowflake.spark.snowflake.io.{CloudStorage, CloudStorageOperations} -import net.snowflake.spark.snowflake.{DefaultJDBCWrapper, IntegrationSuiteBase, Utils} -import org.apache.spark.sql.Row -import org.apache.spark.sql.streaming.Trigger -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} -import org.scalactic.source.Position -import org.scalatest.Tag - -import scala.util.Random - -class StreamingSuite extends IntegrationSuiteBase { - - // This test suite only run when env EXTRA_TEST_FOR_COVERAGE is set as true - override def test(testName: String, testTags: Tag*)( - testFun: => Any - )(implicit pos: Position): Unit = { - if (extraTestForCoverage) { - super.test(testName, testTags: _*)(testFun)(pos) - } else { - super.ignore(testName, testTags: _*)(testFun)(pos) - } - } - - val streamingTable = s"test_table_streaming_$randomSuffix" - val streamStagePrefix = "STREAMING_TEST_STAGE_" - val streamingStage = s"$streamStagePrefix$randomSuffix" - // Random port number [10001-60000] - val testServerPort = Random.nextInt(50000) + 10001 - - private class NetworkService(val port: Int, - val data: Seq[String], - val sleepBeforeAll: Int = 10, - val sleepAfterAll: Int = 1, - val sleepAfterEach: Int = 5) - extends Runnable { - val serverSocket = new ServerSocket(port) - - def run(): Unit = { - val socket = serverSocket.accept() - val output = new DataOutputStream(socket.getOutputStream) - Thread.sleep(sleepBeforeAll * 1000) - data.foreach(d => { - // scalastyle:off println - println(s"send message: $d") - s"$d\n".getBytes(Charset.forName("UTF-8")).foreach(x => output.write(x)) - // scalastyle:on println - Thread.sleep(sleepAfterEach * 1000) - }) - Thread.sleep(sleepAfterAll * 1000) - socket.close() - - } - } - - private class NetworkService2(val port: Int, - val times: Int = 1000, - val sleepBeforeAll: Int = 10, - val sleepAfterEach: Int = 5) - extends Runnable { - val serverSocket = new ServerSocket(port) - - def run(): Unit = { - val socket = serverSocket.accept() - val output = new DataOutputStream(socket.getOutputStream) - Thread.sleep(sleepBeforeAll * 1000) - - (0 until times).foreach(index => { - val word = Random.alphanumeric.take(10).mkString("") - // scalastyle:off println - println(s"send message: $index $word") - // scalastyle:on println - s"$word\n" - .getBytes(Charset.forName("UTF-8")) - .foreach(x => output.write(x)) - Thread.sleep(sleepAfterEach * 1000) - }) - - socket.close() - - } - } - - def removeDirectory(dir: File): Unit = { - if (dir.isDirectory) { - val files = dir.listFiles - files.foreach(removeDirectory) - dir.delete - } else dir.delete - } - - def checkTestTable(expectedAnswer: Seq[Row]): Unit = { - val loadedDf = sparkSession.read - .format(SNOWFLAKE_SOURCE_NAME) - .options(connectorOptionsNoTable) - .option("query", s"select * from $streamingTable order by value") - .load() - checkAnswer(loadedDf, expectedAnswer) - } - - override def afterAll(): Unit = { - conn.dropTable(streamingTable) - // NOTE: The stage name can't be dropped here. - // The Ingest is multiple-threads, if the stage is dropped in afterAll(). - // The Ingest thread may need access it. It will fail. - // dropOldStages() is used to drop old stages in the begin of test. - // conn.dropStage(streamingStage) - super.afterAll() - } - - // manual test only - ignore("test") { - val spark = sparkSession - import spark.implicits._ - - DefaultJDBCWrapper.executeQueryInterruptibly( - conn, - "create or replace table stream_test_table (value string)" - ) - - val lines = spark.readStream - .format("socket") - .option("host", "localhost") - .option("port", testServerPort) - .load() - val words = lines.as[String].flatMap(_.split(" ")) - - val checkpoint = "check" - removeDirectory(new File(checkpoint)) - - new Thread( - new NetworkService2(testServerPort, sleepBeforeAll = 5, sleepAfterEach = 5) - ).start() - - val query = words.writeStream - .outputMode("append") - .option("checkpointLocation", checkpoint) - .options(connectorOptionsNoTable) - .option("dbtable", "stream_test_table") - .option("streaming_stage", "streaming_test_stage") - .format(SNOWFLAKE_SOURCE_NAME) - .start() - - query.awaitTermination() - - } - - // Wait for spark streaming write done. - def waitForWriteDone(tableName: String, - expectedRowCount: Int, - maxWaitTimeInMs: Int, - intervalInMs: Int): Boolean = { - var result = false - Thread.sleep(intervalInMs) - var sleepTime = intervalInMs - while (sleepTime < maxWaitTimeInMs && !result) { - val rs = Utils.runQuery(connectorOptionsNoExternalStageNoTable, - query = s"select count(*) from $tableName") - rs.next() - val rowCount = rs.getLong(1) - // scalastyle:off println - println(s"Get row count: $rowCount : retry ${sleepTime/intervalInMs}") - // scalastyle:off println - if (rowCount >= expectedRowCount) { - // Sleep done. - result = true - } else { - sleepTime += intervalInMs - Thread.sleep(intervalInMs) - } - } - result - } - - test("Test streaming writer") { - // Drop stages which were create 48 hours ago - dropOldStages(streamStagePrefix, 48) - - val spark = sparkSession - import spark.implicits._ - - conn.createStage(name = streamingStage, overwrite = true) - - DefaultJDBCWrapper.executeQueryInterruptibly( - conn, - s"create or replace table $streamingTable (value string)" - ) - - val lines = spark.readStream - .format("socket") - .option("host", "localhost") - .option("port", testServerPort) - .load() - val words = lines.as[String].flatMap(_.split(" ")) - - val output = words.toDF("VALUE") - - val checkpoint = "check" - removeDirectory(new File(checkpoint)) - - new Thread( - new NetworkService( - testServerPort, - Seq( - "one two", - "three four five", - "six seven eight night ten", - "1 2 3 4 5", - "6 7 8 9 0" - ), - sleepBeforeAll = 10, - sleepAfterAll = 1, - sleepAfterEach = 1 - ) - ).start() - - val query = output.writeStream - .outputMode("append") - .option("checkpointLocation", checkpoint) - .options(connectorOptionsNoExternalStageNoTable) - .option("dbtable", streamingTable) - .option("streaming_stage", streamingStage) - .format(SNOWFLAKE_SOURCE_NAME) - .start() - - // Max wait time: 5 minutes, retry interval: 30 seconds. - if (!waitForWriteDone(streamingTable, 20, 300000, 30000)) { - // scalastyle:off println - println(s"Don't find expected row count in target table. " + - s"The root cause could be the snow pipe is too slow. " + - "It is a SC issue, only if it is reproduced consistently.") - // scalastyle:on println - } - query.stop() - query.awaitTermination(10000) - - checkTestTable( - Seq( - Row("0"), - Row("1"), - Row("2"), - Row("3"), - Row("4"), - Row("5"), - Row("6"), - Row("7"), - Row("8"), - Row("9"), - Row("eight"), - Row("five"), - Row("four"), - Row("night"), - Row("one"), - Row("seven"), - Row("six"), - Row("ten"), - Row("three"), - Row("two") - ) - ) - - } - - ignore("kafka") { - - val spark = sparkSession - - conn.createStage(name = streamingStage, overwrite = true) - - DefaultJDBCWrapper.executeQueryInterruptibly( - conn, - s"create or replace table $streamingTable (key string, value string)" - ) - - val checkpoint = "check" - removeDirectory(new File(checkpoint)) - - val df = spark.readStream - .format("kafka") - .option("kafka.bootstrap.servers", "localhost:9092") - .option("subscribe", "test") - .load() - - val query = df - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .writeStream - .outputMode("append") - .option("checkpointLocation", "check") - .options(connectorOptionsNoTable) - .option("dbtable", streamingTable) - .option("streaming_stage", streamingStage) - .format(SNOWFLAKE_SOURCE_NAME) - .start() - - query.awaitTermination() - - } - - ignore("kafka1") { - - val spark = sparkSession - val streamingTable = "streaming_test_table" - - conn.createStage(name = streamingStage, overwrite = true) - - DefaultJDBCWrapper.executeQueryInterruptibly( - conn, - s"create or replace table $streamingTable (key string, value string)" - ) - - val checkpoint = "check" - removeDirectory(new File(checkpoint)) - - val df = spark.readStream - .format("kafka") - .option("kafka.bootstrap.servers", "localhost:9092") - .option("subscribe", "test") - .load() - - // scalastyle:off println - println("-------------------------------------------------") - // scalastyle:on println - - var query = df - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .writeStream - .outputMode("append") - .option("checkpointLocation", "check") - .options(connectorOptionsNoTable) - .option("dbtable", streamingTable) - .option("streaming_stage", streamingStage) - .format(SNOWFLAKE_SOURCE_NAME) - .trigger(Trigger.ProcessingTime(1000)) - .start() - - Thread.sleep(20000) - - // scalastyle:off println - println("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^") - // scalastyle:on println - query.stop() - - query = df - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .writeStream - .outputMode("append") - .option("checkpointLocation", "check") - .options(connectorOptionsNoTable) - .option("dbtable", streamingTable) - .option("streaming_stage", streamingStage) - .format(SNOWFLAKE_SOURCE_NAME) - .trigger(Trigger.ProcessingTime(1000)) - .start() - - Thread.sleep(600000) - spark.close() - - } - - test("test context") { - val tempStage = false - val storage: CloudStorage = CloudStorageOperations - .createStorageClient( - params, - conn, - tempStage, - Some("spark_streaming_test_stage") - ) - ._1 - // val log = IngestLogManager.readIngestList(storage, conn) - - val failed = IngestContextManager.readFailedFileList(0, storage, conn) - failed.addFiles(List("a", "b", "c")) - - val failed1 = IngestContextManager.readFailedFileList(0, storage, conn) - - // scalastyle:off println - println(failed1.toString) - // scalastyle:on println - - } - ignore("kafka2") { - - import org.apache.spark.sql.functions._ - val KafkaLoggingTopic = sparkSession.readStream - .format("kafka") - .option("kafka.bootstrap.servers", "localhost:9092") - .option("subscribe", "test") - .load() - - val spark = sparkSession - import spark.implicits._ - - val KafkaLoggingTopicDF = - KafkaLoggingTopic.select($"value".cast("string")).as("event") - - val loggingSchema = new StructType() - .add("url", StringType) - .add("method", StringType) - .add("action", StringType) - .add("timestamp", StringType) - .add("esbTransactionId", IntegerType) - .add("esbTransactionGuid", StringType) - .add("executionTime", IntegerType) - .add("serverName", StringType) - .add("sourceIp", StringType) - .add("eventName", StringType) - .add("operationName", StringType) - .add( - "header", - new StructType() - .add("request", StringType) - .add("response", StringType) - ) - .add( - "body", - new StructType() - .add("request", StringType) - .add("response", StringType) - ) - .add( - "indexed", - new StructType() - .add("openTimeout", StringType) - .add("readTimeout", StringType) - .add("threadId", StringType) - ) - .add( - "details", - new StructType() - .add("soapAction", StringType) - .add("artifactType", StringType) - .add("reason", StringType) - .add("requestId", StringType) - .add("success", StringType) - .add("backendOwner", StringType) - ) - .add( - "syslog", - new StructType() - .add("appName", StringType) - .add("facility", StringType) - .add("host", StringType) - .add("priority", StringType) - .add("severity", StringType) - .add("timestamp", StringType) - ) - - val loggingSchemaDF = KafkaLoggingTopicDF - .select(from_json('value, loggingSchema) as 'event) - .select("event.*") - - loggingSchemaDF.writeStream - .outputMode("append") - .options(connectorOptionsNoTable) - .option("checkpointLocation", "check") - // .trigger(Trigger.ProcessingTime(1, TimeUnit.SECONDS)) - .option("dbtable", "streaming_test") - .option("streaming_stage", "streaming_test") - .format(SNOWFLAKE_SOURCE_NAME) - .start() - .awaitTermination() - } - - ignore("kafka 3") { - - import org.apache.spark.sql.functions._ - val KafkaLoggingTopic = sparkSession.readStream - .format("kafka") - .option("kafka.bootstrap.servers", "localhost:9092") - .option("subscribe", "test") - .load() - - val spark = sparkSession - import spark.implicits._ - - val KafkaLoggingTopicDF = - KafkaLoggingTopic.select($"value".cast("string")).as("event") - - val loggingSchema = new StructType() - .add("url", StringType) - .add( - "header", - new StructType() - .add("request", StringType) - .add("response", StringType) - ) - - val loggingSchemaDF = KafkaLoggingTopicDF - .select(from_json('value, loggingSchema) as 'event) - .select("event.*") - - loggingSchemaDF.writeStream - .outputMode("append") - .options(connectorOptionsNoTable) - .option("checkpointLocation", "check") - .trigger(Trigger.ProcessingTime(1, TimeUnit.SECONDS)) - .option("dbtable", "streaming_test") - .option("column_mapping", "name") - .option("streaming_stage", "streaming_test") - .format(SNOWFLAKE_SOURCE_NAME) - .start() - .awaitTermination() - } - -} diff --git a/src/main/scala/net/snowflake/spark/snowflake/DefaultSource.scala b/src/main/scala/net/snowflake/spark/snowflake/DefaultSource.scala index 0376cd11..f6b5e332 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/DefaultSource.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/DefaultSource.scala @@ -17,11 +17,8 @@ package net.snowflake.spark.snowflake -import net.snowflake.spark.snowflake.streaming.SnowflakeSink import net.snowflake.spark.snowflake.Utils.SNOWFLAKE_SOURCE_SHORT_NAME -import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.sources._ -import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} import org.slf4j.LoggerFactory @@ -38,7 +35,6 @@ class DefaultSource(jdbcWrapper: JDBCWrapper) extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider - with StreamSinkProvider with DataSourceRegister { override def shortName(): String = SNOWFLAKE_SOURCE_SHORT_NAME @@ -145,10 +141,4 @@ class DefaultSource(jdbcWrapper: JDBCWrapper) createRelation(sqlContext, parameters) } - - override def createSink(sqlContext: SQLContext, - parameters: Map[String, String], - partitionColumns: Seq[String], - outputMode: OutputMode): Sink = - new SnowflakeSink(sqlContext, parameters, partitionColumns, outputMode) } diff --git a/src/main/scala/net/snowflake/spark/snowflake/streaming/SnowflakeIngestConnector.scala b/src/main/scala/net/snowflake/spark/snowflake/streaming/SnowflakeIngestConnector.scala deleted file mode 100644 index 15669622..00000000 --- a/src/main/scala/net/snowflake/spark/snowflake/streaming/SnowflakeIngestConnector.scala +++ /dev/null @@ -1,121 +0,0 @@ -package net.snowflake.spark.snowflake.streaming - -import java.security.PrivateKey -import java.text.SimpleDateFormat -import java.util.{Date, TimeZone} - -import net.snowflake.ingest.SimpleIngestManager -import net.snowflake.ingest.connection.IngestStatus -import net.snowflake.ingest.utils.StagedFileWrapper -import net.snowflake.spark.snowflake.Parameters.MergedParameters - -import scala.collection.JavaConverters._ -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.duration._ -import scala.concurrent.{Await, Future, TimeoutException} - -object SnowflakeIngestConnector { - /** - * list of (fileName, loading Succeed or not - */ - def createHistoryChecker( - ingestManager: SimpleIngestManager - ): () => List[(String, IngestStatus)] = { - var beginMark: String = null - () => - { - val response = ingestManager.getHistory(null, null, beginMark) - beginMark = - Option[String](response.getNextBeginMark).getOrElse(beginMark) - if (response != null && response.files != null) { - response.files.asScala.toList.flatMap(entry => { - if (entry.getPath != null && entry.isComplete) { - List((entry.getPath, entry.getStatus)) - } else Nil - }) - } else Nil - } - } - - def checkHistoryByRange(ingestManager: SimpleIngestManager, - start: Long, - end: Long): List[(String, IngestStatus)] = { - val response = ingestManager - .getHistoryRange(null, timestampToDate(start), timestampToDate(end)) - if (response != null && response.files != null) { - response.files.asScala.toList.flatMap(entry => { - if (entry.getPath != null && entry.isComplete) { - List((entry.getPath, entry.getStatus)) - } else Nil - }) - } else Nil - } - - /** - * timestamp to ISO-8601 Date - */ - private def timestampToDate(time: Long): String = { - val tz = TimeZone.getTimeZone("UTC") - val df = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss'Z'") - df.setTimeZone(tz) - df.format(new Date(time - 1000)) // 1 sec before - } - - def createIngestManager(account: String, - user: String, - pipe: String, - host: String, - privateKey: PrivateKey, - port: Int = 443, - scheme: String = "https"): SimpleIngestManager = - new SimpleIngestManager(account, user, pipe, privateKey, scheme, host, port) - - def ingestFiles( - files: List[String] - )(implicit manager: SimpleIngestManager): Unit = - manager.ingestFiles(files.map(new StagedFileWrapper(_)).asJava, null) - - def createIngestManager(param: MergedParameters, - pipeName: String): SimpleIngestManager = { - val urlPattern = "^(https?://)?([^:]+)(:\\d+)?$".r - val portPattern = ":(\\d+)".r - val accountPattern = "([^.]+).+".r - - param.sfURL.trim match { - case urlPattern(_, host, portStr) => - val scheme: String = if (param.isSslON) "https" else "http" - - val port: Int = - if (portStr != null) { - val portPattern(t) = portStr - t.toInt - } else if (param.isSslON) 443 - else 80 - - val accountPattern(account) = host - - require( - param.privateKey.isDefined, - "PEM Private key must be specified with 'pem_private_key' parameter" - ) - - val privateKey = param.privateKey.get - - val pipe = s"${param.sfDatabase}.${param.sfSchema}.$pipeName" - - createIngestManager( - account, - param.sfUser, - pipe, - host, - privateKey, - port, - scheme - ) - - case _ => - throw new IllegalArgumentException("incorrect url: " + param.sfURL) - } - } - -} diff --git a/src/main/scala/net/snowflake/spark/snowflake/streaming/SnowflakeIngestService.scala b/src/main/scala/net/snowflake/spark/snowflake/streaming/SnowflakeIngestService.scala deleted file mode 100644 index d4a81c22..00000000 --- a/src/main/scala/net/snowflake/spark/snowflake/streaming/SnowflakeIngestService.scala +++ /dev/null @@ -1,314 +0,0 @@ -package net.snowflake.spark.snowflake.streaming - -import java.nio.charset.Charset -import net.snowflake.client.jdbc.internal.apache.commons.logging.{Log, LogFactory} -import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper -import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.node.ArrayNode -import net.snowflake.ingest.SimpleIngestManager -import net.snowflake.ingest.connection.IngestStatus -import net.snowflake.spark.snowflake.DefaultJDBCWrapper.DataBaseOperations -import net.snowflake.spark.snowflake.Parameters.MergedParameters -import net.snowflake.spark.snowflake.ServerConnection -import net.snowflake.spark.snowflake.io.CloudStorage - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.duration._ -import scala.concurrent.{Await, Future} -import scala.language.postfixOps - -class SnowflakeIngestService(param: MergedParameters, - pipeName: String, - storage: CloudStorage, - conn: ServerConnection) { - - val SLEEP_TIME: Long = 60 * 1000 // 1m - val HISTORY_CHECK_TIME: Long = 60 * 60 * 1000 // 1h - val WAITING_TIME_ON_TERMINATION: Int = 10 // 10m - - lazy implicit val ingestManager: SimpleIngestManager = - SnowflakeIngestConnector.createIngestManager(param, pipeName) - - private var notClosed: Boolean = true - - private val ingestedFileList: IngestedFileList = init() - - private lazy val checker = - SnowflakeIngestConnector.createHistoryChecker(ingestManager) - - private var pipeDropped = false - - // run clean function periodically - private val process = Future { - while (notClosed) { - Thread.sleep(SLEEP_TIME) - val time = System.currentTimeMillis() - ingestedFileList.checkResponseList(checker()) - if (ingestedFileList.getFirstTimeStamp.isDefined && - time - ingestedFileList.getFirstTimeStamp.get > HISTORY_CHECK_TIME) { - ingestedFileList - .checkResponseList( - SnowflakeIngestConnector - .checkHistoryByRange( - ingestManager, - ingestedFileList.getFirstTimeStamp.get, - time - ) - ) - } - } - cleanAll() - } - - def ingestFiles(list: List[String]): Unit = { - SnowflakeIngestConnector.ingestFiles(list) - ingestedFileList.addFiles(list) - } - - def cleanAll(): Unit = { - while (ingestedFileList.nonEmpty) { - Thread.sleep(SLEEP_TIME) - val time = System.currentTimeMillis() - if (time - ingestedFileList.getFirstTimeStamp.get > 10 * 60 * 1000) { - ingestedFileList - .checkResponseList( - SnowflakeIngestConnector - .checkHistoryByRange( - ingestManager, - ingestedFileList.getFirstTimeStamp.get, - time - ) - ) - } else ingestedFileList.checkResponseList(checker()) - } - conn.dropPipe(pipeName) - ingestedFileList.remove() - pipeDropped = true - - } - - def close(): Unit = { - val ct = System.currentTimeMillis() - IngestContextManager.logger.debug("closing ingest service") - notClosed = false - Await.result(process, WAITING_TIME_ON_TERMINATION minutes) - if (!pipeDropped) { - IngestContextManager.logger.error( - s"closing ingest service time out, please drop pipe: $pipeName manually" - ) - } - - IngestContextManager.logger.debug( - s"ingest service closed: ${(System.currentTimeMillis() - ct) / 1000.0}" - ) - } - - /** - * recover from context files or create new data - */ - private def init(): IngestedFileList = - IngestContextManager.readIngestList(storage, conn) - -} - -object IngestContextManager { - val CONTEXT_DIR = "context" - val INGEST_FILE_LIST_NAME = "ingested_file_list.json" - val FAILED_FILE_INDEX = "failed_file_index" - val LIST = "list" - val NAME = "name" - val TIME = "time" - val mapper = new ObjectMapper() - val logger: Log = LogFactory.getLog(getClass) - - def readIngestList(storage: CloudStorage, - conn: ServerConnection): IngestedFileList = { - val fileName = s"$CONTEXT_DIR/$INGEST_FILE_LIST_NAME" - if (storage.fileExists(fileName)) { - val inputStream = storage.download(fileName, compress = false) - val buffer = ArrayBuffer.empty[Byte] - var c: Int = inputStream.read() - while (c != -1) { - buffer.append(c.toByte) - c = inputStream.read() - } - try { - val node = - mapper.readTree(new String(buffer.toArray, Charset.forName("UTF-8"))) - val failedIndex: Int = node.get(FAILED_FILE_INDEX).asInt() - val failedList: FailedFileList = - readFailedFileList(failedIndex, storage, conn) - val arrNode = node.get(LIST).asInstanceOf[ArrayNode] - var list: List[(String, Long)] = Nil - (0 until arrNode.size()).foreach(i => { - list = arrNode.get(i).get(NAME).asText() -> arrNode - .get(i) - .get(TIME) - .asLong() :: list - }) - IngestedFileList(storage, conn, Some(failedList), Some(list)) - } catch { - case e: Exception => - throw new IllegalArgumentException( - s"context file: $fileName is broken: $e" - ) - } - } else IngestedFileList(storage, conn) - } - - def readFailedFileList(index: Int, - storage: CloudStorage, - conn: ServerConnection): FailedFileList = { - val fileName = s"$CONTEXT_DIR/failed_file_list_$index.json" - if (storage.fileExists(fileName)) { - val inputStream = storage.download(fileName, compress = false) - val buffer = ArrayBuffer.empty[Byte] - var c: Int = inputStream.read() - while (c != -1) { - buffer.append(c.toByte) - c = inputStream.read() - } - try { - val list = mapper - .readTree(new String(buffer.toArray, Charset.forName("UTF-8"))) - .asInstanceOf[ArrayNode] - var set = mutable.HashSet.empty[String] - (0 until list.size()).foreach(i => { - set += list.get(i).asText() - }) - FailedFileList(storage, conn, index, Some(set)) - } catch { - case e: Exception => - throw new IllegalArgumentException( - s"context file: $fileName is broken: $e" - ) - } - } else FailedFileList(storage, conn, index) - } - -} - -sealed trait IngestContext { - - val storage: CloudStorage - - val fileName: String - - val conn: ServerConnection - - def save(): Unit = { - IngestContextManager.logger.debug(s"$fileName:$toString") - val output = - storage.upload(fileName, Some(IngestContextManager.CONTEXT_DIR), compress = false) - output.write(toString.getBytes("UTF-8")) - output.close() - - } - -} - -case class FailedFileList(override val storage: CloudStorage, - override val conn: ServerConnection, - fileIndex: Int = 0, - files: Option[mutable.HashSet[String]] = None) - extends IngestContext { - val MAX_FILE_SIZE: Int = 1000 // how many file names - - private var fileSet: mutable.HashSet[String] = - files.getOrElse(mutable.HashSet.empty[String]) - - override lazy val fileName: String = s"failed_file_list_$fileIndex.json" - - def addFiles(names: List[String]): FailedFileList = { - val part1 = names.slice(0, MAX_FILE_SIZE - fileSet.size) - val part2 = names.slice(MAX_FILE_SIZE - fileSet.size, Int.MaxValue) - - fileSet ++= part1.toSet - save() - if (part2.isEmpty) this - else FailedFileList(storage, conn, fileIndex + 1).addFiles(part2) - } - - override def toString: String = { - val node = IngestContextManager.mapper.createArrayNode() - fileSet.foreach(node.add) - node.toString - } - -} - -case class IngestedFileList(override val storage: CloudStorage, - override val conn: ServerConnection, - failedFileList: Option[FailedFileList] = None, - ingestList: Option[List[(String, Long)]] = None) - extends IngestContext { - override val fileName: String = IngestContextManager.INGEST_FILE_LIST_NAME - - private var failedFiles: FailedFileList = - failedFileList.getOrElse(FailedFileList(storage, conn)) - - private var fileList: mutable.PriorityQueue[(String, Long)] = - mutable.PriorityQueue - .empty[(String, Long)](Ordering.by[(String, Long), Long](_._2).reverse) - - if (ingestList.isDefined) { - ingestList.get.foreach(fileList += _) - } - - def addFiles(names: List[String]): Unit = { - val time = System.currentTimeMillis() - names.foreach(fileList += _ -> time) - save() - } - - override def toString: String = { - val node = IngestContextManager.mapper.createObjectNode() - node.put(IngestContextManager.FAILED_FILE_INDEX, failedFiles.fileIndex) - - val arr = node.putArray(IngestContextManager.LIST) - fileList.foreach { - case (name, time) => - val n = IngestContextManager.mapper.createObjectNode() - n.put(IngestContextManager.NAME, name) - n.put(IngestContextManager.TIME, time) - arr.add(n) - } - - node.toString - } - - def checkResponseList(list: List[(String, IngestStatus)]): Unit = { - var toClean: List[String] = Nil - var failed: List[String] = Nil - - list.foreach { - case (name, status) => - if (fileList.exists(_._1 == name)) { - status match { - case IngestStatus.LOADED => - toClean = name :: toClean - fileList = fileList.filterNot(_._1 == name) - case IngestStatus.LOAD_FAILED | IngestStatus.PARTIALLY_LOADED => - failed = name :: failed - fileList = fileList.filterNot(_._1 == name) - case _ => // do nothing - } - } - } - if (toClean.nonEmpty) storage.deleteFiles(toClean) - if (failed.nonEmpty) failedFiles = failedFiles.addFiles(failed) - save() - } - - def getFirstTimeStamp: Option[Long] = - if (fileList.isEmpty) None else Some(fileList.head._2) - - def isEmpty: Boolean = fileList.isEmpty - - def nonEmpty: Boolean = fileList.nonEmpty - - def remove(): Unit = - storage.deleteFile(IngestContextManager.CONTEXT_DIR + "/" + fileName) - -} diff --git a/src/main/scala/net/snowflake/spark/snowflake/streaming/SnowflakeSink.scala b/src/main/scala/net/snowflake/spark/snowflake/streaming/SnowflakeSink.scala deleted file mode 100644 index 1f453ed9..00000000 --- a/src/main/scala/net/snowflake/spark/snowflake/streaming/SnowflakeSink.scala +++ /dev/null @@ -1,245 +0,0 @@ -package net.snowflake.spark.snowflake.streaming - -import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper -import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.node.ObjectNode -import net.snowflake.spark.snowflake.DefaultJDBCWrapper.DataBaseOperations -import net.snowflake.spark.snowflake._ -import net.snowflake.spark.snowflake.io.SupportedFormat.SupportedFormat -import net.snowflake.spark.snowflake.io.{ - CloudStorage, - CloudStorageOperations, - SupportedFormat -} -import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} -import org.apache.spark.sql.execution.streaming.Sink -import org.apache.spark.sql.snowflake.SparkStreamingFunctions.streamingToNonStreaming -import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryListener} -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, SQLContext} -import org.slf4j.LoggerFactory - -class SnowflakeSink(sqlContext: SQLContext, - parameters: Map[String, String], - partitionColumns: Seq[String], - outputMode: OutputMode) - extends Sink { - private val STREAMING_OBJECT_PREFIX = "TMP_SPARK" - private val PIPE_TOKEN = "PIPE" - - private val log = LoggerFactory.getLogger(getClass) - - private val param = Parameters.mergeParameters( - parameters + ("keep_column_case" -> "on") - ) // always keep column name case - - // discussion: Do we want to support overwrite mode? - // In Spark Streaming, there are only three mode append, complete, update - require( - outputMode == OutputMode.Append(), - "Snowflake streaming only supports append mode" - ) - - require( - param.table.isDefined, - "Snowflake table name must be specified with 'dbtable' parameter" - ) - - require( - param.privateKey.isDefined, - "Private key must be specified in Snowflake streaming" - ) - - require( - param.streamingStage.isDefined, - "Streaming stage name must be specified with 'streaming_stage' parameter" - ) - - require( - param.rootTempDir.isEmpty, - "Spark Streaming only supports internal stages, please unset tempDir parameter." - ) - - private implicit val conn: ServerConnection = DefaultJDBCWrapper.getConnector(param) - - private val stageName: String = { - val name = param.streamingStage.get - conn.createStage(name) - name - } - - private implicit val storage: CloudStorage = - CloudStorageOperations - .createStorageClient(param, conn, tempStage = false, Some(stageName)) - ._1 - - private val pipeName: String = - s"${STREAMING_OBJECT_PREFIX}_${PIPE_TOKEN}_$stageName" - - private lazy val format: SupportedFormat = - if (Utils.containVariant(cachedSchema.get)) SupportedFormat.JSON - else SupportedFormat.CSV - - private lazy val ingestService: SnowflakeIngestService = { - val service = - openIngestionService(param, pipeName, format, cachedSchema.get, storage, conn) - init() - service - } - - private val compress: Boolean = param.sfCompress - - private var cachedSchema: Option[StructType] = None - - private val streamingStartTime: Long = System.currentTimeMillis() - - // telemetry - private var lastMetricSendTime: Long = 0 - private val mapper = new ObjectMapper() - private val metric: ObjectNode = mapper.createObjectNode() - - private val APP_NAME = TelemetryFieldNames.APPLICATION_NAME - private val START_TIME = TelemetryFieldNames.START_TIME - private val END_TIME = TelemetryFieldNames.END_TIME - private val LOAD_RATE = TelemetryFieldNames.LOAD_RATE - private val DATA_BATCH = TelemetryFieldNames.DATA_BATCH - - private val telemetrySendTime: Long = 10 * 60 * 1000 // 10 min - - // streaming start event - sendStartTelemetry() - - /** - * Create pipe - */ - def init(): Unit = { - sqlContext.sparkContext.addSparkListener(new SparkListener { - override def onApplicationEnd( - applicationEnd: SparkListenerApplicationEnd - ): Unit = { - super.onApplicationEnd(applicationEnd) - closeAllIngestionService() - - // telemetry - val time = System.currentTimeMillis() - metric.put(END_TIME, time) - metric.get(LOAD_RATE).asInstanceOf[ObjectNode].put(END_TIME, time) - SnowflakeTelemetry.addCommonFields(metric) - - SnowflakeTelemetry.addLog( - ((TelemetryTypes.SPARK_STREAMING, metric), time) - ) - SnowflakeTelemetry.send(conn.getTelemetry) - - // streaming termination event - sendEndTelemetry() - } - }) - - sqlContext.sparkSession.streams.addListener(new StreamingQueryListener { - override def onQueryStarted( - event: StreamingQueryListener.QueryStartedEvent - ): Unit = {} - - override def onQueryProgress( - event: StreamingQueryListener.QueryProgressEvent - ): Unit = {} - - override def onQueryTerminated( - event: StreamingQueryListener.QueryTerminatedEvent - ): Unit = - closeIngestionService(pipeName) - }) - } - - override def addBatch(batchId: Long, data: DataFrame): Unit = { - - def registerDataBatchToTelemetry(): Unit = { - val time = System.currentTimeMillis() - if (lastMetricSendTime == 0) { // init - metric.put( - APP_NAME, - (data.sparkSession.sparkContext.appName + streamingStartTime.toString).hashCode - ) // use hashcode, hide app name - metric.put(START_TIME, time) - lastMetricSendTime = time - val rate = metric.putObject(LOAD_RATE) - rate.put(START_TIME, time) - rate.put(DATA_BATCH, 0) - } - val rate = metric.get(LOAD_RATE).asInstanceOf[ObjectNode] - rate.put(DATA_BATCH, rate.get(DATA_BATCH).asInt() + 1) - - if (time - lastMetricSendTime > telemetrySendTime) { - rate.put(END_TIME, time) - SnowflakeTelemetry.addCommonFields(rate) - SnowflakeTelemetry.addLog( - ((TelemetryTypes.SPARK_STREAMING, metric.deepCopy()), time) - ) - SnowflakeTelemetry.send(conn.getTelemetry) - lastMetricSendTime = time - rate.put(START_TIME, time) - rate.put(DATA_BATCH, 0) - } - - } - - if (cachedSchema.isEmpty) cachedSchema = Some(data.schema) - // prepare data - val rdd = - DefaultSnowflakeWriter.dataFrameToRDD( - sqlContext, - streamingToNonStreaming(sqlContext, data), - param, - format - ) - if (!rdd.isEmpty()) { - // write to storage - val fileList = - CloudStorageOperations - .saveToStorage(rdd, format, Some(batchId.toString), compress) - ingestService.ingestFiles(fileList) - - registerDataBatchToTelemetry() - } - } - - private def sendStartTelemetry(): Unit = { - val message: ObjectNode = mapper.createObjectNode() - message.put( - APP_NAME, - (sqlContext.sparkSession.sparkContext.appName + streamingStartTime.toString).hashCode - ) - message.put(START_TIME, streamingStartTime) - SnowflakeTelemetry.addCommonFields(message) - - SnowflakeTelemetry.addLog( - (TelemetryTypes.SPARK_STREAMING_START, message), - streamingStartTime - ) - SnowflakeTelemetry.send(conn.getTelemetry) - - log.info("Streaming started") - - } - - private def sendEndTelemetry(): Unit = { - val endTime: Long = System.currentTimeMillis() - val message: ObjectNode = mapper.createObjectNode() - message.put( - APP_NAME, - (sqlContext.sparkSession.sparkContext.appName + streamingStartTime.toString).hashCode - ) - message.put(START_TIME, streamingStartTime) - message.put(END_TIME, endTime) - SnowflakeTelemetry.addCommonFields(message) - - SnowflakeTelemetry.addLog( - (TelemetryTypes.SPARK_STREAMING_END, message), - endTime - ) - SnowflakeTelemetry.send(conn.getTelemetry) - - log.info("Streaming stopped") - } - -} diff --git a/src/main/scala/net/snowflake/spark/snowflake/streaming/package.scala b/src/main/scala/net/snowflake/spark/snowflake/streaming/package.scala deleted file mode 100644 index 1830142f..00000000 --- a/src/main/scala/net/snowflake/spark/snowflake/streaming/package.scala +++ /dev/null @@ -1,207 +0,0 @@ -package net.snowflake.spark.snowflake - -import net.snowflake.spark.snowflake.Parameters.MergedParameters -import net.snowflake.spark.snowflake.io.{CloudStorage, SupportedFormat} -import net.snowflake.spark.snowflake.io.SupportedFormat.SupportedFormat -import net.snowflake.spark.snowflake.DefaultJDBCWrapper.DataBaseOperations -import org.apache.spark.sql.types.StructType -import org.slf4j.LoggerFactory - -import scala.collection.mutable -import scala.concurrent.{Await, Future} -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.duration._ -import scala.language.postfixOps - -package object streaming { - - private val LOGGER = LoggerFactory.getLogger(this.getClass.getName) - private val SLEEP_TIME = 5000 // 5 seconds - private val TIME_OUT = 5 // 5 minutes - - private val pipeList: mutable.HashMap[String, SnowflakeIngestService] = - new mutable.HashMap() - - private[streaming] def openIngestionService( - param: MergedParameters, - pipeName: String, - format: SupportedFormat, - schema: StructType, - storage: CloudStorage, - conn: ServerConnection - ): SnowflakeIngestService = { - LOGGER.debug(s"create new ingestion service, pipe name: $pipeName") - - var pipeDropped = false - val checkPrevious: Future[Boolean] = Future { - while (pipeList.contains(pipeName)) { - LOGGER.debug(s"waiting previous pipe dropped") - Thread.sleep(SLEEP_TIME) - } - LOGGER.debug(s"previous pipe dropped") - pipeDropped = true - pipeDropped - } - - Await.result(checkPrevious, TIME_OUT minutes) - - if (pipeDropped) { - conn.createTable(param.table.get.name, schema, param, overwrite = false, temporary = false) - - val copy = ConstantString(copySql(param, conn, format, schema)) ! - - if (verifyPipe(conn, pipeName, copy.toString)) { - LOGGER.info(s"reuse pipe: $pipeName") - } else conn.createPipe(pipeName, copy, overwrite = true) - - val ingestion = new SnowflakeIngestService(param, pipeName, storage, conn) - pipeList.put(pipeName, ingestion) - ingestion - } else { - LOGGER.error(s"waiting pipe dropped time out") - throw new IllegalStateException( - s"Waiting pipe dropped time out, pipe name: $pipeName" - ) - } - } - - private[streaming] def closeIngestionService(pipeName: String): Unit = { - LOGGER.debug(s"closing ingestion service, pipe name: $pipeName") - if (pipeList.contains(pipeName)) { - pipeList(pipeName).close() - pipeList.remove(pipeName) - LOGGER.debug(s"ingestion service closed, pipe name: $pipeName") - } else { - LOGGER.error(s"ingestion service not found, pipe name: $pipeName") - } - } - - private[streaming] def closeAllIngestionService(): Unit = { - LOGGER.debug(s"closing ingestion service") - pipeList.foreach(_._2.close()) - LOGGER.debug(s"all ingestion service closed") - } - - /** - * Generate the COPY SQL command for creating pipe only - */ - private def copySql(param: MergedParameters, - conn: ServerConnection, - format: SupportedFormat, - schema: StructType - ): String = { - - val tableName = param.table.get - val stageName = param.streamingStage.get - val tableSchema = - DefaultJDBCWrapper.resolveTable(conn, tableName.toString, param) - - def getMappingToString(list: Option[List[(Int, String)]]): String = - format match { - case SupportedFormat.JSON => - val schema = - DefaultJDBCWrapper.resolveTable(conn, tableName.name, param) - if (list.isEmpty || list.get.isEmpty) { - s"(${schema.fields.map(x => Utils.quotedNameIgnoreCase(x.name)).mkString(",")})" - } else { - s"(${list.get.map(x => - Utils.quotedNameIgnoreCase(tableSchema(x._1).name)).mkString(", ")})" - } - case SupportedFormat.CSV => - if (list.isEmpty || list.get.isEmpty) { - "" - } else { - s"(${list.get.map(x => Utils.quotedNameIgnoreCase(x._2)).mkString(", ")})" - } - } - - def getMappingFromString(list: Option[List[(Int, String)]], - from: String): String = - format match { - case SupportedFormat.JSON => - val columnPrefix = if (param.useParseJsonForWrite) "parse_json($1):" else "$1:" - if (list.isEmpty || list.get.isEmpty) { - val names = - tableSchema.fields - .map( - x => - columnPrefix.concat(Utils.quotedNameIgnoreCase(x.name)) - ) - .mkString(",") - s"from (select $names $from tmp)" - } else { - s"from (select ${list.get.map(x => columnPrefix.concat( - Utils.quotedNameIgnoreCase(x._2))).mkString(", ")} $from tmp)" - } - case SupportedFormat.CSV => - if (list.isEmpty || list.get.isEmpty) { - from - } else { - s"from (select ${list.get - .map(x => "tmp.$".concat(Utils.quotedNameIgnoreCase((x._1 + 1).toString))) - .mkString(", ")} $from tmp)" - } - } - - val fromString = s"FROM @$stageName" - - if (param.columnMap.isEmpty && param.columnMapping == "name") { - param.setColumnMap(Option(schema), Option(tableSchema)) - } - - val mappingList: Option[List[(Int, String)]] = param.columnMap match { - case Some(map) => - Some(map.toList.map { - case (key, value) => - try { - (tableSchema.fieldIndex(value), key) - } catch { - case e: Exception => - LOGGER.error("Error occurred while column mapping: " + e) - throw e - } - }) - - case None => None - } - - val mappingToString = getMappingToString(mappingList) - - val mappingFromString = getMappingFromString(mappingList, fromString) - - val formatString = - format match { - case SupportedFormat.CSV => - s""" - |FILE_FORMAT = ( - | TYPE=CSV - | FIELD_DELIMITER='|' - | NULL_IF=() - | FIELD_OPTIONALLY_ENCLOSED_BY='"' - | TIMESTAMP_FORMAT='TZHTZM YYYY-MM-DD HH24:MI:SS.FF3' - | ) - """.stripMargin - case SupportedFormat.JSON => - s""" - |FILE_FORMAT = ( - | TYPE = JSON - |) - """.stripMargin - } - - s""" - |COPY INTO $tableName $mappingToString - |$mappingFromString - |$formatString - """.stripMargin.trim - } - - private[streaming] def verifyPipe(conn: ServerConnection, - pipeName: String, - copyStatement: String): Boolean = - conn.pipeDefinition(pipeName) match { - case Some(str) => str.trim.equals(copyStatement.trim) - case _ => false - } - -}