From 9100b57c789a8088856c60472fe73b40f55d82db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Lindqvist?= Date: Sun, 14 Jul 2024 17:20:54 +0200 Subject: [PATCH 1/9] Add Diehl and Cook (2015) small model Example cleans up and updates existing code to make it compatible with Brian 2 and Python 3. --- examples/frompapers/Diehl_Cook_2015.py | 295 +++++++++++++++++++++++++ 1 file changed, 295 insertions(+) create mode 100644 examples/frompapers/Diehl_Cook_2015.py diff --git a/examples/frompapers/Diehl_Cook_2015.py b/examples/frompapers/Diehl_Cook_2015.py new file mode 100644 index 000000000..ab812e5d5 --- /dev/null +++ b/examples/frompapers/Diehl_Cook_2015.py @@ -0,0 +1,295 @@ +""" +Unsupervised learning using STDP +-------------------------------- +Diehl, P. U., & Cook, M. (2015). Unsupervised learning of digit +recognition using spike-timing-dependent plasticity. Frontiers in +computational neuroscience, 9, 99. + +This script replicates the small 2x400-model. It has no command line +parameters. Instead, you control it by changing the constants below +the imports. Run the script with MODE set to "train" which +(eventually) creates the files theta.npy and weights.npy in the +DATA_PATH directory. Rerun it with MODE set to "observe" to create the +assign.npy file in the same directory. Finally, run "test" to create a +confusion matrix in confusion.npy. The script also creates a few +auxilliary .npy files useful for analysis. The script requires the +progressbar2 library. + +MNIST_PATH should point to the directory storing the unzipped *-byte +MNIST files. For reasonable accuracy, N_TRAIN should be 50,000+ and +N_OBSERVE 1,000+. + +Written in 2024 by Björn A. Lindqvist +""" +from brian2 import * +from collections import defaultdict +from pathlib import Path +from progressbar import progressbar +from random import randrange, seed as rseed +from struct import unpack +import numpy as np + +# Switch between "train", "observe", and "test" to tune parameters, +# observe excitatory spiking, and test accuracy, respectively. +MODE = 'test' + +# Number of training, observation, and testing samples +N_TRAIN = 200_000 +N_OBSERVE = 2_000 +N_TEST = 1_000 + +# Random seed value +SEED = 42 + +# Storage paths +MNIST_PATH = Path('../mnist') +DATA_PATH = Path('data') + +# Number of weight save points +N_SAVE_POINTS = 100 + +# Don't change these values unless you know what you're doing. +N_INP = 784 +N_NEURONS = 400 +V_EXC_REST = -65 * mV +V_INH_REST = -60 * mV +INTENSITY = 2 + +# Weights of exc->inh and inh->exc synapses +W_EXC_INH = 10.4 +W_INH_EXC = 17.0 + +def save_npy(arr, path): + arr = np.array(arr) + print('%-9s %-15s => %-30s' % ('Saving', arr.shape, path)) + np.save(path, arr) + +def load_npy(path): + arr = np.load(path) + print('%-9s %-30s => %-15s' % ('Loading', path, arr.shape)) + return arr + +def read_mnist(training): + tag = 'train' if training else 't10k' + images = open(MNIST_PATH / ('%s-images-idx3-ubyte' % tag), 'rb') + images.read(4) + n_images = unpack('>I', images.read(4))[0] + n_rows = unpack('>I', images.read(4))[0] + n_cols = unpack('>I', images.read(4))[0] + + labels = open(MNIST_PATH / ('%s-labels-idx1-ubyte' % tag), 'rb') + labels.read(4) + n_labels = unpack('>I', labels.read(4))[0] + x = np.frombuffer(images.read(), dtype = np.uint8) + x = x.reshape(n_images, -1) / 8.0 + y = np.frombuffer(labels.read(), dtype = np.uint8) + return x, y + +def build_network(training): + eqs = ''' + dv/dt = (v_rest - v + i_exc + i_inh) / tau_mem : volt (unless refractory) + i_exc = ge * -v : volt + i_inh = gi * (v_inh_base - v) : volt + dge/dt = -ge/(1 * ms) : 1 + dgi/dt = -gi/(2 * ms) : 1 + dtimer/dt = 0.1 : second + ''' + reset = 'v = %r; timer = 0 * ms' % V_EXC_REST + if training: + exc_eqs = eqs + ''' + dtheta/dt = -theta / (1e7 * ms) : volt + ''' + arr_theta = np.ones(N_NEURONS) * 20 * mV + reset += '; theta += 0.05 * mV' + else: + exc_eqs = eqs + ''' + theta : volt + ''' + arr_theta = load_npy(DATA_PATH / 'theta.npy') * volt + exc_eqs = Equations(exc_eqs, + tau_mem = 100 * ms, + v_rest = V_EXC_REST, + v_inh_base = -100 * mV) + ng_exc = NeuronGroup( + N_NEURONS, exc_eqs, + threshold = 'v > (theta - 72 * mV) and (timer > 5 * ms)', + refractory = 5 * ms, + reset = reset, + method = 'euler', + name = 'exc') + ng_exc.v = V_EXC_REST - 40 * mV + ng_exc.theta = arr_theta + + inh_eqs = Equations(eqs, + tau_mem = 10 * ms, + v_rest = V_INH_REST, + v_inh_base = -85 * mV) + ng_inh = NeuronGroup(N_NEURONS, inh_eqs, + threshold = 'v > -40 * mV', + refractory = 2 * ms, + reset = 'v = -45 * mV', + method = 'euler', + name = 'inh') + ng_inh.v = V_INH_REST - 40 * mV + + syns_exc_inh = Synapses(ng_exc, ng_inh, + model = 'w : 1', + on_pre = 'ge_post += w') + syns_exc_inh.connect(j = 'i') + syns_exc_inh.w = W_EXC_INH + + syns_inh_exc = Synapses(ng_inh, ng_exc, + model = 'w : 1', + on_pre = 'gi_post += w') + syns_inh_exc.connect(True) + + weights = (-np.identity(N_NEURONS) + 1) * W_INH_EXC + syns_inh_exc.w = weights.reshape(-1) + + pg_inp = PoissonGroup(N_INP, 0 * Hz, name = 'inp') + + # During training, inp->exc synapse weights are plastic. + model = 'w : 1' + on_post = '' + on_pre = 'ge_post += w' + if training: + on_pre += '; pre = 1.; w = clip(w - 0.0001 * post1, 0, 1.0)' + on_post += 'post2bef = post2; w = clip(w + 0.01 * pre * post2bef, 0, 1.0); post1 = 1.; post2 = 1.' + model += ''' + post2bef : 1 + dpre/dt = -pre/(20 * ms) : 1 (event-driven) + dpost1/dt = -post1/(20 * ms) : 1 (event-driven) + dpost2/dt = -post2/(40 * ms) : 1 (event-driven) + ''' + weights = (np.random.random(N_INP * N_NEURONS) + 0.01) * 0.3 + else: + weights = load_npy(DATA_PATH / 'weights.npy') + + syns_inp_exc = Synapses( + pg_inp, ng_exc, + model = model, + on_pre = on_pre, + on_post = on_post, + name = 'inp_exc' + ) + syns_inp_exc.connect(True) + syns_inp_exc.delay = 'rand() * 10 * ms' + syns_inp_exc.w = weights + + exc_mon = SpikeMonitor(ng_exc, name = 'sp_exc') + net = Network([pg_inp, ng_exc, ng_inh, + syns_inp_exc, syns_exc_inh, syns_inh_exc, + exc_mon]) + # Initialize + net.run(0 * ms) + return net + +def show_sample(net, sample, intensity): + exc_mon = net['sp_exc'] + prev = exc_mon.count[:] + net['inp'].rates = sample * intensity * Hz + net.run(350 * ms) + # Don't count spikes occuring during the 150 ms rest. + next = exc_mon.count[:] + net['inp'].rates = 0 * Hz + net.run(150 * ms) + pat = next - prev + cnt = np.sum(pat) + if cnt < 5: + return show_sample(net, sample, intensity + 1) + return pat + +def predict(groups, rates): + return np.argmax([rates[grp].mean() for grp in groups]) + +def test(): + conf = np.zeros((10, 10)) + assign = np.load(DATA_PATH / 'assign.npy') + groups = [np.where(assign == i)[0] for i in range(10)] + + X, Y = read_mnist(False) + net = build_network(False) + for i in progressbar(range(N_TEST)): + ix = randrange(len(X)) + exc = show_sample(net, X[ix], INTENSITY) + guess = predict(groups, exc) + real = Y[ix] + conf[real, guess] += 1 + + print('Accuracy: %6.3f' % (np.trace(conf) / np.sum(conf))) + conf = conf/conf.sum(axis=1)[:,None] + print(np.around(conf, 2)) + save_npy(conf, DATA_PATH / 'confusion.npy') + +def normalize_plastic_weights(syns): + conns = np.reshape(syns.w, (N_INP, N_NEURONS)) + col_sums = np.sum(conns, axis = 0) + factors = 78./ col_sums + conns *= factors + syns.w = conns.reshape(-1) + +def stats(net): + tick = int(defaultclock.t / defaultclock.dt) + cnt = np.sum(net['sp_exc'].count[:]) + + inp_exc = net['inp_exc'] + w_mu = np.mean(inp_exc.w) + w_std = np.std(inp_exc.w) + + exc = net['exc'] + theta = exc.theta / mV + theta_mu = np.mean(theta) + theta_sig = np.std(theta) + return [tick, cnt, w_mu, w_std, theta_mu, theta_sig] + +def train(): + X, Y = read_mnist(True) + n_samples = X.shape[0] + net = build_network(True) + rows = [stats(net) + [-1]] + w_hist = [np.array(net['inp_exc'].w)] + + ratio = max(N_TRAIN // N_SAVE_POINTS, 1) + for i in progressbar(range(N_TRAIN)): + ix = i % n_samples + normalize_plastic_weights(net['inp_exc']) + show_sample(net, X[ix], INTENSITY) + rows.append(stats(net) + [Y[ix]]) + if i % ratio == 0: + w_hist.append(np.array(net['inp_exc'].w)) + + save_npy(rows, DATA_PATH / 'train_stats.npy') + save_npy(w_hist, DATA_PATH / 'train_w_hist.npy') + save_npy(net['inp_exc'].w, DATA_PATH / 'weights.npy') + save_npy(net['exc'].theta, DATA_PATH / 'theta.npy') + +def observe(): + X, Y = read_mnist(True) + n_samples = X.shape[0] + net = build_network(False) + rows = [stats(net) + [-1]] + responses = defaultdict(list) + + for i in progressbar(range(N_OBSERVE)): + ix = i % n_samples + sample = X[ix] + cls = Y[ix] + exc = show_sample(net, sample, INTENSITY) + rows.append(stats(net) + [Y[ix]]) + responses[cls].append(exc) + + res = np.zeros((10, N_NEURONS)) + for cls, vals in responses.items(): + res[cls] = np.array(vals).mean(axis = 0) + + assign = np.argmax(res, axis = 0) + save_npy(assign, DATA_PATH / 'assign.npy') + save_npy(rows, DATA_PATH / 'observe_stats.npy') + +if __name__ == '__main__': + seed(SEED) + rseed(SEED) + np.random.seed(SEED) + DATA_PATH.mkdir(parents = True, exist_ok = True) + cmds = dict(train = train, observe = observe, test = test) + cmds[MODE]() From b3ace8e663d17391e548eebd97258c661826aac8 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Fri, 29 Nov 2024 10:46:13 +0100 Subject: [PATCH 2/9] Give a link for MNIST [ci skip] --- examples/frompapers/Diehl_Cook_2015.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/frompapers/Diehl_Cook_2015.py b/examples/frompapers/Diehl_Cook_2015.py index ab812e5d5..58afdbc12 100644 --- a/examples/frompapers/Diehl_Cook_2015.py +++ b/examples/frompapers/Diehl_Cook_2015.py @@ -16,8 +16,8 @@ progressbar2 library. MNIST_PATH should point to the directory storing the unzipped *-byte -MNIST files. For reasonable accuracy, N_TRAIN should be 50,000+ and -N_OBSERVE 1,000+. +MNIST files (e.g. from https://github.com/cvdfoundation/mnist). +For reasonable accuracy, N_TRAIN should be 50,000+ and N_OBSERVE 1,000+. Written in 2024 by Björn A. Lindqvist """ From efd0c6471095297825d56685933fd0526dd52923 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Fri, 29 Nov 2024 10:48:37 +0100 Subject: [PATCH 3/9] Simplify recurrent weight definitions [ci skip] --- examples/frompapers/Diehl_Cook_2015.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/examples/frompapers/Diehl_Cook_2015.py b/examples/frompapers/Diehl_Cook_2015.py index 58afdbc12..eaba4c216 100644 --- a/examples/frompapers/Diehl_Cook_2015.py +++ b/examples/frompapers/Diehl_Cook_2015.py @@ -133,18 +133,12 @@ def build_network(training): ng_inh.v = V_INH_REST - 40 * mV syns_exc_inh = Synapses(ng_exc, ng_inh, - model = 'w : 1', - on_pre = 'ge_post += w') + on_pre = 'ge_post += %f' % W_EXC_INH) syns_exc_inh.connect(j = 'i') - syns_exc_inh.w = W_EXC_INH syns_inh_exc = Synapses(ng_inh, ng_exc, - model = 'w : 1', - on_pre = 'gi_post += w') - syns_inh_exc.connect(True) - - weights = (-np.identity(N_NEURONS) + 1) * W_INH_EXC - syns_inh_exc.w = weights.reshape(-1) + on_pre = 'gi_post += %f' % W_INH_EXC) + syns_inh_exc.connect("i != j") pg_inp = PoissonGroup(N_INP, 0 * Hz, name = 'inp') From 418a5c052d870d8194d43fc5b4d9f0372d79a5ba Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Fri, 29 Nov 2024 10:55:22 +0100 Subject: [PATCH 4/9] Minor cleanup [ci skip] --- examples/frompapers/Diehl_Cook_2015.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/frompapers/Diehl_Cook_2015.py b/examples/frompapers/Diehl_Cook_2015.py index eaba4c216..beb203bda 100644 --- a/examples/frompapers/Diehl_Cook_2015.py +++ b/examples/frompapers/Diehl_Cook_2015.py @@ -74,12 +74,9 @@ def read_mnist(training): images = open(MNIST_PATH / ('%s-images-idx3-ubyte' % tag), 'rb') images.read(4) n_images = unpack('>I', images.read(4))[0] - n_rows = unpack('>I', images.read(4))[0] - n_cols = unpack('>I', images.read(4))[0] labels = open(MNIST_PATH / ('%s-labels-idx1-ubyte' % tag), 'rb') labels.read(4) - n_labels = unpack('>I', labels.read(4))[0] x = np.frombuffer(images.read(), dtype = np.uint8) x = x.reshape(n_images, -1) / 8.0 y = np.frombuffer(labels.read(), dtype = np.uint8) @@ -223,7 +220,7 @@ def normalize_plastic_weights(syns): syns.w = conns.reshape(-1) def stats(net): - tick = int(defaultclock.t / defaultclock.dt) + tick = defaultclock.timestep[:] cnt = np.sum(net['sp_exc'].count[:]) inp_exc = net['inp_exc'] @@ -283,7 +280,6 @@ def observe(): if __name__ == '__main__': seed(SEED) rseed(SEED) - np.random.seed(SEED) DATA_PATH.mkdir(parents = True, exist_ok = True) cmds = dict(train = train, observe = observe, test = test) cmds[MODE]() From 5e50ac339b4b7841f0f45dabcddc190a7110fd6d Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Tue, 17 Dec 2024 16:18:43 +0100 Subject: [PATCH 5/9] Change initialisation to use resting potentials --- examples/frompapers/Diehl_Cook_2015.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/frompapers/Diehl_Cook_2015.py b/examples/frompapers/Diehl_Cook_2015.py index beb203bda..d7cbb7e87 100644 --- a/examples/frompapers/Diehl_Cook_2015.py +++ b/examples/frompapers/Diehl_Cook_2015.py @@ -114,7 +114,7 @@ def build_network(training): reset = reset, method = 'euler', name = 'exc') - ng_exc.v = V_EXC_REST - 40 * mV + ng_exc.v = V_EXC_REST ng_exc.theta = arr_theta inh_eqs = Equations(eqs, @@ -127,7 +127,7 @@ def build_network(training): reset = 'v = -45 * mV', method = 'euler', name = 'inh') - ng_inh.v = V_INH_REST - 40 * mV + ng_inh.v = V_INH_REST syns_exc_inh = Synapses(ng_exc, ng_inh, on_pre = 'ge_post += %f' % W_EXC_INH) From fd86ed5d9aac6078679024ec32de20c453100ac6 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Tue, 17 Dec 2024 16:21:33 +0100 Subject: [PATCH 6/9] Comment on the timer variable and make it less confusing --- examples/frompapers/Diehl_Cook_2015.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/frompapers/Diehl_Cook_2015.py b/examples/frompapers/Diehl_Cook_2015.py index d7cbb7e87..fb1c5fa0b 100644 --- a/examples/frompapers/Diehl_Cook_2015.py +++ b/examples/frompapers/Diehl_Cook_2015.py @@ -89,7 +89,7 @@ def build_network(training): i_inh = gi * (v_inh_base - v) : volt dge/dt = -ge/(1 * ms) : 1 dgi/dt = -gi/(2 * ms) : 1 - dtimer/dt = 0.1 : second + dtimer/dt = 1 : second ''' reset = 'v = %r; timer = 0 * ms' % V_EXC_REST if training: @@ -107,9 +107,12 @@ def build_network(training): tau_mem = 100 * ms, v_rest = V_EXC_REST, v_inh_base = -100 * mV) + # Note that this neuron has a bit of un unusual refractoriness mechanism: + # The membrane potential is clamped for 5ms, but spikes are prevented for 50ms + # This has been taken from the original code. ng_exc = NeuronGroup( N_NEURONS, exc_eqs, - threshold = 'v > (theta - 72 * mV) and (timer > 5 * ms)', + threshold = 'v > (theta - 72 * mV) and (timer > 50 * ms)', refractory = 5 * ms, reset = reset, method = 'euler', From 51f5619210b2c50ff485b53819a11e6b01ffcf81 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Tue, 17 Dec 2024 16:27:28 +0100 Subject: [PATCH 7/9] Add a plot function --- examples/frompapers/Diehl_Cook_2015.py | 39 ++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/examples/frompapers/Diehl_Cook_2015.py b/examples/frompapers/Diehl_Cook_2015.py index fb1c5fa0b..03ed3c2b6 100644 --- a/examples/frompapers/Diehl_Cook_2015.py +++ b/examples/frompapers/Diehl_Cook_2015.py @@ -10,10 +10,11 @@ the imports. Run the script with MODE set to "train" which (eventually) creates the files theta.npy and weights.npy in the DATA_PATH directory. Rerun it with MODE set to "observe" to create the -assign.npy file in the same directory. Finally, run "test" to create a -confusion matrix in confusion.npy. The script also creates a few -auxilliary .npy files useful for analysis. The script requires the -progressbar2 library. +assign.npy file in the same directory. Then, run "test" to create a +confusion matrix in confusion.npy. Finally, you can use "plot" to +plot the confusion matrix. The script also creates a few auxilliary +.npy files useful for analysis. The script requires the progressbar2 +library. MNIST_PATH should point to the directory storing the unzipped *-byte MNIST files (e.g. from https://github.com/cvdfoundation/mnist). @@ -31,10 +32,11 @@ # Switch between "train", "observe", and "test" to tune parameters, # observe excitatory spiking, and test accuracy, respectively. +# Use "plot" to plot the confusion matrix. MODE = 'test' # Number of training, observation, and testing samples -N_TRAIN = 200_000 +N_TRAIN = 25_000 N_OBSERVE = 2_000 N_TEST = 1_000 @@ -280,9 +282,34 @@ def observe(): save_npy(assign, DATA_PATH / 'assign.npy') save_npy(rows, DATA_PATH / 'observe_stats.npy') +def plot(): + conf = np.load(DATA_PATH / "confusion.npy") + + import matplotlib.pyplot as plt + + plt.imshow(100*conf, interpolation="nearest", cmap=plt.cm.Blues) + for i, j in itertools.product(range(conf.shape[0]), range(conf.shape[1])): + if conf[i, j] == 0: + continue + plt.text( + j, + i, + f"{round(100*conf[i, j])}%", + horizontalalignment="center", + verticalalignment="center", + color="white" if conf[i, j] > 0.5 else "black", + ) + plt.colorbar() + plt.xticks(range(10)) + plt.yticks(range(10)) + plt.xlabel("Predicted label") + plt.ylabel("True label") + plt.show() + + if __name__ == '__main__': seed(SEED) rseed(SEED) DATA_PATH.mkdir(parents = True, exist_ok = True) - cmds = dict(train = train, observe = observe, test = test) + cmds = dict(train=train, observe=observe, test=test, plot=plot) cmds[MODE]() From 4c85f1c18d01258af2acfa0b8fda74aa8ae80e22 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Fri, 20 Dec 2024 10:25:47 +0100 Subject: [PATCH 8/9] fix incorrectly removed lines in read_mnist These values are not used, but the read pointer needs to advance --- examples/frompapers/Diehl_Cook_2015.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/frompapers/Diehl_Cook_2015.py b/examples/frompapers/Diehl_Cook_2015.py index 03ed3c2b6..469f66c1b 100644 --- a/examples/frompapers/Diehl_Cook_2015.py +++ b/examples/frompapers/Diehl_Cook_2015.py @@ -76,7 +76,9 @@ def read_mnist(training): images = open(MNIST_PATH / ('%s-images-idx3-ubyte' % tag), 'rb') images.read(4) n_images = unpack('>I', images.read(4))[0] - + n_rows = unpack(">I", images.read(4))[0] + n_cols = unpack(">I", images.read(4))[0] + labels = open(MNIST_PATH / ('%s-labels-idx1-ubyte' % tag), 'rb') labels.read(4) x = np.frombuffer(images.read(), dtype = np.uint8) From 8245aa99524b1986c8e6f26fe6b4f47ed1af1d7b Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Fri, 20 Dec 2024 10:29:56 +0100 Subject: [PATCH 9/9] Add figure for Diehl & Cook (2015) --- docs_sphinx/resources | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs_sphinx/resources b/docs_sphinx/resources index 32a89b4fc..ae3160358 160000 --- a/docs_sphinx/resources +++ b/docs_sphinx/resources @@ -1 +1 @@ -Subproject commit 32a89b4fcf80c211352c5077bd396ec7aa11b25f +Subproject commit ae3160358cf7c4e250bac34587e29d99329762e0