Skip to content

Commit

Permalink
Add soj_prob function
Browse files Browse the repository at this point in the history
  • Loading branch information
chjackson committed Sep 26, 2024
1 parent 913748d commit 23b51ea
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 8 deletions.
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
S3method(print,msmbayes)
S3method(summary,msmbayes)
S3method(summary,msmbres)
export(dnphase)
export(edf)
export(hnphase)
export(hr)
export(loghr)
export(mean_sojourn)
Expand All @@ -13,8 +15,10 @@ export(msmhist_bardata)
export(msmprior)
export(pmatrix)
export(pmatrixdf)
export(pnphase)
export(qdf)
export(qmatrix)
export(soj_prob)
export(standardise_to)
export(standardize_to)
export(totlos)
Expand Down Expand Up @@ -63,6 +67,7 @@ importFrom(rlang,.data)
importFrom(rlang,caller_env)
importFrom(stats,delete.response)
importFrom(stats,na.omit)
importFrom(stats,pexp)
importFrom(stats,qnorm)
importFrom(stats,quantile)
importFrom(stats,reshape)
Expand Down
2 changes: 1 addition & 1 deletion R/msmbayes-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' For more resources on multi-state modelling, see the [`msm` package](http://chjackson.github.io/msm) and its documentation.
#'
#' @name msmbayes-package
#' @importFrom stats delete.response na.omit reshape setNames terms quantile runif qnorm
#' @importFrom stats delete.response na.omit reshape setNames terms quantile runif qnorm pexp
#' @importFrom posterior as_draws as_draws_matrix rhat ess_bulk rvar ndraws rvar_sum "%**%" rdo rvar_sum draws_of merge_chains is_rvar
#' @importFrom cli cli_abort cli_warn qty cli_progress_bar cli_progress_update cli_progress_done
#' @importFrom glue glue
Expand Down
37 changes: 34 additions & 3 deletions R/outputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ pmatrixdf <- function(draws, t=1, new_data=NULL){
if (!is.null(new_data))
pdf <- pdf |>
left_join(new_data |> mutate(covid=1:n()), by="covid")
class(pdf) <- c("msmbres", class(pdf))
pdf |>
as_msmbres() |>
select(-covid) |>
relabel_phase_states(draws)
}
Expand Down Expand Up @@ -250,8 +250,7 @@ loghr <- function(draws){
name = cm$Xnames) |>
select(from, to, name, value) |>
relabel_phase_states(draws)
class(loghr) <- c("msmbres", class(loghr))
loghr
as_msmbres(loghr)
}

#' Hazard ratios for covariates on transition intensities
Expand Down Expand Up @@ -279,6 +278,14 @@ summary.msmbres <- function(object, ...){
cbind(object, summ_df)
}

#' Convert to "msmbayes result" class
#' so we can use summary.msmbres
#'
#' @noRd
as_msmbres <- function(object){
class(object) <- c("msmbres", class(object))
object
}


