Skip to content

Commit

Permalink
Merge pull request #307 from saumil-sh/master
Browse files Browse the repository at this point in the history
Added support for sampling from multiple chain using `Turing.jl`
  • Loading branch information
ChrisRackauckas authored Jun 26, 2023
2 parents 8bb0144 + 5004545 commit 6994120
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
5 changes: 2 additions & 3 deletions docs/src/methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ parameter list.
```julia
function turing_inference(prob::DiffEqBase.DEProblem, alg, t, data, priors;
likelihood_dist_priors, likelihood, num_samples = 1000,
sampler = Turing.NUTS(num_samples, 0.65), syms, kwargs...)
sampler = Turing.NUTS(num_samples, 0.65), parallel_type = MCMCSerial(), n_chains = 1, syms, kwargs...)
end
```

Expand All @@ -49,8 +49,7 @@ observations for the differential equation system at time point `t[i]` (or highe
dimensional). `priors` is an array of prior distributions for each
parameter, specified via a
[Distributions.jl](https://juliastats.github.io/Distributions.jl/dev/)
type. `num_samples` is the number of samples per MCMC chain. The extra `kwargs` are given to the internal differential
equation solver.
type. `num_samples` is the number of samples per MCMC chain. Sampling from multiple chains is possible, see [`Turing.jl` documentation](https://turinglang.org/v0.26/docs/using-turing/guide#sampling-multiple-chains), serially or parallelly using `parallel_type` and `n_chains`. The extra `kwargs` are given to the internal differential equation solver.

### dynamichmc_inference

Expand Down
18 changes: 14 additions & 4 deletions src/turing_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@ function turing_inference(prob::DiffEqBase.DEProblem,
likelihood = (u, p, t, σ) -> MvNormal(u,
Diagonal((σ[1])^2 *
ones(length(u)))),
num_samples = 1000, sampler = Turing.NUTS(0.65),
num_samples = 1000,
sampler = Turing.NUTS(0.65),
parallel_type = MCMCSerial(),
n_chains = 1,
syms = [Turing.@varname(theta[i]) for i in 1:length(priors)],
sample_u0 = false,
save_idxs = nothing,
progress = false,
kwargs...)
N = length(priors)
Turing.@model function mf(x, ::Type{T} = Float64) where {T <: Real}
Turing.@model function infer(x, ::Type{T} = Float64) where {T <: Real}
theta = Vector{T}(undef, length(priors))
for i in 1:length(priors)
theta[i] ~ NamedDist(priors[i], syms[i])
Expand Down Expand Up @@ -54,7 +57,14 @@ function turing_inference(prob::DiffEqBase.DEProblem,
end false

# Instantiate a Model object.
model = mf(data)
chn = sample(model, sampler, num_samples; progress = progress)
model = infer(data)
chn = sample(
model,
sampler,
parallel_type,
num_samples,
n_chains;
progress = progress
)
return chn
end
21 changes: 13 additions & 8 deletions test/turing.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using DiffEqBayes, OrdinaryDiffEq, ParameterizedFunctions, RecursiveArrayTools
using Test, Distributions, SteadyStateDiffEq
using Turing

println("One parameter case")
f1 = @ode_def begin
dx = a * x - x * y
Expand All @@ -14,25 +16,28 @@ randomized = VectorOfArray([(sol(t[i]) + 0.01randn(2)) for i in 1:length(t)])
data = convert(Array, randomized)
priors = [Normal(1.5, 0.01)]

bayesian_result = turing_inference(prob1, Tsit5(), t, data, priors; num_samples = 500,
syms = [:a])
bayesian_result = turing_inference(prob1, Tsit5(), t, data, priors; num_samples = 500, syms = [:a])

@show bayesian_result

@test mean(get(bayesian_result, :a)[1])1.5 atol=3e-1

bayesian_result = turing_inference(prob1, Rosenbrock23(autodiff = false), t, data, priors;
num_samples = 500,
syms = [:a])
bayesian_result = turing_inference(prob1, Rosenbrock23(autodiff = false), t, data, priors; num_samples = 500, syms = [:a])

bayesian_result = turing_inference(prob1, Rosenbrock23(), t, data, priors;
num_samples = 500,
syms = [:a])

# --- test Multithreaded sampling
println("Multithreaded case")
result_threaded = turing_inference(prob1, Tsit5(), t, data, priors; num_samples = 500, syms = [:a], parallel_type=MCMCThreads(), n_chains=2)

@test length(result_threaded.value.axes[3]) == 2
@test mean(get(result_threaded, :a)[1])1.5 atol=3e-1
# ---

priors = [Normal(1.0, 0.01), Normal(1.0, 0.01), Normal(1.5, 0.01)]
bayesian_result = turing_inference(prob1, Tsit5(), t, data, priors; num_samples = 500,
sample_u0 = true,
syms = [:u1, :u2, :a])
bayesian_result = turing_inference(prob1, Tsit5(), t, data, priors; num_samples = 500, sample_u0 = true, syms = [:u1, :u2, :a])

@test mean(get(bayesian_result, :a)[1])1.5 atol=3e-1
@test mean(get(bayesian_result, :u1)[1])1.0 atol=3e-1
Expand Down

0 comments on commit 6994120

Please sign in to comment.