Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDGPU] Do not rewrite or approximate math functions on ROCm #19970

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,16 @@ static void populateMathFunctionsRewritePatterns(

static bool predicateRewrite(StringRef name,
IREE::HAL::ExecutableTargetAttr target) {
(void)target; // Currently unused.
if (clNativeMathPrecision) { // Legacy.
if (name == math::Exp2Op::getOperationName() ||
name == math::RoundEvenOp::getOperationName()) {
return false;
}
}
if (isROCMBackend(target)) {
// On ROCm, we want to use device library functions.
return false;
}
// Currently enable all non-approximative rewrites.
return true;
}
Expand Down Expand Up @@ -109,6 +112,10 @@ static bool predicateApprox(StringRef name,
}
return false;
}
if (isROCMBackend(target)) {
// On ROCm, we want to use device library functions.
return false;
}
StringRef acos = math::AcosOp::getOperationName();
StringRef asin = math::AsinOp::getOperationName();
StringRef atan = math::AtanOp::getOperationName();
Expand All @@ -123,9 +130,6 @@ static bool predicateApprox(StringRef name,
StringRef expm1 = math::ExpM1Op::getOperationName();
StringRef cbrt = math::CbrtOp::getOperationName();
StringRef erf = math::ErfOp::getOperationName();
if (isROCMBackend(target) && name == erf) {
return false;
}
return llvm::is_contained({atan, atan2, tanh, log, log2, log1p, erf, asin,
acos, exp, expm1, cbrt, sin, cos},
name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,36 @@ func.func @rewrite_erf(%arg0: f16) -> f16 attributes {

// -----

// CHECK-LABEL: @no_approx_erf_on_rocm
func.func @no_approx_erf_on_rocm(%arg0: f16) -> f16 attributes {
// CHECK-LABEL: @no_approx_on_rocm
func.func @no_approx_on_rocm(%arg0: f16) -> f16 attributes {
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {}>
} {
// On ROCm, we want to use the native device library function, so math.erf
// should not get rewritten. It's OK for f16 to still get casted to f32, as
// the device library function for f16 is casting to f32 anyway.
// On ROCm, we want to use the native device library functions.
// It's OK for f16 to still get casted to f32, as
// the device library functions for f16 are casting to f32 anyway.
// CHECK: math.acos
// CHECK: math.atan
// CHECK: math.sin
// CHECK: math.tanh
// CHECK: math.log
// CHECK: math.log2
// CHECK: math.log1p
// CHECK: math.exp
// CHECK: math.exp2
// CHECK: math.expm1
// CHECK: math.cbrt
// CHECK: math.erf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have we figured out the numerical issue with math.erf library function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have (99%) figured that there was no issue with it, and the issues we ran into were caused by PolynomialApproximationPass being too coarse-grained and too convoluted, so that when we thought earlier that we were enabling/disabling math.erf approximation, we were also enabling/disabling a number of other things, unintentionally. This is what #19922 solved. Now in the present PR we are finally at a place where we have some fine-grained, well-defined levers to play with.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The remaining issue has been diagnosed in #20074 (comment) as an issue with overly strict tests requiring agreement with the less accurate polynomial approximations in f16, as opposed to the ROCm device lib performing the approximation after upcasting to f32.

// CHECK-NOT: math.exp
// CHECK-NOT: math.log
// CHECK-NOT: math.fma
%0 = math.erf %arg0 : f16
return %0 : f16
%0 = math.acos %arg0 : f16
%1 = math.atan %0 : f16
%2 = math.sin %1 : f16
%3 = math.tanh %2 : f16
%4 = math.log %3 : f16
%5 = math.log2 %4 : f16
%6 = math.log1p %5 : f16
%7 = math.exp %6 : f16
%8 = math.exp2 %7 : f16
%9 = math.expm1 %8 : f16
%10 = math.cbrt %9 : f16
%11 = math.erf %10 : f16
return %11 : f16
}
Loading