Skip to content

Commit

Permalink
Merge pull request #347 from kozistr/fix/adopt-optimizer
Browse files Browse the repository at this point in the history
[Fix] Updating `exp_avg_sq` after calculating the `denominator` in `ADOPT` optimizer
  • Loading branch information
kozistr authored Feb 13, 2025
2 parents b82f7c4 + db3a9ab commit 0386506
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 7 deletions.
9 changes: 9 additions & 0 deletions docs/changelogs/v3.4.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,16 @@
* `AGC` + `Lookahead` variants
* change default beta1, beta2 to 0.95 and 0.98 respectively
* Skip adding `Lookahead` wrapper in case of `Ranger*` optimizers, which already have it in `create_optimizer()`. (#340)
* Improved optimizer visualization. (#345)

### Bug

* Fix to update exp_avg_sq after calculating the denominator in `ADOPT` optimizer. (#346, #347)

### Docs

* Update the visualizations. (#340)

### Contributions

thanks to @AidinHamedi
4 changes: 2 additions & 2 deletions pytorch_optimizer/optimizer/adopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg_sq.addcmul_(grad, grad.conj())
continue

exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2)

de_nom = exp_avg_sq.sqrt().clamp_(min=group['eps'])

normed_grad = grad.div(de_nom)
Expand All @@ -137,4 +135,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:

p.add_(update, alpha=-lr)

exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2)

return loss
4 changes: 2 additions & 2 deletions pytorch_optimizer/optimizer/experimental/ranger25.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
fixed_decay=group['fixed_decay'],
)

exp_avg, exp_avg_sq, exp_avg_slow = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_slow']

grad.copy_(agc(p, grad))

exp_avg, exp_avg_sq, exp_avg_slow = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_slow']

normed_grad = grad.div(
exp_avg_sq.sqrt().clamp_(min=group['eps'] if group['eps'] is not None else 1e-8)
).clamp_(-clip, clip)
Expand Down
6 changes: 3 additions & 3 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,9 @@
(FOCUS, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
(Kron, {'lr': 1e0, 'weight_decay': 1e-3}, 3),
(EXAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
(Ranger25, {'lr': 5e-2}, 5),
(Ranger25, {'lr': 5e-2, 't_alpha_beta3': 5}, 5),
(Ranger25, {'lr': 5e-2, 'stable_adamw': False, 'orthograd': False, 'eps': None}, 5),
(Ranger25, {'lr': 1e-1}, 3),
(Ranger25, {'lr': 1e-1, 't_alpha_beta3': 5}, 3),
(Ranger25, {'lr': 5e-2, 'stable_adamw': False, 'orthograd': False, 'eps': None, 'lookahead_merge_time': 2}, 3),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),
Expand Down

0 comments on commit 0386506

Please sign in to comment.