Skip to content

Commit

Permalink
把dl4j剥离成单独模块,避免Jar包过大
Browse files Browse the repository at this point in the history
  • Loading branch information
allwefantasy committed Apr 25, 2018
1 parent ce4d22f commit ee31da1
Show file tree
Hide file tree
Showing 18 changed files with 217 additions and 132 deletions.
3 changes: 2 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
<module>streamingpro-api</module>
<module>streamingpro-hbase</module>
<module>streamingpro-redis</module>
<module>streamingpro-model</module>
<module>streamingpro-dl4j</module>
<module>streamingpro-tensorflow</module>
</modules>

<properties>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package streaming.dsl.mmlib


import org.apache.spark.sql._
import org.apache.spark.sql.expressions.UserDefinedFunction

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package streaming.common

import java.io.{BufferedReader, InputStreamReader}
import java.io.{BufferedReader, File, InputStreamReader}

import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FSDataOutputStream, FileStatus, FileSystem, Path}

Expand Down Expand Up @@ -97,6 +98,23 @@ object HDFSOperator {

}

def copyToHDFS(tempModelLocalPath: String, path: String, clean: Boolean) = {
val fs = FileSystem.get(new Configuration())
fs.delete(new Path(path), true)
fs.copyFromLocalFile(new Path(tempModelLocalPath),
new Path(path))
FileUtils.forceDelete(new File(tempModelLocalPath))
}

def createTempModelLocalPath(path: String, autoCreateParentDir: Boolean = true) = {
val dir = "/tmp/train/" + Md5.md5Hash(path)
if (autoCreateParentDir) {
FileUtils.forceMkdir(new File(dir))
}
dir
}


def main(args: Array[String]): Unit = {
println(readFile("file:///Users/allwefantasy/streamingpro/flink.json"))
}
Expand Down
12 changes: 12 additions & 0 deletions streamingpro-commons/src/main/java/streaming/common/Md5.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package streaming.common

/**
* Created by allwefantasy on 25/4/2018.
*/
object Md5 {
def md5Hash(text: String): String = java.security.MessageDigest.getInstance("MD5").digest(text.getBytes()).map(0xFF & _).map {
"%02x".format(_)
}.foldLeft("") {
_ + _
}
}
31 changes: 18 additions & 13 deletions streamingpro-model/pom.xml → streamingpro-dl4j/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,19 @@
</parent>
<modelVersion>4.0.0</modelVersion>

<properties>
<scala.binary.version>2.11</scala.binary.version>
</properties>
<artifactId>streamingpro-model</artifactId>
<artifactId>streamingpro-dl4j</artifactId>

<dependencies>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow</artifactId>
<version>1.5.0-rc1</version>
<groupId>streaming.king</groupId>
<artifactId>streamingpro-api</artifactId>
<version>${parent.version}</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni_gpu</artifactId>
<version>1.5.0-rc1</version>
<groupId>streaming.king</groupId>
<artifactId>streamingpro-common</artifactId>
<version>${parent.version}</version>
</dependency>

<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-spark_${scala.binary.version}</artifactId>
Expand Down Expand Up @@ -63,13 +60,15 @@
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>0.9.1</version>

<!--<classifier>macosx-x86_64</classifier>-->
</dependency>

<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-kryo_${scala.binary.version}</artifactId>
<version>0.9.1</version>

<exclusions>
<exclusion>
<groupId>com.google.guava</groupId>
Expand All @@ -81,14 +80,20 @@
<groupId>org.bytedeco</groupId>
<artifactId>javacv-platform</artifactId>
<version>1.4</version>

</dependency>

<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-7.5</artifactId>
<version>0.9.1</version>
</dependency>

</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<scope>${scope}</scope>
</dependency>
</dependencies>

</project>
92 changes: 92 additions & 0 deletions streamingpro-dl4j/src/main/java/streaming/dl4j/Dl4jFunctions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package streaming.dl4j

import streaming.common.HDFSOperator
import java.io.File
import java.util.Collections

import net.csdn.common.logging.Loggers
import streaming.dsl.mmlib.algs.SQLDL4J
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql._
import org.deeplearning4j.eval.Evaluation
import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.optimize.api.IterationListener
import org.deeplearning4j.optimize.listeners.ScoreIterationListener
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer
import org.deeplearning4j.util.ModelSerializer
import org.nd4j.linalg.factory.Nd4j


