-
Notifications
You must be signed in to change notification settings - Fork 19.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add random_shear processing layer (#20702)
* Add random_shear processing layer * Update method name * Fix failed test case * Fix failed test case * Fix failed test case
- Loading branch information
Showing
5 changed files
with
348 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
263 changes: 263 additions & 0 deletions
263
keras/src/layers/preprocessing/image_preprocessing/random_shear.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,263 @@ | ||
from keras.src.api_export import keras_export | ||
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 | ||
BaseImagePreprocessingLayer, | ||
) | ||
from keras.src.random.seed_generator import SeedGenerator | ||
|
||
|
||
@keras_export("keras.layers.RandomShear") | ||
class RandomShear(BaseImagePreprocessingLayer): | ||
"""A preprocessing layer that randomly applies shear transformations to | ||
images. | ||
This layer shears the input images along the x-axis and/or y-axis by a | ||
randomly selected factor within the specified range. The shear | ||
transformation is applied to each image independently in a batch. Empty | ||
regions created during the transformation are filled according to the | ||
`fill_mode` and `fill_value` parameters. | ||
Args: | ||
x_factor: A tuple of two floats. For each augmented image, a value | ||
is sampled from the provided range. If a float is passed, the | ||
range is interpreted as `(0, x_factor)`. Values represent a | ||
percentage of the image to shear over. For example, 0.3 shears | ||
pixels up to 30% of the way across the image. All provided values | ||
should be positive. | ||
y_factor: A tuple of two floats. For each augmented image, a value | ||
is sampled from the provided range. If a float is passed, the | ||
range is interpreted as `(0, y_factor)`. Values represent a | ||
percentage of the image to shear over. For example, 0.3 shears | ||
pixels up to 30% of the way across the image. All provided values | ||
should be positive. | ||
interpolation: Interpolation mode. Supported values: `"nearest"`, | ||
`"bilinear"`. | ||
fill_mode: Points outside the boundaries of the input are filled | ||
according to the given mode. Available methods are `"constant"`, | ||
`"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`. | ||
- `"reflect"`: `(d c b a | a b c d | d c b a)` | ||
The input is extended by reflecting about the edge of the | ||
last pixel. | ||
- `"constant"`: `(k k k k | a b c d | k k k k)` | ||
The input is extended by filling all values beyond the edge | ||
with the same constant value `k` specified by `fill_value`. | ||
- `"wrap"`: `(a b c d | a b c d | a b c d)` | ||
The input is extended by wrapping around to the opposite edge. | ||
- `"nearest"`: `(a a a a | a b c d | d d d d)` | ||
The input is extended by the nearest pixel. | ||
Note that when using torch backend, `"reflect"` is redirected to | ||
`"mirror"` `(c d c b | a b c d | c b a b)` because torch does | ||
not support `"reflect"`. | ||
Note that torch backend does not support `"wrap"`. | ||
fill_value: A float representing the value to be filled outside the | ||
boundaries when `fill_mode="constant"`. | ||
seed: Integer. Used to create a random seed. | ||
""" | ||
|
||
_USE_BASE_FACTOR = False | ||
_FACTOR_BOUNDS = (0, 1) | ||
_FACTOR_VALIDATION_ERROR = ( | ||
"The `factor` argument should be a number (or a list of two numbers) " | ||
"in the range [0, 1.0]. " | ||
) | ||
_SUPPORTED_FILL_MODE = ("reflect", "wrap", "constant", "nearest") | ||
_SUPPORTED_INTERPOLATION = ("nearest", "bilinear") | ||
|
||
def __init__( | ||
self, | ||
x_factor=0.0, | ||
y_factor=0.0, | ||
interpolation="bilinear", | ||
fill_mode="reflect", | ||
fill_value=0.0, | ||
data_format=None, | ||
seed=None, | ||
**kwargs, | ||
): | ||
super().__init__(data_format=data_format, **kwargs) | ||
self.x_factor = self._set_factor_with_name(x_factor, "x_factor") | ||
self.y_factor = self._set_factor_with_name(y_factor, "y_factor") | ||
|
||
if fill_mode not in self._SUPPORTED_FILL_MODE: | ||
raise NotImplementedError( | ||
f"Unknown `fill_mode` {fill_mode}. Expected of one " | ||
f"{self._SUPPORTED_FILL_MODE}." | ||
) | ||
if interpolation not in self._SUPPORTED_INTERPOLATION: | ||
raise NotImplementedError( | ||
f"Unknown `interpolation` {interpolation}. Expected of one " | ||
f"{self._SUPPORTED_INTERPOLATION}." | ||
) | ||
|
||
self.fill_mode = fill_mode | ||
self.fill_value = fill_value | ||
self.interpolation = interpolation | ||
self.seed = seed | ||
self.generator = SeedGenerator(seed) | ||
self.supports_jit = False | ||
|
||
def _set_factor_with_name(self, factor, factor_name): | ||
if isinstance(factor, (tuple, list)): | ||
if len(factor) != 2: | ||
raise ValueError( | ||
self._FACTOR_VALIDATION_ERROR | ||
+ f"Received: {factor_name}={factor}" | ||
) | ||
self._check_factor_range(factor[0]) | ||
self._check_factor_range(factor[1]) | ||
lower, upper = sorted(factor) | ||
elif isinstance(factor, (int, float)): | ||
self._check_factor_range(factor) | ||
factor = abs(factor) | ||
lower, upper = [-factor, factor] | ||
else: | ||
raise ValueError( | ||
self._FACTOR_VALIDATION_ERROR | ||
+ f"Received: {factor_name}={factor}" | ||
) | ||
return lower, upper | ||
|
||
def _check_factor_range(self, input_number): | ||
if input_number > 1.0 or input_number < 0.0: | ||
raise ValueError( | ||
self._FACTOR_VALIDATION_ERROR | ||
+ f"Received: input_number={input_number}" | ||
) | ||
|
||
def get_random_transformation(self, data, training=True, seed=None): | ||
if not training: | ||
return None | ||
|
||
if isinstance(data, dict): | ||
images = data["images"] | ||
else: | ||
images = data | ||
|
||
images_shape = self.backend.shape(images) | ||
if len(images_shape) == 3: | ||
batch_size = 1 | ||
else: | ||
batch_size = images_shape[0] | ||
|
||
if seed is None: | ||
seed = self._get_seed_generator(self.backend._backend) | ||
|
||
invert = self.backend.random.uniform( | ||
minval=0, | ||
maxval=1, | ||
shape=[batch_size, 1], | ||
seed=seed, | ||
dtype=self.compute_dtype, | ||
) | ||
invert = self.backend.numpy.where( | ||
invert > 0.5, | ||
-self.backend.numpy.ones_like(invert), | ||
self.backend.numpy.ones_like(invert), | ||
) | ||
|
||
shear_y = self.backend.random.uniform( | ||
minval=self.y_factor[0], | ||
maxval=self.y_factor[1], | ||
shape=[batch_size, 1], | ||
seed=seed, | ||
dtype=self.compute_dtype, | ||
) | ||
shear_x = self.backend.random.uniform( | ||
minval=self.x_factor[0], | ||
maxval=self.x_factor[1], | ||
shape=[batch_size, 1], | ||
seed=seed, | ||
dtype=self.compute_dtype, | ||
) | ||
shear_factor = ( | ||
self.backend.cast( | ||
self.backend.numpy.concatenate([shear_x, shear_y], axis=1), | ||
dtype=self.compute_dtype, | ||
) | ||
* invert | ||
) | ||
return {"shear_factor": shear_factor} | ||
|
||
def transform_images(self, images, transformation, training=True): | ||
images = self.backend.cast(images, self.compute_dtype) | ||
if training: | ||
return self._shear_inputs(images, transformation) | ||
return images | ||
|
||
def _shear_inputs(self, inputs, transformation): | ||
if transformation is None: | ||
return inputs | ||
|
||
inputs_shape = self.backend.shape(inputs) | ||
unbatched = len(inputs_shape) == 3 | ||
if unbatched: | ||
inputs = self.backend.numpy.expand_dims(inputs, axis=0) | ||
|
||
shear_factor = transformation["shear_factor"] | ||
outputs = self.backend.image.affine_transform( | ||
inputs, | ||
transform=self._get_shear_matrix(shear_factor), | ||
interpolation=self.interpolation, | ||
fill_mode=self.fill_mode, | ||
fill_value=self.fill_value, | ||
data_format=self.data_format, | ||
) | ||
|
||
if unbatched: | ||
outputs = self.backend.numpy.squeeze(outputs, axis=0) | ||
return outputs | ||
|
||
def _get_shear_matrix(self, shear_factors): | ||
num_shear_factors = self.backend.shape(shear_factors)[0] | ||
|
||
# The shear matrix looks like: | ||
# [[1 s_x 0] | ||
# [s_y 1 0] | ||
# [0 0 1]] | ||
|
||
return self.backend.numpy.stack( | ||
[ | ||
self.backend.numpy.ones((num_shear_factors,)), | ||
shear_factors[:, 0], | ||
self.backend.numpy.zeros((num_shear_factors,)), | ||
shear_factors[:, 1], | ||
self.backend.numpy.ones((num_shear_factors,)), | ||
self.backend.numpy.zeros((num_shear_factors,)), | ||
self.backend.numpy.zeros((num_shear_factors,)), | ||
self.backend.numpy.zeros((num_shear_factors,)), | ||
], | ||
axis=1, | ||
) | ||
|
||
def transform_labels(self, labels, transformation, training=True): | ||
return labels | ||
|
||
def transform_bounding_boxes( | ||
self, | ||
bounding_boxes, | ||
transformation, | ||
training=True, | ||
): | ||
raise NotImplementedError | ||
|
||
def transform_segmentation_masks( | ||
self, segmentation_masks, transformation, training=True | ||
): | ||
return self.transform_images( | ||
segmentation_masks, transformation, training=training | ||
) | ||
|
||
def get_config(self): | ||
base_config = super().get_config() | ||
config = { | ||
"x_factor": self.x_factor, | ||
"y_factor": self.y_factor, | ||
"fill_mode": self.fill_mode, | ||
"interpolation": self.interpolation, | ||
"seed": self.seed, | ||
"fill_value": self.fill_value, | ||
"data_format": self.data_format, | ||
} | ||
return {**base_config, **config} | ||
|
||
def compute_output_shape(self, input_shape): | ||
return input_shape |
76 changes: 76 additions & 0 deletions
76
keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import numpy as np | ||
import pytest | ||
from tensorflow import data as tf_data | ||
|
||
import keras | ||
from keras.src import backend | ||
from keras.src import layers | ||
from keras.src import testing | ||
|
||
|
||
class RandomShearTest(testing.TestCase): | ||
@pytest.mark.requires_trainable_backend | ||
def test_layer(self): | ||
self.run_layer_test( | ||
layers.RandomShear, | ||
init_kwargs={ | ||
"x_factor": (0.5, 1), | ||
"y_factor": (0.5, 1), | ||
"interpolation": "bilinear", | ||
"fill_mode": "reflect", | ||
"data_format": "channels_last", | ||
"seed": 1, | ||
}, | ||
input_shape=(8, 3, 4, 3), | ||
supports_masking=False, | ||
expected_output_shape=(8, 3, 4, 3), | ||
) | ||
|
||
def test_random_posterization_inference(self): | ||
seed = 3481 | ||
layer = layers.RandomShear(1, 1) | ||
np.random.seed(seed) | ||
inputs = np.random.randint(0, 255, size=(224, 224, 3)) | ||
output = layer(inputs, training=False) | ||
self.assertAllClose(inputs, output) | ||
|
||
def test_shear_pixel_level(self): | ||
image = np.zeros((1, 5, 5, 3)) | ||
image[0, 1:4, 1:4, :] = 1.0 | ||
image[0, 2, 2, :] = [0.0, 1.0, 0.0] | ||
image = keras.ops.convert_to_tensor(image, dtype="float32") | ||
|
||
data_format = backend.config.image_data_format() | ||
if data_format == "channels_first": | ||
image = keras.ops.transpose(image, (0, 3, 1, 2)) | ||
|
||
shear_layer = layers.RandomShear( | ||
x_factor=(0.2, 0.3), | ||
y_factor=(0.2, 0.3), | ||
interpolation="bilinear", | ||
fill_mode="constant", | ||
fill_value=0.0, | ||
seed=42, | ||
data_format=data_format, | ||
) | ||
|
||
sheared_image = shear_layer(image) | ||
|
||
if data_format == "channels_first": | ||
sheared_image = keras.ops.transpose(sheared_image, (0, 2, 3, 1)) | ||
|
||
original_pixel = image[0, 2, 2, :] | ||
sheared_pixel = sheared_image[0, 2, 2, :] | ||
self.assertNotAllClose(original_pixel, sheared_pixel) | ||
|
||
def test_tf_data_compatibility(self): | ||
data_format = backend.config.image_data_format() | ||
if data_format == "channels_last": | ||
input_data = np.random.random((2, 8, 8, 3)) | ||
else: | ||
input_data = np.random.random((2, 3, 8, 8)) | ||
layer = layers.RandomShear(1, 1) | ||
|
||
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) | ||
for output in ds.take(1): | ||
output.numpy() |