Skip to content

Commit

Permalink
Merge pull request #370 from Neuroblox/ho/decision_making_components
Browse files Browse the repository at this point in the history
Add features for Drift-Diffusion tutorial
  • Loading branch information
harisorgn authored Aug 7, 2024
2 parents 5ccbcdf + 2beaabb commit 9c2cd12
Show file tree
Hide file tree
Showing 10 changed files with 727 additions and 411 deletions.
8 changes: 4 additions & 4 deletions src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ include("measurementmodels/fmri.jl")
include("measurementmodels/lfp.jl")
include("datafitting/spectralDCM.jl")
include("blox/neural_mass.jl")
include("blox/cortical_blox.jl")
include("blox/cortical.jl")
include("blox/canonicalmicrocircuit.jl")
include("blox/neuron_models.jl")
include("blox/DBS_Model_Blox_Adam_Brown.jl")
Expand Down Expand Up @@ -187,15 +187,15 @@ end

export JansenRitSPM12, next_generation, qif_neuron, if_neuron, hh_neuron_excitatory,
hh_neuron_inhibitory, van_der_pol, Generic2dOscillator
export HHNeuronExciBlox, HHNeuronInhibBlox, IFNeuron, LIFNeuron, QIFNeuron, IzhikevichNeuron,
export HHNeuronExciBlox, HHNeuronInhibBlox, IFNeuron, LIFNeuron, QIFNeuron, IzhikevichNeuron, LIFExciNeuron, LIFInhNeuron,
CanonicalMicroCircuitBlox, WinnerTakeAllBlox, CorticalBlox, SuperCortical, HHNeuronInhib_MSN_Adam_Blox, HHNeuronInhib_FSI_Adam_Blox, HHNeuronExci_STN_Adam_Blox,
HHNeuronInhib_GPe_Adam_Blox, Striatum_MSN_Adam, Striatum_FSI_Adam, GPe_Adam, STN_Adam
HHNeuronInhib_GPe_Adam_Blox, Striatum_MSN_Adam, Striatum_FSI_Adam, GPe_Adam, STN_Adam, LIFExciCircuitBlox, LIFInhCircuitBlox
export LinearNeuralMass, HarmonicOscillator, JansenRit, WilsonCowan, LarterBreakspear, NextGenerationBlox, NextGenerationResolvedBlox, NextGenerationEIBlox, KuramotoOscillator
export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc
export HebbianPlasticity, HebbianModulationPlasticity
export Agent, ClassificationEnvironment, GreedyPolicy, reset!
export LearningBlox
export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ExternalInput
export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ExternalInput, PoissonSpikeTrain
export PowerSpectrumBlox, BandPassFilterBlox
export OUBlox, OUCouplingBlox
export phase_inter, phase_sin_blox, phase_cos_blox
Expand Down
88 changes: 79 additions & 9 deletions src/Neurographs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function add_blox!(g::MetaDiGraph,blox)
add_vertex!(g, :blox, blox)
end

function get_blox(g::MetaDiGraph)
function get_bloxs(g::MetaDiGraph)
bs = []
for v in vertices(g)
b = get_prop(g, v, :blox)
Expand All @@ -24,10 +24,15 @@ function get_blox(g::MetaDiGraph)
return bs
end

get_sys(g::MetaDiGraph) = get_sys.(get_blox(g))
get_sys(g::MetaDiGraph) = get_sys.(get_bloxs(g))

get_dynamics_bloxs(blox) = [blox]
get_dynamics_bloxs(blox::Union{CompositeBlox, AbstractComponent}) = get_blox_parts(blox)

flatten_graph(g::MetaDiGraph) = mapreduce(get_dynamics_bloxs, vcat, get_bloxs(g))

function connector_from_graph(g::MetaDiGraph)
bloxs = get_blox(g)
bloxs = get_bloxs(g)
link = BloxConnector(bloxs)

for v in vertices(g)
Expand All @@ -47,6 +52,68 @@ function graph_delays(g::MetaDiGraph)
return bc.delays
end

function generate_discrete_callbacks(g, bc::BloxConnector; t_block=missing)
if !ismissing(t_block)
eqs_params = get_equations_with_parameter_lhs(bc)

neurons_exci = get_exci_neurons(g)
eqs = Equation[]

for neurons in neurons_exci
nn = get_namespaced_sys(neurons)
push!(eqs,nn.spikes_window ~ 0)

