From ae5d2b4ba2a4d6c82f40a03cd344a228eeda3683 Mon Sep 17 00:00:00 2001 From: Mikhail Gudim Date: Fri, 1 Nov 2024 01:13:59 -0700 Subject: [PATCH] [RISCV][WIP] Let RA do the CSR saves. We turn the problem of saving and restoring callee-saved registers efficiently into a register allocation problem. This has the advantage that the register allocator can essentialy do shrink-wrapping on per register basis. Currently, shrink-wrapping pass saves all CSR in the same place which may be suboptimal. Also, improvements to register allocation / coalescing will translate to improvements in shrink-wrapping. In finalizeLowering() we copy all callee-saved registers from a physical register to a virtual one. In all return blocks we copy do the reverse. --- .../llvm/CodeGen/ReachingDefAnalysis.h | 5 + .../llvm/CodeGen/TargetFrameLowering.h | 6 + .../llvm/CodeGen/TargetSubtargetInfo.h | 2 + llvm/lib/CodeGen/MachineLICM.cpp | 44 +- llvm/lib/CodeGen/PrologEpilogInserter.cpp | 7 + llvm/lib/CodeGen/ReachingDefAnalysis.cpp | 60 +- llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp | 15 +- llvm/lib/CodeGen/TargetSubtargetInfo.cpp | 4 + llvm/lib/Target/RISCV/CMakeLists.txt | 1 + llvm/lib/Target/RISCV/RISCV.h | 3 + llvm/lib/Target/RISCV/RISCVCFIInserter.cpp | 569 ++++++++++++++++++ llvm/lib/Target/RISCV/RISCVFrameLowering.cpp | 244 +++++++- llvm/lib/Target/RISCV/RISCVFrameLowering.h | 6 + llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 119 ++++ llvm/lib/Target/RISCV/RISCVISelLowering.h | 2 + llvm/lib/Target/RISCV/RISCVInstrInfo.h | 8 + llvm/lib/Target/RISCV/RISCVInstrInfo.td | 4 + .../Target/RISCV/RISCVMachineFunctionInfo.cpp | 28 + .../Target/RISCV/RISCVMachineFunctionInfo.h | 17 + llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp | 8 + llvm/lib/Target/RISCV/RISCVSubtarget.cpp | 9 + llvm/lib/Target/RISCV/RISCVSubtarget.h | 2 + llvm/lib/Target/RISCV/RISCVTargetMachine.cpp | 2 + 23 files changed, 1138 insertions(+), 27 deletions(-) create mode 100644 llvm/lib/Target/RISCV/RISCVCFIInserter.cpp diff --git a/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h b/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h index d6a1f064ec0a58..6e6684ae53e0c5 100644 --- a/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h +++ b/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h @@ -114,8 +114,11 @@ class ReachingDefAnalysis : public MachineFunctionPass { private: MachineFunction *MF = nullptr; const TargetRegisterInfo *TRI = nullptr; + const TargetInstrInfo *TII = nullptr; LoopTraversal::TraversalOrder TraversedMBBOrder; unsigned NumRegUnits = 0; + unsigned NumStackObjects = 0; + int ObjectIndexBegin = 0; /// Instruction that defined each register, relative to the beginning of the /// current basic block. When a LiveRegsDefInfo is used to represent a /// live-out register, this value is relative to the end of the basic block, @@ -138,6 +141,8 @@ class ReachingDefAnalysis : public MachineFunctionPass { DenseMap InstIds; MBBReachingDefsInfo MBBReachingDefs; + using MBBFrameObjsReachingDefsInfo = std::vector>>; + MBBFrameObjsReachingDefsInfo MBBFrameObjsReachingDefs; /// Default values are 'nothing happened a long time ago'. const int ReachingDefDefaultVal = -(1 << 21); diff --git a/llvm/include/llvm/CodeGen/TargetFrameLowering.h b/llvm/include/llvm/CodeGen/TargetFrameLowering.h index 97de0197da9b40..db7c9f3fce4398 100644 --- a/llvm/include/llvm/CodeGen/TargetFrameLowering.h +++ b/llvm/include/llvm/CodeGen/TargetFrameLowering.h @@ -24,6 +24,7 @@ namespace llvm { class CalleeSavedInfo; class MachineFunction; class RegScavenger; + class ReachingDefAnalysis; namespace TargetStackID { enum Value { @@ -210,6 +211,11 @@ class TargetFrameLowering { /// for noreturn nounwind functions. virtual bool enableCalleeSaveSkip(const MachineFunction &MF) const; + virtual void emitCFIsForCSRsHandledByRA(MachineFunction &MF, + ReachingDefAnalysis *RDA) const { + return; + } + /// emitProlog/emitEpilog - These methods insert prolog and epilog code into /// the function. virtual void emitPrologue(MachineFunction &MF, diff --git a/llvm/include/llvm/CodeGen/TargetSubtargetInfo.h b/llvm/include/llvm/CodeGen/TargetSubtargetInfo.h index bfaa6450779ae0..ea8237cdbac7d0 100644 --- a/llvm/include/llvm/CodeGen/TargetSubtargetInfo.h +++ b/llvm/include/llvm/CodeGen/TargetSubtargetInfo.h @@ -317,6 +317,8 @@ class TargetSubtargetInfo : public MCSubtargetInfo { return false; } + virtual bool doCSRSavesInRA() const; + /// Classify a global function reference. This mainly used to fetch target /// special flags for lowering a function address. For example mark a function /// call should be plt or pc-related addressing. diff --git a/llvm/lib/CodeGen/MachineLICM.cpp b/llvm/lib/CodeGen/MachineLICM.cpp index 7ea07862b839d0..01fdc102961895 100644 --- a/llvm/lib/CodeGen/MachineLICM.cpp +++ b/llvm/lib/CodeGen/MachineLICM.cpp @@ -262,15 +262,19 @@ namespace { void HoistOutOfLoop(MachineDomTreeNode *HeaderN, MachineLoop *CurLoop, MachineBasicBlock *CurPreheader); - void InitRegPressure(MachineBasicBlock *BB); + void InitRegPressure(MachineBasicBlock *BB, const MachineLoop* Loop); SmallDenseMap calcRegisterCost(const MachineInstr *MI, bool ConsiderSeen, - bool ConsiderUnseenAsDef); + bool ConsiderUnseenAsDef, + bool IgnoreDefs = false); + bool allDefsAreOnlyUsedOutsideOfTheLoop(const MachineInstr &MI, const MachineLoop *Loop); void UpdateRegPressure(const MachineInstr *MI, - bool ConsiderUnseenAsDef = false); + bool ConsiderUnseenAsDef = false, bool IgnoreDefs = false); + void UpdateRegPressureForUsesOnly(const MachineInstr *MI, + bool ConsiderUnseenAsDef = false); MachineInstr *ExtractHoistableLoad(MachineInstr *MI, MachineLoop *CurLoop); MachineInstr *LookForDuplicate(const MachineInstr *MI, @@ -884,7 +888,7 @@ void MachineLICMImpl::HoistOutOfLoop(MachineDomTreeNode *HeaderN, // Compute registers which are livein into the loop headers. RegSeen.clear(); BackTrace.clear(); - InitRegPressure(Preheader); + InitRegPressure(Preheader, CurLoop); // Now perform LICM. for (MachineDomTreeNode *Node : Scopes) { @@ -934,7 +938,7 @@ static bool isOperandKill(const MachineOperand &MO, MachineRegisterInfo *MRI) { /// Find all virtual register references that are liveout of the preheader to /// initialize the starting "register pressure". Note this does not count live /// through (livein but not used) registers. -void MachineLICMImpl::InitRegPressure(MachineBasicBlock *BB) { +void MachineLICMImpl::InitRegPressure(MachineBasicBlock *BB, const MachineLoop *Loop) { std::fill(RegPressure.begin(), RegPressure.end(), 0); // If the preheader has only a single predecessor and it ends with a @@ -945,17 +949,32 @@ void MachineLICMImpl::InitRegPressure(MachineBasicBlock *BB) { MachineBasicBlock *TBB = nullptr, *FBB = nullptr; SmallVector Cond; if (!TII->analyzeBranch(*BB, TBB, FBB, Cond, false) && Cond.empty()) - InitRegPressure(*BB->pred_begin()); + InitRegPressure(*BB->pred_begin(), Loop); } - for (const MachineInstr &MI : *BB) - UpdateRegPressure(&MI, /*ConsiderUnseenAsDef=*/true); + for (const MachineInstr &MI : *BB) { + bool IgnoreDefs = allDefsAreOnlyUsedOutsideOfTheLoop(MI, Loop); + UpdateRegPressure(&MI, /*ConsiderUnseenAsDef=*/true, IgnoreDefs); + } +} + +bool MachineLICMImpl::allDefsAreOnlyUsedOutsideOfTheLoop(const MachineInstr &MI, const MachineLoop *Loop) { + for (const MachineOperand DefMO : MI.all_defs()) { + if (!DefMO.isReg()) + continue; + for(const MachineInstr &UseMI : MRI->use_instructions(DefMO.getReg())) { + if (Loop->contains(UseMI.getParent())) + return false; + } + } + return true; } /// Update estimate of register pressure after the specified instruction. void MachineLICMImpl::UpdateRegPressure(const MachineInstr *MI, - bool ConsiderUnseenAsDef) { - auto Cost = calcRegisterCost(MI, /*ConsiderSeen=*/true, ConsiderUnseenAsDef); + bool ConsiderUnseenAsDef, + bool IgnoreDefs) { + auto Cost = calcRegisterCost(MI, /*ConsiderSeen=*/true, ConsiderUnseenAsDef, IgnoreDefs); for (const auto &RPIdAndCost : Cost) { unsigned Class = RPIdAndCost.first; if (static_cast(RegPressure[Class]) < -RPIdAndCost.second) @@ -973,7 +992,8 @@ void MachineLICMImpl::UpdateRegPressure(const MachineInstr *MI, /// FIXME: Figure out a way to consider 'RegSeen' from all code paths. SmallDenseMap MachineLICMImpl::calcRegisterCost(const MachineInstr *MI, bool ConsiderSeen, - bool ConsiderUnseenAsDef) { + bool ConsiderUnseenAsDef, + bool IgnoreDefs) { SmallDenseMap Cost; if (MI->isImplicitDef()) return Cost; @@ -991,7 +1011,7 @@ MachineLICMImpl::calcRegisterCost(const MachineInstr *MI, bool ConsiderSeen, RegClassWeight W = TRI->getRegClassWeight(RC); int RCCost = 0; - if (MO.isDef()) + if (MO.isDef() && !IgnoreDefs) RCCost = W.RegWeight; else { bool isKill = isOperandKill(MO, MRI); diff --git a/llvm/lib/CodeGen/PrologEpilogInserter.cpp b/llvm/lib/CodeGen/PrologEpilogInserter.cpp index ee03eaa8ae527c..78b10d665a0157 100644 --- a/llvm/lib/CodeGen/PrologEpilogInserter.cpp +++ b/llvm/lib/CodeGen/PrologEpilogInserter.cpp @@ -36,6 +36,7 @@ #include "llvm/CodeGen/MachineOperand.h" #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/ReachingDefAnalysis.h" #include "llvm/CodeGen/RegisterScavenging.h" #include "llvm/CodeGen/TargetFrameLowering.h" #include "llvm/CodeGen/TargetInstrInfo.h" @@ -95,6 +96,7 @@ class PEI : public MachineFunctionPass { bool runOnMachineFunction(MachineFunction &MF) override; private: + ReachingDefAnalysis *RDA = nullptr; RegScavenger *RS = nullptr; // MinCSFrameIndex, MaxCSFrameIndex - Keeps the range of callee saved @@ -153,6 +155,7 @@ INITIALIZE_PASS_BEGIN(PEI, DEBUG_TYPE, "Prologue/Epilogue Insertion", false, INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass) +INITIALIZE_PASS_DEPENDENCY(ReachingDefAnalysis) INITIALIZE_PASS_END(PEI, DEBUG_TYPE, "Prologue/Epilogue Insertion & Frame Finalization", false, false) @@ -169,6 +172,7 @@ void PEI::getAnalysisUsage(AnalysisUsage &AU) const { AU.addPreserved(); AU.addPreserved(); AU.addRequired(); + AU.addRequired(); MachineFunctionPass::getAnalysisUsage(AU); } @@ -227,6 +231,7 @@ bool PEI::runOnMachineFunction(MachineFunction &MF) { RS = TRI->requiresRegisterScavenging(MF) ? new RegScavenger() : nullptr; FrameIndexVirtualScavenging = TRI->requiresFrameIndexScavenging(MF); ORE = &getAnalysis().getORE(); + RDA = &getAnalysis(); // Spill frame pointer and/or base pointer registers if they are clobbered. // It is placed before call frame instruction elimination so it will not mess @@ -262,6 +267,7 @@ bool PEI::runOnMachineFunction(MachineFunction &MF) { // called functions. Because of this, calculateCalleeSavedRegisters() // must be called before this function in order to set the AdjustsStack // and MaxCallFrameSize variables. + RDA->reset(); if (!F.hasFnAttribute(Attribute::Naked)) insertPrologEpilogCode(MF); @@ -1164,6 +1170,7 @@ void PEI::calculateFrameObjectOffsets(MachineFunction &MF) { void PEI::insertPrologEpilogCode(MachineFunction &MF) { const TargetFrameLowering &TFI = *MF.getSubtarget().getFrameLowering(); + TFI.emitCFIsForCSRsHandledByRA(MF, RDA); // Add prologue to the function... for (MachineBasicBlock *SaveBlock : SaveBlocks) TFI.emitPrologue(MF, *SaveBlock); diff --git a/llvm/lib/CodeGen/ReachingDefAnalysis.cpp b/llvm/lib/CodeGen/ReachingDefAnalysis.cpp index 0e8220ec6251cb..2120d15465ff9a 100644 --- a/llvm/lib/CodeGen/ReachingDefAnalysis.cpp +++ b/llvm/lib/CodeGen/ReachingDefAnalysis.cpp @@ -10,6 +10,8 @@ #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallSet.h" #include "llvm/CodeGen/LiveRegUnits.h" +#include "llvm/CodeGen/MachineFrameInfo.h" +#include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/Support/Debug.h" @@ -48,12 +50,31 @@ static bool isValidRegDefOf(const MachineOperand &MO, MCRegister PhysReg, return TRI->regsOverlap(MO.getReg(), PhysReg); } +static bool isFIDef(const MachineInstr &MI, int FrameIndex, const TargetInstrInfo *TII) { + int DefFrameIndex = 0; + int SrcFrameIndex = 0; + if ( + TII->isStoreToStackSlot(MI, DefFrameIndex) || + TII->isStackSlotCopy(MI, DefFrameIndex, SrcFrameIndex) + ) { + return DefFrameIndex == FrameIndex; + } + return false; +} + + void ReachingDefAnalysis::enterBasicBlock(MachineBasicBlock *MBB) { unsigned MBBNumber = MBB->getNumber(); assert(MBBNumber < MBBReachingDefs.numBlockIDs() && "Unexpected basic block number."); MBBReachingDefs.startBasicBlock(MBBNumber, NumRegUnits); + MBBFrameObjsReachingDefs[MBBNumber].resize(NumStackObjects); + for (unsigned FOIdx = 0; FOIdx < NumStackObjects; ++FOIdx) { + MBBFrameObjsReachingDefs[MBBNumber][FOIdx].push_back(-1); + } + + // Reset instruction counter in each basic block. CurInstr = 0; @@ -126,6 +147,12 @@ void ReachingDefAnalysis::processDefs(MachineInstr *MI) { "Unexpected basic block number."); for (auto &MO : MI->operands()) { + if (MO.isFI()) { + int FrameIndex = MO.getIndex(); + if (!isFIDef(*MI, FrameIndex, TII)) + continue; + MBBFrameObjsReachingDefs[MBBNumber][FrameIndex - ObjectIndexBegin].push_back(CurInstr); + } if (!isValidRegDef(MO)) continue; for (MCRegUnit Unit : TRI->regunits(MO.getReg().asMCReg())) { @@ -211,7 +238,9 @@ void ReachingDefAnalysis::processBasicBlock( bool ReachingDefAnalysis::runOnMachineFunction(MachineFunction &mf) { MF = &mf; - TRI = MF->getSubtarget().getRegisterInfo(); + const TargetSubtargetInfo &STI = MF->getSubtarget(); + TRI = STI.getRegisterInfo(); + TII = STI.getInstrInfo(); LLVM_DEBUG(dbgs() << "********** REACHING DEFINITION ANALYSIS **********\n"); init(); traverse(); @@ -222,6 +251,7 @@ void ReachingDefAnalysis::releaseMemory() { // Clear the internal vectors. MBBOutRegsInfos.clear(); MBBReachingDefs.clear(); + MBBFrameObjsReachingDefs.clear(); InstIds.clear(); LiveRegs.clear(); } @@ -234,7 +264,10 @@ void ReachingDefAnalysis::reset() { void ReachingDefAnalysis::init() { NumRegUnits = TRI->getNumRegUnits(); + NumStackObjects = MF->getFrameInfo().getNumObjects(); + ObjectIndexBegin = MF->getFrameInfo().getObjectIndexBegin(); MBBReachingDefs.init(MF->getNumBlockIDs()); + MBBFrameObjsReachingDefs.resize(MF->getNumBlockIDs()); // Initialize the MBBOutRegsInfos MBBOutRegsInfos.resize(MF->getNumBlockIDs()); LoopTraversal Traversal; @@ -269,6 +302,18 @@ int ReachingDefAnalysis::getReachingDef(MachineInstr *MI, assert(MBBNumber < MBBReachingDefs.numBlockIDs() && "Unexpected basic block number."); int LatestDef = ReachingDefDefaultVal; + + if (Register::isStackSlot(PhysReg)) { + int FrameIndex = Register::stackSlot2Index(PhysReg); + for (int Def : MBBFrameObjsReachingDefs[MBBNumber][FrameIndex - ObjectIndexBegin]) { + if (Def >= InstId) + break; + DefRes = Def; + } + LatestDef = std::max(LatestDef, DefRes); + return LatestDef; + } + for (MCRegUnit Unit : TRI->regunits(PhysReg)) { for (int Def : MBBReachingDefs.defs(MBBNumber, Unit)) { if (Def >= InstId) @@ -425,7 +470,7 @@ void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB, VisitedBBs.insert(MBB); LiveRegUnits LiveRegs(*TRI); LiveRegs.addLiveOuts(*MBB); - if (LiveRegs.available(PhysReg)) + if (Register::isPhysicalRegister(PhysReg) && LiveRegs.available(PhysReg)) return; if (auto *Def = getLocalLiveOutMIDef(MBB, PhysReg)) @@ -508,7 +553,7 @@ bool ReachingDefAnalysis::isReachingDefLiveOut(MachineInstr *MI, MachineBasicBlock *MBB = MI->getParent(); LiveRegUnits LiveRegs(*TRI); LiveRegs.addLiveOuts(*MBB); - if (LiveRegs.available(PhysReg)) + if (Register::isPhysicalRegister(PhysReg) && LiveRegs.available(PhysReg)) return false; auto Last = MBB->getLastNonDebugInstr(); @@ -529,7 +574,7 @@ ReachingDefAnalysis::getLocalLiveOutMIDef(MachineBasicBlock *MBB, MCRegister PhysReg) const { LiveRegUnits LiveRegs(*TRI); LiveRegs.addLiveOuts(*MBB); - if (LiveRegs.available(PhysReg)) + if (Register::isPhysicalRegister(PhysReg) && LiveRegs.available(PhysReg)) return nullptr; auto Last = MBB->getLastNonDebugInstr(); @@ -537,6 +582,13 @@ ReachingDefAnalysis::getLocalLiveOutMIDef(MachineBasicBlock *MBB, return nullptr; int Def = getReachingDef(&*Last, PhysReg); + + if (Register::isStackSlot(PhysReg)) { + int FrameIndex = Register::stackSlot2Index(PhysReg); + if (isFIDef(*Last, FrameIndex, TII)) + return &*Last; + } + for (auto &MO : Last->operands()) if (isValidRegDefOf(MO, PhysReg, TRI)) return &*Last; diff --git a/llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp b/llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp index a1dccc4d59723b..023c03c5a2a922 100644 --- a/llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp +++ b/llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp @@ -44,6 +44,11 @@ static cl::opt EnableLocalReassignment( "may be compile time intensive"), cl::init(false)); +static cl::opt MinWeightRatioNeededToEvictHint( + "min-weight-ratio-needed-to-evict-hint", cl::Hidden, + cl::desc("The minimum ration of weight needed in order for a live range with bigger weight to evict another live range which satisfies a hint"), + cl::init(1.0)); + namespace llvm { cl::opt EvictInterferenceCutoff( "regalloc-eviction-max-interference-cutoff", cl::Hidden, @@ -156,8 +161,14 @@ bool DefaultEvictionAdvisor::shouldEvict(const LiveInterval &A, bool IsHint, if (CanSplit && IsHint && !BreaksHint) return true; - if (A.weight() > B.weight()) { - LLVM_DEBUG(dbgs() << "should evict: " << B << " w= " << B.weight() << '\n'); + float AWeight = A.weight(); + float BWeight = B.weight(); + if (AWeight > BWeight) { + float WeightRatio = BWeight == 0.0 ? std::numeric_limits::infinity() : AWeight / BWeight; + if (CanSplit && !IsHint && BreaksHint && (WeightRatio < MinWeightRatioNeededToEvictHint)) { + return false; + } + LLVM_DEBUG(dbgs() << "should evict: " << B << " w= " << BWeight << '\n'); return true; } return false; diff --git a/llvm/lib/CodeGen/TargetSubtargetInfo.cpp b/llvm/lib/CodeGen/TargetSubtargetInfo.cpp index 6c97bc0568bdee..566d5420c638ef 100644 --- a/llvm/lib/CodeGen/TargetSubtargetInfo.cpp +++ b/llvm/lib/CodeGen/TargetSubtargetInfo.cpp @@ -45,6 +45,10 @@ bool TargetSubtargetInfo::enableRALocalReassignment( return true; } +bool TargetSubtargetInfo::doCSRSavesInRA() const { + return false; +} + bool TargetSubtargetInfo::enablePostRAScheduler() const { return getSchedModel().PostRAScheduler; } diff --git a/llvm/lib/Target/RISCV/CMakeLists.txt b/llvm/lib/Target/RISCV/CMakeLists.txt index fd049d1a57860e..e8897ed14dcea1 100644 --- a/llvm/lib/Target/RISCV/CMakeLists.txt +++ b/llvm/lib/Target/RISCV/CMakeLists.txt @@ -29,6 +29,7 @@ add_public_tablegen_target(RISCVCommonTableGen) add_llvm_target(RISCVCodeGen RISCVAsmPrinter.cpp + RISCVCFIInserter.cpp RISCVCallingConv.cpp RISCVCodeGenPrepare.cpp RISCVConstantPoolValue.cpp diff --git a/llvm/lib/Target/RISCV/RISCV.h b/llvm/lib/Target/RISCV/RISCV.h index d7bab601d545cc..65d6f7725726f8 100644 --- a/llvm/lib/Target/RISCV/RISCV.h +++ b/llvm/lib/Target/RISCV/RISCV.h @@ -105,6 +105,9 @@ void initializeRISCVPreLegalizerCombinerPass(PassRegistry &); FunctionPass *createRISCVVLOptimizerPass(); void initializeRISCVVLOptimizerPass(PassRegistry &); + +FunctionPass *createRISCVCFIInstrInserter(); +void initializeRISCVCFIInstrInserterPass(PassRegistry &); } // namespace llvm #endif diff --git a/llvm/lib/Target/RISCV/RISCVCFIInserter.cpp b/llvm/lib/Target/RISCV/RISCVCFIInserter.cpp new file mode 100644 index 00000000000000..9f6109524b314e --- /dev/null +++ b/llvm/lib/Target/RISCV/RISCVCFIInserter.cpp @@ -0,0 +1,569 @@ +//===------ RISCVCFIInstrInserter.cpp - Insert additional CFI instructions -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file This pass verifies incoming and outgoing CFA information of basic +/// blocks. CFA information is information about offset and register set by CFI +/// directives, valid at the start and end of a basic block. This pass checks +/// that outgoing information of predecessors matches incoming information of +/// their successors. Then it checks if blocks have correct CFA calculation rule +/// set and inserts additional CFI instruction at their beginnings if they +/// don't. CFI instructions are inserted if basic blocks have incorrect offset +/// or register set by previous blocks, as a result of a non-linear layout of +/// blocks in a function. +//===----------------------------------------------------------------------===// + +#include "RISCV.h" +#include "RISCVMachineFunctionInfo.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/BinaryFormat/Dwarf.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/CodeGen/TargetFrameLowering.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/InitializePasses.h" +#include "llvm/MC/MCDwarf.h" +#include "llvm/Support/LEB128.h" + +using namespace llvm; + +#define DEBUG_TYPE "riscv-cfi-inserter" + +//static cl::opt VerifyCFI("verify-cfiinstrs", +// cl::desc("Verify Call Frame Information instructions"), +// cl::init(false), +// cl::Hidden); + +namespace { +class RISCVCFIInstrInserter : public MachineFunctionPass { + public: + static char ID; + + RISCVCFIInstrInserter() : MachineFunctionPass(ID) { + initializeRISCVCFIInstrInserterPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + MachineFunctionPass::getAnalysisUsage(AU); + } + + bool runOnMachineFunction(MachineFunction &MF) override { + if (!MF.needsFrameMoves()) + return false; + + if (!MF.getSubtarget().doCSRSavesInRA()) + return false; + + RVFI = MF.getInfo(); + MBBVector.resize(MF.getNumBlockIDs()); + calculateCFAInfo(MF); + + //if (VerifyCFI) { + // if (unsigned ErrorNum = verify(MF)) + // report_fatal_error("Found " + Twine(ErrorNum) + + // " in/out CFI information errors."); + //} + bool insertedCFI = insertCFIInstrs(MF); + MBBVector.clear(); + return insertedCFI; + } + + private: +#define INVALID_REG UINT_MAX +#define INVALID_OFFSET INT_MAX + /// contains the location where CSR register is saved. + /// Registers are recorded by their Dwarf numbers. + struct CSRLocation { + bool IsReg = true; + int Reg = 0; + int FrameReg = 0; + int Offset = 0; + bool isEqual(const CSRLocation &Other) const { + if (IsReg) + return Other.IsReg ? (Reg == Other.Reg) : false; + return !Other.IsReg ? ((Offset == Other.Offset) && FrameReg == Other.FrameReg) : false; + } + }; + + struct MBBCFAInfo { + MachineBasicBlock *MBB; + /// Value of cfa offset valid at basic block entry. + int IncomingCFAOffset = -1; + /// Value of cfa offset valid at basic block exit. + int OutgoingCFAOffset = -1; + /// Value of cfa register valid at basic block entry. + int IncomingCFARegister = 0; + /// Value of cfa register valid at basic block exit. + int OutgoingCFARegister = 0; + /// Set of callee saved registers saved at basic block entry. + SmallVector IncomingCSRLocations; + /// Set of callee saved registers saved at basic block exit. + SmallVector OutgoingCSRLocations; + /// If in/out cfa offset and register values for this block have already + /// been set or not. + bool Processed = false; + }; + + RISCVMachineFunctionInfo *RVFI; + /// Contains cfa offset and register values valid at entry and exit of basic + /// blocks. + std::vector MBBVector; + + /// Calculate cfa offset and register values valid at entry and exit for all + /// basic blocks in a function. + void calculateCFAInfo(MachineFunction &MF); + /// Calculate cfa offset and register values valid at basic block exit by + /// checking the block for CFI instructions. Block's incoming CFA info remains + /// the same. + void calculateOutgoingCFAInfo(MBBCFAInfo &MBBInfo); + /// Update in/out cfa offset and register values for successors of the basic + /// block. + void updateSuccCFAInfo(MBBCFAInfo &MBBInfo); + + /// Check if incoming CFA information of a basic block matches outgoing CFA + /// information of the previous block. If it doesn't, insert CFI instruction + /// at the beginning of the block that corrects the CFA calculation rule for + /// that block. + bool insertCFIInstrs(MachineFunction &MF); + /// Return the cfa offset value that should be set at the beginning of a MBB + /// if needed. The negated value is needed when creating CFI instructions that + /// set absolute offset. + int getCorrectCFAOffset(MachineBasicBlock *MBB) { + return MBBVector[MBB->getNumber()].IncomingCFAOffset; + } + + void reportCFAError(const MBBCFAInfo &Pred, const MBBCFAInfo &Succ); + void reportCSRError(const MBBCFAInfo &Pred, const MBBCFAInfo &Succ); + /// Go through each MBB in a function and check that outgoing offset and + /// register of its predecessors match incoming offset and register of that + /// MBB, as well as that incoming offset and register of its successors match + /// outgoing offset and register of the MBB. + unsigned verify(MachineFunction &MF); +}; +} // namespace + +char RISCVCFIInstrInserter::ID = 0; +INITIALIZE_PASS(RISCVCFIInstrInserter, "cfi-instr-inserter", + "Check CFA info and insert CFI instructions if needed", false, + false) +FunctionPass *llvm::createRISCVCFIInstrInserter() { return new RISCVCFIInstrInserter(); } + +void RISCVCFIInstrInserter::calculateCFAInfo(MachineFunction &MF) { + const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo(); + // Initial CFA offset value i.e. the one valid at the beginning of the + // function. + int InitialOffset = + MF.getSubtarget().getFrameLowering()->getInitialCFAOffset(MF); + // Initial CFA register value i.e. the one valid at the beginning of the + // function. + int InitialRegister = + TRI.getDwarfRegNum(MF.getSubtarget().getFrameLowering()->getInitialCFARegister(MF), true); + unsigned NumRegs = TRI.getNumSupportedRegs(MF); + + // Initialize MBBMap. + for (MachineBasicBlock &MBB : MF) { + MBBCFAInfo &MBBInfo = MBBVector[MBB.getNumber()]; + MBBInfo.MBB = &MBB; + MBBInfo.IncomingCFAOffset = InitialOffset; + MBBInfo.OutgoingCFAOffset = InitialOffset; + MBBInfo.IncomingCFARegister = InitialRegister; + MBBInfo.OutgoingCFARegister = InitialRegister; + MBBInfo.IncomingCSRLocations.resize(NumRegs); + MBBInfo.OutgoingCSRLocations.resize(NumRegs); + } + + MBBCFAInfo &EntryMBBInfo = MBBVector[MF.front().getNumber()]; + const MCPhysReg * CSRegs = MF.getRegInfo().getCalleeSavedRegs(); + for (int i = 0; CSRegs[i]; ++i) { + unsigned Reg = TRI.getDwarfRegNum(CSRegs[i], true); + CSRLocation &CSRLoc = EntryMBBInfo.IncomingCSRLocations[Reg]; + CSRLoc.IsReg = true; + CSRLoc.Reg = Reg; + } + // Set in/out cfa info for all blocks in the function. This traversal is based + // on the assumption that the first block in the function is the entry block + // i.e. that it has initial cfa offset and register values as incoming CFA + // information. + updateSuccCFAInfo(MBBVector[MF.front().getNumber()]); + + LLVM_DEBUG( + dbgs() << "Calculated CFI info for " << MF.getName() << "\n"; + for (MachineBasicBlock &MBB : MF) { + dbgs() << "BasicBlock: " <getParent(); + const std::vector &Instrs = MF->getFrameInstructions(); + + int &OutgoingCFAOffset = MBBInfo.OutgoingCFAOffset; + int &OutgoingCFARegister = MBBInfo.OutgoingCFARegister; + SmallVector &OutgoingCSRLocations = MBBInfo.OutgoingCSRLocations; + + OutgoingCSRLocations = MBBInfo.IncomingCSRLocations; + // Determine cfa offset and register set by the block. + for (MachineInstr &MI : *MBBInfo.MBB) { + if (MI.isCFIInstruction()) { + unsigned CFIIndex = MI.getOperand(0).getCFIIndex(); + const MCCFIInstruction &CFI = Instrs[CFIIndex]; + switch (CFI.getOperation()) { + case MCCFIInstruction::OpDefCfaRegister: { + int Reg = CFI.getRegister(); + assert(Reg >= 0 && "Negative dwarf register number!"); + OutgoingCFARegister = Reg; + break; + } + case MCCFIInstruction::OpDefCfaOffset: { + OutgoingCFAOffset = CFI.getOffset(); + break; + } + case MCCFIInstruction::OpAdjustCfaOffset: { + OutgoingCFAOffset += CFI.getOffset(); + break; + } + case MCCFIInstruction::OpDefCfa: { + int Reg = CFI.getRegister(); + assert(Reg >= 0 && "Negative dwarf register number!"); + OutgoingCFARegister = Reg; + OutgoingCFAOffset = CFI.getOffset(); + break; + } + case MCCFIInstruction::OpOffset: { + int Reg = CFI.getRegister(); + assert(Reg >= 0 && "Negative dwarf register number!"); + OutgoingCSRLocations[Reg].Offset = CFI.getOffset(); + OutgoingCSRLocations[Reg].FrameReg = CFI.getOffset(); + OutgoingCSRLocations[Reg].IsReg = false; + break; + } + case MCCFIInstruction::OpEscape: { + int Reg; + int FrameReg; + int64_t Offset; + bool isRegPlusOffset = RVFI->getCFIInfo(&MI, Reg, FrameReg, Offset); + if (!isRegPlusOffset) { + break; + } + assert(Reg >= 0 && "Negative dwarf register number!"); + assert(FrameReg >= 0 && "Negative dwarf register number!"); + OutgoingCSRLocations[Reg].IsReg = false; + OutgoingCSRLocations[Reg].Offset = Offset; + OutgoingCSRLocations[Reg].FrameReg = FrameReg; + break; + } + case MCCFIInstruction::OpRegister: { + int Reg = CFI.getRegister(); + assert(Reg >= 0 && "Negative dwarf register number!"); + int Reg2 = CFI.getRegister(); + assert(Reg2 >= 0 && "Negative dwarf register number!"); + OutgoingCSRLocations[Reg].Reg = Reg2; + OutgoingCSRLocations[Reg].IsReg = true; + break; + } + case MCCFIInstruction::OpRelOffset: + report_fatal_error( + "Support for .cfi_rel_offset not implemented! Value of CFA " + "may be incorrect!\n"); + break; + case MCCFIInstruction::OpRestore: + report_fatal_error( + "Support for .cfi_restore not implemented! Value of CFA " + "may be incorrect!\n"); + break; + case MCCFIInstruction::OpLLVMDefAspaceCfa: + // TODO: Add support for handling cfi_def_aspace_cfa. +#ifndef NDEBUG + report_fatal_error( + "Support for cfi_llvm_def_aspace_cfa not implemented! Value of CFA " + "may be incorrect!\n"); +#endif + break; + case MCCFIInstruction::OpRememberState: + // TODO: Add support for handling cfi_remember_state. +#ifndef NDEBUG + report_fatal_error( + "Support for cfi_remember_state not implemented! Value of CFA " + "may be incorrect!\n"); +#endif + break; + case MCCFIInstruction::OpRestoreState: + // TODO: Add support for handling cfi_restore_state. +#ifndef NDEBUG + report_fatal_error( + "Support for cfi_restore_state not implemented! Value of CFA may " + "be incorrect!\n"); +#endif + break; + case MCCFIInstruction::OpUndefined: + case MCCFIInstruction::OpSameValue: + case MCCFIInstruction::OpWindowSave: + case MCCFIInstruction::OpNegateRAState: + case MCCFIInstruction::OpGnuArgsSize: + break; + } + } + } + + MBBInfo.Processed = true; +} + +void RISCVCFIInstrInserter::updateSuccCFAInfo(MBBCFAInfo &MBBInfo) { + SmallVector Stack; + Stack.push_back(MBBInfo.MBB); + + do { + MachineBasicBlock *Current = Stack.pop_back_val(); + MBBCFAInfo &CurrentInfo = MBBVector[Current->getNumber()]; + calculateOutgoingCFAInfo(CurrentInfo); + for (auto *Succ : CurrentInfo.MBB->successors()) { + MBBCFAInfo &SuccInfo = MBBVector[Succ->getNumber()]; + if (!SuccInfo.Processed) { + SuccInfo.IncomingCFAOffset = CurrentInfo.OutgoingCFAOffset; + SuccInfo.IncomingCFARegister = CurrentInfo.OutgoingCFARegister; + SuccInfo.IncomingCSRLocations = CurrentInfo.OutgoingCSRLocations; + Stack.push_back(Succ); + } + } + } while (!Stack.empty()); +} + +bool RISCVCFIInstrInserter::insertCFIInstrs(MachineFunction &MF) { + const MBBCFAInfo *PrevMBBInfo = &MBBVector[MF.front().getNumber()]; + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + bool InsertedCFIInstr = false; + + BitVector SetDifference; + for (MachineBasicBlock &MBB : MF) { + // Skip the first MBB in a function + if (MBB.getNumber() == MF.front().getNumber()) continue; + + const MBBCFAInfo &MBBInfo = MBBVector[MBB.getNumber()]; + auto MBBI = MBBInfo.MBB->begin(); + DebugLoc DL = MBBInfo.MBB->findDebugLoc(MBBI); + + // If the current MBB will be placed in a unique section, a full DefCfa + // must be emitted. + const bool ForceFullCFA = MBB.isBeginSection(); + + if ((PrevMBBInfo->OutgoingCFAOffset != MBBInfo.IncomingCFAOffset && + PrevMBBInfo->OutgoingCFARegister != MBBInfo.IncomingCFARegister) || + ForceFullCFA) { + // If both outgoing offset and register of a previous block don't match + // incoming offset and register of this block, or if this block begins a + // section, add a def_cfa instruction with the correct offset and + // register for this block. + unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::cfiDefCfa( + nullptr, MBBInfo.IncomingCFARegister, getCorrectCFAOffset(&MBB))); + BuildMI(*MBBInfo.MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION)) + .addCFIIndex(CFIIndex); + InsertedCFIInstr = true; + } else if (PrevMBBInfo->OutgoingCFAOffset != MBBInfo.IncomingCFAOffset) { + // If outgoing offset of a previous block doesn't match incoming offset + // of this block, add a def_cfa_offset instruction with the correct + // offset for this block. + unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::cfiDefCfaOffset( + nullptr, getCorrectCFAOffset(&MBB))); + BuildMI(*MBBInfo.MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION)) + .addCFIIndex(CFIIndex); + InsertedCFIInstr = true; + } else if (PrevMBBInfo->OutgoingCFARegister != + MBBInfo.IncomingCFARegister) { + unsigned CFIIndex = + MF.addFrameInst(MCCFIInstruction::createDefCfaRegister( + nullptr, MBBInfo.IncomingCFARegister)); + BuildMI(*MBBInfo.MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION)) + .addCFIIndex(CFIIndex); + InsertedCFIInstr = true; + } + + if (ForceFullCFA) { + MF.getSubtarget().getFrameLowering()->emitCalleeSavedFrameMovesFullCFA( + *MBBInfo.MBB, MBBI); + InsertedCFIInstr = true; + PrevMBBInfo = &MBBInfo; + continue; + } + + for (unsigned i = 0; i < PrevMBBInfo->OutgoingCSRLocations.size(); ++i) { + const CSRLocation &OutgoingCSRLoc = PrevMBBInfo->OutgoingCSRLocations[i]; + const CSRLocation &IncomingCSRLoc = MBBInfo.IncomingCSRLocations[i]; + if (IncomingCSRLoc.IsReg && (IncomingCSRLoc.Reg == 0)) + continue; + if (MBBInfo.IncomingCSRLocations[i].isEqual(OutgoingCSRLoc)) + continue; + unsigned CFIIndex; + if (IncomingCSRLoc.IsReg) { + CFIIndex = MF.addFrameInst( + MCCFIInstruction::createRegister(nullptr, i, IncomingCSRLoc.Reg) + ); + } + else { + //CFIIndex = MF.addFrameInst( + // MCCFIInstruction::createOffset(nullptr, i, IncomingCSRLoc.Offset) + //); + std::string CommentBuffer; + llvm::raw_string_ostream Comment(CommentBuffer); + int DwarfRegSP = IncomingCSRLoc.FrameReg; + int DwarfEHRegNum = i; + int64_t FixedOffset = IncomingCSRLoc.Offset; + // Build up the expression (SP + FixedOffset) + SmallString<64> Expr; + uint8_t Buffer[16]; + + Comment << FixedOffset; + //0x11 + Expr.push_back(dwarf::DW_OP_consts); + Expr.append(Buffer, Buffer + encodeSLEB128(FixedOffset, Buffer)); + + //0x92 + Expr.push_back((uint8_t)dwarf::DW_OP_bregx); + //0x02 + Expr.append(Buffer, Buffer + encodeULEB128(DwarfRegSP, Buffer)); + Expr.push_back(0); + + //0x22 + Expr.push_back((uint8_t)dwarf::DW_OP_plus); + // Wrap this into DW_CFA_def_cfa. + SmallString<64> DefCfaExpr; + // 0x10 + DefCfaExpr.push_back(dwarf::DW_CFA_expression); + DefCfaExpr.append(Buffer, Buffer + encodeULEB128(DwarfEHRegNum, Buffer)); + DefCfaExpr.append(Buffer, Buffer + encodeULEB128(Expr.size(), Buffer)); + DefCfaExpr.append(Expr.str()); + CFIIndex = MF.addFrameInst( + MCCFIInstruction::createEscape( + nullptr, + DefCfaExpr.str(), + SMLoc(), + Comment.str() + ) + ); + } + BuildMI(*MBBInfo.MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION)) + .addCFIIndex(CFIIndex); + InsertedCFIInstr = true; + } + //BitVector::apply([](auto x, auto y) { return x & ~y; }, SetDifference, + // PrevMBBInfo->OutgoingCSRSaved, MBBInfo.IncomingCSRSaved); + //for (int Reg : SetDifference.set_bits()) { + // unsigned CFIIndex = + // MF.addFrameInst(MCCFIInstruction::createRestore(nullptr, Reg)); + // BuildMI(*MBBInfo.MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION)) + // .addCFIIndex(CFIIndex); + // InsertedCFIInstr = true; + //} + + //BitVector::apply([](auto x, auto y) { return x & ~y; }, SetDifference, + // MBBInfo.IncomingCSRSaved, PrevMBBInfo->OutgoingCSRSaved); + //for (int Reg : SetDifference.set_bits()) { + // auto it = CSRLocMap.find(Reg); + // assert(it != CSRLocMap.end() && "Reg should have an entry in CSRLocMap"); + // unsigned CFIIndex; + // CSRSavedLocation RO = it->second; + // if (!RO.Reg && RO.Offset) { + // CFIIndex = MF.addFrameInst( + // MCCFIInstruction::createOffset(nullptr, Reg, *RO.Offset)); + // } else if (RO.Reg && !RO.Offset) { + // CFIIndex = MF.addFrameInst( + // MCCFIInstruction::createRegister(nullptr, Reg, *RO.Reg)); + // } else { + // llvm_unreachable("RO.Reg and RO.Offset cannot both be valid/invalid"); + // } + // BuildMI(*MBBInfo.MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION)) + // .addCFIIndex(CFIIndex); + // InsertedCFIInstr = true; + //} + + PrevMBBInfo = &MBBInfo; + } + return InsertedCFIInstr; +} + +//void RISCVCFIInstrInserter::reportCFAError(const MBBCFAInfo &Pred, +// const MBBCFAInfo &Succ) { +// errs() << "*** Inconsistent CFA register and/or offset between pred and succ " +// "***\n"; +// errs() << "Pred: " << Pred.MBB->getName() << " #" << Pred.MBB->getNumber() +// << " in " << Pred.MBB->getParent()->getName() +// << " outgoing CFA Reg:" << Pred.OutgoingCFARegister << "\n"; +// errs() << "Pred: " << Pred.MBB->getName() << " #" << Pred.MBB->getNumber() +// << " in " << Pred.MBB->getParent()->getName() +// << " outgoing CFA Offset:" << Pred.OutgoingCFAOffset << "\n"; +// errs() << "Succ: " << Succ.MBB->getName() << " #" << Succ.MBB->getNumber() +// << " incoming CFA Reg:" << Succ.IncomingCFARegister << "\n"; +// errs() << "Succ: " << Succ.MBB->getName() << " #" << Succ.MBB->getNumber() +// << " incoming CFA Offset:" << Succ.IncomingCFAOffset << "\n"; +//} +// +//void RISCVCFIInstrInserter::reportCSRError(const MBBCFAInfo &Pred, +// const MBBCFAInfo &Succ) { +// errs() << "*** Inconsistent CSR Saved between pred and succ in function " +// << Pred.MBB->getParent()->getName() << " ***\n"; +// errs() << "Pred: " << Pred.MBB->getName() << " #" << Pred.MBB->getNumber() +// << " outgoing CSR Saved: "; +// for (int Reg : Pred.OutgoingCSRSaved.set_bits()) +// errs() << Reg << " "; +// errs() << "\n"; +// errs() << "Succ: " << Succ.MBB->getName() << " #" << Succ.MBB->getNumber() +// << " incoming CSR Saved: "; +// for (int Reg : Succ.IncomingCSRSaved.set_bits()) +// errs() << Reg << " "; +// errs() << "\n"; +//} + +//unsigned RISCVCFIInstrInserter::verify(MachineFunction &MF) { +// unsigned ErrorNum = 0; +// for (auto *CurrMBB : depth_first(&MF)) { +// const MBBCFAInfo &CurrMBBInfo = MBBVector[CurrMBB->getNumber()]; +// for (MachineBasicBlock *Succ : CurrMBB->successors()) { +// const MBBCFAInfo &SuccMBBInfo = MBBVector[Succ->getNumber()]; +// // Check that incoming offset and register values of successors match the +// // outgoing offset and register values of CurrMBB +// if (SuccMBBInfo.IncomingCFAOffset != CurrMBBInfo.OutgoingCFAOffset || +// SuccMBBInfo.IncomingCFARegister != CurrMBBInfo.OutgoingCFARegister) { +// // Inconsistent offsets/registers are ok for 'noreturn' blocks because +// // we don't generate epilogues inside such blocks. +// if (SuccMBBInfo.MBB->succ_empty() && !SuccMBBInfo.MBB->isReturnBlock()) +// continue; +// reportCFAError(CurrMBBInfo, SuccMBBInfo); +// ErrorNum++; +// } +// // Check that IncomingCSRSaved of every successor matches the +// // OutgoingCSRSaved of CurrMBB +// if (SuccMBBInfo.IncomingCSRSaved != CurrMBBInfo.OutgoingCSRSaved) { +// reportCSRError(CurrMBBInfo, SuccMBBInfo); +// ErrorNum++; +// } +// } +// } +// return ErrorNum; +//} diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp index f5851f37154519..16d11d7d320b55 100644 --- a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp @@ -18,12 +18,14 @@ #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/ReachingDefAnalysis.h" #include "llvm/CodeGen/RegisterScavenging.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/MC/MCDwarf.h" #include "llvm/Support/LEB128.h" #include +#include using namespace llvm; @@ -525,6 +527,181 @@ static MCCFIInstruction createDefCFAOffset(const TargetRegisterInfo &TRI, Comment.str()); } +struct CFIBuildInfo { + MachineBasicBlock *MBB; + MachineInstr *InsertAfterMI; // nullptr means insert at MBB.begin() + DebugLoc DL; + unsigned CFIIndex; + bool ShouldRecord = false; + int DwarfEHRegNum = 0; + int DwarfFrameReg = 0; + int64_t FixedOffset = 0; +}; + +static void trackRegisterAndEmitCFIs( + MachineFunction &MF, + MachineInstr &MI, + MCRegister Reg, + int DwarfEHRegNum, + const ReachingDefAnalysis &RDA, + const TargetInstrInfo &TII, + const MachineFrameInfo &MFI, + const RISCVRegisterInfo &TRI, + std::vector &CFIBuildInfos, + std::unordered_set &VisitedRestorePoints, + std::unordered_set &VisitedDefs +) { + + if (VisitedRestorePoints.find(&MI) != VisitedRestorePoints.end()) { + return; + } + VisitedRestorePoints.insert(&MI); + SmallPtrSet Defs; + RDA.getGlobalReachingDefs(&MI, Reg, Defs); + MachineBasicBlock &EntryMBB = MF.front(); + if (Defs.empty()) { + // it's a live-in register at the entry block. + //unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::createSameValue(nullptr, DwarfEHRegNum)); + //CFIBuildInfos.push_back({&EntryMBB, nullptr, DebugLoc(), CFIIndex}); + return; + } + + int FrameIndex = std::numeric_limits::min(); + for (MachineInstr *Def : Defs) { + if (VisitedDefs.find(Def) != VisitedDefs.end()) + continue; + VisitedDefs.insert(Def); + + MachineBasicBlock &MBB = *Def->getParent(); + const DebugLoc &DL = Def->getDebugLoc(); + + if (Register StoredReg = TII.isStoreToStackSlot(*Def, FrameIndex)) { + assert(FrameIndex == Register::stackSlot2Index(Reg)); + + Register FrameReg; + StackOffset Offset = MF.getSubtarget().getFrameLowering()->getFrameIndexReference(MF, FrameIndex, FrameReg); + int64_t FixedOffset = Offset.getFixed(); + // TODO: + assert(Offset.getScalable() == 0); + + // TODO: use getSPReg + std::string CommentBuffer; + llvm::raw_string_ostream Comment(CommentBuffer); + int DwarfFrameReg = TRI.getDwarfRegNum(FrameReg, true); + // Build up the expression (SP + FixedOffset) + SmallString<64> Expr; + uint8_t Buffer[16]; + + Comment << FixedOffset; + //0x11 + Expr.push_back(dwarf::DW_OP_consts); + Expr.append(Buffer, Buffer + encodeSLEB128(FixedOffset, Buffer)); + + //0x92 + Expr.push_back((uint8_t)dwarf::DW_OP_bregx); + //0x02 + Expr.append(Buffer, Buffer + encodeULEB128(DwarfFrameReg, Buffer)); + Expr.push_back(0); + + //0x22 + Expr.push_back((uint8_t)dwarf::DW_OP_plus); + // Wrap this into DW_CFA_def_cfa. + SmallString<64> DefCfaExpr; + // 0x10 + DefCfaExpr.push_back(dwarf::DW_CFA_expression); + DefCfaExpr.append(Buffer, Buffer + encodeULEB128(DwarfEHRegNum, Buffer)); + DefCfaExpr.append(Buffer, Buffer + encodeULEB128(Expr.size(), Buffer)); + DefCfaExpr.append(Expr.str()); + unsigned CFIIndex = MF.addFrameInst( + MCCFIInstruction::createEscape( + nullptr, + DefCfaExpr.str(), + SMLoc(), + Comment.str() + ) + ); + + CFIBuildInfos.push_back({&MBB, Def, DL, CFIIndex, true, DwarfEHRegNum, DwarfFrameReg, FixedOffset}); + trackRegisterAndEmitCFIs(MF, *Def, StoredReg, DwarfEHRegNum, RDA, TII, MFI, TRI, CFIBuildInfos, VisitedRestorePoints, VisitedDefs); + } + else if (Register LoadedReg = TII.isLoadFromStackSlot(*Def, FrameIndex)) { + assert(LoadedReg == Reg); + + unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::createRegister( + nullptr, DwarfEHRegNum, TRI.getDwarfRegNum(LoadedReg, true))); + CFIBuildInfos.push_back({&MBB, Def, DL, CFIIndex}); + trackRegisterAndEmitCFIs(MF, *Def, Register::index2StackSlot(FrameIndex), DwarfEHRegNum, RDA, TII, MFI, TRI, CFIBuildInfos, VisitedRestorePoints, VisitedDefs); + } + else if (auto DstSrc = TII.isCopyInstr(*Def)) { + Register DstReg = DstSrc->Destination->getReg(); + Register SrcReg = DstSrc->Source->getReg(); + assert(DstReg == Reg); + + unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::createRegister( + nullptr, DwarfEHRegNum, TRI.getDwarfRegNum(DstReg, true))); + CFIBuildInfos.push_back({&MBB, Def, DL, CFIIndex}); + trackRegisterAndEmitCFIs(MF, *Def, SrcReg, DwarfEHRegNum, RDA, TII, MFI, TRI, CFIBuildInfos, VisitedRestorePoints, VisitedDefs); + } + else { + llvm_unreachable("Unexpected instruction"); + } + } + return; +} + +int RISCVFrameLowering::getInitialCFAOffset(const MachineFunction &MF) const { + return 0; +} + +Register +RISCVFrameLowering::getInitialCFARegister(const MachineFunction &MF) const { + return RISCV::X2; +} + +void RISCVFrameLowering::emitCFIsForCSRsHandledByRA(MachineFunction &MF, ReachingDefAnalysis *RDA) const { + if (!STI.doCSRSavesInRA()) + return; + const RISCVInstrInfo &TII = *STI.getInstrInfo(); + const RISCVRegisterInfo &TRI = *STI.getRegisterInfo(); + const MachineFrameInfo &MFI = MF.getFrameInfo(); + + BitVector MustCalleeSavedRegs; + determineMustCalleeSaves(MF, MustCalleeSavedRegs); + const MCPhysReg * CSRegs = MF.getRegInfo().getCalleeSavedRegs(); + SmallVector EligibleRegs; + for (int i = 0; CSRegs[i]; ++i) { + unsigned Reg = CSRegs[i]; + if (!MustCalleeSavedRegs.test(Reg)) + EligibleRegs.push_back(CSRegs[i]); + } + + SmallVector RestorePoints; + for (MachineBasicBlock &MBB : MF) { + if (MBB.isReturnBlock()) + RestorePoints.push_back(&MBB.back()); + } + std::vector CFIBuildInfos; + for (MCPhysReg Reg : EligibleRegs) { + std::unordered_set VisitedDefs; + for (MachineInstr *RestorePoint : RestorePoints) { + std::unordered_set VisitedRestorePoints; + trackRegisterAndEmitCFIs(MF, *RestorePoint, Reg, TRI.getDwarfRegNum(Reg, true), *RDA, TII, MFI, TRI, CFIBuildInfos, VisitedRestorePoints, VisitedDefs); + } + } + for (CFIBuildInfo &Info : CFIBuildInfos) { + MachineBasicBlock::iterator InsertPos = Info.InsertAfterMI ? ++(Info.InsertAfterMI->getIterator()) : Info.MBB->begin(); + MachineInstr *CFIInstr = BuildMI(*Info.MBB, InsertPos, Info.DL, TII.get(TargetOpcode::CFI_INSTRUCTION)) + .addCFIIndex(Info.CFIIndex) + .setMIFlag(MachineInstr::FrameSetup); + if (Info.ShouldRecord) { + RISCVMachineFunctionInfo &RVFI = *MF.getInfo(); + RVFI.recordCFIInfo(CFIInstr, Info.DwarfEHRegNum, Info.DwarfFrameReg, Info.FixedOffset); + } + } + return; +} + + void RISCVFrameLowering::emitPrologue(MachineFunction &MF, MachineBasicBlock &MBB) const { MachineFrameInfo &MFI = MF.getFrameInfo(); @@ -1057,17 +1234,55 @@ RISCVFrameLowering::getFrameIndexReference(const MachineFunction &MF, int FI, return Offset; } -void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF, - BitVector &SavedRegs, - RegScavenger *RS) const { - TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS); - // Unconditionally spill RA and FP only if the function uses a frame - // pointer. +void RISCVFrameLowering::determineMustCalleeSaves(MachineFunction &MF, + BitVector &SavedRegs) const { + const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo(); + + // Resize before the early returns. Some backends expect that + // SavedRegs.size() == TRI.getNumRegs() after this call even if there are no + // saved registers. + SavedRegs.resize(TRI.getNumRegs()); + + // When interprocedural register allocation is enabled caller saved registers + // are preferred over callee saved registers. + if (MF.getTarget().Options.EnableIPRA && + isSafeForNoCSROpt(MF.getFunction()) && + isProfitableForNoCSROpt(MF.getFunction())) + return; + + // Get the callee saved register list... + const MCPhysReg *CSRegs = MF.getRegInfo().getCalleeSavedRegs(); + + // Early exit if there are no callee saved registers. + if (!CSRegs || CSRegs[0] == 0) + return; + + // In Naked functions we aren't going to save any registers. + if (MF.getFunction().hasFnAttribute(Attribute::Naked)) + return; + + // Noreturn+nounwind functions never restore CSR, so no saves are needed. + // Purely noreturn functions may still return through throws, so those must + // save CSR for caller exception handlers. + // + // If the function uses longjmp to break out of its current path of + // execution we do not need the CSR spills either: setjmp stores all CSRs + // it was called with into the jmp_buf, which longjmp then restores. + if (MF.getFunction().hasFnAttribute(Attribute::NoReturn) && + MF.getFunction().hasFnAttribute(Attribute::NoUnwind) && + !MF.getFunction().hasFnAttribute(Attribute::UWTable) && + enableCalleeSaveSkip(MF)) + return; + + // Functions which call __builtin_unwind_init get all their registers saved. + if (MF.callsUnwindInit()) { + SavedRegs.set(); + return; + } if (hasFP(MF)) { - SavedRegs.set(RAReg); - SavedRegs.set(FPReg); + SavedRegs.set(RISCV::X1); + SavedRegs.set(RISCV::X8); } - // Mark BP as used if function has dedicated base pointer. if (hasBP(MF)) SavedRegs.set(RISCVABI::getBPReg()); @@ -1077,6 +1292,17 @@ void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF, SavedRegs.set(RISCV::X27); } +void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF, + BitVector &SavedRegs, + RegScavenger *RS) const { + const auto &ST = MF.getSubtarget(); + determineMustCalleeSaves(MF, SavedRegs); + if (ST.doCSRSavesInRA()) + return; + + TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS); +} + std::pair RISCVFrameLowering::assignRVVStackObjectOffsets(MachineFunction &MF) const { MachineFrameInfo &MFI = MF.getFrameInfo(); diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.h b/llvm/lib/Target/RISCV/RISCVFrameLowering.h index f45fcdb0acd6bc..e97c6ca7335de3 100644 --- a/llvm/lib/Target/RISCV/RISCVFrameLowering.h +++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.h @@ -23,6 +23,11 @@ class RISCVFrameLowering : public TargetFrameLowering { public: explicit RISCVFrameLowering(const RISCVSubtarget &STI); + int getInitialCFAOffset(const MachineFunction &MF) const override; + Register + getInitialCFARegister(const MachineFunction &MF) const override; + void emitCFIsForCSRsHandledByRA(MachineFunction &MF, ReachingDefAnalysis *RDA) const override; + void emitPrologue(MachineFunction &MF, MachineBasicBlock &MBB) const override; void emitEpilogue(MachineFunction &MF, MachineBasicBlock &MBB) const override; @@ -31,6 +36,7 @@ class RISCVFrameLowering : public TargetFrameLowering { StackOffset getFrameIndexReference(const MachineFunction &MF, int FI, Register &FrameReg) const override; + void determineMustCalleeSaves(MachineFunction &MF, BitVector &SavedRegs) const; void determineCalleeSaves(MachineFunction &MF, BitVector &SavedRegs, RegScavenger *RS) const override; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index af7a39b2580a37..c21d8782d5aeb1 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -21874,6 +21874,125 @@ bool RISCVTargetLowering::fallBackToDAGISel(const Instruction &Inst) const { return false; } +static MachineInstr *findInstrWhichNeedAllCSRs(MachineBasicBlock &MBB) { + // Some instructions may require (implicitly) all CSRs to be saved. + // For example, call to __cxa_throw is noreturn, but expects that all CSRs are taken care of. + // TODO: try to speedup this? + for (MachineInstr &MI : MBB) { + unsigned Opc = MI.getOpcode(); + if (Opc != RISCV::PseudoCALL && Opc != RISCV::PseudoTAIL) + continue; + MachineOperand &MO = MI.getOperand(0); + StringRef Name = ""; + if (MO.isSymbol()) { + Name = MO.getSymbolName(); + } else if (MO.isGlobal()) { + Name = MO.getGlobal()->getName(); + } else { + llvm_unreachable("Unexpected operand type."); + } + if ( + Name == "__cxa_throw" + || Name == "__cxa_rethrow" + || Name == "_Unwind_Resume" + ) + return &MI; + } + return nullptr; +} + +void RISCVTargetLowering::finalizeLowering(MachineFunction &MF) const { + if (!Subtarget.doCSRSavesInRA()) { + TargetLoweringBase::finalizeLowering(MF); + return; + } + + MachineRegisterInfo &MRI = MF.getRegInfo(); + const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo(); + const RISCVRegisterInfo &TRI = *Subtarget.getRegisterInfo(); + const RISCVFrameLowering &TFI = *Subtarget.getFrameLowering(); + + SmallVector RestorePoints; + SmallVector SaveMBBs; + SaveMBBs.push_back(&MF.front()); + for (MachineBasicBlock &MBB : MF) { + if (MBB.isReturnBlock()) + RestorePoints.push_back(&MBB.back()); + if (MachineInstr *CallToCxaThrow = findInstrWhichNeedAllCSRs(MBB)) { + //MachineBasicBlock::iterator MII = CallToCxaThrow->getIterator(); + //++MII; + //assert(MII->getOpcode() == RISCV::ADJCALLSTACKUP && "Unexpected instruction"); + //++MII; + MachineBasicBlock::iterator MII = MBB.getFirstTerminator(); + MachineInstr *NewRetMI = BuildMI( + MBB, + MII, + CallToCxaThrow->getDebugLoc(), + TII.get(RISCV::UnreachableRET) + ); + RestorePoints.push_back(NewRetMI); + MII = ++NewRetMI->getIterator(); + MBB.erase(MII, MBB.end()); + } + } + + const MCPhysReg * CSRegs = MF.getRegInfo().getCalleeSavedRegs(); + SmallVector EligibleRegs; + BitVector MustCalleeSavedRegs; + TFI.determineMustCalleeSaves(MF, MustCalleeSavedRegs); + for (int i = 0; CSRegs[i]; ++i) { + unsigned Reg = CSRegs[i]; + if (!MustCalleeSavedRegs.test(Reg)) { + EligibleRegs.push_back(CSRegs[i]); + } + } + + SmallVector VRegs; + for (MachineBasicBlock *SaveMBB : SaveMBBs) { + for (MCPhysReg Reg : EligibleRegs) { + SaveMBB->addLiveIn(Reg); + // TODO: should we use Maximal register class instead? + Register VReg = MRI.createVirtualRegister(TRI.getLargestLegalSuperClass(TRI.getMinimalPhysRegClass(Reg), MF)); + VRegs.push_back(VReg); + BuildMI( + *SaveMBB, + SaveMBB->begin(), + SaveMBB->findDebugLoc(SaveMBB->begin()), + TII.get(TargetOpcode::COPY), + VReg + ) + .addReg(Reg); + MRI.setSimpleHint(VReg, Reg); + } + } + + for (MachineInstr *RestorePoint : RestorePoints) { + auto VRegI = VRegs.begin(); + for (MCPhysReg Reg : EligibleRegs) { + Register VReg = *VRegI; + BuildMI( + *RestorePoint->getParent(), + RestorePoint->getIterator(), + RestorePoint->getDebugLoc(), + TII.get(TargetOpcode::COPY), + Reg + ) + .addReg(VReg); + RestorePoint->addOperand( + MF, + MachineOperand::CreateReg( + Reg, + /*isDef=*/false, + /*isImplicit=*/true + ) + ); + VRegI++; + } + } + + TargetLoweringBase::finalizeLowering(MF); +} + SDValue RISCVTargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor, SelectionDAG &DAG, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 0b07ad7d7a423f..f625176bdcad40 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -883,6 +883,8 @@ class RISCVTargetLowering : public TargetLowering { bool fallBackToDAGISel(const Instruction &Inst) const override; + void finalizeLowering(MachineFunction &MF) const override; + bool lowerInterleavedLoad(LoadInst *LI, ArrayRef Shuffles, ArrayRef Indices, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h index c3aa367486627a..83eccedb204619 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -293,6 +293,14 @@ class RISCVInstrInfo : public RISCVGenInstrInfo { unsigned getTailDuplicateSize(CodeGenOptLevel OptLevel) const override; + bool expandPostRAPseudo(MachineInstr &MI) const override { + if (MI.getOpcode() == RISCV::UnreachableRET) { + MI.eraseFromParent(); + return true; + } + return false; + } + protected: const RISCVSubtarget &STI; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td index a867368235584c..0826f6daa390c3 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -1615,6 +1615,10 @@ let isBarrier = 1, isReturn = 1, isTerminator = 1 in def PseudoRET : Pseudo<(outs), (ins), [(riscv_ret_glue)]>, PseudoInstExpansion<(JALR X0, X1, 0)>; +let isBarrier = 1, isReturn = 1, isTerminator = 1, isMeta = 1, hasSideEffects = 1, mayLoad = 0, mayStore = 0 in +def UnreachableRET : Pseudo<(outs), (ins), []>; + + // PseudoTAIL is a pseudo instruction similar to PseudoCALL and will eventually // expand to auipc and jalr while encoding. // Define AsmString to print "tail" when compile with -S flag. diff --git a/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.cpp b/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.cpp index d0c363042f5118..b2a582d0ae79f6 100644 --- a/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.cpp @@ -43,3 +43,31 @@ void RISCVMachineFunctionInfo::addSExt32Register(Register Reg) { bool RISCVMachineFunctionInfo::isSExt32Register(Register Reg) const { return is_contained(SExt32Registers, Reg); } + +void RISCVMachineFunctionInfo::recordCFIInfo( + MachineInstr* MI, + int Reg, + int FrameReg, + int64_t Offset +) { + assert(Reg >= 0 && "Negative dwarf reg number!"); + CFIInfoMap[MI] = std::make_tuple(Reg, FrameReg, Offset); +} + +bool RISCVMachineFunctionInfo::getCFIInfo( + MachineInstr* MI, + int &Reg, + int &FrameReg, + int64_t &Offset +) { + auto Found = CFIInfoMap.find(MI); + if (Found == CFIInfoMap.end()) { + return false; + } + Reg = get<0>(Found->second); + FrameReg = get<1>(Found->second); + assert(Reg >= 0 && "Negative dwarf reg number!"); + assert(FrameReg >= 0 && "Negative dwarf reg number!"); + Offset = get<2>(Found->second); + return true; +} diff --git a/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h b/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h index 779c652b4d8fc4..09aa81fdaaee1c 100644 --- a/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h +++ b/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h @@ -14,6 +14,7 @@ #define LLVM_LIB_TARGET_RISCV_RISCVMACHINEFUNCTIONINFO_H #include "RISCVSubtarget.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/CodeGen/MIRYamlMapping.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" @@ -76,6 +77,9 @@ class RISCVMachineFunctionInfo : public MachineFunctionInfo { unsigned RVPushRegs = 0; int RVPushRlist = llvm::RISCVZC::RLISTENCODE::INVALID_RLIST; + + SmallDenseMap> CFIInfoMap; + public: RISCVMachineFunctionInfo(const Function &F, const TargetSubtargetInfo *STI) {} @@ -157,6 +161,19 @@ class RISCVMachineFunctionInfo : public MachineFunctionInfo { bool isVectorCall() const { return IsVectorCall; } void setIsVectorCall() { IsVectorCall = true; } + + void recordCFIInfo( + MachineInstr* MI, + int Reg, + int FrameReg, + int64_t Offset + ); + bool getCFIInfo( + MachineInstr* MI, + int &Reg, + int &FrameReg, + int64_t &Offset + ); }; } // end namespace llvm diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp index 26195ef721db39..6c8dec48492727 100644 --- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp @@ -755,6 +755,14 @@ RISCVRegisterInfo::getCallPreservedMask(const MachineFunction & MF, const TargetRegisterClass * RISCVRegisterInfo::getLargestLegalSuperClass(const TargetRegisterClass *RC, const MachineFunction &) const { + if (RC == &RISCV::GPRX1RegClass) + return &RISCV::GPRRegClass; + if (RC == &RISCV::GPRCRegClass) + return &RISCV::GPRRegClass; + if (RC == &RISCV::SR07RegClass) + return &RISCV::GPRRegClass; + if (RC == &RISCV::GPRJALRRegClass) + return &RISCV::GPRRegClass; if (RC == &RISCV::VMV0RegClass) return &RISCV::VRRegClass; if (RC == &RISCV::VRNoV0RegClass) diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp index e7db1ededf383b..51bdd757f3006d 100644 --- a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp +++ b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp @@ -61,6 +61,11 @@ static cl::opt RISCVMinimumJumpTableEntries( "riscv-min-jump-table-entries", cl::Hidden, cl::desc("Set minimum number of entries to use a jump table on RISCV")); +static cl::opt RISCVEnableSaveCSRByRA( + "riscv-enable-save-csr-in-ra", + cl::desc("Let register alloctor do csr saves/restores"), + cl::init(false), cl::Hidden); + void RISCVSubtarget::anchor() {} RISCVSubtarget & @@ -129,6 +134,10 @@ bool RISCVSubtarget::useConstantPoolForLargeInts() const { return !RISCVDisableUsingConstantPoolForLargeInts; } +bool RISCVSubtarget::doCSRSavesInRA() const { + return RISCVEnableSaveCSRByRA; +} + unsigned RISCVSubtarget::getMaxBuildIntsCost() const { // Loading integer from constant pool needs two instructions (the reason why // the minimum cost is 2): an address calculation instruction and a load diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h index bf9ed3f3d71655..72e693c19ad14b 100644 --- a/llvm/lib/Target/RISCV/RISCVSubtarget.h +++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h @@ -271,6 +271,8 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo { bool useConstantPoolForLargeInts() const; + bool doCSRSavesInRA() const override; + // Maximum cost used for building integers, integers will be put into constant // pool if exceeded. unsigned getMaxBuildIntsCost() const; diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp index 72d74d2d79b1d5..5c6f8503b30acc 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp @@ -542,6 +542,8 @@ void RISCVPassConfig::addPreEmitPass2() { addPass(createUnpackMachineBundles([&](const MachineFunction &MF) { return MF.getFunction().getParent()->getModuleFlag("kcfi"); })); + + addPass(createRISCVCFIInstrInserter()); } void RISCVPassConfig::addMachineSSAOptimization() {