From 26e71f5aff68ccac9c7651f6de802af81f720b1d Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Tue, 7 Jan 2025 11:25:40 -0800 Subject: [PATCH] Add support for `dtype` / `DTypePolicy` to `JaxLayer` and `FlaxLayer`. (#20732) The `dtype` / `DTypePolicy` is applied to all float variables. --- keras/src/backend/jax/export.py | 2 +- keras/src/layers/layer.py | 10 ++++---- keras/src/utils/jax_layer.py | 38 ++++++++++++++++++++++--------- keras/src/utils/jax_layer_test.py | 26 +++++++++++++++++++++ 4 files changed, 58 insertions(+), 18 deletions(-) diff --git a/keras/src/backend/jax/export.py b/keras/src/backend/jax/export.py index 963648460dc..dd754c14418 100644 --- a/keras/src/backend/jax/export.py +++ b/keras/src/backend/jax/export.py @@ -119,7 +119,7 @@ def stateful_fn(*args, **kwargs): self._tf_trackable.non_trainable_variables, non_trainable_variables, ): - var.assign(new_value) + var.assign(tf.cast(new_value, var.dtype)) return output stateful_fn.__signature__ = inspect.Signature( diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 8e36bb20456..a4f830912d5 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -85,12 +85,10 @@ class Layer(BackendLayer, Operation, KerasSaveable): trainable: Boolean, whether the layer's variables should be trainable. name: String name of the layer. dtype: The dtype of the layer's computations and weights. Can also be a - `keras.DTypePolicy`, - which allows the computation and - weight dtype to differ. Defaults to `None`. `None` means to use - `keras.config.dtype_policy()`, - which is a `float32` policy unless set to different value - (via `keras.config.set_dtype_policy()`). + `keras.DTypePolicy`, which allows the computation and weight dtype + to differ. Defaults to `None`. `None` means to use + `keras.config.dtype_policy()`, which is a `float32` policy unless + set to different value (via `keras.config.set_dtype_policy()`). Attributes: name: The name of the layer (string). diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 8fd69d1f5bf..67416cef61a 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -5,6 +5,8 @@ from keras.src import backend from keras.src import tree from keras.src.api_export import keras_export +from keras.src.backend.common.variables import is_float_dtype +from keras.src.backend.common.variables import standardize_dtype from keras.src.layers.layer import Layer from keras.src.saving import serialization_lib from keras.src.utils import jax_utils @@ -204,6 +206,8 @@ def my_haiku_module_fn(inputs, training): argument, then `init_fn` is called at build time to initialize the non-trainable state of the model. seed: Seed for random number generator. Optional. + dtype: The dtype of the layer's computations and weights. Can also be a + `keras.DTypePolicy`. Optional. Defaults to the default policy. """ def __init__( @@ -213,6 +217,7 @@ def __init__( params=None, state=None, seed=None, + dtype=None, **kwargs, ): if backend.backend() != "jax": @@ -226,9 +231,10 @@ def __init__( "`init_fn`, `params` and `state` cannot all be `None`." ) - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.call_fn = call_fn self.init_fn = init_fn + self.has_dtype_policy = dtype is not None self.seed_generator = backend.random.SeedGenerator(seed) self.tracked_params = self._create_variables(params, trainable=True) self.tracked_state = self._create_variables(state, trainable=False) @@ -291,18 +297,28 @@ def _create_variables(self, values, trainable): """ def create_variable(value): - if backend.is_tensor(value) or isinstance(value, np.ndarray): - variable = self.add_weight( - value.shape, initializer="zeros", trainable=trainable + if backend.is_tensor(value) or isinstance( + value, (np.ndarray, np.generic) + ): + dtype = value.dtype + if self.has_dtype_policy and is_float_dtype(dtype): + dtype = None # Use the layer dtype policy + return self.add_weight( + value.shape, + initializer=value, + dtype=dtype, + trainable=trainable, ) - variable.assign(value) - return variable - elif isinstance(value, (np.generic, int, float)): - variable = self.add_weight( - (), initializer="zeros", trainable=trainable + elif isinstance(value, (bool, int, float)): + dtype = standardize_dtype(type(value)) + if self.has_dtype_policy and is_float_dtype(dtype): + dtype = None # Use the layer dtype policy + return self.add_weight( + (), + initializer=backend.convert_to_tensor(value), + dtype=dtype, + trainable=trainable, ) - variable.assign(value) - return variable else: return value diff --git a/keras/src/utils/jax_layer_test.py b/keras/src/utils/jax_layer_test.py index 359bdca41c9..306c930660f 100644 --- a/keras/src/utils/jax_layer_test.py +++ b/keras/src/utils/jax_layer_test.py @@ -15,6 +15,7 @@ from keras.src import testing from keras.src import tree from keras.src import utils +from keras.src.dtype_policies.dtype_policy import DTypePolicy from keras.src.saving import object_registration from keras.src.utils.jax_layer import FlaxLayer from keras.src.utils.jax_layer import JaxLayer @@ -362,6 +363,18 @@ def call(self, inputs): "non_trainable_weights": 1, "non_trainable_params": 1, }, + { + "testcase_name": "training_state_dtype_policy", + "init_kwargs": { + "call_fn": jax_stateful_apply, + "init_fn": jax_stateful_init, + "dtype": DTypePolicy("mixed_float16"), + }, + "trainable_weights": 6, + "trainable_params": 266610, + "non_trainable_weights": 1, + "non_trainable_params": 1, + }, ) def test_jax_layer( self, @@ -414,6 +427,19 @@ def test_jax_layer( "non_trainable_weights": 8, "non_trainable_params": 536, }, + { + "testcase_name": "training_rng_unbound_method_dtype_policy", + "flax_model_class": "FlaxDropoutModel", + "flax_model_method": None, + "init_kwargs": { + "method": "flax_dropout_wrapper", + "dtype": DTypePolicy("mixed_float16"), + }, + "trainable_weights": 8, + "trainable_params": 648226, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, ) @pytest.mark.skipif(flax is None, reason="Flax library is not available.") def test_flax_layer(