Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVPTX] Add patterns for fma.relu.{f16|bf16} #114977

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

hdelan
Copy link
Contributor

@hdelan hdelan commented Nov 5, 2024

Add patterns to lower fma(a, b, c) > 0 ? fma(a, b, c) : 0 for f16 and bf16 types.

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 5, 2024

@llvm/pr-subscribers-backend-nvptx

Author: Hugh Delaney (hdelan)

Changes

Add patterns to lower fma(a, b, c) > 0 ? fma(a, b, c) : 0 for f16 and bf16 types.


Full diff: https://github.com/llvm/llvm-project/pull/114977.diff

2 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+16)
  • (added) llvm/test/CodeGen/NVPTX/fma-relu.ll (+77)
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 5f6cba397c5352..52312fa9afbd7e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3917,3 +3917,19 @@ def atomic_thread_fence_seq_cst_cta :
 def atomic_thread_fence_acq_rel_cta :
   NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
   Requires<[hasPTX<60>, hasSM<70>]>;
+
+def fpimm0 : FPImmLeaf<fAny, [{
+  return Imm.isExactlyValue(+0.0);
+}]>;
+
+def FMARELU :
+  NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+            "fma.rn.relu \t$dst, $a, $b, $c;", []>;
+
+def : Pat<(f16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
+  (FMARELU Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+  Requires<[useFP16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
+
+def : Pat<(bf16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
+  (FMARELU Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+  Requires<[hasBF16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu.ll b/llvm/test/CodeGen/NVPTX/fma-relu.ll
new file mode 100644
index 00000000000000..6c340ef9d53015
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fma-relu.ll
@@ -0,0 +1,77 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 --enable-unsafe-fp-math -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | %ptxas-verify -arch=sm_80 %}
+
+define half @fma_f16(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_f16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_f16_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_f16_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_f16_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = call half @llvm.fma.f16(half %a, half %b, half %c)
+  %2 = fcmp ogt half %1, 0.0
+  %3 = select i1 %2, half %1, half 0.0
+  ret half %3
+}
+
+define half @fma_f16_expanded(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_f16_expanded(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_f16_expanded_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_f16_expanded_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_f16_expanded_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = fmul half %a, %b
+  %2 = fadd half %1, %c
+  %3 = fcmp ogt half %2, 0.0
+  %4 = select i1 %3, half %2, half 0.0
+  ret half %4
+}
+
+define bfloat @fma_bf16(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_bf16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_bf16_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_bf16_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_bf16_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = call bfloat @llvm.fma.bf16(bfloat %a, bfloat %b, bfloat %c)
+  %2 = fcmp ogt bfloat %1, 0.0
+  %3 = select i1 %2, bfloat %1, bfloat 0.0
+  ret bfloat %3
+}
+
+define bfloat @fma_bf16_expanded(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_bf16_expanded(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_bf16_expanded_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_bf16_expanded_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_bf16_expanded_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = fmul bfloat %a, %b
+  %2 = fadd bfloat %1, %c
+  %3 = fcmp ogt bfloat %2, 0.0
+  %4 = select i1 %3, bfloat %2, bfloat 0.0
+  ret bfloat %4
+}

@hdelan hdelan changed the title Add patterns for fma.relu.{f16|bf16} [NVPTX] Add patterns for fma.relu.{f16|bf16} Nov 5, 2024
@hdelan
Copy link
Contributor Author

hdelan commented Nov 5, 2024

Ping @ldrumm @frasercrmck

Copy link
Contributor

@justinfargnoli justinfargnoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, LGTM!

Please wait for @AlexMaclean's review though as he's more familiar with NVPTXInstrInfo.td than I am.


def FMARELU_F16 :
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
"fma.rn.relu.f16 \t$dst, $a, $b, $c;", []>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the Requires<...> on the instruction too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only being used by the anonymous patterns below, which have the necessary Requires. I don't think we need to introduce extra noise by repeating them here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think applying constraint to the instruction itself is the right thing to do. We do not want them to be emitted unintentionally, even if we do not do it now.

I do not know whether the constraint propagates to the pattern, but I think it may, so applying it here should do the job. It's easy enough to test by running the tests while targeting an older GPU.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the PTX and arch requirements on the instruction, and the pattern Requires just on the pattern.

@AlexMaclean
Copy link
Member

Suppose the fma has more uses in addition to the fmaxnum, If this optimization kicks in it may increase the register pressure and won't be a clear win in terms of performance. I'm not sure this will be a problem, but to be conservative it may be better to implement this as a DAG combine and verify the fma has a single use.

Comment on lines +35 to +304
%1 = fmul half %a, %b
%2 = fadd half %1, %c
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to add a couple more test runs:

  • one with w/o mul/add -> fma contraction to make sure we do not use fma.rn.relu unintentionally.
  • one targeting older GPUs to make sure we do not emit fma.rn.relu there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added more tests to cover these cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth also having a test case that uses llvm.maxnum? I believe that if the IR was given the right fast-math flags, InstCombine would transform this select into an llvm.maxnum anyway.

Speaking of, should we also have tests with fast-math flags? My feeling is that we should see fast-math flags in the IR as if this was really coming from a frontend with -ffast-math (or equivalent). IIRC the NVPTX backend relies on the unsafe-fp-math function attribute being set, which enables these fast math optimizations. I think we should have a test with fast-math flags, fast-math function attributes, and the default llc flags (no --enable-unsafe-fp-math, no -nvptx-fma-level. We should still generate fma.relu in that case, right? This, imo, should be "the" canonical test of this optimization - using various llc flags like this is a less standardised approach.


def FMARELU_F16 :
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
"fma.rn.relu.f16 \t$dst, $a, $b, $c;", []>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think applying constraint to the instruction itself is the right thing to do. We do not want them to be emitted unintentionally, even if we do not do it now.

I do not know whether the constraint propagates to the pattern, but I think it may, so applying it here should do the job. It's easy enough to test by running the tests while targeting an older GPU.

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td Outdated Show resolved Hide resolved

def FMARELU_F16 :
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
"fma.rn.relu.f16 \t$dst, $a, $b, $c;", []>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Next question is -- what do we want to do about .ftz?
We handle it for regular FMA instructions and it's probably needed here, too.

; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | %ptxas-verify -arch=sm_80 %}
; RUN: llc < %s -march=nvptx64 --enable-unsafe-fp-math -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=0 | FileCheck %s --check-prefixes=CHECK-NO-FMA
; RUN: llc < %s -march=nvptx64 --enable-unsafe-fp-math -mcpu=sm_70 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | FileCheck %s --check-prefixes=CHECK-NO-ARCH
; RUN: llc < %s -march=nvptx64 --enable-unsafe-fp-math -mcpu=sm_70 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | FileCheck %s --check-prefixes=CHECK-NO-PTX
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This RUN line is the same as the one above.

Also maybe CHECK-SM80 and CHECK-SM70 are better check names? CHECK-NO-ARCH and CHECK-NO-PTX don't really explain to me what they're checking or why.

Comment on lines +35 to +304
%1 = fmul half %a, %b
%2 = fadd half %1, %c
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth also having a test case that uses llvm.maxnum? I believe that if the IR was given the right fast-math flags, InstCombine would transform this select into an llvm.maxnum anyway.

Speaking of, should we also have tests with fast-math flags? My feeling is that we should see fast-math flags in the IR as if this was really coming from a frontend with -ffast-math (or equivalent). IIRC the NVPTX backend relies on the unsafe-fp-math function attribute being set, which enables these fast math optimizations. I think we should have a test with fast-math flags, fast-math function attributes, and the default llc flags (no --enable-unsafe-fp-math, no -nvptx-fma-level. We should still generate fma.relu in that case, right? This, imo, should be "the" canonical test of this optimization - using various llc flags like this is a less standardised approach.

@hdelan hdelan force-pushed the fma-relu branch 4 times, most recently from 4759e15 to 9456007 Compare November 6, 2024 11:20
Copy link
Contributor

@ldrumm ldrumm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@@ -0,0 +1,349 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests look reasonable to me in the behaviour, but I'd prefer they -stop-after=finalize-isel so we can test just the isel in isolation

@@ -3917,3 +3917,40 @@ def atomic_thread_fence_seq_cst_cta :
def atomic_thread_fence_acq_rel_cta :
NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
Requires<[hasPTX<60>, hasSM<70>]>;

def fpimm0 : FPImmLeaf<fAny, [{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear from the name that this is strictly positive zero. Maybe def positive_zero_fp or a better equivalent if you can

Add patterns to lower fma(a, b, c) > 0 ? fma(a, b, c) : 0 for f16 and
bf16 types.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants