diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_flip.py b/keras/src/layers/preprocessing/image_preprocessing/random_flip.py index 799a2477915..25e3cb2e520 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_flip.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_flip.py @@ -101,6 +101,9 @@ def transform_bounding_boxes( transformation, training=True, ): + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + def _flip_boxes_horizontal(boxes): x1, x2, x3, x4 = self.backend.numpy.split(boxes, 4, axis=-1) outputs = self.backend.numpy.concatenate( @@ -116,9 +119,6 @@ def _flip_boxes_vertical(boxes): return outputs def _transform_xyxy(boxes, box_flips): - if backend_utils.in_tf_graph(): - self.backend.set_backend("tensorflow") - bboxes = boxes["boxes"] if self.mode in {HORIZONTAL, HORIZONTAL_AND_VERTICAL}: bboxes = self.backend.numpy.where( @@ -132,9 +132,6 @@ def _transform_xyxy(boxes, box_flips): _flip_boxes_vertical(bboxes), bboxes, ) - - self.backend.reset() - return bboxes flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1) @@ -176,6 +173,8 @@ def _transform_xyxy(boxes, box_flips): width=input_width, ) + self.backend.reset() + return bounding_boxes def transform_segmentation_masks(