diff --git a/keras/api/_tf_keras/keras/activations/__init__.py b/keras/api/_tf_keras/keras/activations/__init__.py index 17624b6ba5d..a56def1a208 100644 --- a/keras/api/_tf_keras/keras/activations/__init__.py +++ b/keras/api/_tf_keras/keras/activations/__init__.py @@ -7,6 +7,7 @@ from keras.src.activations import deserialize from keras.src.activations import get from keras.src.activations import serialize +from keras.src.activations.activations import celu from keras.src.activations.activations import elu from keras.src.activations.activations import exponential from keras.src.activations.activations import gelu diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 20cf46889d2..12a8571fd7d 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -62,6 +62,7 @@ from keras.src.ops.nn import batch_normalization from keras.src.ops.nn import binary_crossentropy from keras.src.ops.nn import categorical_crossentropy +from keras.src.ops.nn import celu from keras.src.ops.nn import conv from keras.src.ops.nn import conv_transpose from keras.src.ops.nn import ctc_decode diff --git a/keras/api/_tf_keras/keras/ops/nn/__init__.py b/keras/api/_tf_keras/keras/ops/nn/__init__.py index adce3312860..49683dc70bd 100644 --- a/keras/api/_tf_keras/keras/ops/nn/__init__.py +++ b/keras/api/_tf_keras/keras/ops/nn/__init__.py @@ -8,6 +8,7 @@ from keras.src.ops.nn import batch_normalization from keras.src.ops.nn import binary_crossentropy from keras.src.ops.nn import categorical_crossentropy +from keras.src.ops.nn import celu from keras.src.ops.nn import conv from keras.src.ops.nn import conv_transpose from keras.src.ops.nn import ctc_decode diff --git a/keras/api/activations/__init__.py b/keras/api/activations/__init__.py index 17624b6ba5d..a56def1a208 100644 --- a/keras/api/activations/__init__.py +++ b/keras/api/activations/__init__.py @@ -7,6 +7,7 @@ from keras.src.activations import deserialize from keras.src.activations import get from keras.src.activations import serialize +from keras.src.activations.activations import celu from keras.src.activations.activations import elu from keras.src.activations.activations import exponential from keras.src.activations.activations import gelu diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 20cf46889d2..12a8571fd7d 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -62,6 +62,7 @@ from keras.src.ops.nn import batch_normalization from keras.src.ops.nn import binary_crossentropy from keras.src.ops.nn import categorical_crossentropy +from keras.src.ops.nn import celu from keras.src.ops.nn import conv from keras.src.ops.nn import conv_transpose from keras.src.ops.nn import ctc_decode diff --git a/keras/api/ops/nn/__init__.py b/keras/api/ops/nn/__init__.py index adce3312860..49683dc70bd 100644 --- a/keras/api/ops/nn/__init__.py +++ b/keras/api/ops/nn/__init__.py @@ -8,6 +8,7 @@ from keras.src.ops.nn import batch_normalization from keras.src.ops.nn import binary_crossentropy from keras.src.ops.nn import categorical_crossentropy +from keras.src.ops.nn import celu from keras.src.ops.nn import conv from keras.src.ops.nn import conv_transpose from keras.src.ops.nn import ctc_decode diff --git a/keras/src/activations/__init__.py b/keras/src/activations/__init__.py index 13bc6de5dba..57cd085a173 100644 --- a/keras/src/activations/__init__.py +++ b/keras/src/activations/__init__.py @@ -1,5 +1,6 @@ import types +from keras.src.activations.activations import celu from keras.src.activations.activations import elu from keras.src.activations.activations import exponential from keras.src.activations.activations import gelu @@ -27,6 +28,7 @@ leaky_relu, relu6, softmax, + celu, elu, selu, softplus, diff --git a/keras/src/activations/activations.py b/keras/src/activations/activations.py index 3f875b64b15..8dc56ee43b7 100644 --- a/keras/src/activations/activations.py +++ b/keras/src/activations/activations.py @@ -302,6 +302,27 @@ def gelu(x, approximate=False): return ops.gelu(x, approximate=approximate) +@keras_export("keras.activations.celu") +def celu(x, alpha=1.0): + """Continuously Differentiable Exponential Linear Unit. + + The CeLU activation function is defined as: + + `celu(x) = alpha * (exp(x / alpha) - 1) for x < 0`,`celu(x) = x for x >= 0`. + + where `alpha` is a scaling parameter that controls the activation's shape. + + Args: + x: Input tensor. + alpha: The α value for the CeLU formulation. Defaults to `1.0`. + + Reference: + + - [Barron, J. T., 2017](https://arxiv.org/abs/1704.07483) + """ + return ops.celu(x, alpha=alpha) + + @keras_export("keras.activations.tanh") def tanh(x): """Hyperbolic tangent activation function. diff --git a/keras/src/activations/activations_test.py b/keras/src/activations/activations_test.py index c0ae34a1739..045ffab14d8 100644 --- a/keras/src/activations/activations_test.py +++ b/keras/src/activations/activations_test.py @@ -582,6 +582,22 @@ def gelu(x, approximate=False): expected = gelu(x, True) self.assertAllClose(result, expected, rtol=1e-05) + def test_celu(self): + def celu(x, alpha=1.0): + return np.maximum(x, 0.0) + alpha * np.expm1( + np.minimum(x, 0.0) / alpha + ) + + x = np.random.random((2, 5)) + result = activations.celu(x[np.newaxis, :])[0] + expected = celu(x) + self.assertAllClose(result, expected, rtol=1e-05) + + x = np.random.random((2, 5)) + result = activations.celu(x[np.newaxis, :], alpha=0.5)[0] + expected = celu(x, True) + self.assertAllClose(result, expected, rtol=1e-05) + def test_elu(self): x = np.random.random((2, 5)) result = activations.elu(x[np.newaxis, :])[0] diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index cba73918976..9899f1b65d5 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -85,6 +85,11 @@ def gelu(x, approximate=True): return jnn.gelu(x, approximate) +def celu(x, alpha=1.0): + x = convert_to_tensor(x) + return jnn.celu(x, alpha=alpha) + + def softmax(x, axis=-1): x = convert_to_tensor(x) return jnn.softmax(x, axis=axis) diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index eea127e554a..6e3f8203957 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -113,6 +113,14 @@ def gelu(x, approximate=True): ) +def celu(x, alpha=1.0): + x = convert_to_tensor(x) + alpha = np.array(alpha, x.dtype) + return np.maximum(x, np.array(0.0, dtype=x.dtype)) + alpha * np.expm1( + np.minimum(x, np.array(0.0, dtype=x.dtype)) / alpha + ) + + def softmax(x, axis=None): exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) return exp_x / np.sum(exp_x, axis=axis, keepdims=True) diff --git a/keras/src/backend/tensorflow/nn.py b/keras/src/backend/tensorflow/nn.py index 01a1aca26d0..7c16e5e901b 100644 --- a/keras/src/backend/tensorflow/nn.py +++ b/keras/src/backend/tensorflow/nn.py @@ -76,6 +76,12 @@ def gelu(x, approximate=True): return tf.nn.gelu(x, approximate=approximate) +def celu(x, alpha=1.0): + return tf.maximum(x, 0.0) + alpha * tf.math.expm1( + tf.minimum(x, 0.0) / alpha + ) + + def softmax(x, axis=-1): logits = x if axis is None: diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index e4291f6b84c..7c253988480 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -88,6 +88,11 @@ def gelu(x, approximate=True): return tnn.gelu(x) +def celu(x, alpha=1.0): + x = convert_to_tensor(x) + return tnn.celu(x, alpha=alpha) + + def softmax(x, axis=-1): x = convert_to_tensor(x) dtype = backend.standardize_dtype(x.dtype) diff --git a/keras/src/ops/nn.py b/keras/src/ops/nn.py index 2d779582a5b..0531f87f869 100644 --- a/keras/src/ops/nn.py +++ b/keras/src/ops/nn.py @@ -498,6 +498,46 @@ def gelu(x, approximate=True): return backend.nn.gelu(x, approximate) +class Celu(Operation): + def __init__(self, alpha=1.0): + super().__init__() + self.alpha = alpha + + def call(self, x): + return backend.nn.celu(x, self.alpha) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.celu", "keras.ops.nn.celu"]) +def celu(x, alpha=1.0): + """Continuously-differentiable exponential linear unit. + + It is defined as: + + `f(x) = alpha * (exp(x / alpha) - 1) for x < 0`, `f(x) = x for x >= 0`. + + Args: + x: Input tensor. + alpha: the α value for the CELU formulation. Defaults to `1.0`. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = np.array([-1., 0., 1.]) + >>> x_celu = keras.ops.celu(x) + >>> print(x_celu) + array([-0.63212056, 0. , 1. ], shape=(3,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return Celu(alpha).symbolic_call(x) + return backend.nn.celu(x, alpha) + + class Softmax(Operation): def __init__(self, axis=-1): super().__init__() diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index 4d75760d894..a14a32b46f0 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -141,6 +141,10 @@ def test_gelu(self): x = KerasTensor([None, 2, 3]) self.assertEqual(knn.gelu(x).shape, (None, 2, 3)) + def test_celu(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.celu(x).shape, (None, 2, 3)) + def test_softmax(self): x = KerasTensor([None, 2, 3]) self.assertEqual(knn.softmax(x).shape, (None, 2, 3)) @@ -786,6 +790,10 @@ def test_gelu(self): x = KerasTensor([1, 2, 3]) self.assertEqual(knn.gelu(x).shape, (1, 2, 3)) + def test_celu(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.celu(x).shape, (1, 2, 3)) + def test_softmax(self): x = KerasTensor([1, 2, 3]) self.assertEqual(knn.softmax(x).shape, (1, 2, 3)) @@ -1292,6 +1300,13 @@ def test_gelu(self): [-0.15880796, 0.0, 0.841192, 1.9545977, 2.9963627], ) + def test_celu(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.celu(x), + [-0.63212055, 0.0, 1.0, 2.0, 3.0], + ) + def test_softmax(self): x = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) self.assertAllClose( @@ -2363,6 +2378,24 @@ def test_gelu(self, dtype): expected_dtype, ) + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_celu(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.celu(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.celu(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Celu().symbolic_call(x).dtype), + expected_dtype, + ) + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) def test_hard_sigmoid(self, dtype): import jax.nn as jnn