diff --git a/jax_triton/pallas/ops/attention.py b/jax_triton/pallas/ops/attention.py index 8eab39a5..dd818925 100644 --- a/jax_triton/pallas/ops/attention.py +++ b/jax_triton/pallas/ops/attention.py @@ -27,7 +27,6 @@ def mha_forward_kernel( q_ref, k_ref, v_ref, # Input arrays o_ref, # Output - tmp_ref, # Temporary scratch space to deal with compiler bug *residual_refs, # Residual outputs sm_scale: float, block_q: int, block_d: int, block_k: int): seq_len = q_ref.shape[0] @@ -79,11 +78,6 @@ def body(i, refs): p_ij = p_ij * p_scale[:, None] # Shape [block_q]. # Update the scaling of the output buffer acc. acc_scale = l_i / l_i_new * alpha # Shape [block_q]. - # Compiler bug! Use tmp real quick - - tmp_idx = (pl.dslice(start_q * block_q, block_q),) - pl.store(tmp_ref, tmp_idx, acc_scale) - acc_scale = pl.load(tmp_ref, tmp_idx) acc = acc * acc_scale[:, None] l_i_ref[:] = l_i_new # Update m_i and l_i for the next block_k. @@ -135,8 +129,6 @@ def mha(q, k, v, block_d=head_dim) out_shape = [ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), - jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), - dtype=jnp.float32) ] out, _ = pl.pallas_call( kernel, @@ -179,8 +171,6 @@ def _mha_forward(q, k, v, sm_scale: float, block_q: int, block_k: int, block_d=head_dim) out_shape = [ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out - jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # tmp - dtype=jnp.float32), jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # l dtype=jnp.float32), jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # m