-
Notifications
You must be signed in to change notification settings - Fork 11.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVPTX] Add patterns for fma.relu.{f16|bf16} #114977
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Hugh Delaney (hdelan) ChangesAdd patterns to lower fma(a, b, c) > 0 ? fma(a, b, c) : 0 for f16 and bf16 types. Full diff: https://github.com/llvm/llvm-project/pull/114977.diff 2 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 5f6cba397c5352..52312fa9afbd7e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3917,3 +3917,19 @@ def atomic_thread_fence_seq_cst_cta :
def atomic_thread_fence_acq_rel_cta :
NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
Requires<[hasPTX<60>, hasSM<70>]>;
+
+def fpimm0 : FPImmLeaf<fAny, [{
+ return Imm.isExactlyValue(+0.0);
+}]>;
+
+def FMARELU :
+ NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+ "fma.rn.relu \t$dst, $a, $b, $c;", []>;
+
+def : Pat<(f16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
+ (FMARELU Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+ Requires<[useFP16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
+
+def : Pat<(bf16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
+ (FMARELU Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+ Requires<[hasBF16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu.ll b/llvm/test/CodeGen/NVPTX/fma-relu.ll
new file mode 100644
index 00000000000000..6c340ef9d53015
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fma-relu.ll
@@ -0,0 +1,77 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 --enable-unsafe-fp-math -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | %ptxas-verify -arch=sm_80 %}
+
+define half @fma_f16(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_f16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_f16_param_2];
+; CHECK-NEXT: fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.fma.f16(half %a, half %b, half %c)
+ %2 = fcmp ogt half %1, 0.0
+ %3 = select i1 %2, half %1, half 0.0
+ ret half %3
+}
+
+define half @fma_f16_expanded(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_f16_expanded(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_f16_expanded_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_f16_expanded_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_f16_expanded_param_2];
+; CHECK-NEXT: fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = fmul half %a, %b
+ %2 = fadd half %1, %c
+ %3 = fcmp ogt half %2, 0.0
+ %4 = select i1 %3, half %2, half 0.0
+ ret half %4
+}
+
+define bfloat @fma_bf16(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_bf16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_bf16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_bf16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_bf16_param_2];
+; CHECK-NEXT: fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = call bfloat @llvm.fma.bf16(bfloat %a, bfloat %b, bfloat %c)
+ %2 = fcmp ogt bfloat %1, 0.0
+ %3 = select i1 %2, bfloat %1, bfloat 0.0
+ ret bfloat %3
+}
+
+define bfloat @fma_bf16_expanded(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_bf16_expanded(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_bf16_expanded_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_bf16_expanded_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_bf16_expanded_param_2];
+; CHECK-NEXT: fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = fmul bfloat %a, %b
+ %2 = fadd bfloat %1, %c
+ %3 = fcmp ogt bfloat %2, 0.0
+ %4 = select i1 %3, bfloat %2, bfloat 0.0
+ ret bfloat %4
+}
|
Ping @ldrumm @frasercrmck |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, LGTM!
Please wait for @AlexMaclean's review though as he's more familiar with NVPTXInstrInfo.td
than I am.
|
||
def FMARELU_F16 : | ||
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), | ||
"fma.rn.relu.f16 \t$dst, $a, $b, $c;", []>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need the Requires<...>
on the instruction too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's only being used by the anonymous patterns below, which have the necessary Requires
. I don't think we need to introduce extra noise by repeating them here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think applying constraint to the instruction itself is the right thing to do. We do not want them to be emitted unintentionally, even if we do not do it now.
I do not know whether the constraint propagates to the pattern, but I think it may, so applying it here should do the job. It's easy enough to test by running the tests while targeting an older GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added the PTX and arch requirements on the instruction, and the pattern Requires just on the pattern.
Suppose the |
%1 = fmul half %a, %b | ||
%2 = fadd half %1, %c |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to add a couple more test runs:
- one with w/o mul/add -> fma contraction to make sure we do not use
fma.rn.relu
unintentionally. - one targeting older GPUs to make sure we do not emit
fma.rn.relu
there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added more tests to cover these cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth also having a test case that uses llvm.maxnum
? I believe that if the IR was given the right fast-math flags, InstCombine would transform this select
into an llvm.maxnum
anyway.
Speaking of, should we also have tests with fast-math flags? My feeling is that we should see fast-math flags in the IR as if this was really coming from a frontend with -ffast-math
(or equivalent). IIRC the NVPTX backend relies on the unsafe-fp-math
function attribute being set, which enables these fast math optimizations. I think we should have a test with fast-math flags, fast-math function attributes, and the default llc
flags (no --enable-unsafe-fp-math
, no -nvptx-fma-level
. We should still generate fma.relu
in that case, right? This, imo, should be "the" canonical test of this optimization - using various llc
flags like this is a less standardised approach.
|
||
def FMARELU_F16 : | ||
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), | ||
"fma.rn.relu.f16 \t$dst, $a, $b, $c;", []>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think applying constraint to the instruction itself is the right thing to do. We do not want them to be emitted unintentionally, even if we do not do it now.
I do not know whether the constraint propagates to the pattern, but I think it may, so applying it here should do the job. It's easy enough to test by running the tests while targeting an older GPU.
|
||
def FMARELU_F16 : | ||
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), | ||
"fma.rn.relu.f16 \t$dst, $a, $b, $c;", []>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Next question is -- what do we want to do about .ftz
?
We handle it for regular FMA instructions and it's probably needed here, too.
llvm/test/CodeGen/NVPTX/fma-relu.ll
Outdated
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | %ptxas-verify -arch=sm_80 %} | ||
; RUN: llc < %s -march=nvptx64 --enable-unsafe-fp-math -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=0 | FileCheck %s --check-prefixes=CHECK-NO-FMA | ||
; RUN: llc < %s -march=nvptx64 --enable-unsafe-fp-math -mcpu=sm_70 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | FileCheck %s --check-prefixes=CHECK-NO-ARCH | ||
; RUN: llc < %s -march=nvptx64 --enable-unsafe-fp-math -mcpu=sm_70 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | FileCheck %s --check-prefixes=CHECK-NO-PTX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This RUN line is the same as the one above.
Also maybe CHECK-SM80
and CHECK-SM70
are better check names? CHECK-NO-ARCH
and CHECK-NO-PTX
don't really explain to me what they're checking or why.
%1 = fmul half %a, %b | ||
%2 = fadd half %1, %c |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth also having a test case that uses llvm.maxnum
? I believe that if the IR was given the right fast-math flags, InstCombine would transform this select
into an llvm.maxnum
anyway.
Speaking of, should we also have tests with fast-math flags? My feeling is that we should see fast-math flags in the IR as if this was really coming from a frontend with -ffast-math
(or equivalent). IIRC the NVPTX backend relies on the unsafe-fp-math
function attribute being set, which enables these fast math optimizations. I think we should have a test with fast-math flags, fast-math function attributes, and the default llc
flags (no --enable-unsafe-fp-math
, no -nvptx-fma-level
. We should still generate fma.relu
in that case, right? This, imo, should be "the" canonical test of this optimization - using various llc
flags like this is a less standardised approach.
4759e15
to
9456007
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
@@ -0,0 +1,349 @@ | |||
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests look reasonable to me in the behaviour, but I'd prefer they -stop-after=finalize-isel
so we can test just the isel in isolation
@@ -3917,3 +3917,40 @@ def atomic_thread_fence_seq_cst_cta : | |||
def atomic_thread_fence_acq_rel_cta : | |||
NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>, | |||
Requires<[hasPTX<60>, hasSM<70>]>; | |||
|
|||
def fpimm0 : FPImmLeaf<fAny, [{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear from the name that this is strictly positive zero. Maybe def positive_zero_fp
or a better equivalent if you can
Add patterns to lower fma(a, b, c) > 0 ? fma(a, b, c) : 0 for f16 and bf16 types.
Add patterns to lower
fma(a, b, c) > 0 ? fma(a, b, c) : 0
for f16 and bf16 types.