diff --git a/hippogriff.py b/hippogriff.py index 0f7d915..04de67a 100644 --- a/hippogriff.py +++ b/hippogriff.py @@ -44,7 +44,7 @@ def __init__(self, *, dim=1024, expansion_factor=1.5, kernel_size=4): self.gates = nn.Linear(hidden, 2*hidden, bias=True) self.forget_base = nn.Parameter(torch.linspace(-4.323, -9, hidden)) self.output = nn.Linear(hidden, dim, bias=False) - self.alpha_log_scale = nn.Parameter(-8 * torch.ones(1), requires_grad=False) + self.alpha_log_scale = nn.Parameter(torch.tensor([8]).log(), requires_grad=False) with torch.no_grad(): self.input.weight.normal_(std=dim**-0.5) @@ -58,8 +58,8 @@ def forward(self, x): # RG-LRU: linear recurrent unit with input-dependent gating forget, input = self.gates(x).chunk(2, dim=-1) - alpha = (self.alpha_log_scale * softplus(self.forget_base) * forget.sigmoid()).exp() - beta = (1 - alpha**2 + 1e-6).sqrt() + alpha = (-self.alpha_log_scale.exp() * softplus(self.forget_base) * forget.sigmoid()).exp() + beta = (1 - alpha**2 + 1e-6).sqrt() # stabilizes variance x = beta * input.sigmoid() * x h = scan(alpha.mT.contiguous(), x.mT.contiguous()).mT diff --git a/sweeps/alpha_log_scale.py b/sweeps/alpha_log_scale.py index ed0ef42..23ac21a 100644 --- a/sweeps/alpha_log_scale.py +++ b/sweeps/alpha_log_scale.py @@ -4,6 +4,7 @@ from pathlib import Path import torch import wandb +import math from train import train, Tapes, parser, device from hippogriff import GriffinLM, GriffinConfig @@ -21,7 +22,7 @@ def make_model(alpha_log_scale, vocab_size=16384, device='cuda'): if alpha_log_scale == 'learn': param.requires_grad = True else: - param.data.fill_(alpha_log_scale) + param.data.fill_(math.log(alpha_log_scale)) param.requires_grad = False return model @@ -47,7 +48,7 @@ def run(): "method": "grid", "metric": {"goal": "minimize", "name": "eval/loss"}, "parameters": { - "alpha_log_scale": {"values": ["learn", -8, -7, -6, -5, -4, -3, -2, -1]}, + "alpha_log_scale": {"values": ["learn", 14, 8, 4]}, }, }