Skip to content

Commit

Permalink
Add support for dtype / DTypePolicy to JaxLayer and FlaxLayer. (
Browse files Browse the repository at this point in the history
#20732)

The `dtype` / `DTypePolicy` is applied to all float variables.
  • Loading branch information
hertschuh authored Jan 7, 2025
1 parent fbf0af7 commit 26e71f5
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 18 deletions.
2 changes: 1 addition & 1 deletion keras/src/backend/jax/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def stateful_fn(*args, **kwargs):
self._tf_trackable.non_trainable_variables,
non_trainable_variables,
):
var.assign(new_value)
var.assign(tf.cast(new_value, var.dtype))
return output

stateful_fn.__signature__ = inspect.Signature(
Expand Down
10 changes: 4 additions & 6 deletions keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,10 @@ class Layer(BackendLayer, Operation, KerasSaveable):
trainable: Boolean, whether the layer's variables should be trainable.
name: String name of the layer.
dtype: The dtype of the layer's computations and weights. Can also be a
`keras.DTypePolicy`,
which allows the computation and
weight dtype to differ. Defaults to `None`. `None` means to use
`keras.config.dtype_policy()`,
which is a `float32` policy unless set to different value
(via `keras.config.set_dtype_policy()`).
`keras.DTypePolicy`, which allows the computation and weight dtype
to differ. Defaults to `None`. `None` means to use
`keras.config.dtype_policy()`, which is a `float32` policy unless
set to different value (via `keras.config.set_dtype_policy()`).
Attributes:
name: The name of the layer (string).
Expand Down
38 changes: 27 additions & 11 deletions keras/src/utils/jax_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from keras.src import backend
from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.backend.common.variables import is_float_dtype
from keras.src.backend.common.variables import standardize_dtype
from keras.src.layers.layer import Layer
from keras.src.saving import serialization_lib
from keras.src.utils import jax_utils
Expand Down Expand Up @@ -204,6 +206,8 @@ def my_haiku_module_fn(inputs, training):
argument, then `init_fn` is called at build time to initialize the
non-trainable state of the model.
seed: Seed for random number generator. Optional.
dtype: The dtype of the layer's computations and weights. Can also be a
`keras.DTypePolicy`. Optional. Defaults to the default policy.
"""

def __init__(
Expand All @@ -213,6 +217,7 @@ def __init__(
params=None,
state=None,
seed=None,
dtype=None,
**kwargs,
):
if backend.backend() != "jax":
Expand All @@ -226,9 +231,10 @@ def __init__(
"`init_fn`, `params` and `state` cannot all be `None`."
)

super().__init__(**kwargs)
super().__init__(dtype=dtype, **kwargs)
self.call_fn = call_fn
self.init_fn = init_fn
self.has_dtype_policy = dtype is not None
self.seed_generator = backend.random.SeedGenerator(seed)
self.tracked_params = self._create_variables(params, trainable=True)
self.tracked_state = self._create_variables(state, trainable=False)
Expand Down Expand Up @@ -291,18 +297,28 @@ def _create_variables(self, values, trainable):
"""

def create_variable(value):
if backend.is_tensor(value) or isinstance(value, np.ndarray):
variable = self.add_weight(
value.shape, initializer="zeros", trainable=trainable
if backend.is_tensor(value) or isinstance(
value, (np.ndarray, np.generic)
):
dtype = value.dtype
if self.has_dtype_policy and is_float_dtype(dtype):
dtype = None # Use the layer dtype policy
return self.add_weight(
value.shape,
initializer=value,
dtype=dtype,
trainable=trainable,
)
variable.assign(value)
return variable
elif isinstance(value, (np.generic, int, float)):
variable = self.add_weight(
(), initializer="zeros", trainable=trainable
elif isinstance(value, (bool, int, float)):
dtype = standardize_dtype(type(value))
if self.has_dtype_policy and is_float_dtype(dtype):
dtype = None # Use the layer dtype policy
return self.add_weight(
(),
initializer=backend.convert_to_tensor(value),
dtype=dtype,
trainable=trainable,
)
variable.assign(value)
return variable
else:
return value

Expand Down
26 changes: 26 additions & 0 deletions keras/src/utils/jax_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from keras.src import testing
from keras.src import tree
from keras.src import utils
from keras.src.dtype_policies.dtype_policy import DTypePolicy
from keras.src.saving import object_registration
from keras.src.utils.jax_layer import FlaxLayer
from keras.src.utils.jax_layer import JaxLayer
Expand Down Expand Up @@ -362,6 +363,18 @@ def call(self, inputs):
"non_trainable_weights": 1,
"non_trainable_params": 1,
},
{
"testcase_name": "training_state_dtype_policy",
"init_kwargs": {
"call_fn": jax_stateful_apply,
"init_fn": jax_stateful_init,
"dtype": DTypePolicy("mixed_float16"),
},
"trainable_weights": 6,
"trainable_params": 266610,
"non_trainable_weights": 1,
"non_trainable_params": 1,
},
)
def test_jax_layer(
self,
Expand Down Expand Up @@ -414,6 +427,19 @@ def test_jax_layer(
"non_trainable_weights": 8,
"non_trainable_params": 536,
},
{
"testcase_name": "training_rng_unbound_method_dtype_policy",
"flax_model_class": "FlaxDropoutModel",
"flax_model_method": None,
"init_kwargs": {
"method": "flax_dropout_wrapper",
"dtype": DTypePolicy("mixed_float16"),
},
"trainable_weights": 8,
"trainable_params": 648226,
"non_trainable_weights": 0,
"non_trainable_params": 0,
},
)
@pytest.mark.skipif(flax is None, reason="Flax library is not available.")
def test_flax_layer(
Expand Down

0 comments on commit 26e71f5

Please sign in to comment.