Skip to content

Commit

Permalink
[SPARK-13449] Naive Bayes wrapper in SparkR
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR continues the work in apache#11486 from yinxusen with some code refactoring. In R package e1071, `naiveBayes` supports both categorical (Bernoulli) and continuous features (Gaussian), while in MLlib we support Bernoulli and multinomial. This PR implements the common subset: Bernoulli.

I moved the implementation out from SparkRWrappers to NaiveBayesWrapper to make it easier to read. Argument names, default values, and summary now match e1071's naiveBayes.

I removed the preprocess part that omit NA values because we don't know which columns to process.

## How was this patch tested?

Test against output from R package e1071's naiveBayes.

cc: yanboliang yinxusen

Closes apache#11486

Author: Xusen Yin <[email protected]>
Author: Xiangrui Meng <[email protected]>

Closes apache#11890 from mengxr/SPARK-13449.
  • Loading branch information
yinxusen authored and mengxr committed Mar 22, 2016
1 parent b2b1ad7 commit d6dc12e
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 7 deletions.
3 changes: 2 additions & 1 deletion R/pkg/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ Depends:
R (>= 3.0),
methods,
Suggests:
testthat
testthat,
e1071
Description: R frontend for Spark
License: Apache License (== 2.0)
Collate:
Expand Down
3 changes: 2 additions & 1 deletion R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ exportMethods("glm",
"predict",
"summary",
"kmeans",
"fitted")
"fitted",
"naiveBayes")

# Job group lifecycle management methods
export("setJobGroup",
Expand Down
4 changes: 4 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1175,3 +1175,7 @@ setGeneric("kmeans")
#' @rdname fitted
#' @export
setGeneric("fitted")

#' @rdname naiveBayes
#' @export
setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") })
91 changes: 86 additions & 5 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
#' @export
setClass("PipelineModel", representation(model = "jobj"))

#' @title S4 class that represents a NaiveBayesModel
#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper
#' @export
setClass("NaiveBayesModel", representation(jobj = "jobj"))

#' Fits a generalized linear model
#'
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
Expand All @@ -42,7 +47,7 @@ setClass("PipelineModel", representation(model = "jobj"))
#' @rdname glm
#' @export
#' @examples
#'\dontrun{
#' \dontrun{
#' sc <- sparkR.init()
#' sqlContext <- sparkRSQL.init(sc)
#' data(iris)
Expand Down Expand Up @@ -71,7 +76,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram
#' @rdname predict
#' @export
#' @examples
#'\dontrun{
#' \dontrun{
#' model <- glm(y ~ x, trainingData)
#' predicted <- predict(model, testData)
#' showDF(predicted)
Expand All @@ -81,6 +86,26 @@ setMethod("predict", signature(object = "PipelineModel"),
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
})

#' Make predictions from a naive Bayes model
#'
#' Makes predictions from a model produced by naiveBayes(), similarly to R package e1071's predict.
#'
#' @param object A fitted naive Bayes model
#' @param newData DataFrame for testing
#' @return DataFrame containing predicted labels in a column named "prediction"
#' @rdname predict
#' @export
#' @examples
#' \dontrun{
#' model <- naiveBayes(y ~ x, trainingData)
#' predicted <- predict(model, testData)
#' showDF(predicted)
#'}
setMethod("predict", signature(object = "NaiveBayesModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})

#' Get the summary of a model
#'
#' Returns the summary of a model produced by glm(), similarly to R's summary().
Expand All @@ -97,7 +122,7 @@ setMethod("predict", signature(object = "PipelineModel"),
#' @rdname summary
#' @export
#' @examples
#'\dontrun{
#' \dontrun{
#' model <- glm(y ~ x, trainingData)
#' summary(model)
#'}
Expand Down Expand Up @@ -140,6 +165,35 @@ setMethod("summary", signature(object = "PipelineModel"),
}
})

#' Get the summary of a naive Bayes model
#'
#' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary().
#'
#' @param object A fitted MLlib model
#' @return a list containing 'apriori', the label distribution, and 'tables', conditional
# probabilities given the target label
#' @rdname summary
#' @export
#' @examples
#' \dontrun{
#' model <- naiveBayes(y ~ x, trainingData)
#' summary(model)
#'}
setMethod("summary", signature(object = "NaiveBayesModel"),
function(object, ...) {
jobj <- object@jobj
features <- callJMethod(jobj, "features")
labels <- callJMethod(jobj, "labels")
apriori <- callJMethod(jobj, "apriori")
apriori <- t(as.matrix(unlist(apriori)))
colnames(apriori) <- unlist(labels)
tables <- callJMethod(jobj, "tables")
tables <- matrix(tables, nrow = length(labels))
rownames(tables) <- unlist(labels)
colnames(tables) <- unlist(features)
return(list(apriori = apriori, tables = tables))
})

#' Fit a k-means model
#'
#' Fit a k-means model, similarly to R's kmeans().
Expand All @@ -152,7 +206,7 @@ setMethod("summary", signature(object = "PipelineModel"),
#' @rdname kmeans
#' @export
#' @examples
#'\dontrun{
#' \dontrun{
#' model <- kmeans(x, centers = 2, algorithm="random")
#'}
setMethod("kmeans", signature(x = "DataFrame"),
Expand All @@ -173,7 +227,7 @@ setMethod("kmeans", signature(x = "DataFrame"),
#' @rdname fitted
#' @export
#' @examples
#'\dontrun{
#' \dontrun{
#' model <- kmeans(trainingData, 2)
#' fitted.model <- fitted(model)
#' showDF(fitted.model)
Expand All @@ -192,3 +246,30 @@ setMethod("fitted", signature(object = "PipelineModel"),
stop(paste("Unsupported model", modelName, sep = " "))
}
})

#' Fit a Bernoulli naive Bayes model
#'
#' Fit a Bernoulli naive Bayes model, similarly to R package e1071's naiveBayes() while only
#' categorical features are supported. The input should be a DataFrame of observations instead of a
#' contingency table.
#'
#' @param object A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' @param data DataFrame for training
#' @param laplace Smoothing parameter
#' @return a fitted naive Bayes model
#' @rdname naiveBayes
#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/}
#' @export
#' @examples
#' \dontrun{
#' df <- createDataFrame(sqlContext, infert)
#' model <- naiveBayes(education ~ ., df, laplace = 0)
#'}
setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"),
function(formula, data, laplace = 0, ...) {
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
formula, data@sdf, laplace)
return(new("NaiveBayesModel", jobj = jobj))
})
59 changes: 59 additions & 0 deletions R/pkg/inst/tests/testthat/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,62 @@ test_that("kmeans", {
cluster <- summary.model$cluster
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
})

