-
Notifications
You must be signed in to change notification settings - Fork 19.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add Equalization Layer * api and fix format * lint * Add tf-data test * data format * Update Doc String
- Loading branch information
1 parent
aae0520
commit 58ac150
Showing
5 changed files
with
357 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
213 changes: 213 additions & 0 deletions
213
keras/src/layers/preprocessing/image_preprocessing/equalization.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
from keras.src import backend | ||
from keras.src.api_export import keras_export | ||
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 | ||
BaseImagePreprocessingLayer, | ||
) | ||
|
||
|
||
@keras_export("keras.layers.Equalization") | ||
class Equalization(BaseImagePreprocessingLayer): | ||
"""Preprocessing layer for histogram equalization on image channels. | ||
Histogram equalization is a technique to adjust image intensities to | ||
enhance contrast by effectively spreading out the most frequent | ||
intensity values. This layer applies equalization on a channel-wise | ||
basis, which can improve the visibility of details in images. | ||
This layer works with both grayscale and color images, performing | ||
equalization independently on each color channel. At inference time, | ||
the equalization is consistently applied. | ||
**Note:** This layer is safe to use inside a `tf.data` pipeline | ||
(independently of which backend you're using). | ||
Args: | ||
value_range: Optional list/tuple of 2 floats specifying the lower | ||
and upper limits of the input data values. Defaults to `[0, 255]`. | ||
If the input image has been scaled, use the appropriate range | ||
(e.g., `[0.0, 1.0]`). The equalization will be scaled to this | ||
range, and output values will be clipped accordingly. | ||
bins: Integer specifying the number of histogram bins to use for | ||
equalization. Defaults to 256, which is suitable for 8-bit images. | ||
Larger values can provide more granular intensity redistribution. | ||
Input shape: | ||
3D (unbatched) or 4D (batched) tensor with shape: | ||
`(..., height, width, channels)`, in `"channels_last"` format, | ||
or `(..., channels, height, width)`, in `"channels_first"` format. | ||
Output shape: | ||
3D (unbatched) or 4D (batched) tensor with shape: | ||
`(..., target_height, target_width, channels)`, | ||
or `(..., channels, target_height, target_width)`, | ||
in `"channels_first"` format. | ||
Example: | ||
```python | ||
# Create an equalization layer for standard 8-bit images | ||
equalizer = keras.layers.Equalization() | ||
# An image with uneven intensity distribution | ||
image = [...] # your input image | ||
# Apply histogram equalization | ||
equalized_image = equalizer(image) | ||
# For images with custom value range | ||
custom_equalizer = keras.layers.Equalization( | ||
value_range=[0.0, 1.0], # for normalized images | ||
bins=128 # fewer bins for more subtle equalization | ||
) | ||
custom_equalized = custom_equalizer(normalized_image) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, value_range=(0, 255), bins=256, data_format=None, **kwargs | ||
): | ||
super().__init__(**kwargs) | ||
self.bins = bins | ||
self._set_value_range(value_range) | ||
self.data_format = backend.standardize_data_format(data_format) | ||
|
||
def _set_value_range(self, value_range): | ||
if not isinstance(value_range, (tuple, list)): | ||
raise ValueError( | ||
self._VALUE_RANGE_VALIDATION_ERROR | ||
+ f"Received: value_range={value_range}" | ||
) | ||
if len(value_range) != 2: | ||
raise ValueError( | ||
self._VALUE_RANGE_VALIDATION_ERROR | ||
+ f"Received: value_range={value_range}" | ||
) | ||
self.value_range = sorted(value_range) | ||
|
||
def _custom_histogram_fixed_width(self, values, value_range, nbins): | ||
values = self.backend.cast(values, "float32") | ||
value_min, value_max = value_range | ||
value_min = self.backend.cast(value_min, "float32") | ||
value_max = self.backend.cast(value_max, "float32") | ||
|
||
scaled = (values - value_min) * (nbins - 1) / (value_max - value_min) | ||
indices = self.backend.cast(scaled, "int32") | ||
indices = self.backend.numpy.clip(indices, 0, nbins - 1) | ||
flat_indices = self.backend.numpy.reshape(indices, [-1]) | ||
|
||
if backend.backend() == "jax": | ||
# for JAX bincount is never jittable because of output shape | ||
histogram = self.backend.numpy.zeros(nbins, dtype="int32") | ||
for i in range(nbins): | ||
matches = self.backend.cast( | ||
self.backend.numpy.equal(flat_indices, i), "int32" | ||
) | ||
bin_count = self.backend.numpy.sum(matches) | ||
one_hot = self.backend.cast( | ||
self.backend.numpy.arange(nbins) == i, "int32" | ||
) | ||
histogram = histogram + (bin_count * one_hot) | ||
return histogram | ||
else: | ||
# TensorFlow/PyTorch/NumPy implementation using bincount | ||
return self.backend.numpy.bincount( | ||
flat_indices, | ||
minlength=nbins, | ||
) | ||
|
||
def _scale_values(self, values, source_range, target_range): | ||
source_min, source_max = source_range | ||
target_min, target_max = target_range | ||
scale = (target_max - target_min) / (source_max - source_min) | ||
offset = target_min - source_min * scale | ||
return values * scale + offset | ||
|
||
def _equalize_channel(self, channel, value_range): | ||
if value_range != (0, 255): | ||
channel = self._scale_values(channel, value_range, (0, 255)) | ||
|
||
hist = self._custom_histogram_fixed_width( | ||
channel, value_range=(0, 255), nbins=self.bins | ||
) | ||
|
||
nonzero_bins = self.backend.numpy.count_nonzero(hist) | ||
equalized = self.backend.numpy.where( | ||
nonzero_bins <= 1, channel, self._apply_equalization(channel, hist) | ||
) | ||
|
||
if value_range != (0, 255): | ||
equalized = self._scale_values(equalized, (0, 255), value_range) | ||
|
||
return equalized | ||
|
||
def _apply_equalization(self, channel, hist): | ||
cdf = self.backend.numpy.cumsum(hist) | ||
|
||
if backend.backend() == "jax": | ||
mask = cdf > 0 | ||
first_nonzero_idx = self.backend.numpy.argmax(mask) | ||
cdf_min = self.backend.numpy.take(cdf, first_nonzero_idx) | ||
else: | ||
cdf_min = self.backend.numpy.take( | ||
cdf, self.backend.numpy.nonzero(cdf)[0][0] | ||
) | ||
|
||
denominator = cdf[-1] - cdf_min | ||
denominator = self.backend.numpy.where( | ||
denominator == 0, | ||
self.backend.numpy.ones_like(1, dtype=denominator.dtype), | ||
denominator, | ||
) | ||
|
||
lookup_table = ((cdf - cdf_min) * 255) / denominator | ||
lookup_table = self.backend.numpy.clip( | ||
self.backend.numpy.round(lookup_table), 0, 255 | ||
) | ||
|
||
scaled_channel = (channel / 255.0) * (self.bins - 1) | ||
indices = self.backend.cast( | ||
self.backend.numpy.clip(scaled_channel, 0, self.bins - 1), "int32" | ||
) | ||
return self.backend.numpy.take(lookup_table, indices) | ||
|
||
def transform_images(self, images, transformations=None, **kwargs): | ||
images = self.backend.cast(images, self.compute_dtype) | ||
|
||
if self.data_format == "channels_first": | ||
channels = [] | ||
for i in range(self.backend.core.shape(images)[-3]): | ||
channel = images[..., i, :, :] | ||
equalized = self._equalize_channel(channel, self.value_range) | ||
channels.append(equalized) | ||
equalized_images = self.backend.numpy.stack(channels, axis=-3) | ||
else: | ||
channels = [] | ||
for i in range(self.backend.core.shape(images)[-1]): | ||
channel = images[..., i] | ||
equalized = self._equalize_channel(channel, self.value_range) | ||
channels.append(equalized) | ||
equalized_images = self.backend.numpy.stack(channels, axis=-1) | ||
|
||
return self.backend.cast(equalized_images, self.compute_dtype) | ||
|
||
def compute_output_shape(self, input_shape): | ||
return input_shape | ||
|
||
def compute_output_spec(self, inputs, **kwargs): | ||
return inputs | ||
|
||
def transform_bounding_boxes(self, bounding_boxes, **kwargs): | ||
return bounding_boxes | ||
|
||
def transform_labels(self, labels, transformations=None, **kwargs): | ||
return labels | ||
|
||
def transform_segmentation_masks( | ||
self, segmentation_masks, transformations, **kwargs | ||
): | ||
return segmentation_masks | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update({"bins": self.bins, "value_range": self.value_range}) | ||
return config |
135 changes: 135 additions & 0 deletions
135
keras/src/layers/preprocessing/image_preprocessing/equalization_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import numpy as np | ||
import pytest | ||
from absl.testing import parameterized | ||
from tensorflow import data as tf_data | ||
|
||
from keras.src import layers | ||
from keras.src import ops | ||
from keras.src import testing | ||
|
||
|
||
class EqualizationTest(testing.TestCase): | ||
def assertAllInRange(self, array, min_val, max_val): | ||
self.assertTrue(np.all(array >= min_val)) | ||
self.assertTrue(np.all(array <= max_val)) | ||
|
||
@pytest.mark.requires_trainable_backend | ||
def test_layer(self): | ||
self.run_layer_test( | ||
layers.Equalization, | ||
init_kwargs={ | ||
"value_range": (0, 255), | ||
"data_format": "channels_last", | ||
}, | ||
input_shape=(1, 2, 2, 3), | ||
supports_masking=False, | ||
expected_output_shape=(1, 2, 2, 3), | ||
) | ||
|
||
self.run_layer_test( | ||
layers.Equalization, | ||
init_kwargs={ | ||
"value_range": (0, 255), | ||
"data_format": "channels_first", | ||
}, | ||
input_shape=(1, 3, 2, 2), | ||
supports_masking=False, | ||
expected_output_shape=(1, 3, 2, 2), | ||
) | ||
|
||
def test_equalizes_to_all_bins(self): | ||
xs = np.random.uniform(size=(2, 512, 512, 3), low=0, high=255).astype( | ||
np.float32 | ||
) | ||
layer = layers.Equalization(value_range=(0, 255)) | ||
xs = layer(xs) | ||
|
||
for i in range(0, 256): | ||
self.assertTrue(np.any(ops.convert_to_numpy(xs) == i)) | ||
|
||
@parameterized.named_parameters( | ||
("float32", np.float32), ("int32", np.int32), ("int64", np.int64) | ||
) | ||
def test_input_dtypes(self, dtype): | ||
xs = np.random.uniform(size=(2, 512, 512, 3), low=0, high=255).astype( | ||
dtype | ||
) | ||
layer = layers.Equalization(value_range=(0, 255)) | ||
xs = ops.convert_to_numpy(layer(xs)) | ||
|
||
for i in range(0, 256): | ||
self.assertTrue(np.any(xs == i)) | ||
self.assertAllInRange(xs, 0, 255) | ||
|
||
@parameterized.named_parameters(("0_255", 0, 255), ("0_1", 0, 1)) | ||
def test_output_range(self, lower, upper): | ||
xs = np.random.uniform( | ||
size=(2, 512, 512, 3), low=lower, high=upper | ||
).astype(np.float32) | ||
layer = layers.Equalization(value_range=(lower, upper)) | ||
xs = ops.convert_to_numpy(layer(xs)) | ||
self.assertAllInRange(xs, lower, upper) | ||
|
||
def test_constant_regions(self): | ||
xs = np.zeros((1, 64, 64, 3), dtype=np.float32) | ||
xs[:, :21, :, :] = 50 | ||
xs[:, 21:42, :, :] = 100 | ||
xs[:, 42:, :, :] = 200 | ||
|
||
layer = layers.Equalization(value_range=(0, 255)) | ||
equalized = ops.convert_to_numpy(layer(xs)) | ||
|
||
self.assertTrue(len(np.unique(equalized)) >= 3) | ||
self.assertAllInRange(equalized, 0, 255) | ||
|
||
def test_grayscale_images(self): | ||
xs_last = np.random.uniform(0, 255, size=(2, 64, 64, 1)).astype( | ||
np.float32 | ||
) | ||
layer_last = layers.Equalization( | ||
value_range=(0, 255), data_format="channels_last" | ||
) | ||
equalized_last = ops.convert_to_numpy(layer_last(xs_last)) | ||
self.assertEqual(equalized_last.shape[-1], 1) | ||
self.assertAllInRange(equalized_last, 0, 255) | ||
|
||
xs_first = np.random.uniform(0, 255, size=(2, 1, 64, 64)).astype( | ||
np.float32 | ||
) | ||
layer_first = layers.Equalization( | ||
value_range=(0, 255), data_format="channels_first" | ||
) | ||
equalized_first = ops.convert_to_numpy(layer_first(xs_first)) | ||
self.assertEqual(equalized_first.shape[1], 1) | ||
self.assertAllInRange(equalized_first, 0, 255) | ||
|
||
def test_single_color_image(self): | ||
xs_last = np.full((1, 64, 64, 3), 128, dtype=np.float32) | ||
layer_last = layers.Equalization( | ||
value_range=(0, 255), data_format="channels_last" | ||
) | ||
equalized_last = ops.convert_to_numpy(layer_last(xs_last)) | ||
self.assertAllClose(equalized_last, 128.0) | ||
|
||
xs_first = np.full((1, 3, 64, 64), 128, dtype=np.float32) | ||
layer_first = layers.Equalization( | ||
value_range=(0, 255), data_format="channels_first" | ||
) | ||
equalized_first = ops.convert_to_numpy(layer_first(xs_first)) | ||
self.assertAllClose(equalized_first, 128.0) | ||
|
||
def test_different_bin_sizes(self): | ||
xs = np.random.uniform(0, 255, size=(1, 64, 64, 3)).astype(np.float32) | ||
bin_sizes = [16, 64, 128, 256] | ||
for bins in bin_sizes: | ||
layer = layers.Equalization(value_range=(0, 255), bins=bins) | ||
equalized = ops.convert_to_numpy(layer(xs)) | ||
self.assertAllInRange(equalized, 0, 255) | ||
|
||
def test_tf_data_compatibility(self): | ||
layer = layers.Equalization(value_range=(0, 255)) | ||
input_data = np.random.random((2, 8, 8, 3)) * 255 | ||
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) | ||
for output in ds.take(1): | ||
output_array = output.numpy() | ||
self.assertAllInRange(output_array, 0, 255) |