Skip to content

Commit

Permalink
Add pad_to_bounding_box image ops (#18503)
Browse files Browse the repository at this point in the history
* add: `pad_to_bounding_box`

* update: format

* update: use f string

* update: rename to `pad_image`

* update: rename arguments

`offset_height` -> `top_padding`
`offset_width` -> `left_padding`

* add: `bottom_padding` and `right_padding`

* update: remove `check_dims`

* update: compute output shape for unbatched and batched inputs

* add: dynamic shape test

* add: static shape test

* add: height & width both dynamic test

* fix: height & width `None` test

* update: remove  residual`check_dims`

* add: correctness test

* fix: no `expand` attribute

AttributeError: module 'absl.testing.parameterized' has no attribute 'expand'

* fix: `.numpy()` bug

AttributeError: 'numpy.ndarray' object has no attribute 'numpy'

* update: `pad_image` -> `pad_images`

* update: use tuple for shapes

* update: remove `-` from 3D & 4D

* update: `PadImage` -> `PadImages`

* update: `image` -> `images`

* update: systematic argument checking and inference

* update: make cond more readable

* update: if paddings `None` cond

* update: print the values passed

* update: show received shape

* update: use `ops.shape` instead of `.shape`

* update: show invalid padding values
  • Loading branch information
awsaf49 authored Oct 24, 2023
1 parent ee8b1ea commit dd3acdb
Show file tree
Hide file tree
Showing 2 changed files with 291 additions and 0 deletions.
224 changes: 224 additions & 0 deletions keras/ops/image.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from keras import backend
from keras import ops
from keras.api_export import keras_export
from keras.backend import KerasTensor
from keras.backend import any_symbolic_tensors
Expand Down Expand Up @@ -501,3 +502,226 @@ def map_coordinates(
fill_mode,
fill_value,
)


class PadImages(Operation):
def __init__(
self,
top_padding,
bottom_padding,
left_padding,
right_padding,
target_height,
target_width,
):
super().__init__()
self.top_padding = top_padding
self.bottom_padding = bottom_padding
self.left_padding = left_padding
self.right_padding = right_padding
self.target_height = target_height
self.target_width = target_width

def call(self, images):
return _pad_images(
images,
self.top_padding,
self.bottom_padding,
self.left_padding,
self.right_padding,
self.target_height,
self.target_width,
)

def compute_output_spec(self, images):
images_shape = ops.shape(images)
if self.target_height is None:
height_axis = 0 if len(images_shape) == 3 else 1
self.target_height = (
self.top_padding
+ images_shape[height_axis]
+ self.bottom_padding
)
if self.target_width is None:
width_axis = 0 if len(images_shape) == 3 else 2
self.target_width = (
self.left_padding
+ images_shape[width_axis]
+ self.right_padding
)
out_shape = (
images_shape[0],
self.target_height,
self.target_width,
images_shape[-1],
)
if len(images_shape) == 3:
out_shape = out_shape[1:]
return KerasTensor(
shape=out_shape,
dtype=images.dtype,
)


def _pad_images(
images,
top_padding,
bottom_padding,
left_padding,
right_padding,
target_height,
target_width,
):
images = backend.convert_to_tensor(images)
is_batch = True
images_shape = ops.shape(images)
if len(images_shape) == 3:
is_batch = False
images = backend.numpy.expand_dims(images, 0)
elif len(images_shape) != 4:
raise ValueError(
f"Invalid shape for argument `images`: "
"it must have rank 3 or 4. "
f"Received: images.shape={images_shape}"
)

batch, height, width, depth = ops.shape(images)

if [top_padding, bottom_padding, target_height].count(None) != 1:
raise ValueError(
"Must specify exactly two of "
"top_padding, bottom_padding, target_height. "
f"Received: top_padding={top_padding}, "
f"bottom_padding={bottom_padding}, "
f"target_height={target_height}"
)
if [left_padding, right_padding, target_width].count(None) != 1:
raise ValueError(
"Must specify exactly two of "
"left_padding, right_padding, target_width. "
f"Received: left_padding={left_padding}, "
f"right_padding={right_padding}, "
f"target_width={target_width}"
)

if top_padding is None:
top_padding = target_height - bottom_padding - height
if bottom_padding is None:
bottom_padding = target_height - top_padding - height
if left_padding is None:
left_padding = target_width - right_padding - width
if right_padding is None:
right_padding = target_width - left_padding - width

if top_padding < 0:
raise ValueError(
"top_padding must be >= 0. " f"Received: top_padding={top_padding}"
)
if left_padding < 0:
raise ValueError(
"left_padding must be >= 0. "
f"Received: left_padding={left_padding}"
)
if right_padding < 0:
raise ValueError(
"right_padding must be >= 0. "
f"Received: right_padding={right_padding}"
)
if bottom_padding < 0:
raise ValueError(
"bottom_padding must be >= 0. "
f"Received: bottom_padding={bottom_padding}"
)

paddings = backend.numpy.reshape(
backend.numpy.stack(
[
0,
0,
top_padding,
bottom_padding,
left_padding,
right_padding,
0,
0,
]
),
[4, 2],
)
padded = backend.numpy.pad(images, paddings)

if target_height is None:
target_height = top_padding + height + bottom_padding
if target_width is None:
target_width = left_padding + width + right_padding
padded_shape = [batch, target_height, target_width, depth]
padded = backend.numpy.reshape(padded, padded_shape)

if not is_batch:
padded = backend.numpy.squeeze(padded, axis=[0])
return padded


@keras_export("keras.ops.image.pad_images")
def pad_images(
images,
top_padding=None,
left_padding=None,
target_height=None,
target_width=None,
bottom_padding=None,
right_padding=None,
):
"""Pad `images` with zeros to the specified `height` and `width`.
Args:
images: 4D Tensor of shape `(batch, height, width, channels)` or 3D
Tensor of shape `(height, width, channels)`.
top_padding: Number of rows of zeros to add on top.
bottom_padding: Number of rows of zeros to add at the bottom.
left_padding: Number of columns of zeros to add on the left.
right_padding: Number of columns of zeros to add on the right.
target_height: Height of output images.
target_width: Width of output images.
Returns:
If `images` were 4D, a 4D float Tensor of shape
`(batch, target_height, target_width, channels)`
If `images` were 3D, a 3D float Tensor of shape
`(target_height, target_width, channels)`
Example:
>>> images = np.random.random((15, 25, 3))
>>> padded_images = keras.ops.image.pad_images(
... images, 2, 3, target_height=20, target_width=30
... )
>>> padded_images.shape
(20, 30, 3)
>>> batch_images = np.random.random((2, 15, 25, 3))
>>> padded_batch = keras.ops.image.pad_images(
... batch_images, 2, 3, target_height=20, target_width=30
... )
>>> padded_batch.shape
(2, 20, 30, 3)"""

if any_symbolic_tensors((images,)):
return PadImages(
top_padding,
bottom_padding,
left_padding,
right_padding,
target_height,
target_width,
).symbolic_call(images)

return _pad_images(
images,
top_padding,
bottom_padding,
left_padding,
right_padding,
target_height,
target_width,
)
67 changes: 67 additions & 0 deletions keras/ops/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ def test_map_coordinates(self):
out = kimage.map_coordinates(input, coordinates, 0)
self.assertEqual(out.shape, coordinates.shape[1:])

def test_pad_images(self):
x = KerasTensor([None, 15, 25, 3])
out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30)
self.assertEqual(out.shape, (None, 20, 30, 3))

