Skip to content

Commit

Permalink
Make GPUInterpreter extensible
Browse files Browse the repository at this point in the history
Currently Enzyme uses it's own AbstractInterpreter, in particular to
handle inlining blocking of functions with custom rules and to handle
nested autodiff operations.

- [ ] Create a version of Enzyme with this
- [ ] Support a version of `gpuc.deferred(meta)`
  • Loading branch information
vchuravy committed Sep 27, 2024
1 parent 828ee63 commit 9af6e00
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 24 deletions.
9 changes: 5 additions & 4 deletions src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ end
## deferred compilation

"""
var"gpuc.deferred"(f, args...)::Ptr{Cvoid}
var"gpuc.deferred"(meta, f, args...)::Ptr{Cvoid}
As if we were to call `f(args...)` but instead we are
putting down a marker and return a function pointer to later
Expand Down Expand Up @@ -199,18 +199,19 @@ const __llvm_initialized = Ref(false)
return val
end

worklist = Dict{Any, Vector{LLVM.CallInst}}()
worklist = Dict{MethodInstance, Vector{LLVM.CallInst}}()
for use in uses(dyn_marker)
# decode the call
call = user(use)::LLVM.CallInst
dyn_mi_inst = find_base_object(operands(call)[1])
dyn_mi_inst = find_base_object(operands(call)[2])
@compiler_assert isa(dyn_mi_inst, LLVM.ConstantInt) job
dyn_mi = Base.unsafe_pointer_to_objref(
convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst)))
convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst)))::MethodInstance
push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call)
end

for dyn_mi in keys(worklist)
# TODO: Should compiled become Edge[]
dyn_fn_name = compiled[dyn_mi].specfunc
dyn_fn = functions(ir)[dyn_fn_name]

Expand Down
4 changes: 2 additions & 2 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,12 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false
# provide a specific interpreter to use.
if VERSION >= v"1.11.0-DEV.1552"
get_interpreter(@nospecialize(job::CompilerJob)) =
GPUInterpreter(job.world; method_table=method_table(job),
GPUInterpreter(job.world; meta=nothing, method_table=method_table(job),
token=ci_cache_token(job), inf_params=inference_params(job),
opt_params=optimization_params(job))
else
get_interpreter(@nospecialize(job::CompilerJob)) =
GPUInterpreter(job.world; method_table=method_table(job),
GPUInterpreter(job.world; meta=nothing, method_table=method_table(job),
code_cache=ci_cache(job), inf_params=inference_params(job),
opt_params=optimization_params(job))
end
Expand Down
206 changes: 189 additions & 17 deletions src/jlgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ else
end

struct GPUInterpreter <: CC.AbstractInterpreter
meta::Any
world::UInt
method_table::GPUMethodTableView

Expand All @@ -336,6 +337,7 @@ end

@static if HAS_INTEGRATED_CACHE
function GPUInterpreter(world::UInt=Base.get_world_counter();
meta = nothing,
method_table::MTType,
token::Any,
inf_params::CC.InferenceParams,
Expand All @@ -345,26 +347,28 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
method_table = get_method_table_view(world, method_table)
inf_cache = Vector{CC.InferenceResult}()

return GPUInterpreter(world, method_table,
return GPUInterpreter(meta, world, method_table,
token, inf_cache,
inf_params, opt_params)
end

function GPUInterpreter(interp::GPUInterpreter;
meta=interp.meta,
world::UInt=interp.world,
method_table::GPUMethodTableView=interp.method_table,
token::Any=interp.token,
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
inf_params::CC.InferenceParams=interp.inf_params,
opt_params::CC.OptimizationParams=interp.opt_params)
return GPUInterpreter(world, method_table,
return GPUInterpreter(meta, world, method_table,
token, inf_cache,
inf_params, opt_params)
end

else

function GPUInterpreter(world::UInt=Base.get_world_counter();
meta=nothing,
method_table::MTType,
code_cache::CodeCache,
inf_params::CC.InferenceParams,
Expand All @@ -374,19 +378,20 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
method_table = get_method_table_view(world, method_table)
inf_cache = Vector{CC.InferenceResult}()

return GPUInterpreter(world, method_table,
return GPUInterpreter(meta, world, method_table,
code_cache, inf_cache,
inf_params, opt_params)
end

function GPUInterpreter(interp::GPUInterpreter;
meta=interp.meta,
world::UInt=interp.world,
method_table::GPUMethodTableView=interp.method_table,
code_cache::CodeCache=interp.code_cache,
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
inf_params::CC.InferenceParams=interp.inf_params,
opt_params::CC.OptimizationParams=interp.opt_params)
return GPUInterpreter(world, method_table,
return GPUInterpreter(meta, world, method_table,
code_cache, inf_cache,
inf_params, opt_params)
end
Expand Down Expand Up @@ -437,28 +442,76 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter,
end


within_gpucompiler() = false

## deferred compilation

struct DeferredCallInfo <: CC.CallInfo
meta::Any
rt::DataType
info::CC.CallInfo
end

# recognize calls to gpuc.deferred and save DeferredCallInfo metadata
function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int = CC.get_max_methods(interp, f, sv))
# default implementation, extensible through meta argument.
# XXX: (or should we dispatch on `f`)?
function abstract_call_known(meta::Nothing, interp::GPUInterpreter, @nospecialize(f),
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int = CC.get_max_methods(interp, f, sv))
(; fargs, argtypes) = arginfo
if f === var"gpuc.deferred"
argvec = argtypes[2:end]
argvec = argtypes[3:end]
call = CC.abstract_call(interp, CC.ArgInfo(nothing, argvec), si, sv, max_methods)
callinfo = DeferredCallInfo(call.rt, call.info)
metaT = argtypes[2]
meta = CC.singleton_type(metaT)
if meta === nothing
if metaT isa Core.Const
meta = metaT.val
else
# meta is not a singleton type result may depend on runtime configuration
add_remark!(interp, sv, "Skipped gpuc.deferred since meta not constant")
@static if VERSION < v"1.11.0-"
return CC.CallMeta(Union{}, CC.Effects(), CC.NoCallInfo())
else
return CC.CallMeta(Union{}, Union{}, CC.Effects(), CC.NoCallInfo())
end
end
end

callinfo = DeferredCallInfo(meta, call.rt, call.info)
@static if VERSION < v"1.11.0-"
return CC.CallMeta(Ptr{Cvoid}, CC.Effects(), callinfo)
else
return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo)
end
elseif f === within_gpucompiler
if length(argtypes) != 1
@static if VERSION < v"1.11.0-"
return CC.CallMeta(Union{}, CC.Effects(), CC.NoCallInfo())
else
return CC.CallMeta(Union{}, Union{}, CC.Effects(), CC.NoCallInfo())
end
end
@static if VERSION < v"1.11.0-"
return CC.CallMeta(Core.Const(true), CC.EFFECTS_TOTAL, CC.MethodResultPure())
else
return CC.CallMeta(Core.Const(true), Union{}, CC.EFFECTS_TOTAL, CC.MethodResultPure(),)
end
end
return nothing
end

function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int = CC.get_max_methods(interp, f, sv))
candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods)
if candidate === nothing && interp.meta !== nothing
candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods)
end
if candidate !== nothing
return candidate
end

return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f,
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int)
Expand All @@ -485,32 +538,39 @@ function CC.handle_call!(todo::Vector{Pair{Int,Any}}, ir::CC.IRCode, idx::CC.Int
args = Any[
"extern gpuc.lookup",
Ptr{Cvoid},
Core.svec(Any, Any, match.spec_types.parameters[2:end]...), # Must use Any for MethodInstance or ftype
Core.svec(Any, Any, Any, match.spec_types.parameters[2:end]...), # Must use Any for MethodInstance or ftype
0,
QuoteNode(:llvmcall),
info.meta,
case.invoke,
stmt.args[2:end]...
stmt.args[3:end]...
]
stmt.head = :foreigncall
stmt.args = args
return nothing
end

struct Edge
meta::Any
mi::MethodInstance
end

struct DeferredEdges
edges::Vector{MethodInstance}
edges::Vector{Edge}
end

function find_deferred_edges(ir::CC.IRCode)
edges = MethodInstance[]
edges = Edge[]
# XXX: can we add this instead in handle_call?
for stmt in ir.stmts
inst = stmt[:inst]
inst isa Expr || continue
expr = inst::Expr
if expr.head === :foreigncall &&
expr.args[1] == "extern gpuc.lookup"
deferred_mi = expr.args[6]
push!(edges, deferred_mi)
deferred_meta = expr.args[6]
deferred_mi = expr.args[7]
push!(edges, Edge(deferred_meta, deferred_mi))
end
end
unique!(edges)
Expand Down Expand Up @@ -542,6 +602,116 @@ function CC.finish(interp::GPUInterpreter, opt::CC.OptimizationState, ir::CC.IRC
end
end

import .CC: CallInfo
struct NoInlineCallInfo <: CallInfo
info::CallInfo # wrapped call
tt::Any # ::Type
kind::Symbol
NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) =
new(info, tt, kind)
end

CC.nsplit_impl(info::NoInlineCallInfo) = CC.nsplit(info.info)
CC.getsplit_impl(info::NoInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::NoInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
struct AlwaysInlineCallInfo <: CallInfo
info::CallInfo # wrapped call
tt::Any # ::Type
AlwaysInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt)) = new(info, tt)
end

CC.nsplit_impl(info::AlwaysInlineCallInfo) = Core.Compiler.nsplit(info.info)
CC.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)


function inlining_handler(meta::Nothing, interp::GPUInterpreter, @nospecialize(atype), callinfo)
return nothing
end

using Core.Compiler: ArgInfo, StmtInfo, AbsIntState
function CC.abstract_call_gf_by_type(interp::GPUInterpreter, @nospecialize(f), arginfo::ArgInfo,
si::StmtInfo, @nospecialize(atype), sv::AbsIntState, max_methods::Int)
ret = @invoke CC.abstract_call_gf_by_type(interp::CC.AbstractInterpreter, f::Any, arginfo::ArgInfo,
si::StmtInfo, atype::Any, sv::AbsIntState, max_methods::Int)

callinfo = nothing
if interp.meta !== nothing
callinfo = inlining_handler(interp.meta, interp, atype, ret.info)
end
if callinfo === nothing
callinfo = inlining_handler(nothing, interp, atype, ret.info)
end
if callinfo === nothing
callinfo = ret.info
end

@static if VERSION v"1.11-"
return CC.CallMeta(ret.rt, ret.exct, ret.effects, callinfo)
else
return CC.CallMeta(ret.rt, ret.effects, callinfo)
end
end

@static if VERSION < v"1.12.0-DEV.45"
let # overload `inlining_policy`
@static if VERSION v"1.11.0-DEV.879"
sigs_ex = :(
interp::GPUInterpreter,
@nospecialize(src),
@nospecialize(info::CC.CallInfo),
stmt_flag::UInt32,
)
args_ex = :(
interp::CC.AbstractInterpreter,
src::Any,
info::CC.CallInfo,
stmt_flag::UInt32,
)
else
sigs_ex = :(
interp::GPUInterpreter,
@nospecialize(src),
@nospecialize(info::CC.CallInfo),
stmt_flag::UInt8,
mi::MethodInstance,
argtypes::Vector{Any},
)
args_ex = :(
interp::CC.AbstractInterpreter,
src::Any,
info::CC.CallInfo,
stmt_flag::UInt8,
mi::MethodInstance,
argtypes::Vector{Any},
)
end
@eval function CC.inlining_policy($(sigs_ex.args...))
if info isa NoInlineCallInfo
@safe_debug "Blocking inlining" info.tt info.kind
return nothing
elseif info isa AlwaysInlineCallInfo
@safe_debug "Forcing inlining for" info.tt
return src
end
return @invoke CC.inlining_policy($(args_ex.args...))
end
end
else
function CC.src_inlining_policy(interp::GPUInterpreter,
@nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::UInt32)

if info isa NoInlineCallInfo
@safe_debug "Blocking inlining" info.tt info.kind
return false
elseif info isa AlwaysInlineCallInfo
@safe_debug "Forcing inlining for" info.tt
return true
end
return @invoke CC.src_inlining_policy(interp::CC.AbstractInterpreter, src, info::CC.CallInfo, stmt_flag::UInt32)
end
end


## world view of the cache
using Core.Compiler: WorldView
Expand Down Expand Up @@ -704,7 +874,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
source = pop!(worklist)
haskey(compiled, source) && continue # We have fulfilled the request already
# Create a new compiler job for this edge, reusing the config settings from the inital one
job2 = CompilerJob(source, job.config)
job2 = CompilerJob(source, job.config) # TODO: GPUInterpreter.meta in config?
llvm_mod2, outstanding = compile_method_instance(job2, compiled)
append!(worklist, outstanding) # merge worklist with new outstanding edges
@assert context(llvm_mod) == context(llvm_mod2)
Expand Down Expand Up @@ -844,7 +1014,9 @@ function compile_method_instance(@nospecialize(job::CompilerJob), compiled::IdDi
end
end
if edges !== nothing
for deferred_mi in (edges::DeferredEdges).edges
for edge in (edges::DeferredEdges).edges
# TODO
deferred_mi = edge.mi
if !haskey(compiled, deferred_mi)
push!(outstanding, deferred_mi)
end
Expand Down
2 changes: 1 addition & 1 deletion test/native_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ end
@testset "deferred" begin
@gensym child kernel unrelated
@eval @noinline $child(i) = i
@eval $kernel(i) = GPUCompiler.var"gpuc.deferred"($child, i)
@eval $kernel(i) = GPUCompiler.var"gpuc.deferred"(nothing, $child, i)

# smoke test
job, _ = Native.create_job(eval(kernel), (Int64,))
Expand Down
Loading

0 comments on commit 9af6e00

Please sign in to comment.