Skip to content

Commit

Permalink
hawk: make conv optional and restore proper initialization, close #5
Browse files Browse the repository at this point in the history
  • Loading branch information
proger committed Apr 11, 2024
1 parent 02bd37f commit d4d3398
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions hippogriff.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class GriffinConfig:
smqa_kv_heads: int = 1
smqa_window_size: int = 512
hawk_expansion_factor: float = 1.5
hawk_kernel_size: int = 4
conv_kernel_size: int = 4
time_module: Literal['TiedQuasiLSTM', 'Hawk'] = 'Hawk'
tied_quasi_lstm_num_heads: int = 16
gmlp_expansion_factor: float = 2
Expand Down Expand Up @@ -61,14 +61,22 @@ def forward(self, x):


class Hawk(nn.Module):
def __init__(self, *, dim=1024, expansion_factor=1.5, kernel_size=4):
def __init__(self, *, dim=1024, expansion_factor=1.5, conv_kernel_size=4):
super().__init__()
hidden = int(dim * expansion_factor)
self.input = nn.Linear(dim, 2*hidden, bias=False)
self.conv = nn.Conv1d(in_channels=hidden, out_channels=hidden, bias=True,
kernel_size=kernel_size, groups=hidden, padding=kernel_size-1)
if conv_kernel_size:
self.conv = nn.Conv1d(in_channels=hidden, out_channels=hidden, bias=True,
kernel_size=conv_kernel_size, groups=hidden, padding=conv_kernel_size-1)
else:
self.conv = None
self.gates = nn.Linear(hidden, 2*hidden, bias=True)
self.forget_base = nn.Parameter(torch.linspace(-4.323, -9, hidden))
def mk(hidden, a=0.001, b=0.1, lo=-4.323, hi=-9):
x = torch.log(torch.expm1(torch.linspace(a, b, hidden)))
x = (x - x.min()) / (x.max() - x.min())
x = x * abs(hi-lo) + hi
return x
self.forget_base = nn.Parameter(mk(hidden))
self.output = nn.Linear(hidden, dim, bias=False)
self.alpha_log_scale = nn.Parameter(torch.tensor([8]).log(), requires_grad=False)

Expand All @@ -80,7 +88,8 @@ def __init__(self, *, dim=1024, expansion_factor=1.5, kernel_size=4):
def forward(self, x):
_N, T, _C = x.shape
gate, x = self.input(x).chunk(2, dim=-1)
x = self.conv(x.mT)[..., :T].mT
if self.conv is not None:
x = self.conv(x.mT)[..., :T].mT

# RG-LRU: linear recurrent unit with input-dependent gating
forget, input = self.gates(x).chunk(2, dim=-1)
Expand Down Expand Up @@ -152,7 +161,7 @@ def __init__(self, config: GriffinConfig):
case 'TiedQuasiLSTM':
self.time = TiedQuasiLSTM(dim=config.dim, num_heads=config.tied_quasi_lstm_num_heads)
case 'Hawk':
self.time = Hawk(dim=config.dim, expansion_factor=config.hawk_expansion_factor, kernel_size=config.hawk_kernel_size)
self.time = Hawk(dim=config.dim, expansion_factor=config.hawk_expansion_factor, conv_kernel_size=config.conv_kernel_size)
self.gmlp_norm = RMSNorm(dim=config.dim)
self.gmlp = GatedMLP(dim=config.dim, expansion_factor=config.gmlp_expansion_factor)

Expand Down

0 comments on commit d4d3398

Please sign in to comment.