From ad19601b627f9466f7aad0c0038b35d57e57228c Mon Sep 17 00:00:00 2001 From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Date: Wed, 6 Dec 2023 15:03:51 -0600 Subject: [PATCH] r0.7.2 Cherry Pick (#2215) * Fix Keras 3 version check (#2191) * Fix Keras 3 version check * Fix Keras 3 version check * Fix Keras 3 version check * Raise error if Keras is not compatible with TF * Fix bug when upranking passthrough inputs to RandAugment (#2194) - RandAugment sometimes will choose a "no augmentation" option and passthrough inputs unaltered. - Preprocessing normalization routines were not making copies of inputs and sometimes mutating layer input directly (mutating the input dict to cast dtypes and uprank tensors). - RandAugment under the passthrough option would return these inputs directly. The net effect was sometimes attempting to uprank during a passthrough call, breaking tf.map_fn * fix stable diffusion rank error (#2208) * Simplify running KerasCV with Keras 3 (#2179) * remove keras_core dependency * update init * update readme * fix model None error (#2176) (#2177) * Update pycoco_callback.py * Update waymo_evaluation_callback.py * fix model None error (#2176) (#2178) * Update pycoco_callback.py * Update waymo_evaluation_callback.py * update readme and conftest * update readme * update citation list * fix mix transformer tests * fix lint error * fix all failing tests * Fix dtype support for SegmentAnythingModel (#2207) * Fix dtype support for SAM * Update keras_cv/models/segmentation/segment_anything/sam_test.py * Fix Keras 2 failures * Fix F401 lint error; remove unused import * Version bump to r0.7.2.dev0 --------- Co-authored-by: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Co-authored-by: Divyashree Sreepathihalli Co-authored-by: Tirth Patel --- README.md | 61 +++++++----- keras_cv/__init__.py | 2 +- keras_cv/backend/__init__.py | 20 +++- keras_cv/backend/config.py | 84 +++++----------- keras_cv/backend/keras.py | 5 - keras_cv/backend/ops.py | 9 +- keras_cv/backend/scope.py | 4 +- keras_cv/conftest.py | 14 +-- .../multi_class_non_max_suppression.py | 4 +- .../object_detection/non_max_suppression.py | 6 +- .../base_image_augmentation_layer.py | 2 + .../layers/preprocessing/rand_augment_test.py | 12 ++- .../random_augmentation_pipeline.py | 2 +- ...ectorized_base_image_augmentation_layer.py | 9 +- keras_cv/layers/spatial_pyramid.py | 4 +- keras_cv/layers/vit_det_layers.py | 16 ++- keras_cv/losses/focal_test.py | 4 +- .../center_pillar_backbone_test.py | 4 +- .../object_detection_3d/center_pillar_test.py | 4 +- .../segment_anything/sam_layers.py | 14 +-- .../segmentation/segment_anything/sam_test.py | 98 +++++++++++-------- .../stable_diffusion/stable_diffusion.py | 2 +- 22 files changed, 194 insertions(+), 186 deletions(-) diff --git a/README.md b/README.md index e4ad4bae14..6c7b97c9f2 100644 --- a/README.md +++ b/README.md @@ -7,10 +7,10 @@ [![Contributions Welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/keras-team/keras-cv/issues) KerasCV is a library of modular computer vision components that work natively -with TensorFlow, JAX, or PyTorch. Built on [Keras Core](https://keras.io/keras_core/announcement/), -these models, layers, metrics, callbacks, etc., can be trained and serialized -in any framework and re-used in another without costly migrations. See -"Configuring your backend" below for more details on multi-framework KerasCV. +with TensorFlow, JAX, or PyTorch. Built on Keras 3, these models, layers, +metrics, callbacks, etc., can be trained and serialized in any framework and +re-used in another without costly migrations. See "Configuring your backend" +below for more details on multi-framework KerasCV. @@ -34,29 +34,44 @@ these common tasks. - [API Design Guidelines](.github/API_DESIGN.md) ## Installation +KerasCV supports both Keras 2 and Keras 3. We recommend Keras 3 for all new +users, as it enables using KerasCV models and layers with JAX, TensorFlow and +PyTorch. -To install the latest official release: +### Keras 2 Installation + +To install the latest KerasCV release with Keras 2, simply run: ``` -pip install keras-cv tensorflow --upgrade +pip install --upgrade keras-cv tensorflow ``` -To install the latest unreleased changes to the library, we recommend using -pip to install directly from the master branch on github: +### Keras 3 Installation + +There are currently two ways to install Keras 3 with KerasCV. To install the +latest changes for KerasCV and Keras, you can use our nightly package. + ``` -pip install git+https://github.com/keras-team/keras-cv.git tensorflow --upgrade +pip install --upgrade keras-cv-nightly tf-nightly ``` -## Configuring your backend +To install the stable versions of KerasCV and Keras 3, you should install Keras +3 **after** installing KerasCV. This is a temporary step while TensorFlow is +pinned to Keras 2, and will no longer be necessary after TensorFlow 2.16. -**Keras 3** is an upcoming release of the Keras library which supports -TensorFlow, Jax or Torch as backends. This is supported today in KerasNLP, -but will not be enabled by default until the official release of Keras 3. If you -`pip install keras-cv` and run a script or notebook without changes, you will -be using TensorFlow and **Keras 2**. +``` +pip install --upgrade keras-cv tensorflow +pip install keras>=3 +``` +> [!IMPORTANT] +> Keras 3 will not function with TensorFlow 2.14 or earlier. + +## Configuring your backend -If you would like to enable a preview of the Keras 3 behavior, you can do +If you have Keras 3 installed in your environment (see installation above), +you can use KerasCV with any of JAX, TensorFlow and PyTorch. To do so, set the +`KERAS_BACKEND` environment variable. For example: so by setting the `KERAS_BACKEND` environment variable. For example: ```shell @@ -75,21 +90,13 @@ import keras_cv > [!IMPORTANT] > Make sure to set the `KERAS_BACKEND` before import any Keras libraries, it > will be used to set up Keras when it is first imported. -Until the Keras 3 release, KerasCV will use a preview of Keras 3 on PyPI named -[keras-core](https://pypi.org/project/keras-core/). - -> [!IMPORTANT] -> If you set `KERAS_BACKEND` variable, you should `import keras_core as keras` -> instead of `import keras`. This is a temporary step until Keras 3 is out! -To restore the default **Keras 2** behavior, `unset KERAS_BACKEND` before -importing Keras and KerasCV. Once that configuration step is done, you can just import KerasCV and start using it on top of your backend of choice: ```python import keras_cv -from keras_cv.backend import keras +import keras filepath = keras.utils.get_file(origin="https://i.imgur.com/gCNcJJI.jpg") image = np.array(keras.utils.load_img(filepath)) @@ -108,7 +115,7 @@ predictions = model.predict(image_resized) import tensorflow as tf import keras_cv import tensorflow_datasets as tfds -from keras_cv.backend import keras +import keras # Create a preprocessing pipeline with augmentations BATCH_SIZE = 16 @@ -260,7 +267,7 @@ Here is the BibTeX entry: ```bibtex @misc{wood2022kerascv, title={KerasCV}, - author={Wood, Luke and Tan, Zhenyu and Stenbit, Ian and Bischof, Jonathan and Zhu, Scott and Chollet, Fran\c{c}ois and others}, + author={Wood, Luke and Tan, Zhenyu and Stenbit, Ian and Bischof, Jonathan and Zhu, Scott and Chollet, Fran\c{c}ois and Sreepathihalli, Divyashree and Sampath, Ramesh and others}, year={2022}, howpublished={\url{https://github.com/keras-team/keras-cv}}, } diff --git a/keras_cv/__init__.py b/keras_cv/__init__.py index 9e405eae3b..e256b71051 100644 --- a/keras_cv/__init__.py +++ b/keras_cv/__init__.py @@ -42,4 +42,4 @@ from keras_cv.core import NormalFactorSampler from keras_cv.core import UniformFactorSampler -__version__ = "0.7.1" +__version__ = "0.7.2.dev0" diff --git a/keras_cv/backend/__init__.py b/keras_cv/backend/__init__.py index 890073a1eb..2190534ff8 100644 --- a/keras_cv/backend/__init__.py +++ b/keras_cv/backend/__init__.py @@ -11,6 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Keras backend module. + +This module adds a temporary Keras API surface that is fully under KerasCV +control. The goal is to allow us to write Keras 3-like code everywhere, while +still supporting Keras 2. We do this by using the `keras_core` package with +Keras 2 to backport Keras 3 numerics APIs (`keras.ops` and `keras.random`) into +Keras 2. The sub-modules exposed are as follows: + +- `config`: check which version of Keras is being run. +- `keras`: The full `keras` API with compat shims for older Keras versions. +- `ops`: `keras.ops` for Keras 3 or `keras_core.ops` for Keras 2. +- `random`: `keras.random` for Keras 3 or `keras_core.ops` for Keras 2. +""" from keras_cv.backend import config # noqa: E402 from keras_cv.backend import keras # noqa: E402 from keras_cv.backend import ops # noqa: E402 @@ -19,12 +33,12 @@ def assert_tf_keras(src): - if config.multi_backend(): + if config.keras_3(): raise NotImplementedError( - f"KerasCV component {src} does not yet support Keras Core, and can " + f"KerasCV component {src} does not yet support Keras 3, and can " "only be used in `tf.keras`." ) def supports_ragged(): - return not config.multi_backend() + return not config.keras_3() diff --git a/keras_cv/backend/config.py b/keras_cv/backend/config.py index 8d9c520e88..11d6c8273b 100644 --- a/keras_cv/backend/config.py +++ b/keras_cv/backend/config.py @@ -11,29 +11,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json -import os -_MULTI_BACKEND = False -# Set Keras base dir path given KERAS_HOME env variable, if applicable. -# Otherwise either ~/.keras or /tmp. -if "KERAS_HOME" in os.environ: - _keras_dir = os.environ.get("KERAS_HOME") -else: - _keras_base_dir = os.path.expanduser("~") - if not os.access(_keras_base_dir, os.W_OK): - _keras_base_dir = "/tmp" - _keras_dir = os.path.join(_keras_base_dir, ".keras") +from tensorflow import keras + +# We follow the version of keras that tensorflow is configured to use. +_USE_KERAS_3 = False + +# Note that only recent versions of keras have a `version()` function. +if hasattr(keras, "version") and keras.version().startswith("3."): + _USE_KERAS_3 = True def detect_if_tensorflow_uses_keras_3(): # We follow the version of keras that tensorflow is configured to use. - from tensorflow import keras - - # Note that only recent versions of keras have a `version()` function. - if hasattr(keras, "version") and keras.version().startswith("3."): - return True + try: + from tensorflow import keras + + # Note that only recent versions of keras have a `version()` function. + if hasattr(keras, "version") and keras.version().startswith("3."): + return True + except: + raise ValueError( + "Unable to import `keras` with `tensorflow`. Please check your " + "Keras and Tensorflow version are compatible; Keras 3 requires " + "TensorFlow 2.15 or later. See keras.io/getting_started for more " + "information on installing Keras." + ) # No `keras.version()` means we are on an old version of keras. return False @@ -49,53 +53,9 @@ def keras_3(): return _USE_KERAS_3 -# Attempt to read KerasCV config file. -_config_path = os.path.expanduser(os.path.join(_keras_dir, "keras_cv.json")) -if os.path.exists(_config_path): - try: - with open(_config_path) as f: - _config = json.load(f) - except ValueError: - _config = {} - _MULTI_BACKEND = _config.get("multi_backend", _MULTI_BACKEND) - -# Save config file, if possible. -if not os.path.exists(_keras_dir): - try: - os.makedirs(_keras_dir) - except OSError: - # Except permission denied and potential race conditions - # in multi-threaded environments. - pass - -if not os.path.exists(_config_path): - _config = { - "multi_backend": _MULTI_BACKEND, - } - try: - with open(_config_path, "w") as f: - f.write(json.dumps(_config, indent=4)) - except IOError: - # Except permission denied. - pass - -if "KERAS_BACKEND" in os.environ and os.environ["KERAS_BACKEND"]: - _MULTI_BACKEND = True - - -def multi_backend(): - return _MULTI_BACKEND - - def backend(): """Check the backend framework.""" - if not multi_backend(): - return "tensorflow" if not keras_3(): - import keras_core - - return keras_core.config.backend() - - from tensorflow import keras + return "tensorflow" return keras.config.backend() diff --git a/keras_cv/backend/keras.py b/keras_cv/backend/keras.py index 65082573ae..340efad3b1 100644 --- a/keras_cv/backend/keras.py +++ b/keras_cv/backend/keras.py @@ -30,11 +30,6 @@ import keras # noqa: F403, F401 from keras import * # noqa: F403, F401 - keras.backend.name_scope = keras.name_scope -elif config.multi_backend(): - import keras_core as keras # noqa: F403, F401 - from keras_core import * # noqa: F403, F401 - keras.backend.name_scope = keras.name_scope else: from tensorflow import keras # noqa: F403, F401 diff --git a/keras_cv/backend/ops.py b/keras_cv/backend/ops.py index cbb533c103..c02dadfd6b 100644 --- a/keras_cv/backend/ops.py +++ b/keras_cv/backend/ops.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from keras_cv.backend.config import keras_3 -from keras_cv.backend.config import multi_backend +from keras_cv.backend import config -if keras_3(): +if config.keras_3(): from keras.ops import * # noqa: F403, F401 from keras.preprocessing.image import smart_resize # noqa: F403, F401 @@ -32,5 +31,5 @@ from keras_core.src.utils.image_utils import ( # noqa: F403, F401 smart_resize, ) -if not multi_backend(): - from keras_cv.backend.tf_ops import * # noqa: F403, F401 + if config.backend() == "tensorflow": + from keras_cv.backend.tf_ops import * # noqa: F403, F401 diff --git a/keras_cv/backend/scope.py b/keras_cv/backend/scope.py index d6bd1b40d2..b2eb094147 100644 --- a/keras_cv/backend/scope.py +++ b/keras_cv/backend/scope.py @@ -18,7 +18,7 @@ from keras_cv.backend import keras from keras_cv.backend import ops from keras_cv.backend import tf_ops -from keras_cv.backend.config import multi_backend +from keras_cv.backend.config import keras_3 _ORIGINAL_OPS = copy.copy(backend.ops.__dict__) _ORIGINAL_SUPPORTS_RAGGED = backend.supports_ragged @@ -30,7 +30,7 @@ def tf_data(function): @functools.wraps(function) def wrapper(*args, **kwargs): - if multi_backend() and keras.src.utils.backend_utils.in_tf_graph(): + if keras_3() and keras.src.utils.backend_utils.in_tf_graph(): with TFDataScope(): return function(*args, **kwargs) else: diff --git a/keras_cv/conftest.py b/keras_cv/conftest.py index 8e886dfd99..b8be780c39 100644 --- a/keras_cv/conftest.py +++ b/keras_cv/conftest.py @@ -17,7 +17,7 @@ import tensorflow as tf from packaging import version -from keras_cv.backend.config import multi_backend +from keras_cv.backend.config import keras_3 def pytest_addoption(parser): @@ -45,7 +45,7 @@ def pytest_configure(config): ) config.addinivalue_line( "markers", - "tf_keras_only: mark test as a tf.keras-only test", + "tf_keras_only: mark test as a Keras 2-only test", ) config.addinivalue_line( "markers", @@ -69,12 +69,12 @@ def pytest_collection_modifyitems(config, items): skip_extra_large = pytest.mark.skipif( not run_extra_large_tests, reason="need --run_extra_large option to run" ) - skip_tf_keras_only = pytest.mark.skipif( - multi_backend(), - reason="This test is only supported on tf.keras", + skip_keras_2_only = pytest.mark.skipif( + keras_3(), + reason="This test is only supported on Keras 2", ) skip_tf_only = pytest.mark.skipif( - multi_backend() and keras_core.backend.backend() != "tensorflow", + keras_3() and keras_core.backend.backend() != "tensorflow", reason="This test is only supported on TensorFlow", ) for item in items: @@ -87,6 +87,6 @@ def pytest_collection_modifyitems(config, items): if "extra_large" in item.keywords: item.add_marker(skip_extra_large) if "tf_keras_only" in item.keywords: - item.add_marker(skip_tf_keras_only) + item.add_marker(skip_keras_2_only) if "tf_only" in item.keywords: item.add_marker(skip_tf_only) diff --git a/keras_cv/layers/object_detection/multi_class_non_max_suppression.py b/keras_cv/layers/object_detection/multi_class_non_max_suppression.py index 7825268578..08bcd97528 100644 --- a/keras_cv/layers/object_detection/multi_class_non_max_suppression.py +++ b/keras_cv/layers/object_detection/multi_class_non_max_suppression.py @@ -18,7 +18,7 @@ from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.backend import ops -from keras_cv.backend.config import multi_backend +from keras_cv.backend.config import keras_3 @keras_cv_export("keras_cv.layers.MultiClassNonMaxSuppression") @@ -73,7 +73,7 @@ def call( `bounding_box_format` specified in the constructor. class_prediction: Dense Tensor of shape [batch, boxes, num_classes]. """ - if multi_backend() and keras.backend.backend() != "tensorflow": + if keras_3() and keras.backend.backend() != "tensorflow": raise NotImplementedError( "MultiClassNonMaxSuppression does not support non-TensorFlow " "backends. Consider using NonMaxSuppression instead." diff --git a/keras_cv/layers/object_detection/non_max_suppression.py b/keras_cv/layers/object_detection/non_max_suppression.py index ed988cf113..2c39fd0d12 100644 --- a/keras_cv/layers/object_detection/non_max_suppression.py +++ b/keras_cv/layers/object_detection/non_max_suppression.py @@ -20,7 +20,7 @@ from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.backend import ops -from keras_cv.backend.config import multi_backend +from keras_cv.backend.config import keras_3 EPSILON = 1e-8 @@ -89,7 +89,7 @@ def call( confidence_prediction = ops.max(class_prediction, axis=-1) - if not multi_backend() or keras.backend.backend() == "tensorflow": + if not keras_3() or keras.backend.backend() == "tensorflow": idx, valid_det = tf.image.non_max_suppression_padded( box_prediction, confidence_prediction, @@ -318,7 +318,7 @@ def suppression_loop_body(boxes, iou_threshold, output_size, idx): # TODO(ianstenbit): Fix bug in tfnp.take_along_axis that causes this hack. # (This will be removed anyway when we use built-in NMS for TF.) - if multi_backend() and keras.backend.backend() != "tensorflow": + if keras_3() and keras.backend.backend() != "tensorflow": idx = ops.take_along_axis( ops.reshape(sorted_indices, [-1]), take_along_axis_idx ) diff --git a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py index 3f42a804a3..0a365891b7 100644 --- a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py +++ b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py @@ -571,6 +571,8 @@ def _ensure_inputs_are_compute_dtype(self, inputs): inputs, self.compute_dtype, ) + # Copy the input dict before we mutate it. + inputs = dict(inputs) inputs[IMAGES] = preprocessing.ensure_tensor( inputs[IMAGES], self.compute_dtype, diff --git a/keras_cv/layers/preprocessing/rand_augment_test.py b/keras_cv/layers/preprocessing/rand_augment_test.py index 9ef0ea9fbd..0f9759cc42 100644 --- a/keras_cv/layers/preprocessing/rand_augment_test.py +++ b/keras_cv/layers/preprocessing/rand_augment_test.py @@ -12,17 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np -import pytest import tensorflow as tf from absl.testing import parameterized from keras_cv import layers -from keras_cv.backend.config import keras_3 from keras_cv.tests.test_case import TestCase -@pytest.mark.skipif(keras_3(), reason="imcompatible with Keras 3") class RandAugmentTest(TestCase): + def test_zero_rate_pass_through(self): + rand_augment = layers.RandAugment( + value_range=(0, 255), + rate=0.0, + ) + xs = np.ones((2, 512, 512, 3)) + ys = rand_augment(xs) + self.assertAllClose(ys, xs) + @parameterized.named_parameters( ("0", 0), ("20", 0.2), diff --git a/keras_cv/layers/preprocessing/random_augmentation_pipeline.py b/keras_cv/layers/preprocessing/random_augmentation_pipeline.py index a525680fca..e437a90147 100644 --- a/keras_cv/layers/preprocessing/random_augmentation_pipeline.py +++ b/keras_cv/layers/preprocessing/random_augmentation_pipeline.py @@ -103,7 +103,7 @@ def _augment(self, inputs): ) result = tf.cond( skip_augment > self.rate, - lambda: inputs, + lambda: result, lambda: self._random_choice(result), ) return result diff --git a/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer.py b/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer.py index c46b5f81b2..8d3dacfa98 100644 --- a/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer.py +++ b/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer.py @@ -24,7 +24,7 @@ from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.backend import scope -from keras_cv.backend.config import multi_backend +from keras_cv.backend.config import keras_3 from keras_cv.utils import preprocessing H_AXIS = -3 @@ -44,7 +44,7 @@ base_class = ( keras.src.layers.preprocessing.tf_data_layer.TFDataLayer - if multi_backend() + if keras_3() else keras.layers.Layer ) @@ -444,6 +444,9 @@ def _format_inputs(self, inputs): # single image input tensor metadata[IS_DICT] = False inputs = {IMAGES: inputs} + else: + # Copy the input dict before we mutate it. + inputs = dict(inputs) metadata[BATCHED] = inputs["images"].shape.rank == 4 if inputs["images"].shape.rank == 3: @@ -504,6 +507,8 @@ def _ensure_inputs_are_compute_dtype(self, inputs): inputs, self.compute_dtype, ) + # Copy the input dict before we mutate it. + inputs = dict(inputs) inputs[IMAGES] = preprocessing.ensure_tensor( inputs[IMAGES], self.compute_dtype, diff --git a/keras_cv/layers/spatial_pyramid.py b/keras_cv/layers/spatial_pyramid.py index b45ee7bda3..8f0e50a7d8 100644 --- a/keras_cv/layers/spatial_pyramid.py +++ b/keras_cv/layers/spatial_pyramid.py @@ -19,7 +19,7 @@ from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.backend import ops -from keras_cv.backend.config import multi_backend +from keras_cv.backend.config import keras_3 @keras_cv_export("keras_cv.layers.SpatialPyramidPooling") @@ -68,7 +68,7 @@ def __init__( self.dropout = dropout # TODO(ianstenbit): Remove this once TF 2.14 is released which adds # XLA support for resizing with bilinear interpolation. - if multi_backend() and keras.backend.backend() == "tensorflow": + if keras_3() and keras.backend.backend() == "tensorflow": self.supports_jit = False def build(self, input_shape): diff --git a/keras_cv/layers/vit_det_layers.py b/keras_cv/layers/vit_det_layers.py index 78c0b0bfb6..9311a957f5 100644 --- a/keras_cv/layers/vit_det_layers.py +++ b/keras_cv/layers/vit_det_layers.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np - from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.backend import ops @@ -123,16 +121,16 @@ def _get_rel_pos(self, query_size, key_size, rel_pos): return rel_pos_resized else: rel_pos_resized = rel_pos - query_coordinates = np.arange(query_size, dtype="float32")[:, None] * ( - max(key_size / query_size, 1.0) - ) - key_coordinates = np.arange(key_size, dtype="float32")[None, :] * ( - max(query_size / key_size, 1.0) - ) + query_coordinates = ops.cast( + ops.arange(query_size), dtype=self.compute_dtype + )[:, None] * (max(key_size / query_size, 1.0)) + key_coordinates = ops.cast( + ops.arange(key_size), dtype=self.compute_dtype + )[None, :] * (max(query_size / key_size, 1.0)) relative_coordinates = (query_coordinates - key_coordinates) + ( key_size - 1 ) * max(query_size / key_size, 1.0) - relative_coordinates = relative_coordinates.astype("int32") + relative_coordinates = ops.cast(relative_coordinates, dtype="int32") return ops.take(rel_pos_resized, relative_coordinates, 0) def call(self, attention_map, queries, query_size, key_size): diff --git a/keras_cv/losses/focal_test.py b/keras_cv/losses/focal_test.py index 9594d13d21..a8f4463fba 100644 --- a/keras_cv/losses/focal_test.py +++ b/keras_cv/losses/focal_test.py @@ -16,7 +16,7 @@ from keras_cv.backend import keras from keras_cv.backend import ops -from keras_cv.backend.config import multi_backend +from keras_cv.backend.config import keras_3 from keras_cv.losses import FocalLoss from keras_cv.tests.test_case import TestCase @@ -73,7 +73,7 @@ def test_from_logits_argument(self): # TF backend somehow has different numerics. expected_loss = ( 31.11176 - if multi_backend() and keras.backend.backend() != "tensorflow" + if keras_3() and keras.backend.backend() != "tensorflow" else 925.28081 ) self.assertAllClose( diff --git a/keras_cv/models/object_detection_3d/center_pillar_backbone_test.py b/keras_cv/models/object_detection_3d/center_pillar_backbone_test.py index cc89cb1895..c02046b5e1 100644 --- a/keras_cv/models/object_detection_3d/center_pillar_backbone_test.py +++ b/keras_cv/models/object_detection_3d/center_pillar_backbone_test.py @@ -16,13 +16,13 @@ import tensorflow as tf from keras_cv.backend import keras -from keras_cv.backend.config import multi_backend +from keras_cv.backend.config import keras_3 from keras_cv.models.object_detection_3d import CenterPillarBackbone from keras_cv.tests.test_case import TestCase @pytest.mark.skipif( - multi_backend() and keras.backend.backend() == "torch", + keras_3() and keras.backend.backend() == "torch", reason="CenterPillar does not yet support PyTorch.", ) class CenterPillarBackboneTest(TestCase): diff --git a/keras_cv/models/object_detection_3d/center_pillar_test.py b/keras_cv/models/object_detection_3d/center_pillar_test.py index 352d33fbbd..b4be9e6a5e 100644 --- a/keras_cv/models/object_detection_3d/center_pillar_test.py +++ b/keras_cv/models/object_detection_3d/center_pillar_test.py @@ -16,7 +16,7 @@ import tensorflow as tf from keras_cv.backend import keras -from keras_cv.backend.config import multi_backend +from keras_cv.backend.config import keras_3 from keras_cv.layers.object_detection_3d.voxelization import DynamicVoxelization from keras_cv.models.object_detection_3d.center_pillar import ( MultiClassDetectionHead, @@ -34,7 +34,7 @@ @pytest.mark.skipif( - multi_backend() and keras.backend.backend() == "torch", + keras_3() and keras.backend.backend() == "torch", reason="CenterPillar does not yet support PyTorch.", ) class CenterPillarTest(TestCase): diff --git a/keras_cv/models/segmentation/segment_anything/sam_layers.py b/keras_cv/models/segmentation/segment_anything/sam_layers.py index 577031c63c..fffc4faee5 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_layers.py +++ b/keras_cv/models/segmentation/segment_anything/sam_layers.py @@ -99,7 +99,7 @@ def call(self, query, value, key): # Attention C_PH = ops.shape(query)[-1] out = query @ ops.transpose(key, (0, 1, 3, 2)) - out = out / ops.sqrt(ops.cast(C_PH, dtype=self.dtype)) + out = out / ops.sqrt(ops.cast(C_PH, dtype=self.compute_dtype)) out = ops.softmax(out, axis=-1) # Get output @@ -278,7 +278,7 @@ def __init__(self, num_positional_features, scale, **kwargs): self.positional_encoding_gaussian_matrix = self.add_weight( name="positional_encoding_gaussian_matrix", shape=(2, self.num_positional_features), - dtype=self.dtype, + dtype=self.variable_dtype, trainable=False, initializer=keras.initializers.get("normal"), ) @@ -288,7 +288,9 @@ def build(self, input_shape=None): def __positional_encodings(self, coords): coords = coords * 2 - 1 - coords = coords @ self.positional_encoding_gaussian_matrix + coords = coords @ ops.cast( + self.positional_encoding_gaussian_matrix, dtype=self.compute_dtype + ) coords = coords * (2 * math.pi) return ops.concatenate([ops.sin(coords), ops.cos(coords)], axis=-1) @@ -305,11 +307,11 @@ def encode_image(self, size): tensor: Positional encoding of the image. """ H, W = size - grid = ops.ones(shape=(H, W), dtype=self.dtype) + grid = ops.ones(shape=(H, W), dtype=self.compute_dtype) y_embed = ops.cumsum(grid, axis=0) - 0.5 x_embed = ops.cumsum(grid, axis=1) - 0.5 - y_embed = y_embed / ops.cast(H, self.dtype) - x_embed = x_embed / ops.cast(W, self.dtype) + y_embed = y_embed / ops.cast(H, self.compute_dtype) + x_embed = x_embed / ops.cast(W, self.compute_dtype) return self.__positional_encodings( ops.stack([x_embed, y_embed], axis=-1) ) diff --git a/keras_cv/models/segmentation/segment_anything/sam_test.py b/keras_cv/models/segmentation/segment_anything/sam_test.py index 3546cb906f..295355a716 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_test.py +++ b/keras_cv/models/segmentation/segment_anything/sam_test.py @@ -220,48 +220,68 @@ def test_mask_decoder(self): self.assertEqual(num_parameters, 4_058_340) @pytest.mark.large - def test_end_to_end_model_predict(self): - model = SegmentAnythingModel( - backbone=self.image_encoder, - prompt_encoder=self.prompt_encoder, - mask_decoder=self.mask_decoder, - ) - - # We use box-only prompting for this test. - mask_prompts = self.get_prompts(1, "boxes") - inputs = { - "images": np.ones((1, 1024, 1024, 3)), - } - inputs.update(mask_prompts) - - # Check the number of parameters - num_parameters = np.sum([np.prod(x.shape) for x in model.weights]) - self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340) + @parameterized.named_parameters( + [ + ("float32", "float32"), + ("mixed_float16", "mixed_float16"), + ("bfloat16", "bfloat16"), + ] + ) + def test_end_to_end_model_predict(self, dtype_policy): + import threading + + with threading.Lock(): + # We are changing the global dtype policy here but don't want any + # other tests to use that policy, so compute under a lock until + # we reset the global policy. + old_policy = getattr( + keras.mixed_precision, "dtype_policy", lambda: "float32" + )() + keras.mixed_precision.set_global_policy(dtype_policy) + model = SegmentAnythingModel( + backbone=self.image_encoder, + prompt_encoder=self.prompt_encoder, + mask_decoder=self.mask_decoder, + ) - # Forward pass through the model - outputs = model.predict(inputs) - masks, iou_pred = outputs["masks"], outputs["iou_pred"] + # We use box-only prompting for this test. + mask_prompts = self.get_prompts(1, "boxes") + inputs = { + "images": np.ones((1, 1024, 1024, 3)), + } + inputs.update(mask_prompts) + + # Check the number of parameters + num_parameters = np.sum([np.prod(x.shape) for x in model.weights]) + self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340) + + # Forward pass through the model + outputs = model.predict(inputs) + masks, iou_pred = outputs["masks"], outputs["iou_pred"] + + # Check the output is equal to the one we expect if we + # run each component separately. This is to confirm that + # the graph is getting compiled correctly i.e. the jitted + # execution is equivalent to the eager execution. + features = self.image_encoder(inputs["images"]) + outputs_ex = self.prompt_encoder( + {k: v for k, v in inputs.items() if k != "images"} + ) + outputs_ex = self.mask_decoder( + { + "image_embeddings": features, + "image_pe": outputs_ex["dense_positional_embeddings"], + "sparse_prompt_embeddings": outputs_ex["sparse_embeddings"], + "dense_prompt_embeddings": outputs_ex["dense_embeddings"], + }, + ) + masks_ex, iou_pred_ex = outputs_ex["masks"], outputs_ex["iou_pred"] - # Check the output is equal to the one we expect if we - # run each component separately. This is to confirm that - # the graph is getting compiled correctly i.e. the jitted - # execution is equivalent to the eager execution. - features = self.image_encoder(inputs["images"]) - outputs_ex = self.prompt_encoder( - {k: v for k, v in inputs.items() if k != "images"} - ) - outputs_ex = self.mask_decoder( - { - "image_embeddings": features, - "image_pe": outputs_ex["dense_positional_embeddings"], - "sparse_prompt_embeddings": outputs_ex["sparse_embeddings"], - "dense_prompt_embeddings": outputs_ex["dense_embeddings"], - }, - ) - masks_ex, iou_pred_ex = outputs_ex["masks"], outputs_ex["iou_pred"] + self.assertAllClose(masks, masks_ex, atol=1e-4) + self.assertAllClose(iou_pred, iou_pred_ex, atol=1e-4) - self.assertAllClose(masks, masks_ex, atol=1e-4) - self.assertAllClose(iou_pred, iou_pred_ex, atol=1e-4) + # Reset the global policy + keras.mixed_precision.set_global_policy(old_policy) @pytest.mark.extra_large def test_end_to_end_model_save(self): diff --git a/keras_cv/models/stable_diffusion/stable_diffusion.py b/keras_cv/models/stable_diffusion/stable_diffusion.py index 975788ac74..768795c956 100644 --- a/keras_cv/models/stable_diffusion/stable_diffusion.py +++ b/keras_cv/models/stable_diffusion/stable_diffusion.py @@ -200,7 +200,7 @@ def generate_image( if diffusion_noise is not None: diffusion_noise = ops.squeeze(diffusion_noise) - if diffusion_noise.shape.rank == 3: + if len(ops.shape(diffusion_noise)) == 3: diffusion_noise = ops.repeat( ops.expand_dims(diffusion_noise, axis=0), batch_size, axis=0 )