Skip to content

Commit

Permalink
[VectorCombine] foldExtractedCmps - (re-)enable fold on non-commutati…
Browse files Browse the repository at this point in the history
…ve binops

#114901 exposed that foldExtractedCmps didn't account for non-commutative binops, and were disabled by 05e838f

This patch re-enables support for non-commutative binops by ensuring that the LHS/RHS arg order of the binop is retained.
  • Loading branch information
RKSimon committed Nov 6, 2024
1 parent 38fffa6 commit e3a0775
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 19 deletions.
11 changes: 5 additions & 6 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,10 +1039,6 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
if (!BI || !I.getType()->isIntegerTy(1))
return false;

// TODO: Support non-commutative binary ops.
if (!BI->isCommutative())
return false;

// The compare predicates should match, and each compare should have a
// constant operand.
Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
Expand All @@ -1066,6 +1062,8 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1);
if (!ConvertToShuf)
return false;
assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) &&
"Unknown ExtractElementInst");

// The original scalar pattern is:
// binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
Expand Down Expand Up @@ -1117,9 +1115,10 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
CmpC[Index0] = C0;
CmpC[Index1] = C1;
Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));

Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
Value *VecLogic = Builder.CreateBinOp(BI->getOpcode(), VCmp, Shuf);
Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp;
Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf;
Value *VecLogic = Builder.CreateBinOp(BI->getOpcode(), LHS, RHS);
Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
replaceValue(I, *NewExt);
++NumVecCmpBO;
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/Transforms/VectorCombine/X86/extract-cmp-binop.ll
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ define i1 @icmp_xor_v4i32(<4 x i32> %a) {
; CHECK-LABEL: @icmp_xor_v4i32(
; CHECK-NEXT: [[TMP1:%.*]] = icmp sgt <4 x i32> [[A:%.*]], <i32 poison, i32 -8, i32 poison, i32 42>
; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <4 x i1> [[TMP1]], <4 x i1> poison, <4 x i32> <i32 poison, i32 3, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP2:%.*]] = xor <4 x i1> [[TMP1]], [[SHIFT]]
; CHECK-NEXT: [[TMP2:%.*]] = xor <4 x i1> [[SHIFT]], [[TMP1]]
; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i1> [[TMP2]], i64 1
; CHECK-NEXT: ret i1 [[R]]
;
Expand All @@ -80,7 +80,7 @@ define i1 @icmp_add_v8i32(<8 x i32> %a) {
; AVX-LABEL: @icmp_add_v8i32(
; AVX-NEXT: [[TMP1:%.*]] = icmp eq <8 x i32> [[A:%.*]], <i32 poison, i32 poison, i32 -8, i32 poison, i32 poison, i32 poison, i32 poison, i32 42>
; AVX-NEXT: [[SHIFT:%.*]] = shufflevector <8 x i1> [[TMP1]], <8 x i1> poison, <8 x i32> <i32 poison, i32 poison, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
; AVX-NEXT: [[TMP2:%.*]] = add <8 x i1> [[TMP1]], [[SHIFT]]
; AVX-NEXT: [[TMP2:%.*]] = add <8 x i1> [[SHIFT]], [[TMP1]]
; AVX-NEXT: [[R:%.*]] = extractelement <8 x i1> [[TMP2]], i64 2
; AVX-NEXT: ret i1 [[R]]
;
Expand Down Expand Up @@ -131,7 +131,7 @@ define i1 @icmp_xor_v4i32_multiuse(<4 x i32> %a) {
; CHECK-NEXT: call void @use(i32 [[E2]])
; CHECK-NEXT: [[TMP1:%.*]] = icmp sgt <4 x i32> [[A]], <i32 poison, i32 -8, i32 poison, i32 42>
; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <4 x i1> [[TMP1]], <4 x i1> poison, <4 x i32> <i32 poison, i32 3, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP2:%.*]] = xor <4 x i1> [[TMP1]], [[SHIFT]]
; CHECK-NEXT: [[TMP2:%.*]] = xor <4 x i1> [[SHIFT]], [[TMP1]]
; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i1> [[TMP2]], i64 1
; CHECK-NEXT: call void @use(i1 [[R]])
; CHECK-NEXT: ret i1 [[R]]
Expand Down
18 changes: 8 additions & 10 deletions llvm/test/Transforms/VectorCombine/X86/pr114901.ll
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ define i1 @PR114901(<4 x i32> %a) {
;
; AVX-LABEL: define i1 @PR114901(
; AVX-SAME: <4 x i32> [[A:%.*]]) #[[ATTR0:[0-9]+]] {
; AVX-NEXT: [[E1:%.*]] = extractelement <4 x i32> [[A]], i32 1
; AVX-NEXT: [[E3:%.*]] = extractelement <4 x i32> [[A]], i32 3
; AVX-NEXT: [[CMP1:%.*]] = icmp sgt i32 [[E1]], -8
; AVX-NEXT: [[CMP3:%.*]] = icmp sgt i32 [[E3]], 42
; AVX-NEXT: [[R:%.*]] = ashr i1 [[CMP3]], [[CMP1]]
; AVX-NEXT: [[TMP1:%.*]] = icmp sgt <4 x i32> [[A]], <i32 poison, i32 -8, i32 poison, i32 42>
; AVX-NEXT: [[SHIFT:%.*]] = shufflevector <4 x i1> [[TMP1]], <4 x i1> poison, <4 x i32> <i32 poison, i32 3, i32 poison, i32 poison>
; AVX-NEXT: [[TMP2:%.*]] = ashr <4 x i1> [[SHIFT]], [[TMP1]]
; AVX-NEXT: [[R:%.*]] = extractelement <4 x i1> [[TMP2]], i64 1
; AVX-NEXT: ret i1 [[R]]
;
%e1 = extractelement <4 x i32> %a, i32 1
Expand All @@ -42,11 +41,10 @@ define i1 @PR114901_flip(<4 x i32> %a) {
;
; AVX-LABEL: define i1 @PR114901_flip(
; AVX-SAME: <4 x i32> [[A:%.*]]) #[[ATTR0]] {
; AVX-NEXT: [[E1:%.*]] = extractelement <4 x i32> [[A]], i32 1
; AVX-NEXT: [[E3:%.*]] = extractelement <4 x i32> [[A]], i32 3
; AVX-NEXT: [[CMP1:%.*]] = icmp sgt i32 [[E1]], -8
; AVX-NEXT: [[CMP3:%.*]] = icmp sgt i32 [[E3]], 42
; AVX-NEXT: [[R:%.*]] = ashr i1 [[CMP1]], [[CMP3]]
; AVX-NEXT: [[TMP1:%.*]] = icmp sgt <4 x i32> [[A]], <i32 poison, i32 -8, i32 poison, i32 42>
; AVX-NEXT: [[SHIFT:%.*]] = shufflevector <4 x i1> [[TMP1]], <4 x i1> poison, <4 x i32> <i32 poison, i32 3, i32 poison, i32 poison>
; AVX-NEXT: [[TMP2:%.*]] = ashr <4 x i1> [[TMP1]], [[SHIFT]]
; AVX-NEXT: [[R:%.*]] = extractelement <4 x i1> [[TMP2]], i64 1
; AVX-NEXT: ret i1 [[R]]
;
%e1 = extractelement <4 x i32> %a, i32 1
Expand Down

0 comments on commit e3a0775

Please sign in to comment.