diff --git a/bmtk/utils/reports/spike_trains/spike_trains_api.py b/bmtk/utils/reports/spike_trains/spike_trains_api.py index 21d163904..23ffde30d 100644 --- a/bmtk/utils/reports/spike_trains/spike_trains_api.py +++ b/bmtk/utils/reports/spike_trains/spike_trains_api.py @@ -24,7 +24,7 @@ import warnings from .core import SortOrder -from .spikes_file_writers import write_csv, write_sonata +from .spikes_file_writers import write_csv, write_sonata, write_nwb class SpikeTrainsAPI(object): @@ -169,7 +169,7 @@ def to_csv(self, path, mode='w', sort_order=SortOrder.none, **kwargs): write_csv(path=path, spiketrain_reader=self, mode=mode, sort_orders=sort_order, **kwargs) def to_nwb(self, path, mode='w', **kwargs): - raise NotImplemented() + write_nwb(path=path, spiketrain_reader=self, mode=mode, **kwargs) def merge(self, other): """Import Another SpikesTrain object into current file, always in-place. diff --git a/bmtk/utils/reports/spike_trains/spikes_file_writers.py b/bmtk/utils/reports/spike_trains/spikes_file_writers.py index 54ec32402..96dd7b9ac 100644 --- a/bmtk/utils/reports/spike_trains/spikes_file_writers.py +++ b/bmtk/utils/reports/spike_trains/spikes_file_writers.py @@ -24,7 +24,9 @@ import csv import h5py import numpy as np +from datetime import datetime +import bmtk from .core import SortOrder, csv_headers, col_population, find_conversion from .core import MPI_rank, comm_barrier from bmtk.utils.sonata.utils import add_hdf5_magic, add_hdf5_version @@ -145,3 +147,47 @@ def write_csv_itr(path, spiketrain_reader, mode='w', sort_order=SortOrder.none, csv_writer.writerow(c_data) comm_barrier() + + +def write_nwb(path, spiketrain_reader, mode='w', include_population=True, units='ms', **kwargs): + import pynwb + + path_dir = os.path.dirname(path) + if MPI_rank == 0 and path_dir and not os.path.exists(path_dir): + os.makedirs(path_dir) + + if MPI_rank == 0: + # Last checked pynwb doesn't support writing on multiple cores, must let first core do all the + # writing to NWB. + nwbfile = pynwb.NWBFile( + session_description='BMTK {} generated NWB spikes file'.format(bmtk.__version__), + identifier='Generated in-silico, no session id', # TODO: No idea what to put here? + session_start_time=datetime.now().astimezone(), + # experiment_description=str(session.experiment_metadata['experiment_id']) + ) + + if include_population: + nwbfile.add_unit_column(name="population", description="node population identifier") + add_unit = lambda nid, pop, st: nwbfile.add_unit(id=nid, spike_times=st, population=pop, node_id=nid) + else: + add_unit = lambda nid, pop, st: nwbfile.add_unit(id=nid, spike_times=st, node_id=nid) + + nwbfile.add_unit_column(name="node_id", description="id of each node within a population") + + for population in spiketrain_reader.populations: + for node_id in spiketrain_reader.node_ids(population=population): + spikes_times = spiketrain_reader.get_times(node_id=node_id, population=population) + if spikes_times is None or len(spikes_times) == 0: + # No spikes for given node, don't try to write to nwb + continue + + # sometimes bmtk/sonata may default to use different 32 or unsigned data-types which will cause + # nwb to throw a fit. Need to explicity convert data-types just in case. + spikes_times = spikes_times.astype('float64') + node_id = int(node_id) + add_unit(node_id, population, spikes_times) + + with pynwb.NWBHDF5IO(path, mode) as io: + io.write(nwbfile) + + comm_barrier()