From 2d14e5e2e8fee7fa8e5b7952cacafe7233684449 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Tue, 11 Feb 2025 21:13:35 -0600 Subject: [PATCH] no-approx-on-rocm Signed-off-by: Benoit Jacob --- .../Codegen/Common/MathTransformPass.cpp | 12 ++++-- .../Codegen/Common/test/math_transform.mlir | 39 ++++++++++++++----- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp b/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp index 65a8e276eb91..94c0c16ce818 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp @@ -63,13 +63,16 @@ static void populateMathFunctionsRewritePatterns( static bool predicateRewrite(StringRef name, IREE::HAL::ExecutableTargetAttr target) { - (void)target; // Currently unused. if (clNativeMathPrecision) { // Legacy. if (name == math::Exp2Op::getOperationName() || name == math::RoundEvenOp::getOperationName()) { return false; } } + if (isROCMBackend(target)) { + // On ROCm, we want to use device library functions. + return false; + } // Currently enable all non-approximative rewrites. return true; } @@ -109,6 +112,10 @@ static bool predicateApprox(StringRef name, } return false; } + if (isROCMBackend(target)) { + // On ROCm, we want to use device library functions. + return false; + } StringRef acos = math::AcosOp::getOperationName(); StringRef asin = math::AsinOp::getOperationName(); StringRef atan = math::AtanOp::getOperationName(); @@ -123,9 +130,6 @@ static bool predicateApprox(StringRef name, StringRef expm1 = math::ExpM1Op::getOperationName(); StringRef cbrt = math::CbrtOp::getOperationName(); StringRef erf = math::ErfOp::getOperationName(); - if (isROCMBackend(target) && name == erf) { - return false; - } return llvm::is_contained({atan, atan2, tanh, log, log2, log1p, erf, asin, acos, exp, expm1, cbrt, sin, cos}, name); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/math_transform.mlir b/compiler/src/iree/compiler/Codegen/Common/test/math_transform.mlir index 5996ceb9067f..4ed8f6c51541 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/math_transform.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/math_transform.mlir @@ -58,17 +58,36 @@ func.func @rewrite_erf(%arg0: f16) -> f16 attributes { // ----- -// CHECK-LABEL: @no_approx_erf_on_rocm -func.func @no_approx_erf_on_rocm(%arg0: f16) -> f16 attributes { +// CHECK-LABEL: @no_approx_on_rocm +func.func @no_approx_on_rocm(%arg0: f16) -> f16 attributes { hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {}> } { - // On ROCm, we want to use the native device library function, so math.erf - // should not get rewritten. It's OK for f16 to still get casted to f32, as - // the device library function for f16 is casting to f32 anyway. + // On ROCm, we want to use the native device library functions. + // It's OK for f16 to still get casted to f32, as + // the device library functions for f16 are casting to f32 anyway. + // CHECK: math.acos + // CHECK: math.atan + // CHECK: math.sin + // CHECK: math.tanh + // CHECK: math.log + // CHECK: math.log2 + // CHECK: math.log1p + // CHECK: math.exp + // CHECK: math.exp2 + // CHECK: math.expm1 + // CHECK: math.cbrt // CHECK: math.erf - // CHECK-NOT: math.exp - // CHECK-NOT: math.log - // CHECK-NOT: math.fma - %0 = math.erf %arg0 : f16 - return %0 : f16 + %0 = math.acos %arg0 : f16 + %1 = math.atan %0 : f16 + %2 = math.sin %1 : f16 + %3 = math.tanh %2 : f16 + %4 = math.log %3 : f16 + %5 = math.log2 %4 : f16 + %6 = math.log1p %5 : f16 + %7 = math.exp %6 : f16 + %8 = math.exp2 %7 : f16 + %9 = math.expm1 %8 : f16 + %10 = math.cbrt %9 : f16 + %11 = math.erf %10 : f16 + return %11 : f16 }