Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing issue with writing to nwb with mpi and adding some extra docum… #350

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bmtk/utils/reports/spike_trains/spike_trains_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import warnings

from .core import SortOrder
from .spikes_file_writers import write_csv, write_sonata
from .spikes_file_writers import write_csv, write_sonata, write_nwb


class SpikeTrainsAPI(object):
Expand Down Expand Up @@ -169,7 +169,7 @@ def to_csv(self, path, mode='w', sort_order=SortOrder.none, **kwargs):
write_csv(path=path, spiketrain_reader=self, mode=mode, sort_orders=sort_order, **kwargs)

def to_nwb(self, path, mode='w', **kwargs):
raise NotImplemented()
write_nwb(path=path, spiketrain_reader=self, mode=mode, **kwargs)

def merge(self, other):
"""Import Another SpikesTrain object into current file, always in-place.
Expand Down
46 changes: 46 additions & 0 deletions bmtk/utils/reports/spike_trains/spikes_file_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import csv
import h5py
import numpy as np
from datetime import datetime

import bmtk
from .core import SortOrder, csv_headers, col_population, find_conversion
from .core import MPI_rank, comm_barrier
from bmtk.utils.sonata.utils import add_hdf5_magic, add_hdf5_version
Expand Down Expand Up @@ -145,3 +147,47 @@ def write_csv_itr(path, spiketrain_reader, mode='w', sort_order=SortOrder.none,
csv_writer.writerow(c_data)

comm_barrier()


def write_nwb(path, spiketrain_reader, mode='w', include_population=True, units='ms', **kwargs):
import pynwb

path_dir = os.path.dirname(path)
if MPI_rank == 0 and path_dir and not os.path.exists(path_dir):
os.makedirs(path_dir)

if MPI_rank == 0:
# Last checked pynwb doesn't support writing on multiple cores, must let first core do all the
# writing to NWB.
nwbfile = pynwb.NWBFile(
session_description='BMTK {} generated NWB spikes file'.format(bmtk.__version__),
identifier='Generated in-silico, no session id', # TODO: No idea what to put here?
session_start_time=datetime.now().astimezone(),
# experiment_description=str(session.experiment_metadata['experiment_id'])
)

if include_population:
nwbfile.add_unit_column(name="population", description="node population identifier")
add_unit = lambda nid, pop, st: nwbfile.add_unit(id=nid, spike_times=st, population=pop, node_id=nid)
else:
add_unit = lambda nid, pop, st: nwbfile.add_unit(id=nid, spike_times=st, node_id=nid)

nwbfile.add_unit_column(name="node_id", description="id of each node within a population")

for population in spiketrain_reader.populations:
for node_id in spiketrain_reader.node_ids(population=population):
spikes_times = spiketrain_reader.get_times(node_id=node_id, population=population)
if spikes_times is None or len(spikes_times) == 0:
# No spikes for given node, don't try to write to nwb
continue

# sometimes bmtk/sonata may default to use different 32 or unsigned data-types which will cause
# nwb to throw a fit. Need to explicity convert data-types just in case.
spikes_times = spikes_times.astype('float64')
node_id = int(node_id)
add_unit(node_id, population, spikes_times)

with pynwb.NWBHDF5IO(path, mode) as io:
io.write(nwbfile)

comm_barrier()
Loading