Skip to content

Commit

Permalink
~260 ms optimization of bias matrix numba
Browse files Browse the repository at this point in the history
  • Loading branch information
kushaangupta committed Jan 18, 2025
1 parent dd3b421 commit 5d5f140
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions neuro_py/ensemble/pairwise_bias_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Tuple

import nelpy as nel
import numba
import numpy as np
import sklearn
import sklearn.metrics
Expand All @@ -22,7 +23,7 @@ def bias_matrix(
fillneutral: float = 0.5
) -> np.ndarray:
r"""
Compute the bias matrix for a given sequence of spikes.
Compute the pairwise bias matrix for a given sequence of spikes.
Parameters
----------
Expand Down Expand Up @@ -85,7 +86,7 @@ def bias_matrix_fast(
return_counts: bool = False
) -> np.ndarray:
r"""
Compute the bias matrix for a given sequence of spikes.
Compute the pairwise bias matrix for a given sequence of spikes.
Parameters
----------
Expand Down Expand Up @@ -172,7 +173,7 @@ def bias_matrix_njit(
fillneutral: float = 0.5
) -> np.ndarray:
r"""
Compute the bias matrix for a given sequence of spikes.
Compute the pairwise bias matrix for a given sequence of spikes.
Parameters
----------
Expand Down Expand Up @@ -209,19 +210,24 @@ def bias_matrix_njit(
.. [1] Roth, Z. (2016). Analysis of neuronal sequences using pairwise
biases. arXiv, 65-67 (2016). https://arxiv.org/abs/1603.02916
"""
ibeforej = np.zeros((total_neurons, total_neurons)) # rows: i, cols: j
ibeforej = np.zeros((total_neurons, total_neurons))
prod_nspikes_ij = np.zeros((total_neurons, total_neurons))
bias = np.empty((total_neurons, total_neurons))

# Create boolean masks for all neurons in advance
masks = [neuron_ids == i for i in range(total_neurons)]
nrns_st = numba.typed.List()
for _ in range(total_neurons):
nrns_st.append(numba.typed.List.empty_list(np.float64))
for i, nrn_id in enumerate(neuron_ids):
nrns_st[nrn_id].append(spike_times[i])

# Build bias matrix
for i in range(total_neurons):
spikes_i = spike_times[masks[i]]
nspikes_i = spikes_i.size
spikes_i = np.asarray(nrns_st[i])
nspikes_i = len(spikes_i)

for j in range(i + 1, total_neurons):
spikes_j = spike_times[masks[j]]
nspikes_j = spikes_j.size
spikes_j = np.asarray(nrns_st[j])
nspikes_j = len(spikes_j)

if nspikes_i > 0 and nspikes_j > 0:
nspikes_ij = np.searchsorted(
Expand All @@ -232,7 +238,7 @@ def bias_matrix_njit(
jbeforei = prod_nspikes_ij - ibeforej
prod_nspikes_ij = prod_nspikes_ij + prod_nspikes_ij.T
ibeforej = ibeforej + jbeforei.T
bias = ibeforej / prod_nspikes_ij
bias = ibeforej / prod_nspikes_ij # no need to check for zero division
bias = np.where(prod_nspikes_ij == 0, fillneutral, bias)

return bias
Expand Down

0 comments on commit 5d5f140

Please sign in to comment.