diff --git a/docs_sphinx/resources b/docs_sphinx/resources index 8710177af..ae3160358 160000 --- a/docs_sphinx/resources +++ b/docs_sphinx/resources @@ -1 +1 @@ -Subproject commit 8710177af5509005cbb88d12dad6d3dced14f23e +Subproject commit ae3160358cf7c4e250bac34587e29d99329762e0 diff --git a/examples/frompapers/Diehl_Cook_2015.py b/examples/frompapers/Diehl_Cook_2015.py new file mode 100644 index 000000000..469f66c1b --- /dev/null +++ b/examples/frompapers/Diehl_Cook_2015.py @@ -0,0 +1,317 @@ +""" +Unsupervised learning using STDP +-------------------------------- +Diehl, P. U., & Cook, M. (2015). Unsupervised learning of digit +recognition using spike-timing-dependent plasticity. Frontiers in +computational neuroscience, 9, 99. + +This script replicates the small 2x400-model. It has no command line +parameters. Instead, you control it by changing the constants below +the imports. Run the script with MODE set to "train" which +(eventually) creates the files theta.npy and weights.npy in the +DATA_PATH directory. Rerun it with MODE set to "observe" to create the +assign.npy file in the same directory. Then, run "test" to create a +confusion matrix in confusion.npy. Finally, you can use "plot" to +plot the confusion matrix. The script also creates a few auxilliary +.npy files useful for analysis. The script requires the progressbar2 +library. + +MNIST_PATH should point to the directory storing the unzipped *-byte +MNIST files (e.g. from https://github.com/cvdfoundation/mnist). +For reasonable accuracy, N_TRAIN should be 50,000+ and N_OBSERVE 1,000+. + +Written in 2024 by Björn A. Lindqvist +""" +from brian2 import * +from collections import defaultdict +from pathlib import Path +from progressbar import progressbar +from random import randrange, seed as rseed +from struct import unpack +import numpy as np + +# Switch between "train", "observe", and "test" to tune parameters, +# observe excitatory spiking, and test accuracy, respectively. +# Use "plot" to plot the confusion matrix. +MODE = 'test' + +# Number of training, observation, and testing samples +N_TRAIN = 25_000 +N_OBSERVE = 2_000 +N_TEST = 1_000 + +# Random seed value +SEED = 42 + +# Storage paths +MNIST_PATH = Path('../mnist') +DATA_PATH = Path('data') + +# Number of weight save points +N_SAVE_POINTS = 100 + +# Don't change these values unless you know what you're doing. +N_INP = 784 +N_NEURONS = 400 +V_EXC_REST = -65 * mV +V_INH_REST = -60 * mV +INTENSITY = 2 + +# Weights of exc->inh and inh->exc synapses +W_EXC_INH = 10.4 +W_INH_EXC = 17.0 + +def save_npy(arr, path): + arr = np.array(arr) + print('%-9s %-15s => %-30s' % ('Saving', arr.shape, path)) + np.save(path, arr) + +def load_npy(path): + arr = np.load(path) + print('%-9s %-30s => %-15s' % ('Loading', path, arr.shape)) + return arr + +def read_mnist(training): + tag = 'train' if training else 't10k' + images = open(MNIST_PATH / ('%s-images-idx3-ubyte' % tag), 'rb') + images.read(4) + n_images = unpack('>I', images.read(4))[0] + n_rows = unpack(">I", images.read(4))[0] + n_cols = unpack(">I", images.read(4))[0] + + labels = open(MNIST_PATH / ('%s-labels-idx1-ubyte' % tag), 'rb') + labels.read(4) + x = np.frombuffer(images.read(), dtype = np.uint8) + x = x.reshape(n_images, -1) / 8.0 + y = np.frombuffer(labels.read(), dtype = np.uint8) + return x, y + +def build_network(training): + eqs = ''' + dv/dt = (v_rest - v + i_exc + i_inh) / tau_mem : volt (unless refractory) + i_exc = ge * -v : volt + i_inh = gi * (v_inh_base - v) : volt + dge/dt = -ge/(1 * ms) : 1 + dgi/dt = -gi/(2 * ms) : 1 + dtimer/dt = 1 : second + ''' + reset = 'v = %r; timer = 0 * ms' % V_EXC_REST + if training: + exc_eqs = eqs + ''' + dtheta/dt = -theta / (1e7 * ms) : volt + ''' + arr_theta = np.ones(N_NEURONS) * 20 * mV + reset += '; theta += 0.05 * mV' + else: + exc_eqs = eqs + ''' + theta : volt + ''' + arr_theta = load_npy(DATA_PATH / 'theta.npy') * volt + exc_eqs = Equations(exc_eqs, + tau_mem = 100 * ms, + v_rest = V_EXC_REST, + v_inh_base = -100 * mV) + # Note that this neuron has a bit of un unusual refractoriness mechanism: + # The membrane potential is clamped for 5ms, but spikes are prevented for 50ms + # This has been taken from the original code. + ng_exc = NeuronGroup( + N_NEURONS, exc_eqs, + threshold = 'v > (theta - 72 * mV) and (timer > 50 * ms)', + refractory = 5 * ms, + reset = reset, + method = 'euler', + name = 'exc') + ng_exc.v = V_EXC_REST + ng_exc.theta = arr_theta + + inh_eqs = Equations(eqs, + tau_mem = 10 * ms, + v_rest = V_INH_REST, + v_inh_base = -85 * mV) + ng_inh = NeuronGroup(N_NEURONS, inh_eqs, + threshold = 'v > -40 * mV', + refractory = 2 * ms, + reset = 'v = -45 * mV', + method = 'euler', + name = 'inh') + ng_inh.v = V_INH_REST + + syns_exc_inh = Synapses(ng_exc, ng_inh, + on_pre = 'ge_post += %f' % W_EXC_INH) + syns_exc_inh.connect(j = 'i') + + syns_inh_exc = Synapses(ng_inh, ng_exc, + on_pre = 'gi_post += %f' % W_INH_EXC) + syns_inh_exc.connect("i != j") + + pg_inp = PoissonGroup(N_INP, 0 * Hz, name = 'inp') + + # During training, inp->exc synapse weights are plastic. + model = 'w : 1' + on_post = '' + on_pre = 'ge_post += w' + if training: + on_pre += '; pre = 1.; w = clip(w - 0.0001 * post1, 0, 1.0)' + on_post += 'post2bef = post2; w = clip(w + 0.01 * pre * post2bef, 0, 1.0); post1 = 1.; post2 = 1.' + model += ''' + post2bef : 1 + dpre/dt = -pre/(20 * ms) : 1 (event-driven) + dpost1/dt = -post1/(20 * ms) : 1 (event-driven) + dpost2/dt = -post2/(40 * ms) : 1 (event-driven) + ''' + weights = (np.random.random(N_INP * N_NEURONS) + 0.01) * 0.3 + else: + weights = load_npy(DATA_PATH / 'weights.npy') + + syns_inp_exc = Synapses( + pg_inp, ng_exc, + model = model, + on_pre = on_pre, + on_post = on_post, + name = 'inp_exc' + ) + syns_inp_exc.connect(True) + syns_inp_exc.delay = 'rand() * 10 * ms' + syns_inp_exc.w = weights + + exc_mon = SpikeMonitor(ng_exc, name = 'sp_exc') + net = Network([pg_inp, ng_exc, ng_inh, + syns_inp_exc, syns_exc_inh, syns_inh_exc, + exc_mon]) + # Initialize + net.run(0 * ms) + return net + +def show_sample(net, sample, intensity): + exc_mon = net['sp_exc'] + prev = exc_mon.count[:] + net['inp'].rates = sample * intensity * Hz + net.run(350 * ms) + # Don't count spikes occuring during the 150 ms rest. + next = exc_mon.count[:] + net['inp'].rates = 0 * Hz + net.run(150 * ms) + pat = next - prev + cnt = np.sum(pat) + if cnt < 5: + return show_sample(net, sample, intensity + 1) + return pat + +def predict(groups, rates): + return np.argmax([rates[grp].mean() for grp in groups]) + +def test(): + conf = np.zeros((10, 10)) + assign = np.load(DATA_PATH / 'assign.npy') + groups = [np.where(assign == i)[0] for i in range(10)] + + X, Y = read_mnist(False) + net = build_network(False) + for i in progressbar(range(N_TEST)): + ix = randrange(len(X)) + exc = show_sample(net, X[ix], INTENSITY) + guess = predict(groups, exc) + real = Y[ix] + conf[real, guess] += 1 + + print('Accuracy: %6.3f' % (np.trace(conf) / np.sum(conf))) + conf = conf/conf.sum(axis=1)[:,None] + print(np.around(conf, 2)) + save_npy(conf, DATA_PATH / 'confusion.npy') + +def normalize_plastic_weights(syns): + conns = np.reshape(syns.w, (N_INP, N_NEURONS)) + col_sums = np.sum(conns, axis = 0) + factors = 78./ col_sums + conns *= factors + syns.w = conns.reshape(-1) + +def stats(net): + tick = defaultclock.timestep[:] + cnt = np.sum(net['sp_exc'].count[:]) + + inp_exc = net['inp_exc'] + w_mu = np.mean(inp_exc.w) + w_std = np.std(inp_exc.w) + + exc = net['exc'] + theta = exc.theta / mV + theta_mu = np.mean(theta) + theta_sig = np.std(theta) + return [tick, cnt, w_mu, w_std, theta_mu, theta_sig] + +def train(): + X, Y = read_mnist(True) + n_samples = X.shape[0] + net = build_network(True) + rows = [stats(net) + [-1]] + w_hist = [np.array(net['inp_exc'].w)] + + ratio = max(N_TRAIN // N_SAVE_POINTS, 1) + for i in progressbar(range(N_TRAIN)): + ix = i % n_samples + normalize_plastic_weights(net['inp_exc']) + show_sample(net, X[ix], INTENSITY) + rows.append(stats(net) + [Y[ix]]) + if i % ratio == 0: + w_hist.append(np.array(net['inp_exc'].w)) + + save_npy(rows, DATA_PATH / 'train_stats.npy') + save_npy(w_hist, DATA_PATH / 'train_w_hist.npy') + save_npy(net['inp_exc'].w, DATA_PATH / 'weights.npy') + save_npy(net['exc'].theta, DATA_PATH / 'theta.npy') + +def observe(): + X, Y = read_mnist(True) + n_samples = X.shape[0] + net = build_network(False) + rows = [stats(net) + [-1]] + responses = defaultdict(list) + + for i in progressbar(range(N_OBSERVE)): + ix = i % n_samples + sample = X[ix] + cls = Y[ix] + exc = show_sample(net, sample, INTENSITY) + rows.append(stats(net) + [Y[ix]]) + responses[cls].append(exc) + + res = np.zeros((10, N_NEURONS)) + for cls, vals in responses.items(): + res[cls] = np.array(vals).mean(axis = 0) + + assign = np.argmax(res, axis = 0) + save_npy(assign, DATA_PATH / 'assign.npy') + save_npy(rows, DATA_PATH / 'observe_stats.npy') + +def plot(): + conf = np.load(DATA_PATH / "confusion.npy") + + import matplotlib.pyplot as plt + + plt.imshow(100*conf, interpolation="nearest", cmap=plt.cm.Blues) + for i, j in itertools.product(range(conf.shape[0]), range(conf.shape[1])): + if conf[i, j] == 0: + continue + plt.text( + j, + i, + f"{round(100*conf[i, j])}%", + horizontalalignment="center", + verticalalignment="center", + color="white" if conf[i, j] > 0.5 else "black", + ) + plt.colorbar() + plt.xticks(range(10)) + plt.yticks(range(10)) + plt.xlabel("Predicted label") + plt.ylabel("True label") + plt.show() + + +if __name__ == '__main__': + seed(SEED) + rseed(SEED) + DATA_PATH.mkdir(parents = True, exist_ok = True) + cmds = dict(train=train, observe=observe, test=test, plot=plot) + cmds[MODE]()