Skip to content

Commit

Permalink
Support for multi-compartmental neurons (#6)
Browse files Browse the repository at this point in the history
* update class names

* enable to access the elements in `Container` by `__getitem__()` or `__getattr__()`

* add `__module__` for each class

* add `PointBased` for single-compartment neurons and `MultiCompartment` for multi-compartment neurons

* add `PointBased` for single-compartment neurons and `MultiCompartment` for multi-compartment neurons
  • Loading branch information
chaoming0625 authored Jul 9, 2024
1 parent 7ca0f2e commit f6afb3f
Show file tree
Hide file tree
Showing 20 changed files with 495 additions and 49 deletions.
9 changes: 0 additions & 9 deletions dendritex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@

from . import channels
from . import ions
# from .channels import *
# from .channels import __all__ as _channels_all
# from .ions import *
# from .ions import __all__ as _ions_all
# from .neurons import *
# from .neurons import __all__ as _membranes_all
from . import neurons
from ._base import *
from ._base import __all__ as _base_all
Expand All @@ -34,7 +28,4 @@
['neurons', 'ions', 'channels'] +
_base_all +
_integrators_all
# _ions_all +
# _channels_all +
# _membranes_all
)
57 changes: 55 additions & 2 deletions dendritex/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import numpy as np
from brainstate.mixin import _JointGenericAlias

from ._misc import set_module_as

