Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aud filternet #351

Merged
merged 24 commits into from
Mar 20, 2024
Merged
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
cf1d477
Initial draft auditory filternet
xpliu16 Feb 7, 2023
58b61fa
Trying different temporal envelope
xpliu16 Feb 10, 2023
27e5bdc
Work on spectral modulation filters
xpliu16 Feb 11, 2023
3e25b69
Auditory filternet refinement
xpliu16 Feb 26, 2023
298f344
Padding and other adjustments to auditory filters
xpliu16 Mar 2, 2023
6b2260e
Merge pull request #288 from xpliu16/aud_filternet
xpliu16 Mar 2, 2023
e98abf7
Handling delay, better documentation of filters, passing of threshold
xpliu16 Mar 14, 2023
3119343
Merge pull request #290 from xpliu16/aud_filternet
xpliu16 Mar 14, 2023
12dd888
Miscellaneous updates
xpliu16 Mar 29, 2023
55ff413
Merge pull request #295 from xpliu16/aud_filternet
xpliu16 Mar 29, 2023
5397e16
Aud filternet b_t now variable
xpliu16 Jun 2, 2023
4a67a29
Merge branch 'develop' into aud_filternet
xpliu16 Jun 2, 2023
42a3361
Merge pull request #304 from xpliu16/aud_filternet
xpliu16 Jun 2, 2023
94326fe
Fixed psi and direction reading bug
xpliu16 Jun 28, 2023
83a554a
Merge branch 'develop' of https://github.com/xpliu16/bmtk into aud_fi…
xpliu16 Jun 28, 2023
9c1954a
Merge pull request #310 from xpliu16/aud_filternet
xpliu16 Jun 28, 2023
0e1cea0
Fixed bug on "delay"
xpliu16 Jun 28, 2023
f452c2c
Merge pull request #311 from xpliu16/aud_filternet
xpliu16 Jun 28, 2023
dd4ff02
resolving merge conflicts
kaeldai Mar 13, 2024
0187d16
fixing issue with spatial_filter
kaeldai Mar 13, 2024
9a7c039
fixing one last issue
kaeldai Mar 20, 2024
3c4e97f
fixing merge
kaeldai Mar 20, 2024
50ef1cf
merge was not fully saved
kaeldai Mar 20, 2024
ea4e7c3
fixing import issue
kaeldai Mar 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Auditory filternet refinement
xpliu16 committed Feb 26, 2023
commit 3e25b6939ec50013ae2877e3dcc10da397824472
5 changes: 3 additions & 2 deletions bmtk/simulator/filternet/auditory_processing.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ def __init__(self, aud_fn, low_lim=50.0, hi_lim=8000.0, sample_factor=4, downsam
:param hi_lim: float, high end of frequency range (Hz)
:param sample_factor: int,
"""
self.stim_array, self.sr = utils.wav_to_array(aud_fn)
self.stim_array, self.sr = utils.wav_to_array(aud_fn) # Allow relative size of stimulus
self.sample_factor = sample_factor # density of sampling, can be 1,2, or 4
self.low_lim = low_lim
self.hi_lim = hi_lim
@@ -46,9 +46,10 @@ def get_cochleagram(self, desired_sr=1000, interp_to_freq=False):
inds_keep = np.argwhere((center_freqs >= self.low_lim) & (center_freqs <= self.hi_lim))
center_freqs = center_freqs[inds_keep]
human_coch = human_coch[np.squeeze(inds_keep)]
minval = np.min(human_coch)
center_freqs_log = np.log2(center_freqs/np.min(center_freqs))
human_coch = resample_poly(human_coch, desired_sr, self.sr, axis=1)
human_coch[human_coch<=minval] = minval # resampling sometimes produces very small negative values
times = np.linspace(0, 1/desired_sr * (human_coch.shape[1]-1), human_coch.shape[1])


return human_coch, center_freqs_log, times
101 changes: 64 additions & 37 deletions bmtk/simulator/filternet/filtersimulator.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,8 @@
from bmtk.simulator.filternet.io_tools import io
from bmtk.utils.io.ioutils import bmtk_world_comm
from bmtk.simulator.filternet.auditory_processing import AuditoryInput
import scipy.io as syio
import os


class FilterSimulator(Simulator):
@@ -52,7 +54,7 @@ def add_movie(self, movie_type, params):
self.io.log_info('Normalizing movie data to (-1.0, 1.0).')
m_data = m_data*2.0/(contrast_max - contrast_min) - 1.0
else:
self.io.log_info('Movie data range is not normalized to (-1.0, 1.0).')
self.io.log_info('Movie data range ifind_paramss not normalized to (-1.0, 1.0).')

init_params = FilterSimulator.find_params(['row_range', 'col_range', 'labels', 'units', 'frame_rate',
't_range'], **params)
@@ -104,50 +106,69 @@ def add_movie(self, movie_type, params):
def add_audio(self, audio_type, params):
# Create cochleagram "movie" from audio wav file
audio_type = audio_type.lower() if isinstance(audio_type, string_types) else 'movie'
if audio_type == 'wav_file' or not audio_type:
if audio_type in ['wav_file', 'mat_file'] or not audio_type:
if 'data_file' in params:
if 'data_file' in params:
aud_file = params['data_file']
aud_file = params['data_file']
if audio_type == 'mat_file':
n = params['stim_number']
wav_file = os.path.splitext(aud_file)[0] + str(n) + '.wav'
if not os.path.exists(wav_file):
mat = syio.loadmat(params['data_file'])
data = np.squeeze(mat['timit_sents'][0, n])
sr = mat['aud_fs'][0][0]
scaled = np.int16(data / np.max(np.abs(data)) * 32768)
syio.wavfile.write(wav_file, sr, scaled)
else:
io.log_warning('Wav file already exists, please delete to overwrite.')
aud_file = wav_file

#elif 'data' in params:
# m_data = params['data']
else:
raise Exception('Could not find audio "data_file" in config to use as input.')
else:
raise Exception('Could not find audio "data_file" in config to use as input.')

aud = AuditoryInput(aud_file)
aud = AuditoryInput(aud_file)

#if params.get('frame_rate'):
# frame_rate = params.get('frame_rate')
#else:
init_params = FilterSimulator.find_params(['row_range', 'col_range', 'labels', 'units', 'frame_rate',
't_range'], **params)
if 'frame_rate' in init_params.keys():
frame_rate = init_params['frame_rate']
else:
frame_rate = 1000
#if params.get('frame_rate'):
# frame_rate = params.get('frame_rate')
#else:
init_params = FilterSimulator.find_params(['row_range', 'col_range', 'labels', 'units', 'frame_rate',
't_range', 'padding'], **params)
if 'frame_rate' in init_params.keys():
frame_rate = init_params['frame_rate']
else:
frame_rate = 1000

coch, center_freqs_log, times = aud.get_cochleagram(frame_rate, interp_to_freq=params['interp_to_freq'])
coch = coch.T
coch = coch[:,:, np.newaxis]
# Log step?
coch = np.log(coch)
coch, center_freqs_log, times = aud.get_cochleagram(frame_rate, interp_to_freq=params['interp_to_freq'])
coch = coch.T
#coch = np.log(coch)

normalize_data = params.get('normalize', None)
if normalize_data == 'full':
contrast_min, contrast_max = coch.min(), coch.max()
normalize_data = params.get('normalize', False)
if normalize_data:
self.io.log_info('Normalizing auditory input to (-1.0, 1.0).')
coch = (coch-contrast_min)*2.0/(contrast_max - contrast_min) - 1.0
else:
self.io.log_info('Auditory input range is not normalized to (-1.0, 1.0).')

amplitude = 100
coch *= amplitude
# Note, overwrites these if user supplied, instead taken from cochleagram
init_params['row_range'] = center_freqs_log
init_params['col_range'] = [0]
init_params['t_range'] = times
#? Frame_rate
# Dimensions of time, row, column
self._movies.append(Movie(coch, **init_params))
self.io.log_info('Normalizing auditory input to (-1.0, 1.0).')
coch = (coch-contrast_min)*2.0/(contrast_max - contrast_min) - 1.0
elif normalize_data == 'relative':
self.io.log_info('Auditory input is normalized maintaining relative amplitude')
coch = coch*3
else:
self.io.log_info('Auditory input range is not normalized.')

amplitude = 100
coch *= amplitude

#pad = np.full((500,coch.shape[1]), coch[0,:])
#coch = np.concatenate((pad,coch))

coch = coch[:,:, np.newaxis]

# Note, overwrites these if user supplied, instead taken from cochleagram
init_params['row_range'] = center_freqs_log
init_params['col_range'] = [0]
init_params['t_range'] = times
#? Frame_rate
# Dimensions of time, row, column
self._movies.append(Movie(coch, **init_params))
else:
raise Exception('Unknown audio type {}'.format(audio_type))

@@ -172,6 +193,12 @@ def run(self):
if cell_num > 0 and cell_num % ten_percent == 0:
io.log_debug(' Processing cell {} of {}{}.'.format(cell_num, n_cells_on_rank, rank_msg))
ts, f_rates = cell.lgn_cell_obj.evaluate(movie, **options)
if movie.padding:
f_rates = f_rates[int((movie.data.shape[0]-movie.data_orig.shape[0])/2) :
-int((movie.data.shape[0]-movie.data_orig.shape[0])/2)]
ts = ts[int((movie.data.shape[0]-movie.data_orig.shape[0])/2):
-int((movie.data.shape[0]-movie.data_orig.shape[0])/2)]
ts = ts-ts[0]

for mod in self._sim_mods:
mod.save(self, cell, ts, f_rates)
32 changes: 32 additions & 0 deletions bmtk/simulator/filternet/lgnmodel/cursor.py
Original file line number Diff line number Diff line change
@@ -106,11 +106,43 @@ class LNUnitCursor(KernelCursor):
"""
def __init__(self, lnunit, movie, threshold=0):
self.lnunit = lnunit
if hasattr(movie, "t_range_orig"): # Reset padded to original
movie.t_range = movie.t_range_orig
movie.data = movie.data_orig
movie.row_range = movie.row_range_orig
if isinstance(self.lnunit.linear_filter, SpatioTemporalFilter):
kernel = lnunit.get_spatiotemporal_kernel(movie.row_range, movie.col_range, movie.t_range, reverse=True,
threshold=threshold)
elif isinstance(self.lnunit.linear_filter, SpectroTemporalFilter):
kernel = lnunit.get_spectrotemporal_kernel(movie.row_range, movie.t_range, reverse=True, threshold=threshold)
if movie.padding:
if movie.padding == 'edge':
pre_pad = np.full((len(np.unique(kernel.t_inds))-1, movie.data.shape[1]),
movie.data[0, :, 0])
post_pad = np.full((len(np.unique(kernel.t_inds)) - 1, movie.data.shape[1]),
movie.data[-1, :, 0])
pre_pad = pre_pad[:, :, np.newaxis]
post_pad = post_pad[:, :, np.newaxis]
movie.data_orig = movie.data
movie.data = np.concatenate((pre_pad, movie.data, post_pad))
lower_pad = np.full((movie.data.shape[0], len(np.unique(kernel.row_inds)) - 1, 1),
np.reshape(movie.data[:, 0, 0], (-1, 1, 1)))
upper_pad = np.full((movie.data.shape[0], len(np.unique(kernel.row_inds)) - 1, 1),
np.reshape(movie.data[:, -1, 0], (-1, 1, 1)))
movie.data = np.hstack((lower_pad, movie.data, upper_pad))
kernel.t_range = np.linspace(kernel.t_range[0] - 2*pre_pad.shape[0] * 1/movie.frame_rate,
0, movie.data.shape[0])
movie.t_range_orig = movie.t_range
movie.t_range = kernel.t_range - kernel.t_range[0]
# Treating it like an image, although technically strange to pad to negative frequencies
kernel.row_range = np.linspace(kernel.row_range[0] -
lower_pad.shape[1] * (kernel.row_range[1]-kernel.row_range[0]),
kernel.row_range[-1] +
upper_pad.shape[1] * (kernel.row_range[-1]-kernel.row_range[-2]),
movie.data.shape[1])
movie.row_range_orig = movie.row_range
movie.row_range = kernel.row_range
kernel.row_inds = kernel.row_inds + lower_pad.shape[1]
else:
pass
kernel.apply_threshold(threshold)
31 changes: 21 additions & 10 deletions bmtk/simulator/filternet/lgnmodel/kernel.py
Original file line number Diff line number Diff line change
@@ -131,7 +131,7 @@ def rescale(self):
def normalize(self):
self.kernel /= np.abs(self.kernel.sum())

def normalize2(self, remove_offset=False):
def normalize2(self, remove_offset=True):
# Better for kernels that are not all positive
if remove_offset:
self.kernel -= self.kernel.mean() # Set amplitude offset to 0
@@ -207,12 +207,17 @@ def apply_threshold(self, threshold):
self.col_inds = self.col_inds[inds_to_keep]
self.kernel = self.kernel[inds_to_keep]

def full(self):
def full(self, truncate_col=False):
data = np.zeros((len(self.row_range), len(self.col_range)))
data[self.row_inds, self.col_inds] = self.kernel
return data
if truncate_col: # For spectrotemporal receptive fields where col is time dimension
ind_max = np.max(self.col_inds)
return data[:, :ind_max]
else:
return data

def imshow(self, ax=None, show=True, save_file_name=None, clim=None, colorbar=True):
def imshow(self, ax=None, show=True, save_file_name=None, clim=None, colorbar=True, truncate_col=False, xlabel=None,
ylabel=None):

from mpl_toolkits.axes_grid1 import make_axes_locatable

@@ -223,14 +228,20 @@ def imshow(self, ax=None, show=True, save_file_name=None, clim=None, colorbar=Tr
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)

data = self.full()

data = self.full(truncate_col=truncate_col)
if truncate_col:
col_max = self.col_range[np.max(self.col_inds)]
else:
col_max = self.col_range[-1]
if clim is not None:
im = ax.imshow(data, extent=(self.col_range[0], self.col_range[-1], self.row_range[0], self.row_range[-1]),
origin='lower', clim=clim, interpolation='none')
im = ax.imshow(data, extent=(self.col_range[0], col_max, np.squeeze(self.row_range[0]),
np.squeeze(self.row_range[-1])), origin='lower', clim=clim, interpolation='none')
else:
im = ax.imshow(data, extent=(self.col_range[0], self.col_range[-1], self.row_range[0], self.row_range[-1]),
origin='lower', interpolation='none')
im = ax.imshow(data, extent=(self.col_range[0], col_max, np.squeeze(self.row_range[0]),
np.squeeze(self.row_range[-1])), origin='lower', interpolation='none',
aspect='auto')
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)

