From 9490b1b3c9dee43c83d3e419bbb158e86b76cdcb Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 11 Oct 2023 22:23:42 -0700 Subject: [PATCH] Version 1.0 Disabled flash attention --- tegridy-tools/X-Transformer/x_transformer_1_23_2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tegridy-tools/X-Transformer/x_transformer_1_23_2.py b/tegridy-tools/X-Transformer/x_transformer_1_23_2.py index 2993865..105715d 100644 --- a/tegridy-tools/X-Transformer/x_transformer_1_23_2.py +++ b/tegridy-tools/X-Transformer/x_transformer_1_23_2.py @@ -256,7 +256,7 @@ def flash_attn( # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale - with torch.backends.cuda.sdp_kernel(**config._asdict()): + with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True): out = F.scaled_dot_product_attention( q, k, v, attn_mask = mask,