diff --git a/.gitignore b/.gitignore index d955216fd45..a4a90053b6b 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,5 @@ dist/** examples/**/*.jpg .python-version .coverage -*coverage.xml \ No newline at end of file +*coverage.xml +.ruff_cache \ No newline at end of file diff --git a/keras/api/_tf_keras/keras/activations/__init__.py b/keras/api/_tf_keras/keras/activations/__init__.py index ad5ae3e352b..2bf9c2f9b5e 100644 --- a/keras/api/_tf_keras/keras/activations/__init__.py +++ b/keras/api/_tf_keras/keras/activations/__init__.py @@ -33,6 +33,7 @@ from keras.src.activations.activations import softplus from keras.src.activations.activations import softsign from keras.src.activations.activations import sparse_plus +from keras.src.activations.activations import sparsemax from keras.src.activations.activations import squareplus from keras.src.activations.activations import tanh from keras.src.activations.activations import tanh_shrink diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index af0063b5635..9ff1eb69749 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -100,6 +100,7 @@ from keras.src.ops.nn import softsign from keras.src.ops.nn import sparse_categorical_crossentropy from keras.src.ops.nn import sparse_plus +from keras.src.ops.nn import sparsemax from keras.src.ops.nn import squareplus from keras.src.ops.nn import tanh_shrink from keras.src.ops.numpy import abs diff --git a/keras/api/_tf_keras/keras/ops/nn/__init__.py b/keras/api/_tf_keras/keras/ops/nn/__init__.py index 04a4d7a471a..6077373b28d 100644 --- a/keras/api/_tf_keras/keras/ops/nn/__init__.py +++ b/keras/api/_tf_keras/keras/ops/nn/__init__.py @@ -45,5 +45,6 @@ from keras.src.ops.nn import softsign from keras.src.ops.nn import sparse_categorical_crossentropy from keras.src.ops.nn import sparse_plus +from keras.src.ops.nn import sparsemax from keras.src.ops.nn import squareplus from keras.src.ops.nn import tanh_shrink diff --git a/keras/api/activations/__init__.py b/keras/api/activations/__init__.py index ad5ae3e352b..2bf9c2f9b5e 100644 --- a/keras/api/activations/__init__.py +++ b/keras/api/activations/__init__.py @@ -33,6 +33,7 @@ from keras.src.activations.activations import softplus from keras.src.activations.activations import softsign from keras.src.activations.activations import sparse_plus +from keras.src.activations.activations import sparsemax from keras.src.activations.activations import squareplus from keras.src.activations.activations import tanh from keras.src.activations.activations import tanh_shrink diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index af0063b5635..9ff1eb69749 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -100,6 +100,7 @@ from keras.src.ops.nn import softsign from keras.src.ops.nn import sparse_categorical_crossentropy from keras.src.ops.nn import sparse_plus +from keras.src.ops.nn import sparsemax from keras.src.ops.nn import squareplus from keras.src.ops.nn import tanh_shrink from keras.src.ops.numpy import abs diff --git a/keras/api/ops/nn/__init__.py b/keras/api/ops/nn/__init__.py index 04a4d7a471a..6077373b28d 100644 --- a/keras/api/ops/nn/__init__.py +++ b/keras/api/ops/nn/__init__.py @@ -45,5 +45,6 @@ from keras.src.ops.nn import softsign from keras.src.ops.nn import sparse_categorical_crossentropy from keras.src.ops.nn import sparse_plus +from keras.src.ops.nn import sparsemax from keras.src.ops.nn import squareplus from keras.src.ops.nn import tanh_shrink diff --git a/keras/src/activations/__init__.py b/keras/src/activations/__init__.py index 42c0f2d0565..2924c7e006d 100644 --- a/keras/src/activations/__init__.py +++ b/keras/src/activations/__init__.py @@ -24,6 +24,7 @@ from keras.src.activations.activations import softplus from keras.src.activations.activations import softsign from keras.src.activations.activations import sparse_plus +from keras.src.activations.activations import sparsemax from keras.src.activations.activations import squareplus from keras.src.activations.activations import tanh from keras.src.activations.activations import tanh_shrink @@ -59,6 +60,7 @@ mish, log_softmax, log_sigmoid, + sparsemax, } ALL_OBJECTS_DICT = {fn.__name__: fn for fn in ALL_OBJECTS} diff --git a/keras/src/activations/activations.py b/keras/src/activations/activations.py index 115f6e6959a..e50edd9c96e 100644 --- a/keras/src/activations/activations.py +++ b/keras/src/activations/activations.py @@ -617,3 +617,28 @@ def log_softmax(x, axis=-1): axis: Integer, axis along which the softmax is applied. """ return ops.log_softmax(x, axis=axis) + + +@keras_export(["keras.activations.sparsemax"]) +def sparsemax(x, axis=-1): + """Sparsemax activation function. + + For each batch `i`, and class `j`, + sparsemax activation function is defined as: + + `sparsemax(x)[i, j] = max(x[i, j] - τ(x[i, :]), 0).` + + Args: + x: Input tensor. + axis: `int`, axis along which the sparsemax operation is applied. + + Returns: + A tensor, output of sparsemax transformation. Has the same type and + shape as `x`. + + Reference: + + - [Martins et.al., 2016](https://arxiv.org/abs/1602.02068) + """ + x = backend.convert_to_tensor(x) + return ops.sparsemax(x, axis) diff --git a/keras/src/activations/activations_test.py b/keras/src/activations/activations_test.py index 5c9a809f509..2cf543a5212 100644 --- a/keras/src/activations/activations_test.py +++ b/keras/src/activations/activations_test.py @@ -896,6 +896,55 @@ def test_linear(self): x_int32 = np.random.randint(-10, 10, (10, 5)).astype(np.int32) self.assertAllClose(x_int32, activations.linear(x_int32)) + def test_sparsemax(self): + # result check with 1d + x_1d = np.linspace(1, 12, num=12) + expected_result = np.zeros_like(x_1d) + expected_result[-1] = 1.0 + self.assertAllClose(expected_result, activations.sparsemax(x_1d)) + + # result check with 2d + x_2d = np.linspace(1, 12, num=12).reshape(-1, 2) + expected_result = np.zeros_like(x_2d) + expected_result[:, -1] = 1.0 + self.assertAllClose(expected_result, activations.sparsemax(x_2d)) + + # result check with 3d + x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3) + expected_result = np.zeros_like(x_3d) + expected_result[:, :, -1] = 1.0 + self.assertAllClose(expected_result, activations.sparsemax(x_3d)) + + # result check with axis=-2 with 2d input + x_2d = np.linspace(1, 12, num=12).reshape(-1, 2) + expected_result = np.zeros_like(x_2d) + expected_result[-1, :] = 1.0 + self.assertAllClose( + expected_result, activations.sparsemax(x_2d, axis=-2) + ) + + # result check with axis=-2 with 3d input + x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3) + expected_result = np.ones_like(x_3d) + self.assertAllClose( + expected_result, activations.sparsemax(x_3d, axis=-2) + ) + + # result check with axis=-3 with 3d input + x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3) + expected_result = np.zeros_like(x_3d) + expected_result[-1, :, :] = 1.0 + self.assertAllClose( + expected_result, activations.sparsemax(x_3d, axis=-3) + ) + + # result check with axis=-3 with 4d input + x_4d = np.linspace(1, 12, num=12).reshape(-1, 1, 1, 2) + expected_result = np.ones_like(x_4d) + self.assertAllClose( + expected_result, activations.sparsemax(x_4d, axis=-3) + ) + def test_get_method(self): obj = activations.get("relu") self.assertEqual(obj, activations.relu) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 56a11e74d83..d0bb721dd30 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -142,6 +142,24 @@ def log_softmax(x, axis=-1): return jnn.log_softmax(x, axis=axis) +def sparsemax(logits, axis=-1): + # Sort logits along the specified axis in descending order + logits = convert_to_tensor(logits) + logits_sorted = -1.0 * jnp.sort(logits * -1.0, axis=axis) + logits_cumsum = jnp.cumsum(logits_sorted, axis=axis) # find cumulative sum + r = jnp.arange(1, logits.shape[axis] + 1) # Determine the sparsity + r_shape = [1] * logits.ndim + r_shape[axis] = -1 # Broadcast to match the target axis + r = r.reshape(r_shape) + support = logits_sorted - (logits_cumsum - 1) / r > 0 + # Find the threshold + k = jnp.sum(support, axis=axis, keepdims=True) + logits_cumsum_safe = jnp.where(support, logits_cumsum, 0.0) + tau = (jnp.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k + output = jnp.maximum(logits - tau, 0.0) + return output + + def _convert_to_spatial_operand( x, num_spatial_dims, diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index a2438325df6..d9ecefca7c5 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -191,6 +191,24 @@ def log_softmax(x, axis=None): return x - max_x - logsumexp +def sparsemax(logits, axis=-1): + # Sort logits along the specified axis in descending order + logits = convert_to_tensor(logits) + logits_sorted = -1.0 * np.sort(-1.0 * logits, axis=axis) + logits_cumsum = np.cumsum(logits_sorted, axis=axis) + r = np.arange(1, logits.shape[axis] + 1) + r_shape = [1] * logits.ndim + r_shape[axis] = -1 # Broadcast to match the target axis + r = r.reshape(r_shape) + support = logits_sorted - (logits_cumsum - 1) / r > 0 + # Find the threshold + k = np.sum(support, axis=axis, keepdims=True) + logits_cumsum_safe = np.where(support, logits_cumsum, 0.0) + tau = (np.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k + output = np.maximum(logits - tau, 0.0) + return output + + def _convert_to_spatial_operand( x, num_spatial_dims, diff --git a/keras/src/backend/tensorflow/nn.py b/keras/src/backend/tensorflow/nn.py index a043c8bc074..0f47ea99a1d 100644 --- a/keras/src/backend/tensorflow/nn.py +++ b/keras/src/backend/tensorflow/nn.py @@ -151,6 +151,24 @@ def log_softmax(x, axis=-1): return tf.nn.log_softmax(x, axis=axis) +def sparsemax(logits, axis=-1): + # Sort logits along the specified axis in descending order + logits = convert_to_tensor(logits) + logits_sorted = tf.sort(logits, direction="DESCENDING", axis=axis) + logits_cumsum = tf.cumsum(logits_sorted, axis=axis) + r = tf.range(1, tf.shape(logits)[axis] + 1, dtype=logits.dtype) + r_shape = [1] * len(logits.shape) + r_shape[axis] = -1 # Broadcast to match the target axis + r = tf.reshape(r, r_shape) # Reshape for broadcasting + support = logits_sorted - (logits_cumsum - 1) / r > 0 + # Find the threshold + logits_cumsum_safe = tf.where(support, logits_cumsum, 0.0) + k = tf.reduce_sum(tf.cast(support, logits.dtype), axis=axis, keepdims=True) + tau = (tf.reduce_sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k + output = tf.maximum(logits - tau, 0.0) + return output + + def _transpose_spatial_inputs(inputs): num_spatial_dims = len(inputs.shape) - 2 # Tensorflow pooling does not support `channels_first` format, so diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index c8bc15df837..11fb6edba24 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -174,6 +174,28 @@ def log_softmax(x, axis=-1): return cast(output, dtype) +def sparsemax(logits, axis=-1): + # Sort logits along the specified axis in descending order + logits = convert_to_tensor(logits) + logits_sorted, _ = torch.sort(logits, dim=axis, descending=True) + logits_cumsum = torch.cumsum(logits_sorted, dim=axis) + r = torch.arange( + 1, logits.size(axis) + 1, device=logits.device, dtype=logits.dtype + ) + r_shape = [1] * logits.ndim + r_shape[axis] = -1 # Broadcast to match the target axis + r = r.view(r_shape) + support = logits_sorted - (logits_cumsum - 1) / r > 0 + # Find the threshold + k = torch.sum(support, dim=axis, keepdim=True) + logits_cumsum_safe = torch.where( + support, logits_cumsum, torch.tensor(0.0, device=logits.device) + ) + tau = (torch.sum(logits_cumsum_safe, dim=axis, keepdim=True) - 1) / k + output = torch.clamp(logits - tau, min=0.0) + return output + + def _compute_padding_length( input_length, kernel_length, stride, dilation_rate=1 ): diff --git a/keras/src/ops/nn.py b/keras/src/ops/nn.py index 8533c83be69..5112525f6fa 100644 --- a/keras/src/ops/nn.py +++ b/keras/src/ops/nn.py @@ -951,6 +951,48 @@ def log_softmax(x, axis=-1): return backend.nn.log_softmax(x, axis=axis) +class Sparsemax(Operation): + def __init__(self, axis=-1): + super().__init__() + self.axis = axis + + def call(self, x): + return backend.nn.sparsemax(x, axis=self.axis) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.sparsemax", "keras.ops.nn.sparsemax"]) +def sparsemax(x, axis=-1): + """Sparsemax activation function. + + For each batch `i`, and class `j`, + sparsemax activation function is defined as: + + `sparsemax(x)[i, j] = max(x[i, j] - τ(x[i, :]), 0).` + + Args: + x: Input tensor. + axis: `int`, axis along which the sparsemax operation is applied. + + Returns: + A tensor, output of sparsemax transformation. Has the same type and + shape as `x`. + + Example: + + >>> x = np.array([-1., 0., 1.]) + >>> x_sparsemax = keras.ops.sparsemax(x) + >>> print(x_sparsemax) + array([0., 0., 1.], shape=(3,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return Sparsemax(axis).symbolic_call(x) + return backend.nn.sparsemax(x, axis=axis) + + class MaxPool(Operation): def __init__( self, diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index c73e555fd02..8bf758cbfff 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -200,6 +200,10 @@ def test_log_softmax(self): self.assertEqual(knn.log_softmax(x, axis=1).shape, (None, 2, 3)) self.assertEqual(knn.log_softmax(x, axis=-1).shape, (None, 2, 3)) + def test_sparsemax(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.sparsemax(x).shape, (None, 2, 3)) + def test_max_pool(self): data_format = backend.config.image_data_format() if data_format == "channels_last": @@ -861,6 +865,10 @@ def test_log_softmax(self): self.assertEqual(knn.log_softmax(x, axis=1).shape, (1, 2, 3)) self.assertEqual(knn.log_softmax(x, axis=-1).shape, (1, 2, 3)) + def test_sparsemax(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.sparsemax(x).shape, (1, 2, 3)) + def test_max_pool(self): data_format = backend.config.image_data_format() if data_format == "channels_last": @@ -1487,6 +1495,13 @@ def test_log_softmax_correctness_with_axis_tuple(self): ) self.assertAllClose(normalized_sum_by_axis, 1.0) + def test_sparsemax(self): + x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.sparsemax(x), + [0.0, 0.0, 0.0, 0.0, 1.0], + ) + def test_max_pool(self): data_format = backend.config.image_data_format() # Test 1D max pooling.