Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: in newer versions of triton, tl.dot should take as input only q … #1288

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

EdouardYvinec
Copy link

One can use the following minimal example to reproduce the problem:

import torch
from flash_attn.flash_attn_triton import FlashAttnFunc

flash_att = FlashAttnFunc.apply

shape = (1, 5, 1, 5)
q = torch.ones(shape).cuda().to(torch.float16)
k = torch.ones(shape).cuda().to(torch.float16)
v = torch.tensor([
    [[[1,1,1,1,1]],[[2,2,2,2,2]],[[3,3,3,3,3]],[[4,4,4,4,4]],[[5,5,5,5,5]]
]]).cuda().to(torch.float16)
bias = torch.zeros([shape[0], shape[2], shape[1], shape[1]]).cuda().to(torch.float16)

print(q)
print(k)
print(v)
bias[0,0] = -(torch.randn(5,5) > 0.01).to(torch.float16)
bias *= torch.finfo(torch.float16).max
print(bias)

out = flash_att(q, k, v, bias)
print(out)

In the original repository with recent versions of triton, we would get an error regarding the use of arg trans_b in tl.dot. In this PR, we fix the issue. The former example can now run properly. This triton implementation enables arbitrary masking through the argument bias.

In order to test in colab, you add the following above the previous example to fix the current version of flash-attn

import flash_attn.flash_attn_triton as module_to_edit
import importlib
import os

path = os.path.abspath(module_to_edit.__file__)
with open(path, 'r') as f:
    lines = f.readlines()
    print('changing:', lines[182])
    lines[182] = '        qk += tl.dot(q, tl.trans(k))\n'
    print('to:', lines[182])
with open(path, 'w') as f:
    f.writelines(lines)

importlib.reload(module_to_edit)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant