Skip to content

Commit

Permalink
xformers attn working, so long as length of mask plus null token is a…
Browse files Browse the repository at this point in the history
… multiple of 8
  • Loading branch information
Birch-san authored and korakoe committed Jun 12, 2023
1 parent 5246f21 commit 000c2da
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"type": "python",
"request": "launch",
"module": "attn_test",
"justMyCode": true
"justMyCode": false
}
]
}
1 change: 0 additions & 1 deletion attn/sdp_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def __init__(self, dim):
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# TODO: change this to use torch sdp attn
class Attention(nn.Module):
def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8):
super().__init__()
Expand Down
74 changes: 74 additions & 0 deletions attn/xformers_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from torch import nn, FloatTensor
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from typing import Optional
from xformers.ops import memory_efficient_attention

def l2norm(t):
return F.normalize(t, dim=-1)

class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))

def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

class Attention(nn.Module):
def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8):
super().__init__()
self.heads = heads
inner_dim = dim_head * heads

self.cross_attend = cross_attend
self.norm = LayerNorm(dim)

self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head))

self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)

typical_scale = dim_head ** -.5
scale_ratio = scale/typical_scale
self.q_scale = nn.Parameter(torch.full((dim_head,), scale_ratio))
self.k_scale = nn.Parameter(torch.ones(dim_head))

self.to_out = nn.Linear(inner_dim, dim, bias=False)

def forward(self, x: FloatTensor, context: Optional[FloatTensor]=None, context_mask=None):
assert (context is None) != self.cross_attend

h = self.heads
x = self.norm(x)

kv_input = context if self.cross_attend else x

q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1))

q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q, k, v))

nk, nv = self.null_kv
nk, nv = map(lambda t: repeat(t, "h 1 d -> b 1 h d", b=x.shape[0]), (nk, nv))

k = torch.cat((nk, k), dim=-3)
v = torch.cat((nv, v), dim=-3)

q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale

if context_mask is None:
attn_bias = None
else:
context_mask = F.pad(context_mask, (1, 0), value=True)
context_mask = rearrange(context_mask, "b j -> b 1 1 j")
attn_bias = torch.where(context_mask == True, 0., -10000.)
attn_bias = attn_bias.expand(-1, h, q.size(1), -1)

out: FloatTensor = memory_efficient_attention(q, k, v, attn_bias)

out = rearrange(out, "b n h d -> b n (h d)")
return self.to_out(out)
18 changes: 12 additions & 6 deletions attn_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from attn.ein_attn import Attention as EinAttn
from attn.sdp_attn import Attention as SDPAttn
from attn.xformers_attn import Attention as XformersAttn
import torch
from torch import FloatTensor, BoolTensor, manual_seed, randn, arange, allclose, no_grad
from torch.backends.cuda import sdp_kernel

device = torch.device('cuda')
dtype = torch.float32
Expand All @@ -23,8 +25,10 @@
# seed RNG before we initialize any layers, so that both will end up with same params
manual_seed(seed)
ein_attn = EinAttn(**attn_init_params).to(device, dtype).eval()
# manual_seed(seed)
# sdp_attn = SDPAttn(**attn_init_params).to(device, dtype).eval()
manual_seed(seed)
sdp_attn = SDPAttn(**attn_init_params).to(device, dtype).eval()
xfo_attn = XformersAttn(**attn_init_params).to(device, dtype).eval()

batch_size = 2

Expand All @@ -34,22 +38,24 @@
# generate rand on-CPU for cross-platform determinism of results
x: FloatTensor = randn(batch_size, vision_tokens, vision_dim, dtype=dtype).to(device)

text_tokens = 16 # CLIP would be 77
text_tokens = 15 # CLIP would be 77
# there's no reason why these would **have** to be the same (in stable-diffusion text_dim is 768)
# but lucid didn't expose any separate param for customizing the cross attention input dim.
# easily fixed, but whatever I'll work with what's there.
text_dim = vision_dim
context: FloatTensor = randn(batch_size, text_tokens, text_dim, dtype=dtype).to(device)

# attend to just the first two tokens in each text condition (e.g. if both were uncond, so [BOS, EOS] followed by PAD tokens)
context_mask: BoolTensor = (arange(text_tokens, device=device) < 2).expand(batch_size, -1)
# context_mask = None
context_mask: BoolTensor = (arange(text_tokens, device=device) < 2).expand(batch_size, -1).contiguous()

ein_result: FloatTensor = ein_attn.forward(x, context, context_mask)
sdp_result: FloatTensor = sdp_attn.forward(x, context, context_mask)
# with sdp_kernel(enable_math=False):
# sdp_result: FloatTensor = sdp_attn.forward(x, context, context_mask)
xfo_attn: FloatTensor = xfo_attn.forward(x, context, context_mask)

# default relative and absolute tolerance
rtol=1e-5
atol=5e-7
assert allclose(ein_result, sdp_result, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}"
# assert allclose(ein_result, sdp_result, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}"
assert allclose(ein_result, xfo_attn, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}"
print(f'attention implementations returned equivalent result, to tolerance rtol={rtol}, atol={atol}')
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
"tqdm-loggable",
"vector-quantize-pytorch>=0.10.14",
"lion-pytorch",
"omegaconf"
"omegaconf",
"xformers>=0.0.20",
],
classifiers=[
"Development Status :: 4 - Beta",
Expand Down

0 comments on commit 000c2da

Please sign in to comment.