From 75b4d30133568f8c14d218684178ecb4ee14d41c Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 16 Nov 2023 05:44:54 +0000 Subject: [PATCH] fix dtype errors --- benchmarks/vectorized_random_translation.py | 4 ++-- benchmarks/vectorized_randomly_zoomed_crop.py | 4 ++-- keras_cv/backend/random.py | 5 ++++- keras_cv/layers/preprocessing/grid_mask.py | 6 +++--- keras_cv/layers/preprocessing/mosaic.py | 13 ++++++------- .../preprocessing/random_augmentation_pipeline.py | 2 +- keras_cv/layers/preprocessing/random_choice.py | 3 +-- .../layers/preprocessing/random_crop_and_resize.py | 4 ++-- keras_cv/layers/preprocessing/random_cutout.py | 6 ++---- keras_cv/layers/preprocessing/random_hue.py | 2 +- keras_cv/layers/preprocessing/random_translation.py | 4 ++-- keras_cv/layers/regularization/dropblock_2d.py | 2 +- keras_cv/utils/preprocessing.py | 2 +- 13 files changed, 28 insertions(+), 29 deletions(-) diff --git a/benchmarks/vectorized_random_translation.py b/benchmarks/vectorized_random_translation.py index 9d883d5f36..7d72596125 100644 --- a/benchmarks/vectorized_random_translation.py +++ b/benchmarks/vectorized_random_translation.py @@ -222,14 +222,14 @@ def get_random_transformation(self, image=None, **kwargs): shape=[batch_size, 1], minval=self.height_lower, maxval=self.height_upper, - dtype=tf.float32, + dtype="float32", seed=self._seed_generator, ) width_translation = random.uniform( shape=[batch_size, 1], minval=self.width_lower, maxval=self.width_upper, - dtype=tf.float32, + dtype="float32", seed=self._seed_generator, ) return { diff --git a/benchmarks/vectorized_randomly_zoomed_crop.py b/benchmarks/vectorized_randomly_zoomed_crop.py index 434e45555a..88876c73af 100644 --- a/benchmarks/vectorized_randomly_zoomed_crop.py +++ b/benchmarks/vectorized_randomly_zoomed_crop.py @@ -114,7 +114,7 @@ def get_random_transformation( (), minval=tf.minimum(0.0, original_height - new_height), maxval=tf.maximum(0.0, original_height - new_height), - dtype=tf.float32, + dtype="float32", seed=self._seed_generator, ) @@ -122,7 +122,7 @@ def get_random_transformation( (), minval=tf.minimum(0.0, original_width - new_width), maxval=tf.maximum(0.0, original_width - new_width), - dtype=tf.float32, + dtype="float32", seed=self._seed_generator, ) diff --git a/keras_cv/backend/random.py b/keras_cv/backend/random.py index 97958479c6..b8fe8eab27 100644 --- a/keras_cv/backend/random.py +++ b/keras_cv/backend/random.py @@ -16,9 +16,11 @@ if keras_3(): from keras.random import * # noqa: F403, F401 + # SeedGenerator is imported from `keras.random` else: from keras_core.random import * # noqa: F403, F401 + class SeedGenerator: def __init__(self, seed=None, **kwargs): self._current_seed = [seed, 0] @@ -86,7 +88,8 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): seed=make_seed(seed), **kwargs, ) - + + def randint(shape, minval=0.0, maxval=1.0, dtype="int32", seed=None): kwargs = {} if dtype: diff --git a/keras_cv/layers/preprocessing/grid_mask.py b/keras_cv/layers/preprocessing/grid_mask.py index 39dd6c35ef..20f54adb21 100644 --- a/keras_cv/layers/preprocessing/grid_mask.py +++ b/keras_cv/layers/preprocessing/grid_mask.py @@ -186,7 +186,7 @@ def _compute_grid_mask(self, input_shape, ratio): shape=(), minval=tf.math.minimum(height * 0.5, width * 0.3), maxval=tf.math.maximum(height * 0.5, width * 0.3) + 1, - dtype=tf.float32, + dtype="float32", seed=self._seed_generator, ) rectangle_side_len = tf.cast((ratio) * unit_size, tf.float32) @@ -196,14 +196,14 @@ def _compute_grid_mask(self, input_shape, ratio): shape=(), minval=0.0, maxval=unit_size, - dtype=tf.float32, + dtype="float32", seed=self._seed_generator, ) delta_y = random.uniform( shape=(), minval=0.0, maxval=unit_size, - dtype=tf.float32, + dtype="float32", seed=self._seed_generator, ) diff --git a/keras_cv/layers/preprocessing/mosaic.py b/keras_cv/layers/preprocessing/mosaic.py index 6e42364516..c60e0de48c 100644 --- a/keras_cv/layers/preprocessing/mosaic.py +++ b/keras_cv/layers/preprocessing/mosaic.py @@ -20,19 +20,19 @@ from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501 BATCHED, ) -from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501 +from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( BOUNDING_BOXES, ) -from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501 +from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( IMAGES, ) -from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501 +from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( LABELS, ) -from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501 +from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( SEGMENTATION_MASKS, ) -from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501 +from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( VectorizedBaseImageAugmentationLayer, ) from keras_cv.utils import preprocessing as preprocessing_utils @@ -97,11 +97,10 @@ def __init__( def get_random_transformation_batch(self, batch_size, **kwargs): # pick 3 indices for every batch to create the mosaic output with. - permutation_order = random.uniform( + permutation_order = random.randint( (batch_size, 3), minval=0, maxval=batch_size, - dtype=tf.int32, seed=self._seed_generator, ) # concatenate the batches with permutation order to get all 4 images of diff --git a/keras_cv/layers/preprocessing/random_augmentation_pipeline.py b/keras_cv/layers/preprocessing/random_augmentation_pipeline.py index 67bf79b615..bee86bb260 100644 --- a/keras_cv/layers/preprocessing/random_augmentation_pipeline.py +++ b/keras_cv/layers/preprocessing/random_augmentation_pipeline.py @@ -103,7 +103,7 @@ def _augment(self, inputs): shape=(), minval=0.0, maxval=1.0, - dtype=tf.float32, + dtype="float32", seed=self._seed_generator, ) result = tf.cond( diff --git a/keras_cv/layers/preprocessing/random_choice.py b/keras_cv/layers/preprocessing/random_choice.py index e38f1bef8d..386d5d00ab 100644 --- a/keras_cv/layers/preprocessing/random_choice.py +++ b/keras_cv/layers/preprocessing/random_choice.py @@ -88,11 +88,10 @@ def _batch_augment(self, inputs): return super()._batch_augment(inputs) def _augment(self, inputs, *args, **kwargs): - selected_op = random.uniform( + selected_op = random.randint( (), minval=0, maxval=len(self.layers), - dtype=tf.int32, seed=self._seed_generator, ) # Warning: diff --git a/keras_cv/layers/preprocessing/random_crop_and_resize.py b/keras_cv/layers/preprocessing/random_crop_and_resize.py index 7c657beec2..9ad271e708 100644 --- a/keras_cv/layers/preprocessing/random_crop_and_resize.py +++ b/keras_cv/layers/preprocessing/random_crop_and_resize.py @@ -114,7 +114,7 @@ def get_random_transformation( (), minval=tf.minimum(0.0, 1.0 - new_height), maxval=tf.maximum(0.0, 1.0 - new_height), - dtype=tf.float32, + dtype="float32", seed=self._seed_generator, ) @@ -122,7 +122,7 @@ def get_random_transformation( (), minval=tf.minimum(0.0, 1.0 - new_width), maxval=tf.maximum(0.0, 1.0 - new_width), - dtype=tf.float32, + dtype="float32", seed=self._seed_generator, ) diff --git a/keras_cv/layers/preprocessing/random_cutout.py b/keras_cv/layers/preprocessing/random_cutout.py index 6b256402e7..c67f66f5a3 100644 --- a/keras_cv/layers/preprocessing/random_cutout.py +++ b/keras_cv/layers/preprocessing/random_cutout.py @@ -132,18 +132,16 @@ def _compute_rectangle_position(self, inputs): input_shape[0], input_shape[1], ) - center_x = random.uniform( + center_x = random.randint( [1], 0, image_width, - dtype=tf.int32, seed=self._seed_generator, ) - center_y = random.uniform( + center_y = random.randint( [1], 0, image_height, - dtype=tf.int32, seed=self._seed_generator, ) return center_x, center_y diff --git a/keras_cv/layers/preprocessing/random_hue.py b/keras_cv/layers/preprocessing/random_hue.py index 65e7a1799a..3cf9a7fe49 100644 --- a/keras_cv/layers/preprocessing/random_hue.py +++ b/keras_cv/layers/preprocessing/random_hue.py @@ -70,7 +70,7 @@ def get_random_transformation_batch(self, batch_size, **kwargs): (batch_size,), 0, 1, - tf.float32, + "float32", seed=self._seed_generator, ) invert = tf.where( diff --git a/keras_cv/layers/preprocessing/random_translation.py b/keras_cv/layers/preprocessing/random_translation.py index 73fd9083d2..6a0537e1a5 100644 --- a/keras_cv/layers/preprocessing/random_translation.py +++ b/keras_cv/layers/preprocessing/random_translation.py @@ -149,14 +149,14 @@ def get_random_transformation_batch(self, batch_size, **kwargs): shape=[batch_size, 1], minval=self.height_lower, maxval=self.height_upper, - dtype=tf.float32, + dtype="float32", seed=self._seed_generator, ) width_translations = random.uniform( shape=[batch_size, 1], minval=self.width_lower, maxval=self.width_upper, - dtype=tf.float32, + dtype="float32", seed=self._seed_generator, ) return { diff --git a/keras_cv/layers/regularization/dropblock_2d.py b/keras_cv/layers/regularization/dropblock_2d.py index 1e3f31705f..43acdc61bc 100644 --- a/keras_cv/layers/regularization/dropblock_2d.py +++ b/keras_cv/layers/regularization/dropblock_2d.py @@ -220,7 +220,7 @@ def call(self, x, training=None): random_noise = random.uniform( tf.shape(x), - dtype=tf.float32, + dtype="float32", seed=self._seed_generator, ) valid_block = tf.cast(valid_block, dtype=tf.float32) diff --git a/keras_cv/utils/preprocessing.py b/keras_cv/utils/preprocessing.py index 84d0519ff3..397b8be7fb 100644 --- a/keras_cv/utils/preprocessing.py +++ b/keras_cv/utils/preprocessing.py @@ -195,7 +195,7 @@ def random_inversion(seed_generator): def batch_random_inversion(seed_generator, batch_size): """Same as `random_inversion` but for batched inputs.""" negate = random.uniform( - (batch_size, 1), 0, 1, dtype=tf.float32, seed=seed_generator + (batch_size, 1), 0, 1, dtype="float32", seed=seed_generator ) negate = tf.where(negate > 0.5, -1.0, 1.0) return negate