diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 67416cef61a..7776e7a5ba2 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -217,7 +217,6 @@ def __init__( params=None, state=None, seed=None, - dtype=None, **kwargs, ): if backend.backend() != "jax": @@ -231,10 +230,9 @@ def __init__( "`init_fn`, `params` and `state` cannot all be `None`." ) - super().__init__(dtype=dtype, **kwargs) + super().__init__(**kwargs) self.call_fn = call_fn self.init_fn = init_fn - self.has_dtype_policy = dtype is not None self.seed_generator = backend.random.SeedGenerator(seed) self.tracked_params = self._create_variables(params, trainable=True) self.tracked_state = self._create_variables(state, trainable=False) @@ -301,7 +299,7 @@ def create_variable(value): value, (np.ndarray, np.generic) ): dtype = value.dtype - if self.has_dtype_policy and is_float_dtype(dtype): + if is_float_dtype(dtype): dtype = None # Use the layer dtype policy return self.add_weight( value.shape, @@ -311,7 +309,7 @@ def create_variable(value): ) elif isinstance(value, (bool, int, float)): dtype = standardize_dtype(type(value)) - if self.has_dtype_policy and is_float_dtype(dtype): + if is_float_dtype(dtype): dtype = None # Use the layer dtype policy return self.add_weight( (),