Skip to content

Commit

Permalink
[SandboxVec][Scheduler] Implement rescheduling
Browse files Browse the repository at this point in the history
This patch adds support for re-scheduling already scheduled instructions.
For now this will clear and rebuild the DAG, and will reschedule the code
using the new DAG.
  • Loading branch information
vporpo committed Nov 7, 2024
1 parent 5942a99 commit 3dcdc06
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ namespace llvm::sandboxir {

class DependencyGraph;
class MemDGNode;
class SchedBundle;

/// SubclassIDs for isa/dyn_cast etc.
enum class DGNodeID {
Expand Down Expand Up @@ -100,6 +101,12 @@ class DGNode {
unsigned UnscheduledSuccs = 0;
/// This is true if this node has been scheduled.
bool Scheduled = false;
/// The scheduler bundle that this node belongs to.
SchedBundle *SB = nullptr;

void setSchedBundle(SchedBundle &SB) { this->SB = &SB; }
void clearSchedBundle() { this->SB = nullptr; }
friend class SchedBundle; // For setSchedBundle(), clearSchedBundle().

DGNode(Instruction *I, DGNodeID ID) : I(I), SubclassID(ID) {}
friend class MemDGNode; // For constructor.
Expand All @@ -122,6 +129,8 @@ class DGNode {
/// \Returns true if this node has been scheduled.
bool scheduled() const { return Scheduled; }
void setScheduled(bool NewVal) { Scheduled = NewVal; }
/// \Returns the scheduling bundle that this node belongs to, or nullptr.
SchedBundle *getSchedBundle() const { return SB; }
/// \Returns true if this is before \p Other in program order.
bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
using iterator = PredIterator;
Expand Down Expand Up @@ -350,6 +359,10 @@ class DependencyGraph {
getOrCreateNode(I);
// TODO: Update the dependencies for the new node.
}
void clear() {
InstrToNodeMap.clear();
DAGInterval = {};
}
#ifndef NDEBUG
void print(raw_ostream &OS) const;
LLVM_DUMP_METHOD void dump() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ReadyListContainer {
return Back;
}
bool empty() const { return List.empty(); }
void clear() { List = {}; }
#ifndef NDEBUG
void dump(raw_ostream &OS) const;
LLVM_DUMP_METHOD void dump() const;
Expand All @@ -70,7 +71,16 @@ class SchedBundle {

public:
SchedBundle() = default;
SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {}
SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {
for (auto *N : this->Nodes)
N->setSchedBundle(*this);
}
~SchedBundle() {
for (auto *N : this->Nodes)
N->clearSchedBundle();
}
bool empty() const { return Nodes.empty(); }
DGNode *back() const { return Nodes.back(); }
using iterator = ContainerTy::iterator;
using const_iterator = ContainerTy::const_iterator;
iterator begin() { return Nodes.begin(); }
Expand All @@ -94,19 +104,34 @@ class Scheduler {
ReadyListContainer ReadyList;
DependencyGraph DAG;
std::optional<BasicBlock::iterator> ScheduleTopItOpt;
SmallVector<std::unique_ptr<SchedBundle>> Bndls;
// TODO: This is wasting memory in exchange for fast removal using a raw ptr.
DenseMap<SchedBundle *, std::unique_ptr<SchedBundle>> Bndls;
Context &Ctx;
Context::CallbackID CreateInstrCB;

/// \Returns a scheduling bundle containing \p Instrs.
SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
void eraseBundle(SchedBundle *SB);
/// Schedule nodes until we can schedule \p Instrs back-to-back.
bool tryScheduleUntil(ArrayRef<Instruction *> Instrs);
/// Schedules all nodes in \p Bndl, marks them as scheduled, updates the
/// UnscheduledSuccs counter of all dependency predecessors, and adds any of
/// them that become ready to the ready list.
void scheduleAndUpdateReadyList(SchedBundle &Bndl);

/// The scheduling state of the instructions in the bundle.
enum class BndlSchedState {
NoneScheduled, ///> No instruction in the bundle was previously scheduled.
PartiallyOrDifferentlyScheduled, ///> Only some of the instrs in the bundle
/// were previously scheduled, or all of
/// them were but not in the same
/// SchedBundle.
FullyScheduled, ///> All instrs in the bundle were previously scheduled and
/// were in the same SchedBundle.
};
/// \Returns whether none/some/all of \p Instrs have been scheduled.
BndlSchedState getBndlSchedState(ArrayRef<Instruction *> Instrs) const;
/// Destroy the top-most part of the schedule that includes \p Instrs.
void trimSchedule(ArrayRef<Instruction *> Instrs);
/// Disable copies.
Scheduler(const Scheduler &) = delete;
Scheduler &operator=(const Scheduler &) = delete;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ class VecUtils {
}
return FixedVectorType::get(ElemTy, NumElts);
}
static Instruction *getLowest(ArrayRef<Instruction *> Instrs) {
Instruction *LowestI = Instrs.front();
for (auto *I : drop_begin(Instrs)) {
if (LowestI->comesBefore(I))
LowestI = I;
}
return LowestI;
}
};

} // namespace llvm::sandboxir
Expand Down
96 changes: 79 additions & 17 deletions llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"

namespace llvm::sandboxir {

Expand Down Expand Up @@ -95,10 +96,12 @@ SchedBundle *Scheduler::createBundle(ArrayRef<Instruction *> Instrs) {
Nodes.push_back(DAG.getNode(I));
auto BndlPtr = std::make_unique<SchedBundle>(std::move(Nodes));
auto *Bndl = BndlPtr.get();
Bndls.push_back(std::move(BndlPtr));
Bndls[Bndl] = std::move(BndlPtr);
return Bndl;
}

void Scheduler::eraseBundle(SchedBundle *SB) { Bndls.erase(SB); }

bool Scheduler::tryScheduleUntil(ArrayRef<Instruction *> Instrs) {
// Use a set of instructions, instead of `Instrs` for fast lookups.
DenseSet<Instruction *> InstrsToDefer(Instrs.begin(), Instrs.end());
Expand Down Expand Up @@ -133,29 +136,88 @@ bool Scheduler::tryScheduleUntil(ArrayRef<Instruction *> Instrs) {
return false;
}

Scheduler::BndlSchedState
Scheduler::getBndlSchedState(ArrayRef<Instruction *> Instrs) const {
assert(!Instrs.empty() && "Expected non-empty bundle");
bool PartiallyScheduled = false;
bool FullyScheduled = true;
for (auto *I : Instrs) {
auto *N = DAG.getNode(I);
if (N != nullptr && N->scheduled())
PartiallyScheduled = true;
else
FullyScheduled = false;
}
if (FullyScheduled) {
// If not all instrs in the bundle are in the same SchedBundle then this
// should be considered as partially-scheduled, because we will need to
// re-schedule.
SchedBundle *SB = DAG.getNode(Instrs[0])->getSchedBundle();
assert(SB != nullptr && "FullyScheduled assumes that there is an SB!");
if (any_of(drop_begin(Instrs), [this, SB](sandboxir::Value *SBV) {
return DAG.getNode(cast<sandboxir::Instruction>(SBV))
->getSchedBundle() != SB;
}))
FullyScheduled = false;
}
return FullyScheduled ? BndlSchedState::FullyScheduled
: PartiallyScheduled ? BndlSchedState::PartiallyOrDifferentlyScheduled
: BndlSchedState::NoneScheduled;
}

void Scheduler::trimSchedule(ArrayRef<Instruction *> Instrs) {
Instruction *TopI = &*ScheduleTopItOpt.value();
Instruction *LowestI = VecUtils::getLowest(Instrs);
// Destroy the schedule bundles from LowestI all the way to the top.
for (auto *I = LowestI, *E = TopI->getPrevNode(); I != E;
I = I->getPrevNode()) {
auto *N = DAG.getNode(I);
if (auto *SB = N->getSchedBundle())
eraseBundle(SB);
}
// TODO: For now we clear the DAG. Trim view once it gets implemented.
Bndls.clear();
DAG.clear();

// Since we are scheduling NewRegion from scratch, we clear the ready lists.
// The nodes currently in the list may not be ready after clearing the View.
ReadyList.clear();
}

bool Scheduler::trySchedule(ArrayRef<Instruction *> Instrs) {
assert(all_of(drop_begin(Instrs),
[Instrs](Instruction *I) {
return I->getParent() == (*Instrs.begin())->getParent();
}) &&
"Instrs not in the same BB!");
// Extend the DAG to include Instrs.
Interval<Instruction> Extension = DAG.extend(Instrs);
// TODO: Set the window of the DAG that we are interested in.
// We start scheduling at the bottom instr of Instrs.
auto getBottomI = [](ArrayRef<Instruction *> Instrs) -> Instruction * {
return *min_element(Instrs,
[](auto *I1, auto *I2) { return I1->comesBefore(I2); });
};
ScheduleTopItOpt = std::next(getBottomI(Instrs)->getIterator());
// Add nodes to ready list.
for (auto &I : Extension) {
auto *N = DAG.getNode(&I);
if (N->ready())
ReadyList.insert(N);
auto SchedState = getBndlSchedState(Instrs);
switch (SchedState) {
case BndlSchedState::FullyScheduled:
// Nothing to do.
return true;
case BndlSchedState::PartiallyOrDifferentlyScheduled:
// If one or more instrs are already scheduled we need to destroy the
// top-most part of the schedule that includes the instrs in the bundle and
// re-schedule.
trimSchedule(Instrs);
[[fallthrough]];
case BndlSchedState::NoneScheduled: {
// TODO: Set the window of the DAG that we are interested in.
// We start scheduling at the bottom instr of Instrs.
ScheduleTopItOpt = std::next(VecUtils::getLowest(Instrs)->getIterator());

// Extend the DAG to include Instrs.
Interval<Instruction> Extension = DAG.extend(Instrs);
// Add nodes to ready list.
for (auto &I : Extension) {
auto *N = DAG.getNode(&I);
if (N->ready())
ReadyList.insert(N);
}
// Try schedule all nodes until we can schedule Instrs back-to-back.
return tryScheduleUntil(Instrs);
}
}
// Try schedule all nodes until we can schedule Instrs back-to-back.
return tryScheduleUntil(Instrs);
}

#ifndef NDEBUG
Expand Down
32 changes: 31 additions & 1 deletion llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,37 @@ define void @store_fcmp_zext_load(ptr %ptr) {
ret void
}

; TODO: Test store_fadd_load once we implement scheduler callbacks and legality diamond check
define void @store_fadd_load(ptr %ptr) {
; CHECK-LABEL: define void @store_fadd_load(
; CHECK-SAME: ptr [[PTR:%.*]]) {
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
; CHECK-NEXT: [[LDA0:%.*]] = load float, ptr [[PTR0]], align 4
; CHECK-NEXT: [[LDA1:%.*]] = load float, ptr [[PTR1]], align 4
; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
; CHECK-NEXT: [[LDB0:%.*]] = load float, ptr [[PTR0]], align 4
; CHECK-NEXT: [[LDB1:%.*]] = load float, ptr [[PTR1]], align 4
; CHECK-NEXT: [[VECL1:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
; CHECK-NEXT: [[FADD0:%.*]] = fadd float [[LDA0]], [[LDB0]]
; CHECK-NEXT: [[FADD1:%.*]] = fadd float [[LDA1]], [[LDB1]]
; CHECK-NEXT: [[VEC:%.*]] = fadd <2 x float> [[VECL]], [[VECL1]]
; CHECK-NEXT: store float [[FADD0]], ptr [[PTR0]], align 4
; CHECK-NEXT: store float [[FADD1]], ptr [[PTR1]], align 4
; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4
; CHECK-NEXT: ret void
;
%ptr0 = getelementptr float, ptr %ptr, i32 0
%ptr1 = getelementptr float, ptr %ptr, i32 1
%ldA0 = load float, ptr %ptr0
%ldA1 = load float, ptr %ptr1
%ldB0 = load float, ptr %ptr0
%ldB1 = load float, ptr %ptr1
%fadd0 = fadd float %ldA0, %ldB0
%fadd1 = fadd float %ldA1, %ldB1
store float %fadd0, ptr %ptr0
store float %fadd1, ptr %ptr1
ret void
}

define void @store_fneg_load(ptr %ptr) {
; CHECK-LABEL: define void @store_fneg_load(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,10 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
EXPECT_TRUE(Sched.trySchedule({S0}));
}
{
// Try invalid scheduling
// Try invalid scheduling. Dependency S0->S1.
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S0}));
EXPECT_FALSE(Sched.trySchedule({S1}));
EXPECT_FALSE(Sched.trySchedule({S0, S1}));
}
}

Expand Down Expand Up @@ -202,3 +201,39 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
EXPECT_TRUE(Sched.trySchedule({S0, S1}));
EXPECT_TRUE(Sched.trySchedule({L0, L1}));
}

TEST_F(SchedulerTest, RescheduleAlreadyScheduled) {
parseIR(C, R"IR(
define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
%ld0 = load i8, ptr %ptr0
%ld1 = load i8, ptr %ptr1
%add0 = add i8 %ld0, %ld0
%add1 = add i8 %ld1, %ld1
store i8 %add0, ptr %ptr0
store i8 %add1, ptr %ptr1
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *L0 = cast<sandboxir::LoadInst>(&*It++);
auto *L1 = cast<sandboxir::LoadInst>(&*It++);
auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S0, S1}));
EXPECT_TRUE(Sched.trySchedule({L0, L1}));
// At this point Add0 and Add1 should have been individually scheduled
// as single bundles.
// Check if rescheduling works.
EXPECT_TRUE(Sched.trySchedule({Add0, Add1}));
EXPECT_TRUE(Sched.trySchedule({L0, L1}));
}
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,32 @@ TEST_F(VecUtilsTest, GetWideType) {
auto *Int32X8Ty = sandboxir::FixedVectorType::get(Int32Ty, 8);
EXPECT_EQ(sandboxir::VecUtils::getWideType(Int32X4Ty, 2), Int32X8Ty);
}

TEST_F(VecUtilsTest, GetLowest) {
parseIR(R"IR(
define void @foo(i8 %v) {
bb0:
%A = add i8 %v, %v
%B = add i8 %v, %v
%C = add i8 %v, %v
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");

sandboxir::Context Ctx(C);
auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
auto It = BB.begin();
auto *IA = &*It++;
auto *IB = &*It++;
auto *IC = &*It++;
SmallVector<sandboxir::Instruction *> ABC({IA, IB, IC});
EXPECT_EQ(sandboxir::VecUtils::getLowest(ABC), IC);
SmallVector<sandboxir::Instruction *> ACB({IA, IC, IB});
EXPECT_EQ(sandboxir::VecUtils::getLowest(ACB), IC);
SmallVector<sandboxir::Instruction *> CAB({IC, IA, IB});
EXPECT_EQ(sandboxir::VecUtils::getLowest(CAB), IC);
SmallVector<sandboxir::Instruction *> CBA({IC, IB, IA});
EXPECT_EQ(sandboxir::VecUtils::getLowest(CBA), IC);
}

0 comments on commit 3dcdc06

Please sign in to comment.