Skip to content

Commit

Permalink
sdp attn
Browse files Browse the repository at this point in the history
  • Loading branch information
Birch-san authored and korakoe committed Jun 12, 2023
1 parent 9a288b7 commit 5246f21
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 18 deletions.
26 changes: 9 additions & 17 deletions attn/sdp_attn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from torch import einsum, nn
from torch import einsum, nn, FloatTensor
import torch
import torch.nn.functional as F
from torch.nn.functional import scaled_dot_product_attention
from einops import rearrange, repeat

# helpers
def exists(val):
return val is not None
from typing import Optional

def l2norm(t):
return F.normalize(t, dim=-1)
Expand Down Expand Up @@ -34,15 +32,15 @@ def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8):
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)

self.typical_scale = dim_head ** -.5
scale_ratio = scale/self.typical_scale
typical_scale = dim_head ** -.5
scale_ratio = scale/typical_scale
self.q_scale = nn.Parameter(torch.full((dim_head,), scale_ratio))
self.k_scale = nn.Parameter(torch.ones(dim_head))

self.to_out = nn.Linear(inner_dim, dim, bias=False)

def forward(self, x, context=None, context_mask=None):
assert not (exists(context) ^ self.cross_attend)
def forward(self, x: FloatTensor, context: Optional[FloatTensor]=None, context_mask=None):
assert (context is None) != self.cross_attend

h = self.heads
x = self.norm(x)
Expand All @@ -63,17 +61,11 @@ def forward(self, x, context=None, context_mask=None):
q = q * self.q_scale
k = k * self.k_scale

sim = q @ k.transpose(-2, -1) * self.typical_scale

if exists(context_mask):
if context_mask is not None:
context_mask = rearrange(context_mask, "b j -> b 1 1 j")
context_mask = F.pad(context_mask, (1, 0), value=True)

mask_value = -torch.finfo(sim.dtype).max
sim = sim.masked_fill(~context_mask, mask_value)

attn = sim.softmax(dim=-1)
out = attn @ v
out: FloatTensor = scaled_dot_product_attention(q, k, v, context_mask)

out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
3 changes: 2 additions & 1 deletion attn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@

# attend to just the first two tokens in each text condition (e.g. if both were uncond, so [BOS, EOS] followed by PAD tokens)
context_mask: BoolTensor = (arange(text_tokens, device=device) < 2).expand(batch_size, -1)
# context_mask = None

ein_result: FloatTensor = ein_attn.forward(x, context, context_mask)
sdp_result: FloatTensor = sdp_attn.forward(x, context, context_mask)

# default relative and absolute tolerance
rtol=1e-5
atol=1e-8
atol=5e-7
assert allclose(ein_result, sdp_result, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}"
print(f'attention implementations returned equivalent result, to tolerance rtol={rtol}, atol={atol}')

0 comments on commit 5246f21

Please sign in to comment.