Skip to content

Commit

Permalink
Merge pull request #1193 from hlzl:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 730595416
  • Loading branch information
OptaxDev committed Feb 24, 2025
2 parents 9b682ab + e165e99 commit 6e45008
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions optax/contrib/_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class MuonState(NamedTuple):
"""State for the Adam algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
mu: base.Updates
ns_coeffs: chex.Array # shape=(), dtype=jnp.int32.


def scale_by_muon(
Expand Down Expand Up @@ -142,15 +143,19 @@ def scale_by_muon(
<https://arxiv.org/abs/2409.20325>`_, 2024
"""
mu_dtype = utils.canonicalize_dtype(mu_dtype)
ns_coeffs_ = jnp.asarray(ns_coeffs)
if ns_coeffs_.ndim > 2 or ns_coeffs_.shape[-1] != 3:
raise ValueError(
f'ns_coeffs must have shape (3,) or (n, 3), got {ns_coeffs_.shape}'
)

def init_fn(params):
mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment
return MuonState(count=jnp.zeros([], jnp.int32), mu=mu)
ns_coeffs_ = jnp.asarray(ns_coeffs)
if ns_coeffs_.ndim > 2 or ns_coeffs_.shape[-1] != 3:
raise ValueError(
f'ns_coeffs must have shape (3,) or (n, 3), got {ns_coeffs_.shape}'
)
return MuonState(
count=jnp.zeros([], jnp.int32),
mu=mu,
ns_coeffs=ns_coeffs_,
)

def update_fn(updates, state, params=None):
del params
Expand All @@ -168,7 +173,9 @@ def update_fn(updates, state, params=None):
mu_hat = otu.tree_bias_correction(mu, beta, count_inc)
# Apply Newton-schulz orthogonalization.
updates = jax.tree.map(
lambda x: orthogonalize_via_newton_schulz(x, ns_coeffs_, ns_steps, eps),
lambda x: orthogonalize_via_newton_schulz(
x, state.ns_coeffs, ns_steps, eps
),
mu_hat,
)
if adaptive:
Expand All @@ -178,7 +185,11 @@ def update_fn(updates, state, params=None):
lambda x, y: jnp.einsum('ij,ij,ab->ab', x, y, y), mu_hat, updates
)
mu = otu.tree_cast(mu, mu_dtype)
return updates, MuonState(count=count_inc, mu=mu)
return updates, MuonState(
count=count_inc,
mu=mu,
ns_coeffs=state.ns_coeffs,
)
return base.GradientTransformation(init_fn, update_fn)


Expand Down

0 comments on commit 6e45008

Please sign in to comment.