Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update/workshop 2024 #372

Merged
merged 12 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions bmtk/analyzer/ecp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import h5py
import matplotlib.pyplot as plt
import numpy as np
from decimal import Decimal

from bmtk.utils.sonata.config import SonataConfig
from bmtk.simulator.utils import simulation_reports
# from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm


def _get_ecp_path(ecp_path=None, config=None, report_name=None):
Expand Down Expand Up @@ -55,30 +57,54 @@ def plot_ecp(config_file=None, report_name=None, ecp_path=None, title=None, show
channels = ecp_h5['/ecp/channel_id'][()]
fig, axes = plt.subplots(len(channels), 1)
fig.text(0.04, 0.5, 'channel id', va='center', rotation='vertical')
v_min, v_max = ecp_h5['/ecp/data'][()].min(), ecp_h5['/ecp/data'][()].max()
# print(v_max - v_min)
# exit()

for idx, channel in enumerate(channels):
data = ecp_h5['/ecp/data'][:, idx]
# print(channel, np.min(data), np.max(data))
axes[idx].plot(time_traces, data)
axes[idx].spines["top"].set_visible(False)
axes[idx].spines["right"].set_visible(False)
axes[idx].set_yticks([])
axes[idx].set_ylabel(channel)
axes[idx].set_ylim([v_min, v_max])

if idx+1 != len(channels):
axes[idx].spines["bottom"].set_visible(False)
axes[idx].set_xticks([])
else:
axes[idx].set_xlabel('timestamps (ms)')
# scalebar = AnchoredSizeBar(axes[idx].transData,
# 2.0, '1 mV', 1,
# pad=0,
# borderpad=0,
# # color='b',
# frameon=True,
# # size_vertical=1.001,
# # fontproperties=fontprops
# )
#
# axes[idx].add_artist(scalebar)


if idx == 0:
scale_bar_size = (v_max-v_min)/2.0
scale_bar_label = f'{scale_bar_size:.2E}'
# print(scale_bar_label)
# exit()
fontprops = fm.FontProperties(size='x-small')

scalebar = AnchoredSizeBar(
axes[idx].transData,
size=scale_bar_size,
label=scale_bar_label,
loc='upper right',
pad=0.1,
borderpad=0.5,
sep=5,
# color='b',
frameon=False,
size_vertical=scale_bar_size,
# size_vertical=1.001,
fontproperties=fontprops
)
axes[idx].add_artist(scalebar)

# label = scalebar.txt_label
# label.set_rotation(270.0)
# label.set_verticalalignment('bottom')
# label.set_horizontalalignment('left')

if title:
fig.set_title(title)
Expand Down
3 changes: 2 additions & 1 deletion bmtk/simulator/bionet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from bmtk.simulator.bionet.pyfunction_cache import synapse_model, synaptic_weight, cell_model, add_weight_function, model_processing
from bmtk.simulator.bionet.pyfunction_cache import synapse_model, synaptic_weight, cell_model, add_weight_function, model_processing, \
spikes_generator
from bmtk.simulator.bionet.config import Config
from bmtk.simulator.bionet.bionetwork import BioNetwork
from bmtk.simulator.bionet.biosimulator import BioSimulator
Expand Down
45 changes: 37 additions & 8 deletions bmtk/simulator/bionet/biocell.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from bmtk.simulator.bionet.morphology import Morphology
import six

import neuron
from neuron import h

pc = h.ParallelContext() # object to access MPI methods
Expand Down Expand Up @@ -74,9 +75,6 @@ class BioCell(Cell):
def __init__(self, node, population_name, bionetwork):
super(BioCell, self).__init__(node=node, population_name=population_name, network=bionetwork)

# Set up netcon object that can be used to detect and communicate cell spikes.
self.set_spike_detector(bionetwork.spike_threshold)

# Determine number of segments and store a list of all sections.
self._secs = []
self._secs_by_id = []
Expand Down Expand Up @@ -105,6 +103,10 @@ def __init__(self, node, population_name, bionetwork):
self._seg_coords = None
self.build_morphology()

# Set up netcon object that can be used to detect and communicate cell spikes.
self.set_spike_detector(bionetwork.spike_threshold)


def build_morphology(self):
morph_base = Morphology.load(hobj=self.hobj, morphology_file=self.morphology_file, cache_seg_props=True)

Expand All @@ -126,6 +128,10 @@ def morphology(self):
"""The actual Morphology object instanstiation"""
return self._morphology

@property
def soma(self):
return self.morphology.soma

@property
def seg_coords(self):
"""Coordinates for segments/sections of the morphology, need to make public for ecp, xstim, and other
Expand All @@ -144,7 +150,7 @@ def seg_coords(self):
return self.morphology.seg_coords

def set_spike_detector(self, spike_threshold):
nc = h.NetCon(self.hobj.soma[0](0.5)._ref_v, None, sec=self.hobj.soma[0]) # attach spike detector to cell
nc = h.NetCon(self.soma[0](0.5)._ref_v, None, sec=self.soma[0])
nc.threshold = spike_threshold
pc.cell(self.gid, nc) # associate gid with spike detector

Expand Down Expand Up @@ -437,18 +443,41 @@ def __init__(self, node, population_name, bionetwork):
self._vecstim = h.VecStim()
self._vecstim.play(self._spike_trains)

self._precell_filter = bionetwork.spont_syns_filter
self._precell_filter = bionetwork.spont_syns_filter_pre
self._postcell_filter = bionetwork.spont_syns_filter_post
assert(isinstance(self._precell_filter, dict))

def _matches_filter(self, src_node):
def _matches_filter(self, src_node, trg_node=None):
"""Check to see if the presynaptic cell matches the criteria specified"""
for k, v in self._precell_filter.items():
# Some key may not show up as node_variable
if k == 'population' and k not in src_node:
key_val = src_node.population_name
else:
key_val = src_node[k]

if isinstance(v, (list, tuple)):
if key_val not in v:
return False
else:
if key_val != v:
return False

trg_node = trg_node or self
for k, v in self._postcell_filter.items():
# Some key may not show up as node_variable
if k == 'population' and k not in trg_node:
key_val = trg_node._node.population_name
else:
key_val = trg_node[k]

if isinstance(v, (list, tuple)):
if src_node[k] not in v:
if key_val not in v:
return False
else:
if src_node[k] != v:
if key_val != v:
return False

return True

def _set_connections(self, edge_prop, src_node, syn_weight, stim=None):
Expand Down
18 changes: 10 additions & 8 deletions bmtk/simulator/bionet/bionetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __init__(self):
self._gid_pool = GidPool()

self.has_spont_syns = False
self.spont_syns_filter = None
self.spont_syns_filter_pre = None
self.spont_syns_filter_post = None
self.spont_syns_times = None

@property
Expand All @@ -88,7 +89,7 @@ def gid_pool(self):
def py_function_caches(self):
return nrn

def set_spont_syn_activity(self, precell_filter, timestamps):
def set_spont_syn_activity(self, precell_filter, postcell_filter, timestamps):
self._model_type_map = {
'biophysical': BioCellSpontSyn,
'point_process': PointProcessCellSpontSyns,
Expand All @@ -98,7 +99,8 @@ def set_spont_syn_activity(self, precell_filter, timestamps):
}

self.has_spont_syns = True
self.spont_syns_filter = precell_filter
self.spont_syns_filter_pre = precell_filter
self.spont_syns_filter_post = postcell_filter
self.spont_syns_times = timestamps

def get_node_id(self, population, node_id):
Expand Down Expand Up @@ -134,12 +136,12 @@ def add_nodes(self, node_population):
self._gid_pool.add_pool(node_population.name, node_population.n_nodes())
super(BioNetwork, self).add_nodes(node_population)

def get_virtual_cells(self, population, node_id, spike_trains):
def get_virtual_cells(self, population, node_id, spike_trains, spikes_generator=None, sim=None):
if node_id in self._virtual_nodes[population]:
return self._virtual_nodes[population][node_id]
else:
node = self.get_node_id(population, node_id)
virt_cell = VirtualCell(node, population, spike_trains)
virt_cell = VirtualCell(node, population, spike_trains, spikes_generator, sim)
self._virtual_nodes[population][node_id] = virt_cell
return virt_cell

Expand All @@ -151,7 +153,7 @@ def get_disconnected_cell(self, population, node_id, spike_trains):
virt_cell = self._disconnected_source_cells[population][node_id]
else:
node = self.get_node_id(population, node_id)
virt_cell = VirtualCell(node, population, spike_trains)
virt_cell = VirtualCell(node, population, spike_trains, self)
self._disconnected_source_cells[population][node_id] = virt_cell

return virt_cell
Expand Down Expand Up @@ -369,7 +371,7 @@ def find_edges(self, source_nodes=None, target_nodes=None):

return selected_edges

def add_spike_trains(self, spike_trains, node_set):
def add_spike_trains(self, spike_trains, node_set, spikes_generator=None, sim=None):
self._init_connections()

src_nodes = [node_pop for node_pop in self.node_populations if node_pop.name in node_set.population_names()]
Expand All @@ -379,7 +381,7 @@ def add_spike_trains(self, spike_trains, node_set):
if edge_pop.virtual_connections:
for trg_nid, trg_cell in self._rank_node_ids[edge_pop.target_nodes].items():
for edge in edge_pop.get_target(trg_nid):
src_cell = self.get_virtual_cells(source_population, edge.source_node_id, spike_trains)
src_cell = self.get_virtual_cells(source_population, edge.source_node_id, spike_trains, spikes_generator, sim)
trg_cell.set_syn_connection(edge, src_cell, src_cell)

elif edge_pop.mixed_connections:
Expand Down
22 changes: 20 additions & 2 deletions bmtk/simulator/bionet/biosimulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def from_config(cls, config, network, set_recordings=True):
if sim_input.input_type == 'syn_activity':
network.set_spont_syn_activity(
precell_filter=sim_input.params['precell_filter'],
postcell_filter=sim_input.params.get('postcell_filter', {}),
timestamps=sim_input.params['timestamps']
)

Expand All @@ -346,13 +347,30 @@ def from_config(cls, config, network, set_recordings=True):

# TODO: Need to create a gid selector
for sim_input in inputs.from_config(config):
if sim_input.input_type == 'spikes' and sim_input.module in ['nwb', 'csv', 'sonata']:
if sim_input.input_type == 'spikes' and sim_input.module in ['nwb', 'csv', 'sonata', 'h5']:
io.log_info('Building virtual cell stimulations for {}'.format(sim_input.name))
path = sim_input.params['input_file']
spikes = SpikeTrains.load(path=path, file_type=sim_input.module, **sim_input.params)
# node_set_opts = sim_input.params.get('node_set', 'all')
node_set = network.get_node_set(sim_input.node_set)
network.add_spike_trains(spikes, node_set)
network.add_spike_trains(
spike_trains=spikes,
node_set=node_set,
spikes_generator=None,
sim=sim
)

elif sim_input.input_type == 'spikes' and sim_input.module == 'function':
io.log_info('Building virtual cell stimulations for {}'.format(sim_input.name))
# path = sim_input.params.get['input_file']
spikes_generator = sim_input.params['spikes_function']
node_set = network.get_node_set(sim_input.node_set)
network.add_spike_trains(
spike_trains=None,
node_set=node_set,
spikes_generator=spikes_generator,
sim=sim
)

elif sim_input.module == 'IClamp':
sim.add_mod(mods.IClampMod(input_type=sim_input.input_type, **sim_input.params))
Expand Down
6 changes: 6 additions & 0 deletions bmtk/simulator/bionet/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def get_connection_info(self):

def set_syn_connections(self, edge_prop, src_node, stim=None):
raise NotImplementedError

def get_section(self, sec_name, sec_index):
raise NotImplementedError

def __contains__(self, node_prop):
return node_prop in self._node

def __getitem__(self, node_prop):
return self._node[node_prop]
12 changes: 11 additions & 1 deletion bmtk/simulator/bionet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ def create_output_dir(self):
io.setup_output_dir(self.output_dir, self.log_file)

def load_nrn_modules(self):
nrn.load_neuron_modules(self.mechanisms_dir, self.templates_dir)
nrn.load_neuron_modules(
mechanisms_dir=self.mechanisms_dir,
templates_dir=self.templates_dir,
default_templates=self.use_default_templates,
use_old_import3d=self.use_old_import3d
)

def build_env(self):
self.io = io
Expand All @@ -52,3 +57,8 @@ def build_env(self):

pc.barrier()
self.load_nrn_modules()

def _set_class_props(self):
super(Config, self)._set_class_props()
self.use_old_import3d = self.run.get('use_old_import3d', False)
self.use_default_templates = self.run.get('use_old_import3d', True)
52 changes: 40 additions & 12 deletions bmtk/simulator/bionet/default_setters/cell_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import numpy as np
from neuron import h
import inspect
try:
from sklearn.decomposition import PCA
except Exception as e:
Expand All @@ -41,19 +42,46 @@
"""

def loadHOC(cell, template_name, dynamics_params):
# Get template to instantiate
template_call = getattr(h, template_name)
if dynamics_params is not None and 'params' in dynamics_params:
template_params = dynamics_params['params']
if isinstance(template_params, list):
# pass in a list of parameters
hobj = template_call(*template_params)
"""A Generic function for creating a cell object from a NEURON HOC Template (eg. a *.hoc file with
`begintemplate template_name` in header). It essentially tries to guess the correct parameters that need to be
called so may not work the majority of the times and require to be overloaded.

:param cell: A SONATA node object, can be used as a dict to get individual properties of current cell.
:param template_name: name of HOCTemplate as stored in "model_template" attribute (hoc:<template_name>).
:param dynamics_params: Dictionary containing contents of cell['dynamics_params'] as loaded from a json file or hdf5.
If cell does not have "dynamics_params" attributes then will be set to None.
"""
try:
# Get template to instantiate
template_call = getattr(h, template_name)
except AttributeError as ae:
io.log_error(
f'loadHOC was unable to load in Neuron HOC Template "{template_name}, '
'Make sure appropiate .hoc file is stored in templates_dir.'
)
raise ae

try:
if dynamics_params is not None and 'params' in dynamics_params:
template_params = dynamics_params['params']
if isinstance(template_params, list):
# pass in a list of parameters
hobj = template_call(*template_params)
else:
# only a single parameter
hobj = template_call(template_params)
elif cell.morphology_file is not None:
# instantiate template with no parameters
hobj = template_call(cell.morphology_file)
else:
# only a single parameter
hobj = template_call(template_params)
else:
# instantiate template with no parameters
hobj = template_call()
hobj = template_call()
except RuntimeError as rte:
io.log_error(
f'bmtk.simualtor.bionet.default_setters.cell_models.loadHOC function failed to load HOC template "{template_call}". '
'If Hoc Templates requires special call to be initialized consider using `bmtk.simulator.bionet.add_cell_model()` '
'to overwrite this function.'
)
raise rte

# TODO: All "all" section if it doesn't exist
# hobj.all = h.SectionList()
Expand Down
Loading
Loading