if colorbar:
plt.colorbar(im, cax=cax)
3 changes: 2 additions & 1 deletion bmtk/simulator/filternet/lgnmodel/movie.py
Original file line number Diff line number Diff line change
@@ -5,10 +5,11 @@

class Movie(object):
def __init__(self, data, row_range=None, col_range=None, labels=('time', 'y', 'x'),
units=('second', 'pixel', 'pixel'), frame_rate=None, t_range=None):
units=('second', 'pixel', 'pixel'), frame_rate=None, t_range=None, padding=False):
self.data = data
self.labels = labels
self.units = units
self.padding=padding
assert(units[0] == 'second')

if t_range is None:
10 changes: 3 additions & 7 deletions bmtk/simulator/filternet/lgnmodel/waveletfilter.py
Original file line number Diff line number Diff line change
@@ -69,7 +69,6 @@ def get_kernel(self, row_range, col_range, threshold=0):

if self.theta != np.pi/2:
f_t = np.cos(self.theta) / self.Lambda
remove_offset = True
translate_t = -0.1 * self.Lambda / np.cos(self.theta)
env = (f_t * (x - translate_t)) ** (self.order_t - 1) * np.exp(-1 * self.b_t * f_t * (x - translate_t)) \
* np.exp(-.5 * (y - self.translate) ** 2 / self.sigma_f ** 2)
@@ -80,25 +79,22 @@ def get_kernel(self, row_range, col_range, threshold=0):
# The step response adapts slightly to a flat steady-state
f_t = 5
self.b_t = 10
remove_offset = False
translate_t = 0
env = (f_t * (x - translate_t)) ** (self.order_t - 1) * np.exp(-1 * self.b_t * f_t * (x - translate_t)) \
* np.sin(2*np.pi*f_t * (x - translate_t)) \
* np.exp(-.5 * (y - self.translate) ** 2 / self.sigma_f ** 2)
wave = np.cos(2 * np.pi / self.Lambda * ((y - self.translate) + self.psi))
wave = np.cos(2 * np.pi / self.Lambda * (y - self.translate) + self.psi)

filt = env * wave
filt /= np.max(filt)
print('max: ', np.max(filt))
print('min:', np.min(filt))

# Need to translate

threshold = 0.05 * np.max(filt)
kernel = Kernel2D.from_dense(row_range, col_range, filt, threshold=threshold)
#kernel.apply_threshold(threshold) # Already applied?
kernel.normalize2(remove_offset) # Scale up large kernels which can hit low float limit when normalized
kernel.normalize2() # Scale up large kernels which can hit low float limit when normalized
kernel.kernel *= self.amplitude # How do normalize and amplitude work together? seems like they would counteract each other?
#kernel.imshow()
#kernel.imshow(truncate_col=True, xlabel='Time(s)', ylabel='log(freq) re: 50 Hz')

return kernel