Skip to content

Commit

Permalink
Do not approximate erf on rocm. (#19969)
Browse files Browse the repository at this point in the history
On ROCm, we want to use the device library functions, which we link as
bitcode and inline. In this PR, we start with `math.erf` because that's
the immediate use case, but this will likely be generalized to other
functions in a subsequent PR.

Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob authored Feb 19, 2025
1 parent c8ba691 commit 7891b80
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ static bool predicateF32Cast(StringRef name,

static bool predicateApprox(StringRef name,
IREE::HAL::ExecutableTargetAttr target) {
(void)target; // Currently unused.
if (clNativeMathPrecision) { // Legacy.
if (name == math::ErfOp::getOperationName()) {
// The legacy implementation had a bug: it always applied polynomial
Expand All @@ -124,6 +123,9 @@ static bool predicateApprox(StringRef name,
StringRef expm1 = math::ExpM1Op::getOperationName();
StringRef cbrt = math::CbrtOp::getOperationName();
StringRef erf = math::ErfOp::getOperationName();
if (isROCMBackend(target) && name == erf) {
return false;
}
return llvm::is_contained({atan, atan2, tanh, log, log2, log1p, erf, asin,
acos, exp, expm1, cbrt, sin, cos},
name);
Expand Down
17 changes: 17 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/math_transform.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,20 @@ func.func @rewrite_erf(%arg0: f16) -> f16 attributes {
%0 = math.erf %arg0 : f16
return %0 : f16
}

// -----

// CHECK-LABEL: @no_approx_erf_on_rocm
func.func @no_approx_erf_on_rocm(%arg0: f16) -> f16 attributes {
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {}>
} {
// On ROCm, we want to use the native device library function, so math.erf
// should not get rewritten. It's OK for f16 to still get casted to f32, as
// the device library function for f16 is casting to f32 anyway.
// CHECK: math.erf
// CHECK-NOT: math.exp
// CHECK-NOT: math.log
// CHECK-NOT: math.fma
%0 = math.erf %arg0 : f16
return %0 : f16
}

0 comments on commit 7891b80

Please sign in to comment.