Skip to content

Commit

Permalink
Merge pull request #350 from kaeldai/documentation/nwb_update
Browse files Browse the repository at this point in the history
fixing issue with writing to nwb with mpi and adding some extra docum…
  • Loading branch information
kaeldai authored Feb 27, 2024
2 parents ba481de + aaa0882 commit 6e2db7a
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
4 changes: 2 additions & 2 deletions bmtk/utils/reports/spike_trains/spike_trains_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import warnings

from .core import SortOrder
from .spikes_file_writers import write_csv, write_sonata
from .spikes_file_writers import write_csv, write_sonata, write_nwb


class SpikeTrainsAPI(object):
Expand Down Expand Up @@ -169,7 +169,7 @@ def to_csv(self, path, mode='w', sort_order=SortOrder.none, **kwargs):
write_csv(path=path, spiketrain_reader=self, mode=mode, sort_orders=sort_order, **kwargs)

def to_nwb(self, path, mode='w', **kwargs):
raise NotImplemented()
write_nwb(path=path, spiketrain_reader=self, mode=mode, **kwargs)

def merge(self, other):
"""Import Another SpikesTrain object into current file, always in-place.
Expand Down
46 changes: 46 additions & 0 deletions bmtk/utils/reports/spike_trains/spikes_file_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import csv
import h5py
import numpy as np
from datetime import datetime

import bmtk
from .core import SortOrder, csv_headers, col_population, find_conversion
from .core import MPI_rank, comm_barrier
from bmtk.utils.sonata.utils import add_hdf5_magic, add_hdf5_version
Expand Down Expand Up @@ -145,3 +147,47 @@ def write_csv_itr(path, spiketrain_reader, mode='w', sort_order=SortOrder.none,
csv_writer.writerow(c_data)

comm_barrier()


def write_nwb(path, spiketrain_reader, mode='w', include_population=True, units='ms', **kwargs):
import pynwb

path_dir = os.path.dirname(path)
if MPI_rank == 0 and path_dir and not os.path.exists(path_dir):
os.makedirs(path_dir)

if MPI_rank == 0:
# Last checked pynwb doesn't support writing on multiple cores, must let first core do all the
# writing to NWB.
nwbfile = pynwb.NWBFile(
session_description='BMTK {} generated NWB spikes file'.format(bmtk.__version__),
identifier='Generated in-silico, no session id', # TODO: No idea what to put here?
session_start_time=datetime.now().astimezone(),
# experiment_description=str(session.experiment_metadata['experiment_id'])
)

if include_population:
nwbfile.add_unit_column(name="population", description="node population identifier")
add_unit = lambda nid, pop, st: nwbfile.add_unit(id=nid, spike_times=st, population=pop, node_id=nid)
else:
add_unit = lambda nid, pop, st: nwbfile.add_unit(id=nid, spike_times=st, node_id=nid)

nwbfile.add_unit_column(name="node_id", description="id of each node within a population")

for population in spiketrain_reader.populations:
for node_id in spiketrain_reader.node_ids(population=population):
spikes_times = spiketrain_reader.get_times(node_id=node_id, population=population)
if spikes_times is None or len(spikes_times) == 0:
# No spikes for given node, don't try to write to nwb
continue

# sometimes bmtk/sonata may default to use different 32 or unsigned data-types which will cause
# nwb to throw a fit. Need to explicity convert data-types just in case.
spikes_times = spikes_times.astype('float64')
node_id = int(node_id)
add_unit(node_id, population, spikes_times)

with pynwb.NWBHDF5IO(path, mode) as io:
io.write(nwbfile)

comm_barrier()

0 comments on commit 6e2db7a

Please sign in to comment.