diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index 5b4de8ac..d5c6a05d 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -97,6 +97,7 @@ class MuonState(NamedTuple): """State for the Adam algorithm.""" count: chex.Array # shape=(), dtype=jnp.int32. mu: base.Updates + ns_coeffs: chex.Array # shape=(), dtype=jnp.int32. def scale_by_muon( @@ -142,15 +143,19 @@ def scale_by_muon( `_, 2024 """ mu_dtype = utils.canonicalize_dtype(mu_dtype) - ns_coeffs_ = jnp.asarray(ns_coeffs) - if ns_coeffs_.ndim > 2 or ns_coeffs_.shape[-1] != 3: - raise ValueError( - f'ns_coeffs must have shape (3,) or (n, 3), got {ns_coeffs_.shape}' - ) def init_fn(params): mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment - return MuonState(count=jnp.zeros([], jnp.int32), mu=mu) + ns_coeffs_ = jnp.asarray(ns_coeffs) + if ns_coeffs_.ndim > 2 or ns_coeffs_.shape[-1] != 3: + raise ValueError( + f'ns_coeffs must have shape (3,) or (n, 3), got {ns_coeffs_.shape}' + ) + return MuonState( + count=jnp.zeros([], jnp.int32), + mu=mu, + ns_coeffs=ns_coeffs_, + ) def update_fn(updates, state, params=None): del params @@ -168,7 +173,9 @@ def update_fn(updates, state, params=None): mu_hat = otu.tree_bias_correction(mu, beta, count_inc) # Apply Newton-schulz orthogonalization. updates = jax.tree.map( - lambda x: orthogonalize_via_newton_schulz(x, ns_coeffs_, ns_steps, eps), + lambda x: orthogonalize_via_newton_schulz( + x, state.ns_coeffs, ns_steps, eps + ), mu_hat, ) if adaptive: @@ -178,7 +185,11 @@ def update_fn(updates, state, params=None): lambda x, y: jnp.einsum('ij,ij,ab->ab', x, y, y), mu_hat, updates ) mu = otu.tree_cast(mu, mu_dtype) - return updates, MuonState(count=count_inc, mu=mu) + return updates, MuonState( + count=count_inc, + mu=mu, + ns_coeffs=state.ns_coeffs, + ) return base.GradientTransformation(init_fn, update_fn)