Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sermon et al. helper functions #378

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ include("gui/GUI.jl")
include("blox/connections.jl")
include("blox/blox_utilities.jl")
include("Neurographs.jl")
include("blox/sermon_dbs.jl")

function simulate(sys::ODESystem, u0, timespan, p, solver = AutoVern7(Rodas4()); kwargs...)
prob = ODEProblem(sys, u0, timespan, p)
Expand Down Expand Up @@ -211,5 +212,6 @@ export run_experiment!, run_trial!
export addnontunableparams
export get_weights, get_dynamic_states, get_idx_tagged_vars, get_eqidx_tagged_vars
export BalloonModel, boldsignal_endo_balloon
export SermonNPool, SermonDBS

end
160 changes: 160 additions & 0 deletions src/blox/sermon_dbs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
using IfElse

# Neurotransmitter pool block
# Implements Equation 4
struct SermonNPool <: NeuralMassBlox
params
output
jcn
odesystem
namespace
function SermonNPool(;
name,
namespace=nothing,
τ_pool=0.1,
p_pool=0.3
)
p = paramscoping(τ_pool=τ_pool, p_pool=p_pool)
τ_pool, p_pool = p

sts = @variables n_pool(t)=1.0 [output = true] jcn(t)=0.0 [input=true]
eqs = [D(n_pool) ~ (1-n_pool)/τ_pool - p_pool*n_pool*jcn]
sys = System(eqs, t, sts, p; name=name)
new(p, sts[1], sts[2], sys, namespace)
end
end

# Create a new DBS stimulator block
# Largely based on MTK Standard Library pulse block from the digital electronics
# Implements unnumbered equation below equation 2 for the pulse train, with default parameters from the paper
struct SermonDBS <: NeuralMassBlox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gabrevaya this is the stimulator for an arbitrary pulse train

params
output
jcn
odesystem
namespace
function SermonDBS(;
name,
namespace=nothing,
f_stim=130.0,
amplitude=1.0,
pulse_width=0.0001,
start_time=0.0
)
p = paramscoping(f_stim=f_stim, amplitude=amplitude, pulse_width=pulse_width, start_time=start_time)
f_stim, amplitude, pulse_width, start_time = p

sts = @variables u(t)=0.0 [output = true]
eqs = [u ~ IfElse.ifelse(t > start_time, IfElse.ifelse(t % (1.0/f_stim) < pulse_width, amplitude, 0), 0)]
sys = System(eqs, t, sts, p; name=name)
new(p, sts[1], nothing, sys, namespace)
end
end

# Connects the DBS stimulator to the neurotransmitter pool
function (bc::BloxConnector)(
bloxout::SermonDBS,
bloxin::SermonNPool;
kwargs...
)
sys_out = get_namespaced_sys(bloxout)
sys_in = get_namespaced_sys(bloxin)

w = generate_weight_param(bloxout, bloxin; kwargs...)
push!(bc.weights, w)

x = namespace_expr(bloxout.output, sys_out)

eq = sys_in.jcn ~ w*x
accumulate_equation!(bc, eq)
end

# Redfine Kuramoto oscillator connections with arbitrary connection rule
function (bc::BloxConnector)(
bloxout::KuramotoOscillator,
bloxin::KuramotoOscillator;
kwargs...
)
sys_out = get_namespaced_sys(bloxout)
sys_in = get_namespaced_sys(bloxin)

#technically these two lines aren't needed, but useful to have if weighting occurs here
w = generate_weight_param(bloxout, bloxin; kwargs...)
push!(bc.weights, w)

if haskey(kwargs, :sermon_rule)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see an else to this if. Could we just claim that (for now) :sermon_rule is the only possible connection rule to reduce the kwargs that are exposed to the user? Or were you planning on adding more rules?

if haskey(kwargs, :extra_bloxs) && haskey(kwargs, :extra_params)
RRP, RP, RtP = kwargs[:extra_bloxs]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks hardcoded for the particular tutorial case, i.e. expecting 3 elements in kwargs[:extra_bloxs]. You can imagine this breaking very easily if a user passes a different number in that keyword. We should generalize this in a loop, seems doable?

out_RRP = namespace_expr(RRP.output, get_namespaced_sys(RRP))
out_RP = namespace_expr(RP.output, get_namespaced_sys(RP))
out_RtP = namespace_expr(RtP.output, get_namespaced_sys(RtP))

