From 961c3b213c6bf68e410e7601346ea7ee13a0808a Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Wed, 11 Oct 2023 15:44:41 +0200 Subject: [PATCH] feat: keep_results can be character vector of IDs This can be useful when wanting access to an object that is not an output node of the graph, i.e. we don't have to add a `PipeOpNOP` (or keep all results) to achieve this. --- NEWS.md | 3 +++ R/Graph.R | 5 +++-- R/PipeOp.R | 2 +- man/Graph.Rd | 3 ++- man/PipeOp.Rd | 2 +- man/mlr_pipeops_nmf.Rd | 2 +- tests/testthat/test_Graph.R | 11 +++++++++++ 7 files changed, 22 insertions(+), 6 deletions(-) diff --git a/NEWS.md b/NEWS.md index a1074398a..7b147489e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # mlr3pipelines 0.5.0-9000 +* Feature: The `Graph`'s `keep_results` can now also be a character vector +containing the IDs of the `PipeOp`s whose results are being stored. + # mlr3pipelines 0.5.0-1 * Bugfix: `PipeOpTuneThreshold` was not overloading the correct `.train` and `.predict` functions. diff --git a/R/Graph.R b/R/Graph.R index 8cc95a0ae..b9070ffc1 100644 --- a/R/Graph.R +++ b/R/Graph.R @@ -59,8 +59,9 @@ #' * `phash` :: `character(1)` \cr #' Stores a checksum calculated on the [`Graph`] configuration, which includes all [`PipeOp`] hashes #' *except* their `$param_set$values`, and a hash of `$edges`. -#' * `keep_results` :: `logical(1)` \cr +#' * `keep_results` :: `logical(1)` or `character()` \cr #' Whether to store intermediate results in the [`PipeOp`]'s `$.result` slot, mostly for debugging purposes. Default `FALSE`. +#' Can also be a character vector of IDs, in which case only the results of the selected `PipeOp`s are stored. #' * `man` :: `character(1)`\cr #' Identifying string of the help page that shows with `help()`. #' @@ -642,7 +643,7 @@ graph_reduce = function(self, input, fun, single_input) { lg$debug("Running PipeOp '%s$%s()'", id, fun, pipeop = op, input = input) output = op[[fun]](input) - if (self$keep_results) { + if (isTRUE(self$keep_results) || op$id %in% self$keep_results) { op$.result = output } diff --git a/R/PipeOp.R b/R/PipeOp.R index 24a12ab6c..c9b393e70 100644 --- a/R/PipeOp.R +++ b/R/PipeOp.R @@ -130,7 +130,7 @@ #' [`PipeOp`]'s functionality may change depending on more than these values, it should inherit the `$hash` active #' binding and calculate the hash as `digest(list(super$hash, ), algo = "xxhash64")`. #' * `.result` :: `list` \cr -#' If the [`Graph`]'s `$keep_results` flag is set to `TRUE`, then the intermediate Results of `$train()` and `$predict()` +#' If the [`Graph`]'s `$keep_results` flag is set to `TRUE` or contains the ID of this `PipeOp`, then the intermediate Results of `$train()` and `$predict()` #' are saved to this slot, exactly as they are returned by these functions. This is mainly for debugging purposes #' and done, if requested, by the [`Graph`] backend itself; it should *not* be done explicitly by `private$.train()` or `private$.predict()`. #' * `man` :: `character(1)`\cr diff --git a/man/Graph.Rd b/man/Graph.Rd index 8da82dac3..fa51c234e 100644 --- a/man/Graph.Rd +++ b/man/Graph.Rd @@ -69,8 +69,9 @@ Stores a checksum calculated on the \code{\link{Graph}} configuration, which inc \item \code{phash} :: \code{character(1)} \cr Stores a checksum calculated on the \code{\link{Graph}} configuration, which includes all \code{\link{PipeOp}} hashes \emph{except} their \verb{$param_set$values}, and a hash of \verb{$edges}. -\item \code{keep_results} :: \code{logical(1)} \cr +\item \code{keep_results} :: \code{logical(1)} or \code{character()} \cr Whether to store intermediate results in the \code{\link{PipeOp}}'s \verb{$.result} slot, mostly for debugging purposes. Default \code{FALSE}. +Can also be a character vector of IDs, in which case only the results of the selected \code{PipeOp}s are stored. \item \code{man} :: \code{character(1)}\cr Identifying string of the help page that shows with \code{help()}. } diff --git a/man/PipeOp.Rd b/man/PipeOp.Rd index 4292943b1..553ca0301 100644 --- a/man/PipeOp.Rd +++ b/man/PipeOp.Rd @@ -137,7 +137,7 @@ Checksum calculated on the \code{\link{PipeOp}}, depending on the \code{\link{Pi \code{\link{PipeOp}}'s functionality may change depending on more than these values, it should inherit the \verb{$hash} active binding and calculate the hash as \verb{digest(list(super$hash, ), algo = "xxhash64")}. \item \code{.result} :: \code{list} \cr -If the \code{\link{Graph}}'s \verb{$keep_results} flag is set to \code{TRUE}, then the intermediate Results of \verb{$train()} and \verb{$predict()} +If the \code{\link{Graph}}'s \verb{$keep_results} flag is set to \code{TRUE} or contains the ID of this \code{PipeOp}, then the intermediate Results of \verb{$train()} and \verb{$predict()} are saved to this slot, exactly as they are returned by these functions. This is mainly for debugging purposes and done, if requested, by the \code{\link{Graph}} backend itself; it should \emph{not} be done explicitly by \code{private$.train()} or \code{private$.predict()}. \item \code{man} :: \code{character(1)}\cr diff --git a/man/mlr_pipeops_nmf.Rd b/man/mlr_pipeops_nmf.Rd index 5e967fab2..3c8a75c9a 100644 --- a/man/mlr_pipeops_nmf.Rd +++ b/man/mlr_pipeops_nmf.Rd @@ -96,7 +96,7 @@ See \code{\link[NMF:nmf]{nmf()}}. \section{Internals}{ -Uses the \code{\link[NMF:nmf]{nmf()}} function as well as \code{\link[NMF:basis-coef-methods]{basis()}}, \code{\link[NMF:basis-coef-methods]{coef()}} and +Uses the \code{\link[NMF:nmf]{nmf()}} function as well as \code{\link[NMF:basis]{basis()}}, \code{\link[NMF:coef]{coef()}} and \code{\link[MASS:ginv]{ginv()}}. } diff --git a/tests/testthat/test_Graph.R b/tests/testthat/test_Graph.R index ed6f900d8..af40c8f15 100644 --- a/tests/testthat/test_Graph.R +++ b/tests/testthat/test_Graph.R @@ -501,3 +501,14 @@ test_that("Same output into multiple channels does not cause a bug", { expect_true(res$po3.output1 == 2) expect_true(res$po4.output1 == 2) }) + +test_that("keep_results can be a character vector", { + graph = po("pca") %>>% po("ica") + + graph$keep_results = "pca" + + graph$train(tsk("iris")) + + expect_true(is.null(graph$pipeops$ica$.result)) + expect_class(graph$pipeops$pca$.result[[1L]], "Task") +})