From a5476c6573df3c8b3488674d9aa92b4f9a4b434e Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 13 Feb 2023 12:37:05 -0800 Subject: [PATCH] bet on new attention stabilizing technique circulating within brain --- README.md | 9 +++++++++ muse_maskgit_pytorch/muse_maskgit_pytorch.py | 20 +++++++++++++++----- setup.py | 2 +- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 5261a41..96388a9 100644 --- a/README.md +++ b/README.md @@ -283,3 +283,12 @@ images # List[PIL.Image.Image] year = {2021} } ``` + +```bibtex +@misc{gilmer2023intriguing + title = {Intriguing Properties of Transformer Training Instabilities}, + author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen}, + year = {2023}, + status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams} +} +``` diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index d3f1fa6..3bf3226 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -36,6 +36,9 @@ def inner(model, *args, **kwargs): return out return inner +def l2norm(t): + return F.normalize(t, dim = -1) + # tensor helpers def get_mask_subset_prob(mask, prob, min_mask = 0): @@ -89,10 +92,11 @@ def __init__( dim, dim_head = 64, heads = 8, - cross_attend = False + cross_attend = False, + scale = 8 ): super().__init__() - self.scale = dim_head ** -0.5 + self.scale = scale self.heads = heads inner_dim = dim_head * heads @@ -103,6 +107,10 @@ def __init__( self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) + + self.q_scale = nn.Parameter(torch.ones(dim_head)) + self.k_scale = nn.Parameter(torch.ones(dim_head)) + self.to_out = nn.Linear(inner_dim, dim, bias = False) def forward( @@ -121,8 +129,6 @@ def forward( q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)) - q = q * self.scale - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) nk, nv = self.null_kv @@ -131,7 +137,11 @@ def forward( k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) - sim = einsum('b h i d, b h j d -> b h i j', q, k) + q, k = map(l2norm, (q, k)) + q = q * self.q_scale + k = k * self.k_scale + + sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale if exists(context_mask): context_mask = rearrange(context_mask, 'b j -> b 1 1 j') diff --git a/setup.py b/setup.py index a461e2f..91ba65f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'muse-maskgit-pytorch', packages = find_packages(exclude=[]), - version = '0.0.27', + version = '0.1.0', license='MIT', description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', author = 'Phil Wang',