Skip to content

Commit

Permalink
[LoopFlatten] Recognise gep+gep (llvm#72515)
Browse files Browse the repository at this point in the history
Now that InstCombine canonicalises add+gep to gep+gep, LoopFlatten needs
to recognise (gep (gep ptr (i*M)), j) as being something it can
optimise.
  • Loading branch information
john-brawn-arm authored Jan 10, 2024
1 parent 9aa8c82 commit ae978ba
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 27 deletions.
36 changes: 36 additions & 0 deletions llvm/include/llvm/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,36 @@ struct ThreeOps_match {
}
};

/// Matches instructions with Opcode and any number of operands
template <unsigned Opcode, typename... OperandTypes> struct AnyOps_match {
std::tuple<OperandTypes...> Operands;

AnyOps_match(const OperandTypes &...Ops) : Operands(Ops...) {}

// Operand matching works by recursively calling match_operands, matching the
// operands left to right. The first version is called for each operand but
// the last, for which the second version is called. The second version of
// match_operands is also used to match each individual operand.
template <int Idx, int Last>
std::enable_if_t<Idx != Last, bool> match_operands(const Instruction *I) {
return match_operands<Idx, Idx>(I) && match_operands<Idx + 1, Last>(I);
}

template <int Idx, int Last>
std::enable_if_t<Idx == Last, bool> match_operands(const Instruction *I) {
return std::get<Idx>(Operands).match(I->getOperand(Idx));
}

template <typename OpTy> bool match(OpTy *V) {
if (V->getValueID() == Value::InstructionVal + Opcode) {
auto *I = cast<Instruction>(V);
return I->getNumOperands() == sizeof...(OperandTypes) &&
match_operands<0, sizeof...(OperandTypes) - 1>(I);
}
return false;
}
};

/// Matches SelectInst.
template <typename Cond, typename LHS, typename RHS>
inline ThreeOps_match<Cond, LHS, RHS, Instruction::Select>
Expand Down Expand Up @@ -1611,6 +1641,12 @@ m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp) {
PointerOp);
}

/// Matches GetElementPtrInst.
template <typename... OperandTypes>
inline auto m_GEP(const OperandTypes &...Ops) {
return AnyOps_match<Instruction::GetElementPtr, OperandTypes...>(Ops...);
}

//===----------------------------------------------------------------------===//
// Matchers for CastInst classes
//
Expand Down
79 changes: 52 additions & 27 deletions llvm/lib/Transforms/Scalar/LoopFlatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ struct FlattenInfo {
match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(OuterInductionPHI)),
m_Value(MatchedItCount)));

// Matches the pattern ptr+i*M+j, with the two additions being done via GEP.
bool IsGEP = match(U, m_GEP(m_GEP(m_Value(), m_Value(MatchedMul)),
m_Specific(InnerInductionPHI))) &&
match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI),
m_Value(MatchedItCount)));

if (!MatchedItCount)
return false;