end
if !isempty(eqs_params) && !isempty(eqs)
cbs_spikes = (t_block + sqrt(eps(float(t_block)))) => eqs
cbs_params = (t_block - sqrt(eps(float(t_block)))) => eqs_params
return vcat(cbs_params, cbs_spikes, bc.discrete_callbacks)
elseif isempty(eqs_params) && !isempty(eqs)
cbs_spikes = (t_block + sqrt(eps(float(t_block)))) => eqs
return vcat(cbs_spikes, bc.discrete_callbacks)
elseif !isempty(eqs_params) && isempty(eqs)
cbs_params = (t_block - sqrt(eps(float(t_block)))) => eqs_params
return vcat(cbs_params, bc.discrete_callbacks)
else
return bc.discrete_callbacks
end
else
return bc.discrete_callbacks
end
end

generate_continuous_callbacks(blox, states_dst) = []

function generate_continuous_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, states_dst)
sys = get_namespaced_sys(blox)

cb = [sys.V ~ sys.θ] => (
LIF_spike_affect!,
vcat(sys.V, states_dst),
[sys.V_reset, sys.t_refract_duration, sys.t_refract_end, sys.is_refractory],
[],
nothing
)

return cb
end

function generate_continuous_callbacks(g, bc::BloxConnector)
bloxs = flatten_graph(g)
spike_affect_states = get_spike_affect_states(bc)

cbs = []
for blox in bloxs
name_blox = namespaced_nameof(blox)

if haskey(spike_affect_states, name_blox)
push!(cbs, generate_continuous_callbacks(blox, spike_affect_states[name_blox]))
end
end

return reduce(vcat, cbs; init=[])
end

function system_from_graph(g::MetaDiGraph; name, t_block=missing)
bc = connector_from_graph(g)
return system_from_graph(g, bc; name, t_block)
Expand All @@ -62,17 +129,20 @@ function system_from_graph(g::MetaDiGraph, bc::BloxConnector; name, t_block=miss
blox_syss = get_sys(g)
connection_eqs = get_equations_with_state_lhs(bc)

cbs = identity.(get_callbacks(g, bc; t_block))
return compose(System(connection_eqs, t, [], params(bc); name, discrete_events = cbs), blox_syss)
discrete_cbs = identity.(generate_discrete_callbacks(g, bc; t_block))
continuous_cbs = identity.(generate_continuous_callbacks(g, bc))

return compose(System(connection_eqs, t, [], params(bc); name, discrete_events = discrete_cbs, continuous_events = continuous_cbs), blox_syss)
end

function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}; name, t_block=missing)

blox_syss = get_sys(g)

connection_eqs = get_equations_with_state_lhs(bc)
cbs = identity.(get_callbacks(g, bc; t_block))
return compose(System(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = cbs), blox_syss)

discrete_cbs = identity.(generate_discrete_callbacks(g, bc; t_block))
continuous_cbs = identity.(generate_continuous_callbacks(g, bc))

return compose(System(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = discrete_cbs, continuous_events = continuous_cbs), blox_syss)
end

function system_from_parts(parts::AbstractVector; name)
Expand Down
81 changes: 42 additions & 39 deletions src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,33 @@ function paramscoping(;kwargs...)
return paramlist
end

get_exci_neurons(n::AbstractExciNeuronBlox) = n
get_exci_neurons(n) = []

function get_exci_neurons(g::MetaDiGraph)
mapreduce(x -> get_exci_neurons(x), vcat, get_blox(g))
mapreduce(x -> get_exci_neurons(x), vcat, get_bloxs(g))
end

function get_exci_neurons(b::AbstractComponent)
function get_exci_neurons(b::Union{AbstractComponent, CompositeBlox})
mapreduce(x -> get_exci_neurons(x), vcat, b.parts)
end

function get_inh_neurons(b::AbstractComponent)
mapreduce(x -> get_inh_neurons(x), vcat, b.parts)
end

function get_discrete_parts(b::AbstractComponent)
mapreduce(x -> get_discrete_parts(x), vcat, b.parts)
end

function get_exci_neurons(b::CompositeBlox)
mapreduce(x -> get_exci_neurons(x), vcat, b.parts)
end
get_inh_neurons(n::AbstractInhNeuronBlox) = n
get_inh_neurons(n) = []

function get_inh_neurons(b::CompositeBlox)
function get_inh_neurons(b::Union{AbstractComponent, CompositeBlox})
mapreduce(x -> get_inh_neurons(x), vcat, b.parts)
end

function get_discrete_parts(b::CompositeBlox)
get_neurons(b::Union{AbstractComponent, CompositeBlox}) = vcat(get_exci_neurons(b), get_inh_neurons(b))

