Skip to content

Commit

Permalink
implemented linear approximation method seems to work in simple ra25 …
Browse files Browse the repository at this point in the history
…test case.. will need to test in objrec next (which has diff params -- that's a challenge).
  • Loading branch information
rcoreilly committed Jun 12, 2024
1 parent 15fed25 commit b8cac23
Show file tree
Hide file tree
Showing 33 changed files with 164 additions and 62 deletions.
10 changes: 5 additions & 5 deletions axon/enumgen.go

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions axon/gpu_hlsl/gpu_synca.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void SynCaSendPath(in Context ctx, in PathParams pj, in LayerParams ly, uint ni,
return;
}

float snCaSyn = pj.Learn.KinaseCa.SpikeG * NrnV(ctx, ni, di, CaSyn);
float snCaSyn = pj.Learn.KinaseCa.CaScale * NrnV(ctx, ni, di, CaSyn);
uint cni = pj.Indexes.SendConSt + lni;
uint synst = pj.Indexes.SynapseSt + SendCon[cni].Start;
uint synn = SendCon[cni].N;
Expand All @@ -80,7 +80,7 @@ void SynCaRecvPath(in Context ctx, in PathParams pj, in LayerParams ly, uint ni,
return;
}

float rnCaSyn = pj.Learn.KinaseCa.SpikeG * NrnV(ctx, ni, di, CaSyn);
float rnCaSyn = pj.Learn.KinaseCa.CaScale * NrnV(ctx, ni, di, CaSyn);
uint cni = pj.Indexes.RecvConSt + lni;
uint synst = pj.Indexes.RecvSynSt + RecvCon[cni].Start;
uint synn = RecvCon[cni].N;
Expand Down
17 changes: 17 additions & 0 deletions axon/layerparams.go
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,18 @@ func (ly *LayerParams) SpikeFromG(ctx *Context, ni, di uint32, lpl *Pool) {
SetNrnV(ctx, ni, di, SpkMax, spkmax)
}
}
bin := ctx.Cycle / 50
spk := NrnV(ctx, ni, di, Spike)
switch bin {
case 0:
AddNrnV(ctx, ni, di, SpkBin0, spk)
case 1:
AddNrnV(ctx, ni, di, SpkBin1, spk)
case 2:
AddNrnV(ctx, ni, di, SpkBin2, spk)
case 3:
AddNrnV(ctx, ni, di, SpkBin3, spk)
}
}

