Skip to content

Commit

Permalink
[CostModel][AArch64] Make extractelement, with fmul user, free whenev…
Browse files Browse the repository at this point in the history
…er possible

In case of Neon, if there exists extractelement from lane != 0 such that
  1. extractelement does not necessitate a move from vector_reg -> GPR.
  2. extractelement result feeds into fmul.
  3. Other operand of fmul is a scalar or extractelement from lane 0 or lane equivalent to 0.
  then the extractelement can be merged with fmul in the backend and it incurs no cost.
  e.g.
  define double @foo(<2 x double> %a) {
    %1 = extractelement <2 x double> %a, i32 0
    %2 = extractelement <2 x double> %a, i32 1
    %res = fmul double %1, %2    ret double %res
  }
  %2 and %res can be merged in the backend to generate:
  fmul    d0, d0, v0.d[1]

The change was tested with SPEC FP(C/C++) on Neoverse-v2.
Compile time impact: None
Performance impact: Observing 1.3-1.7% uplift on lbm benchmark with -flto depending upon the config.
  • Loading branch information
sushgokh committed Oct 9, 2024
1 parent 924a64a commit 85ca8b2
Show file tree
Hide file tree
Showing 9 changed files with 284 additions and 72 deletions.
26 changes: 26 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#define LLVM_ANALYSIS_TARGETTRANSFORMINFO_H

#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/FMF.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/PassManager.h"
Expand Down Expand Up @@ -1392,6 +1394,16 @@ class TargetTransformInfo {
unsigned Index = -1, Value *Op0 = nullptr,
Value *Op1 = nullptr) const;

/// \return The expected cost of vector Insert and Extract.
/// Use -1 to indicate that there is no information on the index value.
/// This is used when the instruction is not available; a typical use
/// case is to provision the cost of vectorization/scalarization in
/// vectorizer passes.
InstructionCost getVectorInstrCost(
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
Value *Scalar,
const ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const;

/// \return The expected cost of vector Insert and Extract.
/// This is used when instruction is available, and implementation
/// asserts 'I' is not nullptr.
Expand Down Expand Up @@ -2062,6 +2074,12 @@ class TargetTransformInfo::Concept {
TTI::TargetCostKind CostKind,
unsigned Index, Value *Op0,
Value *Op1) = 0;

virtual InstructionCost getVectorInstrCost(
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
Value *Scalar,
const ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) = 0;

virtual InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
TTI::TargetCostKind CostKind,
unsigned Index) = 0;
Expand Down Expand Up @@ -2726,6 +2744,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
Value *Op1) override {
return Impl.getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
}
InstructionCost
getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind,
unsigned Index, Value *Scalar,
const ArrayRef<std::tuple<Value *, User *, int>>
ScalarUserAndIdx) override {
return Impl.getVectorInstrCost(Opcode, Val, CostKind, Index, Scalar,
ScalarUserAndIdx);
}
InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
TTI::TargetCostKind CostKind,
unsigned Index) override {
Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,13 @@ class TargetTransformInfoImplBase {
return 1;
}

InstructionCost getVectorInstrCost(
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
Value *Scalar,
const ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
return 1;
}

InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
TTI::TargetCostKind CostKind,
unsigned Index) const {
Expand Down
15 changes: 12 additions & 3 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1277,12 +1277,21 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
return 1;
}

InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
TTI::TargetCostKind CostKind,
unsigned Index, Value *Op0, Value *Op1) {
virtual InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
TTI::TargetCostKind CostKind,
unsigned Index, Value *Op0,
Value *Op1) {
return getRegUsageForType(Val->getScalarType());
}

InstructionCost getVectorInstrCost(
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
Value *Scalar,
const ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx
) {
return getVectorInstrCost(Opcode, Val, CostKind, Index, nullptr, nullptr);
}

InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
TTI::TargetCostKind CostKind,
unsigned Index) {
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,19 @@ InstructionCost TargetTransformInfo::getVectorInstrCost(
return Cost;
}

InstructionCost TargetTransformInfo::getVectorInstrCost(
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
Value *Scalar,
const ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
// FIXME: Assert that Opcode is either InsertElement or ExtractElement.
// This is mentioned in the interface description and respected by all
// callers, but never asserted upon.
InstructionCost Cost = TTIImpl->getVectorInstrCost(
Opcode, Val, CostKind, Index, Scalar, ScalarUserAndIdx);
assert(Cost >= 0 && "TTI should not produce negative costs!");
return Cost;
}

InstructionCost
TargetTransformInfo::getVectorInstrCost(const Instruction &I, Type *Val,
TTI::TargetCostKind CostKind,
Expand Down
171 changes: 167 additions & 4 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,28 @@
#include "AArch64ExpandImm.h"
#include "AArch64PerfectShuffle.h"
#include "MCTargetDesc/AArch64AddressingModes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/IVDescriptors.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/BasicTTIImpl.h"
#include "llvm/CodeGen/CostTable.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/User.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
#include <algorithm>
#include <cassert>
#include <optional>
using namespace llvm;
using namespace llvm::PatternMatch;
Expand Down Expand Up @@ -3145,12 +3153,16 @@ InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
return 0;
}

InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
Type *Val,
unsigned Index,
bool HasRealUse) {
InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
std::variant<const Instruction *, const unsigned> InstOrOpcode, Type *Val,
unsigned Index, bool HasRealUse, Value *Scalar,
const ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) {
assert(Val->isVectorTy() && "This must be a vector type");

const auto *I = (std::holds_alternative<const Instruction *>(InstOrOpcode)
? get<const Instruction *>(InstOrOpcode)
: nullptr);

if (Index != -1U) {
// Legalize the type.
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
Expand Down Expand Up @@ -3194,6 +3206,149 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
// compile-time considerations.
}

// In case of Neon, if there exists extractelement from lane != 0 such that
// 1. extractelement does not necessitate a move from vector_reg -> GPR.
// 2. extractelement result feeds into fmul.
// 3. Other operand of fmul is a scalar or extractelement from lane 0 or lane
// equivalent to 0.
// then the extractelement can be merged with fmul in the backend and it
// incurs no cost.
// e.g.
// define double @foo(<2 x double> %a) {
// %1 = extractelement <2 x double> %a, i32 0
// %2 = extractelement <2 x double> %a, i32 1
// %res = fmul double %1, %2
// ret double %res
// }
// %2 and %res can be merged in the backend to generate fmul v0, v0, v1.d[1]
auto ExtractCanFuseWithFmul = [&]() {
// We bail out if the extract is from lane 0.
if (Index == 0)
return false;

// Check if the scalar element type of the vector operand of ExtractElement
// instruction is one of the allowed types.
auto IsAllowedScalarTy = [&](const Type *T) {
return T->isFloatTy() || T->isDoubleTy() ||
(T->isHalfTy() && ST->hasFullFP16());
};

// Check if the extractelement user is scalar fmul.
auto IsUserFMulScalarTy = [](const Value *EEUser) {
// Check if the user is scalar fmul.
const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
return BO && BO->getOpcode() == BinaryOperator::FMul &&
!BO->getType()->isVectorTy();
};

// InstCombine combines fmul with fadd/fsub. Hence, extractelement fusion
// with fmul does not happen.
auto IsFMulUserFAddFSub = [](const Value *FMul) {
return any_of(FMul->users(), [](const User *U) {
const auto *BO = dyn_cast_if_present<BinaryOperator>(U);
return (BO && (BO->getOpcode() == BinaryOperator::FAdd ||
BO->getOpcode() == BinaryOperator::FSub));
});
};

// Check if the type constraints on input vector type and result scalar type
// of extractelement instruction are satisfied.
auto TypeConstraintsOnEESatisfied =
[&IsAllowedScalarTy](const Type *VectorTy, const Type *ScalarTy) {
return isa<FixedVectorType>(VectorTy) && IsAllowedScalarTy(ScalarTy);
};

// Check if the extract index is from lane 0 or lane equivalent to 0 for a
// certain scalar type and a certain vector register width.
auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
const unsigned &EltSz) {
auto RegWidth =
getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
.getFixedValue();
return (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
};

if (std::holds_alternative<const unsigned>(InstOrOpcode)) {
if (!TypeConstraintsOnEESatisfied(Val, Val->getScalarType()))
return false;

for (auto &RefT : ScalarUserAndIdx) {
Value *RefS = get<0>(RefT);
User *RefU = get<1>(RefT);
const int &RefL = get<2>(RefT);

// Analayze all the users which have same scalar/index as Scalar/Index.
if (RefS != Scalar || RefL != Index)
continue;

// Check if the user of {Scalar, Index} pair is fmul user.
if (!IsUserFMulScalarTy(RefU) || IsFMulUserFAddFSub(RefU))
return false;

// For RefU, check if the other operand is extract from the same SLP
// tree. If not, we bail out since we can't analyze extracts from other
// SLP tree.
unsigned NumExtractEltsIntoUser = 0;
for (auto &CmpT : ScalarUserAndIdx) {
User *CmpU = get<1>(CmpT);
if (CmpT == RefT || CmpU != RefU)
continue;
Value *CmpS = get<0>(CmpT);
++NumExtractEltsIntoUser;
const int &CmpL = get<2>(CmpT);
if (!IsExtractLaneEquivalentToZero(CmpL, Val->getScalarSizeInBits()))
return false;
}
// We know this is fmul user with just 2 operands, one being RefT. If we
// can't find CmpT, as the other operand, then bail out.
if (NumExtractEltsIntoUser != 1)
return false;
}
} else {
const auto *EE = cast<ExtractElementInst>(I);

const auto *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand());
if (!IdxOp)
return false;

if (!TypeConstraintsOnEESatisfied(EE->getVectorOperand()->getType(),
EE->getType()))
return false;

return !EE->users().empty() && all_of(EE->users(), [&](const User *U) {
if (!IsUserFMulScalarTy(U) || IsFMulUserFAddFSub(U))
return false;

// Check if the other operand of extractelement is also extractelement
// from lane equivalent to 0.
const auto *BO = cast<BinaryOperator>(U);
const auto *OtherEE = dyn_cast<ExtractElementInst>(
BO->getOperand(0) == EE ? BO->getOperand(1) : BO->getOperand(0));
if (OtherEE) {
const auto *IdxOp = dyn_cast<ConstantInt>(OtherEE->getIndexOperand());
if (!IdxOp)
return false;
return IsExtractLaneEquivalentToZero(
cast<ConstantInt>(OtherEE->getIndexOperand())
->getValue()
.getZExtValue(),
OtherEE->getType()->getScalarSizeInBits());
}
return true;
});
}
return true;
};

