diff --git a/keras/ops/math.py b/keras/ops/math.py index 6b6bc3b0c5c..94d83e0c2a3 100644 --- a/keras/ops/math.py +++ b/keras/ops/math.py @@ -381,6 +381,10 @@ def extract_sequences(x, sequence_length, sequence_stride): class FFT(Operation): + def __init__(self, axis=-1): + super().__init__() + self.axis = axis + def compute_output_spec(self, x): if not isinstance(x, (tuple, list)) or len(x) != 2: raise ValueError( @@ -449,6 +453,10 @@ def fft(x): class FFT2(Operation): + def __init__(self): + super().__init__() + self.axes = (-2, -1) + def compute_output_spec(self, x): if not isinstance(x, (tuple, list)) or len(x) != 2: raise ValueError( @@ -473,8 +481,8 @@ def compute_output_spec(self, x): ) # The axes along which we are calculating FFT should be fully-defined. - m = real.shape[-1] - n = real.shape[-2] + m = real.shape[self.axes[0]] + n = real.shape[self.axes[1]] if m is None or n is None: raise ValueError( f"Input should have its {self.axes} axes fully-defined. " diff --git a/keras/ops/math_test.py b/keras/ops/math_test.py index ee43a01a408..abf8074fd18 100644 --- a/keras/ops/math_test.py +++ b/keras/ops/math_test.py @@ -976,3 +976,240 @@ def calculate_expected_shape( (input_shape[1] - sequence_length) // sequence_stride ) + 1 return (input_shape[0], num_sequences, sequence_length) + + +class SegmentSumTest(testing.TestCase): + def test_segment_sum_call(self): + data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32) + segment_ids = np.array([0, 0, 1], dtype=np.int32) + num_segments = 2 + sorted_segments = False + segment_sum_op = kmath.SegmentSum( + num_segments=num_segments, sorted=sorted_segments + ) + output = segment_sum_op.call(data, segment_ids) + expected_output = np.array([[5, 7, 9], [7, 8, 9]], dtype=np.float32) + self.assertAllClose(output, expected_output) + + +class SegmentMaxTest(testing.TestCase): + def test_segment_max_call(self): + data = np.array([[1, 4, 7], [2, 5, 8], [3, 6, 9]], dtype=np.float32) + segment_ids = np.array([0, 0, 1], dtype=np.int32) + num_segments = 2 + sorted_segments = False + segment_max_op = kmath.SegmentMax( + num_segments=num_segments, sorted=sorted_segments + ) + output = segment_max_op.call(data, segment_ids) + expected_output = np.array([[2, 5, 8], [3, 6, 9]], dtype=np.float32) + self.assertAllClose(output, expected_output) + + +class TopKTest(testing.TestCase): + def test_top_k_call_values(self): + data = np.array([[1, 3, 2], [4, 6, 5]], dtype=np.float32) + k = 2 + sorted_flag = True + top_k_op = kmath.TopK(k=k, sorted=sorted_flag) + values, _ = top_k_op.call(data) + expected_values = np.array([[3, 2], [6, 5]], dtype=np.float32) + self.assertAllClose(values, expected_values) + + def test_top_k_call_indices(self): + data = np.array([[1, 3, 2], [4, 6, 5]], dtype=np.float32) + k = 2 + sorted_flag = True + top_k_op = kmath.TopK(k=k, sorted=sorted_flag) + _, indices = top_k_op.call(data) + expected_indices = np.array([[1, 2], [1, 2]], dtype=np.int32) + self.assertAllClose(indices, expected_indices) + + +class InTopKTest(testing.TestCase): + def test_in_top_k_call(self): + targets = np.array([2, 0, 1], dtype=np.int32) + predictions = np.array( + [[0.1, 0.2, 0.7], [1.0, 0.2, 0.3], [0.2, 0.6, 0.2]], + dtype=np.float32, + ) + k = 2 + in_top_k_op = kmath.InTopK(k=k) + output = in_top_k_op.call(targets, predictions) + expected_output = np.array([True, True, True], dtype=bool) + self.assertAllEqual(output, expected_output) + + +class LogsumexpTest(testing.TestCase): + def test_logsumexp_call(self): + x = np.array([[1, 2], [3, 4]], dtype=np.float32) + axis = 0 + keepdims = True + logsumexp_op = kmath.Logsumexp(axis=axis, keepdims=keepdims) + output = logsumexp_op.call(x) + expected_output = np.log( + np.sum(np.exp(x), axis=axis, keepdims=keepdims) + ) + self.assertAllClose(output, expected_output) + + +class FFTTest(testing.TestCase): + def test_fft_input_not_tuple_or_list(self): + fft_op = kmath.FFT() + with self.assertRaisesRegex( + ValueError, "Input `x` should be a tuple of two tensors" + ): + fft_op.compute_output_spec(np.array([1, 2, 3])) + + def test_fft_input_parts_different_shapes(self): + fft_op = kmath.FFT() + real = np.array([1, 2, 3]) + imag = np.array([1, 2]) + with self.assertRaisesRegex( + ValueError, + "Both the real and imaginary parts should have the same shape", + ): + fft_op.compute_output_spec((real, imag)) + + def test_fft_input_not_1d(self): + fft_op = kmath.FFT() + real = np.array(1) + imag = np.array(1) + with self.assertRaisesRegex(ValueError, "Input should have rank >= 1"): + fft_op.compute_output_spec((real, imag)) + + def test_fft_last_axis_not_fully_defined(self): + fft_op = kmath.FFT() + real = KerasTensor(shape=(None,), dtype="float32") + imag = KerasTensor(shape=(None,), dtype="float32") + with self.assertRaisesRegex( + ValueError, "Input should have its -1th axis fully-defined" + ): + fft_op.compute_output_spec((real, imag)) + + def test_fft_init_default_axis(self): + fft_op = kmath.FFT() + self.assertEqual(fft_op.axis, -1, "Default axis should be -1") + + +class SolveTest(testing.TestCase): + def test_solve_call(self): + solve_op = kmath.Solve() + a = np.array([[3, 2], [1, 2]], dtype=np.float32) + b = np.array([[9, 8], [5, 4]], dtype=np.float32) + output = solve_op.call(a, b) + expected_output = np.linalg.solve(a, b) + np.testing.assert_allclose(output, expected_output, atol=1e-6) + + +class FFT2Test(testing.TestCase): + def test_fft2_correct_input(self): + fft2_op = kmath.FFT2() + real_part = np.random.rand(2, 3, 4) + imag_part = np.random.rand(2, 3, 4) + # This should not raise any errors + fft2_op.compute_output_spec((real_part, imag_part)) + + def test_fft2_incorrect_input_type(self): + fft2_op = kmath.FFT2() + incorrect_input = np.array([1, 2, 3]) # Not a tuple or list + with self.assertRaisesRegex( + ValueError, "should be a tuple of two tensors" + ): + fft2_op.compute_output_spec(incorrect_input) + + def test_fft2_mismatched_shapes(self): + fft2_op = kmath.FFT2() + real_part = np.random.rand(2, 3, 4) + imag_part = np.random.rand(2, 3) # Mismatched shape + with self.assertRaisesRegex( + ValueError, + "Both the real and imaginary parts should have the same shape", + ): + fft2_op.compute_output_spec((real_part, imag_part)) + + def test_fft2_low_rank(self): + fft2_op = kmath.FFT2() + low_rank_input = np.random.rand(3) # Rank of 1 + with self.assertRaisesRegex(ValueError, "Input should have rank >= 2"): + fft2_op.compute_output_spec((low_rank_input, low_rank_input)) + + def test_fft2_undefined_dimensions(self): + fft2_op = kmath.FFT2() + real_part = KerasTensor(shape=(None, None, 3), dtype="float32") + imag_part = KerasTensor(shape=(None, None, 3), dtype="float32") + with self.assertRaisesRegex( + ValueError, "Input should have its .* axes fully-defined" + ): + fft2_op.compute_output_spec((real_part, imag_part)) + + +class RFFTTest(testing.TestCase): + def test_rfft_low_rank_input(self): + rfft_op = kmath.RFFT() + low_rank_input = np.array(5) + with self.assertRaisesRegex(ValueError, "Input should have rank >= 1"): + rfft_op.compute_output_spec(low_rank_input) + + def test_rfft_defined_fft_length(self): + fft_length = 10 + rfft_op = kmath.RFFT(fft_length=fft_length) + input_tensor = np.random.rand(3, 8) + + expected_last_dimension = fft_length // 2 + 1 + expected_shape = input_tensor.shape[:-1] + (expected_last_dimension,) + + output_tensors = rfft_op.compute_output_spec(input_tensor) + for output_tensor in output_tensors: + self.assertEqual(output_tensor.shape, expected_shape) + + def test_rfft_undefined_fft_length_defined_last_dim(self): + rfft_op = kmath.RFFT() + input_tensor = np.random.rand(3, 8) + expected_last_dimension = input_tensor.shape[-1] // 2 + 1 + expected_shape = input_tensor.shape[:-1] + ( + expected_last_dimension, + ) + output_tensors = rfft_op.compute_output_spec(input_tensor) + for output_tensor in output_tensors: + self.assertEqual(output_tensor.shape, expected_shape) + + def test_rfft_undefined_fft_length_undefined_last_dim(self): + rfft_op = kmath.RFFT() + input_tensor = KerasTensor(shape=(None, None), dtype="float32") + expected_shape = input_tensor.shape[:-1] + (None,) + output_tensors = rfft_op.compute_output_spec(input_tensor) + for output_tensor in output_tensors: + self.assertEqual(output_tensor.shape, expected_shape) + + +class ISTFTTest(testing.TestCase): + def test_istft_incorrect_input_type(self): + istft_op = kmath.ISTFT( + sequence_length=5, sequence_stride=2, fft_length=10 + ) + incorrect_input = np.array([1, 2, 3]) + with self.assertRaisesRegex( + ValueError, "should be a tuple of two tensors" + ): + istft_op.compute_output_spec(incorrect_input) + + def test_istft_mismatched_shapes(self): + istft_op = kmath.ISTFT( + sequence_length=5, sequence_stride=2, fft_length=10 + ) + real_part = np.random.rand(2, 3, 4) + imag_part = np.random.rand(2, 3) + with self.assertRaisesRegex( + ValueError, + "Both the real and imaginary parts should have the same shape", + ): + istft_op.compute_output_spec((real_part, imag_part)) + + def test_istft_low_rank_input(self): + istft_op = kmath.ISTFT( + sequence_length=5, sequence_stride=2, fft_length=10 + ) + low_rank_input = np.random.rand(3) + with self.assertRaisesRegex(ValueError, "Input should have rank >= 2"): + istft_op.compute_output_spec((low_rank_input, low_rank_input))