Skip to content

Commit

Permalink
Merge pull request #70 from ropensci/vs-step-size
Browse files Browse the repository at this point in the history
allow steps of >1 in orsf_vs
  • Loading branch information
bcjaeger authored Feb 23, 2025
2 parents 6b13ee7 + 080680a commit be89df6
Show file tree
Hide file tree
Showing 18 changed files with 237 additions and 137 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- {os: windows-latest, r: 'release'}
- {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'}
- {os: ubuntu-latest, r: 'release'}
- {os: ubuntu-latest, r: 'oldrel-1'}
# - {os: ubuntu-latest, r: 'oldrel-1'}

env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
Expand Down
23 changes: 0 additions & 23 deletions .github/workflows/draft-pdf.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.2
6 changes: 5 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# aorsf 0.1.5 (unreleased)
# aorsf 0.1.6 (unreleased)

* added `n_predictor_drop` to `orsf_vs()`. Dropping one predictor at a time makes `orsf_vs()` slow for data with hundreds of predictors. Using a larger value for `n_predictor_drop` helps speed this up. The default value of `n_predictor_drop` is 1 to maintain backward compatibility.

# aorsf 0.1.5

* fixed an issue where omitting NA values would cause an error in regression forests.

Expand Down
4 changes: 3 additions & 1 deletion R/coerce_nans.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#' @noRd
coerce_nans <- function(x, to){
UseMethod('coerce_nans')
}

#' @noRd
coerce_nans.list <- function(x, to){

lapply(x, coerce_nans, to = to)

}

#' @noRd
coerce_nans.factor <-
coerce_nans.integer <-
coerce_nans.double <-
Expand Down
36 changes: 27 additions & 9 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,9 @@ ObliqueForest <- R6::R6Class(

# Variable selection
# returns a data.table with variable selection info
select_variables = function(n_predictor_min, verbose_progress){
select_variables = function(n_predictor_min,
n_predictor_drop,
verbose_progress){

public_state <- list(verbose_progress = self$verbose_progress,
forest = self$forest,
Expand All @@ -712,7 +714,9 @@ ObliqueForest <- R6::R6Class(
object_trained <- self$trained

out <- try(
private$select_variables_internal(n_predictor_min, verbose_progress)
private$select_variables_internal(n_predictor_min,
n_predictor_drop,
verbose_progress)
)

private$restore_state(public_state, private_state = NULL)
Expand Down Expand Up @@ -2928,9 +2932,11 @@ ObliqueForest <- R6::R6Class(

},

select_variables_internal = function(n_predictor_min, verbose_progress){
select_variables_internal = function(n_predictor_min,
n_predictor_drop,
verbose_progress){

n_predictors <- length(private$data_names$x_original)
n_predictors <- length(private$data_names$x_ref_code)

# verbose progress on the forest should always be FALSE
# because for orsf_vs, verbosity is coordinated in R
Expand All @@ -2941,7 +2947,7 @@ ObliqueForest <- R6::R6Class(
stat_value = rep(NA_real_, n_predictors),
variables_included = vector(mode = 'list', length = n_predictors),
predictors_included = vector(mode = 'list', length = n_predictors),
predictor_dropped = rep(NA_character_, n_predictors)
predictor_dropped = vector(mode = 'list', length = n_predictors)
)

# if the forest was not trained prior to variable selection
Expand Down Expand Up @@ -3045,9 +3051,21 @@ ObliqueForest <- R6::R6Class(
cpp_args$mtry <- mtry_safe
cpp_output <- do.call(orsf_cpp, args = cpp_args)

worst_index <- which.min(cpp_output$importance)
worst_predictor <- colnames(cpp_args$x)[worst_index]
n_drop <- min(n_predictor_drop,
n_predictors - n_predictor_min)

if(n_drop > 0){

worst_index <- order(cpp_output$importance)[seq(n_drop)]

worst_predictor <- colnames(cpp_args$x)[worst_index]

} else {

worst_predictor <- NA_character_
n_drop <- 1

}

.variables_included <- with(
variable_key,
Expand All @@ -3062,8 +3080,8 @@ ObliqueForest <- R6::R6Class(
predictor_dropped = worst_predictor)]

cpp_args$x <- cpp_args$x[, -worst_index, drop = FALSE]
n_predictors <- n_predictors - 1
current_progress <- current_progress + 1
n_predictors <- n_predictors - n_drop
current_progress <- current_progress + n_drop

}

Expand Down
5 changes: 4 additions & 1 deletion R/orsf_data_prep.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@

#' @noRd
orsf_data_prep <- function(data, ...){
UseMethod('orsf_data_prep')
}

#' @noRd
orsf_data_prep.list <- function(data, ...){

lengths <- vapply(data, length, integer(1))
Expand Down Expand Up @@ -43,12 +44,14 @@ orsf_data_prep.list <- function(data, ...){

}

#' @noRd
orsf_data_prep.recipe <- function(data, ...){

getElement(data, 'template')

}

#' @noRd
orsf_data_prep.data.frame <- function(data, ...){
data
}
20 changes: 19 additions & 1 deletion R/orsf_vs.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#'
#' @inheritParams predict.ObliqueForest
#' @param n_predictor_min (*integer*) the minimum number of predictors allowed
#' @param n_predictor_drop (*integer*) the number of predictors dropped at each step
#' @param verbose_progress (*logical*) not implemented yet. Should progress be printed to the console?
#'
#' @return a [data.table][data.table::data.table-package] with four columns:
Expand Down Expand Up @@ -38,8 +39,15 @@

orsf_vs <- function(object,
n_predictor_min = 3,
n_predictor_drop = 1,
verbose_progress = NULL){

if(object$importance_type == 'none'){
stop("object must be specified with importance",
"of 'anova', 'negate', or 'permute'",
call. = FALSE)
}

check_arg_is(arg_value = object,
arg_name = 'object',
expected_class = 'ObliqueForest')
Expand All @@ -55,6 +63,14 @@ orsf_vs <- function(object,
arg_name = 'n_predictor_min',
bound = 1)


check_arg_type(arg_value = n_predictor_drop,
arg_name = 'n_predictor_drop',
expected_type = 'numeric')

check_arg_is_integer(arg_value = n_predictor_drop,
arg_name = 'n_predictor_drop')

check_arg_lt(arg_value = n_predictor_min,
arg_name = 'n_predictor_min',
bound = length(object$get_names_x()),
Expand All @@ -74,7 +90,9 @@ orsf_vs <- function(object,
arg_name = 'verbose_progress',
expected_length = 1)

object$select_variables(n_predictor_min, verbose_progress)
object$select_variables(n_predictor_min,
n_predictor_drop,
verbose_progress)

}

Expand Down
50 changes: 19 additions & 31 deletions man/orsf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/orsf_control_cph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/orsf_control_custom.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/orsf_control_fast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/orsf_control_net.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit be89df6

Please sign in to comment.