From eec5c201cb79543c2438b60e98fa7c5b6daccfd1 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Thu, 17 Oct 2024 16:04:49 -0500 Subject: [PATCH] Fix the use_fused_attention filtering Use fp8_dpa to filter fused attention, rather than fp8 --- transformer_engine/pytorch/attention.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index fb2ee07a..d8e0289a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4897,9 +4897,6 @@ def forward( use_unfused_attention = True if IS_HIP_EXTENSION: - #TODO: rocm does not support fp8 fused attn - if self.fp8: - use_fused_attention = False #TODO: add back once rocm flash-attn is available use_flash_attention = False @@ -4976,7 +4973,12 @@ def forward( "Disabling UnfusedDotProductAttention as it does not support FP8 execution." ) use_unfused_attention = False - + if IS_HIP_EXTENSION and use_fused_attention and self.fp8 and self.fp8_meta["recipe"].fp8_dpa: + self.logger.debug( + "Disabling ROCm fused attention as it does not support FP8 execution." + ) + use_fused_attention = False + # Filter: Device and dimensions. # FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90 # FAv2 requires head_dim % 8 == 0