Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe committed Feb 19, 2025
1 parent e0417e0 commit 2537c77
Showing 1 changed file with 54 additions and 1 deletion.
55 changes: 54 additions & 1 deletion tests/linen/linen_linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def test_circular_conv_1d_custom(self):
correct_ans = np.array((5 + 2 * 1 + 2, 3 + 2 * 4 + 5))
correct_ans = np.expand_dims(correct_ans, (0, 2))
np.testing.assert_allclose(y, correct_ans)

def test_circular_conv_local_1d_custom(self):
"""
Test 1d local convolution with circular padding and a stride
Expand Down Expand Up @@ -846,6 +846,59 @@ def test_circular_conv_local_2d_custom(self):
correct_ans = np.expand_dims(correct_ans, (0, 3))
np.testing.assert_allclose(y, correct_ans)

def test_reflect_conv_1d_custom(self):
"""Test 1d convolution with reflection padding and a stride."""
rng = dict(params=random.key(0))
x = np.arange(1, 6)
x = np.expand_dims(x, (0, 2))
kernel = np.array((1, 2, 1))
kernel = np.expand_dims(kernel, (1, 2))

conv_module = nn.Conv(
features=1,
kernel_size=(3,),
strides=(2,),
padding='REFLECT',
kernel_init=lambda *_: kernel,
bias_init=initializers.zeros,
)
y, initial_params = conv_module.init_with_output(rng, x)

self.assertEqual(initial_params['params']['kernel'].shape, (3, 1, 1))
# Compare with manually computed convolution
correct_ans = np.array((2 + 2 * 1 + 2, 2 + 2 * 3 + 4, 4 + 2 * 5 + 4))
correct_ans = np.expand_dims(correct_ans, (0, 2))
np.testing.assert_allclose(y, correct_ans)

def test_reflect_conv_2d_custom(self):
"""Test 2d convolution with reflect padding on a 3x3 example."""
rng = dict(params=random.key(0))
x = np.array(((1, 2, 3), (4, 5, 6), (7, 8, 9)))
x = np.expand_dims(x, (0, 3))
kernel = np.array(((0, 1, 0), (1, 2, 1), (0, 1, 0)))
kernel = np.expand_dims(kernel, (2, 3))

conv_module = nn.Conv(
features=1,
kernel_size=(3, 3),
padding='REFLECT',
kernel_init=lambda *_: kernel,
bias_init=initializers.zeros,
)
y, initial_params = conv_module.init_with_output(rng, x)

self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 1, 1))
# Compare with manually computed convolution
correct_ans = np.array(
(
(2 * 1 + 4 + 2 + 4 + 2, 2 * 2 + 5 + 3 + 5 + 1, 2 * 3 + 6 + 2 + 6 + 2),
(2 * 4 + 1 + 5 + 7 + 5, 2 * 5 + 2 + 6 + 8 + 4, 2 * 6 + 3 + 5 + 9 + 5),
(2 * 7 + 4 + 8 + 8 + 4, 2 * 8 + 5 + 9 + 5 + 7, 2 * 9 + 6 + 8 + 6 + 8),
)
)
correct_ans = np.expand_dims(correct_ans, (0, 3))
np.testing.assert_allclose(y, correct_ans)

def test_causal_conv1d(self):
rng = dict(params=random.key(0))
x = jnp.ones((1, 8, 4))
Expand Down

0 comments on commit 2537c77

Please sign in to comment.