#' Total length of stay in each state over an interval
Expand Down Expand Up @@ -374,3 +381,27 @@ standardise_to <- function(new_data){
#' @rdname standardise_to
#' @export
standardize_to <- standardise_to


##' Sojourn probability in a state of a msmbayes model
##'
##' @inheritParams qmatrix
##'
##' @param t Time since state entry
##'
##' @param state State of interest (integer)
##'
##' @return A data frame with column `value` giving the probability of
##' remaining in `state` by time `t` since state entry, as an `rvar`
##' object. Other columns give the time and any covariate values.
##'
##' See \code{\link{qdf}} for notes on the `rvar` format.
##'
##' @md
##' @export
soj_prob <- function(draws, t, state, new_data=NULL){
if (is_phasetype(draws))
soj_prob_phase(draws, t, state, new_data)
else
soj_prob_nonphase(draws, t, state, new_data)
}
51 changes: 47 additions & 4 deletions R/outputs_internal.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ qvector <- function(draws, new_data=NULL, X=NULL){
}

logq_add_covs <- function(logq, loghr, new_data, X, draws){
## TODO for getting the prior . Support this, or just summary() for now?
## TODO for getting the prior . Support this, or just summary() for now?
}

check_X <- function(X,draws){
Expand Down Expand Up @@ -101,7 +101,7 @@ qvec_rvar_to_Q <- function(qvec, qm){

qvec_rvar_to_mst <- function(qvec, qm){
Q <- qvec_rvar_to_Q(qvec, qm)
-1 / diag(Q)
-1 / diag(Q)
}

#' @param rvarmat An rvar matrix with one row per covariate value.
Expand Down Expand Up @@ -129,8 +129,7 @@ vecbycovs_to_df <- function(rvarmat, new_data){
if (!is.null(new_data) && !isTRUE(attr(new_data, "std")))
res <- res |>
left_join(new_data |> mutate(covid=1:n()), by="covid")
class(res) <- c("msmbres", class(res))
res |> select(-covid)
as_msmbres(res) |> select(-covid)
}

#' Convert transition intensities in a phase-type model to a mixture
Expand Down Expand Up @@ -208,3 +207,47 @@ new_data_to_X <- function(new_data, draws, call=caller_env()){
X <- as.matrix(do.call("cbind", X))
X
}


soj_prob_phase <- function(draws, t, state, new_data=NULL){
fromobs <- ttype <- value <- covid <- NULL
qphase <- qdf(draws, new_data=new_data) |> filter(fromobs==state)
arate <- qphase |> filter(ttype=="abs") |> pull(value) |> draws_of()
prate <- qphase |> filter(ttype=="prog") |> pull(value) |> draws_of()
ntimes <- length(t)
ncovvals <- max(NROW(new_data), 1)
surv <- array(0, dim=c(ndraws(draws), ncovvals, ntimes))
for (i in 1:ntimes){
for (j in 1:ncovvals){
covid_p <- rep(1:ncovvals, length.out=ncol(prate))
covid_a <- rep(1:ncovvals, length.out=ncol(arate))
surv[,j,i] <- 1 - pnphase(t[i], prate[,covid_p==j], arate[,covid_a==j])
}
}
res <- data.frame(time = rep(t, ncovvals),
value = as.vector(rvar(surv))) |>
as_msmbres()
if (!is.null(new_data))
res <- res |>
mutate(covid = rep(1:ncovvals, each=ntimes)) |>
left_join(new_data |> mutate(covid=1:n()), by="covid") |>
select(-covid)
res
}

soj_prob_nonphase <- function(draws, t, state, new_data=NULL){
vecid <- NULL
qv <- - qmatrix(draws, new_data=new_data, drop=FALSE)[, state, state] |> draws_of()
ntimes <- length(t)
ncovvals <- dim(qv)[2]
surv <- array(0, dim=c(ndraws(draws), ncovvals, ntimes))
for (i in 1:ntimes){
surv[,,i] <- pexp(t[i], rate = qv, lower.tail=FALSE)
}
res <- rvar(surv) |>
vecbycovs_to_df(new_data) |>
mutate(time = t[vecid]) |>
select(-vecid) |>
as_msmbres()
res
}
27 changes: 27 additions & 0 deletions man/soj_prob.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions tests/testthat/test_outputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,14 @@ test_that("edf",{
value(summary(cav_misc) |>filter(name=="e") |>
slice(1) |> pull(value)))
})

test_that("soj_prob",{
nd <- data.frame(sex=c("male","female"))
expect_no_error({
soj_prob(infsim_model, t=c(5), state=2)
soj_prob(infsim_model, t=c(5,10), state=2)
soj_prob(infsim_modelc, t=c(5,10), new_data = nd, state=2)
soj_prob(infsim_modelp, t=c(5,10), state=2)
soj_prob(infsim_modelpc, t=c(5,10), new_data = nd, state=2)
})
})

0 comments on commit 23b51ea

Please sign in to comment.