x = KerasTensor([None, None, 3])
out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30)
self.assertEqual(out.shape, (20, 30, 3))


class ImageOpsStaticShapeTest(testing.TestCase):
def test_resize(self):
Expand Down Expand Up @@ -69,6 +78,17 @@ def test_map_coordinates(self):
out = kimage.map_coordinates(input, coordinates, 0)
self.assertEqual(out.shape, coordinates.shape[1:])

def test_pad_images(self):
x = KerasTensor([15, 25, 3])
out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30)
self.assertEqual(out.shape, (20, 30, 3))

x_batch = KerasTensor([2, 15, 25, 3])
out_batch = kimage.pad_images(
x_batch, 2, 3, target_height=20, target_width=30
)
self.assertEqual(out_batch.shape, (2, 20, 30, 3))


AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order
"nearest": 0,
Expand Down Expand Up @@ -413,3 +433,50 @@ def test_map_coordinates(self, shape, dtype, order, fill_mode):
expected = _fixed_map_coordinates(input, coordinates, order, fill_mode)

self.assertAllClose(output, expected)

@parameterized.parameters(
[
(0, 0, 3, 3, None, None),
(1, 0, 4, 3, None, None),
(0, 1, 3, 4, None, None),
(0, 0, 4, 3, None, None),
(0, 0, 3, 4, None, None),
(0, 0, None, None, 0, 1),
(0, 0, None, None, 1, 0),
(1, 2, None, None, 3, 4),
]
)
def test_pad_images(
self,
top_padding,
left_padding,
target_height,
target_width,
bottom_padding,
right_padding,
):
image = np.random.uniform(size=(3, 3, 1))
padded_image = kimage.pad_images(
image,
top_padding,
left_padding,
target_height,
target_width,
bottom_padding,
right_padding,
)
if target_height is None:
target_height = image.shape[0] + top_padding + bottom_padding
if target_width is None:
target_width = image.shape[1] + left_padding + right_padding
ref_padded_image = tf.image.pad_to_bounding_box(
image, top_padding, left_padding, target_height, target_width
)
self.assertEqual(
tuple(padded_image.shape), tuple(ref_padded_image.shape)
)
self.assertAllClose(
ref_padded_image.numpy(),
backend.convert_to_numpy(padded_image),
atol=1e-5,
)

0 comments on commit dd3acdb

Please sign in to comment.