Skip to content

Commit

Permalink
Fix interpreter for Julia 1.11
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio Sánchez Ramírez committed Jul 15, 2024
1 parent 810c1bb commit c2a1100
Showing 1 changed file with 61 additions and 35 deletions.
96 changes: 61 additions & 35 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,63 +6,89 @@
const CC = Core.Compiler
using Enzyme

const HAS_INTEGRATED_CACHE = VERSION >= v"1.11.0-DEV.1552"

Base.Experimental.@MethodTable(ReactantMethodTable)

macro reactant_override(expr)
return :(Base.Experimental.@overlay ReactantMethodTable $(esc(expr)))
return :(Base.Experimental.@overlay ReactantMethodTable $(expr))
end

struct ReactantCache
dict::IdDict{Core.MethodInstance,Core.CodeInstance}
end
ReactantCache() = ReactantCache(IdDict{Core.MethodInstance,Core.CodeInstance}())
@static if !HAS_INTEGRATED_CACHE
struct ReactantCache
dict::IdDict{Core.MethodInstance,Core.CodeInstance}
end
ReactantCache() = ReactantCache(IdDict{Core.MethodInstance,Core.CodeInstance}())

function CC.get(wvc::CC.WorldView{ReactantCache}, mi::Core.MethodInstance, default)
return get(wvc.cache.dict, mi, default)
end
function CC.getindex(wvc::CC.WorldView{ReactantCache}, mi::Core.MethodInstance)
return getindex(wvc.cache.dict, mi)
end
function CC.haskey(wvc::CC.WorldView{ReactantCache}, mi::Core.MethodInstance)
return haskey(wvc.cache.dict, mi)
end
function CC.setindex!(
wvc::CC.WorldView{ReactantCache}, ci::Core.CodeInstance, mi::Core.MethodInstance
)
return setindex!(wvc.cache.dict, ci, mi)
function CC.get(wvc::CC.WorldView{ReactantCache}, mi::Core.MethodInstance, default)
return get(wvc.cache.dict, mi, default)
end
function CC.getindex(wvc::CC.WorldView{ReactantCache}, mi::Core.MethodInstance)
return getindex(wvc.cache.dict, mi)
end
function CC.haskey(wvc::CC.WorldView{ReactantCache}, mi::Core.MethodInstance)
return haskey(wvc.cache.dict, mi)
end
function CC.setindex!(
wvc::CC.WorldView{ReactantCache}, ci::Core.CodeInstance, mi::Core.MethodInstance
)
return setindex!(wvc.cache.dict, ci, mi)
end
end

struct ReactantInterpreter <: CC.AbstractInterpreter
# compiler::ReactantCompiler
world::UInt
inf_params::CC.InferenceParams
opt_params::CC.OptimizationParams
inf_cache::Vector{CC.InferenceResult}
code_cache::ReactantCache

function ReactantInterpreter(;
world::UInt=Base.get_world_counter(),
inf_params::CC.InferenceParams=CC.InferenceParams(),
opt_params::CC.OptimizationParams=CC.OptimizationParams(),
inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[],
code_cache::ReactantCache=ReactantCache(),
)
return new(world, inf_params, opt_params, inf_cache, code_cache)
@static if !HAS_INTEGRATED_CACHE
code_cache::ReactantCache
end

@static if HAS_INTEGRATED_CACHE
function ReactantInterpreter(;
world::UInt=Base.get_world_counter(),
inf_params::CC.InferenceParams=CC.InferenceParams(),
opt_params::CC.OptimizationParams=CC.OptimizationParams(),
inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[],
)
return new(world, inf_params, opt_params, inf_cache)
end
else
function ReactantInterpreter(;
world::UInt=Base.get_world_counter(),
inf_params::CC.InferenceParams=CC.InferenceParams(),
opt_params::CC.OptimizationParams=CC.OptimizationParams(),
inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[],
code_cache=ReactantCache(),
)
return new(world, inf_params, opt_params, inf_cache, code_cache)
end
end
end

@static if HAS_INTEGRATED_CACHE
CC.get_inference_world(interp::ReactantInterpreter) = interp.world
else
CC.get_world_counter(interp::ReactantInterpreter) = interp.world
end

CC.InferenceParams(interp::ReactantInterpreter) = interp.inf_params
CC.OptimizationParams(interp::ReactantInterpreter) = interp.opt_params
# CC.get_inference_world(interp::ReactantInterpreter) = interp.world
CC.get_world_counter(interp::ReactantInterpreter) = interp.world
CC.get_inference_cache(interp::ReactantInterpreter) = interp.inf_cache
# CC.cache_owner(interp::ReactantInterpreter) = interp.compiler
function CC.code_cache(interp::ReactantInterpreter)
return CC.WorldView(interp.code_cache, CC.WorldRange(interp.world))

@static if HAS_INTEGRATED_CACHE
# TODO what does this do? taken from https://github.com/JuliaLang/julia/blob/v1.11.0-rc1/test/compiler/newinterp.jl
@eval CC.cache_owner(interp::ReactantInterpreter) =
$(QuoteNode(gensym(:ReactantInterpreterCache)))
else
function CC.code_cache(interp::ReactantInterpreter)
return CC.WorldView(interp.code_cache, CC.WorldRange(interp.world))
end
end

function CC.method_table(interp::ReactantInterpreter)
return CC.OverlayMethodTable(CC.get_world_counter(interp), ReactantMethodTable)
return CC.OverlayMethodTable(interp.world, ReactantMethodTable)
end

const enzyme_out = 0
Expand Down

0 comments on commit c2a1100

Please sign in to comment.