Skip to content

Commit

Permalink
JaxLayer now uses the global dtype policy by default. (#20767)
Browse files Browse the repository at this point in the history
All floats will now follow the global dtype policy unless a specific dtype policy is passed to the layer.
  • Loading branch information
hertschuh authored Jan 16, 2025
1 parent 0454f06 commit 25d6d80
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions keras/src/utils/jax_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def __init__(
params=None,
state=None,
seed=None,
dtype=None,
**kwargs,
):
if backend.backend() != "jax":
Expand All @@ -231,10 +230,9 @@ def __init__(
"`init_fn`, `params` and `state` cannot all be `None`."
)

super().__init__(dtype=dtype, **kwargs)
super().__init__(**kwargs)
self.call_fn = call_fn
self.init_fn = init_fn
self.has_dtype_policy = dtype is not None
self.seed_generator = backend.random.SeedGenerator(seed)
self.tracked_params = self._create_variables(params, trainable=True)
self.tracked_state = self._create_variables(state, trainable=False)
Expand Down Expand Up @@ -301,7 +299,7 @@ def create_variable(value):
value, (np.ndarray, np.generic)
):
dtype = value.dtype
if self.has_dtype_policy and is_float_dtype(dtype):
if is_float_dtype(dtype):
dtype = None # Use the layer dtype policy
return self.add_weight(
value.shape,
Expand All @@ -311,7 +309,7 @@ def create_variable(value):
)
elif isinstance(value, (bool, int, float)):
dtype = standardize_dtype(type(value))
if self.has_dtype_policy and is_float_dtype(dtype):
if is_float_dtype(dtype):
dtype = None # Use the layer dtype policy
return self.add_weight(
(),
Expand Down

0 comments on commit 25d6d80

Please sign in to comment.