Skip to content

Commit

Permalink
Use parent forest num.threads in auxiliary forest (#1437)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikcs authored Aug 30, 2024
1 parent 49aac3d commit f3d858d
Show file tree
Hide file tree
Showing 11 changed files with 16 additions and 3 deletions.
1 change: 1 addition & 0 deletions r-package/grf/R/causal_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ causal_forest <- function(X, Y, W,
forest <- do.call.rcpp(causal_train, c(data, args))
class(forest) <- c("causal_forest", "grf")
forest[["seed"]] <- seed
forest[["num.threads"]] <- num.threads
forest[["ci.group.size"]] <- ci.group.size
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/causal_survival_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ causal_survival_forest <- function(X, Y, W, D,
forest <- do.call.rcpp(causal_survival_train, c(data, args))
class(forest) <- c("causal_survival_forest", "grf")
forest[["seed"]] <- seed
forest[["num.threads"]] <- num.threads
forest[["_psi"]] <- psi
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
Expand Down
9 changes: 6 additions & 3 deletions r-package/grf/R/get_scores.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ get_scores.causal_forest <- function(forest,
sample.weights = forest$sample.weights,
num.trees = num.trees.for.weights,
ci.group.size = 1,
seed = forest$seed)
seed = forest$seed,
num.threads = forest$num.threads)
V.hat <- predict(variance_forest)$predictions
debiasing.weights.all <- (forest$W.orig - forest$W.hat) / V.hat
debiasing.weights <- debiasing.weights.all[subset]
Expand Down Expand Up @@ -178,7 +179,8 @@ get_scores.instrumental_forest <- function(forest,
sample.weights = forest$sample.weights,
clusters = clusters,
num.trees = num.trees.for.weights,
seed = forest$seed)
seed = forest$seed,
num.threads = forest$num.threads)
compliance.score <- predict(compliance.forest)$predictions
compliance.score <- compliance.score[subset]
} else if (length(compliance.score) == length(forest$Y.orig)) {
Expand Down Expand Up @@ -342,7 +344,8 @@ get_scores.causal_survival_forest <- function(forest,
sample.weights = forest$sample.weights,
num.trees = num.trees.for.weights,
ci.group.size = 1,
seed = forest$seed)
seed = forest$seed,
num.threads = forest$num.threads)
V.hat <- predict(variance_forest)$predictions[subset]
}

Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/instrumental_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ instrumental_forest <- function(X, Y, W, Z,
forest <- do.call.rcpp(instrumental_train, c(data, args))
class(forest) <- c("instrumental_forest", "grf")
forest[["seed"]] <- seed
forest[["num.threads"]] <- num.threads
forest[["ci.group.size"]] <- ci.group.size
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/ll_regression_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ ll_regression_forest <- function(X, Y,

class(forest) <- c("ll_regression_forest", "grf")
forest[["seed"]] <- seed
forest[["num.threads"]] <- num.threads
forest[["ci.group.size"]] <- ci.group.size
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/multi_arm_causal_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ multi_arm_causal_forest <- function(X, Y, W,
forest <- do.call.rcpp(multi_causal_train, c(data, args))
class(forest) <- c("multi_arm_causal_forest", "grf")
forest[["seed"]] <- seed
forest[["num.threads"]] <- num.threads
forest[["ci.group.size"]] <- ci.group.size
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/multi_regression_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ multi_regression_forest <- function(X, Y,
forest <- do.call.rcpp(multi_regression_train, c(data, args))
class(forest) <- c("multi_regression_forest", "grf")
forest[["seed"]] <- seed
forest[["num.threads"]] <- num.threads
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
forest[["sample.weights"]] <- sample.weights
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/probability_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ probability_forest <- function(X, Y,
forest <- do.call.rcpp(probability_train, c(data, args))
class(forest) <- c("probability_forest", "grf")
forest[["seed"]] <- seed
forest[["num.threads"]] <- num.threads
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
forest[["Y.relabeled"]] <- Y.relabeled
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/quantile_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ quantile_forest <- function(X, Y,
forest <- do.call.rcpp(quantile_train, c(data, args))
class(forest) <- c("quantile_forest", "grf")
forest[["seed"]] <- seed
forest[["num.threads"]] <- num.threads
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
forest[["quantiles.orig"]] <- quantiles
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/regression_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ regression_forest <- function(X, Y,
forest <- do.call.rcpp(regression_train, c(data, args))
class(forest) <- c("regression_forest", "grf")
forest[["seed"]] <- seed
forest[["num.threads"]] <- num.threads
forest[["ci.group.size"]] <- ci.group.size
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/R/survival_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ survival_forest <- function(X, Y, D,
forest <- do.call.rcpp(survival_train, c(data, args))
class(forest) <- c("survival_forest", "grf")
forest[["seed"]] <- seed
forest[["num.threads"]] <- num.threads
forest[["X.orig"]] <- X
forest[["Y.orig"]] <- Y
forest[["Y.relabeled"]] <- Y.relabeled
Expand Down

0 comments on commit f3d858d

Please sign in to comment.