Skip to content

Commit

Permalink
Merge pull request #328 from saumil-sh/master
Browse files Browse the repository at this point in the history
Finer control of `solve` and `sample` args via `turing_inference` args
  • Loading branch information
Vaibhavdixit02 authored Jul 25, 2024
2 parents 3b8f7bb + 3065750 commit 14eb97d
Show file tree
Hide file tree
Showing 12 changed files with 188 additions and 105 deletions.
14 changes: 7 additions & 7 deletions .github/scripts/stan.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
JULIA_CMDSTAN_HOME="/home/runner/cmdstan-2.29.2/"
JULIA_CMDSTAN_HOME="/home/runner/cmdstan-2.34.1/"
OLDWD=`pwd`
cd ~
pwd
wget https://github.com/stan-dev/cmdstan/releases/download/v2.29.2/cmdstan-2.29.2.tar.gz
tar -xzpf cmdstan-2.29.2.tar.gz
wget https://github.com/stan-dev/cmdstan/releases/download/v2.34.1/cmdstan-2.34.1.tar.gz
tar -xzpf cmdstan-2.34.1.tar.gz
ls -lia .
ls -lia ./cmdstan-2.29.2
ls -lia ./cmdstan-2.29.2/make
touch ./cmdstan-2.29.2/make/local
echo "STAN_THREADS=true" > ./cmdstan-2.29.2/make/local
ls -lia ./cmdstan-2.34.1
ls -lia ./cmdstan-2.34.1/make
touch ./cmdstan-2.34.1/make/local
echo "STAN_THREADS=true" > ./cmdstan-2.34.1/make/local
make -C $JULIA_CMDSTAN_HOME build
cd $OLDWD
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
- uses: julia-actions/julia-runtest@v1
env:
GROUP: ${{ matrix.group }}
JULIA_CMDSTAN_HOME: "/home/runner/cmdstan-2.29.2/"
JULIA_CMDSTAN_HOME: "/home/runner/cmdstan-2.34.1/"
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key
GROUP: ${{ matrix.group }}
JULIA_CMDSTAN_HOME: "/home/runner/cmdstan-2.29.2/"
JULIA_CMDSTAN_HOME: "/home/runner/cmdstan-2.34.1/"
run: julia --project=docs/ --code-coverage=user docs/make.jl
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqBayes"
uuid = "ebbdde9d-f333-5424-9be2-dbf1e9acfb5e"
authors = ["Vaibhavdixit02 <[email protected]>"]
version = "3.6.1"
version = "3.6.2"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand Down
30 changes: 30 additions & 0 deletions docs/src/assets/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CmdStan = "593b3428-ca2f-500c-ae53-031589ec8ddd"
DiffEqBayes = "ebbdde9d-f333-5424-9be2-dbf1e9acfb5e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
ParameterizedFunctions = "65888b18-ceab-5e60-b2b9-181511a3b968"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"

[compat]
BenchmarkTools = "1"
CmdStan = "6"
DiffEqBayes = "3"
Distributions = "0.25"
Documenter = "0.27"
DynamicHMC = "3"
OrdinaryDiffEq = "6"
ParameterizedFunctions = "5"
Plots = "1"
RecursiveArrayTools = "2, 3"
StatsBase = "0.33, 0.34"
StatsPlots = "0.15"
TransformVariables = "0.8"
2 changes: 1 addition & 1 deletion docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ data = convert(Array, randomized)

```@example all
using CmdStan #required for using the Stan backend
bayesian_result_stan = stan_inference(prob1, t, data, priors)
bayesian_result_stan = stan_inference(prob1, :rk45, t, data, priors)
```