M_RRP, M_RP, M_RtP, k_μ, I = kwargs[:extra_params]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above about kwargs[:extra_bloxs] but now it's kwargs[:extra_params]. I understand that we might want to keep this conenction rule specific to the tutorial (for now) in the interest of development speed, but we should think about the most likely breaking cases.

else
error("Extra states and parameters need to be specified to use the Sermon rule")
end
xₒ = namespace_expr(bloxout.output, sys_out)
xᵢ = namespace_expr(bloxin.output, sys_in) #needed because this is also the θ term of the block receiving the connection

# Custom values for the connection rule
# Values from supplementary material Table A
f₀ = -0.780
f₁ = 0.198
f₂ = 0.302
f₃ = 0.851
f₄ = 0.998
x = xₒ - xᵢ
# Equation 6 from the paper
eq = sys_in.jcn ~ w * max(M_RRP * out_RRP, M_RP * out_RP, M_RtP * out_RtP)*(f₀ + f₁*cos(x) + f₂*sin(x) + f₃*cos(2*x) + f₄*sin(2*x))
accumulate_equation!(bc, eq)
end
end

# Connects the DBS stimulator to the Kuramoto oscillators (Iₜ term in equation 1)
function (bc::BloxConnector)(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gabrevaya this is the connection rule for the DBS stimulator to Kuramoto blocks

bloxout::SermonDBS,
bloxin::KuramotoOscillator;
kwargs...
)
sys_out = get_namespaced_sys(bloxout)
sys_in = get_namespaced_sys(bloxin)

w = generate_weight_param(bloxout, bloxin; kwargs...)
push!(bc.weights, w)

xₒ = namespace_expr(bloxout.output, sys_out)
xᵢ = namespace_expr(bloxin.output, sys_in) #needed because this is also the θ term of the block receiving the connection

eq = sys_in.jcn ~ w*xₒ*sin(xᵢ)
accumulate_equation!(bc, eq)
end

# Debugging purposes only
# Connects the Kuramoto oscillators without neurotransmitter modulation
function (bc::BloxConnector)(
bloxout::KuramotoOscillator,
bloxin::KuramotoOscillator;
kwargs...
)
sys_out = get_namespaced_sys(bloxout)
sys_in = get_namespaced_sys(bloxin)

#technically these two lines aren't needed, but useful to have if weighting occurs here
w = generate_weight_param(bloxout, bloxin; kwargs...)
push!(bc.weights, w)

if haskey(kwargs, :sermon_rule_no_neurotransmitters)
xₒ = namespace_expr(bloxout.output, sys_out)
xᵢ = namespace_expr(bloxin.output, sys_in) #needed because this is also the θ term of the block receiving the connection

# Custom values for the connection rule
f₀ = -0.780
f₁ = 0.198
f₂ = 0.302
f₃ = 0.851
f₄ = 0.998
x = xₒ - xᵢ
eq = sys_in.jcn ~ w *(f₀ + f₁*cos(x) + f₂*sin(x) + f₃*cos(2*x) + f₄*sin(2*x))
accumulate_equation!(bc, eq)
end
end
140 changes: 140 additions & 0 deletions src/blox/sermon_dbs_tutorial.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
using Neuroblox
using DifferentialEquations
#using DataFrames
using Distributions, Random
#using Statistics
#using LinearAlgebra
using Graphs, MetaGraphs
using DSP

# First, create the three neurotransmitter pools. As they have the same structure,
# we can use the same constructor with different parameters to create the systems.
# Parameters are taken from the supplementary material Table A.
@named RRP = SermonNPool(τ_pool=1.0, p_pool=0.3)
@named RP = SermonNPool(τ_pool=10.0, p_pool=0.02)
@named RtP = SermonNPool(τ_pool=50.0, p_pool=0.00035)

# Next, create the DBS stimulator block. This block is a simple pulse with a given width
# and period. To try and recreate Fig. 2A, start stimulation at 50s.
@named DBS = SermonDBS(start_time=50.0, pulse_width=0.0001)

# Next, create a collection of 50 Kuramotor oscillators with noise.
# Set the number of oscillators (default from the paper)
N = 50

# Define the natural distribution of oscillator frequencies
# Parameters from supplementary material Table A
Ω = 249
σ = 26.317

ks_blocks = [KuramotoOscillator(name=Symbol("KO$i"),
ω=rand(Normal(Ω, σ)),
ζ=5.920,
include_noise=true) for i in 1:N]

# Create a graph and add all the blocks as nodes
g = MetaDiGraph()
add_blox!.(Ref(g), vcat(ks_blocks, DBS, RRP, RP, RtP))

