-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5240c68
commit 40a2e23
Showing
17 changed files
with
723 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
|
||
/* spekin@2024 | ||
Version 2024Dec10_v1.0 | ||
*/ | ||
#ifndef sptest_header | ||
#define sptest_header | ||
#include <iostream> | ||
#include <string> | ||
#include <vector> | ||
#include <functional> | ||
namespace sptest | ||
{ | ||
struct TestCase | ||
{ | ||
std::string name; | ||
std::function<void()> func; | ||
}; | ||
// Container for test cases | ||
inline std::vector<TestCase> test_cases; | ||
// Macros to define and register test cases | ||
#define TEST_CASE(name) \ | ||
void name(); \ | ||
bool _##name##_registered = sptest::register_test(#name, name); \ | ||
void name() | ||
// Function to register a test case | ||
inline bool register_test(const std::string &name, std::function<void()> func) | ||
{ | ||
test_cases.push_back({name, func}); | ||
return true; | ||
} | ||
#define EXPECT(condition) \ | ||
do \ | ||
{ \ | ||
sptest::tests_run++; \ | ||
if (!(condition)) \ | ||
{ \ | ||
std::cerr << COLOR_RED << "[FAIL] " << COLOR_RESET \ | ||
<< __FILE__ << ":" << __LINE__ << " Tested [" << #condition << "]" << std::endl; \ | ||
sptest::tests_failed++; \ | ||
} \ | ||
else \ | ||
{ \ | ||
std::cout << COLOR_GREEN << "[PASS] " << COLOR_RESET \ | ||
<< __FILE__ << ":" << __LINE__ << ": " << #condition << std::endl; \ | ||
} \ | ||
} while (0) | ||
#define OK(msg) \ | ||
do \ | ||
{ \ | ||
sptest::tests_run++; \ | ||
std::cout << "just passing ..." << msg; \ | ||
} while (0) | ||
#define EXPECT_TRUE(condition) \ | ||
do \ | ||
{ \ | ||
if (!(condition)) \ | ||
{ \ | ||
std::cerr << "[FAIL] " << __FILE__ << ":" << __LINE__ << ": " << #condition << " is false" << std::endl; \ | ||
} \ | ||
else \ | ||
{ \ | ||
std::cout << "[PASS] " << __FILE__ << ":" << __LINE__ << ": " << #condition << std::endl; \ | ||
} \ | ||
} while (0) | ||
#define EXPECT_EQ(actual, expected) \ | ||
do \ | ||
{ \ | ||
if ((actual) != (expected)) \ | ||
{ \ | ||
std::cerr << "[FAIL] " << __FILE__ << ":" << __LINE__ << ": Expected " << (expected) << ", got " << (actual) << std::endl; \ | ||
} \ | ||
else \ | ||
{ \ | ||
std::cout << "[PASS] " << __FILE__ << ":" << __LINE__ << ": " << (actual) << " == " << (expected) << std::endl; \ | ||
} \ | ||
} while (0) | ||
#define EXPECT_NEAR(actual, expected, tolerance) \ | ||
do \ | ||
{ \ | ||
sptest::tests_run++; \ | ||
if (std::fabs((actual) - (expected)) > (tolerance)) \ | ||
{ \ | ||
std::cerr << "[FAIL] " << __FILE__ << ":" << __LINE__ \ | ||
<< ": Expected " << (expected) << " but got " << (actual) \ | ||
<< " (Tolerance: " << (tolerance) << ")" << std::endl; \ | ||
} \ | ||
else \ | ||
{ \ | ||
std::cout << "[PASS] " << __FILE__ << ":" << __LINE__ \ | ||
<< ": " << (actual) << " is within " << (tolerance) \ | ||
<< " of " << (expected) << std::endl; \ | ||
} \ | ||
} while (0) | ||
inline void run_all_tests() | ||
{ | ||
int tests_run = 0; | ||
int tests_failed = 0; | ||
for (const auto &test : test_cases) | ||
{ | ||
std::cout << "\nRunning test: " << test.name << std::endl; | ||
try | ||
{ | ||
test.func(); | ||
} | ||
catch (const std::exception &e) | ||
{ | ||
std::cerr << "[EXCEPTION] " << e.what() << std::endl; | ||
tests_failed++; | ||
} | ||
tests_run++; | ||
} | ||
std::cout << "\nTests run: " << tests_run << std::endl; | ||
std::cout << "Tests failed: " << tests_failed << std::endl; | ||
if (tests_failed == 0) | ||
{ | ||
std::cout << "All tests passed!" << std::endl; | ||
} | ||
} | ||
} // namespace sptest | ||
#endif // sptest_header |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
|
||
r""" | ||
--- | ||
title: Adam Optimizer | ||
summary: A simple PyTorch implementation/tutorial of Adam optimizer | ||
--- | ||
# Adam Optimizer | ||
This is a [PyTorch](https://pytorch.org) implementation of popular optimizer *Adam* from paper | ||
[Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980). | ||
*Adam* update is, | ||
\begin{align} | ||
m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t \\ | ||
v_t &\leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \cdot g_t^2 \\ | ||
\hat{m}_t &\leftarrow \frac{m_t}{1-\beta_1^t} \\ | ||
\hat{v}_t &\leftarrow \frac{v_t}{1-\beta_2^t} \\ | ||
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} | ||
\end{align} | ||
where $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalar hyper parameters. | ||
$m_t$ and $v_t$ are first and second order moments. | ||
$\hat{m}_t$ and $\hat{v}_t$ are biased corrected moments. | ||
$\epsilon$ is used as a fix for division by zero error, but also acts as a form of a hyper-parameter | ||
that acts against variance in gradients. | ||
Effective step taken assuming $\epsilon = 0$ is, | ||
$$\Delta t = \alpha \cdot \frac{\hat{m}_t}{\hat{v}_t}$$ | ||
This is bounded by, | ||
$$\vert \Delta t \vert \le \alpha \cdot \frac{1 - \beta_1}{\sqrt{1-\beta_2}}$$ | ||
when $1-\beta_1 \gt \sqrt{1-\beta_2}$ | ||
and | ||
$$\vert \Delta t\vert \le \alpha$$ | ||
otherwise. | ||
And in most common scenarios, | ||
$$\vert \Delta t \vert \approx \alpha$$ | ||
""" | ||
import math | ||
from typing import Dict, Any, Tuple, Optional | ||
import torch | ||
from labml import tracker | ||
from torch import nn | ||
from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay | ||
class Adam(GenericAdaptiveOptimizer): | ||
""" | ||
## Adam Optimizer | ||
We extend the class `GenericAdaptiveOptimizer` defined in [`__init__.py`](index.html) | ||
to implement the Adam optimizer. | ||
""" | ||
def __init__(self, params, | ||
lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16, | ||
weight_decay: WeightDecay = WeightDecay(), | ||
optimized_update: bool = True, | ||
defaults: Optional[Dict[str, Any]] = None): | ||
""" | ||
### Initialize the optimizer | ||
* `params` is the list of parameters | ||
* `lr` is the learning rate $\alpha$ | ||
* `betas` is a tuple of ($\beta_1$, $\beta_2$) | ||
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update` | ||
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html) | ||
* `optimized_update` is a flag whether to optimize the bias correction of the second moment | ||
by doing it after adding $\epsilon$ | ||
* `defaults` is a dictionary of default for group values. | ||
This is useful when you want to extend the class `Adam`. | ||
""" | ||
defaults = {} if defaults is None else defaults | ||
defaults.update(weight_decay.defaults()) | ||
super().__init__(params, defaults, lr, betas, eps) | ||
self.weight_decay = weight_decay | ||
self.optimized_update = optimized_update | ||
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter): | ||
""" | ||
### Initialize a parameter state | ||
* `state` is the optimizer state of the parameter (tensor) | ||
* `group` stores optimizer attributes of the parameter group | ||
* `param` is the parameter tensor $\theta_{t-1}$ | ||
""" | ||
# This is the number of optimizer steps taken on the parameter, $t$ | ||
state['step'] = 0 | ||
# Exponential moving average of gradients, $m_t$ | ||
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) | ||
# Exponential moving average of squared gradient values, $v_t$ | ||
state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) | ||
def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor): | ||
""" | ||
### Calculate $m_t$ and and $v_t$ | ||
* `state` is the optimizer state of the parameter (tensor) | ||
* `group` stores optimizer attributes of the parameter group | ||
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$ | ||
""" | ||
# Get $\beta_1$ and $\beta_2$ | ||
beta1, beta2 = group['betas'] | ||
# Get $m_{t-1}$ and $v_{t-1}$ | ||
m, v = state['exp_avg'], state['exp_avg_sq'] | ||
# In-place calculation of $m_t$ | ||
# $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$ | ||
m.mul_(beta1).add_(grad, alpha=1 - beta1) | ||
# In-place calculation of $v_t$ | ||
# $$v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \cdot g_t^2$$ | ||
v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) | ||
return m, v | ||
def get_lr(self, state: Dict[str, any], group: Dict[str, any]): | ||
""" | ||
### Get learning-rate | ||
This returns the modified learning rate based on the state. | ||
For *Adam* this is just the specified learning rate for the parameter group, | ||
$\alpha$. | ||
""" | ||
return group['lr'] | ||
def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter, | ||
m: torch.Tensor, v: torch.Tensor): | ||
""" | ||
### Do the *Adam* parameter update | ||
* `state` is the optimizer state of the parameter (tensor) | ||
* `group` stores optimizer attributes of the parameter group | ||
* `param` is the parameter tensor $\theta_{t-1}$ | ||
* `m` and `v` are the uncorrected first and second moments $m_t$ and $v_t$. | ||
This computes the following | ||
\begin{align} | ||
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} | ||
\end{align} | ||
Since $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalars and others are tensors | ||
we modify this calculation to optimize the computation. | ||
\begin{align} | ||
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \\ | ||
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot | ||
\frac{m_t / (1-\beta_1^t)}{\sqrt{v_t/(1-\beta_2^t)} + \epsilon} \\ | ||
\theta_t &\leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot | ||
\frac{m_t}{\sqrt{v_t} + \hat{\epsilon}} \\ | ||
\end{align} | ||
where | ||
$$\hat{\epsilon} = (1-\beta_2^t) \epsilon$$ | ||
is what we should specify as the hyper-parameter. | ||
""" | ||
# Get $\beta_1$ and $\beta_2$ | ||
beta1, beta2 = group['betas'] | ||
# Bias correction term for $\hat{m}_t$, $1 - \beta_1^t$ | ||
bias_correction1 = 1 - beta1 ** state['step'] | ||
# Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$ | ||
bias_correction2 = 1 - beta2 ** state['step'] | ||
# Get learning rate | ||
lr = self.get_lr(state, group) | ||
# Whether to optimize the computation | ||
if self.optimized_update: | ||
# $\sqrt{v_t} + \hat{\epsilon}$ | ||
denominator = v.sqrt().add_(group['eps']) | ||
# $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$ | ||
step_size = lr * math.sqrt(bias_correction2) / bias_correction1 | ||
# $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot | ||
# \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$ | ||
param.data.addcdiv_(m, denominator, value=-step_size) | ||
# Computation without optimization | ||
else: | ||
# $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$ | ||
denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) | ||
# $\frac{\alpha}{1-\beta_1^t}$ | ||
step_size = lr / bias_correction1 | ||
# $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot | ||
# \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$ | ||
param.data.addcdiv_(m, denominator, value=-step_size) | ||
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): | ||
""" | ||
### Take an update step for a given parameter tensor | ||
* `state` is the optimizer state of the parameter (tensor) | ||
* `group` stores optimizer attributes of the parameter group | ||
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$ | ||
* `param` is the parameter tensor $\theta_{t-1}$ | ||
""" | ||
# Calculate weight decay | ||
grad = self.weight_decay(param, grad, group) | ||
# Get $m_t$ and $v_t$ | ||
m, v = self.get_mv(state, group, grad) | ||
# Increment $t$ the number of optimizer steps | ||
state['step'] += 1 | ||
# Perform *Adam* update | ||
self.adam_update(state, group, param, m, v) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.