From 000c2da73d5de62c7a8859ee56354b169c1c88ca Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 10 Jun 2023 15:19:37 +0100 Subject: [PATCH] xformers attn working, so long as length of mask plus null token is a multiple of 8 --- .vscode/launch.json | 2 +- attn/sdp_attn.py | 1 - attn/xformers_attn.py | 74 +++++++++++++++++++++++++++++++++++++++++++ attn_test.py | 18 +++++++---- setup.py | 3 +- 5 files changed, 89 insertions(+), 9 deletions(-) create mode 100644 attn/xformers_attn.py diff --git a/.vscode/launch.json b/.vscode/launch.json index c7264c4..2537580 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -5,7 +5,7 @@ "type": "python", "request": "launch", "module": "attn_test", - "justMyCode": true + "justMyCode": false } ] } \ No newline at end of file diff --git a/attn/sdp_attn.py b/attn/sdp_attn.py index 97d76c1..da65da6 100644 --- a/attn/sdp_attn.py +++ b/attn/sdp_attn.py @@ -17,7 +17,6 @@ def __init__(self, dim): def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) -# TODO: change this to use torch sdp attn class Attention(nn.Module): def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): super().__init__() diff --git a/attn/xformers_attn.py b/attn/xformers_attn.py new file mode 100644 index 0000000..c7f1422 --- /dev/null +++ b/attn/xformers_attn.py @@ -0,0 +1,74 @@ +from torch import nn, FloatTensor +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from typing import Optional +from xformers.ops import memory_efficient_attention + +def l2norm(t): + return F.normalize(t, dim=-1) + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.ones(dim)) + self.register_buffer("beta", torch.zeros(dim)) + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + +class Attention(nn.Module): + def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): + super().__init__() + self.heads = heads + inner_dim = dim_head * heads + + self.cross_attend = cross_attend + self.norm = LayerNorm(dim) + + self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head)) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + + typical_scale = dim_head ** -.5 + scale_ratio = scale/typical_scale + self.q_scale = nn.Parameter(torch.full((dim_head,), scale_ratio)) + self.k_scale = nn.Parameter(torch.ones(dim_head)) + + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x: FloatTensor, context: Optional[FloatTensor]=None, context_mask=None): + assert (context is None) != self.cross_attend + + h = self.heads + x = self.norm(x) + + kv_input = context if self.cross_attend else x + + q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q, k, v)) + + nk, nv = self.null_kv + nk, nv = map(lambda t: repeat(t, "h 1 d -> b 1 h d", b=x.shape[0]), (nk, nv)) + + k = torch.cat((nk, k), dim=-3) + v = torch.cat((nv, v), dim=-3) + + q, k = map(l2norm, (q, k)) + q = q * self.q_scale + k = k * self.k_scale + + if context_mask is None: + attn_bias = None + else: + context_mask = F.pad(context_mask, (1, 0), value=True) + context_mask = rearrange(context_mask, "b j -> b 1 1 j") + attn_bias = torch.where(context_mask == True, 0., -10000.) + attn_bias = attn_bias.expand(-1, h, q.size(1), -1) + + out: FloatTensor = memory_efficient_attention(q, k, v, attn_bias) + + out = rearrange(out, "b n h d -> b n (h d)") + return self.to_out(out) \ No newline at end of file diff --git a/attn_test.py b/attn_test.py index e99177e..14b2475 100644 --- a/attn_test.py +++ b/attn_test.py @@ -1,7 +1,9 @@ from attn.ein_attn import Attention as EinAttn from attn.sdp_attn import Attention as SDPAttn +from attn.xformers_attn import Attention as XformersAttn import torch from torch import FloatTensor, BoolTensor, manual_seed, randn, arange, allclose, no_grad +from torch.backends.cuda import sdp_kernel device = torch.device('cuda') dtype = torch.float32 @@ -23,8 +25,10 @@ # seed RNG before we initialize any layers, so that both will end up with same params manual_seed(seed) ein_attn = EinAttn(**attn_init_params).to(device, dtype).eval() + # manual_seed(seed) + # sdp_attn = SDPAttn(**attn_init_params).to(device, dtype).eval() manual_seed(seed) - sdp_attn = SDPAttn(**attn_init_params).to(device, dtype).eval() + xfo_attn = XformersAttn(**attn_init_params).to(device, dtype).eval() batch_size = 2 @@ -34,7 +38,7 @@ # generate rand on-CPU for cross-platform determinism of results x: FloatTensor = randn(batch_size, vision_tokens, vision_dim, dtype=dtype).to(device) - text_tokens = 16 # CLIP would be 77 + text_tokens = 15 # CLIP would be 77 # there's no reason why these would **have** to be the same (in stable-diffusion text_dim is 768) # but lucid didn't expose any separate param for customizing the cross attention input dim. # easily fixed, but whatever I'll work with what's there. @@ -42,14 +46,16 @@ context: FloatTensor = randn(batch_size, text_tokens, text_dim, dtype=dtype).to(device) # attend to just the first two tokens in each text condition (e.g. if both were uncond, so [BOS, EOS] followed by PAD tokens) - context_mask: BoolTensor = (arange(text_tokens, device=device) < 2).expand(batch_size, -1) - # context_mask = None + context_mask: BoolTensor = (arange(text_tokens, device=device) < 2).expand(batch_size, -1).contiguous() ein_result: FloatTensor = ein_attn.forward(x, context, context_mask) - sdp_result: FloatTensor = sdp_attn.forward(x, context, context_mask) + # with sdp_kernel(enable_math=False): + # sdp_result: FloatTensor = sdp_attn.forward(x, context, context_mask) + xfo_attn: FloatTensor = xfo_attn.forward(x, context, context_mask) # default relative and absolute tolerance rtol=1e-5 atol=5e-7 - assert allclose(ein_result, sdp_result, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" + # assert allclose(ein_result, sdp_result, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" + assert allclose(ein_result, xfo_attn, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" print(f'attention implementations returned equivalent result, to tolerance rtol={rtol}, atol={atol}') \ No newline at end of file diff --git a/setup.py b/setup.py index f55dfee..c89a3d6 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,8 @@ "tqdm-loggable", "vector-quantize-pytorch>=0.10.14", "lion-pytorch", - "omegaconf" + "omegaconf", + "xformers>=0.0.20", ], classifiers=[ "Development Status :: 4 - Beta",