# Additional connection parameters specified from supplementary material Table A
p = @parameters M_RRP=1.0 M_RP=0.5 M_RtP = 0.3 k_μ = 11976.0 I =72053.0

# Connect all oscillators to each other
for i in 1:N
for j in 1:N
# Weight is k_μ to be multiplied by the other terms in equation 6
add_edge!(g, i, j, Dict(:weight => k_μ, :sermon_rule => true, :extra_bloxs => [RRP, RP, RtP], :extra_params => p))
end
end

# Connect DBS stimulator to oscillators
for i in 1:N
add_edge!(g, N+1, i, Dict(:weight => I))
end

# Connect DBS stimulator to neurotransmitter pools
add_edge!(g, N+1, N+2, Dict(:weight => 1.0))
add_edge!(g, N+1, N+3, Dict(:weight => 1.0))
add_edge!(g, N+1, N+4, Dict(:weight => 1.0))

@named sys = system_from_graph(g, p)
sys = structural_simplify(sys)

# Simulate for the length of Fig. 2A
sim_len = 150.0
@time prob = SDEProblem(sys, [], (0.0, sim_len))
# EulerHeun because you need to force timestops at the pulse width of the DBS stimulator
# so fixed timestep solver is the easiest way to do this. Maxiter problems can happen with
# the adaptive ones I've tried.
@time sol = solve(prob, EulerHeun(), dt=0.0001, saveat=0.001)

# Helper function because θ from the Kuramoto oscillators is unbounded and needs wrapping
# to get the correct spectrogram
function wrapTo2Pi(x)
posinput = x .> 0
wrapped = mod.(x, 2*π)
wrapped[wrapped .== 0 .&& posinput] .= 2*π
return wrapped
end

# Helper code to plot the spectrogram averaged acrossed the 50 oscillators
data = Array(sol)
spec = spectrogram(wrapTo2Pi(data[1, :]), div(150001, 200); fs=1000)

freq = spec.freq
time = spec.time
hmmpower = spec.power

for i = 2:50
spec = spectrogram(wrapTo2Pi(data[i, :]), div(150001, 200); fs=1000)
hmmpower .+= spec.power
end

using Plots
# This should reproduce Fig. 2A. Instead, what I'm seeing is oscillating around
# the mean value (~20 Hz) as in the paper for the first 50s, then a jump up to a mean
# around the stimulation frequency (130Hz) but no emergent coherence at the 300Hz range.
heatmap(time, freq, pow2db.(hmmpower))

# To check if the coupling matches the values in Fig. 2D, compute equation 5 from the three
# neurotransmitter pools.
kₜ = zeros(length(sol))
for i = 1:length(sol)
kₜ[i] = max(data[N+1, i], data[N+2, i]*0.5, data[N+3, i]*0.3)
end

# This should reproduce the blue line in Fig. 2D. The shape is roughly accurate, but the
# scale is entirely off. I think this is because the authors dropped a factor in connecting
# the DBS stimulator to the neurotransmitter pools. Running just that simulation shows that
# they're very far off the actual values while preserving the shape of the dynamics.
plot(sol.t, kₜ .* 11976.0)

# Transmitter pool only simulation
# This *should* reproduce Figure 1B, but it doesn't. That's why I'm convinced there's a term
# missing in equation 4. The shape is correct, but the scale is off.

# Setup neurotransmitter pools
@named RRP = SermonNPool(τ_pool=1.0, p_pool=0.3)
@named RP = SermonNPool(τ_pool=10.0, p_pool=0.02)
@named RtP = SermonNPool(τ_pool=50.0, p_pool=0.00035)

# Next, create the DBS stimulator block. This block is a simple pulse with a given width
# and period.
@named DBS = SermonDBS(start_time=50.0, pulse_width=0.0001)

p = @parameters M_RRP=1.0 M_RP=0.5 M_RtP = 0.3 k_μ = 11976.0

g = MetaDiGraph()
add_blox!.(Ref(g), vcat(DBS, RRP, RP, RtP))
add_edge!(g, 1, 2, Dict(:weight => 1.0))
add_edge!(g, 1, 3, Dict(:weight => 1.0))
add_edge!(g, 1, 4, Dict(:weight => 1.0))

@named sys = system_from_graph(g, p)
sys = structural_simplify(sys)

prob = ODEProblem(sys, [], (0.0, 150.0))
sol = solve(prob, Euler(), dt=0.0001, saveat=0.001)

# This should reproduce Fig. 1B. It doesn't :()
plot(sol)