From 6956047b164dfd365f8faa0ff44be3572ac3b782 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Thu, 25 Jul 2024 18:31:56 +0300 Subject: [PATCH 01/50] delete unused file --- test/final_dde_before_merge.jl | 237 --------------------------------- 1 file changed, 237 deletions(-) delete mode 100644 test/final_dde_before_merge.jl diff --git a/test/final_dde_before_merge.jl b/test/final_dde_before_merge.jl deleted file mode 100644 index 77b65e95..00000000 --- a/test/final_dde_before_merge.jl +++ /dev/null @@ -1,237 +0,0 @@ -# Intro setup -using ModelingToolkit, DelayDiffEq, Graphs, MetaGraphs, Plots -using ModelingToolkit: get_namespace, get_systems, renamespace, namespace_equation, namespace_expr -import ModelingToolkit: inputs, outputs, nameof -@variables t -D = Differential(t) - -# neural_mass.jl -mutable struct JansenRitCBloxDelay - p_dict::Dict{Symbol,Union{Real,Num}} - eqs::Vector{Equation} - sts::Vector{Any} - connector - jcn - odesystem - namespace - function JansenRitCBloxDelay(;name, τ=0.001, H=20.0, λ=5.0, r=0.15) - para_dict = scope_dict(Dict{Symbol,Union{Real,Num}}(:τ => τ, :H => H, :λ => λ, :r => r)) - τ=para_dict[:τ] - H=para_dict[:H] - λ=para_dict[:λ] - r=para_dict[:r] - sts = @variables x(..)=1.0 y(t)=1.0 jcn(t)=0.0 [input=true] - eqs = [D(x(t)) ~ y - ((2/τ)*x(t)), - D(y) ~ -x(t)/(τ*τ) + (H/τ)*((2*λ)/(1 + exp(-r*(jcn))) - λ)] - odesystem = System(eqs, name=name) - new(para_dict, eqs, sts, sts[1], sts[3], odesystem, nothing) - end -end - -# blox_utilities.jl -namespaced_name(parent_name, name) = Symbol(parent_name, :₊, name) -namespaced_name(nothing, name) = Symbol(name) -namespaceof(blox) = blox.namespace -get_sys(blox) = blox.odesystem - -function scope_dict(para_dict::Dict{Symbol, Union{Real,Num}}) - para_dict_copy = copy(para_dict) - for (n,v) in para_dict_copy - if typeof(v) == Num - para_dict_copy[n] = ParentScope(v) - else - para_dict_copy[n] = (@parameters $n=v)[1] - end - end - return para_dict_copy -end - -# REDEFINED -nameof(blox) = (nameof ∘ get_sys)(blox) # Redefined because no odesys to get name from - -function get_namespaced_sys(blox) - sys = get_sys(blox) - ODESystem( - equations(sys), - independent_variable(sys), - unknowns(sys), - parameters(sys); - name = namespaced_name(inner_namespaceof(blox), nameof(blox)) - ) -end - -function inner_namespaceof(blox) - parts = split((string ∘ namespaceof)(blox), '₊') - if length(parts) == 1 - return nothing - else - return join(parts[2:end], '₊') - end -end - -function find_eq(eqs::AbstractVector{<:Equation}, lhs) - findfirst(eqs) do eq - lhs_vars = get_variables(eq.lhs) - length(lhs_vars) == 1 && isequal(only(lhs_vars), lhs) - end -end - -# redefined to be used for delays -function input_equations(blox) - sys = get_sys(blox) - inps = inputs(sys) - sys_eqs = equations(sys) - # CHANGE HERE - @variables t - eqs = map(inps) do inp - idx = find_eq(sys_eqs, inp) - if isnothing(idx) - namespace_equation( - inp ~ 0, - nothing, - namespaced_name(inner_namespaceof(blox), nameof(blox)); - ivs = t - ) - else - namespace_equation( - sys_eqs[idx], - nothing, - namespaced_name(inner_namespaceof(blox), nameof(blox)); - ivs = t - ) - end - end - - return eqs -end - -# Neurographs.jl -function add_blox!(g::MetaDiGraph,blox) - add_vertex!(g, :blox, blox) -end - -weight_parameters(blox) = Num[] - -# connections.jl -mutable struct BloxConnector - eqs::Vector{Equation} - params::Vector{Num} - - BloxConnector() = new(Equation[], Num[]) - - function BloxConnector(bloxs) - eqs = reduce(vcat, input_equations.(bloxs)) - params = reduce(vcat, weight_parameters.(bloxs)) - #eqs = namespace_equation.(eqs, nothing, namespace) - new(eqs, params) - end -end - -function (bc::BloxConnector)( - jc::JansenRitCBloxDelay, - bloxin; - weight = 1, - delay = 0 -) - # Need t for the delay term - @variables t - - sys_out = get_namespaced_sys(jc) - sys_in = get_namespaced_sys(bloxin) - - # Define & accumulate delay parameter - τ_name = Symbol("τ_$(nameof(sys_out))_$(nameof(sys_in))") - τ = only(@parameters $(τ_name)=delay) - push!(bc.params, τ) - - w_name = Symbol("w_$(nameof(sys_out))_$(nameof(sys_in))") - w = only(@parameters $(w_name)=weight) - push!(bc.params, w) - - x = namespace_expr(jc.connector, nothing, nameof(sys_out)) - eq = sys_in.jcn ~ x(t-τ)*w - - accumulate_equation!(bc, eq) -end - -function accumulate_equation!(bc::BloxConnector, eq) - lhs = eq.lhs - idx = find_eq(bc.eqs, lhs) - bc.eqs[idx] = bc.eqs[idx].lhs ~ bc.eqs[idx].rhs + eq.rhs - -end - -# Neurographs.jl -function get_sys(g::MetaDiGraph) - map(vertices(g)) do v - b = get_prop(g, v, :blox) - get_sys(b) - end -end - -function system_from_graph(g::MetaDiGraph; name) - bc = connector_from_graph(g) - return system_from_graph(g, bc; name) -end - -function system_from_graph(g::MetaDiGraph, bc::BloxConnector; name) - @variables t - blox_syss = get_sys(g) - return compose(ODESystem(bc.eqs, t, [], bc.params; name), blox_syss) -end - -function get_blox(g::MetaDiGraph) - map(vertices(g)) do v - get_prop(g, v, :blox) - end -end - -# switch to system_from_graph -function connector_from_graph(g::MetaDiGraph) - bloxs = get_blox(g) - link = BloxConnector(bloxs) - for v in vertices(g) - b = get_prop(g, v, :blox) - for vn in inneighbors(g, v) - bn = get_prop(g, vn, :blox) - w = get_prop(g, vn, v, :weight) - d = get_prop(g, vn, v, :delay) # CHANGE HERE - link(bn, b; weight = w, delay = d) - end - end - - return link -end - -# test blox -@named PY = JansenRitCBloxDelay(τ=0.001, H=20, λ=5, r=0.15) -@named EI = JansenRitCBloxDelay(τ=0.01, H=20, λ=5, r=5) -@named II = JansenRitCBloxDelay(τ=2.0, H=60, λ=5, r=5) - -# test graphs -g = MetaDiGraph() -add_blox!(g, PY) -add_blox!(g, EI) -add_blox!(g, II) -add_edge!(g, 1, 2, Dict(:weight => 1.0, :delay => 0.5)) -add_edge!(g, 2, 3, Dict(:weight => 1.0, :delay => 0.5)) -add_edge!(g, 3, 1, Dict(:weight => 1.0, :delay => 1.5)) - -@named final_system = system_from_graph(g) -sim_dur = 10.0 # Simulate for 10 Seconds -sys = structural_simplify(final_system) -prob = DDEProblem(sys, - [], - (0.0, sim_dur), - constant_lags = [1]) -alg = MethodOfSteps(Tsit5()) -sol_mtk = solve(prob, alg, reltol = 1e-7, abstol = 1e-10, saveat=0.001) - -### MERGED SO FAR ### - -# Notes for 9/6 meeting - -# Review changes (namespacing with iv = t) -# Review changes (delay) -# Multiple dispatch or separate bloxs for delay? -# Where to collect lags? From adcc69c4ee75bb7ffe0f42234c5504b46cb4e66b Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Thu, 25 Jul 2024 18:33:53 +0300 Subject: [PATCH 02/50] rename `events -> discrete_callbacks` in `BloxConnector` --- src/blox/blox_utilities.jl | 13 +++++-------- src/blox/connections.jl | 35 ++++++++++++++++------------------- 2 files changed, 21 insertions(+), 27 deletions(-) diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index 993d0925..57268753 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -149,14 +149,11 @@ delay_parameters(blox) = Num[] delay_parameters(blox::AbstractComponent) = blox.connector.delays delay_parameters(blox::CompositeBlox) = blox.connector.delays -event_callbacks(blox) = [] -event_callbacks(blox::AbstractComponent) = blox.connector.events -event_callbacks(blox::CompositeBlox) = blox.connector.events - -weight_learning_rules(blox) = Dict{Num, AbstractLearningRule}() -weight_learning_rules(bc::BloxConnector) = bc.learning_rules -weight_learning_rules(blox::AbstractComponent) = weight_learning_rules(blox.connector) -weight_learning_rules(blox::CompositeBlox) = weight_learning_rules(blox.connector) +get_discrete_callbacks(blox) = [] +get_discrete_callbacks(blox::AbstractComponent) = blox.connector.discrete_callbacks +get_discrete_callbacks(blox::CompositeBlox) = blox.connector.discrete_callbacks + +get_continuous_callbacks(blox) = [] get_blox_parts(blox) = blox.parts diff --git a/src/blox/connections.jl b/src/blox/connections.jl index c4450e3e..93098d8d 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -2,17 +2,14 @@ mutable struct BloxConnector eqs::Vector{Equation} weights::Vector{Num} delays::Vector{Num} - events + discrete_callbacks learning_rules BloxConnector() = new(Equation[], Num[], Num[], Pair{Any, Vector{Equation}}[], Dict{Num, AbstractLearningRule}()) function BloxConnector(bloxs) - eqs = mapreduce(input_equations, vcat, bloxs) - weights = mapreduce(weight_parameters, vcat, bloxs) - delays = mapreduce(delay_parameters, vcat, bloxs) - events = mapreduce(event_callbacks, vcat, bloxs) - learning_rules = mapreduce(weight_learning_rules, merge, bloxs) + discrete_callbacks = mapreduce(get_discrete_callbacks, vcat, bloxs) + continuous_callbacks = mapreduce(get_continuous_callbacks, vcat, bloxs) new(eqs, weights, delays, events, learning_rules) end @@ -43,18 +40,18 @@ function get_callbacks(g, bc; t_block=missing) if !isempty(eqs_params) && !isempty(eqs) cbs_spikes = (t_block + sqrt(eps(float(t_block)))) => eqs cbs_params = (t_block - sqrt(eps(float(t_block)))) => eqs_params - return vcat(cbs_params, cbs_spikes, bc.events) + return vcat(cbs_params, cbs_spikes, bc.discrete_callbacks) elseif isempty(eqs_params) && !isempty(eqs) cbs_spikes = (t_block + sqrt(eps(float(t_block)))) => eqs - return vcat(cbs_spikes, bc.events) + return vcat(cbs_spikes, bc.discrete_callbacks) elseif !isempty(eqs_params) && isempty(eqs) cbs_params = (t_block - sqrt(eps(float(t_block)))) => eqs_params - return vcat(cbs_params, bc.events) + return vcat(cbs_params, bc.discrete_callbacks) else - return bc.events + return bc.discrete_callbacks end else - return bc.events + return bc.discrete_callbacks end end @@ -649,10 +646,10 @@ function (bc::BloxConnector)( cb_matr_init = [0.1] => [sys_matr_in.H ~ 1] cb_strios_init = [0.1] => [sys_strios_in.H ~ 1] - push!(bc.events, cb_matr) - push!(bc.events, cb_strios) - push!(bc.events, cb_matr_init) - push!(bc.events, cb_strios_init) + push!(bc.discrete_callbacks, cb_matr) + push!(bc.discrete_callbacks, cb_strios) + push!(bc.discrete_callbacks, cb_matr_init) + push!(bc.discrete_callbacks, cb_strios_init) for neuron in neurons_in sys_neuron = get_namespaced_sys(neuron) @@ -661,8 +658,8 @@ function (bc::BloxConnector)( cb_neuron = [t_event] => [sys_neuron.I_bg ~ ifelse(sys_matr_out.H*sys_matr_out.jcn > sys_matr_in.H*sys_matr_in.jcn, -2, 0)] # lateral inhibition current I_bg should be set to 0 at the beginning of each trial cb_neuron_init = [0.1] => [sys_neuron.I_bg ~ 0] - push!(bc.events, cb_neuron) - push!(bc.events, cb_neuron_init) + push!(bc.discrete_callbacks, cb_neuron) + push!(bc.discrete_callbacks, cb_neuron_init) end end @@ -735,7 +732,7 @@ function (bc::BloxConnector)( t_event = get_event_time(kwargs, nameof(discr_out), nameof(discr_in)) cb = [t_event+sqrt(eps(t_event))] => (sample_affect!, [], [sys_out.κ, sys_out.jcn, sys_in.TAN_spikes], []) - push!(bc.events, cb) + push!(bc.discrete_callbacks, cb) eq = sys_in.jcn ~ w*sys_in.TAN_spikes @@ -752,7 +749,7 @@ function (bc::BloxConnector)( t_event = get_event_time(kwargs, nameof(discr_out), nameof(discr_in)) cb = [t_event] => [sys_in.H ~ ifelse(sys_out.H*sys_out.jcn > sys_in.H*sys_in.jcn, 0, 1)] - push!(bc.events, cb) + push!(bc.discrete_callbacks, cb) end function (bc::BloxConnector)( From 28eb1d85e5f5a5f101e15483550730cc215f7ea7 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Thu, 25 Jul 2024 18:35:50 +0300 Subject: [PATCH 03/50] add seperate continuous_callbacks field in `BloxConnector` --- src/blox/blox_utilities.jl | 3 +++ src/blox/connections.jl | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index 57268753..97f9ff20 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -154,6 +154,9 @@ get_discrete_callbacks(blox::AbstractComponent) = blox.connector.discrete_callba get_discrete_callbacks(blox::CompositeBlox) = blox.connector.discrete_callbacks get_continuous_callbacks(blox) = [] +get_continuous_callbacks(blox::AbstractComponent) = blox.connector.discrete_callbacks +get_continuous_callbacks(blox::CompositeBlox) = blox.connector.discrete_callbacks + get_blox_parts(blox) = blox.parts diff --git a/src/blox/connections.jl b/src/blox/connections.jl index 93098d8d..96489c27 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -3,6 +3,7 @@ mutable struct BloxConnector weights::Vector{Num} delays::Vector{Num} discrete_callbacks + continuous_callbacks learning_rules BloxConnector() = new(Equation[], Num[], Num[], Pair{Any, Vector{Equation}}[], Dict{Num, AbstractLearningRule}()) @@ -11,7 +12,7 @@ mutable struct BloxConnector discrete_callbacks = mapreduce(get_discrete_callbacks, vcat, bloxs) continuous_callbacks = mapreduce(get_continuous_callbacks, vcat, bloxs) - new(eqs, weights, delays, events, learning_rules) + new(eqs, weights, delays, discrete_callbacks, continuous_callbacks, learning_rules) end end From 8ab49f49a1d9406d1e09e025e1bb1fcb770fb75a Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Thu, 25 Jul 2024 18:37:30 +0300 Subject: [PATCH 04/50] use `get_` prefix in getter function names for clarity --- src/Neurographs.jl | 4 ++-- src/blox/blox_utilities.jl | 24 ++++++++++++++---------- src/blox/connections.jl | 6 +++++- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/Neurographs.jl b/src/Neurographs.jl index 6a385981..ce1469bf 100644 --- a/src/Neurographs.jl +++ b/src/Neurographs.jl @@ -62,7 +62,7 @@ function system_from_graph(g::MetaDiGraph, bc::BloxConnector; name, t_block=miss blox_syss = get_sys(g) connection_eqs = get_equations_with_state_lhs(bc) - cbs = identity.(get_callbacks(g, bc; t_block)) + cbs = identity.(generate_discrete_callbacks(g, bc; t_block)) return compose(ODESystem(connection_eqs, t, [], params(bc); name, discrete_events = cbs), blox_syss) end @@ -71,7 +71,7 @@ function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}; na blox_syss = get_sys(g) connection_eqs = get_equations_with_state_lhs(bc) - cbs = identity.(get_callbacks(g, bc; t_block)) + cbs = identity.(generate_discrete_callbacks(g, bc; t_block)) return compose(ODESystem(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = cbs), blox_syss) end diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index 97f9ff20..c2be5284 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -110,7 +110,7 @@ end which holds a `BloxConnector` object with all relevant connections from lower levels and this level. """ -function input_equations(blox) +function get_input_equations(blox) sys = get_sys(blox) inps = inputs(sys) sys_eqs = equations(sys) @@ -137,17 +137,17 @@ function input_equations(blox) return eqs end -input_equations(blox::AbstractComponent) = blox.connector.eqs -input_equations(blox::CompositeBlox) = blox.connector.eqs -input_equations(::ImageStimulus) = [] +get_input_equations(blox::AbstractComponent) = blox.connector.eqs +get_input_equations(blox::CompositeBlox) = blox.connector.eqs +get_input_equations(::ImageStimulus) = [] -weight_parameters(blox) = Num[] -weight_parameters(blox::AbstractComponent) = blox.connector.weights #I think this is the fix? -weight_parameters(blox::CompositeBlox) = blox.connector.weights #I think this is the fix? +get_weight_parameters(blox) = Num[] +get_weight_parameters(blox::AbstractComponent) = blox.connector.weights #I think this is the fix? +get_weight_parameters(blox::CompositeBlox) = blox.connector.weights #I think this is the fix? -delay_parameters(blox) = Num[] -delay_parameters(blox::AbstractComponent) = blox.connector.delays -delay_parameters(blox::CompositeBlox) = blox.connector.delays +get_delay_parameters(blox) = Num[] +get_delay_parameters(blox::AbstractComponent) = blox.connector.delays +get_delay_parameters(blox::CompositeBlox) = blox.connector.delays get_discrete_callbacks(blox) = [] get_discrete_callbacks(blox::AbstractComponent) = blox.connector.discrete_callbacks @@ -157,6 +157,10 @@ get_continuous_callbacks(blox) = [] get_continuous_callbacks(blox::AbstractComponent) = blox.connector.discrete_callbacks get_continuous_callbacks(blox::CompositeBlox) = blox.connector.discrete_callbacks +get_weight_learning_rules(blox) = Dict{Num, AbstractLearningRule}() +get_weight_learning_rules(bc::BloxConnector) = bc.learning_rules +get_weight_learning_rules(blox::AbstractComponent) = weight_learning_rules(blox.connector) +get_weight_learning_rules(blox::CompositeBlox) = weight_learning_rules(blox.connector) get_blox_parts(blox) = blox.parts diff --git a/src/blox/connections.jl b/src/blox/connections.jl index 96489c27..09f53eba 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -9,8 +9,12 @@ mutable struct BloxConnector BloxConnector() = new(Equation[], Num[], Num[], Pair{Any, Vector{Equation}}[], Dict{Num, AbstractLearningRule}()) function BloxConnector(bloxs) + eqs = mapreduce(get_input_equations, vcat, bloxs) + weights = mapreduce(get_weight_parameters, vcat, bloxs) + delays = mapreduce(get_delay_parameters, vcat, bloxs) discrete_callbacks = mapreduce(get_discrete_callbacks, vcat, bloxs) continuous_callbacks = mapreduce(get_continuous_callbacks, vcat, bloxs) + learning_rules = mapreduce(get_weight_learning_rules, merge, bloxs) new(eqs, weights, delays, discrete_callbacks, continuous_callbacks, learning_rules) end @@ -26,7 +30,7 @@ get_equations_with_parameter_lhs(bc) = filter(eq -> isparameter(eq.lhs), bc.eqs) get_equations_with_state_lhs(bc) = filter(eq -> !isparameter(eq.lhs), bc.eqs) -function get_callbacks(g, bc; t_block=missing) +function generate_discrete_callbacks(g, bc; t_block=missing) if !ismissing(t_block) eqs_params = get_equations_with_parameter_lhs(bc) From cd12a15b0c81b2f4ee49006cbe35a99d19a841a4 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Thu, 25 Jul 2024 18:38:00 +0300 Subject: [PATCH 05/50] add new `LIF` exci and inhib neuron bloxs --- src/blox/neuron_models.jl | 120 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/src/blox/neuron_models.jl b/src/blox/neuron_models.jl index 68805f9a..2afa35e3 100644 --- a/src/blox/neuron_models.jl +++ b/src/blox/neuron_models.jl @@ -624,6 +624,126 @@ struct LIFNeuron <: AbstractNeuronBlox end end +struct LIFInhNeuron <: AbstractNeuronBlox + odesystem + namespace + + function LIFInhNeuron(; + name, + namespace = nothing, + g_L = 20, # nS + V_L = -70, # mV + V_E = 0, # mV + V_I = -70, # mV + θ = -50, # mV + V_reset = -55, # mV + C = 0.2, # nF + τ_AMPA = 2, # ms + τ_GABA = 5, # ms + τ_NMDA_decay = 100, # ms + τ_NMDA_rise = 2, # ms + α = 0.5, # ms⁻¹ + g_AMPA = 0.04, # nS + g_AMPA_external = 1.62, # nS + g_GABA = 1, # nS + g_NMDA = 0.13, # nS + Mg = 1 # mM + ) + + ps = @parameters begin + g_L=g_L + V_L=V_L + V_E=V_E + V_I=V_I + C=C + τ_AMPA=τ_AMPA + τ_GABA=τ_GABA + τ_NMDA_decay=τ_NMDA_decay + τ_NMDA_rise=τ_NMDA_rise + g_AMPA = g_AMPA + g_AMPA_external = g_AMPA_external + g_GABA = g_GABA + g_NMDA = g_NMDA + α=α + Mg=Mg + end + + sts = @variables V(t)=V_L S_AMPA(t)=0 S_GABA(t)=0 S_NMDA(t)=0 x(t)=0 jcn(t)=0 [input=true] + eqs = [ + D(V) ~ - (g_L * (V - V_L) + jcn) / C, + D(S_AMPA) ~ - S_AMPA / τ_AMPA, + D(S_GABA) ~ - S_GABA / τ_GABA, + D(S_NMDA) ~ - S_NMDA / τ_NMDA_decay + α * x * (1 - S_NMDA), + D(x) ~ - x / τ_NMDA_rise + ] + + ev = [V ~ θ] => [V ~ V_reset] + sys = System(eqs, t, sts, ps; continuous_events=[ev], name=name) + + new(sys, namespace) + end +end + +struct LIFExciNeuron <: AbstractNeuronBlox + odesystem + namespace + + function LIFExciNeuron(; + name, + namespace = nothing, + g_L = 25, # nS + V_L = -70, # mV + V_E = 0, # mV + V_I = -70, # mV + θ = -50, # mV + V_reset = -55, # mV + C = 0.5, # nF + τ_AMPA = 2, # ms + τ_GABA = 5, # ms + τ_NMDA_decay = 100, # ms + τ_NMDA_rise = 2, # ms + α = 0.5, # ms⁻¹ + g_AMPA = 0.05, # nS + g_AMPA_external = 2.1, # nS + g_GABA = 1.3, # nS + g_NMDA = 0.165, # nS + Mg = 1 # mM + ) + + ps = @parameters begin + g_L=g_L + V_L=V_L + V_E=V_E + V_I=V_I + C=C + τ_AMPA=τ_AMPA + τ_GABA=τ_GABA + τ_NMDA_decay=τ_NMDA_decay + τ_NMDA_rise=τ_NMDA_rise + g_AMPA = g_AMPA + g_AMPA_external = g_AMPA_external + g_GABA = g_GABA + g_NMDA = g_NMDA + α=α + Mg=Mg + end + + sts = @variables V(t)=V_L S_AMPA(t)=0 S_GABA(t)=0 S_NMDA(t)=0 x(t)=0 jcn(t)=0 [input=true] + eqs = [ + D(V) ~ - (g_L * (V - V_L) + jcn) / C, + D(S_AMPA) ~ - S_AMPA / τ_AMPA, + D(S_GABA) ~ - S_GABA / τ_GABA, + D(S_NMDA) ~ - S_NMDA / τ_NMDA_decay + α * x * (1 - S_NMDA), + D(x) ~ - x / τ_NMDA_rise + ] + + ev = [V ~ θ] => [V ~ V_reset] + sys = System(eqs, t, sts, ps; continuous_events=[ev], name=name) + + new(sys, namespace) + end +end + # Paramater bounds for GUI # C = [0.1, 100] μF # E_syn = [1, 100] kΩ From ab8b87e6a4be4946b7884757447685cedc099ac1 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Thu, 25 Jul 2024 18:38:16 +0300 Subject: [PATCH 06/50] add conenction rules for new LIF neurons --- src/blox/connections.jl | 44 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index 09f53eba..deb93f5a 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -853,3 +853,47 @@ function (bc::BloxConnector)( accumulate_equation!(bc, eq) end + +function (bc::BloxConnector)( + bloxout::LIFExciNeuron, + bloxin::Union{LIFExciNeuron, LIFInhNeuron}; + kwargs... +) + sys_out = get_namespaced_sys(bloxout) + sys_in = get_namespaced_sys(bloxin) + + w = generate_weight_param(bloxout, bloxin; kwargs...) + push!(bc.weights, w) + + eq = sys_in.jcn ~ w * sys_in.S_AMPA * sys_in.g_AMPA * (sys_in.V - sys_in.V_E) + + w * sys_in.S_NMDA * sys_in.g_NMDA * (sys_in.V - sys_in.V_E) / + (1 + sys_in.Mg * exp(-0.062 * sys_in.V) / 3.57) + + accumulate_equation!(bc, eq) + + cb = [sys_out.V ~ sys_out.θ] => [ + sys_in.S_AMPA ~ sysin.S_AMPA + 1, + sys_in.S_NMDA ~ sysin.S_NMDA + 1 + ] + push!(bc.continuous_callbacks, cb) +end + +function (bc::BloxConnector)( + bloxout::LIFInhNeuron, + bloxin::Union{LIFExciNeuron, LIFInhNeuron}; + kwargs... +) + sys_out = get_namespaced_sys(bloxout) + sys_in = get_namespaced_sys(bloxin) + + w = generate_weight_param(bloxout, bloxin; kwargs...) + push!(bc.weights, w) + + eq = sys_in.jcn ~ w * sys_in.S_GABA * sys_in.g_GABA * (sys_in.V - sys_in.V_I) + + accumulate_equation!(bc, eq) + + cb = [sys_out.V ~ sys_out.θ] => [sys_in.S_GABA ~ sys_in.S_GABA + 1] + push!(bc.continuous_callbacks, cb) +end + From 1f4714cbcb2cbd0bd78ccda26ec2bbfcc7a8bbe3 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Fri, 26 Jul 2024 13:34:44 +0300 Subject: [PATCH 07/50] tidy up `BloxConnector` getter functions --- src/blox/blox_utilities.jl | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index c2be5284..3d08abef 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -137,32 +137,33 @@ function get_input_equations(blox) return eqs end -get_input_equations(blox::AbstractComponent) = blox.connector.eqs -get_input_equations(blox::CompositeBlox) = blox.connector.eqs -get_input_equations(::ImageStimulus) = [] +get_connector(blox::Union{CompositeBlox, AbstractComponent}) = blox.connector +get_input_equations(bc::BloxConnector) = bc.eqs +get_input_equations(blox::Union{CompositeBlox, AbstractComponent}) = (get_input_equations ∘ get_connector)(blox) +get_input_equations(blox) = [] + +get_weight_parameters(bc::BloxConnector) = bc.weights +get_weight_parameters(blox::Union{CompositeBlox, AbstractComponent}) = (get_weight_parameters ∘ get_connector)(blox) get_weight_parameters(blox) = Num[] -get_weight_parameters(blox::AbstractComponent) = blox.connector.weights #I think this is the fix? -get_weight_parameters(blox::CompositeBlox) = blox.connector.weights #I think this is the fix? +get_delay_parameters(bc::BloxConnector) = bc.delays +get_delay_parameters(blox::Union{CompositeBlox, AbstractComponent}) = (get_delay_parameters ∘ get_connector)(blox) get_delay_parameters(blox) = Num[] -get_delay_parameters(blox::AbstractComponent) = blox.connector.delays -get_delay_parameters(blox::CompositeBlox) = blox.connector.delays +get_discrete_callbacks(bc::BloxConnector) = bc.discrete_callbacks +get_discrete_callbacks(blox::Union{CompositeBlox, AbstractComponent}) = (get_discrete_callbacks ∘ get_connector)(blox) get_discrete_callbacks(blox) = [] -get_discrete_callbacks(blox::AbstractComponent) = blox.connector.discrete_callbacks -get_discrete_callbacks(blox::CompositeBlox) = blox.connector.discrete_callbacks +get_continuous_callbacks(bc::BloxConnector) = bc.continuous_callbacks +get_continuous_callbacks(blox::Union{CompositeBlox, AbstractComponent}) = (get_continuous_callbacks ∘ get_connector)(blox) get_continuous_callbacks(blox) = [] -get_continuous_callbacks(blox::AbstractComponent) = blox.connector.discrete_callbacks -get_continuous_callbacks(blox::CompositeBlox) = blox.connector.discrete_callbacks -get_weight_learning_rules(blox) = Dict{Num, AbstractLearningRule}() get_weight_learning_rules(bc::BloxConnector) = bc.learning_rules -get_weight_learning_rules(blox::AbstractComponent) = weight_learning_rules(blox.connector) -get_weight_learning_rules(blox::CompositeBlox) = weight_learning_rules(blox.connector) +get_weight_learning_rules(blox::Union{CompositeBlox, AbstractComponent}) = (get_weight_learning_rules ∘ get_connector)(blox) +get_weight_learning_rules(blox) = Dict{Num, AbstractLearningRule}() -get_blox_parts(blox) = blox.parts +get_blox_parts(blox::Union{CompositeBlox, AbstractComponent}) = blox.parts function get_weight(kwargs, name_blox1, name_blox2) if haskey(kwargs, :weight) From 2d201c421502116249eb76e508d23c0b45a4009c Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Fri, 26 Jul 2024 15:42:28 +0300 Subject: [PATCH 08/50] specify type in input equation getter dispatch --- src/blox/blox_utilities.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index 3d08abef..ea553690 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -110,7 +110,7 @@ end which holds a `BloxConnector` object with all relevant connections from lower levels and this level. """ -function get_input_equations(blox) +function get_input_equations(blox::AbstractBlox) sys = get_sys(blox) inps = inputs(sys) sys_eqs = equations(sys) From 51ab09d6b7cb3994f943b6c373ddb84168489a83 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Fri, 26 Jul 2024 15:44:49 +0300 Subject: [PATCH 09/50] export new LIF Bloxs --- src/Neuroblox.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index 8ca3ba85..5125e303 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -186,7 +186,7 @@ end export JansenRitSPM12, next_generation, qif_neuron, if_neuron, hh_neuron_excitatory, hh_neuron_inhibitory, van_der_pol, Generic2dOscillator -export HHNeuronExciBlox, HHNeuronInhibBlox, IFNeuron, LIFNeuron, QIFNeuron, IzhikevichNeuron, +export HHNeuronExciBlox, HHNeuronInhibBlox, IFNeuron, LIFNeuron, QIFNeuron, IzhikevichNeuron, LIFExciNeuron, LIFInhNeuron, CanonicalMicroCircuitBlox, WinnerTakeAllBlox, CorticalBlox, SuperCortical, HHNeuronInhib_MSN_Adam_Blox, HHNeuronInhib_FSI_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox, Striatum_MSN_Adam, Striatum_FSI_Adam, GPe_Adam, STN_Adam export LinearNeuralMass, HarmonicOscillator, JansenRit, WilsonCowan, LarterBreakspear, NextGenerationBlox, NextGenerationResolvedBlox, NextGenerationEIBlox From cba8df430032ccb0f52e06c377464742d2f68528 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Fri, 26 Jul 2024 18:36:41 +0300 Subject: [PATCH 10/50] include `ObserverBlox` in input equation getter dispatch --- src/blox/blox_utilities.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index ea553690..0782a6e0 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -110,7 +110,7 @@ end which holds a `BloxConnector` object with all relevant connections from lower levels and this level. """ -function get_input_equations(blox::AbstractBlox) +function get_input_equations(blox::Union{AbstractBlox, ObserverBlox}) sys = get_sys(blox) inps = inputs(sys) sys_eqs = equations(sys) From a0169d64f5d40e0366a148e15fd8569298f85dc1 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Fri, 26 Jul 2024 18:57:22 +0300 Subject: [PATCH 11/50] add continuous callbacks to the final system `compose` --- src/Neurographs.jl | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/Neurographs.jl b/src/Neurographs.jl index ce1469bf..5bb819c4 100644 --- a/src/Neurographs.jl +++ b/src/Neurographs.jl @@ -62,17 +62,20 @@ function system_from_graph(g::MetaDiGraph, bc::BloxConnector; name, t_block=miss blox_syss = get_sys(g) connection_eqs = get_equations_with_state_lhs(bc) - cbs = identity.(generate_discrete_callbacks(g, bc; t_block)) - return compose(ODESystem(connection_eqs, t, [], params(bc); name, discrete_events = cbs), blox_syss) + discrete_cbs = identity.(generate_discrete_callbacks(g, bc; t_block)) + continuous_cbs = get_continuous_callbacks(bc) + + return compose(ODESystem(connection_eqs, t, [], params(bc); name, discrete_events = discrete_cbs, continuous_events = continuous_cbs), blox_syss) end function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}; name, t_block=missing) - blox_syss = get_sys(g) - connection_eqs = get_equations_with_state_lhs(bc) - cbs = identity.(generate_discrete_callbacks(g, bc; t_block)) - return compose(ODESystem(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = cbs), blox_syss) + + discrete_cbs = identity.(generate_discrete_callbacks(g, bc; t_block)) + continuous_cbs = get_continuous_callbacks(bc) + + return compose(ODESystem(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = discrete_cbs, continuous_events = continuous_cbs), blox_syss) end function system_from_parts(parts::AbstractVector; name) From 1ba935706c74b0404d35d27942fa0b86bb1de679 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Fri, 26 Jul 2024 18:57:48 +0300 Subject: [PATCH 12/50] fix typo --- src/blox/connections.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index deb93f5a..accf34ac 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -872,8 +872,8 @@ function (bc::BloxConnector)( accumulate_equation!(bc, eq) cb = [sys_out.V ~ sys_out.θ] => [ - sys_in.S_AMPA ~ sysin.S_AMPA + 1, - sys_in.S_NMDA ~ sysin.S_NMDA + 1 + sys_in.S_AMPA ~ sys_in.S_AMPA + 1, + sys_in.S_NMDA ~ sys_in.S_NMDA + 1 ] push!(bc.continuous_callbacks, cb) end From 4d5665b2badd1a677b1b53f203fc8861ab7747ff Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Fri, 26 Jul 2024 18:58:05 +0300 Subject: [PATCH 13/50] updates units in new LIF params --- src/blox/neuron_models.jl | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/blox/neuron_models.jl b/src/blox/neuron_models.jl index 2afa35e3..6db75e5c 100644 --- a/src/blox/neuron_models.jl +++ b/src/blox/neuron_models.jl @@ -631,22 +631,22 @@ struct LIFInhNeuron <: AbstractNeuronBlox function LIFInhNeuron(; name, namespace = nothing, - g_L = 20, # nS + g_L = 20 * 1e-3, # mS V_L = -70, # mV V_E = 0, # mV V_I = -70, # mV θ = -50, # mV V_reset = -55, # mV - C = 0.2, # nF + C = 0.2 * 1e-3, # mF τ_AMPA = 2, # ms τ_GABA = 5, # ms τ_NMDA_decay = 100, # ms τ_NMDA_rise = 2, # ms α = 0.5, # ms⁻¹ - g_AMPA = 0.04, # nS - g_AMPA_external = 1.62, # nS - g_GABA = 1, # nS - g_NMDA = 0.13, # nS + g_AMPA = 0.04 * 1e-3, # mS + g_AMPA_external = 1.62 * 1e-3, # mS + g_GABA = 1 * 1e-3, # mS + g_NMDA = 0.13 * 1e-3, # mS Mg = 1 # mM ) @@ -655,6 +655,7 @@ struct LIFInhNeuron <: AbstractNeuronBlox V_L=V_L V_E=V_E V_I=V_I + θ=θ C=C τ_AMPA=τ_AMPA τ_GABA=τ_GABA @@ -691,22 +692,22 @@ struct LIFExciNeuron <: AbstractNeuronBlox function LIFExciNeuron(; name, namespace = nothing, - g_L = 25, # nS + g_L = 25 * 1e-3, # mS V_L = -70, # mV V_E = 0, # mV V_I = -70, # mV θ = -50, # mV V_reset = -55, # mV - C = 0.5, # nF + C = 0.5 * 1e-3, # mF τ_AMPA = 2, # ms τ_GABA = 5, # ms τ_NMDA_decay = 100, # ms τ_NMDA_rise = 2, # ms α = 0.5, # ms⁻¹ - g_AMPA = 0.05, # nS - g_AMPA_external = 2.1, # nS - g_GABA = 1.3, # nS - g_NMDA = 0.165, # nS + g_AMPA = 0.05 * 1e-3, # mS + g_AMPA_external = 2.1 * 1e-3, # mS + g_GABA = 1.3 * 1e-3, # mS + g_NMDA = 0.165 * 1e-3, # mS Mg = 1 # mM ) @@ -715,6 +716,7 @@ struct LIFExciNeuron <: AbstractNeuronBlox V_L=V_L V_E=V_E V_I=V_I + θ=θ C=C τ_AMPA=τ_AMPA τ_GABA=τ_GABA From 2a2c9efb7af1530d2a8d273e0bd40f35894fa3b0 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Fri, 26 Jul 2024 19:02:22 +0300 Subject: [PATCH 14/50] add initial test --- test/components.jl | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/components.jl b/test/components.jl index adaea00d..1bd1ddbd 100644 --- a/test/components.jl +++ b/test/components.jl @@ -580,3 +580,26 @@ end sol = solve(prob, Tsit5()) @test sol.retcode == ReturnCode.Success end + +@testset "LIFExciBlox - LIFInhBlox network" begin + global_ns = :g # global namespace + @named n1 = LIFExciNeuron(; namespace = global_ns) + @named n2 = LIFExciNeuron(; namespace = global_ns) + @named n3 = LIFInhNeuron(; namespace = global_ns) + + neurons = [n1, n2, n3] + g = MetaDiGraph() + add_blox!.(Ref(g), neurons) + + for i in eachindex(neurons) + for j in eachindex(neurons) + add_edge!(g, i, j, Dict(:weight => 1)) + end + end + + @named sys = system_from_graph(g) + sys_simpl = structural_simplify(sys) + prob = ODEProblem(sys_simpl, [], (0, 200.0)) + sol = solve(prob, Tsit5()) + @test sol.retcode == ReturnCode.Success +end \ No newline at end of file From 9e166f89e54690d33e3715ee6c475dc46bf819d4 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 30 Jul 2024 15:30:58 +0300 Subject: [PATCH 15/50] use `identity` for concrete type --- src/Neurographs.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Neurographs.jl b/src/Neurographs.jl index 5bb819c4..604c663d 100644 --- a/src/Neurographs.jl +++ b/src/Neurographs.jl @@ -63,7 +63,7 @@ function system_from_graph(g::MetaDiGraph, bc::BloxConnector; name, t_block=miss connection_eqs = get_equations_with_state_lhs(bc) discrete_cbs = identity.(generate_discrete_callbacks(g, bc; t_block)) - continuous_cbs = get_continuous_callbacks(bc) + continuous_cbs = identity.(get_continuous_callbacks(bc)) return compose(ODESystem(connection_eqs, t, [], params(bc); name, discrete_events = discrete_cbs, continuous_events = continuous_cbs), blox_syss) end @@ -73,7 +73,7 @@ function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}; na connection_eqs = get_equations_with_state_lhs(bc) discrete_cbs = identity.(generate_discrete_callbacks(g, bc; t_block)) - continuous_cbs = get_continuous_callbacks(bc) + continuous_cbs = identity.(get_continuous_callbacks(bc)) return compose(ODESystem(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = discrete_cbs, continuous_events = continuous_cbs), blox_syss) end From 6e38adba1504929624ff82390fede53bd72bdcca Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 30 Jul 2024 15:32:31 +0300 Subject: [PATCH 16/50] switch from `ODESystem` to `System` --- src/Neurographs.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Neurographs.jl b/src/Neurographs.jl index 604c663d..0b3eaae4 100644 --- a/src/Neurographs.jl +++ b/src/Neurographs.jl @@ -65,7 +65,7 @@ function system_from_graph(g::MetaDiGraph, bc::BloxConnector; name, t_block=miss discrete_cbs = identity.(generate_discrete_callbacks(g, bc; t_block)) continuous_cbs = identity.(get_continuous_callbacks(bc)) - return compose(ODESystem(connection_eqs, t, [], params(bc); name, discrete_events = discrete_cbs, continuous_events = continuous_cbs), blox_syss) + return compose(System(connection_eqs, t, [], params(bc); name, discrete_events = discrete_cbs, continuous_events = continuous_cbs), blox_syss) end function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}; name, t_block=missing) @@ -75,7 +75,7 @@ function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}; na discrete_cbs = identity.(generate_discrete_callbacks(g, bc; t_block)) continuous_cbs = identity.(get_continuous_callbacks(bc)) - return compose(ODESystem(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = discrete_cbs, continuous_events = continuous_cbs), blox_syss) + return compose(System(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = discrete_cbs, continuous_events = continuous_cbs), blox_syss) end function system_from_parts(parts::AbstractVector; name) From 833702c805bede4779dc399495c05c46fabc7201 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 30 Jul 2024 15:33:02 +0300 Subject: [PATCH 17/50] add `PoissonSpikeTrain` as a spike train source --- src/blox/sources.jl | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/blox/sources.jl b/src/blox/sources.jl index fa0b74cd..9527ccd5 100644 --- a/src/blox/sources.jl +++ b/src/blox/sources.jl @@ -154,3 +154,28 @@ mutable struct ImageStimulus <: StimulusBlox ImageStimulus(data; name, namespace, t_stimulus, t_pause) end end + +struct PoissonSpikeTrain <: StimulusBlox + name + namespace + rate + dt + tspan + rng + + function PoissonSpikeTrain(; name, namespace, rate, tspan, dt=0.5/rate, rng = MersenneTwister(1234)) + new(name, namespace, rate, dt, tspan, rng) + end +end + +function generate_spike_times(stim::PoissonSpikeTrain) + # This could also change to a dispatch of Random.rand() + t_spikes = Float64[] + for t in range(stim.tspan...; step = stim.dt) + if rand(stim.rng) < stim.rate * stim.dt + push!(t_spikes, t) + end + end + + return t_spikes +end From 47bdd606edc98bcf041c871161b00adbb80784a3 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 30 Jul 2024 15:33:28 +0300 Subject: [PATCH 18/50] export `PoissonSpikeTrain` --- src/Neuroblox.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index 5125e303..a539b9f6 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -194,7 +194,7 @@ export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc export HebbianPlasticity, HebbianModulationPlasticity export Agent, ClassificationEnvironment, GreedyPolicy, reset! export LearningBlox -export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ExternalInput +export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ExternalInput, PoissonSpikeTrain export PowerSpectrumBlox, BandPassFilterBlox export OUBlox, OUCouplingBlox export phase_inter, phase_sin_blox, phase_cos_blox From d149c4751080f10fa32fae3647e0145c4e99c452 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 30 Jul 2024 15:34:06 +0300 Subject: [PATCH 19/50] `get_sys` return empty `System` for `PoissonSpikeTrain` --- src/blox/blox_utilities.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index 0782a6e0..8b06c8b0 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -53,6 +53,7 @@ get_inh_neurons(n) = [] get_sys(blox) = blox.odesystem get_sys(sys::AbstractODESystem) = sys +get_sys(stim::PoissonSpikeTrain) = System(Equation[], t, [], []; name=stim.name) function get_namespaced_sys(blox) sys = get_sys(blox) From 44fa957d1de81483eeb39a88b0a3d376ecefe1ae Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 30 Jul 2024 15:34:27 +0300 Subject: [PATCH 20/50] add connection rule for `PoissonSpikeTrain` to new LIF exci neuron --- src/blox/connections.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index accf34ac..711e3a72 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -897,3 +897,23 @@ function (bc::BloxConnector)( push!(bc.continuous_callbacks, cb) end +function (bc::BloxConnector)( + stim::PoissonSpikeTrain, + neuron::Union{LIFExciNeuron, LIFInhNeuron}; + kwargs... +) + sys_in = get_namespaced_sys(neuron) + + w = generate_weight_param(stim, neuron; kwargs...) + push!(bc.weights, w) + + eq = sys_in.jcn ~ w * sys_in.S_AMPA * sys_in.g_AMPA_external * (sys_in.V - sys_in.V_E) + + accumulate_equation!(bc, eq) + + t_spikes = generate_spike_times(stim) + + cb = t_spikes => [sys_in.S_AMPA ~ sys_in.S_AMPA + 1] + push!(bc.discrete_callbacks, cb) +end + From 92451e7a909da86afd0007c56abb526372c0e908 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 30 Jul 2024 15:34:42 +0300 Subject: [PATCH 21/50] test `PoissonSpikeTain` with new LIF exci neuron --- test/components.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/components.jl b/test/components.jl index 1bd1ddbd..e499832c 100644 --- a/test/components.jl +++ b/test/components.jl @@ -602,4 +602,23 @@ end prob = ODEProblem(sys_simpl, [], (0, 200.0)) sol = solve(prob, Tsit5()) @test sol.retcode == ReturnCode.Success +end + +@testset "PoissonSpikeTrain - LIFExciBlox network" begin + global_ns = :g # global namespace + tspan = (0, 200) + @named n1 = LIFExciNeuron(; namespace = global_ns) + @named s = PoissonSpikeTrain(; namespace = global_ns, rate=5, tspan) + + neurons = [s, n1] + g = MetaDiGraph() + add_blox!.(Ref(g), neurons) + + add_edge!(g, 1, 2, Dict(:weight => 1)) + + @named sys = system_from_graph(g) + sys_simpl = structural_simplify(sys) + prob = ODEProblem(sys_simpl, [], tspan) + sol = solve(prob, Tsit5()) + @test sol.retcode == ReturnCode.Success end \ No newline at end of file From 63bfa4cf8363dd0098434f992d57326c6a391417 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 30 Jul 2024 18:07:38 +0300 Subject: [PATCH 22/50] rename `cortical_blox.jl` to `cortical.jl` to include more cortical-like types in it --- src/blox/{cortical_blox.jl => cortical.jl} | 71 ++++++++++++++++++++++ 1 file changed, 71 insertions(+) rename src/blox/{cortical_blox.jl => cortical.jl} (54%) diff --git a/src/blox/cortical_blox.jl b/src/blox/cortical.jl similarity index 54% rename from src/blox/cortical_blox.jl rename to src/blox/cortical.jl index a59253c7..523536a7 100644 --- a/src/blox/cortical_blox.jl +++ b/src/blox/cortical.jl @@ -68,3 +68,74 @@ struct CorticalBlox <: CompositeBlox new(namespace, vcat(wtas, n_ff_inh), sys, bc) end end + +struct LIFExciCircuitBlox <: CompositeBlox + namespace + parts + odesystem + connector + + function LIFExciCircuitBlox( + name, + N_neurons, + namespace=nothing, + g_L = 25 * 1e-3, # mS + V_L = -70, # mV + V_E = 0, # mV + V_I = -70, # mV + θ = -50, # mV + V_reset = -55, # mV + C = 0.5 * 1e-3, # mF + τ_AMPA = 2, # ms + τ_GABA = 5, # ms + τ_NMDA_decay = 100, # ms + τ_NMDA_rise = 2, # ms + α = 0.5, # ms⁻¹ + g_AMPA = 0.05 * 1e-3, # mS + g_AMPA_external = 2.1 * 1e-3, # mS + g_GABA = 1.3 * 1e-3, # mS + g_NMDA = 0.165 * 1e-3, # mS + Mg = 1, # mM + kwargs... + ) + + neurons = map(Base.OneTo(N_neurons)) do i + LIFExciNeuron(; + name = Symbol("neuron$i"), + namespace = namespaced_name(namespace, name), + g_L, + V_L, + V_E, + V_I, + θ, + V_reset, + C, + τ_AMPA, + τ_GABA, + τ_NMDA_decay, + τ_NMDA_rise, + α, + g_AMPA, + g_AMPA_external, + g_GABA, + g_NMDA, + Mg + ) + end + + g = MetaDiGraph() + add_blox!.(Ref(g), neurons) + + for i in eachindex(neurons) + for j in eachindex(neurons) + add_edge!(g, i, j, Dict(kwargs)) + end + end + + bc = connector_from_graph(g) + + sys = isnothing(namespace) ? system_from_graph(g, bc; name) : system_from_parts(vcat(wtas, n_ff_inh); name) + + new(namespace, neurons, sys, bc) + end +end \ No newline at end of file From d9c9a9a4daa1ee855db72de486b98a44e54a9fb9 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 30 Jul 2024 18:08:33 +0300 Subject: [PATCH 23/50] include renamed file --- src/Neuroblox.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index 3cd15392..a07dd807 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -95,7 +95,7 @@ include("control/controlerror.jl") include("measurementmodels/fmri.jl") include("datafitting/spectralDCM.jl") include("blox/neural_mass.jl") -include("blox/cortical_blox.jl") +include("blox/cortical.jl") include("blox/canonicalmicrocircuit.jl") include("blox/neuron_models.jl") include("blox/DBS_Model_Blox_Adam_Brown.jl") From 8aa4c2d540671037b5a0afa186a5558f923d7bcc Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 30 Jul 2024 18:09:26 +0300 Subject: [PATCH 24/50] update spike train test --- test/components.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/test/components.jl b/test/components.jl index 5791c516..70990210 100644 --- a/test/components.jl +++ b/test/components.jl @@ -697,19 +697,25 @@ end @testset "PoissonSpikeTrain - LIFExciBlox network" begin global_ns = :g # global namespace + tspan = (0, 200) - @named n1 = LIFExciNeuron(; namespace = global_ns) + @named s = PoissonSpikeTrain(; namespace = global_ns, rate=5, tspan) + @named n1 = LIFExciNeuron(; namespace = global_ns) + @named n2 = LIFExciNeuron(; namespace = global_ns) - neurons = [s, n1] + neurons = [s, n1, n2] g = MetaDiGraph() add_blox!.(Ref(g), neurons) add_edge!(g, 1, 2, Dict(:weight => 1)) + add_edge!(g, 1, 3, Dict(:weight => 1)) + add_edge!(g, 2, 3, Dict(:weight => 1)) + add_edge!(g, 3, 2, Dict(:weight => 1)) @named sys = system_from_graph(g) sys_simpl = structural_simplify(sys) prob = ODEProblem(sys_simpl, [], tspan) - sol = solve(prob, Tsit5()) + sol = solve(prob, Vern7(), reltol=1e-9,abstol=1e-9) @test sol.retcode == ReturnCode.Success -end \ No newline at end of file +end From 04df20e8cb2fdb9acb75cb24c63f3209172a4283 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Wed, 31 Jul 2024 12:33:41 +0300 Subject: [PATCH 25/50] add note for alternative spike generation mechanism --- src/blox/connections.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index 5b5d9298..64b139c7 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -932,6 +932,12 @@ function (bc::BloxConnector)( t_spikes = generate_spike_times(stim) cb = t_spikes => [sys_in.S_AMPA ~ sys_in.S_AMPA + 1] + # TO DO : Consider generating spikes during simulation + # to make PoissonSpikeTrain independent of `t_span` of the simulation. + # something like : + # discrete_event = t > -Inf => (generate_spike, [sys_in.S_AMPA], [stim.relevant_params...], [], nothing) + # This way we need to resolve the case of multiple spikes potentially being generated within a single integrator step. + push!(bc.discrete_callbacks, cb) end From 0089a9031c27866b58599ea7e2af1056125f2654 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Wed, 31 Jul 2024 16:50:06 +0300 Subject: [PATCH 26/50] simplify test --- test/components.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/test/components.jl b/test/components.jl index 70990210..017bbdb0 100644 --- a/test/components.jl +++ b/test/components.jl @@ -702,20 +702,18 @@ end @named s = PoissonSpikeTrain(; namespace = global_ns, rate=5, tspan) @named n1 = LIFExciNeuron(; namespace = global_ns) - @named n2 = LIFExciNeuron(; namespace = global_ns) - neurons = [s, n1, n2] + neurons = [s, n1] + g = MetaDiGraph() add_blox!.(Ref(g), neurons) add_edge!(g, 1, 2, Dict(:weight => 1)) - add_edge!(g, 1, 3, Dict(:weight => 1)) - add_edge!(g, 2, 3, Dict(:weight => 1)) - add_edge!(g, 3, 2, Dict(:weight => 1)) - + @named sys = system_from_graph(g) sys_simpl = structural_simplify(sys) prob = ODEProblem(sys_simpl, [], tspan) - sol = solve(prob, Vern7(), reltol=1e-9,abstol=1e-9) + sol = solve(prob, Tsit5()) @test sol.retcode == ReturnCode.Success end + From 84df57524925dd78e8ce8c0f69d29657093e612f Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Wed, 31 Jul 2024 18:52:29 +0300 Subject: [PATCH 27/50] add test for LIF circuit composite Blox --- test/components.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/components.jl b/test/components.jl index 017bbdb0..bab380d2 100644 --- a/test/components.jl +++ b/test/components.jl @@ -695,6 +695,15 @@ end @test sol.retcode == ReturnCode.Success end +@testset "LIFExciCircuitBlox" begin + @named n = LIFExciCircuitBlox(; N_neurons = 10, weight=1) + + sys_simpl = structural_simplify(n.odesystem) + prob = ODEProblem(sys_simpl, [], (0, 200.0)) + sol = solve(prob, Vern7()) + @test sol.retcode == ReturnCode.Success  +end + @testset "PoissonSpikeTrain - LIFExciBlox network" begin global_ns = :g # global namespace From a3c275ffa21ec962170421768a4dc09d7a280021 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Wed, 31 Jul 2024 18:52:52 +0300 Subject: [PATCH 28/50] only kwargs in `LIFExciCircuitBlox` --- src/blox/cortical.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blox/cortical.jl b/src/blox/cortical.jl index 523536a7..e91b46a0 100644 --- a/src/blox/cortical.jl +++ b/src/blox/cortical.jl @@ -75,7 +75,7 @@ struct LIFExciCircuitBlox <: CompositeBlox odesystem connector - function LIFExciCircuitBlox( + function LIFExciCircuitBlox(; name, N_neurons, namespace=nothing, From ba52d544cd20b1b90965ef28e36cd94d36039840 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Wed, 31 Jul 2024 18:53:05 +0300 Subject: [PATCH 29/50] export `LIFExciCircuitBlox` --- src/Neuroblox.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index a07dd807..dad1f8ef 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -188,7 +188,7 @@ export JansenRitSPM12, next_generation, qif_neuron, if_neuron, hh_neuron_excitat hh_neuron_inhibitory, van_der_pol, Generic2dOscillator export HHNeuronExciBlox, HHNeuronInhibBlox, IFNeuron, LIFNeuron, QIFNeuron, IzhikevichNeuron, LIFExciNeuron, LIFInhNeuron, CanonicalMicroCircuitBlox, WinnerTakeAllBlox, CorticalBlox, SuperCortical, HHNeuronInhib_MSN_Adam_Blox, HHNeuronInhib_FSI_Adam_Blox, HHNeuronExci_STN_Adam_Blox, - HHNeuronInhib_GPe_Adam_Blox, Striatum_MSN_Adam, Striatum_FSI_Adam, GPe_Adam, STN_Adam + HHNeuronInhib_GPe_Adam_Blox, Striatum_MSN_Adam, Striatum_FSI_Adam, GPe_Adam, STN_Adam, LIFExciCircuitBlox export LinearNeuralMass, HarmonicOscillator, JansenRit, WilsonCowan, LarterBreakspear, NextGenerationBlox, NextGenerationResolvedBlox, NextGenerationEIBlox, KuramotoOscillator export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc export HebbianPlasticity, HebbianModulationPlasticity From 0164318f988c8149187a40f8714bc9cfd4b1e6ec Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 1 Aug 2024 16:26:00 +0300 Subject: [PATCH 30/50] implement composite blox of LIF circuits --- src/blox/cortical.jl | 73 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/src/blox/cortical.jl b/src/blox/cortical.jl index e91b46a0..7ff00df0 100644 --- a/src/blox/cortical.jl +++ b/src/blox/cortical.jl @@ -134,8 +134,79 @@ struct LIFExciCircuitBlox <: CompositeBlox bc = connector_from_graph(g) + sys = isnothing(namespace) ? system_from_graph(g, bc; name) : system_from_parts(neurons; name) + + new(namespace, neurons, sys, bc) + end +end + +struct LIFInhCircuitBlox <: CompositeBlox + namespace + parts + odesystem + connector + + function LIFInhCircuitBlox(; + name, + N_neurons, + namespace=nothing, + g_L = 20 * 1e-3, # mS + V_L = -70, # mV + V_E = 0, # mV + V_I = -70, # mV + θ = -50, # mV + V_reset = -55, # mV + C = 0.2 * 1e-3, # mF + τ_AMPA = 2, # ms + τ_GABA = 5, # ms + τ_NMDA_decay = 100, # ms + τ_NMDA_rise = 2, # ms + α = 0.5, # ms⁻¹ + g_AMPA = 0.04 * 1e-3, # mS + g_AMPA_external = 1.62 * 1e-3, # mS + g_GABA = 1 * 1e-3, # mS + g_NMDA = 0.13 * 1e-3, # mS + Mg = 1, # mM + kwargs... + ) + + neurons = map(Base.OneTo(N_neurons)) do i + LIFInhNeuron(; + name = Symbol("neuron$i"), + namespace = namespaced_name(namespace, name), + g_L, + V_L, + V_E, + V_I, + θ, + V_reset, + C, + τ_AMPA, + τ_GABA, + τ_NMDA_decay, + τ_NMDA_rise, + α, + g_AMPA, + g_AMPA_external, + g_GABA, + g_NMDA, + Mg + ) + end + + g = MetaDiGraph() + add_blox!.(Ref(g), neurons) + + for i in eachindex(neurons) + for j in eachindex(neurons) + add_edge!(g, i, j, Dict(kwargs)) + end + end + + bc = connector_from_graph(g) + sys = isnothing(namespace) ? system_from_graph(g, bc; name) : system_from_parts(vcat(wtas, n_ff_inh); name) new(namespace, neurons, sys, bc) end -end \ No newline at end of file +end From ada1c957fa1cf7334dd930116a05d7e76c1be5e3 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 1 Aug 2024 16:26:12 +0300 Subject: [PATCH 31/50] export circuit bloxs --- src/Neuroblox.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index dad1f8ef..d803c849 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -188,7 +188,7 @@ export JansenRitSPM12, next_generation, qif_neuron, if_neuron, hh_neuron_excitat hh_neuron_inhibitory, van_der_pol, Generic2dOscillator export HHNeuronExciBlox, HHNeuronInhibBlox, IFNeuron, LIFNeuron, QIFNeuron, IzhikevichNeuron, LIFExciNeuron, LIFInhNeuron, CanonicalMicroCircuitBlox, WinnerTakeAllBlox, CorticalBlox, SuperCortical, HHNeuronInhib_MSN_Adam_Blox, HHNeuronInhib_FSI_Adam_Blox, HHNeuronExci_STN_Adam_Blox, - HHNeuronInhib_GPe_Adam_Blox, Striatum_MSN_Adam, Striatum_FSI_Adam, GPe_Adam, STN_Adam, LIFExciCircuitBlox + HHNeuronInhib_GPe_Adam_Blox, Striatum_MSN_Adam, Striatum_FSI_Adam, GPe_Adam, STN_Adam, LIFExciCircuitBlox, LIFInhCircuitBlox export LinearNeuralMass, HarmonicOscillator, JansenRit, WilsonCowan, LarterBreakspear, NextGenerationBlox, NextGenerationResolvedBlox, NextGenerationEIBlox, KuramotoOscillator export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc export HebbianPlasticity, HebbianModulationPlasticity From 66464ac67a6accdb0f421213cc70eaeea954bf82 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 1 Aug 2024 16:26:52 +0300 Subject: [PATCH 32/50] reorganise neuron getter functions for composite bloxs --- src/blox/blox_utilities.jl | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index 797199b7..bae94371 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -17,40 +17,30 @@ function paramscoping(;kwargs...) return paramlist end +get_exci_neurons(n::AbstractExciNeuronBlox) = n +get_exci_neurons(n) = [] + function get_exci_neurons(g::MetaDiGraph) mapreduce(x -> get_exci_neurons(x), vcat, get_blox(g)) end -function get_exci_neurons(b::AbstractComponent) +function get_exci_neurons(b::Union{AbstractComponent, CompositeBlox}) mapreduce(x -> get_exci_neurons(x), vcat, b.parts) end -function get_inh_neurons(b::AbstractComponent) - mapreduce(x -> get_inh_neurons(x), vcat, b.parts) -end - -function get_discrete_parts(b::AbstractComponent) - mapreduce(x -> get_discrete_parts(x), vcat, b.parts) -end - -function get_exci_neurons(b::CompositeBlox) - mapreduce(x -> get_exci_neurons(x), vcat, b.parts) -end +get_inh_neurons(n::AbstractInhNeuronBlox) = n +get_inh_neurons(n) = [] -function get_inh_neurons(b::CompositeBlox) +function get_inh_neurons(b::Union{AbstractComponent, CompositeBlox}) mapreduce(x -> get_inh_neurons(x), vcat, b.parts) end -function get_discrete_parts(b::CompositeBlox) +get_neurons(b::Union{AbstractComponent, CompositeBlox}) = vcat(get_exci_neurons(b), get_inh_neurons(b)) + +function get_discrete_parts(b::Union{AbstractComponent, CompositeBlox}) mapreduce(x -> get_discrete_parts(x), vcat, b.parts) end -get_exci_neurons(n::AbstractExciNeuronBlox) = n -get_exci_neurons(n) = [] - -get_inh_neurons(n::AbstractInhNeuronBlox) = n -get_inh_neurons(n) = [] - get_sys(blox) = blox.odesystem get_sys(sys::AbstractODESystem) = sys get_sys(stim::PoissonSpikeTrain) = System(Equation[], t, [], []; name=stim.name) From b56df51250c1ee12286bd3fcc6053df0147f8f90 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 1 Aug 2024 16:27:27 +0300 Subject: [PATCH 33/50] subtype new LIFs as excitatory/inhibitory --- src/blox/neuron_models.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/blox/neuron_models.jl b/src/blox/neuron_models.jl index 6db75e5c..0a928d6c 100644 --- a/src/blox/neuron_models.jl +++ b/src/blox/neuron_models.jl @@ -624,7 +624,7 @@ struct LIFNeuron <: AbstractNeuronBlox end end -struct LIFInhNeuron <: AbstractNeuronBlox +struct LIFInhNeuron <: AbstractInhNeuronBlox odesystem namespace @@ -647,7 +647,7 @@ struct LIFInhNeuron <: AbstractNeuronBlox g_AMPA_external = 1.62 * 1e-3, # mS g_GABA = 1 * 1e-3, # mS g_NMDA = 0.13 * 1e-3, # mS - Mg = 1 # mM + Mg = 1e-3 # mM ) ps = @parameters begin @@ -685,7 +685,7 @@ struct LIFInhNeuron <: AbstractNeuronBlox end end -struct LIFExciNeuron <: AbstractNeuronBlox +struct LIFExciNeuron <: AbstractExciNeuronBlox odesystem namespace @@ -708,7 +708,7 @@ struct LIFExciNeuron <: AbstractNeuronBlox g_AMPA_external = 2.1 * 1e-3, # mS g_GABA = 1.3 * 1e-3, # mS g_NMDA = 0.165 * 1e-3, # mS - Mg = 1 # mM + Mg = 1e-3 # mM ) ps = @parameters begin @@ -716,6 +716,7 @@ struct LIFExciNeuron <: AbstractNeuronBlox V_L=V_L V_E=V_E V_I=V_I + V_reset=V_reset θ=θ C=C τ_AMPA=τ_AMPA From 15d48ecfba0cbb7efeaae7230962ce2edcdbc777 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 1 Aug 2024 16:27:57 +0300 Subject: [PATCH 34/50] fix connection rule --- src/blox/connections.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index 64b139c7..a9ee284b 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -891,7 +891,7 @@ function (bc::BloxConnector)( cb = [sys_out.V ~ sys_out.θ] => [ sys_in.S_AMPA ~ sys_in.S_AMPA + 1, - sys_in.S_NMDA ~ sys_in.S_NMDA + 1 + sys_in.x ~ sys_in.x + 1 ] push!(bc.continuous_callbacks, cb) end From 77b814624b907d26b1ae7853051d30d41595d262 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 1 Aug 2024 16:28:31 +0300 Subject: [PATCH 35/50] write connection rule for spike train to LIF circuit bloxs --- src/blox/connections.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index a9ee284b..211478ff 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -941,3 +941,14 @@ function (bc::BloxConnector)( push!(bc.discrete_callbacks, cb) end +function (bc::BloxConnector)( + stim::PoissonSpikeTrain, + cb::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}; + kwargs... +) + neurons_in = get_neurons(cb) + + for neuron in neurons_in + bc(stim, neuron; kwargs...) + end +end From 79d46cfd5c22ec6a6bbb3aebfa55395366d4a91d Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 1 Aug 2024 17:08:30 +0300 Subject: [PATCH 36/50] add test for Poisson spike train to LIF exci circuit blox --- test/components.jl | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/test/components.jl b/test/components.jl index bab380d2..0cd88474 100644 --- a/test/components.jl +++ b/test/components.jl @@ -707,9 +707,10 @@ end @testset "PoissonSpikeTrain - LIFExciBlox network" begin global_ns = :g # global namespace - tspan = (0, 200) + tspan = (0, 200) # ms + spike_rate = 10* 1e-3 # spikes / ms - @named s = PoissonSpikeTrain(; namespace = global_ns, rate=5, tspan) + @named s = PoissonSpikeTrain(; namespace = global_ns, rate=spike_rate, tspan) @named n1 = LIFExciNeuron(; namespace = global_ns) neurons = [s, n1] @@ -726,3 +727,26 @@ end @test sol.retcode == ReturnCode.Success end + +@testset "PoissonSpikeTrain - LIFExciCircuitBlox" begin + global_ns = :g # global namespace + + tspan = (0, 1000) # ms + spike_rate = 10* 1e-3 # spikes / ms + + @named s = PoissonSpikeTrain(; namespace = global_ns, rate=spike_rate, tspan) + @named n = LIFExciCircuitBlox(; namespace = global_ns, N_neurons = 10, weight=1) + + neurons = [s, n] + + g = MetaDiGraph() + add_blox!.(Ref(g), neurons) + + add_edge!(g, 1, 2, Dict(:weight => 1)) + + @named sys = system_from_graph(g) + sys_simpl = structural_simplify(sys) + prob = ODEProblem(sys_simpl, [], tspan) + sol = solve(prob, Tsit5()) + @test sol.retcode == ReturnCode.Success +end From d87dc3ca8df259f0b277625fa33b881f2c51110d Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 6 Aug 2024 10:02:20 +0300 Subject: [PATCH 37/50] move callback generator function --- src/Neurographs.jl | 30 ++++++++++++++++++++++++++++ src/blox/connections.jl | 43 ++++++++++------------------------------- 2 files changed, 40 insertions(+), 33 deletions(-) diff --git a/src/Neurographs.jl b/src/Neurographs.jl index 232b28a7..2ccfc5aa 100644 --- a/src/Neurographs.jl +++ b/src/Neurographs.jl @@ -47,6 +47,36 @@ function graph_delays(g::MetaDiGraph) return bc.delays end +function generate_discrete_callbacks(g, bc::BloxConnector; t_block=missing) + if !ismissing(t_block) + eqs_params = get_equations_with_parameter_lhs(bc) + + neurons_exci = get_exci_neurons(g) + eqs = Equation[] + + for neurons in neurons_exci + nn = get_namespaced_sys(neurons) + push!(eqs,nn.spikes_window ~ 0) + + end + if !isempty(eqs_params) && !isempty(eqs) + cbs_spikes = (t_block + sqrt(eps(float(t_block)))) => eqs + cbs_params = (t_block - sqrt(eps(float(t_block)))) => eqs_params + return vcat(cbs_params, cbs_spikes, bc.discrete_callbacks) + elseif isempty(eqs_params) && !isempty(eqs) + cbs_spikes = (t_block + sqrt(eps(float(t_block)))) => eqs + return vcat(cbs_spikes, bc.discrete_callbacks) + elseif !isempty(eqs_params) && isempty(eqs) + cbs_params = (t_block - sqrt(eps(float(t_block)))) => eqs_params + return vcat(cbs_params, bc.discrete_callbacks) + else + return bc.discrete_callbacks + end + else + return bc.discrete_callbacks + end +end + function system_from_graph(g::MetaDiGraph; name, t_block=missing) bc = connector_from_graph(g) return system_from_graph(g, bc; name, t_block) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index 211478ff..4fbffb25 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -26,46 +26,23 @@ function accumulate_equation!(bc::BloxConnector, eq) bc.eqs[idx] = bc.eqs[idx].lhs ~ bc.eqs[idx].rhs + eq.rhs end -get_equations_with_parameter_lhs(bc) = filter(eq -> isparameter(eq.lhs), bc.eqs) +function accumulate_continuous_callback!(bc::BloxConnector, cb_new) + eqs_new = equations(cb_new) + cbs = get_continuous_callbacks(bc) -get_equations_with_state_lhs(bc) = filter(eq -> !isparameter(eq.lhs), bc.eqs) + for (i,c) in enumerate(cbs) + eqs = eqs(c) -function generate_discrete_callbacks(g, bc; t_block=missing) - if !ismissing(t_block) - eqs_params = get_equations_with_parameter_lhs(bc) - - neurons_exci = get_exci_neurons(g) - eqs = Equation[] - - for neurons in neurons_exci - nn = get_namespaced_sys(neurons) - push!(eqs,nn.spikes_window ~ 0) - - end - if !isempty(eqs_params) && !isempty(eqs) - cbs_spikes = (t_block + sqrt(eps(float(t_block)))) => eqs - cbs_params = (t_block - sqrt(eps(float(t_block)))) => eqs_params - return vcat(cbs_params, cbs_spikes, bc.discrete_callbacks) - elseif isempty(eqs_params) && !isempty(eqs) - cbs_spikes = (t_block + sqrt(eps(float(t_block)))) => eqs - return vcat(cbs_spikes, bc.discrete_callbacks) - elseif !isempty(eqs_params) && isempty(eqs) - cbs_params = (t_block - sqrt(eps(float(t_block)))) => eqs_params - return vcat(cbs_params, bc.discrete_callbacks) - else - return bc.discrete_callbacks + is_all_eqs_equal = reduce(&, [any(neq .== eqs) for neq in eqs_new]) + if is_all_eqs_equal + bc.continuous_callbacks[i] = eqs => 0 end - else - return bc.discrete_callbacks end end -function generate_callbacks_for_parameter_lhs(bc) - eqs = get_equations_with_parameter_lhs(bc) - cbs = [bc.param_update_times[eq.lhs] => eq for eq in eqs] +get_equations_with_parameter_lhs(bc) = filter(eq -> isparameter(eq.lhs), bc.eqs) - return cbs -end +get_equations_with_state_lhs(bc) = filter(eq -> !isparameter(eq.lhs), bc.eqs) function generate_weight_param(blox_out, blox_in; kwargs...) name_out = namespaced_nameof(blox_out) From cfe4ebbcceeebdd67d7223909d0126488a24abb2 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 6 Aug 2024 10:02:58 +0300 Subject: [PATCH 38/50] add spiking callback including refractory --- src/blox/connections.jl | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index 4fbffb25..dac84355 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -849,6 +849,22 @@ function (bc::BloxConnector)( accumulate_equation!(bc, eq) end + +function LIF_spike_affect!(integ, u, p, ctx) + integ.u[u[1]] = integ.p[p[1]] + + t_refract_end = integ.t + integ.p[p[2]] + integ.p[p[3]] = t_refract_end + + integ.p[p[4]] = 1 + + SciMLBase.add_tstop!(integ, t_refract_end) + + integ.u[u[2]] += 1 + integ.u[u[3]] += 1 +end + + function (bc::BloxConnector)( bloxout::LIFExciNeuron, bloxin::Union{LIFExciNeuron, LIFInhNeuron}; @@ -866,10 +882,21 @@ function (bc::BloxConnector)( accumulate_equation!(bc, eq) + #= cb = [sys_out.V ~ sys_out.θ] => [ sys_in.S_AMPA ~ sys_in.S_AMPA + 1, sys_in.x ~ sys_in.x + 1 ] + =# + + cb = [sys_out.V ~ sys_out.θ] => ( + LIF_spike_affect!, + [sys_out.V, sys_in.S_AMPA, sys_in.x], + [sys_out.V_reset, sys_out.t_refract_duration, sys_out.t_refract_end, sys_out.is_refractory], + [], + nothing + ) + push!(bc.continuous_callbacks, cb) end From a282670b73f7d15b06a2fcf17179613be95fb2a4 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 6 Aug 2024 10:03:50 +0300 Subject: [PATCH 39/50] add refractory period for LIFExci neuron --- src/blox/neuron_models.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/blox/neuron_models.jl b/src/blox/neuron_models.jl index 0a928d6c..42b2fad7 100644 --- a/src/blox/neuron_models.jl +++ b/src/blox/neuron_models.jl @@ -703,12 +703,13 @@ struct LIFExciNeuron <: AbstractExciNeuronBlox τ_GABA = 5, # ms τ_NMDA_decay = 100, # ms τ_NMDA_rise = 2, # ms + t_refract = 2, # ms α = 0.5, # ms⁻¹ g_AMPA = 0.05 * 1e-3, # mS g_AMPA_external = 2.1 * 1e-3, # mS g_GABA = 1.3 * 1e-3, # mS g_NMDA = 0.165 * 1e-3, # mS - Mg = 1e-3 # mM + Mg = 1 # mM ) ps = @parameters begin @@ -723,25 +724,30 @@ struct LIFExciNeuron <: AbstractExciNeuronBlox τ_GABA=τ_GABA τ_NMDA_decay=τ_NMDA_decay τ_NMDA_rise=τ_NMDA_rise + t_refract_duration=t_refract + t_refract_end=-Inf g_AMPA = g_AMPA g_AMPA_external = g_AMPA_external g_GABA = g_GABA g_NMDA = g_NMDA α=α Mg=Mg + is_refractory=0 end sts = @variables V(t)=V_L S_AMPA(t)=0 S_GABA(t)=0 S_NMDA(t)=0 x(t)=0 jcn(t)=0 [input=true] eqs = [ - D(V) ~ - (g_L * (V - V_L) + jcn) / C, + D(V) ~ - (1 - is_refractory) * (g_L * (V - V_L) + jcn) / C, D(S_AMPA) ~ - S_AMPA / τ_AMPA, D(S_GABA) ~ - S_GABA / τ_GABA, D(S_NMDA) ~ - S_NMDA / τ_NMDA_decay + α * x * (1 - S_NMDA), D(x) ~ - x / τ_NMDA_rise ] - ev = [V ~ θ] => [V ~ V_reset] - sys = System(eqs, t, sts, ps; continuous_events=[ev], name=name) + + refract_end = (t == t_refract_end) => [is_refractory ~ 0] + + sys = System(eqs, t, sts, ps; discrete_events = [refract_end], name=name) new(sys, namespace) end From ca066f47fda2e8036b3f01b33a1e0de7dbae4bba Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 6 Aug 2024 10:04:29 +0300 Subject: [PATCH 40/50] add refractory time on LIF circuit composite blox --- src/blox/cortical.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/blox/cortical.jl b/src/blox/cortical.jl index 7ff00df0..5f5cf02b 100644 --- a/src/blox/cortical.jl +++ b/src/blox/cortical.jl @@ -90,6 +90,7 @@ struct LIFExciCircuitBlox <: CompositeBlox τ_GABA = 5, # ms τ_NMDA_decay = 100, # ms τ_NMDA_rise = 2, # ms + t_refract = 2, # ms α = 0.5, # ms⁻¹ g_AMPA = 0.05 * 1e-3, # mS g_AMPA_external = 2.1 * 1e-3, # mS @@ -114,6 +115,7 @@ struct LIFExciCircuitBlox <: CompositeBlox τ_GABA, τ_NMDA_decay, τ_NMDA_rise, + t_refract, α, g_AMPA, g_AMPA_external, From c859175489c209faf18c585c727a46466d5875c0 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 6 Aug 2024 12:25:11 +0300 Subject: [PATCH 41/50] try a new way of merging spike callbacks --- src/blox/blox_utilities.jl | 4 ++++ src/blox/connections.jl | 9 ++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index bae94371..b45d63ed 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -150,6 +150,10 @@ get_continuous_callbacks(bc::BloxConnector) = bc.continuous_callbacks get_continuous_callbacks(blox::Union{CompositeBlox, AbstractComponent}) = (get_continuous_callbacks ∘ get_connector)(blox) get_continuous_callbacks(blox) = [] +get_spike_affect_states(bc::BloxConnector) = bc.spike_affect_states +get_spike_affect_states(blox::Union{CompositeBlox, AbstractComponent}) = (get_spike_affect_states ∘ get_connector)(blox) +get_spike_affect_states(blox) = Dict{Symbol, Num}() + get_weight_learning_rules(bc::BloxConnector) = bc.learning_rules get_weight_learning_rules(blox::Union{CompositeBlox, AbstractComponent}) = (get_weight_learning_rules ∘ get_connector)(blox) get_weight_learning_rules(blox) = Dict{Num, AbstractLearningRule}() diff --git a/src/blox/connections.jl b/src/blox/connections.jl index dac84355..cba998d3 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -4,9 +4,10 @@ mutable struct BloxConnector delays::Vector{Num} discrete_callbacks continuous_callbacks + spike_affect_states::Dict{Symbol, Num} learning_rules - BloxConnector() = new(Equation[], Num[], Num[], Pair{Any, Vector{Equation}}[], Dict{Num, AbstractLearningRule}()) + BloxConnector() = new(Equation[], Num[], Num[], Pair{Any, Vector{Equation}}[], Dict{Symbol, Num}(), Dict{Num, AbstractLearningRule}()) function BloxConnector(bloxs) eqs = mapreduce(get_input_equations, vcat, bloxs) @@ -14,6 +15,7 @@ mutable struct BloxConnector delays = mapreduce(get_delay_parameters, vcat, bloxs) discrete_callbacks = mapreduce(get_discrete_callbacks, vcat, bloxs) continuous_callbacks = mapreduce(get_continuous_callbacks, vcat, bloxs) + spike_affect_states = mapreduce(get_spike_affect_states, vcat, bloxs) learning_rules = mapreduce(get_weight_learning_rules, merge, bloxs) new(eqs, weights, delays, discrete_callbacks, continuous_callbacks, learning_rules) @@ -860,8 +862,9 @@ function LIF_spike_affect!(integ, u, p, ctx) SciMLBase.add_tstop!(integ, t_refract_end) - integ.u[u[2]] += 1 - integ.u[u[3]] += 1 + for ui in u[2:end] + integ.u[ui] += 1 + end end From 355ee1fa6f3633a5eab5f16b3a1048832e4fa0db Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 6 Aug 2024 12:40:46 +0300 Subject: [PATCH 42/50] rename `get_blox` -> `get_bloxs` for clarity --- src/Neurographs.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Neurographs.jl b/src/Neurographs.jl index 2ccfc5aa..e8509503 100644 --- a/src/Neurographs.jl +++ b/src/Neurographs.jl @@ -12,7 +12,7 @@ function add_blox!(g::MetaDiGraph,blox) add_vertex!(g, :blox, blox) end -function get_blox(g::MetaDiGraph) +function get_bloxs(g::MetaDiGraph) bs = [] for v in vertices(g) b = get_prop(g, v, :blox) @@ -24,10 +24,10 @@ function get_blox(g::MetaDiGraph) return bs end -get_sys(g::MetaDiGraph) = get_sys.(get_blox(g)) +get_sys(g::MetaDiGraph) = get_sys.(get_bloxs(g)) function connector_from_graph(g::MetaDiGraph) - bloxs = get_blox(g) + bloxs = get_bloxs(g) link = BloxConnector(bloxs) for v in vertices(g) From 5e49bae3e01e998cee48c499c1e5968b8dad0a43 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 6 Aug 2024 15:17:28 +0300 Subject: [PATCH 43/50] accumulate all spike affects into a single continuous callback per LIF neuron --- src/Neurographs.jl | 39 ++++++++++++++++++++++++++++- src/blox/blox_utilities.jl | 2 +- src/blox/connections.jl | 50 +++++++++----------------------------- src/blox/neuron_models.jl | 16 +++++++++++- 4 files changed, 65 insertions(+), 42 deletions(-) diff --git a/src/Neurographs.jl b/src/Neurographs.jl index e8509503..981c3a64 100644 --- a/src/Neurographs.jl +++ b/src/Neurographs.jl @@ -26,6 +26,11 @@ end get_sys(g::MetaDiGraph) = get_sys.(get_bloxs(g)) +get_dynamics_bloxs(blox::AbstractBlox) = blox +get_dynamics_bloxs(blox::Union{CompositeBlox, AbstractComponent}) = get_blox_parts(blox) + +flatten_graph(g::MetaDiGraph) = mapreduce(get_dynamics_bloxs, vcat, get_bloxs(g)) + function connector_from_graph(g::MetaDiGraph) bloxs = get_bloxs(g) link = BloxConnector(bloxs) @@ -77,6 +82,38 @@ function generate_discrete_callbacks(g, bc::BloxConnector; t_block=missing) end end +generate_continuous_callbacks(blox, states_dst) = [] + +function generate_continuous_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, states_dst) + sys = get_namespaced_sys(blox) + + cb = [sys.V ~ sys.θ] => ( + LIF_spike_affect!, + vcat(sys.V, states_dst), + [sys.V_reset, sys.t_refract_duration, sys.t_refract_end, sys.is_refractory], + [], + nothing + ) + + return cb +end + +function generate_continuous_callbacks(g, bc::BloxConnector) + bloxs = flatten_graph(g) + spike_affect_states = get_spike_affect_states(bc) + + cbs = [] + for blox in bloxs + name_blox = namespaced_nameof(blox) + + if haskey(spike_affect_states, name_blox) + push!(cbs, generate_continuous_callbacks(blox, spike_affect_states[name_blox])) + end + end + + return reduce(vcat, identity.(cbs)) +end + function system_from_graph(g::MetaDiGraph; name, t_block=missing) bc = connector_from_graph(g) return system_from_graph(g, bc; name, t_block) @@ -93,7 +130,7 @@ function system_from_graph(g::MetaDiGraph, bc::BloxConnector; name, t_block=miss connection_eqs = get_equations_with_state_lhs(bc) discrete_cbs = identity.(generate_discrete_callbacks(g, bc; t_block)) - continuous_cbs = identity.(get_continuous_callbacks(bc)) + continuous_cbs = identity.(generate_continuous_callbacks(g, bc)) return compose(System(connection_eqs, t, [], params(bc); name, discrete_events = discrete_cbs, continuous_events = continuous_cbs), blox_syss) end diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index b45d63ed..5ddf5b5d 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -152,7 +152,7 @@ get_continuous_callbacks(blox) = [] get_spike_affect_states(bc::BloxConnector) = bc.spike_affect_states get_spike_affect_states(blox::Union{CompositeBlox, AbstractComponent}) = (get_spike_affect_states ∘ get_connector)(blox) -get_spike_affect_states(blox) = Dict{Symbol, Num}() +get_spike_affect_states(blox) = Dict{Symbol, Vector{Num}}() get_weight_learning_rules(bc::BloxConnector) = bc.learning_rules get_weight_learning_rules(blox::Union{CompositeBlox, AbstractComponent}) = (get_weight_learning_rules ∘ get_connector)(blox) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index cba998d3..69c7633a 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -4,10 +4,10 @@ mutable struct BloxConnector delays::Vector{Num} discrete_callbacks continuous_callbacks - spike_affect_states::Dict{Symbol, Num} + spike_affect_states::Dict{Symbol, Vector{Num}} learning_rules - BloxConnector() = new(Equation[], Num[], Num[], Pair{Any, Vector{Equation}}[], Dict{Symbol, Num}(), Dict{Num, AbstractLearningRule}()) + BloxConnector() = new(Equation[], Num[], Num[], Pair{Any, Vector{Equation}}[], Dict{Symbol, Vector{Num}}(), Dict{Num, AbstractLearningRule}()) function BloxConnector(bloxs) eqs = mapreduce(get_input_equations, vcat, bloxs) @@ -15,10 +15,10 @@ mutable struct BloxConnector delays = mapreduce(get_delay_parameters, vcat, bloxs) discrete_callbacks = mapreduce(get_discrete_callbacks, vcat, bloxs) continuous_callbacks = mapreduce(get_continuous_callbacks, vcat, bloxs) - spike_affect_states = mapreduce(get_spike_affect_states, vcat, bloxs) + spike_affect_states = mapreduce(get_spike_affect_states, merge, bloxs) learning_rules = mapreduce(get_weight_learning_rules, merge, bloxs) - new(eqs, weights, delays, discrete_callbacks, continuous_callbacks, learning_rules) + new(eqs, weights, delays, discrete_callbacks, continuous_callbacks, spike_affect_states, learning_rules) end end @@ -28,17 +28,11 @@ function accumulate_equation!(bc::BloxConnector, eq) bc.eqs[idx] = bc.eqs[idx].lhs ~ bc.eqs[idx].rhs + eq.rhs end -function accumulate_continuous_callback!(bc::BloxConnector, cb_new) - eqs_new = equations(cb_new) - cbs = get_continuous_callbacks(bc) - - for (i,c) in enumerate(cbs) - eqs = eqs(c) - - is_all_eqs_equal = reduce(&, [any(neq .== eqs) for neq in eqs_new]) - if is_all_eqs_equal - bc.continuous_callbacks[i] = eqs => 0 - end +function accumulate_spike_affect_states!(bc::BloxConnector, name_blox_src, states_dst) + if haskey(bc.spike_affect_states, name_blox_src) + append!(bc.spike_affect_states[name_blox_src], states_dst) + else + bc.spike_affect_states[name_blox_src] = states_dst end end @@ -851,23 +845,6 @@ function (bc::BloxConnector)( accumulate_equation!(bc, eq) end - -function LIF_spike_affect!(integ, u, p, ctx) - integ.u[u[1]] = integ.p[p[1]] - - t_refract_end = integ.t + integ.p[p[2]] - integ.p[p[3]] = t_refract_end - - integ.p[p[4]] = 1 - - SciMLBase.add_tstop!(integ, t_refract_end) - - for ui in u[2:end] - integ.u[ui] += 1 - end -end - - function (bc::BloxConnector)( bloxout::LIFExciNeuron, bloxin::Union{LIFExciNeuron, LIFInhNeuron}; @@ -884,14 +861,9 @@ function (bc::BloxConnector)( (1 + sys_in.Mg * exp(-0.062 * sys_in.V) / 3.57) accumulate_equation!(bc, eq) - - #= - cb = [sys_out.V ~ sys_out.θ] => [ - sys_in.S_AMPA ~ sys_in.S_AMPA + 1, - sys_in.x ~ sys_in.x + 1 - ] - =# + accumulate_spike_affect_states!(bc, nameof(sys_out), [sys_in.S_AMPA, sys_in.x]) + cb = [sys_out.V ~ sys_out.θ] => ( LIF_spike_affect!, [sys_out.V, sys_in.S_AMPA, sys_in.x], diff --git a/src/blox/neuron_models.jl b/src/blox/neuron_models.jl index 42b2fad7..552f3529 100644 --- a/src/blox/neuron_models.jl +++ b/src/blox/neuron_models.jl @@ -624,6 +624,21 @@ struct LIFNeuron <: AbstractNeuronBlox end end +function LIF_spike_affect!(integ, u, p, ctx) + integ.u[u[1]] = integ.p[p[1]] + + t_refract_end = integ.t + integ.p[p[2]] + integ.p[p[3]] = t_refract_end + + integ.p[p[4]] = 1 + + SciMLBase.add_tstop!(integ, t_refract_end) + + for i in eachindex(u)[2:end] + integ.u[u[i]] += 1 + end +end + struct LIFInhNeuron <: AbstractInhNeuronBlox odesystem namespace @@ -744,7 +759,6 @@ struct LIFExciNeuron <: AbstractExciNeuronBlox D(x) ~ - x / τ_NMDA_rise ] - refract_end = (t == t_refract_end) => [is_refractory ~ 0] sys = System(eqs, t, sts, ps; discrete_events = [refract_end], name=name) From 6ea55f8c73cec2a924557db1990f4e41a3dda68c Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 6 Aug 2024 16:45:50 +0300 Subject: [PATCH 44/50] update continuous callback generation before `compose` --- src/Neurographs.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Neurographs.jl b/src/Neurographs.jl index 981c3a64..8b037fef 100644 --- a/src/Neurographs.jl +++ b/src/Neurographs.jl @@ -111,7 +111,7 @@ function generate_continuous_callbacks(g, bc::BloxConnector) end end - return reduce(vcat, identity.(cbs)) + return reduce(vcat, cbs; init=[]) end function system_from_graph(g::MetaDiGraph; name, t_block=missing) @@ -140,7 +140,7 @@ function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}; na connection_eqs = get_equations_with_state_lhs(bc) discrete_cbs = identity.(generate_discrete_callbacks(g, bc; t_block)) - continuous_cbs = identity.(get_continuous_callbacks(bc)) + continuous_cbs = identity.(generate_continuous_callbacks(g, bc)) return compose(System(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = discrete_cbs, continuous_events = continuous_cbs), blox_syss) end From 2dea7212c3e73779a5112ece85ad92252ee58b71 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 6 Aug 2024 16:45:59 +0300 Subject: [PATCH 45/50] fix typo --- src/blox/blox_utilities.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index 5ddf5b5d..f350d8cb 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -21,7 +21,7 @@ get_exci_neurons(n::AbstractExciNeuronBlox) = n get_exci_neurons(n) = [] function get_exci_neurons(g::MetaDiGraph) - mapreduce(x -> get_exci_neurons(x), vcat, get_blox(g)) + mapreduce(x -> get_exci_neurons(x), vcat, get_bloxs(g)) end function get_exci_neurons(b::Union{AbstractComponent, CompositeBlox}) From 63fb054a7b353f16d3b500512b5a2e1c48c5e798 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Wed, 7 Aug 2024 14:26:22 +0300 Subject: [PATCH 46/50] always return a vector of bloxs when flattening graph --- src/Neurographs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Neurographs.jl b/src/Neurographs.jl index 8b037fef..1007d92d 100644 --- a/src/Neurographs.jl +++ b/src/Neurographs.jl @@ -26,7 +26,7 @@ end get_sys(g::MetaDiGraph) = get_sys.(get_bloxs(g)) -get_dynamics_bloxs(blox::AbstractBlox) = blox +get_dynamics_bloxs(blox) = [blox] get_dynamics_bloxs(blox::Union{CompositeBlox, AbstractComponent}) = get_blox_parts(blox) flatten_graph(g::MetaDiGraph) = mapreduce(get_dynamics_bloxs, vcat, get_bloxs(g)) From da848e849962cdf1b426b7fae12c0e82519afa8b Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Wed, 7 Aug 2024 14:26:46 +0300 Subject: [PATCH 47/50] add utility to wrap non-vector elements in vector --- src/blox/blox_utilities.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index f350d8cb..b1cd22fc 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -355,3 +355,6 @@ function get_connection_rule(kwargs, bloxout, bloxin, w) return rhs end + +to_vector(v::AbstractVector) = v +to_vector(v) = [v] From 53e19024be6972aa1769d0a6363cbbd3b4305726 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Wed, 7 Aug 2024 14:32:20 +0300 Subject: [PATCH 48/50] improve stepping in `PoissonSpikeTrain` and allow for changing spike rates with matching tspans --- src/blox/sources.jl | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/blox/sources.jl b/src/blox/sources.jl index 9527ccd5..d8febf44 100644 --- a/src/blox/sources.jl +++ b/src/blox/sources.jl @@ -159,21 +159,33 @@ struct PoissonSpikeTrain <: StimulusBlox name namespace rate - dt tspan + prob_dt rng - function PoissonSpikeTrain(; name, namespace, rate, tspan, dt=0.5/rate, rng = MersenneTwister(1234)) - new(name, namespace, rate, dt, tspan, rng) + function PoissonSpikeTrain(; name, namespace, rate, tspan, prob_dt = 0.01, rng = MersenneTwister(1234)) + rate = to_vector(rate) + tspan = to_vector(tspan) + + @assert length(rate) == length(tspan) "The number of Poisson rates need to match the number of tspan intervals." + + new(name, namespace, rate, tspan, prob_dt, rng) end end function generate_spike_times(stim::PoissonSpikeTrain) # This could also change to a dispatch of Random.rand() t_spikes = Float64[] - for t in range(stim.tspan...; step = stim.dt) - if rand(stim.rng) < stim.rate * stim.dt - push!(t_spikes, t) + for (rate, tspan) in zip(stim.rate, stim.tspan) + # The dt step is determined by the CDF of the Exponential distribution. + # The Exponential is the distribution of the inter-event times for Poisson-distributed events. + # `prob_dt` determines the probability so that `P_CDF_Exponential(dt) = prob_dt` , and then we solve for dt. + # This way we make sure that with probability `1 - prob_dt` there won't be any events within a single dt step. + dt = - log(1 - stim.prob_dt) / (1 / rate) + for t in range(tspan...; step = dt) + if rand(stim.rng) < rate * dt + push!(t_spikes, t) + end end end From 9dbf9ea830cfd91ea049fb89a0524e94778ad40a Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Wed, 7 Aug 2024 18:43:36 +0300 Subject: [PATCH 49/50] fix typo in `LIFInhCircuitBlox` --- src/blox/cortical.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blox/cortical.jl b/src/blox/cortical.jl index 5f5cf02b..8b50c694 100644 --- a/src/blox/cortical.jl +++ b/src/blox/cortical.jl @@ -207,7 +207,7 @@ struct LIFInhCircuitBlox <: CompositeBlox bc = connector_from_graph(g) - sys = isnothing(namespace) ? system_from_graph(g, bc; name) : system_from_parts(vcat(wtas, n_ff_inh); name) + sys = isnothing(namespace) ? system_from_graph(g, bc; name) : system_from_parts(neurons; name) new(namespace, neurons, sys, bc) end From 2beaabb5c6e355c533e74ecd198cb39451a64f25 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Wed, 7 Aug 2024 18:44:01 +0300 Subject: [PATCH 50/50] add connection rule for LIF circuit to circuit --- src/blox/connections.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index 69c7633a..22b3bde9 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -920,6 +920,21 @@ function (bc::BloxConnector)( push!(bc.discrete_callbacks, cb) end +function (bc::BloxConnector)( + bloxout::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}, + bloxin::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}; + kwargs... +) + neurons_out = get_neurons(bloxout) + neurons_in = get_neurons(bloxin) + + for neuron_out in neurons_out + for neuron_in in neurons_in + bc(neuron_out, neuron_in; kwargs...) + end + end +end + function (bc::BloxConnector)( stim::PoissonSpikeTrain, cb::Union{LIFExciCircuitBlox, LIFInhCircuitBlox};