Fitting a GLM to discrete-time events #315
Replies: 1 comment 5 replies
-
@AamnaLawrence thanks the discussion and for sharing the data! I’ve looked into this more closely, and your approach to modeling events is sound. However, there are a few important considerations:
Interpreting the results
This first script can be used for wrangling with the basis parameters until we are satisfied with the fit (in terms of the score we get and the prediction of the peri-event modulation). # Fit a GLM
import pynapple as nap
import numpy as np
import matplotlib.pyplot as plt
import nemos as nmo
import jax
jax.config.update("jax_enable_x64", True)
# load data and store events in a TsGroup
data = nap.load_file("/Users/ebalzani/Downloads/NWB_phy.nwb")
lever = data["LeverInsert"]
rew = data["Reward"]
events = nap.TsGroup(
{
1: nap.Ts(lever.t),
2: nap.Ts(rew.t),
},
metadata={"event_type": ["Lever Insert", "Reward"]}, time_support=nap.IntervalSet(0, np.max(lever.t)))
# print number of events
print(f"Num reward events: {len(rew)}")
print(f"Num lever events: {len(lever)}")
# filter for active neurons (at least > 1Hz)
spikes_real = data["units"]
spikes_real = spikes_real[spikes_real.rate>1]
# select a neuron, here to find this one I plotted all the peri-events and pick one that looked modulated by eye
# as a starting point to design a model.
# from the peri-event you guess if the rate modulation is causal (after the event you see a change in rate) or not (the rate pre and post event show some modulation)
neu_id = 482
f,axs=plt.subplots(1,2)
cc = 0
for cc, item in enumerate(events.items()):
i, ev = item
peth = nap.compute_perievent(
data=spikes_real[neu_id],
tref=ev,
minmax=(-3, 3),
time_unit="s")
axs[cc].plot(np.mean(peth.count(0.05), 1)/0.05, linewidth=3, color="red")
axs[cc].set_title(events.metadata["event_type"][i])
axs[cc].axvline(0, *axs[cc].get_ylim())
# use a 50-50 split to have some predictive power in the test set
duration = lever.time_support.end[-1] - lever.time_support.start[0]
train = nap.IntervalSet(0, duration * 0.5)
test = nap.IntervalSet(train.end[-1] + 0.0001, lever.time_support.end[-1])
# define a window size
bin_size = 0.01
window = int(2 / bin_size)
# bin events and spike times
binned_events = events.count(bin_size)
count = spikes_real.count(bin_size, ep=binned_events.time_support) # pass time support to make sure they span the same range
# define a basis:
# for the most fair the comparison use the same number of basis function and (same type, at least for the two variable that you are directly comparing)
# for this neuron, it looks like the lever press is causal (rate increases after the press), for the reward, it looks like an acausal modulation
# you can model both type of predictors ("causal" is the default and most common)
add_basis = (
nmo.basis.RaisedCosineLinearConv(11, window, width=4, label="Lever Insert") + # default is causal
nmo.basis.RaisedCosineLinearConv(11, window, width=4, label="Reward", conv_kwargs={"predictor_causality": "acausal"}) +
nmo.basis.RaisedCosineLogConv(11, window, width=4, label="Spike History")
)
# define train and test design matrix (restrict first to avoid border effects)
X_train = add_basis.compute_features(
binned_events[:, 0].restrict(train),
binned_events[:, 1].restrict(train),
count.loc[neu_id].restrict(train),
)
X_test = add_basis.compute_features(
binned_events[:, 0].restrict(test),
binned_events[:, 1].restrict(test),
count.loc[neu_id].restrict(test),
)
# first check the the performance of an un-regularized model
model = nmo.glm.GLM(solver_name="LBFGS", solver_kwargs={"tol": 10 ** -12}).fit(
X_train, count.loc[neu_id].restrict(train)
)
# print the scores: if there is a massive drop in score then we are over fitting.
# if the test set is too small the test score has too much variance so the K-fold directly may not be very informative
score_train = model.score(X_train, count.loc[neu_id].restrict(train), score_type="pseudo-r2-McFadden")
score_test = model.score(X_test, count.loc[neu_id].restrict(test), score_type="pseudo-r2-McFadden")
print(f"Train score: {score_train}\nTest score: {score_test}")
# compare the peri-event from the test set based on the predicted rate and the raw spikes
rate_test = model.predict(X_test) / bin_size
f,axs=plt.subplots(1,2)
cc = 0
for cc, item in enumerate(events.items()):
i, ev = item
peth = nap.compute_perievent(
data=spikes_real[neu_id].restrict(test),
tref=ev.restrict(test),
minmax=(-3, 3),
time_unit="s")
peth_rate = nap.compute_perievent_continuous(
data=rate_test,
tref=ev.restrict(test),
minmax=(-3, 3),
time_unit="s")
axs[cc].plot(np.mean(peth.count(0.05), 1)/0.05, linewidth=3, color="red", label="raw")
axs[cc].plot(np.nanmean(peth_rate, 1),label="model")
axs[cc].set_title(events.metadata["event_type"][i])
axs[cc].axvline(0, *axs[cc].get_ylim())
axs[cc].legend()
plt.show() Once this is satisfactory proceed with the GroupLasso ranking of the variable. (as a note, I quickly balanced the events - not in the script below because it was not systematic - and the variables seemed to be equally contributing for this neuron). from tqdm import tqdm # progress bar
# create some log-spaced regularization strength for group lasso
reg_str = np.geomspace(1E-5, 0.1, 20)
# one can use the whole data for this
X = add_basis.compute_features(binned_events[:,0], binned_events[:,1], count.loc[neu_id])
# define the variable grouping by constructing a mask
mask = np.zeros((len(add_basis), X.shape[1]))
cc = 0
for k, bas in enumerate(add_basis):
mask[k, cc:cc+bas.n_basis_funcs] = 1
cc += bas.n_basis_funcs
print(mask)
# initalize the arrays to store the norm of the coefficients for each predictor and the score
coeff_norms = np.zeros((len(reg_str), len(add_basis)))
scores = np.zeros(len(reg_str))
# loop over reg strength
for i, reg in tqdm(enumerate(reg_str)):
regularizer = nmo.regularizer.GroupLasso(mask=mask)
model = nmo.glm.GLM(
regularizer=regularizer, solver_kwargs={"tol": 10 ** -12},
regularizer_strength=reg
).fit(
X, count.loc[neu_id]
)
coeff_dict = add_basis.split_by_feature(model.coef_, axis=0)
# jax funciton that applies the norm to each vector in the coeff_dict equivalent of looping over the dict items
coeff_norms[i] = jax.tree_util.tree_leaves(jax.tree_util.tree_map(np.linalg.norm, coeff_dict))
scores[i] = model.score(X, count.loc[neu_id], score_type="pseudo-r2-McFadden")
# plot the results (since the events are unbalance, it may be misleading to interpret the output now, but if you balance
# the number of events, you can rank the variables from most to least significant)
# in general, as one adds more and more regularization, the score drops, and the norm of the coefficient too
f, axs = plt.subplots(2, 1, sharex=True)
axs[0].plot(reg_str, scores,"-ok")
axs[0].set_xscale('log')
axs[1].plot(reg_str, coeff_norms)
axs[1].legend([b.label for b in add_basis])
plt.show() If the basis configuration is appropriate for multiple neurons, you can use the |
Beta Was this translation helpful? Give feedback.
-
Hello everyone! I am trying to determine whether the activity of a neuron is modulated by discrete-time events in the task (lever insertion in the box, reward delivery etc). I prepare the events variable as follows:
An example of how the Lever Insert event looks like is:
When I try to compute the features for these time inputs, I get a bunch of nans. I have been playing around with different window sizes but it did not help me. Because of the possibly faulty feature matrix, all my GLM coefficients are 0. Here is the code I am using for finding the features:
Has anyone used discrete-time events as inputs to GLM? Any insights on choosing the right basis functions on NeMoS would be very helpful.
Thank you!!
Beta Was this translation helpful? Give feedback.
All reactions