Skip to content

Commit

Permalink
Update aug layers (#2147)
Browse files Browse the repository at this point in the history
* Replace random_generator

* Replace random_genereator
  • Loading branch information
sampathweb committed Nov 15, 2023
1 parent aaf144a commit 266ed03
Show file tree
Hide file tree
Showing 21 changed files with 145 additions and 87 deletions.
3 changes: 2 additions & 1 deletion benchmarks/vectorized_mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.backend import random
from keras_cv.layers import Mosaic
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -101,7 +102,7 @@ def _batch_augment(self, inputs):
minval=0,
maxval=batch_size,
dtype=tf.int32,
seed=self._seed_generator.next(),
seed=random.make_seed(seed=self._seed_generator),
)
# concatenate the batches with permutation order to get all 4 images of
# the mosaic
Expand Down
1 change: 1 addition & 0 deletions benchmarks/vectorized_randomly_zoomed_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tensorflow import keras

from keras_cv import core
from keras_cv.backend import random
from keras_cv.layers import RandomlyZoomedCrop
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down
36 changes: 15 additions & 21 deletions keras_cv/backend/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,30 @@ def __init__(self, seed=None, **kwargs):
seed=seed, **kwargs
)
else:
self._current_seed = [0, seed]
self._current_seed = [seed, 0]

def next(self, ordered=True):
if keras_3():
return self._seed_generator.next(ordered=ordered)
else:
self._current_seed[0] += 1
self._current_seed[1] += 1
return self._current_seed[:]


def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
def make_seed(seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
seed_0, seed_1 = seed.next()
if seed_0 is None:
init_seed = seed_1
else:
init_seed = seed_0 + seed_1
else:
init_seed = seed
return init_seed


def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
init_seed = make_seed(seed)
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
Expand All @@ -68,11 +75,7 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):


def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed
init_seed = make_seed(seed)
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
Expand All @@ -97,12 +100,7 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):


