Skip to content

Commit

Permalink
Merge pull request #331 from kozistr/feature/focus-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implemet `FOCUS` optimizer
  • Loading branch information
kozistr authored Jan 25, 2025
2 parents cdbd9bc + fbe62af commit 7a9377a
Show file tree
Hide file tree
Showing 15 changed files with 141 additions and 25 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **92 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **93 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -200,6 +200,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| Adam-ATAN2 | *Scaling Exponents Across Parameterizations and Optimizers* | | <https://arxiv.org/abs/2407.05872> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240705872E/exportcitation) |
| SPAM | *Spike-Aware Adam with Momentum Reset for Stable LLM Training* | [github](https://github.com/TianjinYellow/SPAM-Optimizer) | <https://arxiv.org/abs/2501.06842> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250106842H/exportcitation) |
| TAM | *Torque-Aware Momentum* | | <https://arxiv.org/abs/2412.18790> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241218790M/exportcitation) |
| FOCUS | *First Order Concentrated Updating Scheme* | [github](https://github.com/liuyz0/FOCUS) | <https://arxiv.org/abs/2501.12243> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250112243M/exportcitation) |

## Supported LR Scheduler

Expand Down
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **92 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **93 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -200,6 +200,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| Adam-ATAN2 | *Scaling Exponents Across Parameterizations and Optimizers* | | <https://arxiv.org/abs/2407.05872> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240705872E/exportcitation) |
| SPAM | *Spike-Aware Adam with Momentum Reset for Stable LLM Training* | [github](https://github.com/TianjinYellow/SPAM-Optimizer) | <https://arxiv.org/abs/2501.06842> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250106842H/exportcitation) |
| TAM | *Torque-Aware Momentum* | | <https://arxiv.org/abs/2412.18790> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241218790M/exportcitation) |
| FOCUS | *First Order Concentrated Updating Scheme* | [github](https://github.com/liuyz0/FOCUS) | <https://arxiv.org/abs/2501.12243> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250112243M/exportcitation) |

## Supported LR Scheduler

Expand Down
4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@
:docstring:
:members:

::: pytorch_optimizer.FOCUS
:docstring:
:members:

::: pytorch_optimizer.Fromage
:docstring:
:members:
Expand Down
8 changes: 8 additions & 0 deletions docs/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_FAdam.png)

### FOCUS

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_FOCUS.png)

### Fromage

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Fromage.png)
Expand Down Expand Up @@ -496,6 +500,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_FAdam.png)

### FOCUS

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_FOCUS.png)

