Skip to content

Commit

Permalink
bet on new attention stabilizing technique circulating within brain
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 13, 2023
1 parent 5912e9d commit a5476c6
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 6 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,12 @@ images # List[PIL.Image.Image]
year = {2021}
}
```

```bibtex
@misc{gilmer2023intriguing
title = {Intriguing Properties of Transformer Training Instabilities},
author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
year = {2023},
status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
```
20 changes: 15 additions & 5 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def inner(model, *args, **kwargs):
return out
return inner

def l2norm(t):
return F.normalize(t, dim = -1)

# tensor helpers

def get_mask_subset_prob(mask, prob, min_mask = 0):
Expand Down Expand Up @@ -89,10 +92,11 @@ def __init__(
dim,
dim_head = 64,
heads = 8,
cross_attend = False
cross_attend = False,
scale = 8
):
super().__init__()
self.scale = dim_head ** -0.5
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads

Expand All @@ -103,6 +107,10 @@ def __init__(

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))

self.to_out = nn.Linear(inner_dim, dim, bias = False)

def forward(
Expand All @@ -121,8 +129,6 @@ def forward(

q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))

q = q * self.scale

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

nk, nv = self.null_kv
Expand All @@ -131,7 +137,11 @@ def forward(
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)

sim = einsum('b h i d, b h j d -> b h i j', q, k)
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale

sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

if exists(context_mask):
context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'muse-maskgit-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.27',
version = '0.1.0',
license='MIT',
description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit a5476c6

Please sign in to comment.