-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Start forward mode AD #389
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great. I've left a few comments, but if you're planning to do a bunch of additional stuff, then maybe they're redundant. Either way, don't feel the need to respond to them.
Co-authored-by: Will Tebbutt <[email protected]> Signed-off-by: Guillaume Dalle <[email protected]>
@willtebbutt following our discussion yesterday I scratched my head some more, and I decided that it would be infinitely simpler to enforce the invariant that one line of primal IR maps to one line of dual IR. While this may require additional fallbacks in the Julia code itself, I hope it will make our lives much easier on the IR side. What do you think? |
I think this could work. You could just replace the @inline function call_frule!!(rule::R, fargs::Vararg{Any, N}) where {N}
return rule(map(x -> x isa Dual ? x : zero_dual(x), fargs)...)
end The optimisation pass will lower this to the what we were thinking about writing out in the IR anyway. I think the other important kinds of nodes would be largely straightforward to handle. |
I think we might need to be slightly more subtle. If an argument to the |
Yes. I think my propose code handles this though, or am I missing something? |
In the spirit of higher-order AD, we may encounter |
Very good point.
Agreed. Specifically, I think we need to distinguish between literals / |
I still need to dig into the different node types we might encounter (and I still don't understand |
I was reviewing the design docs and realised that, sadly, the "one line of primal IR maps to one line of dual IR" won't work for |
I think that's okay, the main trouble is adding new lines which insert new variables because it requires manual renumbering. A GoTo should be much simpler. |
Were the difficulties around renumbering etc not resolved by not |
No they weren't. I experimented with |
Ah, right, but we do need to insert a new SSAValue. Suppose that the GotoIfNot(%5, #3) i.e. jump to block 3 if not %new_ssa = Expr(:call, primal, %5)
GotoIfNot(%new_ssa, #3) Does this not cause the same kind of problems? |
Oh yes you're probably right. Although it might be slightly less of a hassle because the new SSA is only used in one spot, right after. I'll take a look |
Do you know what I should do about expressions of type |
Yup -- I just strip them out of the IR entirely in reverse-mode. See
The way to remove an instruction from an |
I think this works for
MWE (requires this branch of Mooncake): const CC = Core.Compiler
using Mooncake
using MistyClosures
f(x) = x > 1 ? 2x : 3 + x
ir = Base.code_ircode(f, (Float64,))[1][1]
initial_ir = copy(ir)
get_primal_inst = CC.NewInstruction(Expr(:call, +, 1, 2), Any) # placeholder for get_primal
CC.insert_node!(ir, CC.SSAValue(3), get_primal_inst, false)
ir = CC.compact!(ir)
for k in 1:length(ir.stmts)
inst = ir[CC.SSAValue(k)][:stmt]
if inst isa Core.GotoIfNot
Mooncake.replace_call!(ir,CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k-1), inst.dest))
end
end
ir julia> initial_ir
5 1 ─ %1 = Base.lt_float(1.0, _2)::Bool │╻╷╷ >
│ %2 = Base.or_int(%1, false)::Bool ││╻ <
└── goto #3 if not %2 │
2 ─ %4 = Base.mul_float(2.0, _2)::Float64 ││╻ *
└── return %4 │
3 ─ %6 = Base.add_float(3.0, _2)::Float64 ││╻ +
└── return %6 │
julia> ir
5 1 ─ %1 = Base.lt_float(1.0, _2)::Bool │╻╷╷ >
│ Base.or_int(%1, false)::Bool ││╻ <
│ %3 = (+)(1, 2)::Any │
└── goto #3 if not %3 │
2 ─ %5 = Base.mul_float(2.0, _2)::Float64 ││╻ *
└── return %5 │
3 ─ %7 = Base.add_float(3.0, _2)::Float64 ││╻ +
└── return %7 |
Merge conflict resolved. I've also modified |
Do we need to keep both "zero" macros or would one be enough? |
I think one would probably be enough. I think I've messed something up though -- going to fix and push. |
Okay. When you get a minute, could you take a look at my latest commit, and see if you're seeing the same problem as before? |
The test failure on case 10 looks the same, but there's a new one on case 6 which wasn't there before the merge. Is that what you were asking? |
I think it must be that something is going wrong with The only way that I can imagine that this has happened is that something has gone wrong with the edit: look at the |
Hey @willtebbutt, making slow progress on the test cases. Here's where I need some help:
The segfault is especially puzzling. I chose to do nothing to |
I've been struggling with this in other places. I think that there's something going on in type inference, but I'm not 100% sure what exactly it is. As you've observed, it is often (somehow) fixed by advancing the world age, but I don't understand why at all. Sorry not to be of more help. Perhaps this is something that is best left to one side for now?
I think what's going on here is that I've got an outer constructor for # Always sharpen the first thing if it's a type so static dispatch remains possible.
function Dual(x::Type{P}, dx::NoTangent) where {P}
return Dual{@isdefined(P) ? Type{P} : typeof(x),NoTangent}(P, dx)
end
This is a fun one. Let's look at the optimised IR for the rule: julia> rule.fwd_oc.ir[]
231 1 ─ %1 = (Mooncake._primal)(_3)::Bool │
└── goto #3 if not %1 │
2 ─ goto #4 │
3 ─ nothing::Nothing │
4 ┄ %5 = φ (#2 => true, #3 => false)::Bool │
│ %6 = φ (#2 => Mooncake.TestResources.__x_for_gref_test, #3 => 1)::Union{Float64, Int64} │
232 │ %7 = (Mooncake._primal)(%5)::Bool │
231 └── goto #6 if not %7 │
232 5 ─ π (%6, Dual{Float64, Float64}) │
└── unreachable │
6 ─ nothing::Nothing │
7 ─ π (%6, Dual{Int64, NoTangent}) │
└── unreachable Observe that As to the segfault: I'm pretty sure this is a result of the Segfaults scare me, so I ensure that improperly-implemented rules don't ever cause segfaults in hard-to-locate places by shoving a |
Yes, that is indeed the source (or at least a source) of the error, well spotted. I'm running into a deeper issue when I try to fix it: when we get the signature of a method to replace an |
Ah, right, yes. I handle this in reverse-mode by making the "name" of a statement invariant under insertion / deletion of other statements, meaning that this problem never really happens (this is one aspect of what the I wonder whether someone else who works with IRCode regularly (such as @vchuravy ) has thoughts on / has developed strategies to make this kind of work convenient? I'd very much like to know about them if they exist. |
I think last time we chatted about this with Valentin he recommended using the native Julia IR whenever possible to increase compatibility and readability, which I rather agree with. So I would love to stick with that in forward mode because it's much easier for me than it must have been for you in reverse mode. |
I think in Enzyme we handle this by never touching the primal code and then creating a new IR where we copy over things and maintain a mapping from statement to statement. |
I have one curious question for @gdalle and @willtebbutt: Can we use the same forward-mode infrastructure to perform Jax-style symbolic tracing? Briefly, I think this involves writing a different set of |
@willtebbutt more progress, and more bugs for you to look at:
The stack overflow is especially puzzling to me. Here's a reproducer: julia> using Mooncake
julia> f6 = Mooncake.TestResources.globalref_tester_6
globalref_tester_6 (generic function with 1 method)
julia> build_frule(f6)
ERROR: StackOverflowError:
Stacktrace:
[1] _methods_by_ftype(t::Any, mt::Nothing, lim::Int64, world::UInt64, ambig::Bool, min::Core.Compiler.RefValue{…}, max::Core.Compiler.RefValue{…}, has_ambig::Core.Compiler.RefValue{…})
@ Core.Compiler ./reflection.jl:1182
[2] _findall
@ ./compiler/methodtable.jl:95 [inlined]
[3] #findall#310
@ ./compiler/methodtable.jl:80 [inlined]
[4] findall
@ ./compiler/methodtable.jl:71 [inlined]
[5] find_matching_methods(𝕃::Core.Compiler.InferenceLattice{…}, argtypes::Vector{…}, atype::Any, method_table::Core.Compiler.OverlayMethodTable, max_union_splitting::Int64, max_methods::Int64)
@ Core.Compiler ./compiler/abstractinterpretation.jl:312
[6] abstract_call_gf_by_type(interp::Mooncake.MooncakeInterpreter{…}, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, atype::Any, sv::Core.Compiler.InferenceState, max_methods::Int64)
@ Core.Compiler ./compiler/abstractinterpretation.jl:25
[7] abstract_call_gf_by_type(interp::Mooncake.MooncakeInterpreter{…}, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, atype::Any, sv::Core.Compiler.InferenceState, max_methods::Int64)
@ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/abstract_interpretation.jl:128
[8] abstract_call_known(interp::Mooncake.MooncakeInterpreter{…}, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)
@ Core.Compiler ./compiler/abstractinterpretation.jl:2200
[9] abstract_call(interp::Mooncake.MooncakeInterpreter{…}, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)
@ Core.Compiler ./compiler/abstractinterpretation.jl:2282
[10] abstract_call(interp::Mooncake.MooncakeInterpreter{…}, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)
@ Core.Compiler ./compiler/abstractinterpretation.jl:2275
[11] abstract_call(interp::Mooncake.MooncakeInterpreter{…}, arginfo::Core.Compiler.ArgInfo, sv::Core.Compiler.InferenceState)
@ Core.Compiler ./compiler/abstractinterpretation.jl:2420
[12] abstract_eval_call(interp::Mooncake.MooncakeInterpreter{…}, e::Expr, vtypes::Vector{…}, sv::Core.Compiler.InferenceState)
@ Core.Compiler ./compiler/abstractinterpretation.jl:2435
[13] abstract_eval_statement_expr(interp::Mooncake.MooncakeInterpreter{…}, e::Expr, vtypes::Vector{…}, sv::Core.Compiler.InferenceState)
@ Core.Compiler ./compiler/abstractinterpretation.jl:2451
[14] abstract_eval_statement(interp::Mooncake.MooncakeInterpreter{…}, e::Any, vtypes::Vector{…}, sv::Core.Compiler.InferenceState)
@ Core.Compiler ./compiler/abstractinterpretation.jl:2749
[15] abstract_eval_basic_statement(interp::Mooncake.MooncakeInterpreter{…}, stmt::Any, pc_vartable::Vector{…}, frame::Core.Compiler.InferenceState)
@ Core.Compiler ./compiler/abstractinterpretation.jl:3065
[16] typeinf_local(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, frame::Core.Compiler.InferenceState)
@ Core.Compiler ./compiler/abstractinterpretation.jl:3319
[17] typeinf_nocycle(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, frame::Core.Compiler.InferenceState)
@ Core.Compiler ./compiler/abstractinterpretation.jl:3401
[18] _typeinf(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, frame::Core.Compiler.InferenceState)
@ Core.Compiler ./compiler/typeinfer.jl:244
[19] typeinf(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, frame::Core.Compiler.InferenceState)
@ Core.Compiler ./compiler/typeinfer.jl:215
[20] typeinf_frame(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, mi::Core.MethodInstance, run_optimizer::Bool)
@ Core.Compiler ./compiler/typeinfer.jl:1061
[21] typeinf_ircode(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, mi::Core.MethodInstance, optimize_until::Nothing)
@ Core.Compiler ./compiler/typeinfer.jl:1036
[22] typeinf_ircode(interp::Mooncake.MooncakeInterpreter{…}, method::Method, atype::Any, sparams::Core.SimpleVector, optimize_until::Nothing)
@ Core.Compiler ./compiler/typeinfer.jl:1031
[23] lookup_ir(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, mi::Core.MethodInstance; optimize_until::Nothing)
@ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/ir_utils.jl:269
[24] lookup_ir
@ ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/ir_utils.jl:266 [inlined]
[25] generate_dual_ir(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Core.MethodInstance; debug_mode::Bool, do_inline::Bool)
@ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:70
[26] generate_dual_ir
@ ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:62 [inlined]
[27] build_frule(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Core.MethodInstance; debug_mode::Bool, silence_debug_messages::Bool)
@ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:40
[28] build_frule
@ ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:7 [inlined]
[29] modify_fwd_ad_stmts!(dual_stmt::Expr, primal_stmt::Expr, dual_ir::Core.Compiler.IncrementalCompact, primal_ir::Core.Compiler.IncrementalCompact, dual_ssa::Int64, primal_ssa::Int64, interp::Mooncake.MooncakeInterpreter{…}; debug_mode::Bool)
@ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:317
[30] generate_dual_ir(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Core.MethodInstance; debug_mode::Bool, do_inline::Bool)
@ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:97
--- the above 5 lines are repeated 962 more times ---
[4841] generate_dual_ir
@ ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:62 [inlined]
[4842] build_frule(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Type; debug_mode::Bool, silence_debug_messages::Bool)
@ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:40
[4843] build_frule
@ ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:7 [inlined]
[4844] build_frule(args::Function; debug_mode::Bool)
@ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:4
Some type information was truncated. Use `show(err)` to see complete types. |
I've figure out 28, and will comment below once I've got a sense of what's going on with the others. 28So the problem here is that This is probably a good opportunity to discuss the purpose of
If you ctrl+f through that file, you'll find that I construct one of these when I encounter an It also has the advantage of meaning that the IR to do AD for a particular MethodInstance / signature is never generated unless it is actually used (dynamically). This means that lots of code never gets generated, which has the great advantage of meaning that if there's some function that Mooncake can't handle properly due to e.g. the presence of I think I probably skipped over this when we first spoke about doing forwards mode in order to try to limit the amount of information I was trying to convey to a manageable quantity haha. The upshot of all of this is that you should be able to employ the same strategy in forwards mode. If you look a few lines up from the |
17I think this is probably the first example of a Background: you can find out (statically) how many fields of a type must always be initialised by calling So, in the current implementation of
For example, the primal type in question here is julia> fieldcount(Mooncake.TestResources.StructFoo)
2
julia> Core.Compiler.datatype_min_ninitialized(Mooncake.TestResources.StructFoo)
1
julia> Mooncake.always_initialised(Mooncake.TestResources.StructFoo)
(true, false)
julia> tangent_type(Mooncake.TestResources.StructFoo)
Tangent{@NamedTuple{a, b::PossiblyUninitTangent{Vector{Float64}}}} So it has 2 fields, of which 1 must always be initialised. In reverse-mode, we construct Lines 36 to 49 in 5ca50ed
which is called from the rrule!! for _new_ you'll see how this works in reverse-mode.Let me know if the behaviour isn't clear from the above context. |
(I'm going to make adding these notes to the dev docs a todo item for tomorrow morning). |
@yebai I imagine something similar would work. Perhaps we could have a proper discussion in a separate issue? |
I'll take a look at 32 in the morning -- @gdalle hopefully this will keep you occupied until then! |
No problem -- please create an issue for it to continue the discussion. It is not meant to be a distraction but to consider some symbolic tracing design considerations into forward-mode autograd infrastructure so we can build upon it later. |
I've opened a discussion re tracing: #461 |
32I think this is just a case of the constructor for |
Okay, 17 and 32 are fixed, and 28 no longer recurses infinitely. However, I think 28 unearthed another issue with statement insertion in the Here's an example: using Mooncake
f6_inner = Mooncake.TestResources.globalref_tester_6_inner
Base.code_ircode(f6_inner, (Vector{Float64},))
rule = build_frule(f6_inner, rand(10)) Which gives you the following primal IR: julia> Base.code_ircode(f6_inner, (Vector{Float64},))
1-element Vector{Any}:
246 1 ── %1 = Base.identity::typeof(identity) sum
│ %2 = Base.add_sum::typeof(Base.add_sum) #sum#933
│ %3 = Base.getfield(_2, :size)::Tuple{Int64} _sum
│ %4 = $(Expr(:boundscheck, true))::Bool #_sum#935
│ %5 = Base.getfield(%3, 1, %4)::Int64││││││ _sum
│ %6 = (%5 === 0)::Bool ││││││╻ #_sum#936
└─── goto #3 if not %6 │││││││┃│││ mapreduce
2 ── goto #31 ││││││││┃││ #mapreduce#926
3 ── %9 = (%5 === 1)::Bool │││││││││╻ _mapreduce_dim
└─── goto #10 if not %9 ││││││││││┃ _mapreduce
4 ── %11 = $(Expr(:boundscheck, false))::Bool getindex
└─── goto #8 if not %11 ││││││││││││
5 ── %13 = Base.sub_int(1, 1)::Int64│││││││││
│ %14 = Base.bitcast(Base.UInt, %13)::UInt64
│ %15 = Base.getfield(_2, :size)::Tuple{Int64} length
│ %16 = $(Expr(:boundscheck, true))::Bool││╻ getindex
│ %17 = Base.getfield(%15, 1, %16)::Int64│││
│ %18 = Base.bitcast(Base.UInt, %17)::UInt64
│ %19 = Base.ult_int(%14, %18)::Bool││││││
└─── goto #7 if not %19 ││││││││││││
6 ── goto #8 │
7 ── %22 = Core.tuple(1)::Tuple{Int64}│││││││
│ invoke Base.throw_boundserror(_2::Vector{Float64}, %22::Tuple{Int64})::Union{}
└─── unreachable ││││││││││││
8 ┄─ %25 = Base.getfield(_2, :ref)::MemoryRef{Float64}
│ %26 = Base.memoryrefnew(%25, 1, false)::MemoryRef{Float64}
│ %27 = Base.memoryrefget(%26, :not_atomic, false)::Float64
└─── goto #9 │
9 ── goto #31 ││││││││││╻ _mapreduce
10 ─ %30 = Base.slt_int(%5, 16)::Bool│││││││╻ <
└─── goto #30 if not %30 │││││││││││
11 ─ %32 = $(Expr(:boundscheck, false))::Bool getindex
└─── goto #15 if not %32 ││││││││││││
12 ─ %34 = Base.sub_int(1, 1)::Int64│││││││││
│ %35 = Base.bitcast(Base.UInt, %34)::UInt64
│ %36 = Base.getfield(_2, :size)::Tuple{Int64} length
│ %37 = $(Expr(:boundscheck, true))::Bool││╻ getindex
│ %38 = Base.getfield(%36, 1, %37)::Int64│││
│ %39 = Base.bitcast(Base.UInt, %38)::UInt64
│ %40 = Base.ult_int(%35, %39)::Bool││││││
└─── goto #14 if not %40 ││││││││││││
13 ─ goto #15 │
14 ─ %43 = Core.tuple(1)::Tuple{Int64}│││││││
│ invoke Base.throw_boundserror(_2::Vector{Float64}, %43::Tuple{Int64})::Union{}
└─── unreachable ││││││││││││
15 ┄ %46 = Base.getfield(_2, :ref)::MemoryRef{Float64}
│ %47 = Base.memoryrefnew(%46, 1, false)::MemoryRef{Float64}
│ %48 = Base.memoryrefget(%47, :not_atomic, false)::Float64
└─── goto #16 │
16 ─ %50 = $(Expr(:boundscheck, false))::Bool getindex
└─── goto #20 if not %50 ││││││││││││
17 ─ %52 = Base.sub_int(2, 1)::Int64│││││││││
│ %53 = Base.bitcast(Base.UInt, %52)::UInt64
│ %54 = Base.getfield(_2, :size)::Tuple{Int64} length
│ %55 = $(Expr(:boundscheck, true))::Bool││╻ getindex
│ %56 = Base.getfield(%54, 1, %55)::Int64│││
│ %57 = Base.bitcast(Base.UInt, %56)::UInt64
│ %58 = Base.ult_int(%53, %57)::Bool││││││
└─── goto #19 if not %58 ││││││││││││
18 ─ goto #20 │
19 ─ %61 = Core.tuple(2)::Tuple{Int64}│││││││
│ invoke Base.throw_boundserror(_2::Vector{Float64}, %61::Tuple{Int64})::Union{}
└─── unreachable ││││││││││││
20 ┄ %64 = Base.getfield(_2, :ref)::MemoryRef{Float64}
│ %65 = Base.memoryrefnew(%64, 2, false)::MemoryRef{Float64}
│ %66 = Base.memoryrefget(%65, :not_atomic, false)::Float64
└─── goto #21 │
21 ─ %68 = Base.add_float(%48, %66)::Float64╻╷ add_sum
22 ┄ %69 = φ (#21 => %68, #28 => %92)::Float64 _mapreduce
│ %70 = φ (#21 => 2, #28 => %73)::Int64││
│ %71 = Base.slt_int(%70, %5)::Bool││││││╻ <
└─── goto #29 if not %71 │││││││││││
23 ─ %73 = Base.add_int(%70, 1)::Int64││││││╻ +
│ %74 = $(Expr(:boundscheck, false))::Bool getindex
└─── goto #27 if not %74 ││││││││││││
24 ─ %76 = Base.sub_int(%73, 1)::Int64│││││││
│ %77 = Base.bitcast(Base.UInt, %76)::UInt64
│ %78 = Base.getfield(_2, :size)::Tuple{Int64} length
│ %79 = $(Expr(:boundscheck, true))::Bool││╻ getindex
│ %80 = Base.getfield(%78, 1, %79)::Int64│││
│ %81 = Base.bitcast(Base.UInt, %80)::UInt64
│ %82 = Base.ult_int(%77, %81)::Bool││││││
└─── goto #26 if not %82 ││││││││││││
25 ─ goto #27 │
26 ─ %85 = Core.tuple(%73)::Tuple{Int64}│││││
│ invoke Base.throw_boundserror(_2::Vector{Float64}, %85::Tuple{Int64})::Union{}
└─── unreachable ││││││││││││
27 ┄ %88 = Base.getfield(_2, :ref)::MemoryRef{Float64}
│ %89 = Base.memoryrefnew(%88, %73, false)::MemoryRef{Float64}
│ %90 = Base.memoryrefget(%89, :not_atomic, false)::Float64
└─── goto #28 │
28 ─ %92 = Base.add_float(%69, %90)::Float64╻╷ add_sum
└─── goto #22 ││││││││││╻ _mapreduce
29 ─ goto #31 │││││││││││
30 ─ %95 = invoke Base.mapreduce_impl(%1::typeof(identity), %2::typeof(Base.add_sum), _2::Vector{Float64}, 1::Int64, %5::Int64, 1024::Int64)::Float64
└─── goto #31 │││││││││││
31 ┄ %97 = φ (#2 => 0.0, #9 => %27, #29 => %69, #30 => %95)::Float64
└─── goto #32 │
32 ─ goto #33 │
33 ─ goto #34 ││││││││
34 ─ goto #35 │
35 ─ goto #36 ││││││
36 ─ goto #37 │
37 ─ goto #38 ││││
38 ─ goto #39 │
39 ─ goto #40 ││
40 ─ return %97 │
=> Float64 and the following error when building the dual IR: julia> rule = build_frule(f6_inner, rand(10))
Block 4 successors (Array{Int64, 1}(dims=(2,), mem=Memory{Int64}(2, 0x33adf1010)[8, 5])), does not match GotoNode terminator (8)
246 1 ── %1 = Base.identity::typeof(identity)
│ %2 = Base.add_sum::typeof(Base.add_sum)
│ %3 = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:size}())::Any
│ %4 = (DualArguments(frule!!))(Mooncake.lgetfield, %3, Val{1}(), Val{true}())::Any
│ %5 = (DualArguments(frule!!))(===, %4, 0)::Any
│ %6 = (Mooncake._primal)(%5)::Any
└─── goto #3 if not %6
2 ── goto #31
3 ── %9 = (DualArguments(frule!!))(===, %4, 1)::Any
│ %10 = (Mooncake._primal)(%9)::Any
└─── goto #10 if not %10
4 ── goto #8
5 ── %13 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.sub_int, 1, 1)::Any
│ %14 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %13)::Any
│ %15 = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:size}())::Any
│ %16 = (DualArguments(frule!!))(Mooncake.lgetfield, %15, Val{1}(), Val{true}())::Any
│ %17 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %16)::Any
│ %18 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.ult_int, %14, %17)::Any
│ %19 = (Mooncake._primal)(%18)::Any
└─── goto #7 if not %19
6 ── goto #8
7 ── %22 = (DualArguments(frule!!))(tuple, 1)::Any
│ (DualArguments(frule!!))(Base.throw_boundserror, _3, %22)::Any
│ %24 = unreachable
└─── return %24
8 ┄─ %26 = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:ref}())::Any
│ %27 = (DualArguments(frule!!))(Core.memoryrefnew, %26, 1, false)::Any
│ %28 = (DualArguments(frule!!))(Mooncake.lmemoryrefget, %27, Val{:not_atomic}(), Val{false}())::Any
└─── goto #9
9 ── goto #31
10 ─ %31 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.slt_int, %4, 16)::Any
│ %32 = (Mooncake._primal)(%31)::Any
└─── goto #30 if not %32
11 ─ goto #15
12 ─ %35 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.sub_int, 1, 1)::Any
│ %36 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %35)::Any
│ %37 = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:size}())::Any
│ %38 = (DualArguments(frule!!))(Mooncake.lgetfield, %37, Val{1}(), Val{true}())::Any
│ %39 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %38)::Any
│ %40 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.ult_int, %36, %39)::Any
│ %41 = (Mooncake._primal)(%40)::Any
└─── goto #14 if not %41
13 ─ goto #15
14 ─ %44 = (DualArguments(frule!!))(tuple, 1)::Any
│ (DualArguments(frule!!))(Base.throw_boundserror, _3, %44)::Any
│ %46 = unreachable
└─── return %46
15 ┄ %48 = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:ref}())::Any
│ %49 = (DualArguments(frule!!))(Core.memoryrefnew, %48, 1, false)::Any
│ %50 = (DualArguments(frule!!))(Mooncake.lmemoryrefget, %49, Val{:not_atomic}(), Val{false}())::Any
└─── goto #16
16 ─ goto #20
17 ─ %53 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.sub_int, 2, 1)::Any
│ %54 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %53)::Any
│ %55 = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:size}())::Any
│ %56 = (DualArguments(frule!!))(Mooncake.lgetfield, %55, Val{1}(), Val{true}())::Any
│ %57 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %56)::Any
│ %58 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.ult_int, %54, %57)::Any
│ %59 = (Mooncake._primal)(%58)::Any
└─── goto #19 if not %59
18 ─ goto #20
19 ─ %62 = (DualArguments(frule!!))(tuple, 2)::Any
│ (DualArguments(frule!!))(Base.throw_boundserror, _3, %62)::Any
│ %64 = unreachable
└─── return %64
20 ┄ %66 = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:ref}())::Any
│ %67 = (DualArguments(frule!!))(Core.memoryrefnew, %66, 2, false)::Any
│ %68 = (DualArguments(frule!!))(Mooncake.lmemoryrefget, %67, Val{:not_atomic}(), Val{false}())::Any
└─── goto #21
21 ─ %70 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.add_float, %50, %68)::Any
22 ┄ %71 = φ (#21 => %70, #28 => %95)::Any
│ %72 = φ (#21 => 2, #28 => %76)::Any
│ %73 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.slt_int, %72, %4)::Any
│ %74 = (Mooncake._primal)(%73)::Any
└─── goto #29 if not %74
23 ─ %76 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.add_int, %72, 1)::Any
└─── goto #27
24 ─ %78 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.sub_int, %76, 1)::Any
│ %79 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %78)::Any
│ %80 = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:size}())::Any
│ %81 = (DualArguments(frule!!))(Mooncake.lgetfield, %80, Val{1}(), Val{true}())::Any
│ %82 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %81)::Any
│ %83 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.ult_int, %79, %82)::Any
│ %84 = (Mooncake._primal)(%83)::Any
└─── goto #26 if not %84
25 ─ goto #27
26 ─ %87 = (DualArguments(frule!!))(tuple, %76)::Any
│ (DualArguments(frule!!))(Base.throw_boundserror, _3, %87)::Any
│ %89 = unreachable
└─── return %89
27 ┄ %91 = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:ref}())::Any
│ %92 = (DualArguments(frule!!))(Core.memoryrefnew, %91, %76, false)::Any
│ %93 = (DualArguments(frule!!))(Mooncake.lmemoryrefget, %92, Val{:not_atomic}(), Val{false}())::Any
└─── goto #28
28 ─ %95 = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.add_float, %71, %93)::Any
└─── goto #22
29 ─ goto #31
30 ─ %98 = (DualArguments(Mooncake.LazyFRule{Tuple{typeof(Base.mapreduce_impl), typeof(identity), typeof(Base.add_sum), Vector{Float64}, Int64, Int64, Int64}, Mooncake.DerivedFRule{MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Dual{typeof(Base.mapreduce_impl), NoTangent}, Dual{typeof(identity), NoTangent}, Dual{typeof(Base.add_sum), NoTangent}, Dual{Vector{Float64}, Vector{Float64}}, Dual{Int64, NoTangent}, Dual{Int64, NoTangent}, Dual{Int64, NoTangent}}, Dual{Float64, Float64}}}}}(false, MethodInstance for Base.mapreduce_impl(::typeof(identity), ::typeof(Base.add_sum), ::Vector{Float64}, ::Int64, ::Int64, ::Int64), #undef)))(Base.mapreduce_impl, %1, %2, _3, 1, %4, 1024)::Any
└─── goto #31
31 ┄ %100 = φ (#2 => 0.0, #9 => %28, #29 => %71, #30 => %98)::Any
└─── goto #32
32 ─ goto #33
33 ─ goto #34
34 ─ goto #35
35 ─ goto #36
36 ─ goto #37
37 ─ goto #38
38 ─ goto #39
39 ─ goto #40
40 ─ %110 = (Mooncake._dual)(%100)::Any
└─── return %110
ERROR:
Stacktrace:
[1] error(s::String)
@ Core.Compiler ./error.jl:35
[2] verify_ir(ir::Core.Compiler.IRCode, print::Bool, allow_frontend_forms::Bool, 𝕃ₒ::Core.Compiler.PartialsLattice{…})
@ Core.Compiler ./compiler/ssair/verify.jl:191
[3] verify_ir
@ ./compiler/ssair/verify.jl:98 [inlined]
[4] generate_dual_ir(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, sig_or_mi::Type; debug_mode::Bool, do_inline::Bool)
@ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:113
[5] generate_dual_ir
@ ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:62 [inlined]
[6] build_frule(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Type; debug_mode::Bool, silence_debug_messages::Bool)
@ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:40
[7] build_frule
@ ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:7 [inlined]
[8] build_frule(::Function, ::Vararg{Any}; debug_mode::Bool)
@ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:4
[9] build_frule(::Function, ::Vector{Float64})
@ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:1
[10] top-level scope
@ ~/Documents/GitHub/Julia/Mooncake.jl/test/playground.jl:4
Some type information was truncated. Use `show(err)` to see complete types. The error I get:
As you can see, in block
|
This is a very rough backbone of forward mode AD, based on #386 and the existing reverse mode implementation.