def shuffle(x, axis=0, seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed

init_seed = make_seed(seed)
if keras_3():
return keras.random.shuffle(x=x, axis=axis, seed=init_seed)
else:
Expand All @@ -112,11 +110,7 @@ def shuffle(x, axis=0, seed=None):


def categorical(logits, num_samples, dtype=None, seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed
init_seed = make_seed(seed)
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
Expand Down
10 changes: 7 additions & 3 deletions keras_cv/layers/preprocessing/aug_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,22 @@ def _sample_from_dirichlet(self, alpha):
gamma_sample = tf.random.gamma(
shape=(),
alpha=alpha,
seed=self._seed_generator.next(),
seed=random.make_seed(seed=self._seed_generator),
)
return gamma_sample / tf.reduce_sum(
gamma_sample, axis=-1, keepdims=True
)

def _sample_from_beta(self, alpha, beta):
sample_alpha = tf.random.gamma(
(), alpha=alpha, seed=self._seed_generator.next()
(),
alpha=alpha,
seed=random.make_seed(seed=self._seed_generator),
)
sample_beta = tf.random.gamma(
(), alpha=beta, seed=self._seed_generator.next()
(),
alpha=beta,
seed=random.make_seed(seed=self._seed_generator),
)
return sample_alpha / (sample_alpha + sample_beta)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import keras
import tensorflow as tf

from keras_cv import bounding_box
Expand Down Expand Up @@ -126,6 +125,8 @@ def augment_image(self, image, transformation):
"""

def __init__(self, seed=None, **kwargs):
# TODO: Remove unused force_generator arg
_ = kwargs.pop("force_generator", None)
self._seed_generator = SeedGenerator(
seed=seed,
)
Expand Down
9 changes: 7 additions & 2 deletions keras_cv/layers/preprocessing/cut_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import tensorflow as tf

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import random
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
Expand Down Expand Up @@ -49,10 +50,14 @@ def __init__(

def _sample_from_beta(self, alpha, beta, shape):
sample_alpha = tf.random.gamma(
shape, alpha=alpha, seed=self._seed_generator.next()
shape,
alpha=alpha,
seed=random.make_seed(seed=self._seed_generator),
)
sample_beta = tf.random.gamma(
shape, alpha=beta, seed=self._seed_generator.next()
shape,
alpha=beta,
seed=random.make_seed(seed=self._seed_generator),
)
return sample_alpha / (sample_alpha + sample_beta)

Expand Down
8 changes: 6 additions & 2 deletions keras_cv/layers/preprocessing/fourier_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,14 @@ def __init__(self, alpha=0.5, decay_power=3, seed=None, **kwargs):

def _sample_from_beta(self, alpha, beta, shape):
sample_alpha = tf.random.gamma(
shape, alpha=alpha, seed=self._seed_generator.next()
shape,
alpha=alpha,
seed=random.make_seed(seed=self._seed_generator),
)
sample_beta = tf.random.gamma(
shape, alpha=beta, seed=self._seed_generator.next()
shape,
alpha=beta,
seed=random.make_seed(seed=self._seed_generator),
)
return sample_alpha / (sample_alpha + sample_beta)

Expand Down
9 changes: 7 additions & 2 deletions keras_cv/layers/preprocessing/mix_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from keras_cv import bounding_box
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import random
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
Expand Down Expand Up @@ -57,10 +58,14 @@ def __init__(self, alpha=0.2, seed=None, **kwargs):

def _sample_from_beta(self, alpha, beta, shape):
sample_alpha = tf.random.gamma(
shape, alpha=alpha, seed=self._seed_generator.next()
shape,
alpha=alpha,
seed=random.make_seed(seed=self._seed_generator),
)
sample_beta = tf.random.gamma(
shape, alpha=beta, seed=self._seed_generator.next()
shape,
alpha=beta,
seed=random.make_seed(seed=self._seed_generator),
)
return sample_alpha / (sample_alpha + sample_beta)

Expand Down
15 changes: 11 additions & 4 deletions keras_cv/layers/preprocessing/random_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras_cv import bounding_box
from keras_cv import layers as cv_layers
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import random
from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501
VectorizedBaseImageAugmentationLayer,
)
Expand Down Expand Up @@ -79,14 +80,20 @@ def compute_ragged_image_signature(self, images):

def get_random_transformation_batch(self, batch_size, **kwargs):
tops = tf.cast(
self._random_generator.random_uniform(
shape=(batch_size, 1), minval=0, maxval=1
random.uniform(
shape=(batch_size, 1),
minval=0,
maxval=1,
seed=self._seed_generator,
),
self.compute_dtype,
)
lefts = tf.cast(
self._random_generator.random_uniform(
shape=(batch_size, 1), minval=0, maxval=1
random.uniform(
shape=(batch_size, 1),
minval=0,
maxval=1,
seed=self._seed_generator,
),
self.compute_dtype,
)
Expand Down
13 changes: 7 additions & 6 deletions keras_cv/layers/preprocessing/random_crop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from absl.testing import parameterized

from keras_cv import layers as cv_layers
from keras_cv.backend import random
from keras_cv.layers.preprocessing.random_crop import RandomCrop
from keras_cv.tests.test_case import TestCase

Expand Down Expand Up @@ -105,8 +106,8 @@ def test_unbatched_image(self):
mock_offset = np.ones(shape=(1, 1), dtype="float32") * 0.25
layer = RandomCrop(8, 8)
with unittest.mock.patch.object(
layer._random_generator,
"random_uniform",
random,
"uniform",
return_value=mock_offset,
):
actual_output = layer(inp, training=True)
Expand All @@ -119,8 +120,8 @@ def test_batched_input(self):
mock_offset = np.ones(shape=(20, 1), dtype="float32") * 2 / (16 - 8)
layer = RandomCrop(8, 8)
with unittest.mock.patch.object(
layer._random_generator,
"random_uniform",
random,
"uniform",
return_value=mock_offset,
):
actual_output = layer(inp, training=True)
Expand Down Expand Up @@ -194,8 +195,8 @@ def augment(x):
return layer(x, training=True)

with unittest.mock.patch.object(
layer._random_generator,
"random_uniform",
random,
"uniform",
return_value=mock_offset,
):
actual_output = augment(inp)
Expand Down
11 changes: 7 additions & 4 deletions keras_cv/layers/preprocessing/random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from keras_cv import bounding_box
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import random
from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501
VectorizedBaseImageAugmentationLayer,
)
Expand Down Expand Up @@ -98,13 +99,15 @@ def get_random_transformation_batch(self, batch_size, **kwargs):
flip_verticals = tf.zeros(shape=(batch_size, 1))

if self.horizontal:
flip_horizontals = self._random_generator.random_uniform(
shape=(batch_size, 1)
flip_horizontals = random.uniform(
shape=(batch_size, 1),
seed=self._seed_generator,
)

if self.vertical:
flip_verticals = self._random_generator.random_uniform(
shape=(batch_size, 1)
flip_verticals = random.uniform(
shape=(batch_size, 1),
seed=self._seed_generator,
)

return {
Expand Down
Loading

0 comments on commit 266ed03

Please sign in to comment.