Skip to content

Commit

Permalink
update base_augmentation_layer_3d and all subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
divyashreepathihalli committed Nov 15, 2023
1 parent 0e65720 commit 6e89b84
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import config
from keras_cv.backend.random import SeedGenerator

if config.keras_3():
base_layer = tf.keras.layers.Layer
Expand Down Expand Up @@ -101,15 +100,12 @@ def augment_pointclouds(self, point_clouds, transformation):
Note that since the randomness is also a common functionality, this layer
also includes a keras.backend.RandomGenerator, which can be used to
produce the random numbers. The random number generator is stored in the
`self._seed_generator` attribute.
`self._random_generator` attribute.
"""

def __init__(self, seed=None, **kwargs):
# To-do: remove this once th elayer is ported to keras 3
# https://github.com/keras-team/keras-cv/issues/2136
self._seed_generator = SeedGenerator(
seed=seed,
)
if config.keras_3():
raise ValueError(
"This layer is not yet compatible with Keras 3."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import numpy as np
import tensorflow as tf

from keras_cv.backend import random
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d
from keras_cv.tests.test_case import TestCase

Expand All @@ -28,23 +27,20 @@ def __init__(self, translate_noise=(0.0, 0.0, 0.0), **kwargs):
self._translate_noise = translate_noise

def get_random_transformation(self, **kwargs):
random_x = random.normal(
random_x = self._random_generator.random_normal(
(),
mean=0.0,
stddev=self._translate_noise[0],
seed=self._seed_generator,
)
random_y = random.normal(
random_y = self._random_generator.random_normal(
(),
mean=0.0,
stddev=self._translate_noise[1],
seed=self._seed_generator,
)
random_z = random.normal(
random_z = self._random_generator.random_normal(
(),
mean=0.0,
stddev=self._translate_noise[2],
seed=self._seed_generator,
)

return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from keras_cv import point_cloud
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import random
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d

POINT_CLOUDS = base_augmentation_layer_3d.POINT_CLOUDS
Expand Down Expand Up @@ -123,11 +122,10 @@ def get_random_transformation(self, point_clouds, **kwargs):
frustum_mask = tf.concat(frustum_mask, axis=0)
# Generate mask along point dimension.
random_point_mask = (
random.uniform(
self._random_generator.random_uniform(
[1, num_points, 1],
minval=0.0,
maxval=1,
seed=self._seed_generator,
)
< self._keep_probability
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import tensorflow as tf

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import random
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d

POINT_CLOUDS = base_augmentation_layer_3d.POINT_CLOUDS
Expand Down Expand Up @@ -64,11 +63,10 @@ def get_random_transformation(self, point_clouds, **kwargs):
num_points = point_clouds.get_shape().as_list()[-2]
# Generate mask along point dimension.
random_point_mask = (
random.uniform(
self._random_generator.random_uniform(
[1, num_points, 1],
minval=0.0,
maxval=1,
seed=self._seed_generator,
)
< self._keep_probability
)
Expand Down
10 changes: 3 additions & 7 deletions keras_cv/layers/preprocessing_3d/waymo/global_random_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import tensorflow as tf

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import random
from keras_cv.bounding_box_3d import CENTER_XYZ_DXDYDZ_PHI
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d
from keras_cv.point_cloud import coordinate_transform
Expand Down Expand Up @@ -85,26 +84,23 @@ def get_config(self):
}

def get_random_transformation(self, **kwargs):
random_rotation_x = random.uniform(
random_rotation_x = self._random_generator.random_uniform(
(),
minval=-self._max_rotation_angle_x,
maxval=self._max_rotation_angle_x,
dtype=self.compute_dtype,
seed=self._seed_generator,
)
random_rotation_y = random.uniform(
random_rotation_y = self._random_generator.random_uniform(
(),
minval=-self._max_rotation_angle_y,
maxval=self._max_rotation_angle_y,
dtype=self.compute_dtype,
seed=self._seed_generator,
)
random_rotation_z = random.uniform(
random_rotation_z = self._random_generator.random_uniform(
(),
minval=-self._max_rotation_angle_z,
maxval=self._max_rotation_angle_z,
dtype=self.compute_dtype,
seed=self._seed_generator,
)
return {
"pose": tf.stack(
Expand Down
10 changes: 3 additions & 7 deletions keras_cv/layers/preprocessing_3d/waymo/global_random_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import tensorflow as tf

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import random
from keras_cv.bounding_box_3d import CENTER_XYZ_DXDYDZ_PHI
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d

Expand Down Expand Up @@ -134,26 +133,23 @@ def get_config(self):
}

def get_random_transformation(self, **kwargs):
random_scaling_x = random.uniform(
random_scaling_x = self._random_generator.random_uniform(
(),
minval=self._min_x_factor,
maxval=self._max_x_factor,
dtype=self.compute_dtype,
seed=self._seed_generator,
)
random_scaling_y = random.uniform(
random_scaling_y = self._random_generator.random_uniform(
(),
minval=self._min_y_factor,
maxval=self._max_y_factor,
dtype=self.compute_dtype,
seed=self._seed_generator,
)
random_scaling_z = random.uniform(
random_scaling_z = self._random_generator.random_uniform(
(),
minval=self._min_z_factor,
maxval=self._max_z_factor,
dtype=self.compute_dtype,
seed=self._seed_generator,
)
if not self._preserve_aspect_ratio:
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import tensorflow as tf

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import random
from keras_cv.bounding_box_3d import CENTER_XYZ_DXDYDZ_PHI
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d
from keras_cv.point_cloud import coordinate_transform
Expand Down Expand Up @@ -68,26 +67,23 @@ def get_config(self):
}

def get_random_transformation(self, **kwargs):
random_x_translation = random.normal(
random_x_translation = self._random_generator.random_normal(
(),
mean=0.0,
stddev=self._x_stddev,
dtype=self.compute_dtype,
seed=self._seed_generator,
)
random_y_translation = random.normal(
random_y_translation = self._random_generator.random_normal(
(),
mean=0.0,
stddev=self._y_stddev,
dtype=self.compute_dtype,
seed=self._seed_generator,
)
random_z_translation = random.normal(
random_z_translation = self._random_generator.random_normal(
(),
mean=0.0,
stddev=self._z_stddev,
dtype=self.compute_dtype,
seed=self._seed_generator,
)
return {
"pose": tf.stack(
Expand Down
4 changes: 1 addition & 3 deletions keras_cv/layers/preprocessing_3d/waymo/random_copy_paste.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import tensorflow as tf

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import random
from keras_cv.bounding_box_3d import CENTER_XYZ_DXDYDZ_PHI
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d
from keras_cv.ops import iou_3d
Expand Down Expand Up @@ -99,11 +98,10 @@ def get_random_transformation(
**kwargs
):
del point_clouds
num_paste_bounding_boxes = random.uniform(
num_paste_bounding_boxes = self._random_generator.random_uniform(
(),
minval=self._min_paste_bounding_boxes,
maxval=self._max_paste_bounding_boxes,
seed=self._seed_generator,
)
num_paste_bounding_boxes = tf.cast(
num_paste_bounding_boxes, dtype=tf.int32
Expand Down

0 comments on commit 6e89b84

Please sign in to comment.