Skip to content

Commit

Permalink
Add REFLECT padding to convolution layer
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe committed Feb 23, 2025
1 parent 8d3c22b commit d552d4b
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 7 deletions.
8 changes: 5 additions & 3 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,8 @@ class _Conv(Module):
strides: an integer or a sequence of `n` integers, representing the
inter-window strides (default: 1).
padding: either the string ``'SAME'``, the string ``'VALID'``, the string
``'CIRCULAR'`` (periodic boundary conditions), or a sequence of ``n`` ``(low,
``'CIRCULAR'`` (periodic boundary conditions), the string `'REFLECT'`
(reflection across the padding boundary), or a sequence of ``n`` ``(low,
high)`` integer pairs that give the padding to apply before and after each
spatial dimension. A single int is interpreted as applying the same padding
in all dims and assign a single int in a sequence causes the same padding
Expand Down Expand Up @@ -554,7 +555,7 @@ def maybe_broadcast(
kernel_dilation = maybe_broadcast(self.kernel_dilation)

padding_lax = canonicalize_padding(self.padding, len(kernel_size))
if padding_lax == 'CIRCULAR':
if padding_lax in ('CIRCULAR', 'REFLECT'):
kernel_size_dilated = [
(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)
]
Expand All @@ -564,7 +565,8 @@ def maybe_broadcast(
+ [((k - 1) // 2, k // 2) for k in kernel_size_dilated]
+ [(0, 0)]
)
inputs = jnp.pad(inputs, pads, mode='wrap')
padding_mode = {'CIRCULAR': 'wrap', 'REFLECT': 'reflect'}[padding_lax]
inputs = jnp.pad(inputs, pads, mode=padding_mode)
padding_lax = 'VALID'
elif padding_lax == 'CAUSAL':
if len(kernel_size) != 1:
Expand Down
8 changes: 5 additions & 3 deletions flax/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,8 @@ class Conv(Module):
strides: an integer or a sequence of ``n`` integers, representing the
inter-window strides (default: 1).
padding: either the string ``'SAME'``, the string ``'VALID'``, the string
``'CIRCULAR'`` (periodic boundary conditions), or a sequence of ``n``
``'CIRCULAR'`` (periodic boundary conditions), the string `'REFLECT'`
(reflection across the padding boundary), or a sequence of ``n``
``(low, high)`` integer pairs that give the padding to apply before and after each
spatial dimension. A single int is interpeted as applying the same padding
in all dims and passign a single int in a sequence causes the same padding
Expand Down Expand Up @@ -720,7 +721,7 @@ def maybe_broadcast(
kernel_dilation = maybe_broadcast(self.kernel_dilation)

padding_lax = canonicalize_padding(self.padding, len(kernel_size))
if padding_lax == 'CIRCULAR':
if padding_lax in ('CIRCULAR', 'REFLECT'):
kernel_size_dilated = [
(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)
]
Expand All @@ -730,7 +731,8 @@ def maybe_broadcast(
+ [((k - 1) // 2, k // 2) for k in kernel_size_dilated]
+ [(0, 0)]
)
inputs = jnp.pad(inputs, pads, mode='wrap')
padding_mode = {'CIRCULAR': 'wrap', 'REFLECT': 'reflect'}[padding_lax]
inputs = jnp.pad(inputs, pads, mode=padding_mode)
padding_lax = 'VALID'
elif padding_lax == 'CAUSAL':
if len(kernel_size) != 1:
Expand Down
53 changes: 53 additions & 0 deletions tests/linen/linen_linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,59 @@ def test_circular_conv_local_2d_custom(self):
correct_ans = np.expand_dims(correct_ans, (0, 3))
np.testing.assert_allclose(y, correct_ans)

def test_reflect_conv_1d_custom(self):
"""Test 1d convolution with reflection padding and a stride."""
rng = dict(params=random.key(0))
x = np.arange(1, 6)
x = np.expand_dims(x, (0, 2))
kernel = np.array((1, 2, 1))
kernel = np.expand_dims(kernel, (1, 2))

conv_module = nn.Conv(
features=1,
kernel_size=(3,),
strides=(2,),
padding='REFLECT',
kernel_init=lambda *_: kernel,
bias_init=initializers.zeros,
)
y, initial_params = conv_module.init_with_output(rng, x)

self.assertEqual(initial_params['params']['kernel'].shape, (3, 1, 1))
# Compare with manually computed convolution
correct_ans = np.array((2 + 2 * 1 + 2, 2 + 2 * 3 + 4, 4 + 2 * 5 + 4))
correct_ans = np.expand_dims(correct_ans, (0, 2))
np.testing.assert_allclose(y, correct_ans)

def test_reflect_conv_2d_custom(self):
"""Test 2d convolution with reflect padding on a 3x3 example."""
rng = dict(params=random.key(0))
x = np.array(((1, 2, 3), (4, 5, 6), (7, 8, 9)))
x = np.expand_dims(x, (0, 3))
kernel = np.array(((0, 1, 0), (1, 2, 1), (0, 1, 0)))
kernel = np.expand_dims(kernel, (2, 3))

conv_module = nn.Conv(
features=1,
kernel_size=(3, 3),
padding='REFLECT',
kernel_init=lambda *_: kernel,
bias_init=initializers.zeros,
)
y, initial_params = conv_module.init_with_output(rng, x)

self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 1, 1))
# Compare with manually computed convolution
correct_ans = np.array(
(
(2 * 1 + 4 + 2 + 4 + 2, 2 * 2 + 5 + 3 + 5 + 1, 2 * 3 + 6 + 2 + 6 + 2),
(2 * 4 + 1 + 5 + 7 + 5, 2 * 5 + 2 + 6 + 8 + 4, 2 * 6 + 3 + 5 + 9 + 5),
(2 * 7 + 4 + 8 + 8 + 4, 2 * 8 + 5 + 9 + 5 + 7, 2 * 9 + 6 + 8 + 6 + 8),
)
)
correct_ans = np.expand_dims(correct_ans, (0, 3))
np.testing.assert_allclose(y, correct_ans)

def test_causal_conv1d(self):
rng = dict(params=random.key(0))
x = jnp.ones((1, 8, 4))
Expand Down
2 changes: 1 addition & 1 deletion tests/nnx/nn/conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
class TestConvLinenConsistency(parameterized.TestCase):
@parameterized.product(
strides=[None, (2, 3)],
padding=['VALID', 'CIRCULAR', (4, 2)],
padding=['VALID', 'CIRCULAR', 'REFLECT', (4, 2)],
input_dilation=[(2, 3)],
kernel_dilation=[(2, 3)],
feature_group_count=[3],
Expand Down

0 comments on commit d552d4b

Please sign in to comment.