From c821577695cb9adb32e2ee078110bc3d980e5c75 Mon Sep 17 00:00:00 2001 From: gabrevaya Date: Mon, 23 Dec 2024 14:35:53 -0500 Subject: [PATCH 1/5] - fix GPe bug - add I_syn_msn and I_syn_fsi variables - use multithreading for computing power spectrum of EnsembleSolutions - new more general script for tuning parameters - initial tuning of full model in baseline condition --- docs/src/tutorials/basal_ganglia.jl | 24 +++++++++++------------ src/Neuroblox.jl | 2 +- src/blox/DBS_Model_Blox_Adam_Brown.jl | 18 ++++++++--------- src/blox/blox_utilities.jl | 9 +++++---- src/blox/connections.jl | 28 ++++++++------------------- src/blox/neuron_models.jl | 22 ++++++++++++++++++++- 6 files changed, 56 insertions(+), 47 deletions(-) diff --git a/docs/src/tutorials/basal_ganglia.jl b/docs/src/tutorials/basal_ganglia.jl index 6a2de78a..a90c81cb 100644 --- a/docs/src/tutorials/basal_ganglia.jl +++ b/docs/src/tutorials/basal_ganglia.jl @@ -55,10 +55,10 @@ rasterplot(msn, sol, threshold = -35, title = "Neuron's Spikes - Mean Firing Rat # Compute and plot the power spectrum of the GABAa current fig = Figure(size = (1500, 500)) -powerspectrumplot(fig[1,1], msn, sol, state = "G", +powerspectrumplot(fig[1,1], msn, sol, state = "I_syn_msn", title = "FFT with no window") -powerspectrumplot(fig[1,2], msn, sol, state = "G", +powerspectrumplot(fig[1,2], msn, sol, state = "I_syn_msn", method = welch_pgram, window = hanning, title = "Welch's method + Hanning window") fig @@ -105,12 +105,12 @@ fig = Figure(size = (1000, 800)) rasterplot(fig[1,1], msn, ens_sol[1], threshold = -35, title = "MSN - Mean Firing Rate: $(round(fr_msn[1], digits=2)) spikes/s") rasterplot(fig[1,2], fsi, ens_sol[1], threshold = -35, title = "FSI - Mean Firing Rate: $(round(fr_fsi[1], digits=2)) spikes/s") -powerspectrumplot(fig[2,1], msn, ens_sol, state = "G", +powerspectrumplot(fig[2,1], msn, ens_sol, state = "I_syn_msn", method = welch_pgram, window = hanning, ylims= (-35, 15), xlims= (8, 100)) -powerspectrumplot(fig[2,2], fsi, ens_sol, state = "G", +powerspectrumplot(fig[2,2], fsi, ens_sol, state = "I_syn_fsi", method=welch_pgram, window=hanning, ylims= (-35, 15), xlims= (8, 100)) @@ -153,22 +153,22 @@ ens_sol = solve(ens_prob, RKMil(), dt=dt, saveat = dt, trajectories = 3); # Compute and plot power spectra for all components fig = Figure(size = (1600, 450)) -powerspectrumplot(fig[1,1], msn, ens_sol, state = "G", +powerspectrumplot(fig[1,1], msn, ens_sol, state = "I_syn_msn", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "MSN (Baseline)") -powerspectrumplot(fig[1,2], fsi, ens_sol, state = "G", +powerspectrumplot(fig[1,2], fsi, ens_sol, state = "I_syn_fsi", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "FSI (Baseline)") -powerspectrumplot(fig[1,3], gpe, ens_sol, state = "G", +powerspectrumplot(fig[1,3], gpe, ens_sol, state = "V", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "GPe (Baseline)") -powerspectrumplot(fig[1,4], stn, ens_sol, state = "G", +powerspectrumplot(fig[1,4], stn, ens_sol, state = "V", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "STN (Baseline)") @@ -215,22 +215,22 @@ ens_prob = EnsembleProblem(prob) ens_sol = solve(ens_prob, RKMil(), dt = dt, saveat = dt, trajectories = 3); # Compute and compare power spectra for all neural populations in Parkinsonian condition against their counterparts in baseline conditions. -powerspectrumplot(fig[2,1], msn, ens_sol, state = "G", +powerspectrumplot(fig[2,1], msn, ens_sol, state = "I_syn_msn", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "MSN (PD)") -powerspectrumplot(fig[2,2], fsi, ens_sol, state = "G", +powerspectrumplot(fig[2,2], fsi, ens_sol, state = "I_syn_fsi", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "FSI (PD)") -powerspectrumplot(fig[2,3], gpe, ens_sol, state = "G", +powerspectrumplot(fig[2,3], gpe, ens_sol, state = "V", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "GPe (PD)") -powerspectrumplot(fig[2,4], stn, ens_sol, state = "G", +powerspectrumplot(fig[2,4], stn, ens_sol, state = "V", method = welch_pgram, window = hanning, ylims=(-40, 25), title = "STN (PD)") diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index 01b6a2c9..e72005a8 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -4,7 +4,7 @@ import Base: merge using Base.Threads: nthreads -using OhMyThreads: tmapreduce +using OhMyThreads: tmapreduce, tmap using Reexport @reexport using ModelingToolkit diff --git a/src/blox/DBS_Model_Blox_Adam_Brown.jl b/src/blox/DBS_Model_Blox_Adam_Brown.jl index a8cf658e..7a783f52 100644 --- a/src/blox/DBS_Model_Blox_Adam_Brown.jl +++ b/src/blox/DBS_Model_Blox_Adam_Brown.jl @@ -64,11 +64,11 @@ struct Striatum_MSN_Adam <: CompositeBlox namespace = nothing, N_inhib = 100, E_syn_inhib=-80, - I_bg=1.172*ones(N_inhib), + I_bg=1.153064742988923*ones(N_inhib), freq=zeros(N_inhib), phase=zeros(N_inhib), τ_inhib=13, - σ=0.11, + σ=0.17256774881503584, density=0.3, weight=0.1, G_M=1.3, @@ -134,12 +134,12 @@ struct Striatum_FSI_Adam <: CompositeBlox namespace = nothing, N_inhib = 50, E_syn_inhib=-80, - I_bg=6.2*ones(N_inhib), + I_bg=6.196201739395473*ones(N_inhib), freq=zeros(N_inhib), phase=zeros(N_inhib), τ_inhib=11, τ_inhib_s=6.5, - σ=1.2, + σ=0.9548801242101033, density=0.58, g_density=0.33, weight=0.6, @@ -215,17 +215,17 @@ struct GPe_Adam <: CompositeBlox namespace = nothing, N_inhib = 80, E_syn_inhib=-80, - I_bg=3.4*ones(N_inhib), + I_bg=3.272893843123162*ones(N_inhib), freq=zeros(N_inhib), phase=zeros(N_inhib), τ_inhib=10, - σ=1.7, + σ=1.0959782801317943, density=0.0, weight=0.0, connection_matrix=nothing ) n_inh = [ - HHNeuronInhib_MSN_Adam_Blox( + HHNeuronInhib_GPe_Adam_Blox( name = Symbol("inh$i"), namespace = namespaced_name(namespace, name), E_syn = E_syn_inhib, @@ -284,11 +284,11 @@ struct STN_Adam <: CompositeBlox namespace = nothing, N_exci = 40, E_syn_exci=0.0, - I_bg=1.8*ones(N_exci), + I_bg=2.2010777359961953*ones(N_exci), freq=zeros(N_exci), phase=zeros(N_exci), τ_exci=2, - σ=1.7, + σ=2.9158528502583545, density=0.0, weight=0.0, connection_matrix=nothing diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index 168c4a1c..46bfc629 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -578,11 +578,12 @@ function powerspectrum(cb::Union{CompositeBlox, AbstractVector{<:AbstractNeuronB t_sampled, sampling_freq = get_sampling_info(sols[1]; sampling_rate=sampling_rate) powspecs = DSP.Periodograms.Periodogram[] - for sol in sols - s = meanfield_timeseries(cb, sol, state; ts = t_sampled) - powspec = method(s, fs=sampling_freq, window=window) - push!(powspecs, powspec) + powspecs = tmap(eachindex(sols)) do i + sol = sols[i] + s = meanfield_timeseries(cb, sol, state; ts=t_sampled) + method(s, fs=sampling_freq, window=window) end + powspecs = collect(powspecs) return powspecs end diff --git a/src/blox/connections.jl b/src/blox/connections.jl index f92514bc..cfb5ff32 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -316,27 +316,14 @@ function Connector( sys_dest.I_syn ~ -w * sys_src.G * (sys_dest.V - sys_src.E_syn) end - return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w, learning_rule=Dict(w => lr)) -end - -function Connector( - blox_src::Union{HHNeuronInhib_MSN_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox}, - blox_dest::Union{HHNeuronInhib_MSN_Adam_Blox, HHNeuronInhib_FSI_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox}; - kwargs... -) - sys_src = get_namespaced_sys(blox_src) - sys_dest = get_namespaced_sys(blox_dest) - - w = generate_weight_param(blox_src, blox_dest; kwargs...) - - STA = get_sta(kwargs, nameof(blox_src), nameof(blox_dest)) - eq = if STA - sys_dest.I_syn ~ -w * sys_dest.Gₛₜₚ * sys_src.G * (sys_dest.V - sys_src.E_syn) + if blox_src isa HHNeuronInhib_MSN_Adam_Blox && blox_dest isa HHNeuronInhib_MSN_Adam_Blox + eq2 = sys_dest.I_syn_msn ~ -w * sys_src.G * (sys_dest.V - sys_src.E_syn) + eqs = [eq, eq2] else - sys_dest.I_syn ~ -w * sys_src.G * (sys_dest.V - sys_src.E_syn) + eqs = eq end - return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) + return Connector(nameof(sys_src), nameof(sys_dest); equation=eqs, weight=w, learning_rule=Dict(w => lr)) end function Connector( @@ -365,6 +352,7 @@ function Connector( w = generate_weight_param(blox_src, blox_dest; kwargs...) eq = sys_dest.I_syn ~ -w * sys_src.Gₛ * (sys_dest.V - sys_src.E_syn) + eq4 = sys_dest.I_syn_fsi ~ -w * sys_src.Gₛ * (sys_dest.V - sys_src.E_syn) GAP = get_gap(kwargs, nameof(blox_src), nameof(blox_dest)) if GAP @@ -372,9 +360,9 @@ function Connector( eq2 = sys_dest.I_gap ~ -w_gap * (sys_dest.V - sys_src.V) eq3 = sys_src.I_gap ~ -w_gap * (sys_src.V - sys_dest.V) - return Connector(nameof(sys_src), nameof(sys_dest); equation=[eq, eq2, eq3], weight=[w, w_gap]) + return Connector(nameof(sys_src), nameof(sys_dest); equation=[eq, eq2, eq3, eq4], weight=[w, w_gap]) else - return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) + return Connector(nameof(sys_src), nameof(sys_dest); equation=[eq, eq4], weight=w) end end diff --git a/src/blox/neuron_models.jl b/src/blox/neuron_models.jl index b1162246..8a976220 100644 --- a/src/blox/neuron_models.jl +++ b/src/blox/neuron_models.jl @@ -199,7 +199,14 @@ struct HHNeuronInhib_MSN_Adam_Blox <: AbstractInhNeuronBlox I_asc(t) [input=true] G(t)=0.0 - [output = true] + [output = true] + + spikes_cumulative(t)=0.0 + spikes_window(t)=0.0 + + # observables + I_syn_msn(t)=0.0 + [input=true] end ps = @parameters begin @@ -292,6 +299,13 @@ struct HHNeuronInhib_FSI_Adam_Blox <: AbstractInhNeuronBlox [output = true] Gₛ(t)=0.0 [output = true] + + spikes_cumulative(t)=0.0 + spikes_window(t)=0.0 + + # observables + I_syn_fsi(t)=0.0 + [input=true] end ps = @parameters begin @@ -379,6 +393,9 @@ struct HHNeuronExci_STN_Adam_Blox <: AbstractExciNeuronBlox [input=true] G(t)=0.0 [output = true] + + spikes_cumulative(t)=0.0 + spikes_window(t)=0.0 end ps = @parameters begin @@ -458,6 +475,9 @@ struct HHNeuronInhib_GPe_Adam_Blox <: AbstractInhNeuronBlox [input=true] G(t)=0.0 [output = true] + + spikes_cumulative(t)=0.0 + spikes_window(t)=0.0 end ps = @parameters begin From 0707fc4186763f79a0fea18e5c4c01e5ef62c5b3 Mon Sep 17 00:00:00 2001 From: gabrevaya Date: Tue, 24 Dec 2024 07:49:55 -0500 Subject: [PATCH 2/5] add parameter tuning script --- examples/tune_parameters.jl | 395 ++++++++++++++++++++++++++++++++++++ 1 file changed, 395 insertions(+) create mode 100644 examples/tune_parameters.jl diff --git a/examples/tune_parameters.jl b/examples/tune_parameters.jl new file mode 100644 index 00000000..1e337562 --- /dev/null +++ b/examples/tune_parameters.jl @@ -0,0 +1,395 @@ +using Neuroblox +using StochasticDiffEq +using Random +using Statistics +using Optimization +using OptimizationOptimJL +using ModelingToolkit: setp, getp + + +struct PopConfig + name::Symbol + blox + firing_rate_target::Union{Nothing, Float64} + firing_rate_weight::Float64 + freq_target::Union{Nothing, Float64} + freq_weight::Float64 + freq_min::Union{Nothing, Float64} + freq_max::Union{Nothing, Float64} + freq_aggregate_state::Union{Nothing, String} + tunable::Bool +end + +function PopConfig( + name, + blox; + firing_rate_target=nothing, + firing_rate_weight=0.0, + freq_target=nothing, + freq_weight=0.0, + freq_min=nothing, + freq_max=nothing, + freq_aggregate_state=nothing, + tunable=false +) + PopConfig( + name, + blox, + firing_rate_target, + firing_rate_weight, + freq_target, + freq_weight, + freq_min, + freq_max, + freq_aggregate_state, + tunable + ) +end + +function get_peak_freq(powspec, freq_min, freq_max; + freq_ind = findall(x -> x > freq_min && x < freq_max, powspec.freq)) + + if isempty(freq_ind) + return NaN + end + ind = argmax(powspec.power[freq_ind]) + return powspec.freq[freq_ind][ind] +end + +function get_peak_freq(powspecs::Vector, freq_min, freq_max) + freq_ind = findall(x -> x > freq_min && x < freq_max, powspecs[1].freq) + if isempty(freq_ind) + return NaN + end + + # peak_freqs = [get_peak_freq(powspec, freq_min, freq_max, freq_ind=freq_ind) for powspec in powspecs] + # mean_peak_freq = mean(peak_freqs) + # std_peak_freq = std(peak_freqs) + + mean_power = mean(powspec.power[freq_ind] for powspec in powspecs) + ind = argmax(mean_power) + + return powspecs[1].freq[freq_ind][ind], NaN64 + # return mean_peak_freq, std_peak_freq +end + +function get_inds(params_str::Vector{String}, pattern::Vector{String}) + # Find indices where all patterns appear + findall(str -> all(pat -> occursin(pat, str), pattern), params_str) +end + +function create_problem(size, T) + @info "Size = $size" + Random.seed!(123) + N_MSN = round(Int64, 100*size) + N_FSI = round(Int64, 50*size) + N_GPe = round(Int64, 80*size) + N_STN = round(Int64, 40*size) + + make_conn = Neuroblox.indegree_constrained_connection_matrix + + global_ns = :g + + ḡ_FSI_MSN = 0.6 + density_FSI_MSN = 0.15 + weight_FSI_MSN = ḡ_FSI_MSN / (N_FSI * density_FSI_MSN) + conn_FSI_MSN = make_conn(density_FSI_MSN, N_FSI, N_MSN) + + ḡ_MSN_GPe = 2.5 + density_MSN_GPe = 0.33 + weight_MSN_GPe = ḡ_MSN_GPe / (N_MSN * density_MSN_GPe) + conn_MSN_GPe = make_conn(density_MSN_GPe, N_MSN, N_GPe) + + ḡ_GPe_STN = 0.3 + density_GPe_STN = 0.05 + weight_GPe_STN = ḡ_GPe_STN / (N_GPe * density_GPe_STN) + conn_GPe_STN = make_conn(density_GPe_STN, N_GPe, N_STN) + + ḡ_STN_FSI = 0.165 + density_STN_FSI = 0.1 + weight_STN_FSI = ḡ_STN_FSI / (N_STN * density_STN_FSI) + conn_STN_FSI = make_conn(density_STN_FSI, N_STN, N_FSI) + + @named msn = Striatum_MSN_Adam(namespace=global_ns, N_inhib=N_MSN, I_bg=1.153064742988923*ones(N_MSN), σ=0.17256774881503584) + @named fsi = Striatum_FSI_Adam(namespace=global_ns, N_inhib=N_FSI, I_bg=6.196201739395473*ones(N_FSI), σ=0.9548801242101033) + @named gpe = GPe_Adam(namespace=global_ns, N_inhib=N_GPe, I_bg=3.272893843123162, σ=1.0959782801317943) + @named stn = STN_Adam(namespace=global_ns, N_exci=N_STN, I_bg=2.2010777359961953*ones(N_STN), σ=2.9158528502583545) + + g = MetaDiGraph() + add_edge!(g, fsi => msn, weight=weight_FSI_MSN, connection_matrix=conn_FSI_MSN) + add_edge!(g, msn => gpe, weight=weight_MSN_GPe, connection_matrix=conn_MSN_GPe) + add_edge!(g, gpe => stn, weight=weight_GPe_STN, connection_matrix=conn_GPe_STN) + add_edge!(g, stn => fsi, weight=weight_STN_FSI, connection_matrix=conn_STN_FSI) + + @info "Creating system from graph" + @named sys = system_from_graph(g) + + # For MSN-only + # g = MetaDiGraph() ## defines a graph + # add_blox!(g, msn) ## adds the defined blocks into the graph + # @named sys = system_from_graph(g) + + tspan = (0.0, T) + @info "Creating SDEProblem" + prob = SDEProblem(sys, [], tspan, []) + + return prob, sys, msn, fsi, gpe, stn +end + +function build_pop_param_indices(params_str::Vector{String}, populations) + indices = NamedTuple{keys(populations)}(map(pop -> begin + pop_str = string(pop.name) + I_bg_inds = get_inds(params_str, ["I_bg", pop_str]) + σ_inds = get_inds(params_str, ["σ", pop_str]) + if isempty(I_bg_inds) + error("No I_bg parameters found for population $(pop.name). Check naming.") + end + if isempty(σ_inds) + error("No σ parameters found for population $(pop.name). Check naming.") + end + (I_bg_ind = I_bg_inds, σ_ind = σ_inds) + end, populations)) + return indices +end + +function remake_prob!(prob, args, p) + @unpack populations, pop_params, get_ps, set_ps! = args + ps_new = get_ps(prob) + + offset = 1 + for (popname, popconfig) in pairs(populations) + if popconfig.tunable + # This population is tuned: assign parameters from p + I_bg_inds = pop_params[popname].I_bg_ind + σ_inds = pop_params[popname].σ_ind + + ps_new[I_bg_inds] .= abs(p[offset]) + offset += 1 + ps_new[σ_inds] .= abs(p[offset]) + offset += 1 + end + end + + set_ps!(prob, ps_new) +end + +function loss(p, args) + @unpack dt, saveat, solver, seed, other_diffeq_kwargs, threshold, transient, populations, prob, pop_params = args + @unpack trajectories, ensemblealg = args + Random.seed!(seed) + remake_prob!(prob, args, p) + + ens_prob = EnsembleProblem(prob) + sol = solve(ens_prob, solver, ensemblealg, trajectories=trajectories, dt=dt, saveat=saveat; other_diffeq_kwargs...) + + total_err = 0.0 + + # Compute firing rates and frequencies + for (popname, popconfig) in pairs(populations) + if popconfig.firing_rate_target !== nothing + fr_res = firing_rate(popconfig.blox, sol, threshold=threshold, transient=transient, scheduler=:dynamic) + fr, fr_std = fr_res + err_fr = (fr[1] - popconfig.firing_rate_target)^2 * popconfig.firing_rate_weight + total_err += err_fr + @info "[$popname] FR = $(fr[1]), FR std = $(fr_std[1]), FR Error = $err_fr" + end + + if popconfig.freq_target !== nothing + powspecs = powerspectrum(popconfig.blox, sol, popconfig.freq_aggregate_state, method=welch_pgram, window=hamming) + + # peak_freqs = Float64[] + # for powspec in powspecs + # peak_freq = get_peak_freq(powspec, popconfig.freq_min, popconfig.freq_max) + # @show peak_freq + # push!(peak_freqs, peak_freq) + # end + # @show mean(peak_freqs) + + peak_freq, peak_freq_std = get_peak_freq(powspecs, popconfig.freq_min, popconfig.freq_max) + err_freq = abs(peak_freq - popconfig.freq_target) * popconfig.freq_weight + total_err += err_freq + @info "[$popname] Peak freq = $peak_freq, Freq std = $peak_freq_std, Freq Error = $err_freq" + end + end + + return total_err +end + +function run_optimization() + prob, sys, msn, fsi, gpe, stn = create_problem(1.0, 5500.0) + params_str = string.(tunable_parameters(sys)) + get_ps = getp(prob, tunable_parameters(sys)) + set_ps! = setp(prob, tunable_parameters(sys)) + + populations = ( + + ## only MSN + + # msn = PopConfig( + # :msn, + # msn, + # firing_rate_target=1.46, + # firing_rate_weight=3.0, + # freq_target=17.53, + # freq_weight=1.0, + # freq_min=5.0, + # freq_max=25.0, + # freq_aggregate_state="I_syn_msn", + # tunable=true + # ), + + ## MSN - FSI + + # msn = PopConfig( + # :msn, + # msn, + # firing_rate_target=1.88, + # firing_rate_weight=3.0, + # tunable=false + # ), + # fsi = PopConfig( + # :fsi, + # fsi, + # firing_rate_target=10.66, + # firing_rate_weight=3.0, + # freq_target=58.63, + # freq_weight=1.0, + # freq_min=40.0, + # freq_max=90.0, + # freq_aggregate_state="I_syn_fsi", + # tunable=true + # ) + + ## full model in baseline conditions + + msn = PopConfig( + :msn, + msn, + firing_rate_target=1.21, + firing_rate_weight=60.0, + freq_target=10, + freq_weight=0.5, + freq_min=3.0, + freq_max=20.0, + freq_aggregate_state="I_syn_msn", + tunable=false + ), + fsi = PopConfig( + :fsi, + fsi, + firing_rate_target=13.00, + firing_rate_weight=10.0, + freq_target=61.14, + freq_weight=1.0, + freq_min=40.0, + freq_max=90.0, + freq_aggregate_state="I_syn_fsi", + tunable=false + ), + gpe = PopConfig( + :gpe, + gpe, + freq_target=85.0, + freq_weight=0.5, + freq_min=40.0, + freq_max=90.0, + freq_aggregate_state="V", + tunable=true + ), + stn = PopConfig( + :stn, + stn, + tunable=true + ), + + + ## full model PD conditions + + # msn = PopConfig( + # :msn, + # msn, + # firing_rate_target=4.85, + # firing_rate_weight=60.0, + # freq_target=15.80, + # freq_weight=1.0, + # freq_min=3.0, + # freq_max=50.0, + # freq_aggregate_state="I_syn_msn", + # tunable=true + # ), + # fsi = PopConfig( + # :fsi, + # fsi, + # freq_target=15.80, + # freq_weight=1.0, + # freq_min=3.0, + # freq_max=50.0, + # freq_aggregate_state="I_syn_fsi", + # tunable=false + # ), + # gpe = PopConfig( + # :gpe, + # gpe, + # firing_rate_target=50.0, + # firing_rate_weight=1.0, + # freq_target=15.80, + # freq_weight=1.0, + # freq_min=3.0, + # freq_max=50.0, + # freq_aggregate_state="V", + # tunable=false + # ), + # stn = PopConfig( + # :stn, + # stn, + # freq_target=15.80, + # freq_weight=1.0, + # freq_min=3.0, + # freq_max=50.0, + # firing_rate_target=42.30, + # firing_rate_weight=1.0, + # tunable=false + # ), + ); + + pop_params = build_pop_param_indices(params_str, populations) + + args = ( + dt = 0.05, + saveat = 0.05, + solver = RKMil(), + seed = 1234, + other_diffeq_kwargs = (abstol=1e-3, reltol=1e-6, maxiters=1e10), + ensemblealg = EnsembleThreads(), + trajectories = 3, + threshold = -35, + transient = 200, + populations = populations, + pop_params = pop_params, + prob = prob, + get_ps = get_ps, + set_ps! = set_ps! + ) + + callback = function (state, l) + @info "Iteration: $(state.iter)" + @info "Parameters: $(state.u)" + @info "Loss: $l" + return false + end + + # Starting guess + p = [3.272893843123162, 1.0959782801317943, 2.2010777359961953, 2.9158528502583545] + + # Solve optimization + result = solve( + Optimization.OptimizationProblem(loss, p, args), + Optim.NelderMead(); + callback=callback, + maxiters=200 + ) + + return result +end + +result = run_optimization() \ No newline at end of file From e14cc4aa28a83aa5d22786bfb12f04371b1587bb Mon Sep 17 00:00:00 2001 From: gabrevaya Date: Fri, 27 Dec 2024 04:51:40 -0500 Subject: [PATCH 3/5] - make powerspectrum dispatch for EnsembleSolution type-stable - fix memory leak in parameter tuning code - improve parameter tuning code type-stability --- examples/tune_parameters.jl | 251 ++++++++++++++++++------------------ src/Neuroblox.jl | 2 +- src/blox/blox_utilities.jl | 22 ++-- 3 files changed, 141 insertions(+), 134 deletions(-) diff --git a/examples/tune_parameters.jl b/examples/tune_parameters.jl index 1e337562..6e6f197e 100644 --- a/examples/tune_parameters.jl +++ b/examples/tune_parameters.jl @@ -5,61 +5,49 @@ using Statistics using Optimization using OptimizationOptimJL using ModelingToolkit: setp, getp +using DSP +using AbstractFFTs - -struct PopConfig +@kwdef struct PopConfig{BLOX} name::Symbol - blox - firing_rate_target::Union{Nothing, Float64} - firing_rate_weight::Float64 - freq_target::Union{Nothing, Float64} - freq_weight::Float64 - freq_min::Union{Nothing, Float64} - freq_max::Union{Nothing, Float64} - freq_aggregate_state::Union{Nothing, String} - tunable::Bool + blox::BLOX + firing_rate_target::Union{Nothing, Float64} = nothing + firing_rate_weight::Float64 = 1.0 + freq_target::Union{Nothing, Float64} = nothing + freq_weight::Float64 = 1.0 + freq_min::Union{Nothing, Float64} = nothing + freq_max::Union{Nothing, Float64} = nothing + freq_aggregate_state::Union{Nothing, String} = nothing + tunable::Bool = false end -function PopConfig( - name, - blox; - firing_rate_target=nothing, - firing_rate_weight=0.0, - freq_target=nothing, - freq_weight=0.0, - freq_min=nothing, - freq_max=nothing, - freq_aggregate_state=nothing, - tunable=false -) - PopConfig( - name, - blox, - firing_rate_target, - firing_rate_weight, - freq_target, - freq_weight, - freq_min, - freq_max, - freq_aggregate_state, - tunable - ) +const PopIndexNT = NamedTuple{(:I_bg_ind, :σ_ind), Tuple{Vector{Int}, Vector{Int}}} + +@kwdef struct Args{GetterT,SetterT,BLOXTuple,ProblemT,SolverT,DiffeqkargsT,PopT,PopparamsT} + dt::Float64 = 0.05 + saveat::Float64 = 0.05 + solver::SolverT = RKMil() + seed::Int = 1234 + other_diffeq_kwargs::DiffeqkargsT = (abstol=1e-3, reltol=1e-6, maxiters=1e10) + ensemblealg::EnsembleThreads = EnsembleThreads() + trajectories::Int = 3 + threshold::Float64 = -35 + transient::Float64 = 200 + populations::PopT + pop_params::PopparamsT + prob::ProblemT + get_ps::GetterT + set_ps!::SetterT end -function get_peak_freq(powspec, freq_min, freq_max; - freq_ind = findall(x -> x > freq_min && x < freq_max, powspec.freq)) - - if isempty(freq_ind) - return NaN - end - ind = argmax(powspec.power[freq_ind]) - return powspec.freq[freq_ind][ind] -end -function get_peak_freq(powspecs::Vector, freq_min, freq_max) - freq_ind = findall(x -> x > freq_min && x < freq_max, powspecs[1].freq) +function get_peak_freq( + powspecs, + freq_min, + freq_max) + freq_ind = get_freq_inds(powspecs[1].freq, freq_min, freq_max) if isempty(freq_ind) - return NaN + return NaN64, NaN64 end # peak_freqs = [get_peak_freq(powspec, freq_min, freq_max, freq_ind=freq_ind) for powspec in powspecs] @@ -70,7 +58,10 @@ function get_peak_freq(powspecs::Vector, freq_min, freq_max) ind = argmax(mean_power) return powspecs[1].freq[freq_ind][ind], NaN64 - # return mean_peak_freq, std_peak_freq +end + +function get_freq_inds(freq, freq_min, freq_max) + return findall(x -> x > freq_min && x < freq_max, freq) end function get_inds(params_str::Vector{String}, pattern::Vector{String}) @@ -112,7 +103,7 @@ function create_problem(size, T) @named msn = Striatum_MSN_Adam(namespace=global_ns, N_inhib=N_MSN, I_bg=1.153064742988923*ones(N_MSN), σ=0.17256774881503584) @named fsi = Striatum_FSI_Adam(namespace=global_ns, N_inhib=N_FSI, I_bg=6.196201739395473*ones(N_FSI), σ=0.9548801242101033) - @named gpe = GPe_Adam(namespace=global_ns, N_inhib=N_GPe, I_bg=3.272893843123162, σ=1.0959782801317943) + @named gpe = GPe_Adam(namespace=global_ns, N_inhib=N_GPe, I_bg=3.272893843123162*ones(N_GPe), σ=1.0959782801317943) @named stn = STN_Adam(namespace=global_ns, N_exci=N_STN, I_bg=2.2010777359961953*ones(N_STN), σ=2.9158528502583545) g = MetaDiGraph() @@ -131,37 +122,39 @@ function create_problem(size, T) tspan = (0.0, T) @info "Creating SDEProblem" - prob = SDEProblem(sys, [], tspan, []) + prob = SDEProblem{true}(sys, [], tspan, []) return prob, sys, msn, fsi, gpe, stn end -function build_pop_param_indices(params_str::Vector{String}, populations) - indices = NamedTuple{keys(populations)}(map(pop -> begin - pop_str = string(pop.name) +function build_pop_param_indices(params_str::Vector{String}, + populations::Tuple{Vararg{PopConfig}}) + idxs = map(popconfig -> begin + pop_str = string(popconfig.name) I_bg_inds = get_inds(params_str, ["I_bg", pop_str]) - σ_inds = get_inds(params_str, ["σ", pop_str]) + σ_inds = get_inds(params_str, ["σ", pop_str]) if isempty(I_bg_inds) - error("No I_bg parameters found for population $(pop.name). Check naming.") + error("No I_bg parameters found for population $(popconfig.name). Check naming.") end if isempty(σ_inds) - error("No σ parameters found for population $(pop.name). Check naming.") + error("No σ parameters found for population $(popconfig.name). Check naming.") end (I_bg_ind = I_bg_inds, σ_ind = σ_inds) - end, populations)) - return indices + end, populations) + + return collect(idxs) end + function remake_prob!(prob, args, p) @unpack populations, pop_params, get_ps, set_ps! = args ps_new = get_ps(prob) offset = 1 - for (popname, popconfig) in pairs(populations) + for (i, popconfig) in enumerate(populations) if popconfig.tunable - # This population is tuned: assign parameters from p - I_bg_inds = pop_params[popname].I_bg_ind - σ_inds = pop_params[popname].σ_ind + I_bg_inds = pop_params[i][:I_bg_ind] + σ_inds = pop_params[i][:σ_ind] ps_new[I_bg_inds] .= abs(p[offset]) offset += 1 @@ -171,44 +164,48 @@ function remake_prob!(prob, args, p) end set_ps!(prob, ps_new) + return nothing end function loss(p, args) - @unpack dt, saveat, solver, seed, other_diffeq_kwargs, threshold, transient, populations, prob, pop_params = args - @unpack trajectories, ensemblealg = args + @unpack dt, saveat, solver, seed, other_diffeq_kwargs, threshold, transient = args + @unpack populations, prob, pop_params, trajectories, ensemblealg = args Random.seed!(seed) + + # Update prob in-place remake_prob!(prob, args, p) ens_prob = EnsembleProblem(prob) - sol = solve(ens_prob, solver, ensemblealg, trajectories=trajectories, dt=dt, saveat=saveat; other_diffeq_kwargs...) + sol = solve(ens_prob, solver, ensemblealg; + trajectories=trajectories, + dt=dt, + saveat=saveat, + other_diffeq_kwargs...) total_err = 0.0 # Compute firing rates and frequencies - for (popname, popconfig) in pairs(populations) + for popconfig in populations if popconfig.firing_rate_target !== nothing - fr_res = firing_rate(popconfig.blox, sol, threshold=threshold, transient=transient, scheduler=:dynamic) - fr, fr_std = fr_res + fr, fr_std = firing_rate(popconfig.blox, sol; + threshold=threshold, + transient=transient, + scheduler=:dynamic) err_fr = (fr[1] - popconfig.firing_rate_target)^2 * popconfig.firing_rate_weight total_err += err_fr - @info "[$popname] FR = $(fr[1]), FR std = $(fr_std[1]), FR Error = $err_fr" + @info "[$(popconfig.name)] FR = $(fr[1]), FR std = $(fr_std[1]), FR Error = $err_fr" end if popconfig.freq_target !== nothing - powspecs = powerspectrum(popconfig.blox, sol, popconfig.freq_aggregate_state, method=welch_pgram, window=hamming) - - # peak_freqs = Float64[] - # for powspec in powspecs - # peak_freq = get_peak_freq(powspec, popconfig.freq_min, popconfig.freq_max) - # @show peak_freq - # push!(peak_freqs, peak_freq) - # end - # @show mean(peak_freqs) - + powspecs = powerspectrum(popconfig.blox, sol, + popconfig.freq_aggregate_state; + method=welch_pgram, + window=hamming) peak_freq, peak_freq_std = get_peak_freq(powspecs, popconfig.freq_min, popconfig.freq_max) + err_freq = abs(peak_freq - popconfig.freq_target) * popconfig.freq_weight total_err += err_freq - @info "[$popname] Peak freq = $peak_freq, Freq std = $peak_freq_std, Freq Error = $err_freq" + @info "[$(popconfig.name)] Peak freq = $peak_freq, Freq std = $peak_freq_std, Freq Error = $err_freq" end end @@ -216,7 +213,7 @@ function loss(p, args) end function run_optimization() - prob, sys, msn, fsi, gpe, stn = create_problem(1.0, 5500.0) + prob, sys, msn, fsi, gpe, stn = create_problem(0.1, 5500.0) params_str = string.(tunable_parameters(sys)) get_ps = getp(prob, tunable_parameters(sys)) set_ps! = setp(prob, tunable_parameters(sys)) @@ -225,9 +222,9 @@ function run_optimization() ## only MSN - # msn = PopConfig( - # :msn, - # msn, + # PopConfig( + # name=:msn, + # blox=msn, # firing_rate_target=1.46, # firing_rate_weight=3.0, # freq_target=17.53, @@ -240,16 +237,16 @@ function run_optimization() ## MSN - FSI - # msn = PopConfig( - # :msn, - # msn, + # PopConfig( + # name=:msn, + # blox=msn, # firing_rate_target=1.88, # firing_rate_weight=3.0, # tunable=false # ), - # fsi = PopConfig( - # :fsi, - # fsi, + # PopConfig( + # name=:fsi, + # blox=fsi, # firing_rate_target=10.66, # firing_rate_weight=3.0, # freq_target=58.63, @@ -262,21 +259,21 @@ function run_optimization() ## full model in baseline conditions - msn = PopConfig( - :msn, - msn, + PopConfig( + name=:msn, + blox=msn, firing_rate_target=1.21, firing_rate_weight=60.0, - freq_target=10, + freq_target=10.0, freq_weight=0.5, freq_min=3.0, freq_max=20.0, freq_aggregate_state="I_syn_msn", tunable=false ), - fsi = PopConfig( - :fsi, - fsi, + PopConfig( + name=:fsi, + blox=fsi, firing_rate_target=13.00, firing_rate_weight=10.0, freq_target=61.14, @@ -286,9 +283,9 @@ function run_optimization() freq_aggregate_state="I_syn_fsi", tunable=false ), - gpe = PopConfig( - :gpe, - gpe, + PopConfig( + name=:gpe, + blox=gpe, freq_target=85.0, freq_weight=0.5, freq_min=40.0, @@ -296,18 +293,18 @@ function run_optimization() freq_aggregate_state="V", tunable=true ), - stn = PopConfig( - :stn, - stn, + PopConfig( + name=:stn, + blox=stn, tunable=true ), ## full model PD conditions - # msn = PopConfig( - # :msn, - # msn, + # PopConfig( + # name=:msn, + # blox=msn, # firing_rate_target=4.85, # firing_rate_weight=60.0, # freq_target=15.80, @@ -317,9 +314,9 @@ function run_optimization() # freq_aggregate_state="I_syn_msn", # tunable=true # ), - # fsi = PopConfig( - # :fsi, - # fsi, + # PopConfig( + # name=:fsi, + # blox=fsi, # freq_target=15.80, # freq_weight=1.0, # freq_min=3.0, @@ -327,9 +324,9 @@ function run_optimization() # freq_aggregate_state="I_syn_fsi", # tunable=false # ), - # gpe = PopConfig( - # :gpe, - # gpe, + # PopConfig( + # name=:gpe, + # blox=gpe, # firing_rate_target=50.0, # firing_rate_weight=1.0, # freq_target=15.80, @@ -339,12 +336,13 @@ function run_optimization() # freq_aggregate_state="V", # tunable=false # ), - # stn = PopConfig( - # :stn, - # stn, + # PopConfig( + # name=:stn, + # blox=stn, # freq_target=15.80, # freq_weight=1.0, # freq_min=3.0, + # freq_aggregate_state="V", # freq_max=50.0, # firing_rate_target=42.30, # firing_rate_weight=1.0, @@ -354,19 +352,21 @@ function run_optimization() pop_params = build_pop_param_indices(params_str, populations) - args = ( + other_diffeq_kwargs = (abstol=1e-3, reltol=1e-6, maxiters=1e10) + solver = RKMil() + args = Args{typeof(get_ps), typeof(set_ps!), typeof(populations), typeof(prob), typeof(solver), typeof(other_diffeq_kwargs), typeof(populations), typeof(pop_params)}( dt = 0.05, saveat = 0.05, - solver = RKMil(), + solver = solver, seed = 1234, - other_diffeq_kwargs = (abstol=1e-3, reltol=1e-6, maxiters=1e10), + other_diffeq_kwargs = other_diffeq_kwargs, ensemblealg = EnsembleThreads(), trajectories = 3, - threshold = -35, - transient = 200, - populations = populations, - pop_params = pop_params, - prob = prob, + threshold = -35.0, + transient = 200.0, + populations = populations, # the tuple + pop_params = pop_params, # the vector of indices + prob = prob, # the SDEProblem get_ps = get_ps, set_ps! = set_ps! ) @@ -375,15 +375,16 @@ function run_optimization() @info "Iteration: $(state.iter)" @info "Parameters: $(state.u)" @info "Loss: $l" + println("\n") return false end # Starting guess - p = [3.272893843123162, 1.0959782801317943, 2.2010777359961953, 2.9158528502583545] + u = [3.272893843123162, 1.0959782801317943, 2.2010777359961953, 2.9158528502583545] # Solve optimization result = solve( - Optimization.OptimizationProblem(loss, p, args), + Optimization.OptimizationProblem(loss, u, args), Optim.NelderMead(); callback=callback, maxiters=200 diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index e72005a8..20cfe50f 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -4,7 +4,7 @@ import Base: merge using Base.Threads: nthreads -using OhMyThreads: tmapreduce, tmap +using OhMyThreads: tmapreduce, tforeach using Reexport @reexport using ModelingToolkit diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index 46bfc629..c7e21d13 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -572,18 +572,24 @@ function powerspectrum(blox::AbstractNeuronBlox, sol::SciMLBase.AbstractSolution end function powerspectrum(cb::Union{CompositeBlox, AbstractVector{<:AbstractNeuronBlox}}, - sols::SciMLBase.EnsembleSolution, state::String; sampling_rate=nothing, - method=periodogram, window=nothing)::Vector{DSP.Periodograms.Periodogram} - - t_sampled, sampling_freq = get_sampling_info(sols[1]; sampling_rate=sampling_rate) - powspecs = DSP.Periodograms.Periodogram[] + sols::SciMLBase.EnsembleSolution{T}, + state::String; + sampling_rate=nothing, + method=periodogram, + window=nothing + ) where {T} - powspecs = tmap(eachindex(sols)) do i + t_sampled, sampling_freq = get_sampling_info(sols[1]; sampling_rate=sampling_rate) + + # Pre-allocate concretely typed array + powspecs = Vector{DSP.Periodograms.Periodogram{T, + DSP.Frequencies{T}, + Vector{T}}}(undef, length(sols)) + tforeach(eachindex(sols)) do i sol = sols[i] s = meanfield_timeseries(cb, sol, state; ts=t_sampled) - method(s, fs=sampling_freq, window=window) + powspecs[i] = method(s, fs=sampling_freq, window=window) end - powspecs = collect(powspecs) return powspecs end From d2b4ae3b9b3c02b8bc6310d2580dccbfa7ac967a Mon Sep 17 00:00:00 2001 From: gabrevaya Date: Fri, 27 Dec 2024 21:05:46 -0300 Subject: [PATCH 4/5] refactor parameter tuning script to adopt a better design pattern and make it type-stable --- examples/tune_parameters.jl | 611 +++++++++++++++++------------------- src/blox/blox_utilities.jl | 3 +- 2 files changed, 298 insertions(+), 316 deletions(-) diff --git a/examples/tune_parameters.jl b/examples/tune_parameters.jl index 6e6f197e..50a14489 100644 --- a/examples/tune_parameters.jl +++ b/examples/tune_parameters.jl @@ -5,58 +5,79 @@ using Statistics using Optimization using OptimizationOptimJL using ModelingToolkit: setp, getp -using DSP -using AbstractFFTs +using Unrolled + +# Abstract type for all metrics +abstract type AbstractPopulationMetric end + +########### +# Metrics # +########### +struct FiringRateMetric{T} <: AbstractPopulationMetric + target::T + weight::T + threshold::T + transient::T +end +struct FrequencyMetric{T} <: AbstractPopulationMetric + target::T + weight::T + min_freq::T + max_freq::T + aggregate_state::String +end -@kwdef struct PopConfig{BLOX} - name::Symbol - blox::BLOX - firing_rate_target::Union{Nothing, Float64} = nothing - firing_rate_weight::Float64 = 1.0 - freq_target::Union{Nothing, Float64} = nothing - freq_weight::Float64 = 1.0 - freq_min::Union{Nothing, Float64} = nothing - freq_max::Union{Nothing, Float64} = nothing - freq_aggregate_state::Union{Nothing, String} = nothing - tunable::Bool = false +####################### +# Metric Computations # +####################### + +function compute_metric(m::FiringRateMetric, pop, sol; logging::Bool=false) + fr, fr_std = firing_rate(pop.blox, sol; + threshold=m.threshold, + transient=m.transient, + scheduler=:dynamic) + val = (fr[1] - m.target)^2 * m.weight + + if logging + @info "[$(pop.name)] Firing Rate = $(fr[1]) ± $(fr_std[1]) " * + " (target=$(m.target)), metric error = $val" + end + return val end -const PopIndexNT = NamedTuple{(:I_bg_ind, :σ_ind), Tuple{Vector{Int}, Vector{Int}}} - -@kwdef struct Args{GetterT,SetterT,BLOXTuple,ProblemT,SolverT,DiffeqkargsT,PopT,PopparamsT} - dt::Float64 = 0.05 - saveat::Float64 = 0.05 - solver::SolverT = RKMil() - seed::Int = 1234 - other_diffeq_kwargs::DiffeqkargsT = (abstol=1e-3, reltol=1e-6, maxiters=1e10) - ensemblealg::EnsembleThreads = EnsembleThreads() - trajectories::Int = 3 - threshold::Float64 = -35 - transient::Float64 = 200 - populations::PopT - pop_params::PopparamsT - prob::ProblemT - get_ps::GetterT - set_ps!::SetterT +function compute_metric(m::FrequencyMetric, pop, sol; logging::Bool=false) + powspecs = powerspectrum(pop.blox, sol, m.aggregate_state; + method=welch_pgram, + window=hamming) + peak_freq, peak_freq_std = get_peak_freq(powspecs, m.min_freq, m.max_freq) + val = abs(peak_freq - m.target) * m.weight + + if logging + @info "[$(pop.name)] Peak Frequency = $peak_freq ± $peak_freq_std " * + " in [$(m.min_freq), $(m.max_freq)] (target=$(m.target)), metric error = $val" + end + return val end +#################### +# Helper Functions # +#################### -function get_peak_freq( - powspecs, - freq_min, - freq_max) +function get_peak_freq(powspecs, freq_min, freq_max) freq_ind = get_freq_inds(powspecs[1].freq, freq_min, freq_max) if isempty(freq_ind) - return NaN64, NaN64 + return NaN64, NaN64 end + # Average the power across the different trajectories + mean_power = mean(powspec.power[freq_ind] for powspec in powspecs) + ind = argmax(mean_power) + + # alternative method # peak_freqs = [get_peak_freq(powspec, freq_min, freq_max, freq_ind=freq_ind) for powspec in powspecs] # mean_peak_freq = mean(peak_freqs) # std_peak_freq = std(peak_freqs) - mean_power = mean(powspec.power[freq_ind] for powspec in powspecs) - ind = argmax(mean_power) - return powspecs[1].freq[freq_ind][ind], NaN64 end @@ -64,11 +85,155 @@ function get_freq_inds(freq, freq_min, freq_max) return findall(x -> x > freq_min && x < freq_max, freq) end -function get_inds(params_str::Vector{String}, pattern::Vector{String}) - # Find indices where all patterns appear - findall(str -> all(pat -> occursin(pat, str), pattern), params_str) +struct TuningSpec + param_map::Dict{String, Vector{Int}} +end + +# Find indexes of the parameters to be tuned +function build_tuning_spec(prob, pop_name::String, param_names::Vector{String}) + paramlist = string.(tunable_parameters(prob.f.sys)) + param_map = Dict{String,Vector{Int}}() + for pname in param_names + inds = findall(str -> occursin(pname, str) && occursin(pop_name, str), + paramlist) + param_map[pname] = inds + end + return TuningSpec(param_map) +end + +################ +# Populations # +################ + +struct Population{B,N,MT<:NTuple{N,AbstractPopulationMetric}} + name::Symbol + blox::B + metrics::MT + tuning::TuningSpec + tunable::Bool +end + +""" + compute_metrics(pop, sol; logging=false) + +Sum the contributions of all metrics in `pop.metrics`, optionally logging. +""" +function compute_metrics(pop::Population, sol; logging::Bool=false) + total = zero(eltype(sol)) + for m in pop.metrics + total += compute_metric(m, pop, sol; logging=logging) + end + return total +end + +function Population( + name, + blox; + frm=nothing, + freqm=nothing, + tuning_params=String[], + prob=nothing, + tunable::Bool=false +) + mt = () + if frm !== nothing + fr_target, fr_weight, fr_threshold, fr_transient = frm + mt = (FiringRateMetric(fr_target, fr_weight, fr_threshold, fr_transient),) + end + if freqm !== nothing + freq_target, freq_weight, fmin, fmax, agg = freqm + freq = FrequencyMetric(freq_target, freq_weight, fmin, fmax, agg) + mt = tuple(mt..., freq) + end + + tspec = isempty(tuning_params) || prob === nothing ? + TuningSpec(Dict()) : + build_tuning_spec(prob, string(name), tuning_params) + + local N = length(mt) + return Population{typeof(blox), N, typeof(mt)}(name, blox, mt, tspec, tunable) +end + +""" + update_parameters!(prob, populations, p, get_ps, set_ps!) + +Update the `prob` in place, assigning the parameter values from `p` +according to each population's `tuning` spec. +""" +function update_parameters!(prob, populations, p, get_ps, set_ps!) + ps_new = get_ps(prob) + offset = 1 + + for pop in populations + if !pop.tunable + continue + end + + for (param_name, inds) in pop.tuning.param_map + ps_new[inds] .= abs.(p[offset]) + offset += 1 + end + end + + set_ps!(prob, ps_new) + return nothing +end + +############################# +# OptimizationConfig + Loss # +############################# + +struct OptimizationConfig{P,PopT,GetterT,SetterT,SolverT,EnsembleAlgT,dtT,DiffeqkargsT} + prob::P + populations::PopT + get_ps::GetterT + set_ps!::SetterT + solver::SolverT + ensemblealg::EnsembleAlgT + dt::dtT + other_diffeq_kwargs::DiffeqkargsT + trajectories::Int + seed::Int end +@unroll function sum_metrics_unrolled(pops, sol; logging=false) + total_err = zero(eltype(sol)) + for pop in pops + total_err += compute_metrics(pop, sol; logging=logging) + end + return total_err +end + +""" + loss(p, config::OptimizationConfig; logging=false) + +Update parameters, solve the ensemble problem, +compute total error, and optionally log each metric's value. +""" +function loss(p, config::OptimizationConfig; logging::Bool=false) + # Set random seed + Random.seed!(config.seed) + + # Update prob in-place + update_parameters!(config.prob, config.populations, p, config.get_ps, config.set_ps!) + + # Solve + ens_prob = EnsembleProblem(config.prob) + sol = solve(ens_prob, config.solver, config.ensemblealg; + trajectories=config.trajectories, + dt=config.dt, + saveat=config.dt, + config.other_diffeq_kwargs...) + + # Sum errors from each population + total_err = sum_metrics_unrolled(config.populations, sol; logging=logging) + return total_err +end + +################# +# Example usage # +################# + function create_problem(size, T) @info "Size = $size" Random.seed!(123) @@ -101,10 +266,30 @@ function create_problem(size, T) weight_STN_FSI = ḡ_STN_FSI / (N_STN * density_STN_FSI) conn_STN_FSI = make_conn(density_STN_FSI, N_STN, N_FSI) - @named msn = Striatum_MSN_Adam(namespace=global_ns, N_inhib=N_MSN, I_bg=1.153064742988923*ones(N_MSN), σ=0.17256774881503584) - @named fsi = Striatum_FSI_Adam(namespace=global_ns, N_inhib=N_FSI, I_bg=6.196201739395473*ones(N_FSI), σ=0.9548801242101033) - @named gpe = GPe_Adam(namespace=global_ns, N_inhib=N_GPe, I_bg=3.272893843123162*ones(N_GPe), σ=1.0959782801317943) - @named stn = STN_Adam(namespace=global_ns, N_exci=N_STN, I_bg=2.2010777359961953*ones(N_STN), σ=2.9158528502583545) + @named msn = Striatum_MSN_Adam( + namespace=global_ns, + N_inhib=N_MSN, + I_bg=1.153064742988923*ones(N_MSN), + σ=0.17256774881503584 + ) + @named fsi = Striatum_FSI_Adam( + namespace=global_ns, + N_inhib=N_FSI, + I_bg=6.196201739395473*ones(N_FSI), + σ=0.9548801242101033 + ) + @named gpe = GPe_Adam( + namespace=global_ns, + N_inhib=N_GPe, + I_bg=3.272893843123162*ones(N_GPe), + σ=1.0959782801317943 + ) + @named stn = STN_Adam( + namespace=global_ns, + N_exci=N_STN, + I_bg=2.2010777359961953*ones(N_STN), + σ=2.9158528502583545 + ) g = MetaDiGraph() add_edge!(g, fsi => msn, weight=weight_FSI_MSN, connection_matrix=conn_FSI_MSN) @@ -115,11 +300,6 @@ function create_problem(size, T) @info "Creating system from graph" @named sys = system_from_graph(g) - # For MSN-only - # g = MetaDiGraph() ## defines a graph - # add_blox!(g, msn) ## adds the defined blocks into the graph - # @named sys = system_from_graph(g) - tspan = (0.0, T) @info "Creating SDEProblem" prob = SDEProblem{true}(sys, [], tspan, []) @@ -127,270 +307,71 @@ function create_problem(size, T) return prob, sys, msn, fsi, gpe, stn end -function build_pop_param_indices(params_str::Vector{String}, - populations::Tuple{Vararg{PopConfig}}) - idxs = map(popconfig -> begin - pop_str = string(popconfig.name) - I_bg_inds = get_inds(params_str, ["I_bg", pop_str]) - σ_inds = get_inds(params_str, ["σ", pop_str]) - if isempty(I_bg_inds) - error("No I_bg parameters found for population $(popconfig.name). Check naming.") - end - if isempty(σ_inds) - error("No σ parameters found for population $(popconfig.name). Check naming.") - end - (I_bg_ind = I_bg_inds, σ_ind = σ_inds) - end, populations) - - return collect(idxs) -end - - -function remake_prob!(prob, args, p) - @unpack populations, pop_params, get_ps, set_ps! = args - ps_new = get_ps(prob) - - offset = 1 - for (i, popconfig) in enumerate(populations) - if popconfig.tunable - I_bg_inds = pop_params[i][:I_bg_ind] - σ_inds = pop_params[i][:σ_ind] - - ps_new[I_bg_inds] .= abs(p[offset]) - offset += 1 - ps_new[σ_inds] .= abs(p[offset]) - offset += 1 - end - end - - set_ps!(prob, ps_new) - return nothing -end - -function loss(p, args) - @unpack dt, saveat, solver, seed, other_diffeq_kwargs, threshold, transient = args - @unpack populations, prob, pop_params, trajectories, ensemblealg = args - Random.seed!(seed) - - # Update prob in-place - remake_prob!(prob, args, p) - - ens_prob = EnsembleProblem(prob) - sol = solve(ens_prob, solver, ensemblealg; - trajectories=trajectories, - dt=dt, - saveat=saveat, - other_diffeq_kwargs...) - - total_err = 0.0 - - # Compute firing rates and frequencies - for popconfig in populations - if popconfig.firing_rate_target !== nothing - fr, fr_std = firing_rate(popconfig.blox, sol; - threshold=threshold, - transient=transient, - scheduler=:dynamic) - err_fr = (fr[1] - popconfig.firing_rate_target)^2 * popconfig.firing_rate_weight - total_err += err_fr - @info "[$(popconfig.name)] FR = $(fr[1]), FR std = $(fr_std[1]), FR Error = $err_fr" - end - - if popconfig.freq_target !== nothing - powspecs = powerspectrum(popconfig.blox, sol, - popconfig.freq_aggregate_state; - method=welch_pgram, - window=hamming) - peak_freq, peak_freq_std = get_peak_freq(powspecs, popconfig.freq_min, popconfig.freq_max) - - err_freq = abs(peak_freq - popconfig.freq_target) * popconfig.freq_weight - total_err += err_freq - @info "[$(popconfig.name)] Peak freq = $peak_freq, Freq std = $peak_freq_std, Freq Error = $err_freq" - end - end - - return total_err -end - -function run_optimization() - prob, sys, msn, fsi, gpe, stn = create_problem(0.1, 5500.0) - params_str = string.(tunable_parameters(sys)) - get_ps = getp(prob, tunable_parameters(sys)) - set_ps! = setp(prob, tunable_parameters(sys)) - - populations = ( - - ## only MSN - - # PopConfig( - # name=:msn, - # blox=msn, - # firing_rate_target=1.46, - # firing_rate_weight=3.0, - # freq_target=17.53, - # freq_weight=1.0, - # freq_min=5.0, - # freq_max=25.0, - # freq_aggregate_state="I_syn_msn", - # tunable=true - # ), - - ## MSN - FSI - - # PopConfig( - # name=:msn, - # blox=msn, - # firing_rate_target=1.88, - # firing_rate_weight=3.0, - # tunable=false - # ), - # PopConfig( - # name=:fsi, - # blox=fsi, - # firing_rate_target=10.66, - # firing_rate_weight=3.0, - # freq_target=58.63, - # freq_weight=1.0, - # freq_min=40.0, - # freq_max=90.0, - # freq_aggregate_state="I_syn_fsi", - # tunable=true - # ) - - ## full model in baseline conditions - - PopConfig( - name=:msn, - blox=msn, - firing_rate_target=1.21, - firing_rate_weight=60.0, - freq_target=10.0, - freq_weight=0.5, - freq_min=3.0, - freq_max=20.0, - freq_aggregate_state="I_syn_msn", - tunable=false - ), - PopConfig( - name=:fsi, - blox=fsi, - firing_rate_target=13.00, - firing_rate_weight=10.0, - freq_target=61.14, - freq_weight=1.0, - freq_min=40.0, - freq_max=90.0, - freq_aggregate_state="I_syn_fsi", - tunable=false - ), - PopConfig( - name=:gpe, - blox=gpe, - freq_target=85.0, - freq_weight=0.5, - freq_min=40.0, - freq_max=90.0, - freq_aggregate_state="V", - tunable=true - ), - PopConfig( - name=:stn, - blox=stn, - tunable=true - ), - - - ## full model PD conditions - - # PopConfig( - # name=:msn, - # blox=msn, - # firing_rate_target=4.85, - # firing_rate_weight=60.0, - # freq_target=15.80, - # freq_weight=1.0, - # freq_min=3.0, - # freq_max=50.0, - # freq_aggregate_state="I_syn_msn", - # tunable=true - # ), - # PopConfig( - # name=:fsi, - # blox=fsi, - # freq_target=15.80, - # freq_weight=1.0, - # freq_min=3.0, - # freq_max=50.0, - # freq_aggregate_state="I_syn_fsi", - # tunable=false - # ), - # PopConfig( - # name=:gpe, - # blox=gpe, - # firing_rate_target=50.0, - # firing_rate_weight=1.0, - # freq_target=15.80, - # freq_weight=1.0, - # freq_min=3.0, - # freq_max=50.0, - # freq_aggregate_state="V", - # tunable=false - # ), - # PopConfig( - # name=:stn, - # blox=stn, - # freq_target=15.80, - # freq_weight=1.0, - # freq_min=3.0, - # freq_aggregate_state="V", - # freq_max=50.0, - # firing_rate_target=42.30, - # firing_rate_weight=1.0, - # tunable=false - # ), - ); - - pop_params = build_pop_param_indices(params_str, populations) - - other_diffeq_kwargs = (abstol=1e-3, reltol=1e-6, maxiters=1e10) - solver = RKMil() - args = Args{typeof(get_ps), typeof(set_ps!), typeof(populations), typeof(prob), typeof(solver), typeof(other_diffeq_kwargs), typeof(populations), typeof(pop_params)}( - dt = 0.05, - saveat = 0.05, - solver = solver, - seed = 1234, - other_diffeq_kwargs = other_diffeq_kwargs, - ensemblealg = EnsembleThreads(), - trajectories = 3, - threshold = -35.0, - transient = 200.0, - populations = populations, # the tuple - pop_params = pop_params, # the vector of indices - prob = prob, # the SDEProblem - get_ps = get_ps, - set_ps! = set_ps! - ) - - callback = function (state, l) - @info "Iteration: $(state.iter)" - @info "Parameters: $(state.u)" - @info "Loss: $l" - println("\n") - return false - end - - # Starting guess - u = [3.272893843123162, 1.0959782801317943, 2.2010777359961953, 2.9158528502583545] - - # Solve optimization - result = solve( - Optimization.OptimizationProblem(loss, u, args), - Optim.NelderMead(); - callback=callback, - maxiters=200 - ) - - return result +# Example +prob, sys, msn, fsi, gpe, stn = create_problem(0.1, 5500.0) + +msn_pop = Population( + :msn, msn; + frm = (1.21, 60.0, -35.0, 200.0), # FiringRateMetric: (target, weight, threshold, transient) + freqm = (10.0, 0.5, 3.0, 20.0, "I_syn_msn"), # FrequencyMetric: (target, weight, fmin, fmax, aggregate_state) + prob = prob, + tunable = false +) + +fsi_pop = Population( + :fsi, fsi; + frm = (13.0, 10.0, -35.0, 200.0), + freqm = (61.14, 1.0, 40.0, 90.0, "I_syn_fsi"), + prob = prob, + tunable = false +) + +gpe_pop = Population( + :gpe, gpe; + freqm = (85.0, 0.5, 40.0, 90.0, "V"), + tuning_params = ["I_bg", "σ"], + prob = prob, + tunable = true +) + +stn_pop = Population( + :stn, stn; + tuning_params = ["I_bg", "σ"], + prob = prob, + tunable = false +) + +other_diffeq_kwargs = (abstol=1e-3, reltol=1e-6, maxiters=1e10) +get_ps = getp(prob, tunable_parameters(sys)) +set_ps! = setp(prob, tunable_parameters(sys)) + +config = OptimizationConfig( + prob, + (msn_pop, fsi_pop, gpe_pop, stn_pop), + get_ps, + set_ps!, + RKMil(), + EnsembleThreads(), + 0.1, + other_diffeq_kwargs, + 3, + 1234 +) + +u = [3.272893843123162, 1.0959782801317943, 2.2010777359961953, 2.9158528502583545] +# optprob = Optimization.OptimizationProblem(loss, p0, config) +optprob = Optimization.OptimizationProblem((p, config)->loss(p, config; logging=true), p0, config) +callback = function (state, l) + println("\n") + @info "Iteration: $(state.iter)" + @info "Parameters: $(state.u)" + @info "Loss: $l" + println("\n") + return false end -result = run_optimization() \ No newline at end of file +# Example run +res = solve(optprob, Optim.NelderMead(); + maxiters=2, + callback=callback +) \ No newline at end of file diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index c7e21d13..5b4a6eb5 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -469,7 +469,8 @@ function firing_rate( 1000.0 * (nnz(spikes[idx_start:idx_end, :]) / N_neurons) / win_size end - return fr + T = eltype(sol) + return fr::Vector{T} end function firing_rate( From c150d71bc81b72dbd9b05d8ab2899030a3687deb Mon Sep 17 00:00:00 2001 From: gabrevaya Date: Mon, 30 Dec 2024 17:08:49 -0300 Subject: [PATCH 5/5] - use Mason's unroll macro in order to achieve full type stability - some code restructure --- examples/tune_parameters.jl | 216 ++++++++++++++++++------------------ 1 file changed, 105 insertions(+), 111 deletions(-) diff --git a/examples/tune_parameters.jl b/examples/tune_parameters.jl index 50a14489..4f8d6d4e 100644 --- a/examples/tune_parameters.jl +++ b/examples/tune_parameters.jl @@ -5,11 +5,69 @@ using Statistics using Optimization using OptimizationOptimJL using ModelingToolkit: setp, getp -using Unrolled -# Abstract type for all metrics abstract type AbstractPopulationMetric end +struct Population{B,N,MT<:NTuple{N,AbstractPopulationMetric}} + name::Symbol + blox::B + metrics::MT + tuning::Dict{String, Vector{Int}} + tunable::Bool +end + +function Population( + name, + blox; + frm=nothing, + freqm=nothing, + tuning_params=String[], + prob=nothing, + tunable::Bool=false +) + mt = () + if frm !== nothing + fr_target, fr_weight, fr_threshold, fr_transient = frm + mt = (FiringRateMetric(fr_target, fr_weight, fr_threshold, fr_transient),) + end + if freqm !== nothing + freq_target, freq_weight, fmin, fmax, agg = freqm + freq = FrequencyMetric(freq_target, freq_weight, fmin, fmax, agg) + mt = tuple(mt..., freq) + end + + tspec = isempty(tuning_params) || prob === nothing ? + Dict() : build_tuning_spec(prob, string(name), tuning_params) + + local N = length(mt) + return Population{typeof(blox), N, typeof(mt)}(name, blox, mt, tspec, tunable) +end + +# Find indexes of the parameters to be tuned +function build_tuning_spec(prob, pop_name::String, param_names::Vector{String}) + paramlist = string.(tunable_parameters(prob.f.sys)) + param_map = Dict{String,Vector{Int}}() + for pname in param_names + inds = findall(str -> occursin(pname, str) && occursin(pop_name, str), + paramlist) + param_map[pname] = inds + end + return param_map +end + +struct OptimizationConfig{P,PopT,GetterT,SetterT,SolverT,EnsembleAlgT,dtT,DiffeqkargsT} + prob::P + populations::PopT + get_ps::GetterT + set_ps!::SetterT + solver::SolverT + ensemblealg::EnsembleAlgT + dt::dtT + other_diffeq_kwargs::DiffeqkargsT + trajectories::Int + seed::Int +end + ########### # Metrics # ########### @@ -27,11 +85,7 @@ struct FrequencyMetric{T} <: AbstractPopulationMetric aggregate_state::String end -####################### -# Metric Computations # -####################### - -function compute_metric(m::FiringRateMetric, pop, sol; logging::Bool=false) +function compute_error(m::FiringRateMetric, pop, sol; logging::Bool=false) fr, fr_std = firing_rate(pop.blox, sol; threshold=m.threshold, transient=m.transient, @@ -45,7 +99,7 @@ function compute_metric(m::FiringRateMetric, pop, sol; logging::Bool=false) return val end -function compute_metric(m::FrequencyMetric, pop, sol; logging::Bool=false) +function compute_error(m::FrequencyMetric, pop, sol; logging::Bool=false) powspecs = powerspectrum(pop.blox, sol, m.aggregate_state; method=welch_pgram, window=hamming) @@ -85,93 +139,55 @@ function get_freq_inds(freq, freq_min, freq_max) return findall(x -> x > freq_min && x < freq_max, freq) end -struct TuningSpec - param_map::Dict{String, Vector{Int}} -end - -# Find indexes of the parameters to be tuned -function build_tuning_spec(prob, pop_name::String, param_names::Vector{String}) - paramlist = string.(tunable_parameters(prob.f.sys)) - param_map = Dict{String,Vector{Int}}() - for pname in param_names - inds = findall(str -> occursin(pname, str) && occursin(pop_name, str), - paramlist) - param_map[pname] = inds +# Mason's unroll macro, from https://github.com/Neuroblox/GraphDynamics.jl/blob/6c0bbb81abf1981c52a4605dc32d7073fea2ff0d/src/utils.jl#L10 +macro unroll(N::Int, loop) + Base.isexpr(loop, :for) || error("only works on for loops") + Base.isexpr(loop.args[1], :(=)) || error("This loop pattern isn't supported") + val, itr = esc.(loop.args[1].args) + body = esc(loop.args[2]) + @gensym loopend + label = :(@label $loopend) + goto = :(@goto $loopend) + out = Expr(:block, :(itr = $itr), :(next = iterate(itr))) + unrolled = map(1:N) do _ + quote + isnothing(next) && @goto loopend + $val, state = next + $body + next = iterate(itr, state) + end end - return TuningSpec(param_map) -end - -################ -# Populations # -################ - -struct Population{B,N,MT<:NTuple{N,AbstractPopulationMetric}} - name::Symbol - blox::B - metrics::MT - tuning::TuningSpec - tunable::Bool + append!(out.args, unrolled) + remainder = quote + while !isnothing(next) + $val, state = next + $body + next = iterate(itr, state) + end + @label loopend + end + push!(out.args, remainder) + out end -""" - compute_metrics(pop, sol; logging=false) - -Sum the contributions of all metrics in `pop.metrics`, optionally logging. -""" -function compute_metrics(pop::Population, sol; logging::Bool=false) +function compute_errors(pop::Population, sol; logging::Bool=false) total = zero(eltype(sol)) for m in pop.metrics - total += compute_metric(m, pop, sol; logging=logging) + total += compute_error(m, pop, sol; logging=logging) end return total end -function Population( - name, - blox; - frm=nothing, - freqm=nothing, - tuning_params=String[], - prob=nothing, - tunable::Bool=false -) - mt = () - if frm !== nothing - fr_target, fr_weight, fr_threshold, fr_transient = frm - mt = (FiringRateMetric(fr_target, fr_weight, fr_threshold, fr_transient),) - end - if freqm !== nothing - freq_target, freq_weight, fmin, fmax, agg = freqm - freq = FrequencyMetric(freq_target, freq_weight, fmin, fmax, agg) - mt = tuple(mt..., freq) - end - - tspec = isempty(tuning_params) || prob === nothing ? - TuningSpec(Dict()) : - build_tuning_spec(prob, string(name), tuning_params) - - local N = length(mt) - return Population{typeof(blox), N, typeof(mt)}(name, blox, mt, tspec, tunable) -end - -""" - update_parameters!(prob, populations, p, get_ps, set_ps!) - -Update the `prob` in place, assigning the parameter values from `p` -according to each population's `tuning` spec. -""" function update_parameters!(prob, populations, p, get_ps, set_ps!) ps_new = get_ps(prob) offset = 1 - for pop in populations - if !pop.tunable - continue - end - - for (param_name, inds) in pop.tuning.param_map - ps_new[inds] .= abs.(p[offset]) - offset += 1 + @unroll 16 for pop in populations + if pop.tunable + for (param_name, inds) in pop.tuning + ps_new[inds] .= abs.(p[offset]) + offset += 1 + end end end @@ -179,44 +195,21 @@ function update_parameters!(prob, populations, p, get_ps, set_ps!) return nothing end -############################# -# OptimizationConfig + Loss # -############################# - -struct OptimizationConfig{P,PopT,GetterT,SetterT,SolverT,EnsembleAlgT,dtT,DiffeqkargsT} - prob::P - populations::PopT - get_ps::GetterT - set_ps!::SetterT - solver::SolverT - ensemblealg::EnsembleAlgT - dt::dtT - other_diffeq_kwargs::DiffeqkargsT - trajectories::Int - seed::Int -end - -@unroll function sum_metrics_unrolled(pops, sol; logging=false) +function sum_errors(pops, sol; logging=false) total_err = zero(eltype(sol)) - for pop in pops - total_err += compute_metrics(pop, sol; logging=logging) + @unroll 16 for pop in pops + total_err += compute_errors(pop, sol; logging=logging) end return total_err end -""" - loss(p, config::OptimizationConfig; logging=false) - -Update parameters, solve the ensemble problem, -compute total error, and optionally log each metric's value. -""" function loss(p, config::OptimizationConfig; logging::Bool=false) # Set random seed Random.seed!(config.seed) # Update prob in-place update_parameters!(config.prob, config.populations, p, config.get_ps, config.set_ps!) - + # Solve ens_prob = EnsembleProblem(config.prob) sol = solve(ens_prob, config.solver, config.ensemblealg; @@ -226,7 +219,7 @@ function loss(p, config::OptimizationConfig; logging::Bool=false) config.other_diffeq_kwargs...) # Sum errors from each population - total_err = sum_metrics_unrolled(config.populations, sol; logging=logging) + total_err = sum_errors(config.populations, sol; logging=logging) return total_err end @@ -358,7 +351,8 @@ config = OptimizationConfig( 1234 ) -u = [3.272893843123162, 1.0959782801317943, 2.2010777359961953, 2.9158528502583545] +p0 = [3.272893843123162, 1.0959782801317943, 2.2010777359961953, 2.9158528502583545] + # optprob = Optimization.OptimizationProblem(loss, p0, config) optprob = Optimization.OptimizationProblem((p, config)->loss(p, config; logging=true), p0, config) callback = function (state, l)