test_that("naiveBayes", {
# R code to reproduce the result.
# We do not support instance weights yet. So we ignore the frequencies.
#
#' library(e1071)
#' t <- as.data.frame(Titanic)
#' t1 <- t[t$Freq > 0, -5]
#' m <- naiveBayes(Survived ~ ., data = t1)
#' m
#' predict(m, t1)
#
# -- output of 'm'
#
# A-priori probabilities:
# Y
# No Yes
# 0.4166667 0.5833333
#
# Conditional probabilities:
# Class
# Y 1st 2nd 3rd Crew
# No 0.2000000 0.2000000 0.4000000 0.2000000
# Yes 0.2857143 0.2857143 0.2857143 0.1428571
#
# Sex
# Y Male Female
# No 0.5 0.5
# Yes 0.5 0.5
#
# Age
# Y Child Adult
# No 0.2000000 0.8000000
# Yes 0.4285714 0.5714286
#
# -- output of 'predict(m, t1)'
#
# Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No
#

t <- as.data.frame(Titanic)
t1 <- t[t$Freq > 0, -5]
df <- suppressWarnings(createDataFrame(sqlContext, t1))
m <- naiveBayes(Survived ~ ., data = df)
s <- summary(m)
expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6)
expect_equal(sum(s$apriori), 1)
expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6)
p <- collect(select(predict(m, df), "prediction"))
expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No",
"Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No",
"Yes", "Yes", "No", "No"))

# Test e1071::naiveBayes
if (requireNamespace("e1071", quietly = TRUE)) {
expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error()))
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
}
})
75 changes: 75 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.r

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.ml.feature.{IndexToString, RFormula}
import org.apache.spark.sql.DataFrame

private[r] class NaiveBayesWrapper private (
pipeline: PipelineModel,
val labels: Array[String],
val features: Array[String]) {

import NaiveBayesWrapper._

private val naiveBayesModel: NaiveBayesModel = pipeline.stages(1).asInstanceOf[NaiveBayesModel]

lazy val apriori: Array[Double] = naiveBayesModel.pi.toArray.map(math.exp)

lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp)

def transform(dataset: DataFrame): DataFrame = {
pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL)
}
}

private[r] object NaiveBayesWrapper {

val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
val PREDICTED_LABEL_COL = "prediction"

def fit(formula: String, data: DataFrame, laplace: Double): NaiveBayesWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.fit(data)
// get labels and feature names from output schema
val schema = rFormula.transform(data).schema
val labelAttr = Attribute.fromStructField(schema(rFormula.getLabelCol))
.asInstanceOf[NominalAttribute]
val labels = labelAttr.values.get
val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
.attributes.get
val features = featureAttrs.map(_.name.get)
// assemble and fit the pipeline
val naiveBayes = new NaiveBayes()
.setSmoothing(laplace)
.setModelType("bernoulli")
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
val idxToStr = new IndexToString()
.setInputCol(PREDICTED_LABEL_INDEX_COL)
.setOutputCol(PREDICTED_LABEL_COL)
.setLabels(labels)
val pipeline = new Pipeline()
.setStages(Array(rFormula, naiveBayes, idxToStr))
.fit(data)
new NaiveBayesWrapper(pipeline, labels, features)
}
}

0 comments on commit d6dc12e

Please sign in to comment.