### Turing
Expand Down
10 changes: 5 additions & 5 deletions docs/src/examples/pendulum.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ priors = [
Finally, let's run the estimation routine from DiffEqBayes.jl with the Turing.jl backend to check if we indeed recover the parameters!

```@example pendulum
bayesian_result = turing_inference(prob1, Tsit5(), t, data, priors; num_samples = 10_000,
syms = [:omega, :L])
bayesian_result = turing_inference(prob1, Tsit5(), t, data, priors;
syms = [:omega, :L], sample_args = (num_samples = 10_000,))
```

Notice that while our guesses had the wrong means, the learned parameters converged
Expand Down Expand Up @@ -120,12 +120,12 @@ to get better understanding of the performance.

```@example pendulum
@btime bayesian_result = turing_inference(prob1, Tsit5(), t, data, priors;
syms = [:omega, :L], num_samples = 10_000)
syms = [:omega, :L], sample_args = (num_samples = 10_000,))
```

```@example pendulum
@btime bayesian_result = stan_inference(prob1, t, data, priors; num_samples = 10_000,
print_summary = false)
@btime bayesian_result = stan_inference(prob1, :rk45, t, data, priors;
sample_kwargs = Dict(:num_samples => 10_000), print_summary = false)
```

```@example pendulum
Expand Down
42 changes: 21 additions & 21 deletions docs/src/methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,24 @@ using DiffEqBayes
### stan_inference

```julia
stan_inference(prob::ODEProblem, t, data, priors = nothing; alg = :rk45,
num_samples = 1000, num_warmups = 1000, reltol = 1e-3,
abstol = 1e-6, maxiter = Int(1e5), likelihood = Normal,
vars = (StanODEData(), InverseGamma(2, 3)))
stan_inference(prob::DiffEqBase.DEProblem, alg, t, data, priors = nothing;
stanmodel = nothing, likelihood = Normal, vars = (StanODEData(), InverseGamma(3, 3)), sample_u0 = false, solve_kwargs = Dict(), diffeq_string = nothing, sample_kwargs = Dict(), output_format = :mcmcchains, print_summary = true, tmpdir = mktempdir())
```

`stan_inference` uses [Stan.jl](https://stanjulia.github.io/CmdStan.jl/latest/INTRO/)
to perform the Bayesian inference. The
[Stan installation process](https://stanjulia.github.io/CmdStan.jl/latest/INSTALLATION/)
is required to use this function. `t` is the array of time
and `data` is the array where the first dimension (columns) corresponds to the
array of system values. `priors` is an array of prior distributions for each
parameter, specified via a [Distributions.jl](https://juliastats.github.io/Distributions.jl/dev/)
type. `alg` is a choice between `:rk45` and `:bdf`, the two internal integrators
of Stan. `num_samples` is the number of samples to take per chain, and `num_warmups`
is the number of MCMC warm-up steps. `abstol` and `reltol` are the keyword
arguments for the internal integrator. `likelihood` is the likelihood distribution
to use with the arguments from `vars`, and `vars` is a tuple of priors for the
distributions of the likelihood hyperparameters. The special value `StanODEData()`
in this tuple denotes the position that the ODE solution takes in the likelihood's
parameter list.
is required to use this function. Currently `CmdStan v2.34.1` is supported.

`prob` can be any `DEProblem` with a corresponding `alg` choice. `alg` is a choice between `:rk45` and `:bdf`, the two internal integrators of Stan. `t` is the array of time and `data` is the array where the first dimension (columns) corresponds to the array of system values. `priors` is an array of prior distributions for each parameter, specified via a [Distributions.jl](https://juliastats.github.io/Distributions.jl/dev/) type. `likelihood` is the likelihood distribution to use with the arguments from `vars`, and `vars` is a tuple of priors for the distributions of the likelihood hyperparameters. The special value `StanODEData()` in this tuple denotes the position that the ODE solution takes in the likelihood's parameter list.

`solve_kwargs` is a `Dict` and passed to the stan differential equation solver. `solve_kwargs` may contain `save_idxs`, `reltol`, `abstol`, and `maxiter`. `save_idxs` is documented at [`DifferentialEquations.jl`](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/). `sample_kwargs` are passed to the stan sampler and accepts `num_samples`, `num_warmups`, `num_cpp_chains` , `num_chains`, `num_threads`, `delta`. Please refer to the [stan documentation for more information](https://mc-stan.org/docs/cmdstan-guide/mcmc-intro.html).

### turing_inference

```julia
function turing_inference(prob::DiffEqBase.DEProblem, alg, t, data, priors;
likelihood_dist_priors, likelihood, num_samples = 1000,
sampler = Turing.NUTS(num_samples, 0.65), parallel_type = MCMCSerial(), n_chains = 1, syms, kwargs...)
end
turing_inference(prob::DiffEqBase.DEProblem, alg, t, data, priors;
likelihood_dist_priors, likelihood, syms, sample_u0 = false, progress = false, solve_kwargs = Dict(), sample_args = NamedTuple(), sample_kwargs= Dict())
```

`turing_inference` uses [Turing.jl](https://github.com/TuringLang/Turing.jl) to
Expand All @@ -49,7 +38,18 @@ observations for the differential equation system at time point `t[i]` (or highe
dimensional). `priors` is an array of prior distributions for each
parameter, specified via a
[Distributions.jl](https://juliastats.github.io/Distributions.jl/dev/)
type. `num_samples` is the number of samples per MCMC chain. Sampling from multiple chains is possible, see [`Turing.jl` documentation](https://turinglang.org/v0.26/docs/using-turing/guide#sampling-multiple-chains), serially or parallelly using `parallel_type` and `n_chains`. The extra `kwargs` are given to the internal differential equation solver.
type.

The `turing_inference` interacts with `SciML.CommonSolve.solve` and `StatsBase.sample`. Both accept many arguments depending on the solver and sampling algorithm.
These arguments are supplied to `turing_inferene` function via `solve_kwargs`, `sample_args`, and `sample_kwargs` arguments. Please refer to [the `solve` documentation](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) for `solve_kwargs`, e.g. `solve_kwargs = Dict(:save_idxs => [1])`.
The `solve` keyword arguments default to `save_idxs = nothing`. Similarly please refer to [the `sample` documentation]((https://turinglang.org/v0.26/docs/using-turing/guide#sampling-multiple-chains)) for `sample_args` and `sample_kwargs`. The four positional argument are as following: `sampler`, the sampling algorithm. Sampling from multiple chains is possible serially or parallelly using `parallel_type`. Third `num_samples`, the number of samples per MCMC chain and `n_chains`, the number of MCMC chains. The positional arguments default to the following values.

```julia
sampler = Turing.NUTS(0.65)
parallel_type = MCMCSerial()
num_samples = 1000
n_chains = 1
```

### dynamichmc_inference

Expand Down
78 changes: 60 additions & 18 deletions src/stan_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,66 @@ function generate_theta(n, priors)
return theta
end

function stan_inference(prob::DiffEqBase.DEProblem,
function stan_inference(
prob::DiffEqBase.DEProblem,
alg,
# Positional arguments
t, data, priors = nothing, stanmodel = nothing;
t,
data,
priors = nothing;
stanmodel = nothing,
# DiffEqBayes keyword arguments
likelihood = Normal, vars = (StanODEData(), InverseGamma(3, 3)),
sample_u0 = false, save_idxs = nothing, diffeq_string = nothing,
likelihood = Normal,
vars = (StanODEData(), InverseGamma(3, 3)),
sample_u0 = false,
# Stan differential equation function keyword arguments
alg = :rk45, reltol = 1e-3, abstol = 1e-6, maxiter = Int(1e5),
solve_kwargs = Dict(),
diffeq_string = nothing,
# stan_sample keyword arguments
num_samples = 1000, num_warmups = 1000,
num_cpp_chains = 1, num_chains = 1, num_threads = 1,
delta = 0.8,
sample_kwargs = Dict(),
# read_samples arguments
output_format = :mcmcchains,
# read_summary arguments
print_summary = true,
# pass in existing tmpdir
tmpdir = mktempdir())
tmpdir = mktempdir()
)
# update default stan diff eq function kw args
solve_kwargs = merge(
Dict(
:save_idxs => nothing,
:reltol => 1e-3,
:abstol => 1e-6,
:maxiter => Int(1e5)
),
solve_kwargs
)
# update default stan sample kw args
sample_kwargs = merge(
Dict(
:num_samples => 1000,
:num_warmups => 1000,
:num_cpp_chains => 1,
:num_chains => 1,
:num_threads => 1,
:delta => 0.8
),
sample_kwargs
)

# make kw arg from dicts available as old vars
save_idxs = solve_kwargs[:save_idxs]
reltol = solve_kwargs[:reltol]
abstol = solve_kwargs[:abstol]
maxiter = solve_kwargs[:maxiter]

num_samples = sample_kwargs[:num_samples]
num_warmups = sample_kwargs[:num_warmups]
num_cpp_chains = sample_kwargs[:num_cpp_chains]
num_chains = sample_kwargs[:num_chains]
num_threads = sample_kwargs[:num_threads]
delta = sample_kwargs[:delta]

save_idxs !== nothing && length(save_idxs) == 1 ? save_idxs = save_idxs[1] :
save_idxs = save_idxs
length_of_y = length(prob.u0)
Expand Down Expand Up @@ -167,29 +209,29 @@ function stan_inference(prob::DiffEqBase.DEProblem,

parameter_estimation_model = "
functions {
$diffeq_string
$(diffeq_string)
}
data {
vector[$length_of_y] u0;
vector[$(length_of_y)] u0;
int<lower=1> T;
real internal_var___u[T,$(length(save_idxs))];
array[T,$(length(save_idxs))] real internal_var___u;
real t0;
real ts[T];
array[T] real ts;
}
parameters {
$setup_params
$theta_string
$(setup_params)
$(theta_string)
}
model{
vector[$length_of_y] u_hat[T];
array[T] vector[$length_of_y] u_hat;
$hyper_params
$priors_string
$integral_string
for (t in 1:T){
internal_var___u[t,:] ~ $stan_likelihood($tuple_hyper_params);
}
}
"
"

stanmodel = SampleModel("parameter_estimation_model",
parameter_estimation_model,
Expand All @@ -208,4 +250,4 @@ function stan_inference(prob::DiffEqBase.DEProblem,
else
rc.err
end
end
end
43 changes: 25 additions & 18 deletions src/turing_inference.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
function turing_inference(prob::DiffEqBase.DEProblem,
function turing_inference(
prob::DiffEqBase.DEProblem,
alg,
t,
data,
priors;
likelihood_dist_priors = [InverseGamma(2, 3)],
likelihood = (u, p, t, σ) -> MvNormal(u,
Diagonal((σ[1])^2 *
ones(length(u)))),
num_samples = 1000,
sampler = Turing.NUTS(0.65),
parallel_type = MCMCSerial(),
n_chains = 1,
likelihood = (u, p, t, σ) -> MvNormal(
u, σ[1]^2 * Diagonal(ones(length(u)))
),
syms = [Turing.@varname(theta[i]) for i in 1:length(priors)],
sample_u0 = false,
save_idxs = nothing,
progress = false,
kwargs...)
solve_kwargs = Dict(), # accept SciML DiffEq solve kwargs
sample_args = NamedTuple(), # accept Turing.jl sample args
sample_kwargs = Dict() # accept Turing.jl sample kwargs
)
N = length(priors)
# default args are updated with user supplied args
solve_kwargs = merge(Dict(:save_idxs => nothing), solve_kwargs)
sample_args = (;
sampler = Turing.NUTS(0.65),
parallel_type = MCMCSerial(),
num_samples = 1000,
n_chains = 1,
sample_args...
)
Turing.@model function infer(x, ::Type{T} = Float64) where {T <: Real}
theta = Vector{T}(undef, length(priors))
for i in 1:length(priors)
Expand All @@ -26,7 +34,8 @@ function turing_inference(prob::DiffEqBase.DEProblem,
for i in 1:length(likelihood_dist_priors)
σ[i] ~ likelihood_dist_priors[i]
end
nu = save_idxs === nothing ? length(prob.u0) : length(save_idxs)
nu = solve_kwargs[:save_idxs] === nothing ? length(prob.u0) :
length(solve_kwargs[:save_idxs])
u0 = convert.(T, sample_u0 ? theta[1:nu] : prob.u0)
p = convert.(T, sample_u0 ? theta[(nu + 1):end] : theta)
if length(u0) < length(prob.u0)
Expand All @@ -36,8 +45,8 @@ function turing_inference(prob::DiffEqBase.DEProblem,
end
end
_saveat = t === nothing ? Float64[] : t
sol = solve(prob, alg; u0 = u0, p = p, saveat = _saveat, progress = progress,
save_idxs = save_idxs, kwargs...)
sol = solve(prob, alg; u0 = u0, p = p, saveat = _saveat,
progress = progress, solve_kwargs...)
failure = size(sol, 2) < length(_saveat)

if failure
Expand All @@ -58,11 +67,9 @@ function turing_inference(prob::DiffEqBase.DEProblem,
model = infer(data)
chn = sample(
model,
sampler,
parallel_type,
num_samples,
n_chains;
progress = progress
sample_args...;
progress = progress,
sample_kwargs...
)
return chn
end
Loading

0 comments on commit 14eb97d

Please sign in to comment.