Skip to content

Commit

Permalink
feat: add the DiseasyModel class
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusSkytte committed Oct 9, 2023
1 parent 5693142 commit bd8b8fe
Show file tree
Hide file tree
Showing 5 changed files with 398 additions and 20 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
export("%.%")
export(DiseasyActivity)
export(DiseasyBaseModule)
export(DiseasyModel)
export(DiseasyObservables)
export(DiseasySeason)
export(diseasyoption)
import(R6)
import(diseasystore)
import(lgr)
importFrom(digest,digest)
importFrom(diseasystore,`%.%`)
importFrom(dplyr,as_label)
importFrom(lubridate,today)
importFrom(pracma,logseq)
Expand Down
63 changes: 46 additions & 17 deletions R/0_documentation.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
rd_aggregation <- function(type = "param") {
checkmate::assert_choice(type, c("param", "field"))
paste("aggregation (`list`(`quosures`))\\cr",
paste("(`list`(`quosures`))\\cr",
"Default NULL.",
"If given, expressions in aggregation evaluated to give the aggregation level.",
ifelse(type == "field", " Read only.", ""))
Expand All @@ -9,15 +9,15 @@ rd_aggregation <- function(type = "param") {

rd_case_definition <- function(type = "param") {
checkmate::assert_choice(type, c("param", "field"))
paste("case_definition (`character`)\\cr",
paste("(`character`)\\cr",
"A character string that controls which feature store to get data from.",
ifelse(type == "field", " Read only.", ""))
}


rd_observable <- function(type = "param") {
checkmate::assert_choice(type, c("param", "field"))
paste("observable (`character`)\\cr",
paste("(`character`)\\cr",
"The observable to provide prediction for.",
"Must match observable in `DiseasyObservables` [R6][R6::R6Class] class.",
ifelse(type == "field", " Read only.", ""))
Expand All @@ -26,7 +26,7 @@ rd_observable <- function(type = "param") {

rd_prediction_length <- function(type = "param") {
checkmate::assert_choice(type, c("param", "field"))
paste("prediction_length (`numeric`)\\cr",
paste("(`numeric`)\\cr",
"The number of days to predict.",
"The prediction start is defined by `last_queryable_date` of the `DiseasyObservables` [R6][R6::R6Class] class.",
ifelse(type == "field", " Read only.", ""))
Expand All @@ -35,18 +35,18 @@ rd_prediction_length <- function(type = "param") {

rd_quantiles <- function(type = "param") {
checkmate::assert_choice(type, c("param", "field"))
paste("quantiles (`list`(`numeric`))\\cr",
paste("(`list`(`numeric`))\\cr",
"Default NULL.",
"If given, results are returned at the quantiles given",
"If given, results are returned at the quantiles given.",
ifelse(type == "field", " Read only.", ""))
}


rd_scale <- function(type = "param") {
checkmate::assert_choice(type, c("param", "field"))
paste("scale (`numeric`)\\cr",
paste("(`numeric`)\\cr",
"Sets the scale of the season model.",
"The scale is the percent wise difference between most active and least active period",
"The scale is the percent wise difference between most active and least active period.",
ifelse(type == "field", " Read only.", ""))
}

Expand All @@ -62,15 +62,15 @@ rd_source_conn <- function(type = "param") {

rd_target_conn <- function(type = "param") {
checkmate::assert_choice(type, c("param", "field"))
paste("target_conn (`DBIConnection`)\\cr",
paste("(`DBIConnection`)\\cr",
"A database connection to store the computed features in.",
ifelse(type == "field", " Read only.", ""))
}


rd_target_schema <- function(type = "param") {
checkmate::assert_choice(type, c("param", "field"))
paste("target_schema (`character`)\\cr",
paste("(`character`)\\cr",
"The schema to place the feature store in.",
ifelse(type == "field", " Read only.", ""),
"If the database backend does not support schema, the tables will be prefixed with target_schema.")
Expand All @@ -79,31 +79,31 @@ rd_target_schema <- function(type = "param") {

rd_training_length <- function(type = "param") {
checkmate::assert_choice(type, c("param", "field"))
paste("training_length (`numeric`)\\cr",
paste("(`numeric`)\\cr",
"The number of days that should be included in the training of the model.",
ifelse(type == "field", " Read only.", ""))
}


rd_start_date <- function(type = "param") {
checkmate::assert_choice(type, c("param", "field"))
paste("start_date (`Date`)\\cr",
paste("(`Date`)\\cr",
"Study period start.",
ifelse(type == "field", " Read only.", ""))
}


rd_end_date <- function(type = "param") {
checkmate::assert_choice(type, c("param", "field"))
paste("end_date (`Date`)\\cr",
paste("(`Date`)\\cr",
"Study period end.",
ifelse(type == "field", " Read only.", ""))
}


rd_slice_ts <- function(type = "param") {
checkmate::assert_choice(type, c("param", "field"))
paste("slice_ts (`Date` or `character`)\\cr",
paste("(`Date` or `character`)\\cr",
"Date to slice the database on (used if source_conn is a database).",
ifelse(type == "field", " Read only.", ""))
}
Expand All @@ -112,23 +112,52 @@ rd_slice_ts <- function(type = "param") {
rd_.data <- function(type = "param") { # nolint: object_name_linter
checkmate::assert_choice(type, c("param", "field"))
paste(".data\\cr",
"The data object to perform the operation on",
"The data object to perform the operation on.",
ifelse(type == "field", " Read only.", ""))
}


rd_describe <- "Prints a human readable report of the internal state of the module."

rd_get_results_description <- paste(
"The primary method used to request model results of a given observable at a given aggregation"
"The primary method used to request model results of a given observable at a given aggregation."
)

rd_get_results_return <- paste(
"A `tibble` [tibble::tibble] with predictions at the level specified by aggregation level.",
"In addition to aggregation columns, the output has the columns:\\cr",
" date (`Date`) specifying the date of the prediction\\cr",
" realization_id (`character`) giving a unique id for each realization in the ensemble\\cr",
" model (`character`) the name (classname) of the model used to provide the prediction"
" model (`character`) the name (classname) of the model used to provide the prediction."
)

rd_get_results_seealso <- "[diseasy::DiseasyObservables]"


rd_aggregation <- paste(
"(`list`(`quosures`))\\cr",
"Default NULL.",
"If given, expressions in aggregation evaluated to give the aggregation level."
)

rd_observable <- paste(
"(`character`)\\cr",
"The observable to provide prediction for. Must match observable in `DiseasyObservables` [R6][R6::R6Class] class."
)

rd_prediction_length <- paste(
"(`numeric`)\\cr",
"The number of days to predict.",
"The prediction start is defined by `last_queryable_date` of the `DiseasyObservables` [R6][R6::R6Class] class."
)

rd_quantiles <- paste(
"(`list`(`numeric`))\\cr",
"Default NULL.",
"If given, results are returned at the quantiles given."
)

rd_training_length <- paste(
"(`numeric`)\\cr",
"The number of days that should be included in the training of the model."
)
179 changes: 179 additions & 0 deletions R/DiseasyModel.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#' @title Meta module for the models
#'
#' @description TODO
#' @export
DiseasyModel <- R6::R6Class( # nolint: object_name_linter
classname = "DiseasyModel",
inherit = DiseasyBaseModule,

public = list(

#' @description
#' Creates a new instance of the `DiseasyModel` [R6][R6::R6Class] class.
#' This module is typically not constructed directly but rather through `DiseasyModel*` classes
#' @param activity,observables,season (`boolean` or `R6::R6Class instance`)\cr
#' If a boolean is given, it dictates whether to load a new instance module of this class\cr
#' If an instance of the module is provided instead, this instance is cloned to the new `DiseasyModel` instance\cr
#' Default is FALSE.
#' @param label (`character`)\cr
#' A human readable label for the model instance
#' @param ...
#' parameters sent to `DiseasyBaseModule` [R6][R6::R6Class] constructor
#' @details
#' The `DiseasyModel` is the main template that the individual models should inherit from since this defines the
#' set of methods the later framework expects from each model. In addition, it provides the main interface with
#' the other modules of the framework
#' @return
#' A new instance of the `DiseasyModel` [R6][R6::R6Class] class.
initialize = function(activity = FALSE,
observables = FALSE,
season = FALSE,
label = NULL,
...) {

coll <- checkmate::makeAssertCollection()
checkmate::assert(checkmate::check_logical(activity, null.ok = TRUE),
checkmate::check_class(activity, "DiseasyActivity", null.ok = TRUE),
add = coll)
checkmate::assert(checkmate::check_logical(observables, null.ok = TRUE),
checkmate::check_class(observables, "DiseasyObservables", null.ok = TRUE),
add = coll)
checkmate::assert(checkmate::check_logical(season, null.ok = TRUE),
checkmate::check_class(season, "DiseasySeason", null.ok = TRUE),
add = coll)
checkmate::assert_character(label, len = 1, any.missing = FALSE, null.ok = TRUE, add = coll)
checkmate::reportAssertions(coll)

# Pass further arguments to the DiseasyBaseModule initializer
super$initialize(...)

# Then try to set the modules
if (isTRUE(observables)) {
self$load_module(DiseasyObservables$new())
} else if (inherits(observables, "DiseasyObservables")) {
self$load_module(observables)
}

if (isTRUE(activity)) {
self$load_module(DiseasyActivity$new())
} else if (inherits(activity, "DiseasyActivity")) {
self$load_module(activity)
}

if (isTRUE(season)) {
self$load_module(DiseasySeason$new())
} else if (inherits(season, "DiseasySeason")) {
self$load_module(season)
}

# Set the label for the model
private$label <- label
},


# Roxygen has only limited support for R6 docs currently, so we need to do some tricks for the documentation
# of get_results
#' @description `r rd_get_results_description`
#' @param observable `r rd_observable`
#' @param prediction_length `r rd_prediction_length`
#' @param quantiles `r rd_quantiles`
#' @param aggregation `r rd_aggregation`
#' @return `r rd_get_results_return`
#' @seealso `r rd_get_results_seealso`
get_results = function(observable, prediction_length, quantiles = NULL, aggregation = NULL) {
private$not_implemented_error("Each model must implement their own `get_results` methods")
},


#' @description
#' A method that returns training data for the models based on the model value of `training_length` and
#' the `last_queryable_date` of the `DiseasyObservables` module.
#' @param observable `r rd_observable`
#' @param aggregation `r rd_aggregation`
#' @return The output of `DiseasyObservables$get_observation` constrained to the training period.
get_training_data = function(observable, aggregation = NULL) {

# Input validation
coll <- checkmate::makeAssertCollection()
checkmate::assert_character(observable, add = coll)
checkmate::assert_number(self$parameters$training_length, add = coll)
checkmate::assert_date(self$observables$last_queryable_date, add = coll)
checkmate::reportAssertions(coll)

# Get the observable at the aggregation level
start_date <- self$observables$last_queryable_date - lubridate::days(self$parameters$training_length)
end_date <- self$observables$last_queryable_date # Only within the training period

data <- self$observables$get_observation(observable, aggregation, start_date, end_date) |>
dplyr::mutate(t = lubridate::interval(max(zoo::as.Date(date)), zoo::as.Date(date)) / lubridate::days(1))

return(data)
}
),

# Make active bindings to the private variables
active = list(

#' @field activity (`diseasy::activity`)\cr
#' The local copy of an activity module. Read-only.
#' @seealso [diseasy::DiseasyActivity]
#' @importFrom diseasystore `%.%`
activity = purrr::partial(
.f = active_binding, # nolint: indentation_linter
name = "activity",
expr = return(private %.% .DiseasyActivity)),


#' @field observables (`diseasy::DiseasyObservables`)\cr
#' The local copy of an DiseasyObservables module. Read-only.
#' @seealso [diseasy::DiseasyObservables]
#' @importFrom diseasystore `%.%`
observables = purrr::partial(
.f = active_binding, # nolint: indentation_linter
name = "observables",
expr = return(private %.% .DiseasyObservables)),


#' @field season (`diseasy::season`)\cr
#' The local copy of an season module. Read-only.
#' @seealso [diseasy::DiseasySeason]
#' @importFrom diseasystore `%.%`
season = purrr::partial(
.f = active_binding, # nolint: indentation_linter
name = "season",
expr = return(private %.% .DiseasySeason)),


#' @field parameters (`list()`)\cr
#' The parameters used in the model. Read-only.
#' @importFrom diseasystore `%.%`
parameters = purrr::partial(
.f = active_binding, # nolint: indentation_linter
name = "parameters",
expr = return(private %.% .parameters))
),

private = list(

.DiseasyActivity = NULL,
.DiseasyObservables = NULL,
.DiseasySeason = NULL,
.parameters = NULL,

# @param label (`character`)\cr
# A human readable label for the model instance
label = NULL,

model_cannot_predict = function(observable = NULL, aggregation = NULL) {
coll <- checkmate::makeAssertCollection()
if (!is.null(observable)) {
coll$push(glue::glue("Model not configured to predict for observable: {observable}"))
}
if (!is.null(aggregation)) {
coll$push(glue::glue("Model not configured to predict at aggregation: ",
"{private$aggregation_to_string(aggregation)}"))
}
checkmate::reportAssertions(coll)
}
)
)
Loading

0 comments on commit bd8b8fe

Please sign in to comment.