Skip to content

Commit

Permalink
[RISCV][WIP] Let RA do the CSR saves.
Browse files Browse the repository at this point in the history
We turn the problem of saving and restoring callee-saved registers efficiently into a register allocation problem. This has the advantage that the register allocator can essentialy do shrink-wrapping on per register basis. Currently, shrink-wrapping pass saves all CSR in the same place which may be suboptimal. Also, improvements to register allocation / coalescing will translate to improvements in shrink-wrapping.

In finalizeLowering() we copy all callee-saved registers from a physical register to a virtual one. In all return blocks we copy do the reverse.
  • Loading branch information
Mikhail Gudim committed Nov 6, 2024
1 parent 5cbd4b0 commit ae5d2b4
Show file tree
Hide file tree
Showing 23 changed files with 1,138 additions and 27 deletions.
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/ReachingDefAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,11 @@ class ReachingDefAnalysis : public MachineFunctionPass {
private:
MachineFunction *MF = nullptr;
const TargetRegisterInfo *TRI = nullptr;
const TargetInstrInfo *TII = nullptr;
LoopTraversal::TraversalOrder TraversedMBBOrder;
unsigned NumRegUnits = 0;
unsigned NumStackObjects = 0;
int ObjectIndexBegin = 0;
/// Instruction that defined each register, relative to the beginning of the
/// current basic block. When a LiveRegsDefInfo is used to represent a
/// live-out register, this value is relative to the end of the basic block,
Expand All @@ -138,6 +141,8 @@ class ReachingDefAnalysis : public MachineFunctionPass {
DenseMap<MachineInstr *, int> InstIds;

MBBReachingDefsInfo MBBReachingDefs;
using MBBFrameObjsReachingDefsInfo = std::vector<std::vector<std::vector<int>>>;
MBBFrameObjsReachingDefsInfo MBBFrameObjsReachingDefs;

/// Default values are 'nothing happened a long time ago'.
const int ReachingDefDefaultVal = -(1 << 21);
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGen/TargetFrameLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace llvm {
class CalleeSavedInfo;
class MachineFunction;
class RegScavenger;
class ReachingDefAnalysis;

namespace TargetStackID {
enum Value {
Expand Down Expand Up @@ -210,6 +211,11 @@ class TargetFrameLowering {
/// for noreturn nounwind functions.
virtual bool enableCalleeSaveSkip(const MachineFunction &MF) const;

virtual void emitCFIsForCSRsHandledByRA(MachineFunction &MF,
ReachingDefAnalysis *RDA) const {
return;
}

/// emitProlog/emitEpilog - These methods insert prolog and epilog code into
/// the function.
virtual void emitPrologue(MachineFunction &MF,
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/TargetSubtargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ class TargetSubtargetInfo : public MCSubtargetInfo {
return false;
}

virtual bool doCSRSavesInRA() const;

/// Classify a global function reference. This mainly used to fetch target
/// special flags for lowering a function address. For example mark a function
/// call should be plt or pc-related addressing.
Expand Down
44 changes: 32 additions & 12 deletions llvm/lib/CodeGen/MachineLICM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,19 @@ namespace {
void HoistOutOfLoop(MachineDomTreeNode *HeaderN, MachineLoop *CurLoop,
MachineBasicBlock *CurPreheader);

void InitRegPressure(MachineBasicBlock *BB);
void InitRegPressure(MachineBasicBlock *BB, const MachineLoop* Loop);

SmallDenseMap<unsigned, int> calcRegisterCost(const MachineInstr *MI,
bool ConsiderSeen,
bool ConsiderUnseenAsDef);
bool ConsiderUnseenAsDef,
bool IgnoreDefs = false);

bool allDefsAreOnlyUsedOutsideOfTheLoop(const MachineInstr &MI, const MachineLoop *Loop);
void UpdateRegPressure(const MachineInstr *MI,
bool ConsiderUnseenAsDef = false);
bool ConsiderUnseenAsDef = false, bool IgnoreDefs = false);

void UpdateRegPressureForUsesOnly(const MachineInstr *MI,
bool ConsiderUnseenAsDef = false);
MachineInstr *ExtractHoistableLoad(MachineInstr *MI, MachineLoop *CurLoop);

MachineInstr *LookForDuplicate(const MachineInstr *MI,
Expand Down Expand Up @@ -884,7 +888,7 @@ void MachineLICMImpl::HoistOutOfLoop(MachineDomTreeNode *HeaderN,
// Compute registers which are livein into the loop headers.
RegSeen.clear();
BackTrace.clear();
InitRegPressure(Preheader);
InitRegPressure(Preheader, CurLoop);

// Now perform LICM.
for (MachineDomTreeNode *Node : Scopes) {
Expand Down Expand Up @@ -934,7 +938,7 @@ static bool isOperandKill(const MachineOperand &MO, MachineRegisterInfo *MRI) {
/// Find all virtual register references that are liveout of the preheader to
/// initialize the starting "register pressure". Note this does not count live
/// through (livein but not used) registers.
void MachineLICMImpl::InitRegPressure(MachineBasicBlock *BB) {
void MachineLICMImpl::InitRegPressure(MachineBasicBlock *BB, const MachineLoop *Loop) {
std::fill(RegPressure.begin(), RegPressure.end(), 0);

// If the preheader has only a single predecessor and it ends with a
Expand All @@ -945,17 +949,32 @@ void MachineLICMImpl::InitRegPressure(MachineBasicBlock *BB) {
MachineBasicBlock *TBB = nullptr, *FBB = nullptr;
SmallVector<MachineOperand, 4> Cond;
if (!TII->analyzeBranch(*BB, TBB, FBB, Cond, false) && Cond.empty())
InitRegPressure(*BB->pred_begin());
InitRegPressure(*BB->pred_begin(), Loop);
}

for (const MachineInstr &MI : *BB)
UpdateRegPressure(&MI, /*ConsiderUnseenAsDef=*/true);
for (const MachineInstr &MI : *BB) {
bool IgnoreDefs = allDefsAreOnlyUsedOutsideOfTheLoop(MI, Loop);
UpdateRegPressure(&MI, /*ConsiderUnseenAsDef=*/true, IgnoreDefs);
}
}

bool MachineLICMImpl::allDefsAreOnlyUsedOutsideOfTheLoop(const MachineInstr &MI, const MachineLoop *Loop) {
for (const MachineOperand DefMO : MI.all_defs()) {
if (!DefMO.isReg())
continue;
for(const MachineInstr &UseMI : MRI->use_instructions(DefMO.getReg())) {
if (Loop->contains(UseMI.getParent()))
return false;
}
}
return true;
}

/// Update estimate of register pressure after the specified instruction.
void MachineLICMImpl::UpdateRegPressure(const MachineInstr *MI,
bool ConsiderUnseenAsDef) {
auto Cost = calcRegisterCost(MI, /*ConsiderSeen=*/true, ConsiderUnseenAsDef);
bool ConsiderUnseenAsDef,
bool IgnoreDefs) {
auto Cost = calcRegisterCost(MI, /*ConsiderSeen=*/true, ConsiderUnseenAsDef, IgnoreDefs);
for (const auto &RPIdAndCost : Cost) {
unsigned Class = RPIdAndCost.first;
if (static_cast<int>(RegPressure[Class]) < -RPIdAndCost.second)
Expand All @@ -973,7 +992,8 @@ void MachineLICMImpl::UpdateRegPressure(const MachineInstr *MI,
/// FIXME: Figure out a way to consider 'RegSeen' from all code paths.
SmallDenseMap<unsigned, int>
MachineLICMImpl::calcRegisterCost(const MachineInstr *MI, bool ConsiderSeen,
bool ConsiderUnseenAsDef) {
bool ConsiderUnseenAsDef,
bool IgnoreDefs) {
SmallDenseMap<unsigned, int> Cost;
if (MI->isImplicitDef())
return Cost;
Expand All @@ -991,7 +1011,7 @@ MachineLICMImpl::calcRegisterCost(const MachineInstr *MI, bool ConsiderSeen,

RegClassWeight W = TRI->getRegClassWeight(RC);
int RCCost = 0;
if (MO.isDef())
if (MO.isDef() && !IgnoreDefs)
RCCost = W.RegWeight;
else {
bool isKill = isOperandKill(MO, MRI);
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/CodeGen/PrologEpilogInserter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/ReachingDefAnalysis.h"
#include "llvm/CodeGen/RegisterScavenging.h"
#include "llvm/CodeGen/TargetFrameLowering.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
Expand Down Expand Up @@ -95,6 +96,7 @@ class PEI : public MachineFunctionPass {
bool runOnMachineFunction(MachineFunction &MF) override;

private:
ReachingDefAnalysis *RDA = nullptr;
RegScavenger *RS = nullptr;

// MinCSFrameIndex, MaxCSFrameIndex - Keeps the range of callee saved
Expand Down Expand Up @@ -153,6 +155,7 @@ INITIALIZE_PASS_BEGIN(PEI, DEBUG_TYPE, "Prologue/Epilogue Insertion", false,
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
INITIALIZE_PASS_DEPENDENCY(ReachingDefAnalysis)
INITIALIZE_PASS_END(PEI, DEBUG_TYPE,
"Prologue/Epilogue Insertion & Frame Finalization", false,
false)
Expand All @@ -169,6 +172,7 @@ void PEI::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addPreserved<MachineLoopInfoWrapperPass>();
AU.addPreserved<MachineDominatorTreeWrapperPass>();
AU.addRequired<MachineOptimizationRemarkEmitterPass>();
AU.addRequired<ReachingDefAnalysis>();
MachineFunctionPass::getAnalysisUsage(AU);
}

Expand Down Expand Up @@ -227,6 +231,7 @@ bool PEI::runOnMachineFunction(MachineFunction &MF) {
RS = TRI->requiresRegisterScavenging(MF) ? new RegScavenger() : nullptr;
FrameIndexVirtualScavenging = TRI->requiresFrameIndexScavenging(MF);
ORE = &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE();
RDA = &getAnalysis<ReachingDefAnalysis>();

// Spill frame pointer and/or base pointer registers if they are clobbered.
// It is placed before call frame instruction elimination so it will not mess
Expand Down Expand Up @@ -262,6 +267,7 @@ bool PEI::runOnMachineFunction(MachineFunction &MF) {
// called functions. Because of this, calculateCalleeSavedRegisters()
// must be called before this function in order to set the AdjustsStack
// and MaxCallFrameSize variables.
RDA->reset();
if (!F.hasFnAttribute(Attribute::Naked))
insertPrologEpilogCode(MF);

Expand Down Expand Up @@ -1164,6 +1170,7 @@ void PEI::calculateFrameObjectOffsets(MachineFunction &MF) {
void PEI::insertPrologEpilogCode(MachineFunction &MF) {
const TargetFrameLowering &TFI = *MF.getSubtarget().getFrameLowering();

TFI.emitCFIsForCSRsHandledByRA(MF, RDA);
// Add prologue to the function...
for (MachineBasicBlock *SaveBlock : SaveBlocks)
TFI.emitPrologue(MF, *SaveBlock);
Expand Down
60 changes: 56 additions & 4 deletions llvm/lib/CodeGen/ReachingDefAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/CodeGen/LiveRegUnits.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -48,12 +50,31 @@ static bool isValidRegDefOf(const MachineOperand &MO, MCRegister PhysReg,
return TRI->regsOverlap(MO.getReg(), PhysReg);
}

static bool isFIDef(const MachineInstr &MI, int FrameIndex, const TargetInstrInfo *TII) {
int DefFrameIndex = 0;
int SrcFrameIndex = 0;
if (
TII->isStoreToStackSlot(MI, DefFrameIndex) ||
TII->isStackSlotCopy(MI, DefFrameIndex, SrcFrameIndex)
) {
return DefFrameIndex == FrameIndex;
}
return false;
}


void ReachingDefAnalysis::enterBasicBlock(MachineBasicBlock *MBB) {
unsigned MBBNumber = MBB->getNumber();
assert(MBBNumber < MBBReachingDefs.numBlockIDs() &&
"Unexpected basic block number.");
MBBReachingDefs.startBasicBlock(MBBNumber, NumRegUnits);

MBBFrameObjsReachingDefs[MBBNumber].resize(NumStackObjects);
for (unsigned FOIdx = 0; FOIdx < NumStackObjects; ++FOIdx) {
MBBFrameObjsReachingDefs[MBBNumber][FOIdx].push_back(-1);
}


// Reset instruction counter in each basic block.
CurInstr = 0;

Expand Down Expand Up @@ -126,6 +147,12 @@ void ReachingDefAnalysis::processDefs(MachineInstr *MI) {
"Unexpected basic block number.");

for (auto &MO : MI->operands()) {
if (MO.isFI()) {
int FrameIndex = MO.getIndex();
if (!isFIDef(*MI, FrameIndex, TII))
continue;
MBBFrameObjsReachingDefs[MBBNumber][FrameIndex - ObjectIndexBegin].push_back(CurInstr);
}
if (!isValidRegDef(MO))
continue;
for (MCRegUnit Unit : TRI->regunits(MO.getReg().asMCReg())) {
Expand Down Expand Up @@ -211,7 +238,9 @@ void ReachingDefAnalysis::processBasicBlock(

bool ReachingDefAnalysis::runOnMachineFunction(MachineFunction &mf) {
MF = &mf;
TRI = MF->getSubtarget().getRegisterInfo();
const TargetSubtargetInfo &STI = MF->getSubtarget();
TRI = STI.getRegisterInfo();
TII = STI.getInstrInfo();
LLVM_DEBUG(dbgs() << "********** REACHING DEFINITION ANALYSIS **********\n");
init();
traverse();
Expand All @@ -222,6 +251,7 @@ void ReachingDefAnalysis::releaseMemory() {
// Clear the internal vectors.
MBBOutRegsInfos.clear();
MBBReachingDefs.clear();
MBBFrameObjsReachingDefs.clear();
InstIds.clear();
LiveRegs.clear();
}
Expand All @@ -234,7 +264,10 @@ void ReachingDefAnalysis::reset() {

void ReachingDefAnalysis::init() {
NumRegUnits = TRI->getNumRegUnits();
NumStackObjects = MF->getFrameInfo().getNumObjects();
ObjectIndexBegin = MF->getFrameInfo().getObjectIndexBegin();
MBBReachingDefs.init(MF->getNumBlockIDs());
MBBFrameObjsReachingDefs.resize(MF->getNumBlockIDs());
// Initialize the MBBOutRegsInfos
MBBOutRegsInfos.resize(MF->getNumBlockIDs());
LoopTraversal Traversal;
Expand Down Expand Up @@ -269,6 +302,18 @@ int ReachingDefAnalysis::getReachingDef(MachineInstr *MI,
assert(MBBNumber < MBBReachingDefs.numBlockIDs() &&
"Unexpected basic block number.");
int LatestDef = ReachingDefDefaultVal;

if (Register::isStackSlot(PhysReg)) {
int FrameIndex = Register::stackSlot2Index(PhysReg);
for (int Def : MBBFrameObjsReachingDefs[MBBNumber][FrameIndex - ObjectIndexBegin]) {
if (Def >= InstId)
break;
DefRes = Def;
}
LatestDef = std::max(LatestDef, DefRes);
return LatestDef;
}

for (MCRegUnit Unit : TRI->regunits(PhysReg)) {
for (int Def : MBBReachingDefs.defs(MBBNumber, Unit)) {
if (Def >= InstId)
Expand Down Expand Up @@ -425,7 +470,7 @@ void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB,
VisitedBBs.insert(MBB);
LiveRegUnits LiveRegs(*TRI);
LiveRegs.addLiveOuts(*MBB);
if (LiveRegs.available(PhysReg))
if (Register::isPhysicalRegister(PhysReg) && LiveRegs.available(PhysReg))
return;

if (auto *Def = getLocalLiveOutMIDef(MBB, PhysReg))
Expand Down Expand Up @@ -508,7 +553,7 @@ bool ReachingDefAnalysis::isReachingDefLiveOut(MachineInstr *MI,
MachineBasicBlock *MBB = MI->getParent();
LiveRegUnits LiveRegs(*TRI);
LiveRegs.addLiveOuts(*MBB);
if (LiveRegs.available(PhysReg))
if (Register::isPhysicalRegister(PhysReg) && LiveRegs.available(PhysReg))
return false;

auto Last = MBB->getLastNonDebugInstr();
Expand All @@ -529,14 +574,21 @@ ReachingDefAnalysis::getLocalLiveOutMIDef(MachineBasicBlock *MBB,
MCRegister PhysReg) const {
LiveRegUnits LiveRegs(*TRI);
LiveRegs.addLiveOuts(*MBB);
if (LiveRegs.available(PhysReg))
if (Register::isPhysicalRegister(PhysReg) && LiveRegs.available(PhysReg))
return nullptr;

auto Last = MBB->getLastNonDebugInstr();
if (Last == MBB->end())
return nullptr;

int Def = getReachingDef(&*Last, PhysReg);

if (Register::isStackSlot(PhysReg)) {
int FrameIndex = Register::stackSlot2Index(PhysReg);
if (isFIDef(*Last, FrameIndex, TII))
return &*Last;
}

for (auto &MO : Last->operands())
if (isValidRegDefOf(MO, PhysReg, TRI))
return &*Last;
Expand Down
15 changes: 13 additions & 2 deletions llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ static cl::opt<bool> EnableLocalReassignment(
"may be compile time intensive"),
cl::init(false));

static cl::opt<float> MinWeightRatioNeededToEvictHint(
"min-weight-ratio-needed-to-evict-hint", cl::Hidden,
cl::desc("The minimum ration of weight needed in order for a live range with bigger weight to evict another live range which satisfies a hint"),
cl::init(1.0));

namespace llvm {
cl::opt<unsigned> EvictInterferenceCutoff(
"regalloc-eviction-max-interference-cutoff", cl::Hidden,
Expand Down Expand Up @@ -156,8 +161,14 @@ bool DefaultEvictionAdvisor::shouldEvict(const LiveInterval &A, bool IsHint,
if (CanSplit && IsHint && !BreaksHint)
return true;

if (A.weight() > B.weight()) {
LLVM_DEBUG(dbgs() << "should evict: " << B << " w= " << B.weight() << '\n');
float AWeight = A.weight();
float BWeight = B.weight();
if (AWeight > BWeight) {
float WeightRatio = BWeight == 0.0 ? std::numeric_limits<float>::infinity() : AWeight / BWeight;
if (CanSplit && !IsHint && BreaksHint && (WeightRatio < MinWeightRatioNeededToEvictHint)) {
return false;
}
LLVM_DEBUG(dbgs() << "should evict: " << B << " w= " << BWeight << '\n');
return true;
}
return false;
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/TargetSubtargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ bool TargetSubtargetInfo::enableRALocalReassignment(
return true;
}

bool TargetSubtargetInfo::doCSRSavesInRA() const {
return false;
}

bool TargetSubtargetInfo::enablePostRAScheduler() const {
return getSchedModel().PostRAScheduler;
}
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ add_public_tablegen_target(RISCVCommonTableGen)

add_llvm_target(RISCVCodeGen
RISCVAsmPrinter.cpp
RISCVCFIInserter.cpp
RISCVCallingConv.cpp
RISCVCodeGenPrepare.cpp
RISCVConstantPoolValue.cpp
Expand Down
Loading

0 comments on commit ae5d2b4

Please sign in to comment.