### Fromage

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Fromage.png)
Expand Down
Binary file added docs/visualizations/rastrigin_FOCUS.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rosenbrock_FOCUS.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
38 changes: 19 additions & 19 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ keywords = [
"AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix", "ADOPT",
"AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos",
"Apollo", "APOLLO", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD",
"DAdaptLion", "DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Grams", "Gravity", "GrokFast", "GSAM",
"Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG", "Muno", "Nero",
"DAdaptLion", "DeMo", "DiffGrad", "FAdam", "FOCUS", "Fromage", "FTRL", "GaLore", "Grams", "Gravity", "GrokFast",
"GSAM", "Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG", "Muno", "Nero",
"NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger",
"Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo",
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "SRMM", "StableAdamW", "SWATS", "TAM",
Expand Down
1 change: 1 addition & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
ASGD,
BSAM,
CAME,
FOCUS,
FTRL,
GSAM,
LARS,
Expand Down
2 changes: 2 additions & 0 deletions pytorch_optimizer/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from pytorch_optimizer.optimizer.diffgrad import DiffGrad
from pytorch_optimizer.optimizer.experimental.ranger25 import Ranger25
from pytorch_optimizer.optimizer.fadam import FAdam
from pytorch_optimizer.optimizer.focus import FOCUS
from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
from pytorch_optimizer.optimizer.fromage import Fromage
from pytorch_optimizer.optimizer.ftrl import FTRL
Expand Down Expand Up @@ -289,6 +290,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
LaProp,
MARS,
SGDSaI,
FOCUS,
Grams,
SPAM,
Ranger25,
Expand Down
95 changes: 95 additions & 0 deletions pytorch_optimizer/optimizer/focus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS


class FOCUS(BaseOptimizer):
r"""First Order Concentrated Updating Scheme.
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
:param gamma: float. control the strength of the attraction.
:param weight_decay: float. weight decay (L2 penalty).
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-2,
betas: BETAS = (0.9, 0.999),
gamma: float = 0.1,
weight_decay: float = 0.0,
**kwargs,
):
self.validate_learning_rate(lr)
self.validate_betas(betas)
self.validate_range(gamma, 'gamma', 0.0, 1.0, '[)')
self.validate_non_negative(weight_decay, 'weight_decay')

defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'gamma': gamma, 'weight_decay': weight_decay}

super().__init__(params, defaults)

def __str__(self) -> str:
return 'FOCUS'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
state = self.state[p]

state['exp_avg'] = torch.zeros_like(p)
state['pbar'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1

beta1, beta2 = group['betas']

bias_correction2: float = self.debias(beta2, group['step'])

weight_decay: float = group['weight_decay']

for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(str(self))

state = self.state[p]
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['pbar'] = torch.zeros_like(p)

exp_avg, pbar = state['exp_avg'], state['pbar']

exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
pbar.mul_(beta2).add_(p, alpha=1.0 - beta2)

pbar_hat = pbar / bias_correction2

if weight_decay > 0.0:
p.add_(pbar_hat, alpha=-group['lr'] * weight_decay)

update = (p - pbar_hat).sign_().mul_(group['gamma']).add_(torch.sign(exp_avg))

p.add_(update, alpha=-group['lr'])

return loss
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ platformdirs==4.3.6 ; python_version >= "3.8"
pluggy==1.5.0 ; python_version >= "3.8"
pytest-cov==5.0.0 ; python_version >= "3.8"
pytest==8.3.4 ; python_version >= "3.8"
ruff==0.9.2 ; python_version >= "3.8"
ruff==0.9.3 ; python_version >= "3.8"
setuptools==75.8.0 ; python_version >= "3.12"
sympy==1.12.1 ; python_version == "3.8"
sympy==1.13.1 ; python_version >= "3.9"
Expand Down
3 changes: 3 additions & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
APOLLO,
ASGD,
CAME,
FOCUS,
FTRL,
LARS,
MADGRAD,
Expand Down Expand Up @@ -161,6 +162,7 @@
'apollo',
'mars',
'adatam',
'focus',
]

VALID_LR_SCHEDULER_NAMES: List[str] = [
Expand Down Expand Up @@ -554,6 +556,7 @@
(SPAM, {'lr': 1e0, 'weight_decay': 1e-3, 'warmup_epoch': 1, 'grad_accu_steps': 1, 'update_proj_gap': 1}, 5),
(TAM, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(AdaTAM, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
(FOCUS, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
(Ranger25, {'lr': 1e0}, 5),
(Ranger25, {'lr': 1e0, 't_alpha_beta3': 5}, 5),
(Ranger25, {'lr': 1e-1, 'stable_adamw': False, 'eps': None}, 5),
Expand Down
1 change: 1 addition & 0 deletions tests/test_general_optimizer_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_epsilon(optimizer_name):
'ftrl',
'demo',
'muon',
'focus',
):
pytest.skip(f'skip {optimizer_name} optimizer')

Expand Down
2 changes: 1 addition & 1 deletion tests/test_load_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):


def test_get_supported_optimizers():
assert len(get_supported_optimizers()) == 90
assert len(get_supported_optimizers()) == 91
assert len(get_supported_optimizers('adam*')) == 7
assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 10

Expand Down

0 comments on commit 7a9377a

Please sign in to comment.