diff --git a/pom.xml b/pom.xml
index 465b09092..ac465508f 100644
--- a/pom.xml
+++ b/pom.xml
@@ -13,7 +13,8 @@
streamingpro-api
streamingpro-hbase
streamingpro-redis
- streamingpro-model
+ streamingpro-dl4j
+ streamingpro-tensorflow
diff --git a/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/SQLAlg.scala b/streamingpro-api/src/main/java/streaming/dsl/mmlib/SQLAlg.scala
similarity index 99%
rename from streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/SQLAlg.scala
rename to streamingpro-api/src/main/java/streaming/dsl/mmlib/SQLAlg.scala
index f1908d94b..065e10bee 100644
--- a/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/SQLAlg.scala
+++ b/streamingpro-api/src/main/java/streaming/dsl/mmlib/SQLAlg.scala
@@ -1,6 +1,5 @@
package streaming.dsl.mmlib
-
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.UserDefinedFunction
diff --git a/streamingpro-commons/src/main/java/streaming/common/HDFSOperator.scala b/streamingpro-commons/src/main/java/streaming/common/HDFSOperator.scala
index 1a4ad128f..bd1363c8a 100644
--- a/streamingpro-commons/src/main/java/streaming/common/HDFSOperator.scala
+++ b/streamingpro-commons/src/main/java/streaming/common/HDFSOperator.scala
@@ -1,7 +1,8 @@
package streaming.common
-import java.io.{BufferedReader, InputStreamReader}
+import java.io.{BufferedReader, File, InputStreamReader}
+import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FSDataOutputStream, FileStatus, FileSystem, Path}
@@ -97,6 +98,23 @@ object HDFSOperator {
}
+ def copyToHDFS(tempModelLocalPath: String, path: String, clean: Boolean) = {
+ val fs = FileSystem.get(new Configuration())
+ fs.delete(new Path(path), true)
+ fs.copyFromLocalFile(new Path(tempModelLocalPath),
+ new Path(path))
+ FileUtils.forceDelete(new File(tempModelLocalPath))
+ }
+
+ def createTempModelLocalPath(path: String, autoCreateParentDir: Boolean = true) = {
+ val dir = "/tmp/train/" + Md5.md5Hash(path)
+ if (autoCreateParentDir) {
+ FileUtils.forceMkdir(new File(dir))
+ }
+ dir
+ }
+
+
def main(args: Array[String]): Unit = {
println(readFile("file:///Users/allwefantasy/streamingpro/flink.json"))
}
diff --git a/streamingpro-commons/src/main/java/streaming/common/Md5.scala b/streamingpro-commons/src/main/java/streaming/common/Md5.scala
new file mode 100644
index 000000000..720ebe48d
--- /dev/null
+++ b/streamingpro-commons/src/main/java/streaming/common/Md5.scala
@@ -0,0 +1,12 @@
+package streaming.common
+
+/**
+ * Created by allwefantasy on 25/4/2018.
+ */
+object Md5 {
+ def md5Hash(text: String): String = java.security.MessageDigest.getInstance("MD5").digest(text.getBytes()).map(0xFF & _).map {
+ "%02x".format(_)
+ }.foldLeft("") {
+ _ + _
+ }
+}
diff --git a/streamingpro-model/pom.xml b/streamingpro-dl4j/pom.xml
similarity index 83%
rename from streamingpro-model/pom.xml
rename to streamingpro-dl4j/pom.xml
index 4482bac4b..150e5a666 100644
--- a/streamingpro-model/pom.xml
+++ b/streamingpro-dl4j/pom.xml
@@ -9,22 +9,19 @@
4.0.0
-
- 2.11
-
- streamingpro-model
+ streamingpro-dl4j
+
- org.tensorflow
- libtensorflow
- 1.5.0-rc1
+ streaming.king
+ streamingpro-api
+ ${parent.version}
- org.tensorflow
- libtensorflow_jni_gpu
- 1.5.0-rc1
+ streaming.king
+ streamingpro-common
+ ${parent.version}
-
org.deeplearning4j
dl4j-spark_${scala.binary.version}
@@ -63,6 +60,7 @@
org.nd4j
nd4j-native-platform
0.9.1
+
@@ -70,6 +68,7 @@
org.nd4j
nd4j-kryo_${scala.binary.version}
0.9.1
+
com.google.guava
@@ -81,14 +80,20 @@
org.bytedeco
javacv-platform
1.4
+
org.nd4j
nd4j-cuda-7.5
0.9.1
-
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ ${scope}
+
-
\ No newline at end of file
diff --git a/streamingpro-model/src/main/java/streaming/dl4j/DL4JModelLoader.scala b/streamingpro-dl4j/src/main/java/streaming/dl4j/DL4JModelLoader.scala
similarity index 100%
rename from streamingpro-model/src/main/java/streaming/dl4j/DL4JModelLoader.scala
rename to streamingpro-dl4j/src/main/java/streaming/dl4j/DL4JModelLoader.scala
diff --git a/streamingpro-dl4j/src/main/java/streaming/dl4j/Dl4jFunctions.scala b/streamingpro-dl4j/src/main/java/streaming/dl4j/Dl4jFunctions.scala
new file mode 100644
index 000000000..ad885f217
--- /dev/null
+++ b/streamingpro-dl4j/src/main/java/streaming/dl4j/Dl4jFunctions.scala
@@ -0,0 +1,92 @@
+package streaming.dl4j
+
+import streaming.common.HDFSOperator
+import java.io.File
+import java.util.Collections
+
+import net.csdn.common.logging.Loggers
+import streaming.dsl.mmlib.algs.SQLDL4J
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.sql._
+import org.deeplearning4j.eval.Evaluation
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration
+import org.deeplearning4j.optimize.api.IterationListener
+import org.deeplearning4j.optimize.listeners.ScoreIterationListener
+import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer
+import org.deeplearning4j.util.ModelSerializer
+import org.nd4j.linalg.factory.Nd4j
+
+
+/**
+ * Created by allwefantasy on 25/4/2018.
+ */
+trait Dl4jFunctions {
+ val logger = Loggers.getLogger(getClass)
+
+ def dl4jClassificationTrain(df: DataFrame, path: String, params: Map[String, String], multiLayerConfiguration: () => MultiLayerConfiguration): Unit = {
+ require(params.contains("featureSize"), "featureSize is required")
+
+ val labelSize = params.getOrElse("labelSize", "-1").toInt
+ val batchSize = params.getOrElse("batchSize", "32").toInt
+
+ val epochs = params.getOrElse("epochs", "1").toInt
+ val validateTable = params.getOrElse("validateTable", "")
+
+ val tm = SQLDL4J.init2(df.sparkSession.sparkContext.isLocal, batchSizePerWorker = batchSize)
+
+ val netConf = multiLayerConfiguration()
+
+ val sparkNetwork = new SparkDl4jMultiLayer(df.sparkSession.sparkContext, netConf, tm)
+ sparkNetwork.setCollectTrainingStats(false)
+ sparkNetwork.setListeners(Collections.singletonList[IterationListener](new ScoreIterationListener(1)))
+
+ val labelFieldName = params.getOrElse("outputCol", "label")
+ val newDataSetRDD = if (df.schema.fieldNames.contains(labelFieldName)) {
+
+ require(params.contains("labelSize"), "labelSize is required")
+
+ df.select(params.getOrElse("inputCol", "features"), params.getOrElse("outputCol", "label")).rdd.map { row =>
+ val features = row.getAs[Vector](0)
+ val label = row.getAs[Vector](1)
+ new org.nd4j.linalg.dataset.DataSet(Nd4j.create(features.toArray), Nd4j.create(label.toArray))
+ }.toJavaRDD()
+ } else {
+ df.select(params.getOrElse("inputCol", "features")).rdd.map { row =>
+ val features = row.getAs[Vector](0)
+ new org.nd4j.linalg.dataset.DataSet(Nd4j.create(features.toArray), Nd4j.zeros(0))
+ }.toJavaRDD()
+ }
+
+
+ (0 until epochs).foreach { i =>
+ sparkNetwork.fit(newDataSetRDD)
+ }
+
+ val tempModelLocalPath = createTempModelLocalPath(path)
+ ModelSerializer.writeModel(sparkNetwork.getNetwork, new File(tempModelLocalPath, "dl4j.model"), true)
+ copyToHDFS(tempModelLocalPath + "/dl4j.model", path + "/dl4j.model", true)
+
+ if (!validateTable.isEmpty) {
+
+ val testDataSetRDD = df.sparkSession.table(validateTable).select(params.getOrElse("inputCol", "features"), params.getOrElse("outputCol", "label")).rdd.map { row =>
+ val features = row.getAs[Vector](0)
+ val label = row.getAs[Vector](1)
+ new org.nd4j.linalg.dataset.DataSet(Nd4j.create(features.toArray), Nd4j.create(label.toArray))
+ }.toJavaRDD()
+
+ val evaluation = sparkNetwork.doEvaluation(testDataSetRDD, batchSize, new Evaluation(labelSize))(0); //Work-around for 0.9.1 bug: see https://deeplearning4j.org/releasenotes
+ logger.info("***** Evaluation *****")
+ logger.info(evaluation.stats())
+ logger.info("***** Example Complete *****")
+ }
+ tm.deleteTempFiles(df.sparkSession.sparkContext)
+ }
+
+ def copyToHDFS(tempModelLocalPath: String, path: String, clean: Boolean) = {
+ HDFSOperator.copyToHDFS(tempModelLocalPath, path, clean)
+ }
+
+ def createTempModelLocalPath(path: String, autoCreateParentDir: Boolean = true) = {
+ HDFSOperator.createTempModelLocalPath(path, autoCreateParentDir)
+ }
+}
diff --git a/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/SQLDL4J.scala b/streamingpro-dl4j/src/main/java/streaming/dsl/mmlib/algs/SQLDL4J.scala
similarity index 94%
rename from streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/SQLDL4J.scala
rename to streamingpro-dl4j/src/main/java/streaming/dsl/mmlib/algs/SQLDL4J.scala
index a4ff423db..47abce9a4 100644
--- a/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/SQLDL4J.scala
+++ b/streamingpro-dl4j/src/main/java/streaming/dsl/mmlib/algs/SQLDL4J.scala
@@ -2,23 +2,23 @@ package streaming.dsl.mmlib.algs
import java.util.Random
-import streaming.dl4j.{DL4JModelLoader, DL4JModelPredictor}
import org.apache.spark.ml.linalg.SQLDataTypes._
-import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.{DataFrame, SparkSession}
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.spark.api.RDDTrainingApproach
-import org.deeplearning4j.spark.impl.paramavg.{ParameterAveragingTrainingMaster, ParameterAveragingTrainingWorker}
+import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration
+import streaming.dl4j.{DL4JModelLoader, DL4JModelPredictor, Dl4jFunctions}
import streaming.dsl.mmlib.SQLAlg
/**
* Created by allwefantasy on 15/1/2018.
*/
-class SQLDL4J extends SQLAlg with Functions {
+class SQLDL4J extends SQLAlg with Dl4jFunctions {
override def train(df: DataFrame, path: String, params: Map[String, String]): Unit = {
require(params.contains("featureSize"), "featureSize is required")
diff --git a/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/dl4j/FCClassify.scala b/streamingpro-dl4j/src/main/java/streaming/dsl/mmlib/algs/dl4j/FCClassify.scala
similarity index 88%
rename from streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/dl4j/FCClassify.scala
rename to streamingpro-dl4j/src/main/java/streaming/dsl/mmlib/algs/dl4j/FCClassify.scala
index 432d9d4ca..315c57813 100644
--- a/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/dl4j/FCClassify.scala
+++ b/streamingpro-dl4j/src/main/java/streaming/dsl/mmlib/algs/dl4j/FCClassify.scala
@@ -1,26 +1,23 @@
package streaming.dsl.mmlib.algs.dl4j
-import streaming.dl4j.{DL4JModelLoader, DL4JModelPredictor}
import java.util.Random
-import org.apache.spark.ml.linalg.SQLDataTypes._
-import streaming.dsl.mmlib.SQLAlg
-import streaming.dsl.mmlib.algs.Functions
-import org.apache.spark.sql._
import org.apache.spark.sql.expressions.UserDefinedFunction
+import org.apache.spark.sql.{DataFrame, SparkSession}
import org.deeplearning4j.nn.api.OptimizationAlgorithm
import org.deeplearning4j.nn.conf.layers.{DenseLayer, OutputLayer}
import org.deeplearning4j.nn.conf.{NeuralNetConfiguration, Updater}
-import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.lossfunctions.LossFunctions
+import streaming.dl4j.Dl4jFunctions
+import streaming.dsl.mmlib.SQLAlg
/**
* Created by allwefantasy on 23/2/2018.
*/
-class FCClassify extends SQLAlg with Functions {
+class FCClassify extends SQLAlg with Dl4jFunctions {
def train(df: DataFrame, path: String, params: Map[String, String]): Unit = {
dl4jClassificationTrain(df, path, params, () => {
diff --git a/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/dl4j/SDAutoencoder.scala b/streamingpro-dl4j/src/main/java/streaming/dsl/mmlib/algs/dl4j/SDAutoencoder.scala
similarity index 83%
rename from streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/dl4j/SDAutoencoder.scala
rename to streamingpro-dl4j/src/main/java/streaming/dsl/mmlib/algs/dl4j/SDAutoencoder.scala
index e4eee86d2..b687e7065 100644
--- a/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/dl4j/SDAutoencoder.scala
+++ b/streamingpro-dl4j/src/main/java/streaming/dsl/mmlib/algs/dl4j/SDAutoencoder.scala
@@ -1,21 +1,12 @@
package streaming.dsl.mmlib.algs.dl4j
-import java.util.{Random}
+import java.util.Random
-import streaming.dl4j.{DL4JModelLoader, DL4JModelPredictor}
-import org.apache.spark.sql.{DataFrame, SparkSession}
-import org.apache.spark.sql.expressions.UserDefinedFunction
import org.deeplearning4j.nn.api.OptimizationAlgorithm
-import org.deeplearning4j.nn.conf.layers.OutputLayer
+import org.deeplearning4j.nn.conf.layers.variational.{BernoulliReconstructionDistribution, VariationalAutoencoder}
import org.deeplearning4j.nn.conf.{NeuralNetConfiguration, Updater}
-import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.nd4j.linalg.activations.Activation
-import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction
-import streaming.dsl.mmlib.SQLAlg
-import streaming.dsl.mmlib.algs.Functions
-import org.apache.spark.ml.linalg.SQLDataTypes._
-import org.deeplearning4j.nn.conf.layers.variational.{BernoulliReconstructionDistribution, VariationalAutoencoder}
/**
* Created by allwefantasy on 24/2/2018.
diff --git a/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/dl4j/VanillaLSTMClassify.scala b/streamingpro-dl4j/src/main/java/streaming/dsl/mmlib/algs/dl4j/VanillaLSTMClassify.scala
similarity index 89%
rename from streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/dl4j/VanillaLSTMClassify.scala
rename to streamingpro-dl4j/src/main/java/streaming/dsl/mmlib/algs/dl4j/VanillaLSTMClassify.scala
index 01015b4c5..9b193c61f 100644
--- a/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/dl4j/VanillaLSTMClassify.scala
+++ b/streamingpro-dl4j/src/main/java/streaming/dsl/mmlib/algs/dl4j/VanillaLSTMClassify.scala
@@ -2,16 +2,12 @@ package streaming.dsl.mmlib.algs.dl4j
import java.util.Random
-import org.apache.spark.sql.{DataFrame, SparkSession}
-import org.apache.spark.sql.expressions.UserDefinedFunction
import org.deeplearning4j.nn.api.OptimizationAlgorithm
-import org.deeplearning4j.nn.conf.layers.{DenseLayer, GravesLSTM, OutputLayer}
+import org.deeplearning4j.nn.conf.layers.{GravesLSTM, OutputLayer}
import org.deeplearning4j.nn.conf.{NeuralNetConfiguration, Updater}
import org.deeplearning4j.nn.weights.WeightInit
import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.lossfunctions.LossFunctions
-import streaming.dsl.mmlib.SQLAlg
-import streaming.dsl.mmlib.algs.Functions
/**
* Created by allwefantasy on 26/2/2018.
diff --git a/streamingpro-spark-2.0/pom.xml b/streamingpro-spark-2.0/pom.xml
index ed92c948d..e889ef221 100644
--- a/streamingpro-spark-2.0/pom.xml
+++ b/streamingpro-spark-2.0/pom.xml
@@ -124,6 +124,16 @@
+
+ dl4j
+
+
+ streaming.king
+ streamingpro-dl4j
+ ${parent.version}
+
+
+
online
@@ -195,7 +205,7 @@
streaming.king
- streamingpro-model
+ streamingpro-tensorflow
${parent.version}
diff --git a/streamingpro-spark-2.0/src/main/java/org/apache/spark/ps/cluster/PSExecutorBackend.scala b/streamingpro-spark-2.0/src/main/java/org/apache/spark/ps/cluster/PSExecutorBackend.scala
index 8448042d5..4f390b7e3 100644
--- a/streamingpro-spark-2.0/src/main/java/org/apache/spark/ps/cluster/PSExecutorBackend.scala
+++ b/streamingpro-spark-2.0/src/main/java/org/apache/spark/ps/cluster/PSExecutorBackend.scala
@@ -1,13 +1,11 @@
package org.apache.spark.ps.cluster
+import streaming.tensorflow.TFModelLoader
import java.util.{Locale}
-
import org.apache.spark.internal.Logging
import org.apache.spark.{SparkEnv}
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.util.ThreadUtils
-import streaming.tensorflow.TFModelLoader
-
import scala.util.{Failure, Success}
diff --git a/streamingpro-spark-2.0/src/main/java/org/apache/spark/ps/local/LocalPSSchedulerBackend.scala b/streamingpro-spark-2.0/src/main/java/org/apache/spark/ps/local/LocalPSSchedulerBackend.scala
index e9b7f74a1..b6845c173 100644
--- a/streamingpro-spark-2.0/src/main/java/org/apache/spark/ps/local/LocalPSSchedulerBackend.scala
+++ b/streamingpro-spark-2.0/src/main/java/org/apache/spark/ps/local/LocalPSSchedulerBackend.scala
@@ -1,5 +1,6 @@
package org.apache.spark.ps.local
+import streaming.tensorflow.TFModelLoader
import java.io.File
import java.net.URL
import org.apache.spark.{SparkConf, SparkContext, SparkEnv}
@@ -8,7 +9,7 @@ import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle}
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.StopExecutor
-import streaming.tensorflow.TFModelLoader
+
private case class TensorFlowModelClean(modelPath: String)
diff --git a/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/Functions.scala b/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/Functions.scala
index 2abf9c833..8979aa469 100644
--- a/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/Functions.scala
+++ b/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/Functions.scala
@@ -1,12 +1,10 @@
package streaming.dsl.mmlib.algs
import java.io.{ByteArrayOutputStream, File}
-import java.util.{Collections, Properties, Random}
+import java.util.Properties
import net.csdn.common.logging.Loggers
import org.apache.commons.io.FileUtils
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord}
import org.apache.spark.Partitioner
import org.apache.spark.ml.linalg.SQLDataTypes._
@@ -17,16 +15,6 @@ import org.apache.spark.ml.util.{MLReadable, MLWritable}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, functions => F}
import org.apache.spark.util.{ExternalCommandRunner, ObjPickle, WowMD5, WowXORShiftRandom}
-import org.deeplearning4j.eval.Evaluation
-import org.deeplearning4j.nn.conf.MultiLayerConfiguration
-import org.deeplearning4j.optimize.api.IterationListener
-import org.deeplearning4j.optimize.listeners.ScoreIterationListener
-import org.deeplearning4j.spark.api.RDDTrainingApproach
-import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer
-import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster
-import org.deeplearning4j.util.ModelSerializer
-import org.nd4j.linalg.factory.Nd4j
-import org.nd4j.parameterserver.distributed.conf.VoidConfiguration
import streaming.common.HDFSOperator
import scala.collection.mutable.ArrayBuffer
@@ -228,78 +216,11 @@ trait Functions {
}
def copyToHDFS(tempModelLocalPath: String, path: String, clean: Boolean) = {
- val fs = FileSystem.get(new Configuration())
- fs.delete(new Path(path), true)
- fs.copyFromLocalFile(new Path(tempModelLocalPath),
- new Path(path))
- FileUtils.forceDelete(new File(tempModelLocalPath))
+ HDFSOperator.copyToHDFS(tempModelLocalPath, path, clean)
}
def createTempModelLocalPath(path: String, autoCreateParentDir: Boolean = true) = {
- val dir = "/tmp/train/" + WowMD5.md5Hash(path)
- if (autoCreateParentDir) {
- FileUtils.forceMkdir(new File(dir))
- }
- dir
- }
-
- def dl4jClassificationTrain(df: DataFrame, path: String, params: Map[String, String], multiLayerConfiguration: () => MultiLayerConfiguration): Unit = {
- require(params.contains("featureSize"), "featureSize is required")
-
- val labelSize = params.getOrElse("labelSize", "-1").toInt
- val batchSize = params.getOrElse("batchSize", "32").toInt
-
- val epochs = params.getOrElse("epochs", "1").toInt
- val validateTable = params.getOrElse("validateTable", "")
-
- val tm = SQLDL4J.init2(df.sparkSession.sparkContext.isLocal, batchSizePerWorker = batchSize)
-
- val netConf = multiLayerConfiguration()
-
- val sparkNetwork = new SparkDl4jMultiLayer(df.sparkSession.sparkContext, netConf, tm)
- sparkNetwork.setCollectTrainingStats(false)
- sparkNetwork.setListeners(Collections.singletonList[IterationListener](new ScoreIterationListener(1)))
-
- val labelFieldName = params.getOrElse("outputCol", "label")
- val newDataSetRDD = if (df.schema.fieldNames.contains(labelFieldName)) {
-
- require(params.contains("labelSize"), "labelSize is required")
-
- df.select(params.getOrElse("inputCol", "features"), params.getOrElse("outputCol", "label")).rdd.map { row =>
- val features = row.getAs[Vector](0)
- val label = row.getAs[Vector](1)
- new org.nd4j.linalg.dataset.DataSet(Nd4j.create(features.toArray), Nd4j.create(label.toArray))
- }.toJavaRDD()
- } else {
- df.select(params.getOrElse("inputCol", "features")).rdd.map { row =>
- val features = row.getAs[Vector](0)
- new org.nd4j.linalg.dataset.DataSet(Nd4j.create(features.toArray), Nd4j.zeros(0))
- }.toJavaRDD()
- }
-
-
- (0 until epochs).foreach { i =>
- sparkNetwork.fit(newDataSetRDD)
- }
-
- val tempModelLocalPath = createTempModelLocalPath(path)
- ModelSerializer.writeModel(sparkNetwork.getNetwork, new File(tempModelLocalPath, "dl4j.model"), true)
- copyToHDFS(tempModelLocalPath + "/dl4j.model", path + "/dl4j.model", true)
-
- if (!validateTable.isEmpty) {
-
- val testDataSetRDD = df.sparkSession.table(validateTable).select(params.getOrElse("inputCol", "features"), params.getOrElse("outputCol", "label")).rdd.map { row =>
- val features = row.getAs[Vector](0)
- val label = row.getAs[Vector](1)
- new org.nd4j.linalg.dataset.DataSet(Nd4j.create(features.toArray), Nd4j.create(label.toArray))
- }.toJavaRDD()
-
- val evaluation = sparkNetwork.doEvaluation(testDataSetRDD, batchSize, new Evaluation(labelSize))(0); //Work-around for 0.9.1 bug: see https://deeplearning4j.org/releasenotes
- logger.info("***** Evaluation *****")
- logger.info(evaluation.stats())
- logger.info("***** Example Complete *****")
- }
- tm.deleteTempFiles(df.sparkSession.sparkContext)
+ HDFSOperator.createTempModelLocalPath(path, autoCreateParentDir)
}
}
diff --git a/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/SQLTensorFlow.scala b/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/SQLTensorFlow.scala
index 0a63f89bd..6b49afedd 100644
--- a/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/SQLTensorFlow.scala
+++ b/streamingpro-spark-2.0/src/main/java/streaming/dsl/mmlib/algs/SQLTensorFlow.scala
@@ -4,8 +4,6 @@ import streaming.tensorflow.TFModelLoader
import streaming.tensorflow.TFModelPredictor
import java.io.{ByteArrayOutputStream, File}
import java.util
-import java.util.Properties
-
import com.hortonworks.spark.sql.kafka08.KafkaOperator
import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
diff --git a/streamingpro-tensorflow/pom.xml b/streamingpro-tensorflow/pom.xml
new file mode 100644
index 000000000..f61333761
--- /dev/null
+++ b/streamingpro-tensorflow/pom.xml
@@ -0,0 +1,46 @@
+
+
+
+ streamingpro
+ streaming.king
+ 1.1.0
+
+ 4.0.0
+
+
+ 2.11
+
+ streamingpro-tensorflow
+
+
+ streaming.king
+ streamingpro-api
+ ${parent.version}
+
+
+ streaming.king
+ streamingpro-common
+ ${parent.version}
+
+
+ org.tensorflow
+ libtensorflow
+ 1.5.0-rc1
+
+
+ org.tensorflow
+ libtensorflow_jni_gpu
+ 1.5.0-rc1
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ ${scope}
+
+
+
+
+
\ No newline at end of file
diff --git a/streamingpro-model/src/main/java/streaming/tensorflow/TFModelLoader.scala b/streamingpro-tensorflow/src/main/java/streaming/tensorflow/TFModelLoader.scala
similarity index 100%
rename from streamingpro-model/src/main/java/streaming/tensorflow/TFModelLoader.scala
rename to streamingpro-tensorflow/src/main/java/streaming/tensorflow/TFModelLoader.scala