From b1e80e802b9fa39e3bd2b2e48fc9d6d1a676efdd Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 10 Dec 2023 14:45:46 +0800 Subject: [PATCH] [math & dyn] add ``brainpy.math.exprel``, and change the code in the corresponding HH neuron models to improve numerical computation accuracy (#557) * merge (#6) * [running] fix multiprocessing bugs (#547) * [running] fix multiprocessing bugs * fix tests * [docs] Fix typo in docs (#549) * :arrow_up: Bump conda-incubator/setup-miniconda from 2 to 3 (#551) Bumps [conda-incubator/setup-miniconda](https://github.com/conda-incubator/setup-miniconda) from 2 to 3. - [Release notes](https://github.com/conda-incubator/setup-miniconda/releases) - [Changelog](https://github.com/conda-incubator/setup-miniconda/blob/main/CHANGELOG.md) - [Commits](https://github.com/conda-incubator/setup-miniconda/compare/v2...v3) --- updated-dependencies: - dependency-name: conda-incubator/setup-miniconda dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * updates (#550) * [running] fix multiprocessing bugs * fix tests * [doc] update doc * update * [math] add `brainpy.math.gpu_memory_preallocation()` for controlling GPU memory preallocation * [math] `clear_buffer_memory` support to clear array and compilation both * [dyn] compatible old version of `.reset_state()` function * [setup] update installation info * ``brainpy.math.defjvp`` and ``brainpy.math.XLACustomOp.defjvp`` (#554) * [running] fix multiprocessing bugs * fix tests * [doc] update doc * update * [math] add `brainpy.math.gpu_memory_preallocation()` for controlling GPU memory preallocation * [math] `clear_buffer_memory` support to clear array and compilation both * [dyn] compatible old version of `.reset_state()` function * [setup] update installation info * [install] upgrade dependency * updates * [math] add `brainpy.math.defjvp`, support to define jvp rules for Primitive with multiple results. See examples in `test_ad_support.py` * [math] add `brainpy.math.XLACustomOp.defjvp` * [doc] upgrade `brainpy.math.defjvp` docstring * :arrow_up: Bump actions/setup-python from 4 to 5 (#555) Bumps [actions/setup-python](https://github.com/actions/setup-python) from 4 to 5. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --------- Signed-off-by: dependabot[bot] Co-authored-by: Sichao He <1310722434@qq.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * [math] add experimental `brainpy.math.exprel` * [delay] move delays in models of `brainpy.synapses` module into new delay register API `DynamicalSystem.register_local_delay()` and `DynamicalSystem.get_local_delay()` * [math & dyn] add `brainpy.math.exprel`, and change the code in the corresponding HH neuron models to improve numerical computation accuracy --------- Signed-off-by: dependabot[bot] Co-authored-by: Sichao He <1310722434@qq.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- brainpy/_src/dyn/neurons/hh.py | 24 +++++++----- .../_src/dynold/synapses/abstract_models.py | 8 ++-- brainpy/_src/dynold/synapses/base.py | 4 +- brainpy/_src/dynsys.py | 2 +- brainpy/_src/integrators/ode/exponential.py | 6 +-- brainpy/_src/integrators/sde/normal.py | 3 +- brainpy/_src/math/ndarray.py | 13 ++++--- brainpy/_src/math/others.py | 39 ++++++++++++++++++- brainpy/_src/math/tests/test_others.py | 21 ++++++++++ brainpy/math/others.py | 1 + 10 files changed, 90 insertions(+), 31 deletions(-) create mode 100644 brainpy/_src/math/tests/test_others.py diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py index 7a985cb9d..97e612097 100644 --- a/brainpy/_src/dyn/neurons/hh.py +++ b/brainpy/_src/dyn/neurons/hh.py @@ -348,7 +348,8 @@ def __init__( self.reset_state(self.mode) # m channel - m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) + # m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) + m_alpha = lambda self, V: 1. / bm.exprel(-(V + 40) / 10) m_beta = lambda self, V: 4.0 * bm.exp(-(V + 65) / 18) m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V)) dm = lambda self, m, t, V: self.m_alpha(V) * (1 - m) - self.m_beta(V) * m @@ -360,7 +361,8 @@ def __init__( dh = lambda self, h, t, V: self.h_alpha(V) * (1 - h) - self.h_beta(V) * h # n channel - n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) + # n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) + n_alpha = lambda self, V: 0.1 / bm.exprel(-(V + 55) / 10) n_beta = lambda self, V: 0.125 * bm.exp(-(V + 65) / 80) n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V)) dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n @@ -383,8 +385,9 @@ def reset_state(self, batch_size=None, **kwargs): def dV(self, V, t, m, h, n, I): I = self.sum_inputs(V, init=I) - I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) - I_K = (self.gK * n ** 4.0) * (V - self.EK) + I_Na = (self.gNa * m * m * m * h) * (V - self.ENa) + n2 = n * n + I_K = (self.gK * n2 * n2) * (V - self.EK) I_leak = self.gL * (V - self.EL) dVdt = (- I_Na - I_K - I_leak + I) / self.C return dVdt @@ -516,8 +519,9 @@ class HH(HHLTC): """ def dV(self, V, t, m, h, n, I): - I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) - I_K = (self.gK * n ** 4.0) * (V - self.EK) + I_Na = (self.gNa * m * m * m * h) * (V - self.ENa) + n2 = n * n + I_K = (self.gK * n2 * n2) * (V - self.EK) I_leak = self.gL * (V - self.EL) dVdt = (- I_Na - I_K - I_leak + I) / self.C return dVdt @@ -680,9 +684,7 @@ def update(self, x=None): t = share.load('t') dt = share.load('dt') x = 0. if x is None else x - V, W = self.integral(self.V, self.W, t, x, dt) - spike = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.W.value = W @@ -930,7 +932,8 @@ def reset_state(self, batch_size=None): self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) def m_inf(self, V): - alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) + # alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) + alpha = 1. / bm.exprel(-0.1 * (V + 35)) beta = 4. * bm.exp(-(V + 60.) / 18.) return alpha / (alpha + beta) @@ -941,7 +944,8 @@ def dh(self, h, t, V): return self.phi * dhdt def dn(self, n, t, V): - alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) + # alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) + alpha = 1. / bm.exprel(-0.1 * (V + 34)) beta = 0.125 * bm.exp(-(V + 44) / 80) dndt = alpha * (1 - n) - beta * n return self.phi * dndt diff --git a/brainpy/_src/dynold/synapses/abstract_models.py b/brainpy/_src/dynold/synapses/abstract_models.py index 62b55a0e7..904cdd889 100644 --- a/brainpy/_src/dynold/synapses/abstract_models.py +++ b/brainpy/_src/dynold/synapses/abstract_models.py @@ -114,7 +114,7 @@ def __init__( self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr') # register delay - self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike) + self.pre.register_local_delay("spike", self.name, delay_step) def reset_state(self, batch_size=None): self.output.reset_state(batch_size) @@ -124,7 +124,7 @@ def reset_state(self, batch_size=None): def update(self, pre_spike=None): # pre-synaptic spikes if pre_spike is None: - pre_spike = self.pre.get_delay_data("spike", self.delay_step) + pre_spike = self.pre.get_local_delay("spike", self.name) pre_spike = bm.as_jax(pre_spike) if self.stop_spike_gradient: pre_spike = jax.lax.stop_gradient(pre_spike) @@ -317,7 +317,7 @@ def __init__( self.g = self.syn.g # delay - self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike) + self.pre.register_local_delay("spike", self.name, delay_step) def reset_state(self, batch_size=None): self.syn.reset_state(batch_size) @@ -328,7 +328,7 @@ def reset_state(self, batch_size=None): def update(self, pre_spike=None): # delays if pre_spike is None: - pre_spike = self.pre.get_delay_data("spike", self.delay_step) + pre_spike = self.pre.get_local_delay("spike", self.name) pre_spike = bm.as_jax(pre_spike) if self.stop_spike_gradient: pre_spike = jax.lax.stop_gradient(pre_spike) diff --git a/brainpy/_src/dynold/synapses/base.py b/brainpy/_src/dynold/synapses/base.py index 02a0355aa..a2bc1bdd5 100644 --- a/brainpy/_src/dynold/synapses/base.py +++ b/brainpy/_src/dynold/synapses/base.py @@ -296,7 +296,7 @@ def __init__( mode=mode) # delay - self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike) + self.pre.register_local_delay("spike", self.name, delay_step) # synaptic dynamics self.syn = syn @@ -317,7 +317,7 @@ def __init__( def update(self, pre_spike=None, stop_spike_gradient: bool = False): if pre_spike is None: - pre_spike = self.pre.get_delay_data("spike", self.delay_step) + pre_spike = self.pre.get_local_delay("spike", self.name) if stop_spike_gradient: pre_spike = jax.lax.stop_gradient(pre_spike) if self.stp is not None: diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index 10d2de792..ee1fb2b8f 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -336,7 +336,7 @@ def _compatible_reset_state(self, *args, **kwargs): the_top_layer_reset_state = True warnings.warn( ''' - From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. See https://brainpy.tech/docs/tutorial_toolbox/state_saving_and_loading.html for details. + From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_saving_and_loading.html for details. 1. If you are resetting all states in a network by calling "net.reset_state(*args, **kwargs)", please use "bp.reset_state(net, *args, **kwargs)" function, or "net.reset(*args, **kwargs)". diff --git a/brainpy/_src/integrators/ode/exponential.py b/brainpy/_src/integrators/ode/exponential.py index 2e577e6ab..e44e324e7 100644 --- a/brainpy/_src/integrators/ode/exponential.py +++ b/brainpy/_src/integrators/ode/exponential.py @@ -105,8 +105,6 @@ .. [2] Hochbruck, M., & Ostermann, A. (2010). Exponential integrators. Acta Numerica, 19, 209-286. """ -import logging - from functools import wraps from brainpy import errors from brainpy._src import math as bm @@ -360,9 +358,7 @@ def integral(*args, **kwargs): assert len(args) > 0 dt = kwargs.pop(C.DT, self.dt) linear, derivative = value_and_grad(*args, **kwargs) - phi = bm.where(linear == 0., - bm.ones_like(linear), - (bm.exp(dt * linear) - 1) / (dt * linear)) + phi = bm.exprel(dt * linear) return args[0] + dt * phi * derivative return [(integral, vars, pars), ] diff --git a/brainpy/_src/integrators/sde/normal.py b/brainpy/_src/integrators/sde/normal.py index b7de12515..34dbafff1 100644 --- a/brainpy/_src/integrators/sde/normal.py +++ b/brainpy/_src/integrators/sde/normal.py @@ -626,8 +626,7 @@ def integral(*args, **kwargs): assert len(args) > 0 dt = kwargs.pop('dt', self.dt) linear, derivative = value_and_grad(*args, **kwargs) - linear = bm.as_jax(linear) - phi = jnp.where(linear == 0., jnp.ones_like(linear), (jnp.exp(dt * linear) - 1) / (dt * linear)) + phi = bm.as_jax(bm.exprel(dt * linear)) return args[0] + dt * phi * derivative return [(integral, vars, pars), ] diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index b5d12d9ce..61746c038 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -79,7 +79,7 @@ class Array(object): """ - __slots__ = ('_value', '_keep_sharding') + __slots__ = ('_value', ) def __init__(self, value, dtype: Any = None): # array value @@ -132,7 +132,7 @@ def value(self, value): if value.dtype != self_value.dtype: raise MathError(f"The dtype of the original data is {self_value.dtype}, " f"while we got {value.dtype}.") - self._value = value.value if isinstance(value, Array) else value + self._value = value def update(self, value): """Update the value of this Array. @@ -1549,11 +1549,12 @@ def value(self): Returns: The stored data. """ + v = self._value # keep sharding constraints - if self._keep_sharding and hasattr(self._value, 'sharding') and (self._value.sharding is not None): - return jax.lax.with_sharding_constraint(self._value, self._value.sharding) + if self._keep_sharding and hasattr(v, 'sharding') and (v.sharding is not None): + return jax.lax.with_sharding_constraint(v, v.sharding) # return the value - return self._value + return v @value.setter def value(self, value): @@ -1574,6 +1575,6 @@ def value(self, value): if value.dtype != self_value.dtype: raise MathError(f"The dtype of the original data is {self_value.dtype}, " f"while we got {value.dtype}.") - self._value = value.value if isinstance(value, Array) else value + self._value = value diff --git a/brainpy/_src/math/others.py b/brainpy/_src/math/others.py index 31e97df88..f3cf4f516 100644 --- a/brainpy/_src/math/others.py +++ b/brainpy/_src/math/others.py @@ -7,14 +7,16 @@ from jax.tree_util import tree_map from brainpy import check, tools +from .compat_numpy import fill_diagonal from .environment import get_dt, get_int from .ndarray import Array -from .compat_numpy import fill_diagonal +from .interoperability import as_jax __all__ = [ 'shared_args_over_time', 'remove_diag', 'clip_by_norm', + 'exprel', ] @@ -82,3 +84,38 @@ def f(l): return l * clip_norm / jnp.maximum(jnp.sqrt(jnp.sum(l * l, axis=axis, keepdims=True)), clip_norm) return tree_map(f, t) + + +def _exprel(x, threshold): + def true_f(x): + x2 = x * x + return 1. + x / 2. + x2 / 6. + x2 * x / 24.0 # + x2 * x2 / 120. + + def false_f(x): + return (jnp.exp(x) - 1) / x + + # return jax.lax.cond(jnp.abs(x) < threshold, true_f, false_f, x) + return jnp.where(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x) + + +def exprel(x, threshold: float = None): + """Relative error exponential, ``(exp(x) - 1)/x``. + + When ``x`` is near zero, ``exp(x)`` is near 1, so the numerical calculation of ``exp(x) - 1`` can + suffer from catastrophic loss of precision. ``exprel(x)`` is implemented to avoid the loss of + precision that occurs when ``x`` is near zero. + + Args: + x: ndarray. Input array. ``x`` must contain real numbers. + threshold: float. + + Returns: + ``(exp(x) - 1)/x``, computed element-wise. + """ + x = as_jax(x) + if threshold is None: + if hasattr(x, 'dtype') and x.dtype == jnp.float64: + threshold = 1e-8 + else: + threshold = 1e-5 + return _exprel(x, threshold) diff --git a/brainpy/_src/math/tests/test_others.py b/brainpy/_src/math/tests/test_others.py new file mode 100644 index 000000000..084b8664d --- /dev/null +++ b/brainpy/_src/math/tests/test_others.py @@ -0,0 +1,21 @@ + +import brainpy.math as bm +from scipy.special import exprel + +from unittest import TestCase + + +class Test_exprel(TestCase): + def test1(self): + for x in [1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9]: + print(f'{exprel(x)}, {bm.exprel(x)}, {exprel(x) - bm.exprel(x):.10f}') + # self.assertEqual(exprel(x)) + + def test2(self): + bm.enable_x64() + for x in [1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9]: + print(f'{exprel(x)}, {bm.exprel(x)}, {exprel(x) - bm.exprel(x):.10f}') + # self.assertEqual(exprel(x)) + + + diff --git a/brainpy/math/others.py b/brainpy/math/others.py index 23d9b0816..9b9d7b368 100644 --- a/brainpy/math/others.py +++ b/brainpy/math/others.py @@ -4,6 +4,7 @@ shared_args_over_time as shared_args_over_time, remove_diag as remove_diag, clip_by_norm as clip_by_norm, + exprel as exprel, ) from brainpy._src.math.object_transform.naming import (