diff --git a/R/orsf_R6.R b/R/orsf_R6.R index 8c832576..29979b46 100644 --- a/R/orsf_R6.R +++ b/R/orsf_R6.R @@ -3051,7 +3051,6 @@ ObliqueForest <- R6::R6Class( unique(variable[predictor %in% colnames(cpp_args$x)]) ) - oob_data[n_predictors, `:=`(n_predictors = n_predictors, stat_value = cpp_output$eval_oobag$stat_values[1,1], @@ -3059,7 +3058,7 @@ ObliqueForest <- R6::R6Class( predictors_included = colnames(cpp_args$x), predictor_dropped = worst_predictor)] - cpp_args$x <- cpp_args$x[, -worst_index] + cpp_args$x <- cpp_args$x[, -worst_index, drop = FALSE] n_predictors <- n_predictors - 1 current_progress <- current_progress + 1 @@ -3069,6 +3068,7 @@ ObliqueForest <- R6::R6Class( cat("Selecting variables: 100%\n") } + collapse::na_omit(oob_data) }, diff --git a/tests/testthat/test-orsf_vs.R b/tests/testthat/test-orsf_vs.R index f5076b65..c1180140 100644 --- a/tests/testthat/test-orsf_vs.R +++ b/tests/testthat/test-orsf_vs.R @@ -91,6 +91,21 @@ test_that( } ) +test_that( + desc = "variable selection can go down to 1 predictor", + code = { + + fit_cars <- orsf(mpg ~ ., data = mtcars, n_tree = n_tree_test) + + vs <- orsf_vs(fit_cars, n_predictor_min = 1) + # assert that we eliminated 1 predictor at each step and got down to + # 1 remaining predictor + expect_equal(nrow(vs), ncol(mtcars) - 1) + expect_length(vs$variables_included[[1]], 1) + expect_length(vs$predictors_included[[1]], 1) + + } +)