Skip to content

Commit

Permalink
Add celu activation layer for each nn.py (#20384)
Browse files Browse the repository at this point in the history
* Add celu activation layer for each nn.py

* Add celu implementation to ops, activations py & test cases

* rollback init py for celu

* run api_gen script

* add convert_to_tensor method to numpy implementation

* add celu to activations init py

* correct issue on NNOpsDtypeTest
  • Loading branch information
shashaka authored Oct 21, 2024
1 parent 3591299 commit c31fad7
Show file tree
Hide file tree
Showing 15 changed files with 142 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras.src.activations import deserialize
from keras.src.activations import get
from keras.src.activations import serialize
from keras.src.activations.activations import celu
from keras.src.activations.activations import elu
from keras.src.activations.activations import exponential
from keras.src.activations.activations import gelu
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from keras.src.ops.nn import batch_normalization
from keras.src.ops.nn import binary_crossentropy
from keras.src.ops.nn import categorical_crossentropy
from keras.src.ops.nn import celu
from keras.src.ops.nn import conv
from keras.src.ops.nn import conv_transpose
from keras.src.ops.nn import ctc_decode
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.ops.nn import batch_normalization
from keras.src.ops.nn import binary_crossentropy
from keras.src.ops.nn import categorical_crossentropy
from keras.src.ops.nn import celu
from keras.src.ops.nn import conv
from keras.src.ops.nn import conv_transpose
from keras.src.ops.nn import ctc_decode
Expand Down
1 change: 1 addition & 0 deletions keras/api/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras.src.activations import deserialize
from keras.src.activations import get
from keras.src.activations import serialize
from keras.src.activations.activations import celu
from keras.src.activations.activations import elu
from keras.src.activations.activations import exponential
from keras.src.activations.activations import gelu
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from keras.src.ops.nn import batch_normalization
from keras.src.ops.nn import binary_crossentropy
from keras.src.ops.nn import categorical_crossentropy
from keras.src.ops.nn import celu
from keras.src.ops.nn import conv
from keras.src.ops.nn import conv_transpose
from keras.src.ops.nn import ctc_decode
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.ops.nn import batch_normalization
from keras.src.ops.nn import binary_crossentropy
from keras.src.ops.nn import categorical_crossentropy
from keras.src.ops.nn import celu
from keras.src.ops.nn import conv
from keras.src.ops.nn import conv_transpose
from keras.src.ops.nn import ctc_decode
Expand Down
2 changes: 2 additions & 0 deletions keras/src/activations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import types

from keras.src.activations.activations import celu
from keras.src.activations.activations import elu
from keras.src.activations.activations import exponential
from keras.src.activations.activations import gelu
Expand Down Expand Up @@ -27,6 +28,7 @@
leaky_relu,
relu6,
softmax,
celu,
elu,
selu,
softplus,
Expand Down
21 changes: 21 additions & 0 deletions keras/src/activations/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,27 @@ def gelu(x, approximate=False):
return ops.gelu(x, approximate=approximate)


@keras_export("keras.activations.celu")
def celu(x, alpha=1.0):
"""Continuously Differentiable Exponential Linear Unit.
The CeLU activation function is defined as:
`celu(x) = alpha * (exp(x / alpha) - 1) for x < 0`,`celu(x) = x for x >= 0`.
where `alpha` is a scaling parameter that controls the activation's shape.
Args:
x: Input tensor.
alpha: The α value for the CeLU formulation. Defaults to `1.0`.
Reference:
- [Barron, J. T., 2017](https://arxiv.org/abs/1704.07483)
"""
return ops.celu(x, alpha=alpha)


@keras_export("keras.activations.tanh")
def tanh(x):
"""Hyperbolic tangent activation function.
Expand Down
16 changes: 16 additions & 0 deletions keras/src/activations/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,22 @@ def gelu(x, approximate=False):
expected = gelu(x, True)
self.assertAllClose(result, expected, rtol=1e-05)

def test_celu(self):
def celu(x, alpha=1.0):
return np.maximum(x, 0.0) + alpha * np.expm1(
np.minimum(x, 0.0) / alpha
)

x = np.random.random((2, 5))
result = activations.celu(x[np.newaxis, :])[0]
expected = celu(x)
self.assertAllClose(result, expected, rtol=1e-05)

x = np.random.random((2, 5))
result = activations.celu(x[np.newaxis, :], alpha=0.5)[0]
expected = celu(x, True)
self.assertAllClose(result, expected, rtol=1e-05)

def test_elu(self):
x = np.random.random((2, 5))
result = activations.elu(x[np.newaxis, :])[0]
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def gelu(x, approximate=True):
return jnn.gelu(x, approximate)


def celu(x, alpha=1.0):
x = convert_to_tensor(x)
return jnn.celu(x, alpha=alpha)


def softmax(x, axis=-1):
x = convert_to_tensor(x)
return jnn.softmax(x, axis=axis)
Expand Down
8 changes: 8 additions & 0 deletions keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ def gelu(x, approximate=True):
)


def celu(x, alpha=1.0):
x = convert_to_tensor(x)
alpha = np.array(alpha, x.dtype)
return np.maximum(x, np.array(0.0, dtype=x.dtype)) + alpha * np.expm1(
np.minimum(x, np.array(0.0, dtype=x.dtype)) / alpha
)


def softmax(x, axis=None):
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def gelu(x, approximate=True):
return tf.nn.gelu(x, approximate=approximate)


def celu(x, alpha=1.0):
return tf.maximum(x, 0.0) + alpha * tf.math.expm1(
tf.minimum(x, 0.0) / alpha
)


def softmax(x, axis=-1):
logits = x
if axis is None:
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def gelu(x, approximate=True):
return tnn.gelu(x)


def celu(x, alpha=1.0):
x = convert_to_tensor(x)
return tnn.celu(x, alpha=alpha)


def softmax(x, axis=-1):
x = convert_to_tensor(x)
dtype = backend.standardize_dtype(x.dtype)
Expand Down
40 changes: 40 additions & 0 deletions keras/src/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,46 @@ def gelu(x, approximate=True):
return backend.nn.gelu(x, approximate)


class Celu(Operation):
def __init__(self, alpha=1.0):
super().__init__()
self.alpha = alpha

def call(self, x):
return backend.nn.celu(x, self.alpha)

def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)


@keras_export(["keras.ops.celu", "keras.ops.nn.celu"])
def celu(x, alpha=1.0):
"""Continuously-differentiable exponential linear unit.
It is defined as:
`f(x) = alpha * (exp(x / alpha) - 1) for x < 0`, `f(x) = x for x >= 0`.
Args:
x: Input tensor.
alpha: the α value for the CELU formulation. Defaults to `1.0`.
Returns:
A tensor with the same shape as `x`.
Example:
>>> x = np.array([-1., 0., 1.])
>>> x_celu = keras.ops.celu(x)
>>> print(x_celu)
array([-0.63212056, 0. , 1. ], shape=(3,), dtype=float64)
"""
if any_symbolic_tensors((x,)):
return Celu(alpha).symbolic_call(x)
return backend.nn.celu(x, alpha)


class Softmax(Operation):
def __init__(self, axis=-1):
super().__init__()
Expand Down
33 changes: 33 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def test_gelu(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.gelu(x).shape, (None, 2, 3))

def test_celu(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.celu(x).shape, (None, 2, 3))

def test_softmax(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.softmax(x).shape, (None, 2, 3))
Expand Down Expand Up @@ -786,6 +790,10 @@ def test_gelu(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.gelu(x).shape, (1, 2, 3))

def test_celu(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.celu(x).shape, (1, 2, 3))

def test_softmax(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.softmax(x).shape, (1, 2, 3))
Expand Down Expand Up @@ -1292,6 +1300,13 @@ def test_gelu(self):
[-0.15880796, 0.0, 0.841192, 1.9545977, 2.9963627],
)

def test_celu(self):
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)
self.assertAllClose(
knn.celu(x),
[-0.63212055, 0.0, 1.0, 2.0, 3.0],
)

def test_softmax(self):
x = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)
self.assertAllClose(
Expand Down Expand Up @@ -2363,6 +2378,24 @@ def test_gelu(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
def test_celu(self, dtype):
import jax.nn as jnn
import jax.numpy as jnp

x = knp.ones((), dtype=dtype)
x_jax = jnp.ones((), dtype=dtype)
expected_dtype = standardize_dtype(jnn.celu(x_jax).dtype)

self.assertEqual(
standardize_dtype(knn.celu(x).dtype),
expected_dtype,
)
self.assertEqual(
standardize_dtype(knn.Celu().symbolic_call(x).dtype),
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
def test_hard_sigmoid(self, dtype):
import jax.nn as jnn
Expand Down

0 comments on commit c31fad7

Please sign in to comment.