diff --git a/keras_cv/backend/__init__.py b/keras_cv/backend/__init__.py index a22ce4ca4d..890073a1eb 100644 --- a/keras_cv/backend/__init__.py +++ b/keras_cv/backend/__init__.py @@ -11,86 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Keras backend module. - -This module adds a temporarily Keras API surface that is fully under KerasCV -control. This allows us to switch between `keras_core` and `tf.keras`, as well -as add shims to support older version of `tf.keras`. -- `config`: check which backend is being run. -- `keras`: The full `keras` API (via `keras_core` or `tf.keras`). -- `ops`: `keras_core.ops`, always tf-backed if using `tf.keras`. -""" - -import types - -from packaging.version import parse - -from keras_cv.backend.config import multi_backend - -# Keys are of the form: "module.where.attr.exists->module.where.to.alias" -# Value are of the form: ["attr1", "attr2", ...] or -# [("attr1_original_name", "attr1_alias_name")] -_KERAS_CORE_ALIASES = { - "utils->saving": [ - "register_keras_serializable", - "deserialize_keras_object", - "serialize_keras_object", - "get_registered_object", - ], - "models->saving": ["load_model"], -} - - -if multi_backend(): - import keras - - if not hasattr(keras, "__version__") or parse(keras.__version__) < parse( - "3.0" - ): - import keras_core as keras - - keras.backend.name_scope = keras.name_scope -else: - from tensorflow import keras - - if not hasattr(keras, "saving"): - keras.saving = types.SimpleNamespace() - - # add aliases - for key, value in _KERAS_CORE_ALIASES.items(): - src, _, dst = key.partition("->") - src = src.split(".") - dst = dst.split(".") - - src_mod, dst_mod = keras, keras - - # navigate to where we want to alias the attributes - for mod in src: - src_mod = getattr(src_mod, mod) - for mod in dst: - dst_mod = getattr(dst_mod, mod) - - # add an alias for each attribute - for attr in value: - if isinstance(attr, tuple): - src_attr, dst_attr = attr - else: - src_attr, dst_attr = attr, attr - attr_val = getattr(src_mod, src_attr) - setattr(dst_mod, dst_attr, attr_val) - - # TF Keras doesn't have this rename. - keras.activations.silu = keras.activations.swish - from keras_cv.backend import config # noqa: E402 +from keras_cv.backend import keras # noqa: E402 from keras_cv.backend import ops # noqa: E402 from keras_cv.backend import random # noqa: E402 from keras_cv.backend import tf_ops # noqa: E402 def assert_tf_keras(src): - if multi_backend(): + if config.multi_backend(): raise NotImplementedError( f"KerasCV component {src} does not yet support Keras Core, and can " "only be used in `tf.keras`." @@ -98,4 +27,4 @@ def assert_tf_keras(src): def supports_ragged(): - return not multi_backend() + return not config.multi_backend() diff --git a/keras_cv/backend/config.py b/keras_cv/backend/config.py index a921639bd9..8d9c520e88 100644 --- a/keras_cv/backend/config.py +++ b/keras_cv/backend/config.py @@ -26,6 +26,29 @@ _keras_base_dir = "/tmp" _keras_dir = os.path.join(_keras_base_dir, ".keras") + +def detect_if_tensorflow_uses_keras_3(): + # We follow the version of keras that tensorflow is configured to use. + from tensorflow import keras + + # Note that only recent versions of keras have a `version()` function. + if hasattr(keras, "version") and keras.version().startswith("3."): + return True + + # No `keras.version()` means we are on an old version of keras. + return False + + +_USE_KERAS_3 = detect_if_tensorflow_uses_keras_3() +if _USE_KERAS_3: + _MULTI_BACKEND = True + + +def keras_3(): + """Check if Keras 3 is being used.""" + return _USE_KERAS_3 + + # Attempt to read KerasCV config file. _config_path = os.path.expanduser(os.path.join(_keras_dir, "keras_cv.json")) if os.path.exists(_config_path): @@ -62,3 +85,17 @@ def multi_backend(): return _MULTI_BACKEND + + +def backend(): + """Check the backend framework.""" + if not multi_backend(): + return "tensorflow" + if not keras_3(): + import keras_core + + return keras_core.config.backend() + + from tensorflow import keras + + return keras.config.backend() diff --git a/keras_cv/backend/keras.py b/keras_cv/backend/keras.py new file mode 100644 index 0000000000..65082573ae --- /dev/null +++ b/keras_cv/backend/keras.py @@ -0,0 +1,70 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types + +from keras_cv.backend import config + +_KERAS_CORE_ALIASES = { + "utils->saving": [ + "register_keras_serializable", + "deserialize_keras_object", + "serialize_keras_object", + "get_registered_object", + ], + "models->saving": ["load_model"], +} + +if config.keras_3(): + import keras # noqa: F403, F401 + from keras import * # noqa: F403, F401 + + keras.backend.name_scope = keras.name_scope +elif config.multi_backend(): + import keras_core as keras # noqa: F403, F401 + from keras_core import * # noqa: F403, F401 + + keras.backend.name_scope = keras.name_scope +else: + from tensorflow import keras # noqa: F403, F401 + from tensorflow.keras import * # noqa: F403, F401 + + if not hasattr(keras, "saving"): + keras.saving = types.SimpleNamespace() + + # add aliases + for key, value in _KERAS_CORE_ALIASES.items(): + src, _, dst = key.partition("->") + src = src.split(".") + dst = dst.split(".") + + src_mod, dst_mod = keras, keras + + # navigate to where we want to alias the attributes + for mod in src: + src_mod = getattr(src_mod, mod) + for mod in dst: + dst_mod = getattr(dst_mod, mod) + + # add an alias for each attribute + for attr in value: + if isinstance(attr, tuple): + src_attr, dst_attr = attr + else: + src_attr, dst_attr = attr, attr + attr_val = getattr(src_mod, src_attr) + setattr(dst_mod, dst_attr, attr_val) + + # TF Keras doesn't have this rename. + keras.activations.silu = keras.activations.swish diff --git a/keras_cv/backend/ops.py b/keras_cv/backend/ops.py index f9f1f43a93..cdb116bf63 100644 --- a/keras_cv/backend/ops.py +++ b/keras_cv/backend/ops.py @@ -11,11 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from keras_cv.backend.config import keras_3 from keras_cv.backend.config import multi_backend -if multi_backend(): +if keras_3(): + from keras.ops import * # noqa: F403, F401 + from keras.preprocessing.image import smart_resize # noqa: F403, F401 +else: try: - from keras.src.backend import vectorized_map # noqa: F403, F401 from keras.src.ops import * # noqa: F403, F401 from keras.src.utils.image_utils import smart_resize # noqa: F403, F401 # Import error means Keras isn't installed, or is Keras 2. @@ -25,5 +28,5 @@ from keras_core.src.utils.image_utils import ( # noqa: F403, F401 smart_resize, ) -else: +if not multi_backend(): from keras_cv.backend.tf_ops import * # noqa: F403, F401 diff --git a/keras_cv/backend/random.py b/keras_cv/backend/random.py index 21d4b08c7d..d1d88cd715 100644 --- a/keras_cv/backend/random.py +++ b/keras_cv/backend/random.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_cv.backend.config import multi_backend +from keras_cv.backend.config import keras_3 -if multi_backend(): - from keras_core.random import * # noqa: F403, F401 +if keras_3(): + from keras.random import * # noqa: F403, F401 else: - from keras_core.src.backend.tensorflow.random import * # noqa: F403, F401 + from keras_core.random import * # noqa: F403, F401 diff --git a/keras_cv/backend/tf_ops.py b/keras_cv/backend/tf_ops.py index 106c9d0a33..b29d627b56 100644 --- a/keras_cv/backend/tf_ops.py +++ b/keras_cv/backend/tf_ops.py @@ -11,23 +11,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from keras_core.src.backend.tensorflow import * # noqa: F403, F401 -from keras_core.src.backend.tensorflow import ( # noqa: F403, F401 - convert_to_numpy, -) -from keras_core.src.backend.tensorflow.core import * # noqa: F403, F401 -from keras_core.src.backend.tensorflow.math import * # noqa: F403, F401 -from keras_core.src.backend.tensorflow.nn import * # noqa: F403, F401 -from keras_core.src.backend.tensorflow.numpy import * # noqa: F403, F401 +from keras_cv.backend import config + +if config.keras_3(): + from keras.src.backend.tensorflow import * # noqa: F403, F401 + from keras.src.backend.tensorflow import ( # noqa: F403, F401 + convert_to_numpy, + ) + from keras.src.backend.tensorflow.core import * # noqa: F403, F401 + from keras.src.backend.tensorflow.math import * # noqa: F403, F401 + from keras.src.backend.tensorflow.nn import * # noqa: F403, F401 + from keras.src.backend.tensorflow.numpy import * # noqa: F403, F401 +else: + # isort: off + from keras_core.src.backend.tensorflow import * # noqa: F403, F401 + from keras_core.src.backend.tensorflow import ( # noqa: F403, F401 + convert_to_numpy, + ) + from keras_core.src.backend.tensorflow.core import * # noqa: F403, F401 + from keras_core.src.backend.tensorflow.math import * # noqa: F403, F401 + from keras_core.src.backend.tensorflow.nn import * # noqa: F403, F401 + from keras_core.src.backend.tensorflow.numpy import * # noqa: F403, F401, E501 + # Some TF APIs where the numpy API doesn't support raggeds that we need from tensorflow import broadcast_to # noqa: F403, F401 from tensorflow import concat as concatenate # noqa: F403, F401 +from tensorflow import repeat # noqa: F403, F401 +from tensorflow import reshape # noqa: F403, F401 + from tensorflow import range as arange # noqa: F403, F401 from tensorflow import reduce_all as all # noqa: F403, F401 from tensorflow import reduce_max as max # noqa: F403, F401 -from tensorflow import repeat # noqa: F403, F401 -from tensorflow import reshape # noqa: F403, F401 from tensorflow import split # noqa: F403, F401 from tensorflow.keras.preprocessing.image import ( # noqa: F403, F401 smart_resize, diff --git a/keras_cv/layers/preprocessing_3d/base_augmentation_layer_3d.py b/keras_cv/layers/preprocessing_3d/base_augmentation_layer_3d.py index 49fa2772d1..20efedd972 100644 --- a/keras_cv/layers/preprocessing_3d/base_augmentation_layer_3d.py +++ b/keras_cv/layers/preprocessing_3d/base_augmentation_layer_3d.py @@ -13,9 +13,14 @@ # limitations under the License. import tensorflow as tf -from tensorflow import keras from keras_cv.api_export import keras_cv_export +from keras_cv.backend import config + +if config.keras_3(): + base_layer = tf.keras.layers.Layer +else: + base_layer = tf.keras.__internal__.layers.BaseRandomLayer POINT_CLOUDS = "point_clouds" BOUNDING_BOXES = "bounding_boxes" @@ -29,7 +34,7 @@ @keras_cv_export("keras_cv.layers.BaseAugmentationLayer3D") -class BaseAugmentationLayer3D(keras.__internal__.layers.BaseRandomLayer): +class BaseAugmentationLayer3D(base_layer): """Abstract base layer for data augmentation for 3D perception. This layer contains base functionalities for preprocessing layers which @@ -99,8 +104,16 @@ def augment_pointclouds(self, point_clouds, transformation): """ def __init__(self, seed=None, **kwargs): - super().__init__(seed=seed, **kwargs) - self.auto_vectorize = False + # To-do: remove this once th elayer is ported to keras 3 + # https://github.com/keras-team/keras-cv/issues/2136 + if config.keras_3(): + raise ValueError( + "This layer is not yet compatible with Keras 3." + "Please switch to Keras 2 to use this layer." + ) + else: + super().__init__(seed=seed, **kwargs) + self.auto_vectorize = False @property def auto_vectorize(self): diff --git a/keras_cv/layers/regularization/dropblock_2d.py b/keras_cv/layers/regularization/dropblock_2d.py index 72c5835fd4..abef5b17d1 100644 --- a/keras_cv/layers/regularization/dropblock_2d.py +++ b/keras_cv/layers/regularization/dropblock_2d.py @@ -13,14 +13,22 @@ # limitations under the License. import tensorflow as tf -from tensorflow.keras.__internal__.layers import BaseRandomLayer + +from keras_cv.backend import config + +if config.keras_3(): + base_layer = tf.keras.layers.Layer +else: + from tensorflow.keras.__internal__.layers import BaseRandomLayer + + base_layer = BaseRandomLayer from keras_cv.api_export import keras_cv_export from keras_cv.utils import conv_utils @keras_cv_export("keras_cv.layers.DropBlock2D") -class DropBlock2D(BaseRandomLayer): +class DropBlock2D(base_layer): """Applies DropBlock regularization to input features. DropBlock is a form of structured dropout, where units in a contiguous @@ -145,20 +153,29 @@ def __init__( seed=None, **kwargs, ): - super().__init__(seed=seed, **kwargs) - if not 0.0 <= rate <= 1.0: + # To-do: remove this once th elayer is ported to keras 3 + # https://github.com/keras-team/keras-cv/issues/2136 + if config.keras_3(): raise ValueError( - f"rate must be a number between 0 and 1. " f"Received: {rate}" + "This layer is not yet compatible with Keras 3." + "Please switch to Keras 2 to use this layer." ) - - self._rate = rate - ( - self._dropblock_height, - self._dropblock_width, - ) = conv_utils.normalize_tuple( - value=block_size, n=2, name="block_size", allow_zero=False - ) - self.seed = seed + else: + super().__init__(seed=seed, **kwargs) + if not 0.0 <= rate <= 1.0: + raise ValueError( + f"rate must be a number between 0 and 1. " + f"Received: {rate}" + ) + + self._rate = rate + ( + self._dropblock_height, + self._dropblock_width, + ) = conv_utils.normalize_tuple( + value=block_size, n=2, name="block_size", allow_zero=False + ) + self.seed = seed def call(self, x, training=None): if not training or self._rate == 0.0: diff --git a/keras_cv/models/object_detection/predict_utils.py b/keras_cv/models/object_detection/predict_utils.py index 8eb6e10e62..298f38091d 100644 --- a/keras_cv/models/object_detection/predict_utils.py +++ b/keras_cv/models/object_detection/predict_utils.py @@ -15,12 +15,14 @@ import tensorflow as tf try: - from keras.src.engine.training import _minimum_control_deps - from keras.src.engine.training import reduce_per_replica + # To-do: these imports need to fixed - Issue 2134 + # https://github.com/keras-team/keras-cv/issues/2134 + # from keras.src.engine.training import _minimum_control_deps + # from keras.src.engine.training import reduce_per_replica from keras.src.utils import tf_utils except ImportError: - from keras.engine.training import _minimum_control_deps - from keras.engine.training import reduce_per_replica + # from keras.engine.training import _minimum_control_deps + # from keras.engine.training import reduce_per_replica from keras.utils import tf_utils @@ -34,8 +36,8 @@ def step_function(iterator): def run_step(data): outputs = model.predict_step(data) # Ensure counter is updated only if `test_step` succeeds. - with tf.control_dependencies(_minimum_control_deps(outputs)): - model._predict_counter.assign_add(1) + # with tf.control_dependencies(_minimum_control_deps(outputs)): + model._predict_counter.assign_add(1) return outputs if model._jit_compile: @@ -45,9 +47,9 @@ def run_step(data): data = next(iterator) outputs = model.distribute_strategy.run(run_step, args=(data,)) - outputs = reduce_per_replica( - outputs, model.distribute_strategy, reduction="concat" - ) + # outputs = reduce_per_replica( + # outputs, model.distribute_strategy, reduction="concat" + # ) # Note that this is the only deviation from the base keras.Model # implementation. We add the decode_step inside of the computation # graph but outside of the distribute_strategy (i.e on host CPU). diff --git a/keras_cv/models/object_detection/retinanet/retinanet_test.py b/keras_cv/models/object_detection/retinanet/retinanet_test.py index 45026262f4..bbede75943 100644 --- a/keras_cv/models/object_detection/retinanet/retinanet_test.py +++ b/keras_cv/models/object_detection/retinanet/retinanet_test.py @@ -20,7 +20,7 @@ from absl.testing import parameterized import keras_cv -from keras_cv import backend +from keras_cv.backend import config from keras_cv.backend import keras from keras_cv.backend import ops from keras_cv.models.backbones.test_backbone_presets import ( @@ -251,7 +251,7 @@ def test_call_with_custom_label_encoder(self): model(ops.ones(shape=(2, 224, 224, 3))) def test_tf_dataset_data_generator(self): - if backend.multi_backend() and keras.backend.backend() != "tensorflow": + if config.backend() != "tensorflow": pytest.skip("TensorFlow required for `tf.data.Dataset` test.") def data_generator():