Skip to content

Commit

Permalink
[SandboxVec] Notify scheduler about new instructions (#115102)
Browse files Browse the repository at this point in the history
This patch registers the "createInstr" callback that notifies the
scheduler about newly created instructions. This guarantees that all
newly created instructions have a corresponding DAG node associated with
them. Without this the pass crashes when the scheduler encounters the
newly created vector instructions.

This patch also changes the lifetime of the sandboxir Ctx variable in
the SandboxVectorizer pass. It needs to be destroyed after the passes
get destroyed. Without this change when components like the Scheduler
get destroyed Ctx will have already been freed, which is not legal.
  • Loading branch information
vporpo authored Nov 6, 2024
1 parent a878dc8 commit 5942a99
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,11 @@ class DependencyGraph {
Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
/// \Returns the range of instructions included in the DAG.
Interval<Instruction> getInterval() const { return DAGInterval; }
/// Called by the scheduler when a new instruction \p I has been created.
void notifyCreateInstr(Instruction *I) {
getOrCreateNode(I);
// TODO: Update the dependencies for the new node.
}
#ifndef NDEBUG
void print(raw_ostream &OS) const;
LLVM_DUMP_METHOD void dump() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,9 @@ class LegalityAnalysis {
const DataLayout &DL;

public:
LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL)
: Sched(AA), SE(SE), DL(DL) {}
LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL,
Context &Ctx)
: Sched(AA, Ctx), SE(SE), DL(DL) {}
/// A LegalityResult factory.
template <typename ResultT, typename... ArgsT>
ResultT &createLegalityResult(ArgsT... Args) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/IR/PassManager.h"
#include "llvm/SandboxIR/Context.h"
#include "llvm/SandboxIR/PassManager.h"

namespace llvm {
Expand All @@ -24,6 +25,8 @@ class SandboxVectorizerPass : public PassInfoMixin<SandboxVectorizerPass> {
AAResults *AA = nullptr;
ScalarEvolution *SE = nullptr;

std::unique_ptr<sandboxir::Context> Ctx;

// A pipeline of SandboxIR function passes run by the vectorizer.
sandboxir::FunctionPassManager FPM;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class Scheduler {
DependencyGraph DAG;
std::optional<BasicBlock::iterator> ScheduleTopItOpt;
SmallVector<std::unique_ptr<SchedBundle>> Bndls;
Context &Ctx;
Context::CallbackID CreateInstrCB;

/// \Returns a scheduling bundle containing \p Instrs.
SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
Expand All @@ -110,8 +112,11 @@ class Scheduler {
Scheduler &operator=(const Scheduler &) = delete;

public:
Scheduler(AAResults &AA) : DAG(AA) {}
~Scheduler() {}
Scheduler(AAResults &AA, Context &Ctx) : DAG(AA), Ctx(Ctx) {
CreateInstrCB = Ctx.registerCreateInstrCallback(
[this](Instruction *I) { DAG.notifyCreateInstr(I); });
}
~Scheduler() { Ctx.unregisterCreateInstrCallback(CreateInstrCB); }

bool trySchedule(ArrayRef<Instruction *> Instrs);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
}
NewVec = createVectorInstr(Bndl, VecOperands);

// TODO: Notify DAG/Scheduler about new instruction

// TODO: Collect potentially dead instructions.
break;
}
Expand All @@ -202,7 +200,8 @@ bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {

bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
Legality = std::make_unique<LegalityAnalysis>(
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout());
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
F.getContext());
Change = false;
// TODO: Start from innermost BBs first
for (auto &BB : F) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ PreservedAnalyses SandboxVectorizerPass::run(Function &F,
}

bool SandboxVectorizerPass::runImpl(Function &LLVMF) {
if (Ctx == nullptr)
Ctx = std::make_unique<sandboxir::Context>(LLVMF.getContext());

if (PrintPassPipeline) {
FPM.printPipeline(outs());
return false;
Expand All @@ -82,8 +85,7 @@ bool SandboxVectorizerPass::runImpl(Function &LLVMF) {
}

// Create SandboxIR for LLVMF and run BottomUpVec on it.
sandboxir::Context Ctx(LLVMF.getContext());
sandboxir::Function &F = *Ctx.createFunction(&LLVMF);
sandboxir::Function &F = *Ctx->createFunction(&LLVMF);
sandboxir::Analyses A(*AA, *SE);
return FPM.runOnFunction(F, A);
}
41 changes: 40 additions & 1 deletion llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,46 @@ define void @store_fpext_load(ptr %ptr) {
ret void
}

; TODO: Test store_zext_fcmp_load once we implement scheduler callbacks and legality diamond check
define void @store_fcmp_zext_load(ptr %ptr) {
; CHECK-LABEL: define void @store_fcmp_zext_load(
; CHECK-SAME: ptr [[PTR:%.*]]) {
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
; CHECK-NEXT: [[PTRB0:%.*]] = getelementptr i32, ptr [[PTR]], i32 0
; CHECK-NEXT: [[PTRB1:%.*]] = getelementptr i32, ptr [[PTR]], i32 1
; CHECK-NEXT: [[LDB0:%.*]] = load float, ptr [[PTR0]], align 4
; CHECK-NEXT: [[LDB1:%.*]] = load float, ptr [[PTR1]], align 4
; CHECK-NEXT: [[VECL1:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
; CHECK-NEXT: [[LDA0:%.*]] = load float, ptr [[PTR0]], align 4
; CHECK-NEXT: [[LDA1:%.*]] = load float, ptr [[PTR1]], align 4
; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
; CHECK-NEXT: [[FCMP0:%.*]] = fcmp ogt float [[LDA0]], [[LDB0]]
; CHECK-NEXT: [[FCMP1:%.*]] = fcmp ogt float [[LDA1]], [[LDB1]]
; CHECK-NEXT: [[VCMP:%.*]] = fcmp ogt <2 x float> [[VECL]], [[VECL1]]
; CHECK-NEXT: [[ZEXT0:%.*]] = zext i1 [[FCMP0]] to i32
; CHECK-NEXT: [[ZEXT1:%.*]] = zext i1 [[FCMP1]] to i32
; CHECK-NEXT: [[VCAST:%.*]] = zext <2 x i1> [[VCMP]] to <2 x i32>
; CHECK-NEXT: store i32 [[ZEXT0]], ptr [[PTRB0]], align 4
; CHECK-NEXT: store i32 [[ZEXT1]], ptr [[PTRB1]], align 4
; CHECK-NEXT: store <2 x i32> [[VCAST]], ptr [[PTRB0]], align 4
; CHECK-NEXT: ret void
;
%ptr0 = getelementptr float, ptr %ptr, i32 0
%ptr1 = getelementptr float, ptr %ptr, i32 1
%ptrb0 = getelementptr i32, ptr %ptr, i32 0
%ptrb1 = getelementptr i32, ptr %ptr, i32 1
%ldB0 = load float, ptr %ptr0
%ldB1 = load float, ptr %ptr1
%ldA0 = load float, ptr %ptr0
%ldA1 = load float, ptr %ptr1
%fcmp0 = fcmp ogt float %ldA0, %ldB0
%fcmp1 = fcmp ogt float %ldA1, %ldB1
%zext0 = zext i1 %fcmp0 to i32
%zext1 = zext i1 %fcmp1 to i32
store i32 %zext0, ptr %ptrb0
store i32 %zext1, ptr %ptrb1
ret void
}

; TODO: Test store_fadd_load once we implement scheduler callbacks and legality diamond check

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);

sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
const auto &Result =
Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
Expand Down Expand Up @@ -228,7 +228,7 @@ define void @foo(ptr %ptr) {
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
auto *St1 = cast<sandboxir::StoreInst>(&*It++);

sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
{
// Can vectorize St0,St1.
const auto &Result = Legality.canVectorize({St0, St1});
Expand Down Expand Up @@ -262,7 +262,8 @@ define void @foo() {
return Buff == ExpectedStr;
};

sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
sandboxir::Context Ctx(C);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
EXPECT_TRUE(
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,20 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {

{
// Schedule all instructions in sequence.
sandboxir::Scheduler Sched(getAA(*LLVMF));
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S1}));
EXPECT_TRUE(Sched.trySchedule({S0}));
}
{
// Skip instructions.
sandboxir::Scheduler Sched(getAA(*LLVMF));
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S0}));
}
{
// Try invalid scheduling
sandboxir::Scheduler Sched(getAA(*LLVMF));
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S0}));
EXPECT_FALSE(Sched.trySchedule({S1}));
Expand Down Expand Up @@ -197,7 +197,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::Scheduler Sched(getAA(*LLVMF));
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S0, S1}));
EXPECT_TRUE(Sched.trySchedule({L0, L1}));
Expand Down

0 comments on commit 5942a99

Please sign in to comment.