Skip to content

Commit

Permalink
update the brainpy.math.softplus
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jan 3, 2024
1 parent 46d4c46 commit a7ffdef
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions brainpy/_src/dnn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ class Softplus(Layer):
Args:
beta: the :math:`\beta` value for the Softplus formulation. Default: 1
threshold: values above this revert to a linear function. Default: 40
threshold: values above this revert to a linear function. Default: 20
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
Expand All @@ -840,10 +840,10 @@ class Softplus(Layer):
>>> output = m(input)
"""
__constants__ = ['beta', 'threshold']
beta: int
threshold: int
beta: float
threshold: float

def __init__(self, beta: int = 1, threshold: int = 40) -> None:
def __init__(self, beta: float = 1, threshold: float = 20.) -> None:
super().__init__()
self.beta = beta
self.threshold = threshold
Expand Down
8 changes: 4 additions & 4 deletions brainpy/_src/math/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def leaky_relu(x, negative_slope=1e-2):
return jnp.where(x >= 0, x, negative_slope * x)


def softplus(x, beta=1, threshold=40):
def softplus(x, beta: float = 1., threshold: float = 20.):
r"""Softplus activation function.
Computes the element-wise function
Expand All @@ -315,12 +315,12 @@ def softplus(x, beta=1, threshold=40):
Parameters
----------
x: The input array.
beta: the :math:`\beta` value for the Softplus formulation. Default: 1
threshold: values above this revert to a linear function. Default: 40
beta: the :math:`\beta` value for the Softplus formulation. Default: 1.
threshold: values above this revert to a linear function. Default: 20.
"""
x = x.value if isinstance(x, Array) else x
return jnp.where(x > threshold, x, 1 / beta * jnp.logaddexp(beta * x, 0))
return jnp.where(x * beta > threshold, x, 1 / beta * jnp.logaddexp(beta * x, 0))


def log_sigmoid(x):
Expand Down

0 comments on commit a7ffdef

Please sign in to comment.