function get_discrete_parts(b::Union{AbstractComponent, CompositeBlox})
mapreduce(x -> get_discrete_parts(x), vcat, b.parts)
end

get_exci_neurons(n::AbstractExciNeuronBlox) = n
get_exci_neurons(n) = []

get_inh_neurons(n::AbstractInhNeuronBlox) = n
get_inh_neurons(n) = []

get_sys(blox) = blox.odesystem
get_sys(sys::AbstractODESystem) = sys
get_sys(stim::PoissonSpikeTrain) = System(Equation[], t, [], []; name=stim.name)

function get_namespaced_sys(blox)
sys = get_sys(blox)
Expand Down Expand Up @@ -110,7 +101,7 @@ end
which holds a `BloxConnector` object with all relevant connections
from lower levels and this level.
"""
function input_equations(blox)
function get_input_equations(blox::Union{AbstractBlox, ObserverBlox})
sys = get_sys(blox)
inps = inputs(sys)
sys_eqs = equations(sys)
Expand All @@ -137,28 +128,37 @@ function input_equations(blox)
return eqs
end

input_equations(blox::AbstractComponent) = blox.connector.eqs
input_equations(blox::CompositeBlox) = blox.connector.eqs
input_equations(::ImageStimulus) = []
get_connector(blox::Union{CompositeBlox, AbstractComponent}) = blox.connector

get_input_equations(bc::BloxConnector) = bc.eqs
get_input_equations(blox::Union{CompositeBlox, AbstractComponent}) = (get_input_equations get_connector)(blox)
get_input_equations(blox) = []

weight_parameters(blox) = Num[]
weight_parameters(blox::AbstractComponent) = blox.connector.weights #I think this is the fix?
weight_parameters(blox::CompositeBlox) = blox.connector.weights #I think this is the fix?
get_weight_parameters(bc::BloxConnector) = bc.weights
get_weight_parameters(blox::Union{CompositeBlox, AbstractComponent}) = (get_weight_parameters get_connector)(blox)
get_weight_parameters(blox) = Num[]

delay_parameters(blox) = Num[]
delay_parameters(blox::AbstractComponent) = blox.connector.delays
delay_parameters(blox::CompositeBlox) = blox.connector.delays
get_delay_parameters(bc::BloxConnector) = bc.delays
get_delay_parameters(blox::Union{CompositeBlox, AbstractComponent}) = (get_delay_parameters get_connector)(blox)
get_delay_parameters(blox) = Num[]

event_callbacks(blox) = []
event_callbacks(blox::AbstractComponent) = blox.connector.events
event_callbacks(blox::CompositeBlox) = blox.connector.events
get_discrete_callbacks(bc::BloxConnector) = bc.discrete_callbacks
get_discrete_callbacks(blox::Union{CompositeBlox, AbstractComponent}) = (get_discrete_callbacks get_connector)(blox)
get_discrete_callbacks(blox) = []

weight_learning_rules(blox) = Dict{Num, AbstractLearningRule}()
weight_learning_rules(bc::BloxConnector) = bc.learning_rules
weight_learning_rules(blox::AbstractComponent) = weight_learning_rules(blox.connector)
weight_learning_rules(blox::CompositeBlox) = weight_learning_rules(blox.connector)
get_continuous_callbacks(bc::BloxConnector) = bc.continuous_callbacks
get_continuous_callbacks(blox::Union{CompositeBlox, AbstractComponent}) = (get_continuous_callbacks get_connector)(blox)
get_continuous_callbacks(blox) = []

get_blox_parts(blox) = blox.parts
get_spike_affect_states(bc::BloxConnector) = bc.spike_affect_states
get_spike_affect_states(blox::Union{CompositeBlox, AbstractComponent}) = (get_spike_affect_states get_connector)(blox)
get_spike_affect_states(blox) = Dict{Symbol, Vector{Num}}()

get_weight_learning_rules(bc::BloxConnector) = bc.learning_rules
get_weight_learning_rules(blox::Union{CompositeBlox, AbstractComponent}) = (get_weight_learning_rules get_connector)(blox)
get_weight_learning_rules(blox) = Dict{Num, AbstractLearningRule}()

get_blox_parts(blox::Union{CompositeBlox, AbstractComponent}) = blox.parts

function get_weight(kwargs, name_blox1, name_blox2)
if haskey(kwargs, :weight)
Expand Down Expand Up @@ -355,3 +355,6 @@ function get_connection_rule(kwargs, bloxout, bloxin, w)

return rhs
end

to_vector(v::AbstractVector) = v
to_vector(v) = [v]
Loading

0 comments on commit 9c2cd12

Please sign in to comment.