Skip to content

Commit

Permalink
fix and add tests for /ops/math.py (#19080)
Browse files Browse the repository at this point in the history
* Add new math operations and tests

* fix FFT2 class and add more tests
  • Loading branch information
Faisal-Alsrheed authored Jan 22, 2024
1 parent dad5342 commit dfadf6a
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 2 deletions.
12 changes: 10 additions & 2 deletions keras/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ def extract_sequences(x, sequence_length, sequence_stride):


class FFT(Operation):
def __init__(self, axis=-1):
super().__init__()
self.axis = axis

def compute_output_spec(self, x):
if not isinstance(x, (tuple, list)) or len(x) != 2:
raise ValueError(
Expand Down Expand Up @@ -449,6 +453,10 @@ def fft(x):


class FFT2(Operation):
def __init__(self):
super().__init__()
self.axes = (-2, -1)

def compute_output_spec(self, x):
if not isinstance(x, (tuple, list)) or len(x) != 2:
raise ValueError(
Expand All @@ -473,8 +481,8 @@ def compute_output_spec(self, x):
)

# The axes along which we are calculating FFT should be fully-defined.
m = real.shape[-1]
n = real.shape[-2]
m = real.shape[self.axes[0]]
n = real.shape[self.axes[1]]
if m is None or n is None:
raise ValueError(
f"Input should have its {self.axes} axes fully-defined. "
Expand Down
237 changes: 237 additions & 0 deletions keras/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,3 +976,240 @@ def calculate_expected_shape(
(input_shape[1] - sequence_length) // sequence_stride
) + 1
return (input_shape[0], num_sequences, sequence_length)


class SegmentSumTest(testing.TestCase):
def test_segment_sum_call(self):
data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
segment_ids = np.array([0, 0, 1], dtype=np.int32)
num_segments = 2
sorted_segments = False
segment_sum_op = kmath.SegmentSum(
num_segments=num_segments, sorted=sorted_segments
)
output = segment_sum_op.call(data, segment_ids)
expected_output = np.array([[5, 7, 9], [7, 8, 9]], dtype=np.float32)
self.assertAllClose(output, expected_output)


class SegmentMaxTest(testing.TestCase):
def test_segment_max_call(self):
data = np.array([[1, 4, 7], [2, 5, 8], [3, 6, 9]], dtype=np.float32)
segment_ids = np.array([0, 0, 1], dtype=np.int32)
num_segments = 2
sorted_segments = False
segment_max_op = kmath.SegmentMax(
num_segments=num_segments, sorted=sorted_segments
)
output = segment_max_op.call(data, segment_ids)
expected_output = np.array([[2, 5, 8], [3, 6, 9]], dtype=np.float32)
self.assertAllClose(output, expected_output)


class TopKTest(testing.TestCase):
def test_top_k_call_values(self):
data = np.array([[1, 3, 2], [4, 6, 5]], dtype=np.float32)
k = 2
sorted_flag = True
top_k_op = kmath.TopK(k=k, sorted=sorted_flag)
values, _ = top_k_op.call(data)
expected_values = np.array([[3, 2], [6, 5]], dtype=np.float32)
self.assertAllClose(values, expected_values)

def test_top_k_call_indices(self):
data = np.array([[1, 3, 2], [4, 6, 5]], dtype=np.float32)
k = 2
sorted_flag = True
top_k_op = kmath.TopK(k=k, sorted=sorted_flag)
_, indices = top_k_op.call(data)
expected_indices = np.array([[1, 2], [1, 2]], dtype=np.int32)
self.assertAllClose(indices, expected_indices)


class InTopKTest(testing.TestCase):
def test_in_top_k_call(self):
targets = np.array([2, 0, 1], dtype=np.int32)
predictions = np.array(
[[0.1, 0.2, 0.7], [1.0, 0.2, 0.3], [0.2, 0.6, 0.2]],
dtype=np.float32,
)
k = 2
in_top_k_op = kmath.InTopK(k=k)
output = in_top_k_op.call(targets, predictions)
expected_output = np.array([True, True, True], dtype=bool)
self.assertAllEqual(output, expected_output)


class LogsumexpTest(testing.TestCase):
def test_logsumexp_call(self):
x = np.array([[1, 2], [3, 4]], dtype=np.float32)
axis = 0
keepdims = True
logsumexp_op = kmath.Logsumexp(axis=axis, keepdims=keepdims)
output = logsumexp_op.call(x)
expected_output = np.log(
np.sum(np.exp(x), axis=axis, keepdims=keepdims)
)
self.assertAllClose(output, expected_output)


class FFTTest(testing.TestCase):
def test_fft_input_not_tuple_or_list(self):
fft_op = kmath.FFT()
with self.assertRaisesRegex(
ValueError, "Input `x` should be a tuple of two tensors"
):
fft_op.compute_output_spec(np.array([1, 2, 3]))

def test_fft_input_parts_different_shapes(self):
fft_op = kmath.FFT()
real = np.array([1, 2, 3])
imag = np.array([1, 2])
with self.assertRaisesRegex(
ValueError,
"Both the real and imaginary parts should have the same shape",
):
fft_op.compute_output_spec((real, imag))

def test_fft_input_not_1d(self):
fft_op = kmath.FFT()
real = np.array(1)
imag = np.array(1)
with self.assertRaisesRegex(ValueError, "Input should have rank >= 1"):
fft_op.compute_output_spec((real, imag))

def test_fft_last_axis_not_fully_defined(self):
fft_op = kmath.FFT()
real = KerasTensor(shape=(None,), dtype="float32")
imag = KerasTensor(shape=(None,), dtype="float32")
with self.assertRaisesRegex(
ValueError, "Input should have its -1th axis fully-defined"
):
fft_op.compute_output_spec((real, imag))

def test_fft_init_default_axis(self):
fft_op = kmath.FFT()
self.assertEqual(fft_op.axis, -1, "Default axis should be -1")


class SolveTest(testing.TestCase):
def test_solve_call(self):
solve_op = kmath.Solve()
a = np.array([[3, 2], [1, 2]], dtype=np.float32)
b = np.array([[9, 8], [5, 4]], dtype=np.float32)
output = solve_op.call(a, b)
expected_output = np.linalg.solve(a, b)
np.testing.assert_allclose(output, expected_output, atol=1e-6)


class FFT2Test(testing.TestCase):
def test_fft2_correct_input(self):
fft2_op = kmath.FFT2()
real_part = np.random.rand(2, 3, 4)
imag_part = np.random.rand(2, 3, 4)
# This should not raise any errors
fft2_op.compute_output_spec((real_part, imag_part))

def test_fft2_incorrect_input_type(self):
fft2_op = kmath.FFT2()
incorrect_input = np.array([1, 2, 3]) # Not a tuple or list
with self.assertRaisesRegex(
ValueError, "should be a tuple of two tensors"
):
fft2_op.compute_output_spec(incorrect_input)

def test_fft2_mismatched_shapes(self):
fft2_op = kmath.FFT2()
real_part = np.random.rand(2, 3, 4)
imag_part = np.random.rand(2, 3) # Mismatched shape
with self.assertRaisesRegex(
ValueError,
"Both the real and imaginary parts should have the same shape",
):
fft2_op.compute_output_spec((real_part, imag_part))

def test_fft2_low_rank(self):
fft2_op = kmath.FFT2()
low_rank_input = np.random.rand(3) # Rank of 1
with self.assertRaisesRegex(ValueError, "Input should have rank >= 2"):
fft2_op.compute_output_spec((low_rank_input, low_rank_input))

def test_fft2_undefined_dimensions(self):
fft2_op = kmath.FFT2()
real_part = KerasTensor(shape=(None, None, 3), dtype="float32")
imag_part = KerasTensor(shape=(None, None, 3), dtype="float32")
with self.assertRaisesRegex(
ValueError, "Input should have its .* axes fully-defined"
):
fft2_op.compute_output_spec((real_part, imag_part))


class RFFTTest(testing.TestCase):
def test_rfft_low_rank_input(self):
rfft_op = kmath.RFFT()
low_rank_input = np.array(5)
with self.assertRaisesRegex(ValueError, "Input should have rank >= 1"):
rfft_op.compute_output_spec(low_rank_input)

def test_rfft_defined_fft_length(self):
fft_length = 10
rfft_op = kmath.RFFT(fft_length=fft_length)
input_tensor = np.random.rand(3, 8)

expected_last_dimension = fft_length // 2 + 1
expected_shape = input_tensor.shape[:-1] + (expected_last_dimension,)

output_tensors = rfft_op.compute_output_spec(input_tensor)
for output_tensor in output_tensors:
self.assertEqual(output_tensor.shape, expected_shape)

def test_rfft_undefined_fft_length_defined_last_dim(self):
rfft_op = kmath.RFFT()
input_tensor = np.random.rand(3, 8)
expected_last_dimension = input_tensor.shape[-1] // 2 + 1
expected_shape = input_tensor.shape[:-1] + (
expected_last_dimension,
)
output_tensors = rfft_op.compute_output_spec(input_tensor)
for output_tensor in output_tensors:
self.assertEqual(output_tensor.shape, expected_shape)

def test_rfft_undefined_fft_length_undefined_last_dim(self):
rfft_op = kmath.RFFT()
input_tensor = KerasTensor(shape=(None, None), dtype="float32")
expected_shape = input_tensor.shape[:-1] + (None,)
output_tensors = rfft_op.compute_output_spec(input_tensor)
for output_tensor in output_tensors:
self.assertEqual(output_tensor.shape, expected_shape)


class ISTFTTest(testing.TestCase):
def test_istft_incorrect_input_type(self):
istft_op = kmath.ISTFT(
sequence_length=5, sequence_stride=2, fft_length=10
)
incorrect_input = np.array([1, 2, 3])
with self.assertRaisesRegex(
ValueError, "should be a tuple of two tensors"
):
istft_op.compute_output_spec(incorrect_input)

def test_istft_mismatched_shapes(self):
istft_op = kmath.ISTFT(
sequence_length=5, sequence_stride=2, fft_length=10
)
real_part = np.random.rand(2, 3, 4)
imag_part = np.random.rand(2, 3)
with self.assertRaisesRegex(
ValueError,
"Both the real and imaginary parts should have the same shape",
):
istft_op.compute_output_spec((real_part, imag_part))

def test_istft_low_rank_input(self):
istft_op = kmath.ISTFT(
sequence_length=5, sequence_stride=2, fft_length=10
)
low_rank_input = np.random.rand(3)
with self.assertRaisesRegex(ValueError, "Input should have rank >= 2"):
istft_op.compute_output_spec((low_rank_input, low_rank_input))

0 comments on commit dfadf6a

Please sign in to comment.