diff --git a/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp b/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp index 497da0f65dc4..65a8e276eb91 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp @@ -99,7 +99,6 @@ static bool predicateF32Cast(StringRef name, static bool predicateApprox(StringRef name, IREE::HAL::ExecutableTargetAttr target) { - (void)target; // Currently unused. if (clNativeMathPrecision) { // Legacy. if (name == math::ErfOp::getOperationName()) { // The legacy implementation had a bug: it always applied polynomial @@ -124,6 +123,9 @@ static bool predicateApprox(StringRef name, StringRef expm1 = math::ExpM1Op::getOperationName(); StringRef cbrt = math::CbrtOp::getOperationName(); StringRef erf = math::ErfOp::getOperationName(); + if (isROCMBackend(target) && name == erf) { + return false; + } return llvm::is_contained({atan, atan2, tanh, log, log2, log1p, erf, asin, acos, exp, expm1, cbrt, sin, cos}, name); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/math_transform.mlir b/compiler/src/iree/compiler/Codegen/Common/test/math_transform.mlir index 1ddd94ba02bd..5996ceb9067f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/math_transform.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/math_transform.mlir @@ -55,3 +55,20 @@ func.func @rewrite_erf(%arg0: f16) -> f16 attributes { %0 = math.erf %arg0 : f16 return %0 : f16 } + +// ----- + +// CHECK-LABEL: @no_approx_erf_on_rocm +func.func @no_approx_erf_on_rocm(%arg0: f16) -> f16 attributes { + hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {}> +} { + // On ROCm, we want to use the native device library function, so math.erf + // should not get rewritten. It's OK for f16 to still get casted to f32, as + // the device library function for f16 is casting to f32 anyway. + // CHECK: math.erf + // CHECK-NOT: math.exp + // CHECK-NOT: math.log + // CHECK-NOT: math.fma + %0 = math.erf %arg0 : f16 + return %0 : f16 +}