Skip to content

Commit

Permalink
Merge pull request #344 from kozistr/feature/looksam-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement `GCSAM` and `LookSAM` optimizers
  • Loading branch information
kozistr authored Feb 9, 2025
2 parents 111249d + 40ec30d commit b82f7c4
Show file tree
Hide file tree
Showing 12 changed files with 285 additions and 15 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **97 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **98 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -204,6 +204,8 @@ get_supported_optimizers(['adam*', 'ranger*'])
| FOCUS | *First Order Concentrated Updating Scheme* | [github](https://github.com/liuyz0/FOCUS) | <https://arxiv.org/abs/2501.12243> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250112243M/exportcitation) |
| PSGD | *Preconditioned Stochastic Gradient Descent* | [github](https://github.com/lixilinx/psgd_torch) | <https://arxiv.org/abs/1512.04202> | [cite](https://github.com/lixilinx/psgd_torch?tab=readme-ov-file#resources) |
| EXAdam | *The Power of Adaptive Cross-Moments* | [github](https://github.com/AhmedMostafa16/EXAdam) | <https://arxiv.org/abs/2412.20302> | [cite](https://github.com/AhmedMostafa16/EXAdam?tab=readme-ov-file#citation) |
| GCSAM | *Gradient Centralized Sharpness Aware Minimization* | [github](https://github.com/mhassann22/GCSAM) | <https://arxiv.org/abs/2501.11584> | [cite](https://github.com/mhassann22/GCSAM?tab=readme-ov-file#citation) |
| LookSAM | *Towards Efficient and Scalable Sharpness-Aware Minimization* | [github](https://github.com/rollovd/LookSAM) | <https://arxiv.org/abs/2203.02714> | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220302714L/exportcitation) |

## Supported LR Scheduler

Expand Down
8 changes: 8 additions & 0 deletions docs/changelogs/v3.4.1.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
### Change Log

### Feature

* Support `GCSAM` optimizer. (#343, #344)
* [Gradient Centralized Sharpness Aware Minimization](https://arxiv.org/abs/2501.11584)
* you can use it from `SAM` optimizer by setting `use_gc=True`.
* Support `LookSAM` optimizer. (#343, #344)
* [Towards Efficient and Scalable Sharpness-Aware Minimization](https://arxiv.org/abs/2203.02714)

### Update

* Support alternative precision training for `Shampoo` optimizer. (#339)
Expand Down
4 changes: 3 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **97 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **98 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -204,6 +204,8 @@ get_supported_optimizers(['adam*', 'ranger*'])
| FOCUS | *First Order Concentrated Updating Scheme* | [github](https://github.com/liuyz0/FOCUS) | <https://arxiv.org/abs/2501.12243> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250112243M/exportcitation) |
| PSGD | *Preconditioned Stochastic Gradient Descent* | [github](https://github.com/lixilinx/psgd_torch) | <https://arxiv.org/abs/1512.04202> | [cite](https://github.com/lixilinx/psgd_torch?tab=readme-ov-file#resources) |
| EXAdam | *The Power of Adaptive Cross-Moments* | [github](https://github.com/AhmedMostafa16/EXAdam) | <https://arxiv.org/abs/2412.20302> | [cite](https://github.com/AhmedMostafa16/EXAdam?tab=readme-ov-file#citation) |
| GCSAM | *Gradient Centralized Sharpness Aware Minimization* | [github](https://github.com/mhassann22/GCSAM) | <https://arxiv.org/abs/2501.11584> | [cite](https://github.com/mhassann22/GCSAM?tab=readme-ov-file#citation) |
| LookSAM | *Towards Efficient and Scalable Sharpness-Aware Minimization* | [github](https://github.com/rollovd/LookSAM) | <https://arxiv.org/abs/2203.02714> | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220302714L/exportcitation) |

## Supported LR Scheduler

Expand Down
4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@
:docstring:
:members:

::: pytorch_optimizer.LookSAM
:docstring:
:members:

::: pytorch_optimizer.MADGRAD
:docstring:
:members:
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ keywords = [
"DAdaptLion", "DeMo", "DiffGrad", "EXAdam", "FAdam", "FOCUS", "Fromage", "FTRL", "GaLore", "Grams", "Gravity",
"GrokFast", "GSAM", "Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG",
"Muno", "Nero", "NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "PSGD", "QHAdam", "QHM",
"RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam",
"SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "SRMM", "StableAdamW",
"SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice",
"LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
"RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "GCSAM", "LookSAM", "ScheduleFreeSGD", "ScheduleFreeAdamW",
"ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM",
"SRMM", "StableAdamW", "SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine",
"SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
"QGaLore",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand Down
1 change: 1 addition & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
LaProp,
Lion,
Lookahead,
LookSAM,
Muon,
Nero,
NovoGrad,
Expand Down
2 changes: 1 addition & 1 deletion pytorch_optimizer/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
from pytorch_optimizer.optimizer.ranger import Ranger
from pytorch_optimizer.optimizer.ranger21 import Ranger21
from pytorch_optimizer.optimizer.rotograd import RotoGrad
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM, LookSAM
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SGDSaI, SignSGD
from pytorch_optimizer.optimizer.sgdp import SGDP
Expand Down
198 changes: 195 additions & 3 deletions pytorch_optimizer/optimizer/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pytorch_optimizer.base.exception import NoClosureError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, OPTIMIZER, PARAMETERS
from pytorch_optimizer.optimizer.gc import centralize_gradient
from pytorch_optimizer.optimizer.utils import disable_running_stats, enable_running_stats


Expand Down Expand Up @@ -58,6 +59,7 @@ def closure():
:param base_optimizer: OPTIMIZER. base optimizer.
:param rho: float. size of the neighborhood for computing the max loss.
:param adaptive: bool. element-wise Adaptive SAM.
:param use_gc: bool. perform gradient centralization, GCSAM variant.
:param perturb_eps: float. eps for perturbation.
:param kwargs: Dict. parameters for optimizer.
"""
Expand All @@ -68,12 +70,14 @@ def __init__(
base_optimizer: OPTIMIZER,
rho: float = 0.05,
adaptive: bool = False,
use_gc: bool = False,
perturb_eps: float = 1e-12,
**kwargs,
):
self.validate_non_negative(rho, 'rho')
self.validate_non_negative(perturb_eps, 'perturb_eps')

self.use_gc = use_gc
self.perturb_eps = perturb_eps

defaults: DEFAULTS = {'rho': rho, 'adaptive': adaptive}
Expand All @@ -92,16 +96,20 @@ def reset(self):

@torch.no_grad()
def first_step(self, zero_grad: bool = False):
grad_norm = self.grad_norm()
grad_norm = self.grad_norm().add_(self.perturb_eps)
for group in self.param_groups:
scale = group['rho'] / (grad_norm + self.perturb_eps)
scale = group['rho'] / grad_norm

for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if self.use_gc:
centralize_gradient(grad, gc_conv_only=False)

self.state[p]['old_p'] = p.clone()
e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * p.grad * scale.to(p)
e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * grad * scale.to(p)

p.add_(e_w)

Expand Down Expand Up @@ -670,3 +678,187 @@ def step(self, closure: CLOSURE = None):
self.third_step()

return loss


class LookSAM(BaseOptimizer):
r"""Towards Efficient and Scalable Sharpness-Aware Minimization.
Example:
-------
Here's an example::
model = YourModel()
base_optimizer = Ranger21
optimizer = LookSAM(model.parameters(), base_optimizer)
for input, output in data:
# first forward-backward pass
loss = loss_function(output, model(input))
loss.backward()
optimizer.first_step(zero_grad=True)
# second forward-backward pass
# make sure to do a full forward pass
loss_function(output, model(input)).backward()
optimizer.second_step(zero_grad=True)
Alternative example with a single closure-based step function::
model = YourModel()
base_optimizer = Ranger21
optimizer = LookSAM(model.parameters(), base_optimizer)
def closure():
loss = loss_function(output, model(input))
loss.backward()
return loss
for input, output in data:
loss = loss_function(output, model(input))
loss.backward()
optimizer.step(closure)
optimizer.zero_grad()
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param base_optimizer: OPTIMIZER. base optimizer.
:param rho: float. size of the neighborhood for computing the max loss.
:param k: int. lookahead step.
:param alpha: float. lookahead blending alpha.
:param adaptive: bool. element-wise Adaptive SAM.
:param use_gc: bool. perform gradient centralization, GCSAM variant.
:param perturb_eps: float. eps for perturbation.
:param kwargs: Dict. parameters for optimizer.
"""

def __init__(
self,
params: PARAMETERS,
base_optimizer: OPTIMIZER,
rho: float = 0.1,
k: int = 10,
alpha: float = 0.7,
adaptive: bool = False,
use_gc: bool = False,
perturb_eps: float = 1e-12,
**kwargs,
):
self.validate_non_negative(rho, 'rho')
self.validate_positive(k, 'k')
self.validate_range(alpha, 'alpha', 0.0, 1.0, '()')
self.validate_non_negative(perturb_eps, 'perturb_eps')

self.k = k
self.alpha = alpha
self.use_gc = use_gc
self.perturb_eps = perturb_eps

defaults: DEFAULTS = {'rho': rho, 'adaptive': adaptive}
defaults.update(kwargs)

super().__init__(params, defaults)

self.base_optimizer: Optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups

def __str__(self) -> str:
return 'LookSAM'

@torch.no_grad()
def reset(self):
pass

def get_step(self):
return (
self.param_groups[0]['step']
if 'step' in self.param_groups[0]
else next(iter(self.base_optimizer.state.values()))['step'] if self.base_optimizer.state else 0
)

@torch.no_grad()
def first_step(self, zero_grad: bool = False) -> None:
if self.get_step() % self.k != 0:
return

grad_norm = self.grad_norm().add_(self.perturb_eps)
for group in self.param_groups:
scale = group['rho'] / grad_norm

for i, p in enumerate(group['params']):
if p.grad is None:
continue

grad = p.grad
if self.use_gc:
centralize_gradient(grad, gc_conv_only=False)

self.state[p]['old_p'] = p.clone()
self.state[f'old_grad_p_{i}']['old_grad_p'] = grad.clone()

e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * grad * scale.to(p)
p.add_(e_w)

if zero_grad:
self.zero_grad()

@torch.no_grad()
def second_step(self, zero_grad: bool = False):
step = self.get_step()

for group in self.param_groups:
for i, p in enumerate(group['params']):
if p.grad is None:
continue

grad = p.grad
grad_norm = grad.norm(p=2)

if step % self.k == 0:
old_grad_p = self.state[f'old_grad_p_{i}']['old_grad_p']

g_grad_norm = old_grad_p / old_grad_p.norm(p=2)
g_s_grad_norm = grad / grad_norm

self.state[f'gv_{i}']['gv'] = torch.sub(
grad, grad_norm * torch.sum(g_grad_norm * g_s_grad_norm) * g_grad_norm
)
else:
gv = self.state[f'gv_{i}']['gv']
grad.add_(grad_norm / (gv.norm(p=2) + 1e-8) * gv, alpha=self.alpha)

p.data = self.state[p]['old_p']

self.base_optimizer.step()

if zero_grad:
self.zero_grad()

@torch.no_grad()
def step(self, closure: CLOSURE = None):
if closure is None:
raise NoClosureError(str(self))

self.first_step(zero_grad=True)

with torch.enable_grad():
closure()

self.second_step()

def grad_norm(self) -> torch.Tensor:
shared_device = self.param_groups[0]['params'][0].device
return torch.norm(
torch.stack(
[
((torch.abs(p) if group['adaptive'] else 1.0) * p.grad).norm(p=2).to(shared_device)
for group in self.param_groups
for p in group['params']
if p.grad is not None
]
),
p=2,
)

def load_state_dict(self, state_dict: Dict):
super().load_state_dict(state_dict)
self.base_optimizer.param_groups = self.param_groups
1 change: 1 addition & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
'sam',
'gsam',
'wsam',
'looksam',
'pcgrad',
'lookahead',
'trac',
Expand Down
7 changes: 4 additions & 3 deletions tests/test_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.optimizer import SAM, TRAC, WSAM, AdamP, Lookahead, OrthoGrad, load_optimizer
from pytorch_optimizer.optimizer import SAM, TRAC, WSAM, AdamP, Lookahead, LookSAM, OrthoGrad, load_optimizer
from tests.constants import NO_SPARSE_OPTIMIZERS, SPARSE_OPTIMIZERS, VALID_OPTIMIZER_NAMES
from tests.utils import build_environment, simple_parameter, simple_sparse_parameter, sphere_loss

Expand Down Expand Up @@ -116,12 +116,13 @@ def test_sparse_supported(sparse_optimizer):
optimizer.step()


def test_sam_no_gradient():
@pytest.mark.parametrize('optimizer', [SAM, LookSAM])
def test_sam_no_gradient(optimizer):
(x_data, y_data), model, loss_fn = build_environment()
model.fc1.weight.requires_grad = False
model.fc1.weight.grad = None

optimizer = SAM(model.parameters(), AdamP)
optimizer = optimizer(model.parameters(), AdamP)
optimizer.zero_grad()

loss = loss_fn(y_data, model(x_data))
Expand Down
7 changes: 7 additions & 0 deletions tests/test_optimizer_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TRAC,
WSAM,
Lookahead,
LookSAM,
OrthoGrad,
PCGrad,
Ranger21,
Expand Down Expand Up @@ -110,6 +111,12 @@ def test_wsam_methods():
optimizer.load_state_dict(optimizer.state_dict())


def test_looksam_methods():
optimizer = LookSAM([simple_parameter()], load_optimizer('adamp'))
optimizer.reset()
optimizer.load_state_dict(optimizer.state_dict())


def test_safe_fp16_methods():
optimizer = SafeFP16Optimizer(load_optimizer('adamp')([simple_parameter()], lr=5e-1))
optimizer.load_state_dict(optimizer.state_dict())
Expand Down
Loading

0 comments on commit b82f7c4

Please sign in to comment.