Skip to content

Commit

Permalink
various from cherry picked commits
Browse files Browse the repository at this point in the history
  • Loading branch information
goldingn authored and njtierney committed Dec 2, 2024
1 parent 59eb35c commit 17ee6d3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
5 changes: 3 additions & 2 deletions R/iterate_dynamic_function.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ iterate_dynamic_function <- function(
...,
parameter_is_time_varying = c(),
state_limits = c(-Inf, Inf)
) {
) {

# generalise checking of inputs from iterate_matrix into functions
niter <- as.integer(niter)
Expand All @@ -96,6 +96,7 @@ iterate_dynamic_function <- function(
"{.var initial_state} must be either a column vector, or a 3D array \\
with final dimension 1"
)

}

# if this is multisite
Expand Down Expand Up @@ -216,6 +217,7 @@ as_tf_transition_function <- function (transition_function, state, iter, dots) {

# tf_dots will have been added to this environment by
# tf_iterate_dynamic_function
# tf_iterate_dynamic_matrix
args <- list(state = state, iter = iter)
do.call(tf_fun, c(args, tf_dots))

Expand Down Expand Up @@ -287,7 +289,6 @@ tf_iterate_dynamic_function <- function (state,
}
assign("tf_dots", tf_dots,
environment(tf_transition_function))

# evaluate function to get the new state (dots have been inserted into its
# environment, since TF while loops are treacherous things)
new_state <- tf_transition_function(old_state, iter)
Expand Down
25 changes: 25 additions & 0 deletions tests/testthat/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,31 @@ r_iterate_dynamic_function <- function(transition_function,
max_iter = i)
}

r_iterate_dynamic_function <- function(transition_function, initial_state, niter = 100, tol = 1e-6, ...) {

states <- list(initial_state)

i <- 0L
diff <- Inf

while(i < niter & diff > tol) {
i <- i + 1L
states[[i + 1]] <- transition_function(states[[i]], i, ...)
growth <- states[[i + 1]] / states[[i]]
diffs <- growth - 1
diff <- max(abs(diffs))
}

all_states <- matrix(0, length(states[[1]]), niter)
states_keep <- states[-1]
all_states[, seq_along(states_keep)] <- t(do.call(rbind, states_keep))

list(stable_state = states[[i]],
all_states = all_states,
converged = as.integer(diff < tol),
max_iter = i)
}

# a midpoint solver for use in deSolve, from the vignette p8
rk_midpoint <- deSolve::rkMethod(
ID = "midpoint",
Expand Down
3 changes: 0 additions & 3 deletions tests/testthat/test_iterate_dynamic_function.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
test_that("single iteration works", {
skip_if_not(check_tf_version())
set.seed(2017 - 05 - 01)

n <- 4
init <- rep(1, n)
niter <- 100
Expand Down Expand Up @@ -60,8 +59,6 @@ test_that("single iteration works", {

})



test_that("iteration works with time-varying parameters", {
skip_if_not(check_tf_version())
set.seed(2017 - 05 - 01)
Expand Down

0 comments on commit 17ee6d3

Please sign in to comment.