Skip to content

Commit

Permalink
Improve keras 3 detection (keras-team#2132)
Browse files Browse the repository at this point in the history
* improve keras 3 detection

* update __init__

* update config

* update config to update multibackend flag

* updated spelling error

* updatconfig

* update tf_ops

* update tf -ops

* update tf ops

* undo last commit

* update tf_ops.py

* reformat

* update init()

* update minor error in tf ops

* remove check for keras 3

* revert changes to tf_ops

* disable layers using internal base random layer

* update syntax error

* update imports in keras version check layer

* update init method in keras version check

* add seed argument

* rasie error in class itself

* update constructor

* update to import directly from layers

* update import in base image augmenattion layer

* change import sin tf_ops

* update tf ops import

* change init

* update tf_ops

* update backend functions

* keras.src

* update namespace

* update namescope correctly

* code reformat

* reformat and add backend functions

* modified ops import

* reformat

* update ops

* update backend

* update ops

* code reformatted

* update import

* update imports in ops

* update error message

* review changes added

* update keras imports

* update imports

* update imports

* update import in random.py

* add issues link

* code reformat

* code reformat

* review comments addressed

* code reformat
  • Loading branch information
divyashreepathihalli authored and yuvraj-wale committed Feb 8, 2024
1 parent c10fd37 commit 43eefe7
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 120 deletions.
77 changes: 3 additions & 74 deletions keras_cv/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,91 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Keras backend module.
This module adds a temporarily Keras API surface that is fully under KerasCV
control. This allows us to switch between `keras_core` and `tf.keras`, as well
as add shims to support older version of `tf.keras`.
- `config`: check which backend is being run.
- `keras`: The full `keras` API (via `keras_core` or `tf.keras`).
- `ops`: `keras_core.ops`, always tf-backed if using `tf.keras`.
"""

import types

from packaging.version import parse

from keras_cv.backend.config import multi_backend

# Keys are of the form: "module.where.attr.exists->module.where.to.alias"
# Value are of the form: ["attr1", "attr2", ...] or
# [("attr1_original_name", "attr1_alias_name")]
_KERAS_CORE_ALIASES = {
"utils->saving": [
"register_keras_serializable",
"deserialize_keras_object",
"serialize_keras_object",
"get_registered_object",
],
"models->saving": ["load_model"],
}


if multi_backend():
import keras

if not hasattr(keras, "__version__") or parse(keras.__version__) < parse(
"3.0"
):
import keras_core as keras

keras.backend.name_scope = keras.name_scope
else:
from tensorflow import keras

if not hasattr(keras, "saving"):
keras.saving = types.SimpleNamespace()

# add aliases
for key, value in _KERAS_CORE_ALIASES.items():
src, _, dst = key.partition("->")
src = src.split(".")
dst = dst.split(".")

src_mod, dst_mod = keras, keras

# navigate to where we want to alias the attributes
for mod in src:
src_mod = getattr(src_mod, mod)
for mod in dst:
dst_mod = getattr(dst_mod, mod)

# add an alias for each attribute
for attr in value:
if isinstance(attr, tuple):
src_attr, dst_attr = attr
else:
src_attr, dst_attr = attr, attr
attr_val = getattr(src_mod, src_attr)
setattr(dst_mod, dst_attr, attr_val)

# TF Keras doesn't have this rename.
keras.activations.silu = keras.activations.swish

from keras_cv.backend import config # noqa: E402
from keras_cv.backend import keras # noqa: E402
from keras_cv.backend import ops # noqa: E402
from keras_cv.backend import random # noqa: E402
from keras_cv.backend import tf_ops # noqa: E402


def assert_tf_keras(src):
if multi_backend():
if config.multi_backend():
raise NotImplementedError(
f"KerasCV component {src} does not yet support Keras Core, and can "
"only be used in `tf.keras`."
)


def supports_ragged():
return not multi_backend()
return not config.multi_backend()
37 changes: 37 additions & 0 deletions keras_cv/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,29 @@
_keras_base_dir = "/tmp"
_keras_dir = os.path.join(_keras_base_dir, ".keras")


def detect_if_tensorflow_uses_keras_3():
# We follow the version of keras that tensorflow is configured to use.
from tensorflow import keras

# Note that only recent versions of keras have a `version()` function.
if hasattr(keras, "version") and keras.version().startswith("3."):
return True

# No `keras.version()` means we are on an old version of keras.
return False


_USE_KERAS_3 = detect_if_tensorflow_uses_keras_3()
if _USE_KERAS_3:
_MULTI_BACKEND = True


def keras_3():
"""Check if Keras 3 is being used."""
return _USE_KERAS_3


# Attempt to read KerasCV config file.
_config_path = os.path.expanduser(os.path.join(_keras_dir, "keras_cv.json"))
if os.path.exists(_config_path):
Expand Down Expand Up @@ -62,3 +85,17 @@

def multi_backend():
return _MULTI_BACKEND


def backend():
"""Check the backend framework."""
if not multi_backend():
return "tensorflow"
if not keras_3():
import keras_core

return keras_core.config.backend()

from tensorflow import keras

return keras.config.backend()
70 changes: 70 additions & 0 deletions keras_cv/backend/keras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import types

from keras_cv.backend import config

_KERAS_CORE_ALIASES = {
"utils->saving": [
"register_keras_serializable",
"deserialize_keras_object",
"serialize_keras_object",
"get_registered_object",
],
"models->saving": ["load_model"],
}

if config.keras_3():
import keras # noqa: F403, F401
from keras import * # noqa: F403, F401

keras.backend.name_scope = keras.name_scope
elif config.multi_backend():
import keras_core as keras # noqa: F403, F401
from keras_core import * # noqa: F403, F401

keras.backend.name_scope = keras.name_scope
else:
from tensorflow import keras # noqa: F403, F401
from tensorflow.keras import * # noqa: F403, F401

if not hasattr(keras, "saving"):
keras.saving = types.SimpleNamespace()

# add aliases
for key, value in _KERAS_CORE_ALIASES.items():
src, _, dst = key.partition("->")
src = src.split(".")
dst = dst.split(".")

src_mod, dst_mod = keras, keras

# navigate to where we want to alias the attributes
for mod in src:
src_mod = getattr(src_mod, mod)
for mod in dst:
dst_mod = getattr(dst_mod, mod)

# add an alias for each attribute
for attr in value:
if isinstance(attr, tuple):
src_attr, dst_attr = attr
else:
src_attr, dst_attr = attr, attr
attr_val = getattr(src_mod, src_attr)
setattr(dst_mod, dst_attr, attr_val)

# TF Keras doesn't have this rename.
keras.activations.silu = keras.activations.swish
9 changes: 6 additions & 3 deletions keras_cv/backend/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_cv.backend.config import keras_3
from keras_cv.backend.config import multi_backend

if multi_backend():
if keras_3():
from keras.ops import * # noqa: F403, F401
from keras.preprocessing.image import smart_resize # noqa: F403, F401
else:
try:
from keras.src.backend import vectorized_map # noqa: F403, F401
from keras.src.ops import * # noqa: F403, F401
from keras.src.utils.image_utils import smart_resize # noqa: F403, F401
# Import error means Keras isn't installed, or is Keras 2.
Expand All @@ -25,5 +28,5 @@
from keras_core.src.utils.image_utils import ( # noqa: F403, F401
smart_resize,
)
else:
if not multi_backend():
from keras_cv.backend.tf_ops import * # noqa: F403, F401
8 changes: 4 additions & 4 deletions keras_cv/backend/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.backend.config import multi_backend
from keras_cv.backend.config import keras_3

if multi_backend():
from keras_core.random import * # noqa: F403, F401
if keras_3():
from keras.random import * # noqa: F403, F401
else:
from keras_core.src.backend.tensorflow.random import * # noqa: F403, F401
from keras_core.random import * # noqa: F403, F401
35 changes: 25 additions & 10 deletions keras_cv/backend/tf_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_core.src.backend.tensorflow import * # noqa: F403, F401
from keras_core.src.backend.tensorflow import ( # noqa: F403, F401
convert_to_numpy,
)
from keras_core.src.backend.tensorflow.core import * # noqa: F403, F401
from keras_core.src.backend.tensorflow.math import * # noqa: F403, F401
from keras_core.src.backend.tensorflow.nn import * # noqa: F403, F401
from keras_core.src.backend.tensorflow.numpy import * # noqa: F403, F401
from keras_cv.backend import config

if config.keras_3():
from keras.src.backend.tensorflow import * # noqa: F403, F401
from keras.src.backend.tensorflow import ( # noqa: F403, F401
convert_to_numpy,
)
from keras.src.backend.tensorflow.core import * # noqa: F403, F401
from keras.src.backend.tensorflow.math import * # noqa: F403, F401
from keras.src.backend.tensorflow.nn import * # noqa: F403, F401
from keras.src.backend.tensorflow.numpy import * # noqa: F403, F401
else:
# isort: off
from keras_core.src.backend.tensorflow import * # noqa: F403, F401
from keras_core.src.backend.tensorflow import ( # noqa: F403, F401
convert_to_numpy,
)
from keras_core.src.backend.tensorflow.core import * # noqa: F403, F401
from keras_core.src.backend.tensorflow.math import * # noqa: F403, F401
from keras_core.src.backend.tensorflow.nn import * # noqa: F403, F401
from keras_core.src.backend.tensorflow.numpy import * # noqa: F403, F401, E501


# Some TF APIs where the numpy API doesn't support raggeds that we need
from tensorflow import broadcast_to # noqa: F403, F401
from tensorflow import concat as concatenate # noqa: F403, F401
from tensorflow import repeat # noqa: F403, F401
from tensorflow import reshape # noqa: F403, F401

from tensorflow import range as arange # noqa: F403, F401
from tensorflow import reduce_all as all # noqa: F403, F401
from tensorflow import reduce_max as max # noqa: F403, F401
from tensorflow import repeat # noqa: F403, F401
from tensorflow import reshape # noqa: F403, F401
from tensorflow import split # noqa: F403, F401
from tensorflow.keras.preprocessing.image import ( # noqa: F403, F401
smart_resize,
Expand Down
21 changes: 17 additions & 4 deletions keras_cv/layers/preprocessing_3d/base_augmentation_layer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@
# limitations under the License.

import tensorflow as tf
from tensorflow import keras

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import config

if config.keras_3():
base_layer = tf.keras.layers.Layer
else:
base_layer = tf.keras.__internal__.layers.BaseRandomLayer

POINT_CLOUDS = "point_clouds"
BOUNDING_BOXES = "bounding_boxes"
Expand All @@ -29,7 +34,7 @@


@keras_cv_export("keras_cv.layers.BaseAugmentationLayer3D")
class BaseAugmentationLayer3D(keras.__internal__.layers.BaseRandomLayer):
class BaseAugmentationLayer3D(base_layer):
"""Abstract base layer for data augmentation for 3D perception.
This layer contains base functionalities for preprocessing layers which
Expand Down Expand Up @@ -99,8 +104,16 @@ def augment_pointclouds(self, point_clouds, transformation):
"""

def __init__(self, seed=None, **kwargs):
super().__init__(seed=seed, **kwargs)
self.auto_vectorize = False
# To-do: remove this once th elayer is ported to keras 3
# https://github.com/keras-team/keras-cv/issues/2136
if config.keras_3():
raise ValueError(
"This layer is not yet compatible with Keras 3."
"Please switch to Keras 2 to use this layer."
)
else:
super().__init__(seed=seed, **kwargs)
self.auto_vectorize = False

@property
def auto_vectorize(self):
Expand Down
Loading

0 comments on commit 43eefe7

Please sign in to comment.