Expand All @@ -224,7 +230,7 @@ struct FlattenInfo {

// Look through extends if the IV has been widened. Don't look through
// extends if we already looked through a trunc.
if (Widened && IsAdd &&
if (Widened && (IsAdd || IsGEP) &&
(isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
assert(MatchedItCount->getType() == InnerInductionPHI->getType() &&
"Unexpected type mismatch in types after widening");
Expand All @@ -236,7 +242,7 @@ struct FlattenInfo {
LLVM_DEBUG(dbgs() << "Looking for inner trip count: ";
InnerTripCount->dump());

if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
if ((IsAdd || IsAddTrunc || IsGEP) && MatchedItCount == InnerTripCount) {
LLVM_DEBUG(dbgs() << "Found. This sse is optimisable\n");
ValidOuterPHIUses.insert(MatchedMul);
LinearIVUses.insert(U);
Expand Down Expand Up @@ -646,33 +652,40 @@ static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
if (OR != OverflowResult::MayOverflow)
return OR;

for (Value *V : FI.LinearIVUses) {
for (Value *U : V->users()) {
if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
for (Value *GEPUser : U->users()) {
auto *GEPUserInst = cast<Instruction>(GEPUser);
if (!isa<LoadInst>(GEPUserInst) &&
!(isa<StoreInst>(GEPUserInst) &&
GEP == GEPUserInst->getOperand(1)))
continue;
if (!isGuaranteedToExecuteForEveryIteration(GEPUserInst,
FI.InnerLoop))
continue;
// The IV is used as the operand of a GEP which dominates the loop
// latch, and the IV is at least as wide as the address space of the
// GEP. In this case, the GEP would wrap around the address space
// before the IV increment wraps, which would be UB.
if (GEP->isInBounds() &&
V->getType()->getIntegerBitWidth() >=
DL.getPointerTypeSizeInBits(GEP->getType())) {
LLVM_DEBUG(
dbgs() << "use of linear IV would be UB if overflow occurred: ";
GEP->dump());
return OverflowResult::NeverOverflows;
}
}
auto CheckGEP = [&](GetElementPtrInst *GEP, Value *GEPOperand) {
for (Value *GEPUser : GEP->users()) {
auto *GEPUserInst = cast<Instruction>(GEPUser);
if (!isa<LoadInst>(GEPUserInst) &&
!(isa<StoreInst>(GEPUserInst) && GEP == GEPUserInst->getOperand(1)))
continue;
if (!isGuaranteedToExecuteForEveryIteration(GEPUserInst, FI.InnerLoop))
continue;
// The IV is used as the operand of a GEP which dominates the loop
// latch, and the IV is at least as wide as the address space of the
// GEP. In this case, the GEP would wrap around the address space
// before the IV increment wraps, which would be UB.
if (GEP->isInBounds() &&
GEPOperand->getType()->getIntegerBitWidth() >=
DL.getPointerTypeSizeInBits(GEP->getType())) {
LLVM_DEBUG(
dbgs() << "use of linear IV would be UB if overflow occurred: ";
GEP->dump());
return true;
}
}
return false;
};

// Check if any IV user is, or is used by, a GEP that would cause UB if the
// multiply overflows.
for (Value *V : FI.LinearIVUses) {
if (auto *GEP = dyn_cast<GetElementPtrInst>(V))
if (GEP->getNumIndices() == 1 && CheckGEP(GEP, GEP->getOperand(1)))
return OverflowResult::NeverOverflows;
for (Value *U : V->users())
if (auto *GEP = dyn_cast<GetElementPtrInst>(U))
if (CheckGEP(GEP, V))
return OverflowResult::NeverOverflows;
}

return OverflowResult::MayOverflow;
Expand Down Expand Up @@ -778,6 +791,18 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
OuterValue = Builder.CreateTrunc(FI.OuterInductionPHI, V->getType(),
"flatten.trunciv");

if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {
// Replace the GEP with one that uses OuterValue as the offset.
auto *InnerGEP = cast<GetElementPtrInst>(GEP->getOperand(0));
Value *Base = InnerGEP->getOperand(0);
// When the base of the GEP doesn't dominate the outer induction phi then
// we need to insert the new GEP where the old GEP was.
if (!DT->dominates(Base, &*Builder.GetInsertPoint()))
Builder.SetInsertPoint(cast<Instruction>(V));
OuterValue = Builder.CreateGEP(GEP->getSourceElementType(), Base,
OuterValue, "flatten." + V->getName());
}

LLVM_DEBUG(dbgs() << "Replacing: "; V->dump(); dbgs() << "with: ";
OuterValue->dump());
V->replaceAllUsesWith(OuterValue);
Expand Down
137 changes: 137 additions & 0 deletions llvm/test/Transforms/LoopFlatten/loop-flatten-gep.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
; RUN: opt < %s -S -passes='loop(loop-flatten),verify' -verify-loop-info -verify-dom-info -verify-scev | FileCheck %s

target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64"

; We should be able to flatten the loops and turn the two geps into one.
; CHECK-LABEL: test1
define void @test1(i32 %N, ptr %A) {
entry:
%cmp3 = icmp ult i32 0, %N
br i1 %cmp3, label %for.outer.preheader, label %for.end

; CHECK-LABEL: for.outer.preheader:
; CHECK: %flatten.tripcount = mul i32 %N, %N
for.outer.preheader:
br label %for.inner.preheader

; CHECK-LABEL: for.inner.preheader:
; CHECK: %flatten.arrayidx = getelementptr i32, ptr %A, i32 %i
for.inner.preheader:
%i = phi i32 [ 0, %for.outer.preheader ], [ %inc2, %for.outer ]
br label %for.inner

; CHECK-LABEL: for.inner:
; CHECK: store i32 0, ptr %flatten.arrayidx, align 4
; CHECK: br label %for.outer
for.inner:
%j = phi i32 [ 0, %for.inner.preheader ], [ %inc1, %for.inner ]
%mul = mul i32 %i, %N
%gep = getelementptr inbounds i32, ptr %A, i32 %mul
%arrayidx = getelementptr inbounds i32, ptr %gep, i32 %j
store i32 0, ptr %arrayidx, align 4
%inc1 = add nuw i32 %j, 1
%cmp2 = icmp ult i32 %inc1, %N
br i1 %cmp2, label %for.inner, label %for.outer

; CHECK-LABEL: for.outer:
; CHECK: %cmp1 = icmp ult i32 %inc2, %flatten.tripcount
for.outer:
%inc2 = add i32 %i, 1
%cmp1 = icmp ult i32 %inc2, %N
br i1 %cmp1, label %for.inner.preheader, label %for.end.loopexit

for.end.loopexit:
br label %for.end

for.end:
ret void
}

; We can flatten, but the flattened gep has to be inserted after the load it
; depends on.
; CHECK-LABEL: test2
define void @test2(i32 %N, ptr %A) {
entry:
%cmp3 = icmp ult i32 0, %N
br i1 %cmp3, label %for.outer.preheader, label %for.end

; CHECK-LABEL: for.outer.preheader:
; CHECK: %flatten.tripcount = mul i32 %N, %N
for.outer.preheader:
br label %for.inner.preheader

; CHECK-LABEL: for.inner.preheader:
; CHECK-NOT: getelementptr i32, ptr %ptr, i32 %i
for.inner.preheader:
%i = phi i32 [ 0, %for.outer.preheader ], [ %inc2, %for.outer ]
br label %for.inner

; CHECK-LABEL: for.inner:
; CHECK: %flatten.arrayidx = getelementptr i32, ptr %ptr, i32 %i
; CHECK: store i32 0, ptr %flatten.arrayidx, align 4
; CHECK: br label %for.outer
for.inner:
%j = phi i32 [ 0, %for.inner.preheader ], [ %inc1, %for.inner ]
%ptr = load volatile ptr, ptr %A, align 4
%mul = mul i32 %i, %N
%gep = getelementptr inbounds i32, ptr %ptr, i32 %mul
%arrayidx = getelementptr inbounds i32, ptr %gep, i32 %j
store i32 0, ptr %arrayidx, align 4
%inc1 = add nuw i32 %j, 1
%cmp2 = icmp ult i32 %inc1, %N
br i1 %cmp2, label %for.inner, label %for.outer

; CHECK-LABEL: for.outer:
; CHECK: %cmp1 = icmp ult i32 %inc2, %flatten.tripcount
for.outer:
%inc2 = add i32 %i, 1
%cmp1 = icmp ult i32 %inc2, %N
br i1 %cmp1, label %for.inner.preheader, label %for.end.loopexit

for.end.loopexit:
br label %for.end

for.end:
ret void
}

; We can't flatten if the gep offset is smaller than the pointer size.
; CHECK-LABEL: test3
define void @test3(i16 %N, ptr %A) {
entry:
%cmp3 = icmp ult i16 0, %N
br i1 %cmp3, label %for.outer.preheader, label %for.end

for.outer.preheader:
br label %for.inner.preheader

; CHECK-LABEL: for.inner.preheader:
; CHECK-NOT: getelementptr i32, ptr %A, i16 %i
for.inner.preheader:
%i = phi i16 [ 0, %for.outer.preheader ], [ %inc2, %for.outer ]
br label %for.inner

; CHECK-LABEL: for.inner:
; CHECK-NOT: getelementptr i32, ptr %A, i16 %i
; CHECK: br i1 %cmp2, label %for.inner, label %for.outer
for.inner:
%j = phi i16 [ 0, %for.inner.preheader ], [ %inc1, %for.inner ]
%mul = mul i16 %i, %N
%gep = getelementptr inbounds i32, ptr %A, i16 %mul
%arrayidx = getelementptr inbounds i32, ptr %gep, i16 %j
store i32 0, ptr %arrayidx, align 4
%inc1 = add nuw i16 %j, 1
%cmp2 = icmp ult i16 %inc1, %N
br i1 %cmp2, label %for.inner, label %for.outer

for.outer:
%inc2 = add i16 %i, 1
%cmp1 = icmp ult i16 %inc2, %N
br i1 %cmp1, label %for.inner.preheader, label %for.end.loopexit

for.end.loopexit:
br label %for.end

for.end:
ret void
}

0 comments on commit ae978ba

Please sign in to comment.