Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SandboxVec][Scheduler] Implement rescheduling #115220

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ namespace llvm::sandboxir {

class DependencyGraph;
class MemDGNode;
class SchedBundle;

/// SubclassIDs for isa/dyn_cast etc.
enum class DGNodeID {
Expand Down Expand Up @@ -100,6 +101,12 @@ class DGNode {
unsigned UnscheduledSuccs = 0;
/// This is true if this node has been scheduled.
bool Scheduled = false;
/// The scheduler bundle that this node belongs to.
SchedBundle *SB = nullptr;

void setSchedBundle(SchedBundle &SB) { this->SB = &SB; }
void clearSchedBundle() { this->SB = nullptr; }
friend class SchedBundle; // For setSchedBundle(), clearSchedBundle().

DGNode(Instruction *I, DGNodeID ID) : I(I), SubclassID(ID) {}
friend class MemDGNode; // For constructor.
Expand All @@ -122,6 +129,8 @@ class DGNode {
/// \Returns true if this node has been scheduled.
bool scheduled() const { return Scheduled; }
void setScheduled(bool NewVal) { Scheduled = NewVal; }
/// \Returns the scheduling bundle that this node belongs to, or nullptr.
SchedBundle *getSchedBundle() const { return SB; }
/// \Returns true if this is before \p Other in program order.
bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
using iterator = PredIterator;
Expand Down Expand Up @@ -350,6 +359,10 @@ class DependencyGraph {
getOrCreateNode(I);
// TODO: Update the dependencies for the new node.
}
void clear() {
InstrToNodeMap.clear();
DAGInterval = {};
}
#ifndef NDEBUG
void print(raw_ostream &OS) const;
LLVM_DUMP_METHOD void dump() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ReadyListContainer {
return Back;
}
bool empty() const { return List.empty(); }
void clear() { List = {}; }
#ifndef NDEBUG
void dump(raw_ostream &OS) const;
LLVM_DUMP_METHOD void dump() const;
Expand All @@ -70,7 +71,16 @@ class SchedBundle {

public:
SchedBundle() = default;
SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {}
SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {
for (auto *N : this->Nodes)
N->setSchedBundle(*this);
}
~SchedBundle() {
for (auto *N : this->Nodes)
N->clearSchedBundle();
}
bool empty() const { return Nodes.empty(); }
DGNode *back() const { return Nodes.back(); }
using iterator = ContainerTy::iterator;
using const_iterator = ContainerTy::const_iterator;
iterator begin() { return Nodes.begin(); }
Expand All @@ -94,19 +104,34 @@ class Scheduler {
ReadyListContainer ReadyList;
DependencyGraph DAG;
std::optional<BasicBlock::iterator> ScheduleTopItOpt;
SmallVector<std::unique_ptr<SchedBundle>> Bndls;
// TODO: This is wasting memory in exchange for fast removal using a raw ptr.
DenseMap<SchedBundle *, std::unique_ptr<SchedBundle>> Bndls;
Context &Ctx;
Context::CallbackID CreateInstrCB;

/// \Returns a scheduling bundle containing \p Instrs.
SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
void eraseBundle(SchedBundle *SB);
/// Schedule nodes until we can schedule \p Instrs back-to-back.
bool tryScheduleUntil(ArrayRef<Instruction *> Instrs);
/// Schedules all nodes in \p Bndl, marks them as scheduled, updates the
/// UnscheduledSuccs counter of all dependency predecessors, and adds any of
/// them that become ready to the ready list.
void scheduleAndUpdateReadyList(SchedBundle &Bndl);

/// The scheduling state of the instructions in the bundle.
enum class BndlSchedState {
NoneScheduled, ///> No instruction in the bundle was previously scheduled.
PartiallyOrDifferentlyScheduled, ///> Only some of the instrs in the bundle
/// were previously scheduled, or all of
/// them were but not in the same
/// SchedBundle.
FullyScheduled, ///> All instrs in the bundle were previously scheduled and
/// were in the same SchedBundle.
};
/// \Returns whether none/some/all of \p Instrs have been scheduled.
BndlSchedState getBndlSchedState(ArrayRef<Instruction *> Instrs) const;
/// Destroy the top-most part of the schedule that includes \p Instrs.
void trimSchedule(ArrayRef<Instruction *> Instrs);
/// Disable copies.
Scheduler(const Scheduler &) = delete;
Scheduler &operator=(const Scheduler &) = delete;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ class VecUtils {
}
return FixedVectorType::get(ElemTy, NumElts);
}
static Instruction *getLowest(ArrayRef<Instruction *> Instrs) {
Instruction *LowestI = Instrs.front();
for (auto *I : drop_begin(Instrs)) {
if (LowestI->comesBefore(I))
LowestI = I;
}
return LowestI;
}
};

} // namespace llvm::sandboxir
Expand Down
96 changes: 79 additions & 17 deletions llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"

namespace llvm::sandboxir {

Expand Down Expand Up @@ -95,10 +96,12 @@ SchedBundle *Scheduler::createBundle(ArrayRef<Instruction *> Instrs) {
Nodes.push_back(DAG.getNode(I));
auto BndlPtr = std::make_unique<SchedBundle>(std::move(Nodes));
auto *Bndl = BndlPtr.get();
Bndls.push_back(std::move(BndlPtr));
Bndls[Bndl] = std::move(BndlPtr);
return Bndl;
}

void Scheduler::eraseBundle(SchedBundle *SB) { Bndls.erase(SB); }

bool Scheduler::tryScheduleUntil(ArrayRef<Instruction *> Instrs) {
// Use a set of instructions, instead of `Instrs` for fast lookups.
DenseSet<Instruction *> InstrsToDefer(Instrs.begin(), Instrs.end());
Expand Down Expand Up @@ -133,29 +136,88 @@ bool Scheduler::tryScheduleUntil(ArrayRef<Instruction *> Instrs) {
return false;
}

Scheduler::BndlSchedState
Scheduler::getBndlSchedState(ArrayRef<Instruction *> Instrs) const {
assert(!Instrs.empty() && "Expected non-empty bundle");
bool PartiallyScheduled = false;
bool FullyScheduled = true;
for (auto *I : Instrs) {
auto *N = DAG.getNode(I);
if (N != nullptr && N->scheduled())
PartiallyScheduled = true;
else
FullyScheduled = false;
}
if (FullyScheduled) {
// If not all instrs in the bundle are in the same SchedBundle then this
// should be considered as partially-scheduled, because we will need to
// re-schedule.
SchedBundle *SB = DAG.getNode(Instrs[0])->getSchedBundle();
assert(SB != nullptr && "FullyScheduled assumes that there is an SB!");
if (any_of(drop_begin(Instrs), [this, SB](sandboxir::Value *SBV) {
return DAG.getNode(cast<sandboxir::Instruction>(SBV))
->getSchedBundle() != SB;
}))
FullyScheduled = false;
}
return FullyScheduled ? BndlSchedState::FullyScheduled
: PartiallyScheduled ? BndlSchedState::PartiallyOrDifferentlyScheduled
: BndlSchedState::NoneScheduled;
}

void Scheduler::trimSchedule(ArrayRef<Instruction *> Instrs) {
Instruction *TopI = &*ScheduleTopItOpt.value();
Instruction *LowestI = VecUtils::getLowest(Instrs);
// Destroy the schedule bundles from LowestI all the way to the top.
for (auto *I = LowestI, *E = TopI->getPrevNode(); I != E;
I = I->getPrevNode()) {
auto *N = DAG.getNode(I);
if (auto *SB = N->getSchedBundle())
eraseBundle(SB);
}
// TODO: For now we clear the DAG. Trim view once it gets implemented.
Bndls.clear();
DAG.clear();

// Since we are scheduling NewRegion from scratch, we clear the ready lists.
// The nodes currently in the list may not be ready after clearing the View.
ReadyList.clear();
}

bool Scheduler::trySchedule(ArrayRef<Instruction *> Instrs) {
assert(all_of(drop_begin(Instrs),
[Instrs](Instruction *I) {
return I->getParent() == (*Instrs.begin())->getParent();
}) &&
"Instrs not in the same BB!");
// Extend the DAG to include Instrs.
Interval<Instruction> Extension = DAG.extend(Instrs);
// TODO: Set the window of the DAG that we are interested in.
// We start scheduling at the bottom instr of Instrs.
auto getBottomI = [](ArrayRef<Instruction *> Instrs) -> Instruction * {
return *min_element(Instrs,
[](auto *I1, auto *I2) { return I1->comesBefore(I2); });
};
ScheduleTopItOpt = std::next(getBottomI(Instrs)->getIterator());
// Add nodes to ready list.
for (auto &I : Extension) {
auto *N = DAG.getNode(&I);
if (N->ready())
ReadyList.insert(N);
auto SchedState = getBndlSchedState(Instrs);
switch (SchedState) {
case BndlSchedState::FullyScheduled:
// Nothing to do.
return true;
case BndlSchedState::PartiallyOrDifferentlyScheduled:
// If one or more instrs are already scheduled we need to destroy the
// top-most part of the schedule that includes the instrs in the bundle and
// re-schedule.
trimSchedule(Instrs);
[[fallthrough]];
case BndlSchedState::NoneScheduled: {
// TODO: Set the window of the DAG that we are interested in.
// We start scheduling at the bottom instr of Instrs.
ScheduleTopItOpt = std::next(VecUtils::getLowest(Instrs)->getIterator());

// Extend the DAG to include Instrs.
Interval<Instruction> Extension = DAG.extend(Instrs);
// Add nodes to ready list.
for (auto &I : Extension) {
auto *N = DAG.getNode(&I);
if (N->ready())
ReadyList.insert(N);
}
// Try schedule all nodes until we can schedule Instrs back-to-back.
return tryScheduleUntil(Instrs);
}
}
// Try schedule all nodes until we can schedule Instrs back-to-back.
return tryScheduleUntil(Instrs);
}

#ifndef NDEBUG
Expand Down
32 changes: 31 additions & 1 deletion llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,37 @@ define void @store_fcmp_zext_load(ptr %ptr) {
ret void
}

; TODO: Test store_fadd_load once we implement scheduler callbacks and legality diamond check
define void @store_fadd_load(ptr %ptr) {
; CHECK-LABEL: define void @store_fadd_load(
; CHECK-SAME: ptr [[PTR:%.*]]) {
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
; CHECK-NEXT: [[LDA0:%.*]] = load float, ptr [[PTR0]], align 4
; CHECK-NEXT: [[LDA1:%.*]] = load float, ptr [[PTR1]], align 4
; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
; CHECK-NEXT: [[LDB0:%.*]] = load float, ptr [[PTR0]], align 4
; CHECK-NEXT: [[LDB1:%.*]] = load float, ptr [[PTR1]], align 4
; CHECK-NEXT: [[VECL1:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
; CHECK-NEXT: [[FADD0:%.*]] = fadd float [[LDA0]], [[LDB0]]
; CHECK-NEXT: [[FADD1:%.*]] = fadd float [[LDA1]], [[LDB1]]
; CHECK-NEXT: [[VEC:%.*]] = fadd <2 x float> [[VECL]], [[VECL1]]
; CHECK-NEXT: store float [[FADD0]], ptr [[PTR0]], align 4
; CHECK-NEXT: store float [[FADD1]], ptr [[PTR1]], align 4
; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4
; CHECK-NEXT: ret void
;
%ptr0 = getelementptr float, ptr %ptr, i32 0
%ptr1 = getelementptr float, ptr %ptr, i32 1
%ldA0 = load float, ptr %ptr0
%ldA1 = load float, ptr %ptr1
%ldB0 = load float, ptr %ptr0
%ldB1 = load float, ptr %ptr1
%fadd0 = fadd float %ldA0, %ldB0
%fadd1 = fadd float %ldA1, %ldB1
store float %fadd0, ptr %ptr0
store float %fadd1, ptr %ptr1
ret void
}

define void @store_fneg_load(ptr %ptr) {
; CHECK-LABEL: define void @store_fneg_load(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,10 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
EXPECT_TRUE(Sched.trySchedule({S0}));
}
{
// Try invalid scheduling
// Try invalid scheduling. Dependency S0->S1.
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S0}));
EXPECT_FALSE(Sched.trySchedule({S1}));
EXPECT_FALSE(Sched.trySchedule({S0, S1}));
}
}

Expand Down Expand Up @@ -202,3 +201,39 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
EXPECT_TRUE(Sched.trySchedule({S0, S1}));
EXPECT_TRUE(Sched.trySchedule({L0, L1}));
}

TEST_F(SchedulerTest, RescheduleAlreadyScheduled) {
parseIR(C, R"IR(
define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
%ld0 = load i8, ptr %ptr0
%ld1 = load i8, ptr %ptr1
%add0 = add i8 %ld0, %ld0
%add1 = add i8 %ld1, %ld1
store i8 %add0, ptr %ptr0
store i8 %add1, ptr %ptr1
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *L0 = cast<sandboxir::LoadInst>(&*It++);
auto *L1 = cast<sandboxir::LoadInst>(&*It++);
auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S0, S1}));
EXPECT_TRUE(Sched.trySchedule({L0, L1}));
// At this point Add0 and Add1 should have been individually scheduled
// as single bundles.
// Check if rescheduling works.
EXPECT_TRUE(Sched.trySchedule({Add0, Add1}));
EXPECT_TRUE(Sched.trySchedule({L0, L1}));
}
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,32 @@ TEST_F(VecUtilsTest, GetWideType) {
auto *Int32X8Ty = sandboxir::FixedVectorType::get(Int32Ty, 8);
EXPECT_EQ(sandboxir::VecUtils::getWideType(Int32X4Ty, 2), Int32X8Ty);
}

TEST_F(VecUtilsTest, GetLowest) {
parseIR(R"IR(
define void @foo(i8 %v) {
bb0:
%A = add i8 %v, %v
%B = add i8 %v, %v
%C = add i8 %v, %v
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");

sandboxir::Context Ctx(C);
auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
auto It = BB.begin();
auto *IA = &*It++;
auto *IB = &*It++;
auto *IC = &*It++;
SmallVector<sandboxir::Instruction *> ABC({IA, IB, IC});
EXPECT_EQ(sandboxir::VecUtils::getLowest(ABC), IC);
SmallVector<sandboxir::Instruction *> ACB({IA, IC, IB});
EXPECT_EQ(sandboxir::VecUtils::getLowest(ACB), IC);
SmallVector<sandboxir::Instruction *> CAB({IC, IA, IB});
EXPECT_EQ(sandboxir::VecUtils::getLowest(CAB), IC);
SmallVector<sandboxir::Instruction *> CBA({IC, IB, IA});
EXPECT_EQ(sandboxir::VecUtils::getLowest(CBA), IC);
}
Loading