diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h index b498e0f189465c..5211c7922ea2fd 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h @@ -33,6 +33,7 @@ namespace llvm::sandboxir { class DependencyGraph; class MemDGNode; +class SchedBundle; /// SubclassIDs for isa/dyn_cast etc. enum class DGNodeID { @@ -100,6 +101,12 @@ class DGNode { unsigned UnscheduledSuccs = 0; /// This is true if this node has been scheduled. bool Scheduled = false; + /// The scheduler bundle that this node belongs to. + SchedBundle *SB = nullptr; + + void setSchedBundle(SchedBundle &SB) { this->SB = &SB; } + void clearSchedBundle() { this->SB = nullptr; } + friend class SchedBundle; // For setSchedBundle(), clearSchedBundle(). DGNode(Instruction *I, DGNodeID ID) : I(I), SubclassID(ID) {} friend class MemDGNode; // For constructor. @@ -122,6 +129,8 @@ class DGNode { /// \Returns true if this node has been scheduled. bool scheduled() const { return Scheduled; } void setScheduled(bool NewVal) { Scheduled = NewVal; } + /// \Returns the scheduling bundle that this node belongs to, or nullptr. + SchedBundle *getSchedBundle() const { return SB; } /// \Returns true if this is before \p Other in program order. bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); } using iterator = PredIterator; @@ -350,6 +359,10 @@ class DependencyGraph { getOrCreateNode(I); // TODO: Update the dependencies for the new node. } + void clear() { + InstrToNodeMap.clear(); + DAGInterval = {}; + } #ifndef NDEBUG void print(raw_ostream &OS) const; LLVM_DUMP_METHOD void dump() const; diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h index 0e4eea3880efbd..2d6b4035b67408 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h @@ -53,6 +53,7 @@ class ReadyListContainer { return Back; } bool empty() const { return List.empty(); } + void clear() { List = {}; } #ifndef NDEBUG void dump(raw_ostream &OS) const; LLVM_DUMP_METHOD void dump() const; @@ -70,7 +71,16 @@ class SchedBundle { public: SchedBundle() = default; - SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {} + SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) { + for (auto *N : this->Nodes) + N->setSchedBundle(*this); + } + ~SchedBundle() { + for (auto *N : this->Nodes) + N->clearSchedBundle(); + } + bool empty() const { return Nodes.empty(); } + DGNode *back() const { return Nodes.back(); } using iterator = ContainerTy::iterator; using const_iterator = ContainerTy::const_iterator; iterator begin() { return Nodes.begin(); } @@ -94,19 +104,30 @@ class Scheduler { ReadyListContainer ReadyList; DependencyGraph DAG; std::optional ScheduleTopItOpt; - SmallVector> Bndls; + // TODO: This is wasting memory in exchange for fast removal using a raw ptr. + DenseMap> Bndls; Context &Ctx; Context::CallbackID CreateInstrCB; /// \Returns a scheduling bundle containing \p Instrs. SchedBundle *createBundle(ArrayRef Instrs); + void eraseBundle(SchedBundle *SB); /// Schedule nodes until we can schedule \p Instrs back-to-back. bool tryScheduleUntil(ArrayRef Instrs); /// Schedules all nodes in \p Bndl, marks them as scheduled, updates the /// UnscheduledSuccs counter of all dependency predecessors, and adds any of /// them that become ready to the ready list. void scheduleAndUpdateReadyList(SchedBundle &Bndl); - + /// The scheduling state of the instructions in the bundle. + enum class BndlSchedState { + NoneScheduled, + PartiallyOrDifferentlyScheduled, + FullyScheduled, + }; + /// \Returns whether none/some/all of \p Instrs have been scheduled. + BndlSchedState getBndlSchedState(ArrayRef Instrs) const; + /// Destroy the top-most part of the schedule that includes \p Instrs. + void trimSchedule(ArrayRef Instrs); /// Disable copies. Scheduler(const Scheduler &) = delete; Scheduler &operator=(const Scheduler &) = delete; diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h index 85229150de2b6c..d44c845bfbf4e9 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h @@ -100,6 +100,14 @@ class VecUtils { } return FixedVectorType::get(ElemTy, NumElts); } + static Instruction *getLowest(ArrayRef Instrs) { + Instruction *LowestI = Instrs.front(); + for (auto *I : drop_begin(Instrs)) { + if (LowestI->comesBefore(I)) + LowestI = I; + } + return LowestI; + } }; } // namespace llvm::sandboxir diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp index 6140c2a8dcec82..2c869d4619d8d3 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h" +#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" namespace llvm::sandboxir { @@ -95,10 +96,12 @@ SchedBundle *Scheduler::createBundle(ArrayRef Instrs) { Nodes.push_back(DAG.getNode(I)); auto BndlPtr = std::make_unique(std::move(Nodes)); auto *Bndl = BndlPtr.get(); - Bndls.push_back(std::move(BndlPtr)); + Bndls[Bndl] = std::move(BndlPtr); return Bndl; } +void Scheduler::eraseBundle(SchedBundle *SB) { Bndls.erase(SB); } + bool Scheduler::tryScheduleUntil(ArrayRef Instrs) { // Use a set of instructions, instead of `Instrs` for fast lookups. DenseSet InstrsToDefer(Instrs.begin(), Instrs.end()); @@ -133,29 +136,87 @@ bool Scheduler::tryScheduleUntil(ArrayRef Instrs) { return false; } +Scheduler::BndlSchedState +Scheduler::getBndlSchedState(ArrayRef Instrs) const { + assert(!Instrs.empty() && "Expected non-empty bundle"); + bool PartiallyScheduled = false; + bool FullyScheduled = true; + for (auto *I : Instrs) { + auto *N = DAG.getNode(I); + if (N != nullptr && N->scheduled()) + PartiallyScheduled = true; + else + FullyScheduled = false; + } + if (FullyScheduled) { + // If not all instrs in the bundle are in the same SchedBundle then this + // should be considered as partially-scheduled, because we will need to + // re-schedule. + SchedBundle *SB = DAG.getNode(Instrs[0])->getSchedBundle(); + assert(SB != nullptr && "FullyScheduled assumes that there is an SB!"); + if (any_of(drop_begin(Instrs), [this, SB](sandboxir::Value *SBV) { + return DAG.getNode(cast(SBV)) + ->getSchedBundle() != SB; + })) + FullyScheduled = false; + } + return FullyScheduled ? BndlSchedState::FullyScheduled + : PartiallyScheduled ? BndlSchedState::PartiallyOrDifferentlyScheduled + : BndlSchedState::NoneScheduled; +} + +void Scheduler::trimSchedule(ArrayRef Instrs) { + Instruction *TopI = &*ScheduleTopItOpt.value(); + Instruction *LowestI = VecUtils::getLowest(Instrs); + // Destroy the schedule bundles from LowestI all the way to the top. + for (auto *I = LowestI, *E = TopI->getPrevNode(); I != E; + I = I->getPrevNode()) { + auto *N = DAG.getNode(I); + if (auto *SB = N->getSchedBundle()) + eraseBundle(SB); + } + // TODO: For now we clear the DAG. Trim view once it gets implemented. + DAG.clear(); + + // Since we are scheduling NewRegion from scratch, we clear the ready lists. + // The nodes currently in the list may not be ready after clearing the View. + ReadyList.clear(); +} + bool Scheduler::trySchedule(ArrayRef Instrs) { assert(all_of(drop_begin(Instrs), [Instrs](Instruction *I) { return I->getParent() == (*Instrs.begin())->getParent(); }) && "Instrs not in the same BB!"); - // Extend the DAG to include Instrs. - Interval Extension = DAG.extend(Instrs); - // TODO: Set the window of the DAG that we are interested in. - // We start scheduling at the bottom instr of Instrs. - auto getBottomI = [](ArrayRef Instrs) -> Instruction * { - return *min_element(Instrs, - [](auto *I1, auto *I2) { return I1->comesBefore(I2); }); - }; - ScheduleTopItOpt = std::next(getBottomI(Instrs)->getIterator()); - // Add nodes to ready list. - for (auto &I : Extension) { - auto *N = DAG.getNode(&I); - if (N->ready()) - ReadyList.insert(N); + auto SchedState = getBndlSchedState(Instrs); + switch (SchedState) { + case BndlSchedState::FullyScheduled: + // Nothing to do. + return true; + case BndlSchedState::PartiallyOrDifferentlyScheduled: + // If one or more instrs are already scheduled we need to destroy the + // top-most part of the schedule that includes the instrs in the bundle and + // re-schedule. + trimSchedule(Instrs); + [[fallthrough]]; + case BndlSchedState::NoneScheduled: { + // TODO: Set the window of the DAG that we are interested in. + // We start scheduling at the bottom instr of Instrs. + ScheduleTopItOpt = std::next(VecUtils::getLowest(Instrs)->getIterator()); + + // Extend the DAG to include Instrs. + Interval Extension = DAG.extend(Instrs); + // Add nodes to ready list. + for (auto &I : Extension) { + auto *N = DAG.getNode(&I); + if (N->ready()) + ReadyList.insert(N); + } + // Try schedule all nodes until we can schedule Instrs back-to-back. + return tryScheduleUntil(Instrs); + } } - // Try schedule all nodes until we can schedule Instrs back-to-back. - return tryScheduleUntil(Instrs); } #ifndef NDEBUG diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll index 45c701a18fd9bf..e56dbd75963f7a 100644 --- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll +++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll @@ -96,7 +96,37 @@ define void @store_fcmp_zext_load(ptr %ptr) { ret void } -; TODO: Test store_fadd_load once we implement scheduler callbacks and legality diamond check +define void @store_fadd_load(ptr %ptr) { +; CHECK-LABEL: define void @store_fadd_load( +; CHECK-SAME: ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0 +; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1 +; CHECK-NEXT: [[LDA0:%.*]] = load float, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LDA1:%.*]] = load float, ptr [[PTR1]], align 4 +; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LDB0:%.*]] = load float, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LDB1:%.*]] = load float, ptr [[PTR1]], align 4 +; CHECK-NEXT: [[VECL1:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[FADD0:%.*]] = fadd float [[LDA0]], [[LDB0]] +; CHECK-NEXT: [[FADD1:%.*]] = fadd float [[LDA1]], [[LDB1]] +; CHECK-NEXT: [[VEC:%.*]] = fadd <2 x float> [[VECL]], [[VECL1]] +; CHECK-NEXT: store float [[FADD0]], ptr [[PTR0]], align 4 +; CHECK-NEXT: store float [[FADD1]], ptr [[PTR1]], align 4 +; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4 +; CHECK-NEXT: ret void +; + %ptr0 = getelementptr float, ptr %ptr, i32 0 + %ptr1 = getelementptr float, ptr %ptr, i32 1 + %ldA0 = load float, ptr %ptr0 + %ldA1 = load float, ptr %ptr1 + %ldB0 = load float, ptr %ptr0 + %ldB1 = load float, ptr %ptr1 + %fadd0 = fadd float %ldA0, %ldB0 + %fadd1 = fadd float %ldA1, %ldB1 + store float %fadd0, ptr %ptr0 + store float %fadd1, ptr %ptr1 + ret void +} define void @store_fneg_load(ptr %ptr) { ; CHECK-LABEL: define void @store_fneg_load( diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp index 4a8b0ba1d7c12b..94a57914429748 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp @@ -168,11 +168,10 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { EXPECT_TRUE(Sched.trySchedule({S0})); } { - // Try invalid scheduling + // Try invalid scheduling. Dependency S0->S1. sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); EXPECT_TRUE(Sched.trySchedule({Ret})); - EXPECT_TRUE(Sched.trySchedule({S0})); - EXPECT_FALSE(Sched.trySchedule({S1})); + EXPECT_FALSE(Sched.trySchedule({S0, S1})); } } @@ -202,3 +201,39 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) { EXPECT_TRUE(Sched.trySchedule({S0, S1})); EXPECT_TRUE(Sched.trySchedule({L0, L1})); } + +TEST_F(SchedulerTest, RescheduleAlreadyScheduled) { + parseIR(C, R"IR( +define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) { + %ld0 = load i8, ptr %ptr0 + %ld1 = load i8, ptr %ptr1 + %add0 = add i8 %ld0, %ld0 + %add1 = add i8 %ld1, %ld1 + store i8 %add0, ptr %ptr0 + store i8 %add1, ptr %ptr1 + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + auto It = BB->begin(); + auto *L0 = cast(&*It++); + auto *L1 = cast(&*It++); + auto *Add0 = cast(&*It++); + auto *Add1 = cast(&*It++); + auto *S0 = cast(&*It++); + auto *S1 = cast(&*It++); + auto *Ret = cast(&*It++); + + sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); + EXPECT_TRUE(Sched.trySchedule({Ret})); + EXPECT_TRUE(Sched.trySchedule({S0, S1})); + EXPECT_TRUE(Sched.trySchedule({L0, L1})); + // At this point Add0 and Add1 should have been individually scheduled + // as single bundles. + // Check if rescheduling works. + EXPECT_TRUE(Sched.trySchedule({Add0, Add1})); + EXPECT_TRUE(Sched.trySchedule({L0, L1})); +} diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp index 6d1ab95ce31440..835b9285c9d9ff 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp @@ -410,3 +410,32 @@ TEST_F(VecUtilsTest, GetWideType) { auto *Int32X8Ty = sandboxir::FixedVectorType::get(Int32Ty, 8); EXPECT_EQ(sandboxir::VecUtils::getWideType(Int32X4Ty, 2), Int32X8Ty); } + +TEST_F(VecUtilsTest, GetLowest) { + parseIR(R"IR( +define void @foo(i8 %v) { +bb0: + %A = add i8 %v, %v + %B = add i8 %v, %v + %C = add i8 %v, %v + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + + sandboxir::Context Ctx(C); + auto &F = *Ctx.createFunction(&LLVMF); + auto &BB = *F.begin(); + auto It = BB.begin(); + auto *IA = &*It++; + auto *IB = &*It++; + auto *IC = &*It++; + SmallVector ABC({IA, IB, IC}); + EXPECT_EQ(sandboxir::VecUtils::getLowest(ABC), IC); + SmallVector ACB({IA, IC, IB}); + EXPECT_EQ(sandboxir::VecUtils::getLowest(ACB), IC); + SmallVector CAB({IC, IA, IB}); + EXPECT_EQ(sandboxir::VecUtils::getLowest(CAB), IC); + SmallVector CBA({IC, IB, IA}); + EXPECT_EQ(sandboxir::VecUtils::getLowest(CBA), IC); +}