Skip to content

Commit

Permalink
Version 1.0 Disabled flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex authored Oct 12, 2023
1 parent f6e958c commit 9490b1b
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tegridy-tools/X-Transformer/x_transformer_1_23_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def flash_attn(

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

with torch.backends.cuda.sdp_kernel(**config._asdict()):
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
Expand Down

0 comments on commit 9490b1b

Please sign in to comment.