// PostSpikeSpecial does updates at neuron level after spiking has been computed.
Expand Down Expand Up @@ -973,6 +985,11 @@ func (ly *LayerParams) NewStateNeuron(ctx *Context, ni, di uint32, vals *LayerVa
SetNrnV(ctx, ni, di, SpkMax, 0)
SetNrnV(ctx, ni, di, SpkMaxCa, 0)

SetNrnV(ctx, ni, di, SpkBin0, 0)
SetNrnV(ctx, ni, di, SpkBin1, 0)
SetNrnV(ctx, ni, di, SpkBin2, 0)
SetNrnV(ctx, ni, di, SpkBin3, 0)

ly.Acts.DecayState(ctx, ni, di, ly.Acts.Decay.Act, ly.Acts.Decay.Glong, ly.Acts.Decay.AHP)
// Note: synapse-level Ca decay happens in DWt
ly.Acts.KNaNewState(ctx, ni, di)
Expand Down
11 changes: 8 additions & 3 deletions axon/neuron.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ const (
/////////////////////////////////////////
// Calcium for learning

// CaSpkM is spike-driven calcium trace used as a neuron-level proxy for synpatic credit assignment factor based on continuous time-integrated spiking: exponential integration of SpikeG * Spike at MTau time constant (typically 5). Simulates a calmodulin (CaM) like signal at the most abstract level.
CaSpkM

// CaSpkP is continuous cascaded integration of CaSpkM at PTau time constant (typically 40), representing neuron-level purely spiking version of plus, LTP direction of weight change and capturing the function of CaMKII in the Kinase learning rule. Used for specialized learning and computational functions, statistics, instead of Act.
CaSpkP

Expand All @@ -103,9 +106,6 @@ const (
// CaSyn is spike-driven calcium trace for synapse-level Ca-driven learning: exponential integration of SpikeG * Spike at SynTau time constant (typically 30). Synapses integrate send.CaSyn * recv.CaSyn across M, P, D time integrals for the synaptic trace driving credit assignment in learning. Time constant reflects binding time of Glu to NMDA and Ca buffering postsynaptically, and determines time window where pre * post spiking must overlap to drive learning.
CaSyn

// CaSpkM is spike-driven calcium trace used as a neuron-level proxy for synpatic credit assignment factor based on continuous time-integrated spiking: exponential integration of SpikeG * Spike at MTau time constant (typically 5). Simulates a calmodulin (CaM) like signal at the most abstract level.
CaSpkM

// CaSpkPM is minus-phase snapshot of the CaSpkP value -- similar to ActM but using a more directly spike-integrated value.
CaSpkPM

Expand Down Expand Up @@ -133,6 +133,11 @@ const (
// SpkMaxCa is Ca integrated like CaSpkP but only starting at MaxCycStart cycle, to prevent inclusion of carryover spiking from prior theta cycle trial -- the PTau time constant otherwise results in significant carryover. This is the input to SpkMax
SpkMaxCa

SpkBin0
SpkBin1
SpkBin2
SpkBin3

// SpkMax is maximum CaSpkP across one theta cycle time window (max of SpkMaxCa) -- used for specialized algorithms that have more phasic behavior within a single trial, e.g., BG Matrix layer gating. Also useful for visualization of peak activity of neurons.
SpkMax

Expand Down
24 changes: 22 additions & 2 deletions axon/pathparams.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,28 @@ func (pj *PathParams) DWtSynCortex(ctx *Context, syni, si, ri, di uint32, layPoo
syCaP := SynCaV(ctx, syni, di, CaP) // slower but still fast time scale, drives Potentiation
syCaD := SynCaV(ctx, syni, di, CaD) // slow time scale, drives Depression (one trial = 200 cycles)
pj.Learn.KinaseCa.CurCa(ctx.SynCaCtr, caUpT, &syCaM, &syCaP, &syCaD) // always update, getting current Ca (just optimization)
dtr := syCaD // delta trace, caD reflects entire window
if pj.PathType == CTCtxtPath { // layer 6 CT pathway

rb0 := NrnV(ctx, ri, di, SpkBin0)
sb0 := NrnV(ctx, si, di, SpkBin0)
rb1 := NrnV(ctx, ri, di, SpkBin1)
sb1 := NrnV(ctx, si, di, SpkBin1)
rb2 := NrnV(ctx, ri, di, SpkBin2)
sb2 := NrnV(ctx, si, di, SpkBin2)
rb3 := NrnV(ctx, ri, di, SpkBin3)
sb3 := NrnV(ctx, si, di, SpkBin3)

b0 := 0.1 * (rb0 * sb0)
b1 := 0.1 * (rb1 * sb1)
b2 := 0.1 * (rb2 * sb2)
b3 := 0.1 * (rb3 * sb3)

pj.Learn.KinaseCa.FinalCa(b0, b1, b2, b3, &syCaM, &syCaP, &syCaD)

SetSynCaV(ctx, syni, di, CaM, syCaM)
SetSynCaV(ctx, syni, di, CaP, syCaP)
SetSynCaV(ctx, syni, di, CaD, syCaD)
dtr := syCaD // delta trace, caD reflects entire window
if pj.PathType == CTCtxtPath { // layer 6 CT pathway
dtr = NrnV(ctx, si, di, BurstPrv)
}
SetSynCaV(ctx, syni, di, DTr, dtr) // save delta trace for GUI
Expand Down
4 changes: 3 additions & 1 deletion axon/rand.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package axon

import (
"cogentcore.org/core/vgpu/gosl/slrand"
"cogentcore.org/core/vgpu/gosl/sltype"
)

//gosl:hlsl axonrand
Expand Down Expand Up @@ -31,7 +32,8 @@ func GetRandomNumber(index uint32, counter slrand.Counter, funIndex RandFunIndex
var randCtr slrand.Counter
randCtr = counter
randCtr.Add(uint32(funIndex))
ctr := randCtr.Uint2()
var ctr sltype.Uint2
ctr = randCtr.Uint2()
return slrand.Float(&ctr, index)
}

Expand Down
Binary file modified axon/shaders/gpu_applyext.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_betweengi.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_cycle.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_cyclepost.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_dwt.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_dwtfmdi.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_dwtsubmean.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_gather.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_laygi.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_minusneuron.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_minuspool.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_newstate_neuron.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_newstate_pool.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_plusneuron.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_plusstart.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_poolgi.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_postspike.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_sendspike.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_synca.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_wtfmdwt.spv
Binary file not shown.
2 changes: 0 additions & 2 deletions axon/typegen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ require (
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/pelletier/go-toml/v2 v2.1.2-0.20240227203013-2b69615b5d55 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/sajari/regression v1.0.1 // indirect
golang.org/x/image v0.15.0 // indirect
golang.org/x/mod v0.16.0 // indirect
golang.org/x/net v0.23.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/sajari/regression v1.0.1 h1:iTVc6ZACGCkoXC+8NdqH5tIreslDTT/bXxT6OmHR5PE=
github.com/sajari/regression v1.0.1/go.mod h1:NeG/XTW1lYfGY7YV/Z0nYDV/RGh3wxwd1yW46835flM=
github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
Expand Down
84 changes: 38 additions & 46 deletions kinase/linear.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,10 @@ import (
"fmt"
"math/rand"

"cogentcore.org/core/tensor"
"cogentcore.org/core/tensor/table"
"github.com/chewxy/math32"
)

type LinearState int

const (
StartCaSyn LinearState = iota
StartCaM
StartCaP
StartCaD

FinalCaSyn
FinalCaM
FinalCaP
FinalCaD

TotalSpikes

NLinearState
"github.com/sajari/regression"
)

// Linear performs a linear regression to approximate the synaptic Ca
Expand All @@ -39,9 +23,6 @@ type Linear struct {
// Kinase Synapse params
Synapse SynCaParams

// gain on S*R product for BinnedSums
BinProd float32 `default:"10"`

// total number of cycles (1 MSec) to run
NCycles int `min:"10" default:"200"`

Expand Down Expand Up @@ -91,13 +72,12 @@ type Linear struct {
func (ls *Linear) Defaults() {
ls.Neuron.Defaults()
ls.Synapse.Defaults()
ls.BinProd = 10
ls.NCycles = 200
ls.PlusCycles = 50
ls.CyclesPerBin = 25
ls.MaxHz = 120
ls.CyclesPerBin = 50
ls.MaxHz = 100
ls.StepHz = 10
ls.NTrials = 10
ls.NTrials = 2
ls.Update()
}

Expand All @@ -124,7 +104,7 @@ func (ls *Linear) InitTable() {
if ls.Data.NumColumns() > 0 {
return
}
nneur := int(NLinearState*2) + ls.NumBins
nneur := ls.NumBins
ls.Data.AddIntColumn("Trial")
ls.Data.AddFloat64TensorColumn("Hz", []int{4}, "Send*Recv*Minus*Plus")
ls.Data.AddFloat64TensorColumn("State", []int{nneur}, "States")
Expand Down Expand Up @@ -154,9 +134,6 @@ type Neuron struct {
// neuron-level spike-driven Ca integration
CaSpkM, CaSpkP, CaSpkD float32

// regression variables
StartCaSyn float32

TotalSpikes float32

// binned count of spikes, for regression learning
Expand All @@ -174,7 +151,6 @@ func (kn *Neuron) Init() {
}

func (kn *Neuron) StartTrial() {
kn.StartCaSyn = kn.CaSyn
kn.TotalSpikes = 0
for i := range kn.BinnedSpikes {
kn.BinnedSpikes[i] = 0
Expand Down Expand Up @@ -250,14 +226,6 @@ func (ls *Linear) Run() {
}
}
}
fmt.Println("row:", row)
}

func (ls *Linear) SetNeurState(nr *Neuron, off, row int) {
ls.Data.SetTensorFloat1D("State", row, off, float64(nr.CaSyn))
ls.Data.SetTensorFloat1D("State", row, off+1, float64(nr.CaSpkM))
ls.Data.SetTensorFloat1D("State", row, off+2, float64(nr.CaSpkP))
ls.Data.SetTensorFloat1D("State", row, off+3, float64(nr.CaSpkD))
}

func (ls *Linear) SetSynState(sy *Synapse, row int) {
Expand All @@ -270,7 +238,7 @@ func (ls *Linear) SetSynState(sy *Synapse, row int) {
func (ls *Linear) SetBins(sn, rn *Neuron, off, row int) {
for i, s := range sn.BinnedSpikes {
r := rn.BinnedSpikes[i]
bs := r + s + ls.BinProd*r*s
bs := (r * s) / 10.0
ls.BinnedSums[i] = bs
ls.Data.SetTensorFloat1D("State", row, off+i, float64(bs))
}
Expand All @@ -286,10 +254,6 @@ func (ls *Linear) Trial(sendMinusHz, sendPlusHz, recvMinusHz, recvPlusHz float32
ls.Data.SetTensorFloat1D("Hz", row, 2, float64(recvMinusHz))
ls.Data.SetTensorFloat1D("Hz", row, 3, float64(recvPlusHz))

// capture starting
ls.SetNeurState(&ls.Send, 0, row)
ls.SetNeurState(&ls.Recv, int(NLinearState), row)

minusCycles := ls.NCycles - ls.PlusCycles

ls.StartTrial()
Expand Down Expand Up @@ -321,9 +285,37 @@ func (ls *Linear) Trial(sendMinusHz, sendPlusHz, recvMinusHz, recvPlusHz float32
ls.StdSyn.DWt = ls.StdSyn.CaP - ls.StdSyn.CaD

// capture final
ls.SetNeurState(&ls.Send, int(FinalCaSyn), row)
ls.SetNeurState(&ls.Recv, int(NLinearState+FinalCaSyn), row)
// ls.SetNeurState(&ls.Send, int(FinalCa), row)
// ls.SetNeurState(&ls.Recv, int(NLinearState+FinalCa), row)
// ls.Data.SetTensorFloat1D("State", row, int(TotalSpikes), float64(ls.Send.TotalSpikes))
// ls.Data.SetTensorFloat1D("State", row, int(NLinearState+TotalSpikes), float64(ls.Recv.TotalSpikes))
ls.SetSynState(&ls.StdSyn, row)

ls.SetBins(&ls.Send, &ls.Recv, int(NLinearState*2), row)
ls.SetBins(&ls.Send, &ls.Recv, 0, row)
}

// Regress runs the linear regression on the data
func (ls *Linear) Regress() {
for vi := 0; vi < 4; vi++ {
r := new(regression.Regression)
r.SetObserved("CaD")
for bi := 0; bi < ls.NumBins; bi++ {
r.SetVar(bi, fmt.Sprintf("Bin_%d", bi))
}

for row := 0; row < ls.Data.Rows; row++ {
st := ls.Data.Tensor("State", row).(*tensor.Float64)
cad := ls.Data.TensorFloat1D("StdCa", row, vi)
r.Train(regression.DataPoint(cad, st.Values))
}
r.Run()
fmt.Printf("Regression formula:\n%v\n", r.Formula)
fmt.Printf("Variance observed = %v\nVariance Predicted = %v", r.Varianceobserved, r.VariancePredicted)
fmt.Printf("\nR2 = %v\n", r.R2)
str := "{"
for ci := 0; ci <= ls.NumBins; ci++ {
str += fmt.Sprintf("%8.6g, ", r.Coeff(ci))
}
fmt.Println(str + "}")
}
}
20 changes: 20 additions & 0 deletions kinase/linear_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) 2024, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package kinase

import (
"testing"

"cogentcore.org/core/tensor/table"
)

func TestLinear(t *testing.T) {
var ls Linear
ls.Defaults()
ls.Init()
ls.Run()
ls.Data.SaveCSV("linear_data.tsv", table.Tab, table.Headers)
ls.Regress()
}
Loading

0 comments on commit b8cac23

Please sign in to comment.