Skip to content

Commit

Permalink
no bias correction
Browse files Browse the repository at this point in the history
  • Loading branch information
evanatyourservice committed Oct 1, 2024
1 parent 9408c40 commit cccb8a8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 11 deletions.
13 changes: 3 additions & 10 deletions psgd_jax/kron.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,11 @@ def update_fn(updates: base.Updates, state: dict, params: base.Params = None):
precond_lr_in = precond_lr(count_inc)

# momentum
momentum_updates = updates
mu = None
momentum_updates = updates
if state["mu"] is not None:
momentum_updates, mu = _apply_momentum(updates, state["mu"], count_inc, b1)
mu = otu.tree_update_moment(updates, state["mu"], b1, 1)
momentum_updates = mu

# flatten pytrees
updates, grads_structure = jax.tree.flatten(updates)
Expand Down Expand Up @@ -370,14 +371,6 @@ def kron(
return chain(*opt)


def _apply_momentum(
updates: base.Updates, momentum: base.Updates, step, b1
) -> Tuple[base.Updates, base.Updates]:
mu = otu.tree_update_moment(updates, momentum, b1, 1)
updates = otu.tree_bias_correction(mu, b1, step)
return updates, mu


def _add_eps(x):
return jnp.clip(x, 1e-30, None)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "psgd-jax"
version = "0.1.9"
version = "0.1.10"
description = "An implementation of PSGD optimizer in JAX."
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
Expand Down

0 comments on commit cccb8a8

Please sign in to comment.