diff --git a/keras_cv/utils/preprocessing_test.py b/keras_cv/utils/preprocessing_test.py index 96ad303d80..c9812ea862 100644 --- a/keras_cv/utils/preprocessing_test.py +++ b/keras_cv/utils/preprocessing_test.py @@ -12,21 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import tensorflow as tf +from keras_cv.backend import random from keras_cv.tests.test_case import TestCase from keras_cv.utils import preprocessing -class MockRandomGenerator: - def __init__(self, value): - self.value = value - - def random_uniform(self, shape, minval, maxval, dtype=None): - del minval, maxval - return tf.constant(self.value, dtype=dtype) - - class PreprocessingTestCase(TestCase): def setUp(self): super().setUp() @@ -60,7 +54,19 @@ def test_transform_to_value_range(self): self.assertAllClose(x, [128 / 255, 1, 0]) def test_random_inversion(self): - generator = MockRandomGenerator(0.75) - self.assertEqual(preprocessing.random_inversion(generator), -1.0) - generator = MockRandomGenerator(0.25) - self.assertEqual(preprocessing.random_inversion(generator), 1.0) + with unittest.mock.patch.object( + random, + "uniform", + return_value=0.75, + ): + self.assertEqual( + preprocessing.random_inversion(random.SeedGenerator()), -1.0 + ) + with unittest.mock.patch.object( + random, + "uniform", + return_value=0.25, + ): + self.assertEqual( + preprocessing.random_inversion(random.SeedGenerator()), 1.0 + )