if (std::holds_alternative<const unsigned>(InstOrOpcode)) {
const unsigned &Opcode = get<const unsigned>(InstOrOpcode);
if (Opcode == Instruction::ExtractElement && ExtractCanFuseWithFmul())
return 0;
} else if (I && I->getOpcode() == Instruction::ExtractElement &&
ExtractCanFuseWithFmul()) {
return 0;
}

// All other insert/extracts cost this much.
return ST->getVectorInsertExtractBaseCost();
}
Expand All @@ -3207,6 +3362,14 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
return getVectorInstrCostHelper(nullptr, Val, Index, HasRealUse);
}

InstructionCost AArch64TTIImpl::getVectorInstrCost(
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
Value *Scalar,
const ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) {
return getVectorInstrCostHelper(Opcode, Val, Index, false, Scalar,
ScalarUserAndIdx);
}

InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
Type *Val,
TTI::TargetCostKind CostKind,
Expand Down
13 changes: 11 additions & 2 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
// 'Val' and 'Index' are forwarded from 'getVectorInstrCost'; 'HasRealUse'
// indicates whether the vector instruction is available in the input IR or
// just imaginary in vectorizer passes.
InstructionCost getVectorInstrCostHelper(const Instruction *I, Type *Val,
unsigned Index, bool HasRealUse);
InstructionCost getVectorInstrCostHelper(
std::variant<const Instruction *, const unsigned> InstOrOpcode, Type *Val,
unsigned Index, bool HasRealUse, Value *Scalar = nullptr,
const ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx =
SmallVector<std::tuple<Value *, User *, int>, 0>());

public:
explicit AArch64TTIImpl(const AArch64TargetMachine *TM, const Function &F)
Expand Down Expand Up @@ -185,6 +188,12 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
TTI::TargetCostKind CostKind,
unsigned Index, Value *Op0, Value *Op1);

InstructionCost getVectorInstrCost(
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
Value *Scalar,
const ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx);

InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
TTI::TargetCostKind CostKind,
unsigned Index);
Expand Down
12 changes: 10 additions & 2 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11633,6 +11633,13 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
std::optional<DenseMap<Value *, unsigned>> ValueToExtUses;
DenseMap<const TreeEntry *, DenseSet<Value *>> ExtractsCount;
SmallPtrSet<Value *, 4> ScalarOpsFromCasts;
// Keep track {Scalar, Index, User} tuple.
// On AArch64, this helps in fusing a mov instruction, associated with
// extractelement, with fmul in the backend so that extractelement is free.
SmallVector<std::tuple<Value *, User *, int>, 4> ScalarUserAndIdx;
for (ExternalUser &EU : ExternalUses) {
ScalarUserAndIdx.emplace_back(std::make_tuple(EU.Scalar, EU.User, EU.Lane));
}
for (ExternalUser &EU : ExternalUses) {
// Uses by ephemeral values are free (because the ephemeral value will be
// removed prior to code generation, and so the extraction will be
Expand Down Expand Up @@ -11739,8 +11746,9 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
ExtraCost = TTI->getExtractWithExtendCost(Extend, EU.Scalar->getType(),
VecTy, EU.Lane);
} else {
ExtraCost = TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy,
CostKind, EU.Lane);
ExtraCost =
TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
EU.Lane, EU.Scalar, ScalarUserAndIdx);
}
// Leave the scalar instructions as is if they are cheaper than extracts.
if (Entry->Idx != 0 || Entry->getOpcode() == Instruction::GetElementPtr ||
Expand Down
Loading

0 comments on commit 85ca8b2

Please sign in to comment.