__all__ = [
'DendriticDynamics',
'State4Integral',
Expand All @@ -40,7 +42,7 @@
#
# - DendriticDynamics
# - HHTypedNeuron
# - SingleCompartmentNeuron
# - SingleCompartment
# - IonChannel
# - Ion
# - Calcium
Expand All @@ -50,6 +52,7 @@
# - Channel
#


class State4Integral(bst.ShortTermState):
"""
A state that integrates the state of the system to the integral of the state.
Expand All @@ -60,6 +63,8 @@ class State4Integral(bst.ShortTermState):
"""

__module__ = 'dentritex'

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.derivative = None
Expand All @@ -75,6 +80,7 @@ class DendriticDynamics(bst.Dynamics):
n_compartment: The number of compartments in each neuron.
varshape: The shape of the state variables.
"""
__module__ = 'dentritex'

def __init__(
self,
Expand Down Expand Up @@ -130,6 +136,9 @@ def reset_state(self, *args, **kwargs):


class Container(bst.mixin.Mixin):
__module__ = 'dentritex'

_container_name: str

@staticmethod
def _get_elem_name(elem):
Expand Down Expand Up @@ -169,6 +178,27 @@ def _format_elements(child_type: type, *children_as_tuple, **children_as_dict):
res[k] = v
return res

def __getitem__(self, item):
"""Overwrite the slice access (`self['']`). """
children = self.__getattr__(self._container_name)
if item in children:
return children[item]
else:
raise ValueError(f'Unknown item {item}, we only found {list(children.keys())}')

def __getattr__(self, item):
"""Overwrite the dot access (`self.`). """
name = super().__getattribute__('_container_name')
if item == '_container_name':
return name
children = super().__getattribute__(name)
if item == name:
return children
if item in children:
return children[item]
else:
return super().__getattribute__(item)

def add_elem(self, *elems, **elements):
"""
Add new elements.
Expand All @@ -180,6 +210,8 @@ def add_elem(self, *elems, **elements):


class TreeNode(bst.mixin.Mixin):
__module__ = 'dentritex'

root_type: type

@staticmethod
Expand Down Expand Up @@ -216,8 +248,10 @@ def check_hierarchies(root: type, *leaves, check_fun: Callable = None, **named_l

class HHTypedNeuron(DendriticDynamics, Container):
"""
The base class for the Hodgkin-Huxley typed neuronal dynamics.
The base class for the Hodgkin-Huxley typed neuronal membrane dynamics.
"""
__module__ = 'dentritex'
_container_name = 'ion_channels'

def __init__(
self,
Expand All @@ -231,6 +265,17 @@ def __init__(
# attribute for ``Container``
self.ion_channels = bst.visible_module_dict(self._format_elements(IonChannel, **ion_channels))

def init_state(self, batch_size=None):
nodes = self.nodes(level=1, include_self=False).subset(IonChannel).values()
TreeNode.check_hierarchies(self.__class__, *nodes)
for channel in nodes:
channel.init_state(self.V.value, batch_size=batch_size)

def reset_state(self, batch_size=None):
nodes = self.nodes(level=1, include_self=False).subset(IonChannel).values()
for channel in nodes:
channel.reset_state(self.V.value, batch_size=batch_size)

def add_elem(self, *elems, **elements):
"""
Add new elements.
Expand All @@ -243,6 +288,8 @@ def add_elem(self, *elems, **elements):


class IonChannel(DendriticDynamics, TreeNode):
__module__ = 'dentritex'

def current(self, *args, **kwargs):
raise NotImplementedError

Expand Down Expand Up @@ -274,6 +321,8 @@ class Ion(IonChannel, Container):
size: The size of the simulation target.
name: The name of the object.
"""
__module__ = 'dentritex'
_container_name = 'channels'

# The type of the master object.
root_type = HHTypedNeuron
Expand Down Expand Up @@ -381,8 +430,10 @@ class MixIons(IonChannel, Container):
Args:
ions: Instances of ions. This option defines the master types of all children objects.
"""
__module__ = 'dentritex'

root_type = HHTypedNeuron
_container_name = 'channels'

def __init__(
self,
Expand Down Expand Up @@ -509,6 +560,7 @@ def _check_root(self, leaf):
)


@set_module_as('dentritex')
def mix_ions(*ions) -> MixIons:
"""Create mixed ions.
Expand All @@ -526,3 +578,4 @@ def mix_ions(*ions) -> MixIons:

class Channel(IonChannel):
"""Base class for ion channels."""
__module__ = 'dentritex'
5 changes: 5 additions & 0 deletions dendritex/_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jax

from ._base import State4Integral, DendriticDynamics
from ._misc import set_module_as

__all__ = [
'euler_step',
Expand All @@ -33,6 +34,7 @@ def tree_map(f, tree, *rest):
return jax.tree.map(f, tree, *rest, is_leaf=lambda a: isinstance(a, bu.Quantity))


@set_module_as('dentritex')
def euler_step(target: DendriticDynamics, t: jax.typing.ArrayLike, *args):
dt = bst.environ.get_dt()

Expand All @@ -57,6 +59,7 @@ def euler_step(target: DendriticDynamics, t: jax.typing.ArrayLike, *args):
target.after_integral(*args)


@set_module_as('dentritex')
def rk2_step(target: DendriticDynamics, t: jax.typing.ArrayLike, *args):
dt = bst.environ.get_dt()

Expand Down Expand Up @@ -93,6 +96,7 @@ def rk2_step(target: DendriticDynamics, t: jax.typing.ArrayLike, *args):
target.after_integral(*args)


@set_module_as('dentritex')
def rk3_step(target: DendriticDynamics, t: jax.typing.ArrayLike, *args):
dt = bst.environ.get_dt()

Expand Down Expand Up @@ -149,6 +153,7 @@ def rk3_step(target: DendriticDynamics, t: jax.typing.ArrayLike, *args):
target.after_integral(*args)


@set_module_as('dentritex')
def rk4_step(target: DendriticDynamics, t: jax.typing.ArrayLike, *args):
dt = bst.environ.get_dt()

Expand Down
26 changes: 26 additions & 0 deletions dendritex/_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


import brainunit as bu


def set_module_as(name: str):
def decorator(module):
module.__name__ = name
return module

return decorator

9 changes: 8 additions & 1 deletion dendritex/channels/calcium.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
class CalciumChannel(Channel):
"""Base class for Calcium ion channels."""

__module__ = 'dendritex.channels'

root_type = Calcium

def before_integral(self, V, Ca: IonInfo):
Expand Down Expand Up @@ -87,8 +89,8 @@ class ICaN_IS2008(CalciumChannel):
increase in the excitability of olfactory bulb interneurons.
J Neurophysiol 99: 187–199.
"""
__module__ = 'dendritex.channels'

'''The type of the master object.'''
root_type = Calcium

def __init__(
Expand Down Expand Up @@ -339,6 +341,7 @@ class ICaT_HM1992(_ICa_p2q_ss):
--------
ICa_p2q_form
"""
__module__ = 'dendritex.channels'

def __init__(
self,
Expand Down Expand Up @@ -439,6 +442,7 @@ class ICaT_HP1992(_ICa_p2q_ss):
--------
ICa_p2q_form
"""
__module__ = 'dendritex.channels'

def __init__(
self,
Expand Down Expand Up @@ -536,6 +540,7 @@ class ICaHT_HM1992(_ICa_p2q_ss):
--------
ICa_p2q_form
"""
__module__ = 'dendritex.channels'

def __init__(
self,
Expand Down Expand Up @@ -632,6 +637,7 @@ class ICaHT_Re1993(_ICa_p2q_markov):
Neuroscience 13.11 (1993): 4609-4621.
"""
__module__ = 'dendritex.channels'

def __init__(
self,
Expand Down Expand Up @@ -723,6 +729,7 @@ class ICaL_IS2008(_ICa_p2q_ss):
--------
ICa_p2q_form
"""
__module__ = 'dendritex.channels'

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions dendritex/channels/hyperpolarization_activated.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Ih_HM1992(Channel):
of neurophysiology 68, no. 4 (1992): 1373-1383.
"""
__module__ = 'dendritex.channels'

root_type = HHTypedNeuron

Expand Down
2 changes: 2 additions & 0 deletions dendritex/channels/leaky.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class LeakageChannel(Channel):
"""
Base class for leakage channel dynamics.
"""
__module__ = 'dendritex.channels'

root_type = HHTypedNeuron

Expand Down Expand Up @@ -56,6 +57,7 @@ class IL(LeakageChannel):
E : float
The reversal potential.
"""
__module__ = 'dendritex.channels'

def __init__(
self,
Expand Down
Loading

0 comments on commit f6afb3f

Please sign in to comment.