forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-13449] Naive Bayes wrapper in SparkR
## What changes were proposed in this pull request? This PR continues the work in apache#11486 from yinxusen with some code refactoring. In R package e1071, `naiveBayes` supports both categorical (Bernoulli) and continuous features (Gaussian), while in MLlib we support Bernoulli and multinomial. This PR implements the common subset: Bernoulli. I moved the implementation out from SparkRWrappers to NaiveBayesWrapper to make it easier to read. Argument names, default values, and summary now match e1071's naiveBayes. I removed the preprocess part that omit NA values because we don't know which columns to process. ## How was this patch tested? Test against output from R package e1071's naiveBayes. cc: yanboliang yinxusen Closes apache#11486 Author: Xusen Yin <[email protected]> Author: Xiangrui Meng <[email protected]> Closes apache#11890 from mengxr/SPARK-13449.
- Loading branch information
Showing
6 changed files
with
228 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
75 changes: 75 additions & 0 deletions
75
mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.r | ||
|
||
import org.apache.spark.ml.{Pipeline, PipelineModel} | ||
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} | ||
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel} | ||
import org.apache.spark.ml.feature.{IndexToString, RFormula} | ||
import org.apache.spark.sql.DataFrame | ||
|
||
private[r] class NaiveBayesWrapper private ( | ||
pipeline: PipelineModel, | ||
val labels: Array[String], | ||
val features: Array[String]) { | ||
|
||
import NaiveBayesWrapper._ | ||
|
||
private val naiveBayesModel: NaiveBayesModel = pipeline.stages(1).asInstanceOf[NaiveBayesModel] | ||
|
||
lazy val apriori: Array[Double] = naiveBayesModel.pi.toArray.map(math.exp) | ||
|
||
lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp) | ||
|
||
def transform(dataset: DataFrame): DataFrame = { | ||
pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL) | ||
} | ||
} | ||
|
||
private[r] object NaiveBayesWrapper { | ||
|
||
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" | ||
val PREDICTED_LABEL_COL = "prediction" | ||
|
||
def fit(formula: String, data: DataFrame, laplace: Double): NaiveBayesWrapper = { | ||
val rFormula = new RFormula() | ||
.setFormula(formula) | ||
.fit(data) | ||
// get labels and feature names from output schema | ||
val schema = rFormula.transform(data).schema | ||
val labelAttr = Attribute.fromStructField(schema(rFormula.getLabelCol)) | ||
.asInstanceOf[NominalAttribute] | ||
val labels = labelAttr.values.get | ||
val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol)) | ||
.attributes.get | ||
val features = featureAttrs.map(_.name.get) | ||
// assemble and fit the pipeline | ||
val naiveBayes = new NaiveBayes() | ||
.setSmoothing(laplace) | ||
.setModelType("bernoulli") | ||
.setPredictionCol(PREDICTED_LABEL_INDEX_COL) | ||
val idxToStr = new IndexToString() | ||
.setInputCol(PREDICTED_LABEL_INDEX_COL) | ||
.setOutputCol(PREDICTED_LABEL_COL) | ||
.setLabels(labels) | ||
val pipeline = new Pipeline() | ||
.setStages(Array(rFormula, naiveBayes, idxToStr)) | ||
.fit(data) | ||
new NaiveBayesWrapper(pipeline, labels, features) | ||
} | ||
} |