Skip to content

Commit

Permalink
REBASE: fix interface changes in rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Oct 28, 2024
1 parent 223a016 commit 18f3bae
Showing 1 changed file with 5 additions and 12 deletions.
17 changes: 5 additions & 12 deletions flash_attn/flash_attn_triton_amd/interface_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def fwd(q,
else:
if DEBUG:
print("Using Triton implementation")
(output,
(_,
softmax_lse,
exp_scores,
_,
Expand Down Expand Up @@ -122,15 +122,11 @@ def fwd(q,

if DEBUG:
print("fwd outputs")
print("output:", output, output.shape)
print("q:", q, q.shape)
print("k:", k, k.shape)
print("v:", v, v.shape)
print("o:", o, o.shape)
print("softmax_lse:", softmax_lse, softmax_lse.shape)
print("exp_scores:", exp_scores, exp_scores.shape if exp_scores is not None else None )

return output, q , k , v, o, softmax_lse, exp_scores, None
return o, softmax_lse, exp_scores, None

def bwd(
dout,
Expand Down Expand Up @@ -330,7 +326,7 @@ def varlen_fwd(
else:
if DEBUG:
print("Using Triton implementation")
(output,
(_,
softmax_lse,
exp_scores,
_,
Expand All @@ -357,16 +353,12 @@ def varlen_fwd(
metadata.use_exp2)
if DEBUG:
print("varlen_fwd outputs")
print("output:", output, output.shape)
print("q:", q, q.shape)
print("k:", k, k.shape)
print("v:", v, v.shape)
print("o:", o, o.shape)
print("softmax_lse:", softmax_lse, softmax_lse.shape)
print("exp_scores:", exp_scores, exp_scores.shape if exp_scores is not None else None )


return output, q , k , v, o, softmax_lse, exp_scores, None
return o, softmax_lse, exp_scores, None

def varlen_bwd(
dout,
Expand Down Expand Up @@ -527,6 +519,7 @@ def fwd_kvcache(
metadata.need_alibi(alibi_slopes, batch, nheads_q)

# launch kernel
# TODO: pass output as an arg. Maybe we are copying output which is causing slow down
output, softmax_lse = attention_decode_forward_triton_impl(
q,
k_cache,
Expand Down

0 comments on commit 18f3bae

Please sign in to comment.