From 5942a99f8b7dd361c35eb1c9c32b2475dce2c0b2 Mon Sep 17 00:00:00 2001 From: vporpo Date: Wed, 6 Nov 2024 13:26:14 -0800 Subject: [PATCH] [SandboxVec] Notify scheduler about new instructions (#115102) This patch registers the "createInstr" callback that notifies the scheduler about newly created instructions. This guarantees that all newly created instructions have a corresponding DAG node associated with them. Without this the pass crashes when the scheduler encounters the newly created vector instructions. This patch also changes the lifetime of the sandboxir Ctx variable in the SandboxVectorizer pass. It needs to be destroyed after the passes get destroyed. Without this change when components like the Scheduler get destroyed Ctx will have already been freed, which is not legal. --- .../SandboxVectorizer/DependencyGraph.h | 5 +++ .../Vectorize/SandboxVectorizer/Legality.h | 5 ++- .../SandboxVectorizer/SandboxVectorizer.h | 3 ++ .../Vectorize/SandboxVectorizer/Scheduler.h | 9 +++- .../SandboxVectorizer/Passes/BottomUpVec.cpp | 5 +-- .../SandboxVectorizer/SandboxVectorizer.cpp | 6 ++- .../SandboxVectorizer/bottomup_basic.ll | 41 ++++++++++++++++++- .../SandboxVectorizer/LegalityTest.cpp | 7 ++-- .../SandboxVectorizer/SchedulerTest.cpp | 8 ++-- 9 files changed, 72 insertions(+), 17 deletions(-) diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h index 5be05bc80c4925..b498e0f189465c 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h @@ -345,6 +345,11 @@ class DependencyGraph { Interval extend(ArrayRef Instrs); /// \Returns the range of instructions included in the DAG. Interval getInterval() const { return DAGInterval; } + /// Called by the scheduler when a new instruction \p I has been created. + void notifyCreateInstr(Instruction *I) { + getOrCreateNode(I); + // TODO: Update the dependencies for the new node. + } #ifndef NDEBUG void print(raw_ostream &OS) const; LLVM_DUMP_METHOD void dump() const; diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h index 58dcb2eeadbc2d..63d6ef31c86453 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h @@ -162,8 +162,9 @@ class LegalityAnalysis { const DataLayout &DL; public: - LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL) - : Sched(AA), SE(SE), DL(DL) {} + LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL, + Context &Ctx) + : Sched(AA, Ctx), SE(SE), DL(DL) {} /// A LegalityResult factory. template ResultT &createLegalityResult(ArgsT... Args) { diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h index 46b953ff9b7f49..09369dbb496fce 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h @@ -13,6 +13,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/PassManager.h" +#include "llvm/SandboxIR/Context.h" #include "llvm/SandboxIR/PassManager.h" namespace llvm { @@ -24,6 +25,8 @@ class SandboxVectorizerPass : public PassInfoMixin { AAResults *AA = nullptr; ScalarEvolution *SE = nullptr; + std::unique_ptr Ctx; + // A pipeline of SandboxIR function passes run by the vectorizer. sandboxir::FunctionPassManager FPM; diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h index 08972d460b406e..0e4eea3880efbd 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h @@ -95,6 +95,8 @@ class Scheduler { DependencyGraph DAG; std::optional ScheduleTopItOpt; SmallVector> Bndls; + Context &Ctx; + Context::CallbackID CreateInstrCB; /// \Returns a scheduling bundle containing \p Instrs. SchedBundle *createBundle(ArrayRef Instrs); @@ -110,8 +112,11 @@ class Scheduler { Scheduler &operator=(const Scheduler &) = delete; public: - Scheduler(AAResults &AA) : DAG(AA) {} - ~Scheduler() {} + Scheduler(AAResults &AA, Context &Ctx) : DAG(AA), Ctx(Ctx) { + CreateInstrCB = Ctx.registerCreateInstrCallback( + [this](Instruction *I) { DAG.notifyCreateInstr(I); }); + } + ~Scheduler() { Ctx.unregisterCreateInstrCallback(CreateInstrCB); } bool trySchedule(ArrayRef Instrs); diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp index 37713e7da6432d..0a930d30aeab58 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp @@ -182,8 +182,6 @@ Value *BottomUpVec::vectorizeRec(ArrayRef Bndl) { } NewVec = createVectorInstr(Bndl, VecOperands); - // TODO: Notify DAG/Scheduler about new instruction - // TODO: Collect potentially dead instructions. break; } @@ -202,7 +200,8 @@ bool BottomUpVec::tryVectorize(ArrayRef Bndl) { bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) { Legality = std::make_unique( - A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout()); + A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(), + F.getContext()); Change = false; // TODO: Start from innermost BBs first for (auto &BB : F) { diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp index 790bee4a4d7f39..c22eb01d74a1cb 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp @@ -64,6 +64,9 @@ PreservedAnalyses SandboxVectorizerPass::run(Function &F, } bool SandboxVectorizerPass::runImpl(Function &LLVMF) { + if (Ctx == nullptr) + Ctx = std::make_unique(LLVMF.getContext()); + if (PrintPassPipeline) { FPM.printPipeline(outs()); return false; @@ -82,8 +85,7 @@ bool SandboxVectorizerPass::runImpl(Function &LLVMF) { } // Create SandboxIR for LLVMF and run BottomUpVec on it. - sandboxir::Context Ctx(LLVMF.getContext()); - sandboxir::Function &F = *Ctx.createFunction(&LLVMF); + sandboxir::Function &F = *Ctx->createFunction(&LLVMF); sandboxir::Analyses A(*AA, *SE); return FPM.runOnFunction(F, A); } diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll index 2b9aac93b74851..45c701a18fd9bf 100644 --- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll +++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll @@ -55,7 +55,46 @@ define void @store_fpext_load(ptr %ptr) { ret void } -; TODO: Test store_zext_fcmp_load once we implement scheduler callbacks and legality diamond check +define void @store_fcmp_zext_load(ptr %ptr) { +; CHECK-LABEL: define void @store_fcmp_zext_load( +; CHECK-SAME: ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0 +; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1 +; CHECK-NEXT: [[PTRB0:%.*]] = getelementptr i32, ptr [[PTR]], i32 0 +; CHECK-NEXT: [[PTRB1:%.*]] = getelementptr i32, ptr [[PTR]], i32 1 +; CHECK-NEXT: [[LDB0:%.*]] = load float, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LDB1:%.*]] = load float, ptr [[PTR1]], align 4 +; CHECK-NEXT: [[VECL1:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LDA0:%.*]] = load float, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LDA1:%.*]] = load float, ptr [[PTR1]], align 4 +; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[FCMP0:%.*]] = fcmp ogt float [[LDA0]], [[LDB0]] +; CHECK-NEXT: [[FCMP1:%.*]] = fcmp ogt float [[LDA1]], [[LDB1]] +; CHECK-NEXT: [[VCMP:%.*]] = fcmp ogt <2 x float> [[VECL]], [[VECL1]] +; CHECK-NEXT: [[ZEXT0:%.*]] = zext i1 [[FCMP0]] to i32 +; CHECK-NEXT: [[ZEXT1:%.*]] = zext i1 [[FCMP1]] to i32 +; CHECK-NEXT: [[VCAST:%.*]] = zext <2 x i1> [[VCMP]] to <2 x i32> +; CHECK-NEXT: store i32 [[ZEXT0]], ptr [[PTRB0]], align 4 +; CHECK-NEXT: store i32 [[ZEXT1]], ptr [[PTRB1]], align 4 +; CHECK-NEXT: store <2 x i32> [[VCAST]], ptr [[PTRB0]], align 4 +; CHECK-NEXT: ret void +; + %ptr0 = getelementptr float, ptr %ptr, i32 0 + %ptr1 = getelementptr float, ptr %ptr, i32 1 + %ptrb0 = getelementptr i32, ptr %ptr, i32 0 + %ptrb1 = getelementptr i32, ptr %ptr, i32 1 + %ldB0 = load float, ptr %ptr0 + %ldB1 = load float, ptr %ptr1 + %ldA0 = load float, ptr %ptr0 + %ldA1 = load float, ptr %ptr1 + %fcmp0 = fcmp ogt float %ldA0, %ldB0 + %fcmp1 = fcmp ogt float %ldA1, %ldB1 + %zext0 = zext i1 %fcmp0 to i32 + %zext1 = zext i1 %fcmp1 to i32 + store i32 %zext0, ptr %ptrb0 + store i32 %zext1, ptr %ptrb1 + ret void +} ; TODO: Test store_fadd_load once we implement scheduler callbacks and legality diamond check diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp index 51e7a14013299b..b5e2c302f5901e 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp @@ -110,7 +110,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float auto *CmpSLT = cast(&*It++); auto *CmpSGT = cast(&*It++); - sandboxir::LegalityAnalysis Legality(*AA, *SE, DL); + sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx); const auto &Result = Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true); EXPECT_TRUE(isa(Result)); @@ -228,7 +228,7 @@ define void @foo(ptr %ptr) { auto *St0 = cast(&*It++); auto *St1 = cast(&*It++); - sandboxir::LegalityAnalysis Legality(*AA, *SE, DL); + sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx); { // Can vectorize St0,St1. const auto &Result = Legality.canVectorize({St0, St1}); @@ -262,7 +262,8 @@ define void @foo() { return Buff == ExpectedStr; }; - sandboxir::LegalityAnalysis Legality(*AA, *SE, DL); + sandboxir::Context Ctx(C); + sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx); EXPECT_TRUE( Matches(Legality.createLegalityResult(), "Widen")); EXPECT_TRUE(Matches(Legality.createLegalityResult( diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp index 92e767e55fbddb..4a8b0ba1d7c12b 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp @@ -156,20 +156,20 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { { // Schedule all instructions in sequence. - sandboxir::Scheduler Sched(getAA(*LLVMF)); + sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); EXPECT_TRUE(Sched.trySchedule({Ret})); EXPECT_TRUE(Sched.trySchedule({S1})); EXPECT_TRUE(Sched.trySchedule({S0})); } { // Skip instructions. - sandboxir::Scheduler Sched(getAA(*LLVMF)); + sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); EXPECT_TRUE(Sched.trySchedule({Ret})); EXPECT_TRUE(Sched.trySchedule({S0})); } { // Try invalid scheduling - sandboxir::Scheduler Sched(getAA(*LLVMF)); + sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); EXPECT_TRUE(Sched.trySchedule({Ret})); EXPECT_TRUE(Sched.trySchedule({S0})); EXPECT_FALSE(Sched.trySchedule({S1})); @@ -197,7 +197,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) { auto *S1 = cast(&*It++); auto *Ret = cast(&*It++); - sandboxir::Scheduler Sched(getAA(*LLVMF)); + sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); EXPECT_TRUE(Sched.trySchedule({Ret})); EXPECT_TRUE(Sched.trySchedule({S0, S1})); EXPECT_TRUE(Sched.trySchedule({L0, L1}));