diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 91ec10a9e61d..c4f2e5b7628e 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -335,7 +335,7 @@ def attn_fwd( return is_mqa = hq != hk - off_h_k = off_h_q % hk if is_mqa else off_h_q + off_h_k = tl.where(is_mqa, off_h_q % hk, off_h_q) need_padding = False n_extra_tokens = 0 if seqlen_k < BLOCK_N: