diff --git a/src/compiler.jl b/src/compiler.jl index 32dd293ece..feccbe0403 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4048,7 +4048,7 @@ if VERSION >= v"1.11.0-DEV.1552" always_inline::Any method_table::Core.MethodTable param_type::Type - is_fwd::Bool + mode::API.CDerivativeMode end GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = @@ -4057,15 +4057,15 @@ if VERSION >= v"1.11.0-DEV.1552" job.config.always_inline, GPUCompiler.method_table(job), typeof(job.config.params), - job.config.params.mode == API.DEM_ForwardMode, + API.DEM_ForwardMode, ) GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = - Interpreter.EnzymeInterpreter( + GPUCompiler.GPUInterpreter( GPUCompiler.ci_cache_token(job), GPUCompiler.method_table(job), job.world, - job.config.params.mode, + meta=Interpreter.EnzymeMeta(job.config.params.mode), ) else @@ -4074,6 +4074,7 @@ else # rule or not inlining a rev mode rule. Otherwise, all caches can be re-used. const GLOBAL_FWD_CACHE = GPUCompiler.CodeCache() const GLOBAL_REV_CACHE = GPUCompiler.CodeCache() + # TODO: Branch on target... otherwise GPU and CPU code end in the same cache function enzyme_ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) return if job.config.params.mode == API.DEM_ForwardMode GLOBAL_FWD_CACHE @@ -4086,11 +4087,11 @@ else enzyme_ci_cache(job) GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = - Interpreter.EnzymeInterpreter( + GPUCompiler.GPUInterpreter( enzyme_ci_cache(job), GPUCompiler.method_table(job), job.world, - job.config.params.mode, + meta=Interpreter.EnzymeMeta(job.config.params.mode), ) end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 4d48297ae5..907f80a29b 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -1,18 +1,5 @@ module Interpreter import Enzyme: API -using Core.Compiler: - AbstractInterpreter, - InferenceResult, - InferenceParams, - InferenceState, - OptimizationParams, - MethodInstance -using GPUCompiler: @safe_debug -if VERSION < v"1.11.0-DEV.1552" - using GPUCompiler: CodeCache, WorldView, @safe_debug -end -const HAS_INTEGRATED_CACHE = VERSION >= v"1.11.0-DEV.1552" - import ..Enzyme import ..EnzymeRules @@ -22,93 +9,10 @@ import ..EnzymeRules else import Core.Compiler: get_world_counter, get_world_counter as get_inference_world end - -struct EnzymeInterpreter <: AbstractInterpreter - @static if HAS_INTEGRATED_CACHE - token::Any - else - code_cache::CodeCache - end - method_table::Union{Nothing,Core.MethodTable} - - # Cache of inference results for this particular interpreter - local_cache::Vector{InferenceResult} - # The world age we're working inside of - world::UInt - - # Parameters for inference and optimization - inf_params::InferenceParams - opt_params::OptimizationParams - +struct EnzymeMeta mode::API.CDerivativeMode end -function EnzymeInterpreter( - cache_or_token, - mt::Union{Nothing,Core.MethodTable}, - world::UInt, - mode::API.CDerivativeMode, -) - @assert world <= Base.get_world_counter() - - parms = @static if VERSION < v"1.12" - InferenceParams(unoptimize_throw_blocks = false) - else - InferenceParams() - end - - return EnzymeInterpreter( - cache_or_token, - mt, - - # Initially empty cache - Vector{InferenceResult}(), - - # world age counter - world, - - # parameters for inference and optimization - parms, - OptimizationParams(), - mode, - ) -end - -Core.Compiler.InferenceParams(interp::EnzymeInterpreter) = interp.inf_params -Core.Compiler.OptimizationParams(interp::EnzymeInterpreter) = interp.opt_params -get_inference_world(interp::EnzymeInterpreter) = interp.world -Core.Compiler.get_inference_cache(interp::EnzymeInterpreter) = interp.local_cache -@static if HAS_INTEGRATED_CACHE - Core.Compiler.cache_owner(interp::EnzymeInterpreter) = interp.token -else - Core.Compiler.code_cache(interp::EnzymeInterpreter) = - WorldView(interp.code_cache, interp.world) -end - -# No need to do any locking since we're not putting our results into the runtime cache -Core.Compiler.lock_mi_inference(::EnzymeInterpreter, ::MethodInstance) = nothing -Core.Compiler.unlock_mi_inference(::EnzymeInterpreter, ::MethodInstance) = nothing - -Core.Compiler.may_optimize(::EnzymeInterpreter) = true -Core.Compiler.may_compress(::EnzymeInterpreter) = true -# From @aviatesk: -# `may_discard_trees = true`` means a complicated (in terms of inlineability) source will be discarded, -# but as far as I understand Enzyme wants "always inlining, except special cased functions", -# so I guess we really don't want to discard sources? -Core.Compiler.may_discard_trees(::EnzymeInterpreter) = false -Core.Compiler.verbose_stmt_info(::EnzymeInterpreter) = false - -if isdefined(Base.Experimental, Symbol("@overlay")) - Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = - Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) -else - - # On 1.6- CUDA.jl will poison the method table at the end of the world - # using GPUCompiler: WorldOverlayMethodTable - # Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = - # WorldOverlayMethodTable(interp.world) -end - function is_alwaysinline_func(@nospecialize(TT)) isa(TT, DataType) || return false return false @@ -149,153 +53,41 @@ function simplify_kw(@nospecialize specTypes) end end -import Core.Compiler: CallInfo -struct NoInlineCallInfo <: CallInfo - info::CallInfo # wrapped call - tt::Any # ::Type - kind::Symbol - NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) = - new(info, tt, kind) -end -Core.Compiler.nsplit_impl(info::NoInlineCallInfo) = Core.Compiler.nsplit(info.info) -Core.Compiler.getsplit_impl(info::NoInlineCallInfo, idx::Int) = - Core.Compiler.getsplit(info.info, idx) -Core.Compiler.getresult_impl(info::NoInlineCallInfo, idx::Int) = - Core.Compiler.getresult(info.info, idx) -struct AlwaysInlineCallInfo <: CallInfo - info::CallInfo # wrapped call - tt::Any # ::Type - AlwaysInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt)) = new(info, tt) -end -Core.Compiler.nsplit_impl(info::AlwaysInlineCallInfo) = Core.Compiler.nsplit(info.info) -Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = - Core.Compiler.getsplit(info.info, idx) -Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = - Core.Compiler.getresult(info.info, idx) - -using Core.Compiler: ArgInfo, StmtInfo, AbsIntState -function Core.Compiler.abstract_call_gf_by_type( - interp::EnzymeInterpreter, - @nospecialize(f), - arginfo::ArgInfo, - si::StmtInfo, - @nospecialize(atype), - sv::AbsIntState, - max_methods::Int, -) - ret = @invoke Core.Compiler.abstract_call_gf_by_type( - interp::AbstractInterpreter, - f::Any, - arginfo::ArgInfo, - si::StmtInfo, - atype::Any, - sv::AbsIntState, - max_methods::Int, - ) - callinfo = ret.info +import GPUCompiler: GPUInterpreter, NoInlineCallInfo, AlwaysInlineCallInfo +function inlining_handler(meta::EnzymeMeta, interp::GPUInterpreter, @nospecialize(atype), callinfo) method_table = Core.Compiler.method_table(interp) + world = get_inference_world(interp) + specTypes = simplify_kw(atype) if is_primitive_func(specTypes) - callinfo = NoInlineCallInfo(callinfo, atype, :primitive) + return NoInlineCallInfo(callinfo, atype, :primitive) elseif is_alwaysinline_func(specTypes) - callinfo = AlwaysInlineCallInfo(callinfo, atype) - elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) - callinfo = NoInlineCallInfo(callinfo, atype, :inactive) - elseif interp.mode == API.DEM_ForwardMode - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) - callinfo = NoInlineCallInfo(callinfo, atype, :frule) + return AlwaysInlineCallInfo(callinfo, atype) + elseif EnzymeRules.is_inactive_from_sig(specTypes; world, method_table) + return NoInlineCallInfo(callinfo, atype, :inactive) + elseif meta.mode == API.DEM_ForwardMode + if EnzymeRules.has_frule_from_sig(specTypes; world, method_table) + return NoInlineCallInfo(callinfo, atype, :frule) end - elseif EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) - callinfo = NoInlineCallInfo(callinfo, atype, :rrule) - end - @static if VERSION ≥ v"1.11-" - return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) - else - return Core.Compiler.CallMeta(ret.rt, ret.effects, callinfo) - end -end - -let # overload `inlining_policy` - @static if VERSION ≥ v"1.11.0-DEV.879" - sigs_ex = :( - interp::EnzymeInterpreter, - @nospecialize(src), - @nospecialize(info::Core.Compiler.CallInfo), - stmt_flag::UInt32, - ) - args_ex = :( - interp::AbstractInterpreter, - src::Any, - info::Core.Compiler.CallInfo, - stmt_flag::UInt32, - ) - else - sigs_ex = :( - interp::EnzymeInterpreter, - @nospecialize(src), - @nospecialize(info::Core.Compiler.CallInfo), - stmt_flag::UInt8, - mi::MethodInstance, - argtypes::Vector{Any}, - ) - args_ex = :( - interp::AbstractInterpreter, - src::Any, - info::Core.Compiler.CallInfo, - stmt_flag::UInt8, - mi::MethodInstance, - argtypes::Vector{Any}, - ) - end - @eval function Core.Compiler.inlining_policy($(sigs_ex.args...)) - if info isa NoInlineCallInfo - if info.kind === :primitive - @safe_debug "Blocking inlining for primitive func" info.tt - elseif info.kind === :inactive - @safe_debug "Blocking inlining due to inactive rule" info.tt - elseif info.kind === :frule - @safe_debug "Blocking inlining due to frule" info.tt - else - @assert info.kind === :rrule - @safe_debug "Blocking inlining due to rrule" info.tt - end - return nothing - elseif info isa AlwaysInlineCallInfo - @safe_debug "Forcing inlining for primitive func" info.tt - return src + elseif meta.mode == API.DEM_ReverseModeCombined || + meta.mode == API.DEM_ReverseModePrimal || + meta.mode == API.DEM_ReverseModeGradient + if EnzymeRules.has_rrule_from_sig(specTypes; world, method_table) + return NoInlineCallInfo(callinfo, atype, :rrule) end - return @invoke Core.Compiler.inlining_policy($(args_ex.args...)) end + return nothing end -import Core.Compiler: - abstract_call, - abstract_call_known, - ArgInfo, - StmtInfo, - AbsIntState, - get_max_methods, - CallMeta, - Effects, - NoCallInfo, - widenconst, - mapany, - MethodResultPure - -struct AutodiffCallInfo <: CallInfo +struct AutodiffCallInfo <: CC.CallInfo # ... - info::CallInfo + info::CC.CallInfo end -function abstract_call_known( - interp::EnzymeInterpreter, - @nospecialize(f), - arginfo::ArgInfo, - si::StmtInfo, - sv::AbsIntState, - max_methods::Int = get_max_methods(interp, f, sv), -) - +import GPUCompiler: abstract_call_known +import CC: CallMeta, Effects, NoCallInfo +function abstract_call_known(meta::EnzymeMeta, interp::GPUInterpreter, @nospecialize(f), + arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int) (; fargs, argtypes) = arginfo if f === Enzyme.within_autodiff @@ -307,18 +99,9 @@ function abstract_call_known( end end @static if VERSION < v"1.11.0-" - return CallMeta( - Core.Const(true), - Core.Compiler.EFFECTS_TOTAL, - MethodResultPure(), - ) + return CallMeta(Core.Const(true), CC.EFFECTS_TOTAL, MethodResultPure()) else - return CallMeta( - Core.Const(true), - Union{}, - Core.Compiler.EFFECTS_TOTAL, - MethodResultPure(), - ) + return CallMeta(Core.Const(true), Union{}, CC.EFFECTS_TOTAL, MethodResultPure(),) end end @@ -331,6 +114,7 @@ function abstract_call_known( [:(Enzyme.autodiff_deferred), fargs[2:end]...], [Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...], ) + # FIXME: Use AutodiffCallInfo and a custom inlining handler return abstract_call_known( interp, Enzyme.autodiff_deferred, @@ -341,14 +125,7 @@ function abstract_call_known( ) end end - return Base.@invoke abstract_call_known( - interp::AbstractInterpreter, - f, - arginfo::ArgInfo, - si::StmtInfo, - sv::AbsIntState, - max_methods::Int, - ) + return nothing end end