From 40e6ad97e0f5a139f589e78cccd12a0015cc6840 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 26 Aug 2024 15:30:50 +0200 Subject: [PATCH] Add optimization callbacks that fire on a marker function --- src/optim.jl | 32 +++++++++++++++++++++++++++++++- test/plugin_testsetup.jl | 30 ++++++++++++++++++++++++++++++ test/ptx_tests.jl | 16 ++++++++++++++++ test/ptx_testsetup.jl | 1 - 4 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 test/plugin_testsetup.jl diff --git a/src/optim.jl b/src/optim.jl index cd561a5c..de83dd45 100644 --- a/src/optim.jl +++ b/src/optim.jl @@ -3,7 +3,7 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=1) tm = llvm_machine(job.config.target) - global current_job + global current_job # ScopedValue? current_job = job @dispose pb=NewPMPassBuilder() begin @@ -14,6 +14,10 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level= register!(pb, LowerKernelStatePass()) register!(pb, CleanupKernelStatePass()) + for (name, callback) in PIPELINE_CALLBACKS + register!(pb, CallbackPass(name, callback)) + end + add!(pb, NewPMModulePassManager()) do mpm buildNewPMPipeline!(mpm, job, opt_level) end @@ -24,6 +28,15 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level= return end +# TODO: Priority heap to provide order between different plugins +const PIPELINE_CALLBACKS = Dict{String, Any}() +function register_plugin!(name::String, plugin) + if haskey(PIPELINE_CALLBACKS, name) + error("GPUCompiler plugin with name $name is already registered") + end + PIPELINE_CALLBACKS[name] = plugin +end + function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level) buildEarlySimplificationPipeline(mpm, job, opt_level) add!(mpm, AlwaysInlinerPass()) @@ -41,6 +54,9 @@ function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level) add!(fpm, WarnMissedTransformationsPass()) end end + for (name, callback) in PIPELINE_CALLBACKS + add!(mpm, CallbackPass(name, callback)) + end buildIntrinsicLoweringPipeline(mpm, job, opt_level) buildCleanupPipeline(mpm, job, opt_level) end @@ -423,3 +439,17 @@ function lower_ptls!(mod::LLVM.Module) return changed end LowerPTLSPass() = NewPMModulePass("GPULowerPTLS", lower_ptls!) + + +function callback_pass!(name, callback::F, mod::LLVM.Module) where F + job = current_job::CompilerJob + changed = false + + if haskey(functions(mod), name) + marker = functions(mod)[name] + changed = callback(job, marker, mod) + end + return changed +end + +CallbackPass(name, callback) = NewPMModulePass("CallbackPass<$name>", (mod)->callback_pass!(name, callback, mod)) diff --git a/test/plugin_testsetup.jl b/test/plugin_testsetup.jl new file mode 100644 index 00000000..fdab7ee5 --- /dev/null +++ b/test/plugin_testsetup.jl @@ -0,0 +1,30 @@ +@testsetup module Plugin + +using Test +using ReTestItems +import LLVM +import GPUCompiler + +function mark(x) + ccall("extern gpucompiler.mark", llvmcall, Nothing, (Int,), x) +end + +function remove_mark!(@nospecialize(job), intrinsic, mod::LLVM.Module) + changed = false + + for use in LLVM.uses(intrinsic) + val = LLVM.user(use) + if isempty(LLVM.uses(val)) + LLVM.erase!(val) + changed = true + else + # the validator will detect this + end + end + + return changed +end + +GPUCompiler.register_plugin!("gpucompiler.mark", remove_mark!) + +end \ No newline at end of file diff --git a/test/ptx_tests.jl b/test/ptx_tests.jl index 6caa6c71..600561f5 100644 --- a/test/ptx_tests.jl +++ b/test/ptx_tests.jl @@ -1,6 +1,7 @@ @testitem "PTX" setup=[PTX, Helpers] begin using LLVM +import InteractiveUtils ############################################################################################ @@ -406,7 +407,22 @@ precompile_test_harness("Inference caching") do load_path @test check_presence(identity_mi, token) end end +end # testitem ############################################################################################ +@testitem "PTX plugin" setup=[PTX, Plugin] begin + +import InteractiveUtils + +@testset "Pipeline callbacks" begin + function kernel(x) + Plugin.mark(x) + return + end + ir = sprint(io->InteractiveUtils.code_llvm(io, kernel, Tuple{Int})) + @test occursin("gpucompiler.mark", ir) + ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Int})) + @test !occursin("gpucompiler.mark", ir) end +end #testitem diff --git a/test/ptx_testsetup.jl b/test/ptx_testsetup.jl index ed5026f1..89516283 100644 --- a/test/ptx_testsetup.jl +++ b/test/ptx_testsetup.jl @@ -2,7 +2,6 @@ using GPUCompiler - # create a PTX-based test compiler, and generate reflection methods for it include("runtime.jl")