/**
* Created by allwefantasy on 25/4/2018.
*/
trait Dl4jFunctions {
val logger = Loggers.getLogger(getClass)

def dl4jClassificationTrain(df: DataFrame, path: String, params: Map[String, String], multiLayerConfiguration: () => MultiLayerConfiguration): Unit = {
require(params.contains("featureSize"), "featureSize is required")

val labelSize = params.getOrElse("labelSize", "-1").toInt
val batchSize = params.getOrElse("batchSize", "32").toInt

val epochs = params.getOrElse("epochs", "1").toInt
val validateTable = params.getOrElse("validateTable", "")

val tm = SQLDL4J.init2(df.sparkSession.sparkContext.isLocal, batchSizePerWorker = batchSize)

val netConf = multiLayerConfiguration()

val sparkNetwork = new SparkDl4jMultiLayer(df.sparkSession.sparkContext, netConf, tm)
sparkNetwork.setCollectTrainingStats(false)
sparkNetwork.setListeners(Collections.singletonList[IterationListener](new ScoreIterationListener(1)))

val labelFieldName = params.getOrElse("outputCol", "label")
val newDataSetRDD = if (df.schema.fieldNames.contains(labelFieldName)) {

require(params.contains("labelSize"), "labelSize is required")

df.select(params.getOrElse("inputCol", "features"), params.getOrElse("outputCol", "label")).rdd.map { row =>
val features = row.getAs[Vector](0)
val label = row.getAs[Vector](1)
new org.nd4j.linalg.dataset.DataSet(Nd4j.create(features.toArray), Nd4j.create(label.toArray))
}.toJavaRDD()
} else {
df.select(params.getOrElse("inputCol", "features")).rdd.map { row =>
val features = row.getAs[Vector](0)
new org.nd4j.linalg.dataset.DataSet(Nd4j.create(features.toArray), Nd4j.zeros(0))
}.toJavaRDD()
}


(0 until epochs).foreach { i =>
sparkNetwork.fit(newDataSetRDD)
}

val tempModelLocalPath = createTempModelLocalPath(path)
ModelSerializer.writeModel(sparkNetwork.getNetwork, new File(tempModelLocalPath, "dl4j.model"), true)
copyToHDFS(tempModelLocalPath + "/dl4j.model", path + "/dl4j.model", true)

if (!validateTable.isEmpty) {

val testDataSetRDD = df.sparkSession.table(validateTable).select(params.getOrElse("inputCol", "features"), params.getOrElse("outputCol", "label")).rdd.map { row =>
val features = row.getAs[Vector](0)
val label = row.getAs[Vector](1)
new org.nd4j.linalg.dataset.DataSet(Nd4j.create(features.toArray), Nd4j.create(label.toArray))
}.toJavaRDD()

val evaluation = sparkNetwork.doEvaluation(testDataSetRDD, batchSize, new Evaluation(labelSize))(0); //Work-around for 0.9.1 bug: see https://deeplearning4j.org/releasenotes
logger.info("***** Evaluation *****")
logger.info(evaluation.stats())
logger.info("***** Example Complete *****")
}
tm.deleteTempFiles(df.sparkSession.sparkContext)
}

def copyToHDFS(tempModelLocalPath: String, path: String, clean: Boolean) = {
HDFSOperator.copyToHDFS(tempModelLocalPath, path, clean)
}

def createTempModelLocalPath(path: String, autoCreateParentDir: Boolean = true) = {
HDFSOperator.createTempModelLocalPath(path, autoCreateParentDir)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@ package streaming.dsl.mmlib.algs

import java.util.Random

import streaming.dl4j.{DL4JModelLoader, DL4JModelPredictor}
import org.apache.spark.ml.linalg.SQLDataTypes._
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.spark.api.RDDTrainingApproach
import org.deeplearning4j.spark.impl.paramavg.{ParameterAveragingTrainingMaster, ParameterAveragingTrainingWorker}
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration
import streaming.dl4j.{DL4JModelLoader, DL4JModelPredictor, Dl4jFunctions}
import streaming.dsl.mmlib.SQLAlg


/**
* Created by allwefantasy on 15/1/2018.
*/
class SQLDL4J extends SQLAlg with Functions {
class SQLDL4J extends SQLAlg with Dl4jFunctions {
override def train(df: DataFrame, path: String, params: Map[String, String]): Unit = {

require(params.contains("featureSize"), "featureSize is required")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
package streaming.dsl.mmlib.algs.dl4j

import streaming.dl4j.{DL4JModelLoader, DL4JModelPredictor}
import java.util.Random

import org.apache.spark.ml.linalg.SQLDataTypes._
import streaming.dsl.mmlib.SQLAlg
import streaming.dsl.mmlib.algs.Functions
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.deeplearning4j.nn.api.OptimizationAlgorithm
import org.deeplearning4j.nn.conf.layers.{DenseLayer, OutputLayer}
import org.deeplearning4j.nn.conf.{NeuralNetConfiguration, Updater}
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.lossfunctions.LossFunctions
import streaming.dl4j.Dl4jFunctions
import streaming.dsl.mmlib.SQLAlg


/**
* Created by allwefantasy on 23/2/2018.
*/
class FCClassify extends SQLAlg with Functions {
class FCClassify extends SQLAlg with Dl4jFunctions {
def train(df: DataFrame, path: String, params: Map[String, String]): Unit = {
dl4jClassificationTrain(df, path, params, () => {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
package streaming.dsl.mmlib.algs.dl4j

import java.util.{Random}
import java.util.Random

import streaming.dl4j.{DL4JModelLoader, DL4JModelPredictor}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.deeplearning4j.nn.api.OptimizationAlgorithm
import org.deeplearning4j.nn.conf.layers.OutputLayer
import org.deeplearning4j.nn.conf.layers.variational.{BernoulliReconstructionDistribution, VariationalAutoencoder}
import org.deeplearning4j.nn.conf.{NeuralNetConfiguration, Updater}
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction
import streaming.dsl.mmlib.SQLAlg
import streaming.dsl.mmlib.algs.Functions
import org.apache.spark.ml.linalg.SQLDataTypes._
import org.deeplearning4j.nn.conf.layers.variational.{BernoulliReconstructionDistribution, VariationalAutoencoder}

/**
* Created by allwefantasy on 24/2/2018.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@ package streaming.dsl.mmlib.algs.dl4j

import java.util.Random

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.deeplearning4j.nn.api.OptimizationAlgorithm
import org.deeplearning4j.nn.conf.layers.{DenseLayer, GravesLSTM, OutputLayer}
import org.deeplearning4j.nn.conf.layers.{GravesLSTM, OutputLayer}
import org.deeplearning4j.nn.conf.{NeuralNetConfiguration, Updater}
import org.deeplearning4j.nn.weights.WeightInit
import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.lossfunctions.LossFunctions
import streaming.dsl.mmlib.SQLAlg
import streaming.dsl.mmlib.algs.Functions

/**
* Created by allwefantasy on 26/2/2018.
Expand Down
12 changes: 11 additions & 1 deletion streamingpro-spark-2.0/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,16 @@
</dependencies>
</profile>

<profile>
<id>dl4j</id>
<dependencies>
<dependency>
<groupId>streaming.king</groupId>
<artifactId>streamingpro-dl4j</artifactId>
<version>${parent.version}</version>
</dependency>
</dependencies>
</profile>
<profile>
<id>online</id>
<properties>
Expand Down Expand Up @@ -195,7 +205,7 @@

<dependency>
<groupId>streaming.king</groupId>
<artifactId>streamingpro-model</artifactId>
<artifactId>streamingpro-tensorflow</artifactId>
<version>${parent.version}</version>
</dependency>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package org.apache.spark.ps.cluster

import streaming.tensorflow.TFModelLoader
import java.util.{Locale}

import org.apache.spark.internal.Logging
import org.apache.spark.{SparkEnv}
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.util.ThreadUtils
import streaming.tensorflow.TFModelLoader

import scala.util.{Failure, Success}


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.apache.spark.ps.local

import streaming.tensorflow.TFModelLoader
import java.io.File
import java.net.URL
import org.apache.spark.{SparkConf, SparkContext, SparkEnv}
Expand All @@ -8,7 +9,7 @@ import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle}
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.StopExecutor
import streaming.tensorflow.TFModelLoader


private case class TensorFlowModelClean(modelPath: String)

Expand Down
Loading

0 comments on commit ee31da1

Please sign in to comment.