Skip to content

Commit

Permalink
Fix plotFift
Browse files Browse the repository at this point in the history
  • Loading branch information
agosiewska committed Apr 19, 2018
1 parent 979edf9 commit 6143c09
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 49 deletions.
34 changes: 7 additions & 27 deletions R/plotCGains.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
#' @description Cumulative Gains Chartis a plot of the rate of positive prediction against true positive rate for the different thresholds.
#' It is useful for measuring and comparing the accuracy of the classificators.
#' @param object An object of class ModelAudit
#' @param newdata optionally, a data frame in which to look for variables with which to plot CGains curve. If omitted, the data used to build model will be used.
#' @param newy optionally, required if newdata used. Response vector for new data.
#' @param ... other modelAudit objects to be plotted together
#'
#' @return ggplot object
Expand All @@ -14,33 +12,20 @@
#' @import ggplot2
#' @import ROCR
#'
#' @examples
#' library(auditor)
#' library(mlbench)
#' library(randomForest)
#' data("PimaIndiansDiabetes")
#'
#' model_rf <- randomForest(diabetes~., data=PimaIndiansDiabetes)
#' au_rf <- audit(model_rf, label="rf")
#' plotCGains(au_rf)
#'
#' model_glm <- glm(diabetes~., family=binomial, data=PimaIndiansDiabetes)
#' au_glm <- audit(model_glm)
#' plotCGains(au_rf, au_glm)
#'
#' @export


plotCGains <- function(object, ..., newdata = NULL, newy){
plotCGains <- function(object, ...){
if(class(object)!="modelAudit") stop("plotCGains requires object class modelAudit.")
rpp <- tpr <- label <- NULL
df <- getCGainsDF(object, newdata, newy)
df <- getCGainsDF(object)

dfl <- list(...)
if (length(dfl) > 0) {
for (resp in dfl) {
if(class(resp)=="modelAudit"){
df <- rbind( df, getCGainsDF(resp, newdata, newy) )
df <- rbind( df, getCGainsDF(resp) )
}
}
}
Expand All @@ -52,15 +37,10 @@ plotCGains <- function(object, ..., newdata = NULL, newy){
theme_light()
}

getCGainsDF <- function(object, newdata, newy){
if (is.null(newdata)) {
predictions <- object$fitted.values
y <- object$y
} else {
if(is.null(newy)) stop("newy must be provided.")
predictions <- object$predict.function(object$model, newdata)
y <- newy
}
getCGainsDF <- function(object){

predictions <- object$fitted.values
y <- as.numeric(as.character(object$y))

pred <- prediction(predictions, y)
gain <- performance(pred, "tpr", "rpp")
Expand Down
2 changes: 1 addition & 1 deletion R/plotLift.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ plotLIFT <- function(object, ..., groups = 10, cumulative = TRUE){

getLIFTDF <- function(object, n.groups, cumulative = TRUE){
pred <- NULL
y = object$y
y = as.numeric(as.character(object$y))
df <- data.frame(pred=object$fitted.values, y=y)
df <- arrange(df, desc(pred))

Expand Down
2 changes: 1 addition & 1 deletion R/plotModelPCA.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ plotModelPCA <- function(object, ...){
}
}

res.pca <- prcomp(df, scale = TRUE)
res.pca <- prcomp(df, scale = FALSE)

fviz_pca_biplot(res.pca,
repel = TRUE,
Expand Down
21 changes: 1 addition & 20 deletions man/plotCGains.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 6143c09

Please sign in to comment.