Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline and DataFrames #2

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
package ml.dmlc.xgboost4j.scala.spark

import scala.collection.JavaConverters._

import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector}
import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}

import ml.dmlc.xgboost4j.LabeledPoint
import org.apache.spark.sql.DataFrame
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

object DataUtils extends Serializable {
implicit def fromSparkPointsToXGBoostPointsJava(sps: Iterator[SparkLabeledPoint])
Expand Down Expand Up @@ -58,4 +60,21 @@ object DataUtils extends Serializable {
}
}
}

implicit def dataframeToLabledPoints(dataset: DataFrame, labelColumn: String = "label",
featuresColumn: String = "features"): RDD[SparkLabeledPoint] = {
dataset.select(labelColumn, featuresColumn).rdd map { row =>
new SparkLabeledPoint(row.getDouble(0), row.getAs[Vector](1))}
}

def appendOutput(df: DataFrame, colName: String, colType: DataType,
values: RDD[Array[Array[Float]]]): DataFrame = {

val dfRDD = df.rdd.zipWithIndex() map {x => (x._2, x._1) }
val dataRDD = values.zipWithIndex() map {x => (x._2, x._1) }
val rows = dfRDD.join(dataRDD) map { case(id, (row, value)) =>
Row.fromSeq(row.toSeq :+ value)}
df.sqlContext.createDataFrame(rows, StructType(
df.schema.fields :+ StructField(colName, colType, true)))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
Copyright (c) 2014 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ml.dmlc.xgboost4j.scala.spark.pipeline

import org.apache.spark.ml.Estimator
import org.apache.spark.ml.XGBoostParams
import org.apache.spark.sql.DataFrame
import ml.dmlc.xgboost4j.scala.spark.XGBoost
import ml.dmlc.xgboost4j.scala.spark.DataUtils._
import org.apache.spark.sql.types.StructType
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.param.ParamMap

/**
* An XGBoost estimator to be fit on the training set. Returns the instance
* of XGBoostModel as a learned model.
*/
class XGBoost(override val uid: String) extends Estimator[XGBoostModel]
with XGBoostParams {

def this() = this(Identifiable.randomUID("xgb"))

/**
* Fit the XGBoost model on the dataset provided.
* @param dataset The training set to fit the model on.
* @returns An instance of XGBoostModel.
*/
override def fit(dataset: DataFrame): XGBoostModel = {

transformSchema(dataset.schema, logging = true)
val trainData = dataframeToLabledPoints(dataset, $(labelCol), $(featuresCol))
val model = XGBoost.train(trainData, paramsMap, $(rounds), $(nWorkers),
useExternalMemory = $(useExternalCache))
copyValues(new XGBoostModel(model).setParent(this))
}

/**
* Returns the new XGBoost instance.
* @param extra Additional parameters for the new model.
*/
override def copy(extra: ParamMap): XGBoost = copyValues(new XGBoost(uid), extra)

/**
* Validate and Transform the input schema to output schema.
* @param schema Schema for the input dataset/dataframe.
*/
override def transformSchema(schema: StructType): StructType =
validateAndTransformSchema(schema, false)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
Copyright (c) 2014 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ml.dmlc.xgboost4j.scala.spark.pipeline

import org.apache.spark.ml.Model
import ml.dmlc.xgboost4j.scala.spark.{ XGBoostModel => XGBModel }
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.types.StructType
import org.apache.spark.ml.XGBoostParams
import org.apache.spark.sql.types.StructField
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector
import ml.dmlc.xgboost4j.scala.spark.DataUtils._

/**
* [[http://xgboost.readthedocs.io/en/latest/model.html XGBoost]] model for classification
* and regression.
*
* @param model The ml.dmlc.xgboost4j.scala.spark.XGBoostModel instance to delegate the tasks to.
*/
class XGBoostModel(override val uid: String, model: XGBModel)
extends Model[XGBoostModel] with XGBoostParams {

def this(model: XGBModel) = this(Identifiable.randomUID("xgb"), model)

/**
* Do the transformation - this means make the predictions for the given dataset.
* @param dataset Dataset to get the predictions for.
*/
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema)
val features: RDD[Vector] = dataset.select($(featuresCol))
.rdd.map { row => row.getAs[Vector](0) }

val output = model.predict(features, $(useExternalCache))
appendOutput(dataset, $(predictionCol), new VectorUDT, output)
}

/**
* Returns the new XGBoostModel
* @param extra Additional parameters for the new model.
*/
override def copy(extra: ParamMap): XGBoostModel = copyValues(new XGBoostModel(model), extra)

/**
* Validate and Transform the input schema to output schema.
* @param schema Schema for the input dataset/dataframe.
*/

override def transformSchema(schema: StructType): StructType =
validateAndTransformSchema(schema, false)
}
Loading