Skip to content

Commit

Permalink
store alpha_log_scale in log space
Browse files Browse the repository at this point in the history
  • Loading branch information
proger committed Mar 22, 2024
1 parent 7bf5732 commit 92c94f5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
6 changes: 3 additions & 3 deletions hippogriff.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, *, dim=1024, expansion_factor=1.5, kernel_size=4):
self.gates = nn.Linear(hidden, 2*hidden, bias=True)
self.forget_base = nn.Parameter(torch.linspace(-4.323, -9, hidden))
self.output = nn.Linear(hidden, dim, bias=False)
self.alpha_log_scale = nn.Parameter(-8 * torch.ones(1), requires_grad=False)
self.alpha_log_scale = nn.Parameter(torch.tensor([8]).log(), requires_grad=False)

with torch.no_grad():
self.input.weight.normal_(std=dim**-0.5)
Expand All @@ -58,8 +58,8 @@ def forward(self, x):

# RG-LRU: linear recurrent unit with input-dependent gating
forget, input = self.gates(x).chunk(2, dim=-1)
alpha = (self.alpha_log_scale * softplus(self.forget_base) * forget.sigmoid()).exp()
beta = (1 - alpha**2 + 1e-6).sqrt()
alpha = (-self.alpha_log_scale.exp() * softplus(self.forget_base) * forget.sigmoid()).exp()
beta = (1 - alpha**2 + 1e-6).sqrt() # stabilizes variance
x = beta * input.sigmoid() * x

h = scan(alpha.mT.contiguous(), x.mT.contiguous()).mT
Expand Down
5 changes: 3 additions & 2 deletions sweeps/alpha_log_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
import torch
import wandb
import math

from train import train, Tapes, parser, device
from hippogriff import GriffinLM, GriffinConfig
Expand All @@ -21,7 +22,7 @@ def make_model(alpha_log_scale, vocab_size=16384, device='cuda'):
if alpha_log_scale == 'learn':
param.requires_grad = True
else:
param.data.fill_(alpha_log_scale)
param.data.fill_(math.log(alpha_log_scale))
param.requires_grad = False

return model
Expand All @@ -47,7 +48,7 @@ def run():
"method": "grid",
"metric": {"goal": "minimize", "name": "eval/loss"},
"parameters": {
"alpha_log_scale": {"values": ["learn", -8, -7, -6, -5, -4, -3, -2, -1]},
"alpha_log_scale": {"values": ["learn", 14, 8, 4]},
},
}

Expand Down

0 comments on commit 92c94f5

Please sign in to comment.