Skip to content

Commit

Permalink
Merge pull request #81 from ROCm/fix_fp8_fused_attention_filtering
Browse files Browse the repository at this point in the history
Fix the use_fused_attention filtering
  • Loading branch information
wangye805 authored Oct 17, 2024
2 parents 3e697a2 + eec5c20 commit 2e46618
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4897,9 +4897,6 @@ def forward(
use_unfused_attention = True

if IS_HIP_EXTENSION:
#TODO: rocm does not support fp8 fused attn
if self.fp8:
use_fused_attention = False
#TODO: add back once rocm flash-attn is available
use_flash_attention = False

Expand Down Expand Up @@ -4976,7 +4973,12 @@ def forward(
"Disabling UnfusedDotProductAttention as it does not support FP8 execution."
)
use_unfused_attention = False

if IS_HIP_EXTENSION and use_fused_attention and self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
self.logger.debug(
"Disabling ROCm fused attention as it does not support FP8 execution."
)
use_fused_attention = False

# Filter: Device and dimensions.
# FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
# FAv2 requires head_dim % 8 == 0
Expand Down

0 comments on commit 2e46618

Please sign in to comment.