Skip to content

Commit

Permalink
[math & dyn] add brainpy.math.exprel, and change the code in the …
Browse files Browse the repository at this point in the history
…corresponding HH neuron models to improve numerical computation accuracy (#557)

* merge (#6)

* [running] fix multiprocessing bugs (#547)

* [running] fix multiprocessing bugs

* fix tests

* [docs] Fix typo in docs (#549)

* ⬆️ Bump conda-incubator/setup-miniconda from 2 to 3 (#551)

Bumps [conda-incubator/setup-miniconda](https://github.com/conda-incubator/setup-miniconda) from 2 to 3.
- [Release notes](https://github.com/conda-incubator/setup-miniconda/releases)
- [Changelog](https://github.com/conda-incubator/setup-miniconda/blob/main/CHANGELOG.md)
- [Commits](conda-incubator/setup-miniconda@v2...v3)

---
updated-dependencies:
- dependency-name: conda-incubator/setup-miniconda
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* updates  (#550)

* [running] fix multiprocessing bugs

* fix tests

* [doc] update doc

* update

* [math] add `brainpy.math.gpu_memory_preallocation()` for controlling GPU memory preallocation

* [math] `clear_buffer_memory` support to clear array and compilation both

* [dyn] compatible old version of `.reset_state()` function

* [setup] update installation info

* ``brainpy.math.defjvp`` and ``brainpy.math.XLACustomOp.defjvp`` (#554)

* [running] fix multiprocessing bugs

* fix tests

* [doc] update doc

* update

* [math] add `brainpy.math.gpu_memory_preallocation()` for controlling GPU memory preallocation

* [math] `clear_buffer_memory` support to clear array and compilation both

* [dyn] compatible old version of `.reset_state()` function

* [setup] update installation info

* [install] upgrade dependency

* updates

* [math] add `brainpy.math.defjvp`, support to define jvp rules for Primitive with multiple results. See examples in `test_ad_support.py`

* [math] add `brainpy.math.XLACustomOp.defjvp`

* [doc] upgrade `brainpy.math.defjvp` docstring

* ⬆️ Bump actions/setup-python from 4 to 5 (#555)

Bumps [actions/setup-python](https://github.com/actions/setup-python) from 4 to 5.
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](actions/setup-python@v4...v5)

---
updated-dependencies:
- dependency-name: actions/setup-python
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: Sichao He <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [math] add experimental `brainpy.math.exprel`

* [delay] move delays in models of `brainpy.synapses` module into new delay register API `DynamicalSystem.register_local_delay()` and `DynamicalSystem.get_local_delay()`

* [math & dyn] add `brainpy.math.exprel`, and change the code in the corresponding HH neuron models to improve numerical computation accuracy

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: Sichao He <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 10, 2023
1 parent 5be1834 commit b1e80e8
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 31 deletions.
24 changes: 14 additions & 10 deletions brainpy/_src/dyn/neurons/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ def __init__(
self.reset_state(self.mode)

# m channel
m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
# m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
m_alpha = lambda self, V: 1. / bm.exprel(-(V + 40) / 10)
m_beta = lambda self, V: 4.0 * bm.exp(-(V + 65) / 18)
m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V))
dm = lambda self, m, t, V: self.m_alpha(V) * (1 - m) - self.m_beta(V) * m
Expand All @@ -360,7 +361,8 @@ def __init__(
dh = lambda self, h, t, V: self.h_alpha(V) * (1 - h) - self.h_beta(V) * h

# n channel
n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
# n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
n_alpha = lambda self, V: 0.1 / bm.exprel(-(V + 55) / 10)
n_beta = lambda self, V: 0.125 * bm.exp(-(V + 65) / 80)
n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V))
dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n
Expand All @@ -383,8 +385,9 @@ def reset_state(self, batch_size=None, **kwargs):

def dV(self, V, t, m, h, n, I):
I = self.sum_inputs(V, init=I)
I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
I_K = (self.gK * n ** 4.0) * (V - self.EK)
I_Na = (self.gNa * m * m * m * h) * (V - self.ENa)
n2 = n * n
I_K = (self.gK * n2 * n2) * (V - self.EK)
I_leak = self.gL * (V - self.EL)
dVdt = (- I_Na - I_K - I_leak + I) / self.C
return dVdt
Expand Down Expand Up @@ -516,8 +519,9 @@ class HH(HHLTC):
"""

def dV(self, V, t, m, h, n, I):
I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
I_K = (self.gK * n ** 4.0) * (V - self.EK)
I_Na = (self.gNa * m * m * m * h) * (V - self.ENa)
n2 = n * n
I_K = (self.gK * n2 * n2) * (V - self.EK)
I_leak = self.gL * (V - self.EL)
dVdt = (- I_Na - I_K - I_leak + I) / self.C
return dVdt
Expand Down Expand Up @@ -680,9 +684,7 @@ def update(self, x=None):
t = share.load('t')
dt = share.load('dt')
x = 0. if x is None else x

V, W = self.integral(self.V, self.W, t, x, dt)

spike = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
self.W.value = W
Expand Down Expand Up @@ -930,7 +932,8 @@ def reset_state(self, batch_size=None):
self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size)

def m_inf(self, V):
alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1)
# alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1)
alpha = 1. / bm.exprel(-0.1 * (V + 35))
beta = 4. * bm.exp(-(V + 60.) / 18.)
return alpha / (alpha + beta)

Expand All @@ -941,7 +944,8 @@ def dh(self, h, t, V):
return self.phi * dhdt

def dn(self, n, t, V):
alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1)
# alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1)
alpha = 1. / bm.exprel(-0.1 * (V + 34))
beta = 0.125 * bm.exp(-(V + 44) / 80)
dndt = alpha * (1 - n) - beta * n
return self.phi * dndt
Expand Down
8 changes: 4 additions & 4 deletions brainpy/_src/dynold/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr')

# register delay
self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike)
self.pre.register_local_delay("spike", self.name, delay_step)

def reset_state(self, batch_size=None):
self.output.reset_state(batch_size)
Expand All @@ -124,7 +124,7 @@ def reset_state(self, batch_size=None):
def update(self, pre_spike=None):
# pre-synaptic spikes
if pre_spike is None:
pre_spike = self.pre.get_delay_data("spike", self.delay_step)
pre_spike = self.pre.get_local_delay("spike", self.name)
pre_spike = bm.as_jax(pre_spike)
if self.stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
Expand Down Expand Up @@ -317,7 +317,7 @@ def __init__(
self.g = self.syn.g

# delay
self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike)
self.pre.register_local_delay("spike", self.name, delay_step)

def reset_state(self, batch_size=None):
self.syn.reset_state(batch_size)
Expand All @@ -328,7 +328,7 @@ def reset_state(self, batch_size=None):
def update(self, pre_spike=None):
# delays
if pre_spike is None:
pre_spike = self.pre.get_delay_data("spike", self.delay_step)
pre_spike = self.pre.get_local_delay("spike", self.name)
pre_spike = bm.as_jax(pre_spike)
if self.stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/dynold/synapses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def __init__(
mode=mode)

# delay
self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike)
self.pre.register_local_delay("spike", self.name, delay_step)

# synaptic dynamics
self.syn = syn
Expand All @@ -317,7 +317,7 @@ def __init__(

def update(self, pre_spike=None, stop_spike_gradient: bool = False):
if pre_spike is None:
pre_spike = self.pre.get_delay_data("spike", self.delay_step)
pre_spike = self.pre.get_local_delay("spike", self.name)
if stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
if self.stp is not None:
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _compatible_reset_state(self, *args, **kwargs):
the_top_layer_reset_state = True
warnings.warn(
'''
From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. See https://brainpy.tech/docs/tutorial_toolbox/state_saving_and_loading.html for details.
From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_saving_and_loading.html for details.
1. If you are resetting all states in a network by calling "net.reset_state(*args, **kwargs)", please use
"bp.reset_state(net, *args, **kwargs)" function, or "net.reset(*args, **kwargs)".
Expand Down
6 changes: 1 addition & 5 deletions brainpy/_src/integrators/ode/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@
.. [2] Hochbruck, M., & Ostermann, A. (2010). Exponential integrators. Acta Numerica, 19, 209-286.
"""

import logging

from functools import wraps
from brainpy import errors
from brainpy._src import math as bm
Expand Down Expand Up @@ -360,9 +358,7 @@ def integral(*args, **kwargs):
assert len(args) > 0
dt = kwargs.pop(C.DT, self.dt)
linear, derivative = value_and_grad(*args, **kwargs)
phi = bm.where(linear == 0.,
bm.ones_like(linear),
(bm.exp(dt * linear) - 1) / (dt * linear))
phi = bm.exprel(dt * linear)
return args[0] + dt * phi * derivative

return [(integral, vars, pars), ]
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/integrators/sde/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,8 +626,7 @@ def integral(*args, **kwargs):
assert len(args) > 0
dt = kwargs.pop('dt', self.dt)
linear, derivative = value_and_grad(*args, **kwargs)
linear = bm.as_jax(linear)
phi = jnp.where(linear == 0., jnp.ones_like(linear), (jnp.exp(dt * linear) - 1) / (dt * linear))
phi = bm.as_jax(bm.exprel(dt * linear))
return args[0] + dt * phi * derivative

return [(integral, vars, pars), ]
Expand Down
13 changes: 7 additions & 6 deletions brainpy/_src/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Array(object):
"""

__slots__ = ('_value', '_keep_sharding')
__slots__ = ('_value', )

def __init__(self, value, dtype: Any = None):
# array value
Expand Down Expand Up @@ -132,7 +132,7 @@ def value(self, value):
if value.dtype != self_value.dtype:
raise MathError(f"The dtype of the original data is {self_value.dtype}, "
f"while we got {value.dtype}.")
self._value = value.value if isinstance(value, Array) else value
self._value = value

def update(self, value):
"""Update the value of this Array.
Expand Down Expand Up @@ -1549,11 +1549,12 @@ def value(self):
Returns:
The stored data.
"""
v = self._value
# keep sharding constraints
if self._keep_sharding and hasattr(self._value, 'sharding') and (self._value.sharding is not None):
return jax.lax.with_sharding_constraint(self._value, self._value.sharding)
if self._keep_sharding and hasattr(v, 'sharding') and (v.sharding is not None):
return jax.lax.with_sharding_constraint(v, v.sharding)
# return the value
return self._value
return v

@value.setter
def value(self, value):
Expand All @@ -1574,6 +1575,6 @@ def value(self, value):
if value.dtype != self_value.dtype:
raise MathError(f"The dtype of the original data is {self_value.dtype}, "
f"while we got {value.dtype}.")
self._value = value.value if isinstance(value, Array) else value
self._value = value


39 changes: 38 additions & 1 deletion brainpy/_src/math/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from jax.tree_util import tree_map

from brainpy import check, tools
from .compat_numpy import fill_diagonal
from .environment import get_dt, get_int
from .ndarray import Array
from .compat_numpy import fill_diagonal
from .interoperability import as_jax

__all__ = [
'shared_args_over_time',
'remove_diag',
'clip_by_norm',
'exprel',
]


Expand Down Expand Up @@ -82,3 +84,38 @@ def f(l):
return l * clip_norm / jnp.maximum(jnp.sqrt(jnp.sum(l * l, axis=axis, keepdims=True)), clip_norm)

return tree_map(f, t)


def _exprel(x, threshold):
def true_f(x):
x2 = x * x
return 1. + x / 2. + x2 / 6. + x2 * x / 24.0 # + x2 * x2 / 120.

def false_f(x):
return (jnp.exp(x) - 1) / x

# return jax.lax.cond(jnp.abs(x) < threshold, true_f, false_f, x)
return jnp.where(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x)


def exprel(x, threshold: float = None):
"""Relative error exponential, ``(exp(x) - 1)/x``.
When ``x`` is near zero, ``exp(x)`` is near 1, so the numerical calculation of ``exp(x) - 1`` can
suffer from catastrophic loss of precision. ``exprel(x)`` is implemented to avoid the loss of
precision that occurs when ``x`` is near zero.
Args:
x: ndarray. Input array. ``x`` must contain real numbers.
threshold: float.
Returns:
``(exp(x) - 1)/x``, computed element-wise.
"""
x = as_jax(x)
if threshold is None:
if hasattr(x, 'dtype') and x.dtype == jnp.float64:
threshold = 1e-8
else:
threshold = 1e-5
return _exprel(x, threshold)
21 changes: 21 additions & 0 deletions brainpy/_src/math/tests/test_others.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

import brainpy.math as bm
from scipy.special import exprel

from unittest import TestCase


class Test_exprel(TestCase):
def test1(self):
for x in [1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9]:
print(f'{exprel(x)}, {bm.exprel(x)}, {exprel(x) - bm.exprel(x):.10f}')
# self.assertEqual(exprel(x))

def test2(self):
bm.enable_x64()
for x in [1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9]:
print(f'{exprel(x)}, {bm.exprel(x)}, {exprel(x) - bm.exprel(x):.10f}')
# self.assertEqual(exprel(x))



1 change: 1 addition & 0 deletions brainpy/math/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
shared_args_over_time as shared_args_over_time,
remove_diag as remove_diag,
clip_by_norm as clip_by_norm,
exprel as exprel,
)

from brainpy._src.math.object_transform.naming import (
Expand Down

0 comments on commit b1e80e8

Please sign in to comment.