diff --git a/flax/linen/linear.py b/flax/linen/linear.py index e5b5c3ba3..7aa0dab6e 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -444,7 +444,8 @@ class _Conv(Module): strides: an integer or a sequence of `n` integers, representing the inter-window strides (default: 1). padding: either the string ``'SAME'``, the string ``'VALID'``, the string - ``'CIRCULAR'`` (periodic boundary conditions), or a sequence of ``n`` ``(low, + ``'CIRCULAR'`` (periodic boundary conditions), the string `'REFLECT'` + (reflection across the padding boundary), or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding @@ -554,7 +555,7 @@ def maybe_broadcast( kernel_dilation = maybe_broadcast(self.kernel_dilation) padding_lax = canonicalize_padding(self.padding, len(kernel_size)) - if padding_lax == 'CIRCULAR': + if padding_lax in ('CIRCULAR', 'REFLECT'): kernel_size_dilated = [ (k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation) ] @@ -564,7 +565,8 @@ def maybe_broadcast( + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)] ) - inputs = jnp.pad(inputs, pads, mode='wrap') + padding_mode = {'CIRCULAR': 'wrap', 'REFLECT': 'reflect'}[padding_lax] + inputs = jnp.pad(inputs, pads, mode=padding_mode) padding_lax = 'VALID' elif padding_lax == 'CAUSAL': if len(kernel_size) != 1: diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py index e6cd308bd..12d2aac75 100644 --- a/flax/nnx/nn/linear.py +++ b/flax/nnx/nn/linear.py @@ -584,7 +584,8 @@ class Conv(Module): strides: an integer or a sequence of ``n`` integers, representing the inter-window strides (default: 1). padding: either the string ``'SAME'``, the string ``'VALID'``, the string - ``'CIRCULAR'`` (periodic boundary conditions), or a sequence of ``n`` + ``'CIRCULAR'`` (periodic boundary conditions), the string `'REFLECT'` + (reflection across the padding boundary), or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpeted as applying the same padding in all dims and passign a single int in a sequence causes the same padding @@ -720,7 +721,7 @@ def maybe_broadcast( kernel_dilation = maybe_broadcast(self.kernel_dilation) padding_lax = canonicalize_padding(self.padding, len(kernel_size)) - if padding_lax == 'CIRCULAR': + if padding_lax in ('CIRCULAR', 'REFLECT'): kernel_size_dilated = [ (k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation) ] @@ -730,7 +731,8 @@ def maybe_broadcast( + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)] ) - inputs = jnp.pad(inputs, pads, mode='wrap') + padding_mode = {'CIRCULAR': 'wrap', 'REFLECT': 'reflect'}[padding_lax] + inputs = jnp.pad(inputs, pads, mode=padding_mode) padding_lax = 'VALID' elif padding_lax == 'CAUSAL': if len(kernel_size) != 1: diff --git a/tests/linen/linen_linear_test.py b/tests/linen/linen_linear_test.py index d88af68e8..bb7cfce0a 100644 --- a/tests/linen/linen_linear_test.py +++ b/tests/linen/linen_linear_test.py @@ -846,6 +846,59 @@ def test_circular_conv_local_2d_custom(self): correct_ans = np.expand_dims(correct_ans, (0, 3)) np.testing.assert_allclose(y, correct_ans) + def test_reflect_conv_1d_custom(self): + """Test 1d convolution with reflection padding and a stride.""" + rng = dict(params=random.key(0)) + x = np.arange(1, 6) + x = np.expand_dims(x, (0, 2)) + kernel = np.array((1, 2, 1)) + kernel = np.expand_dims(kernel, (1, 2)) + + conv_module = nn.Conv( + features=1, + kernel_size=(3,), + strides=(2,), + padding='REFLECT', + kernel_init=lambda *_: kernel, + bias_init=initializers.zeros, + ) + y, initial_params = conv_module.init_with_output(rng, x) + + self.assertEqual(initial_params['params']['kernel'].shape, (3, 1, 1)) + # Compare with manually computed convolution + correct_ans = np.array((2 + 2 * 1 + 2, 2 + 2 * 3 + 4, 4 + 2 * 5 + 4)) + correct_ans = np.expand_dims(correct_ans, (0, 2)) + np.testing.assert_allclose(y, correct_ans) + + def test_reflect_conv_2d_custom(self): + """Test 2d convolution with reflect padding on a 3x3 example.""" + rng = dict(params=random.key(0)) + x = np.array(((1, 2, 3), (4, 5, 6), (7, 8, 9))) + x = np.expand_dims(x, (0, 3)) + kernel = np.array(((0, 1, 0), (1, 2, 1), (0, 1, 0))) + kernel = np.expand_dims(kernel, (2, 3)) + + conv_module = nn.Conv( + features=1, + kernel_size=(3, 3), + padding='REFLECT', + kernel_init=lambda *_: kernel, + bias_init=initializers.zeros, + ) + y, initial_params = conv_module.init_with_output(rng, x) + + self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 1, 1)) + # Compare with manually computed convolution + correct_ans = np.array( + ( + (2 * 1 + 4 + 2 + 4 + 2, 2 * 2 + 5 + 3 + 5 + 1, 2 * 3 + 6 + 2 + 6 + 2), + (2 * 4 + 1 + 5 + 7 + 5, 2 * 5 + 2 + 6 + 8 + 4, 2 * 6 + 3 + 5 + 9 + 5), + (2 * 7 + 4 + 8 + 8 + 4, 2 * 8 + 5 + 9 + 5 + 7, 2 * 9 + 6 + 8 + 6 + 8), + ) + ) + correct_ans = np.expand_dims(correct_ans, (0, 3)) + np.testing.assert_allclose(y, correct_ans) + def test_causal_conv1d(self): rng = dict(params=random.key(0)) x = jnp.ones((1, 8, 4)) diff --git a/tests/nnx/nn/conv_test.py b/tests/nnx/nn/conv_test.py index de0207a57..fe362036c 100644 --- a/tests/nnx/nn/conv_test.py +++ b/tests/nnx/nn/conv_test.py @@ -30,7 +30,7 @@ class TestConvLinenConsistency(parameterized.TestCase): @parameterized.product( strides=[None, (2, 3)], - padding=['VALID', 'CIRCULAR', (4, 2)], + padding=['VALID', 'CIRCULAR', 'REFLECT', (4, 2)], input_dilation=[(2, 3)], kernel_dilation=[(2, 3)], feature_group_count=[3],