Skip to content

Commit

Permalink
fix random inversion test
Browse files Browse the repository at this point in the history
  • Loading branch information
divyashreepathihalli committed Nov 15, 2023
1 parent 6e89b84 commit 332c041
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions keras_cv/utils/preprocessing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import tensorflow as tf

from keras_cv.backend import random
from keras_cv.tests.test_case import TestCase
from keras_cv.utils import preprocessing


class MockRandomGenerator:
def __init__(self, value):
self.value = value

def random_uniform(self, shape, minval, maxval, dtype=None):
del minval, maxval
return tf.constant(self.value, dtype=dtype)


class PreprocessingTestCase(TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -60,7 +54,19 @@ def test_transform_to_value_range(self):
self.assertAllClose(x, [128 / 255, 1, 0])

def test_random_inversion(self):
generator = MockRandomGenerator(0.75)
self.assertEqual(preprocessing.random_inversion(generator), -1.0)
generator = MockRandomGenerator(0.25)
self.assertEqual(preprocessing.random_inversion(generator), 1.0)
with unittest.mock.patch.object(
random,
"uniform",
return_value=0.75,
):
self.assertEqual(
preprocessing.random_inversion(random.SeedGenerator()), -1.0
)
with unittest.mock.patch.object(
random,
"uniform",
return_value=0.25,
):
self.assertEqual(
preprocessing.random_inversion(random.SeedGenerator()), 1.0
)

0 comments on commit 332c041

Please sign in to comment.