Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start forward mode AD #389

Draft
wants to merge 42 commits into
base: main
Choose a base branch
from
Draft

Start forward mode AD #389

wants to merge 42 commits into from

Conversation

gdalle
Copy link
Collaborator

@gdalle gdalle commented Nov 24, 2024

This is a very rough backbone of forward mode AD, based on #386 and the existing reverse mode implementation.

Copy link

codecov bot commented Nov 24, 2024

Codecov Report

Attention: Patch coverage is 9.30736% with 419 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/interpreter/s2s_forward_mode_ad.jl 0.00% 153 Missing ⚠️
src/test_utils.jl 16.03% 110 Missing ⚠️
src/interpreter/diffractor_compiler_utils.jl 0.00% 66 Missing ⚠️
src/dual.jl 18.18% 18 Missing ⚠️
src/rrules/builtins.jl 0.00% 17 Missing ⚠️
src/rrules/new.jl 0.00% 17 Missing ⚠️
src/tools_for_rules.jl 60.00% 12 Missing ⚠️
src/interpreter/bbcode.jl 0.00% 9 Missing ⚠️
src/rrules/misc.jl 0.00% 9 Missing ⚠️
src/rrules/low_level_maths.jl 0.00% 6 Missing ⚠️
... and 2 more
Files with missing lines Coverage Δ
src/Mooncake.jl 100.00% <ø> (ø)
src/interpreter/ir_normalisation.jl 39.71% <ø> (-50.36%) ⬇️
src/debug_mode.jl 86.11% <0.00%> (-11.04%) ⬇️
src/interpreter/ir_utils.jl 62.18% <0.00%> (-25.32%) ⬇️
src/rrules/low_level_maths.jl 0.00% <0.00%> (-100.00%) ⬇️
src/interpreter/bbcode.jl 70.83% <0.00%> (-25.26%) ⬇️
src/rrules/misc.jl 0.00% <0.00%> (-97.44%) ⬇️
src/tools_for_rules.jl 78.15% <60.00%> (-20.81%) ⬇️
src/rrules/builtins.jl 1.96% <0.00%> (-97.01%) ⬇️
src/rrules/new.jl 2.89% <0.00%> (-85.57%) ⬇️
... and 4 more

... and 28 files with indirect coverage changes

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. I've left a few comments, but if you're planning to do a bunch of additional stuff, then maybe they're redundant. Either way, don't feel the need to respond to them.

src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
test/forward.jl Outdated Show resolved Hide resolved
src/frules/basic.jl Outdated Show resolved Hide resolved
src/frules/basic.jl Outdated Show resolved Hide resolved
src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

@willtebbutt following our discussion yesterday I scratched my head some more, and I decided that it would be infinitely simpler to enforce the invariant that one line of primal IR maps to one line of dual IR. While this may require additional fallbacks in the Julia code itself, I hope it will make our lives much easier on the IR side. What do you think?

@willtebbutt
Copy link
Member

I think this could work.

You could just replace the frule!! calls with a call to a function call_frule!! which would be something like

@inline function call_frule!!(rule::R, fargs::Vararg{Any, N}) where {N}
    return rule(map(x -> x isa Dual ? x : zero_dual(x), fargs)...)
end

The optimisation pass will lower this to the what we were thinking about writing out in the IR anyway.

I think the other important kinds of nodes would be largely straightforward to handle.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

I think we might need to be slightly more subtle. If an argument to the :call or :invoke expression is a CC.Argument or a CC.SSAValue, we don't wrap it in a Dual because we assume it will already be one, right?

@willtebbutt
Copy link
Member

willtebbutt commented Nov 26, 2024

Yes. I think my propose code handles this though, or am I missing something?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

In the spirit of higher-order AD, we may encounter Dual inputs that we want to wrap with a second Dual, and Dual inputs that we want to leave as-is. So I think this wrapping needs to be decided from the type of each argument in the IR?

@willtebbutt
Copy link
Member

Very good point.

So I think this wrapping needs to be decided from the type of each argument in the IR?

Agreed. Specifically, I think we need to distinguish between literals / QuoteNodes / GlobalRefs, and Argument / SSAValues?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

I still need to dig into the different node types we might encounter (and I still don't understand QuoteNodes) but yeah, Argument and SSAValue don't need to be wrapped.

@gdalle gdalle mentioned this pull request Nov 27, 2024
@willtebbutt
Copy link
Member

I was reviewing the design docs and realised that, sadly, the "one line of primal IR maps to one line of dual IR" won't work for Core.GotoIfNot nodes. See https://compintell.github.io/Mooncake.jl/previews/PR386/developer_documentation/forwards_mode_design/#Statement-Transformation .

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

I think that's okay, the main trouble is adding new lines which insert new variables because it requires manual renumbering. A GoTo should be much simpler.

@willtebbutt
Copy link
Member

Were the difficulties around renumbering etc not resolved by not compact!ing until the end? I feel like I might be missing something.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

No they weren't. I experimented with compact! in various places and I was struggling a lot, so I asked Frames for advice. She agreed that insertion should usually be avoided.
If we have to insert something for GoTo, I think it will still be easier because we're not defining a new SSAValue so we don't have to adapt future statements that refer to it.

@willtebbutt
Copy link
Member

willtebbutt commented Nov 27, 2024

Ah, right, but we do need to insert a new SSAValue. Suppose that the GotoIfNot of interest is

GotoIfNot(%5, #3)

i.e. jump to block 3 if not %5. In the forwards-mode IR this would become

%new_ssa = Expr(:call, primal, %5)
GotoIfNot(%new_ssa, #3)

Does this not cause the same kind of problems?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

Oh yes you're probably right. Although it might be slightly less of a hassle because the new SSA is only used in one spot, right after. I'll take a look

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

Do you know what I should do about expressions of type :code_coverage_effect? I assume they're inserted automatically and they're alone on their lines?

@willtebbutt
Copy link
Member

willtebbutt commented Nov 27, 2024

Yup -- I just strip them out of the IR entirely in reverse-mode. See

elseif Meta.isexpr(stmt, :code_coverage_effect)

The way to remove an instruction from an IRCode is just to replace the instruction with nothing.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

I think this works for GotoIfNot:

  1. make all the insertions necessary
  2. compact! once to make sure they applied
  3. shift the conditions of all GotoIfNot nodes to refer to the node right before them (where we get the primal value of the condition)

MWE (requires this branch of Mooncake):

const CC = Core.Compiler
using Mooncake
using MistyClosures

f(x) = x > 1 ? 2x : 3 + x
ir = Base.code_ircode(f, (Float64,))[1][1]
initial_ir = copy(ir)
get_primal_inst = CC.NewInstruction(Expr(:call, +, 1, 2), Any)  # placeholder for get_primal
CC.insert_node!(ir, CC.SSAValue(3), get_primal_inst, false)
ir = CC.compact!(ir)
for k in 1:length(ir.stmts)
    inst = ir[CC.SSAValue(k)][:stmt]
    if inst isa Core.GotoIfNot
        Mooncake.replace_call!(ir,CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k-1), inst.dest))
    end
end
ir
julia> initial_ir
5 1%1 = Base.lt_float(1.0, _2)::Bool                                                                                 │╻╷╷ >%2 = Base.or_int(%1, false)::Bool                                                                                 ││╻   <
  └──      goto #3 if not %2                                                                                            │   
  2%4 = Base.mul_float(2.0, _2)::Float64                                                                             ││╻   *
  └──      return %43%6 = Base.add_float(3.0, _2)::Float64                                                                             ││╻   +
  └──      return %6                                                                                                    │   
                                                                                                                            

julia> ir
5 1%1 = Base.lt_float(1.0, _2)::Bool                                                                                 │╻╷╷ >
  │        Base.or_int(%1, false)::Bool                                                                                 ││╻   <%3 = (+)(1, 2)::Any                                                                                               │   
  └──      goto #3 if not %3                                                                                            │   
  2%5 = Base.mul_float(2.0, _2)::Float64                                                                             ││╻   *
  └──      return %53%7 = Base.add_float(3.0, _2)::Float64                                                                             ││╻   +
  └──      return %7      

@willtebbutt
Copy link
Member

Merge conflict resolved. I've also modified @zero_derivative to also do @zero_adjoint so that precompilation is happy. I'll take a look at the error now.

@gdalle
Copy link
Collaborator Author

gdalle commented Jan 27, 2025

Do we need to keep both "zero" macros or would one be enough?

@willtebbutt
Copy link
Member

I think one would probably be enough. I think I've messed something up though -- going to fix and push.

@willtebbutt
Copy link
Member

Okay. When you get a minute, could you take a look at my latest commit, and see if you're seeing the same problem as before?

@gdalle
Copy link
Collaborator Author

gdalle commented Jan 27, 2025

The test failure on case 10 looks the same, but there's a new one on case 6 which wasn't there before the merge. Is that what you were asking?

@willtebbutt
Copy link
Member

willtebbutt commented Jan 27, 2025

I think it must be that something is going wrong with getfield -- you should never see a PossiblyUninitTangent unless its a field of a struct of some kind. The other errors also suggest that this kind of thing is what's going on -- the Expression: Mooncake.verify_dual_type(y_ẏ) suggests that the wrong type is being returned.

The only way that I can imagine that this has happened is that something has gone wrong with the frule!! for lgetfield. I'd suggest accessing the field via _get_fdata_field, which will do the correct thing with possibly-undefined fields.

edit: look at the frule!! in for lgetfield, it looks like this is probably what's going on. The frule!! accesses the field of the fdata directly, rather than going via _get_fdata_field etc.

@gdalle
Copy link
Collaborator Author

gdalle commented Feb 4, 2025

Hey @willtebbutt, making slow progress on the test cases. Here's where I need some help:

Failing cases:
- 4: fails an allocation test on the first run only (rule compilation?)
- 15,16,17,18: I don't understand why `DataType` doesn't dispatch on `Type{P}`
- 23: segfault on `GlobalRef`

The segfault is especially puzzling. I chose to do nothing to GlobalRefs in the IR (because DualArguments would wrap them anyway) and I didn't expect such a brutal reaction from the Julia compiler ^^

@willtebbutt
Copy link
Member

willtebbutt commented Feb 5, 2025

  • 4: fails an allocation test on the first run only (rule compilation?)

I've been struggling with this in other places. I think that there's something going on in type inference, but I'm not 100% sure what exactly it is. As you've observed, it is often (somehow) fixed by advancing the world age, but I don't understand why at all. Sorry not to be of more help. Perhaps this is something that is best left to one side for now?

  • 15,16,17,18: I don't understand why DataType doesn't dispatch on Type{P}

I think what's going on here is that I've got an outer constructor for CoDual which always sharpens the type of the CoDual produced if the argument is a Type. The equivalent thing for a Dual would be something like:

# Always sharpen the first thing if it's a type so static dispatch remains possible.
function Dual(x::Type{P}, dx::NoTangent) where {P}
    return Dual{@isdefined(P) ? Type{P} : typeof(x),NoTangent}(P, dx)
end
  • 23: segfault on GlobalRef

This is a fun one. Let's look at the optimised IR for the rule:

julia> rule.fwd_oc.ir[]
231 1%1 = (Mooncake._primal)(_3)::Bool                                                                                                                                     │
    └──      goto #3 if not %1                                                                                                                                                │
    2 ─      goto #4                                                                                                                                                          │
    3nothing::Nothing4%5 = φ (#2 => true, #3 => false)::Bool                                                                                                                                │%6 = φ (#2 => Mooncake.TestResources.__x_for_gref_test, #3 => 1)::Union{Float64, Int64}                                                                               │
232%7 = (Mooncake._primal)(%5)::Bool231 └──      goto #6 if not %7                                                                                                                                                │
232 5 ─      π (%6, Dual{Float64, Float64})                                                                                                                                   │
    └──      unreachable                                                                                                                                                      │
    6nothing::Nothing7 ─      π (%6, Dual{Int64, NoTangent})                                                                                                                                   │
    └──      unreachable

Observe that %6 is inferred to have type Union{Float64, Int64}. However, the PiNodes both tell the compiler to assume that %6 is either a Dual{Float64, Float64} or a Dual{Int, NoTangent}, depending on which branch was taken. In both cases, what the PiNodes say doesn't align with what the PhiNode produces. This looks to me like maybe the PhiNode is not being "dual-ed" properly? Do you think this is true?

As to the segfault: I'm pretty sure this is a result of the PiNode asserting the wrong type. This is one of those places that things go really wrong if we make a mistake.

Segfaults scare me, so I ensure that improperly-implemented rules don't ever cause segfaults in hard-to-locate places by shoving a typeassert in the IR after a rule call in reverse-mode, but I think for internal things that users don't extend (like PiNodes and PhiNodes) we just have to get it right.

@gdalle
Copy link
Collaborator Author

gdalle commented Feb 5, 2025

This looks to me like maybe the PhiNode is not being "dual-ed" properly? Do you think this is true?

Yes, that is indeed the source (or at least a source) of the error, well spotted.

I'm running into a deeper issue when I try to fix it: when we get the signature of a method to replace an Expr statement with the correct rule, we base our retrieval of the primal type on the primal_ir. But since we now insert statements into the dual_ir, they go out of sync so this line is wrong. I wish there was a way to abandon the primal_ir entirely, or at least to avoid maintaining a mapping from old to new SSAValues, but it's not looking good. Any suggestions?

@willtebbutt
Copy link
Member

Ah, right, yes. I handle this in reverse-mode by making the "name" of a statement invariant under insertion / deletion of other statements, meaning that this problem never really happens (this is one aspect of what the BBCode type does). The reason I did this is because I ran into the same kind of problem that you're currently trying to deal with (amongst others), and was unable to find a solution that I liked.

I wonder whether someone else who works with IRCode regularly (such as @vchuravy ) has thoughts on / has developed strategies to make this kind of work convenient? I'd very much like to know about them if they exist.

@gdalle
Copy link
Collaborator Author

gdalle commented Feb 5, 2025

I think last time we chatted about this with Valentin he recommended using the native Julia IR whenever possible to increase compatibility and readability, which I rather agree with. So I would love to stick with that in forward mode because it's much easier for me than it must have been for you in reverse mode.
However, the deeper problem is that the compiler is not documented. Our limited understanding of what things like IncrementalCompact do is obtained through trial and error, and looking at non-descript fields of internal data structures. I think spending a couple of days putting the draft of #382 inside the newly-excised Compiler stdlib would be a tremendous added value.

@vchuravy
Copy link

vchuravy commented Feb 5, 2025

I'd very much like to know about them if they exist.

I think in Enzyme we handle this by never touching the primal code and then creating a new IR where we copy over things and maintain a mapping from statement to statement.

@yebai
Copy link
Contributor

yebai commented Feb 6, 2025

I have one curious question for @gdalle and @willtebbutt: Can we use the same forward-mode infrastructure to perform Jax-style symbolic tracing?

Briefly, I think this involves writing a different set of frules that emit code (e.g. Julia/Jaxprs) rather than computing gradients.

@gdalle
Copy link
Collaborator Author

gdalle commented Feb 6, 2025

@willtebbutt more progress, and more bugs for you to look at:

- 17: rule for `new` with uninitialized tangents for some fields
- 28: stackoverflow (probably in method recognition and rule insertion?)
- 32: rule for `new` with namedtuple is incorrect

The stack overflow is especially puzzling to me. Here's a reproducer:

julia> using Mooncake

julia> f6 = Mooncake.TestResources.globalref_tester_6
globalref_tester_6 (generic function with 1 method)

julia> build_frule(f6)
ERROR: StackOverflowError:
Stacktrace:
    [1] _methods_by_ftype(t::Any, mt::Nothing, lim::Int64, world::UInt64, ambig::Bool, min::Core.Compiler.RefValue{…}, max::Core.Compiler.RefValue{…}, has_ambig::Core.Compiler.RefValue{…})
      @ Core.Compiler ./reflection.jl:1182
    [2] _findall
      @ ./compiler/methodtable.jl:95 [inlined]
    [3] #findall#310
      @ ./compiler/methodtable.jl:80 [inlined]
    [4] findall
      @ ./compiler/methodtable.jl:71 [inlined]
    [5] find_matching_methods(𝕃::Core.Compiler.InferenceLattice{…}, argtypes::Vector{…}, atype::Any, method_table::Core.Compiler.OverlayMethodTable, max_union_splitting::Int64, max_methods::Int64)
      @ Core.Compiler ./compiler/abstractinterpretation.jl:312
    [6] abstract_call_gf_by_type(interp::Mooncake.MooncakeInterpreter{…}, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, atype::Any, sv::Core.Compiler.InferenceState, max_methods::Int64)
      @ Core.Compiler ./compiler/abstractinterpretation.jl:25
    [7] abstract_call_gf_by_type(interp::Mooncake.MooncakeInterpreter{…}, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, atype::Any, sv::Core.Compiler.InferenceState, max_methods::Int64)
      @ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/abstract_interpretation.jl:128
    [8] abstract_call_known(interp::Mooncake.MooncakeInterpreter{…}, f::Any, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)
      @ Core.Compiler ./compiler/abstractinterpretation.jl:2200
    [9] abstract_call(interp::Mooncake.MooncakeInterpreter{…}, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState, max_methods::Int64)
      @ Core.Compiler ./compiler/abstractinterpretation.jl:2282
   [10] abstract_call(interp::Mooncake.MooncakeInterpreter{…}, arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState)
      @ Core.Compiler ./compiler/abstractinterpretation.jl:2275
   [11] abstract_call(interp::Mooncake.MooncakeInterpreter{…}, arginfo::Core.Compiler.ArgInfo, sv::Core.Compiler.InferenceState)
      @ Core.Compiler ./compiler/abstractinterpretation.jl:2420
   [12] abstract_eval_call(interp::Mooncake.MooncakeInterpreter{…}, e::Expr, vtypes::Vector{…}, sv::Core.Compiler.InferenceState)
      @ Core.Compiler ./compiler/abstractinterpretation.jl:2435
   [13] abstract_eval_statement_expr(interp::Mooncake.MooncakeInterpreter{…}, e::Expr, vtypes::Vector{…}, sv::Core.Compiler.InferenceState)
      @ Core.Compiler ./compiler/abstractinterpretation.jl:2451
   [14] abstract_eval_statement(interp::Mooncake.MooncakeInterpreter{…}, e::Any, vtypes::Vector{…}, sv::Core.Compiler.InferenceState)
      @ Core.Compiler ./compiler/abstractinterpretation.jl:2749
   [15] abstract_eval_basic_statement(interp::Mooncake.MooncakeInterpreter{…}, stmt::Any, pc_vartable::Vector{…}, frame::Core.Compiler.InferenceState)
      @ Core.Compiler ./compiler/abstractinterpretation.jl:3065
   [16] typeinf_local(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, frame::Core.Compiler.InferenceState)
      @ Core.Compiler ./compiler/abstractinterpretation.jl:3319
   [17] typeinf_nocycle(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, frame::Core.Compiler.InferenceState)
      @ Core.Compiler ./compiler/abstractinterpretation.jl:3401
   [18] _typeinf(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, frame::Core.Compiler.InferenceState)
      @ Core.Compiler ./compiler/typeinfer.jl:244
   [19] typeinf(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, frame::Core.Compiler.InferenceState)
      @ Core.Compiler ./compiler/typeinfer.jl:215
   [20] typeinf_frame(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, mi::Core.MethodInstance, run_optimizer::Bool)
      @ Core.Compiler ./compiler/typeinfer.jl:1061
   [21] typeinf_ircode(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, mi::Core.MethodInstance, optimize_until::Nothing)
      @ Core.Compiler ./compiler/typeinfer.jl:1036
   [22] typeinf_ircode(interp::Mooncake.MooncakeInterpreter{…}, method::Method, atype::Any, sparams::Core.SimpleVector, optimize_until::Nothing)
      @ Core.Compiler ./compiler/typeinfer.jl:1031
   [23] lookup_ir(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, mi::Core.MethodInstance; optimize_until::Nothing)
      @ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/ir_utils.jl:269
   [24] lookup_ir
      @ ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/ir_utils.jl:266 [inlined]
   [25] generate_dual_ir(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Core.MethodInstance; debug_mode::Bool, do_inline::Bool)
      @ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:70
   [26] generate_dual_ir
      @ ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:62 [inlined]
   [27] build_frule(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Core.MethodInstance; debug_mode::Bool, silence_debug_messages::Bool)
      @ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:40
   [28] build_frule
      @ ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:7 [inlined]
   [29] modify_fwd_ad_stmts!(dual_stmt::Expr, primal_stmt::Expr, dual_ir::Core.Compiler.IncrementalCompact, primal_ir::Core.Compiler.IncrementalCompact, dual_ssa::Int64, primal_ssa::Int64, interp::Mooncake.MooncakeInterpreter{…}; debug_mode::Bool)
      @ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:317
   [30] generate_dual_ir(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Core.MethodInstance; debug_mode::Bool, do_inline::Bool)
      @ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:97
--- the above 5 lines are repeated 962 more times ---
 [4841] generate_dual_ir
      @ ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:62 [inlined]
 [4842] build_frule(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Type; debug_mode::Bool, silence_debug_messages::Bool)
      @ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:40
 [4843] build_frule
      @ ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:7 [inlined]
 [4844] build_frule(args::Function; debug_mode::Bool)
      @ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:4
Some type information was truncated. Use `show(err)` to see complete types.

@willtebbutt
Copy link
Member

I've figure out 28, and will comment below once I've got a sense of what's going on with the others.

28

So the problem here is that sum calls Base.mapreduce_impl(::typeof(identity), ::typeof(Base.add_sum), ::Vector{Float64}, ::Int64, ::Int64, ::Int64), which calls itself. With the way that things are currently implemented, the fact that MethodInstance for Base.mapreduce_impl(::typeof(identity), ::typeof(Base.add_sum), ::Vector{Float64}, ::Int64, ::Int64, ::Int64) calls itself leads to a stack overflow.

This is probably a good opportunity to discuss the purpose of

mutable struct LazyDerivedRule{primal_sig,Trule}

If you ctrl+f through that file, you'll find that I construct one of these when I encounter an invoke statement, rather than a DerivedRRule directly. LazyDerivedRule knows its input and output types, which means that inference still works in the rule that's being constructed using it, but it defers constructing the OpaqueClosure which actually runs it until it is first called (hence the use of the term "lazy"). Doing this prevents infinite recursion when constructing rules.

It also has the advantage of meaning that the IR to do AD for a particular MethodInstance / signature is never generated unless it is actually used (dynamically). This means that lots of code never gets generated, which has the great advantage of meaning that if there's some function that Mooncake can't handle properly due to e.g. the presence of PhiCNodes and UpsilonNodes, but which never actually gets run, the whole thing doesn't fall over. I don't know how important this will be for forwards mode, because I think you can probably support these nodes more straightforwardly. I think it probably also improves compile times, because let codegen happens.

I think I probably skipped over this when we first spoke about doing forwards mode in order to try to limit the amount of information I was trying to convey to a manageable quantity haha.

The upshot of all of this is that you should be able to employ the same strategy in forwards mode.

If you look a few lines up from the LazyDerivedRule stuff you'll find DynamicDerivedRule, which is how I go about handling dynamic dispatch -- I imagine you're going to hit dynamic stuff soon, so it will be relevant for that.

@willtebbutt
Copy link
Member

17

I think this is probably the first example of a :new instruction for a struct with potentially un-initialised fields.

Background: you can find out (statically) how many fields of a type must always be initialised by calling Nmin = Core.Compiler.datatype_min_ninitialized(P) on your primal type P. I make use of this function inside Mooncake.always_initialised, Mooncake.is_always_initialised, and Mooncake.is_always_fully_initialised.
If the number returned by this function is equal to N = fieldcount(P), then all fields must always be initialised. If it is less, then the final N - Nmin fields may be uninitialised. If it is ever the case that Nmin > N, then there's a bug in Julia.

So, in the current implementation of frule!! for _new_, the call to _new_ to construct the primal is fine, because _new_ just calls out to :new, which will not complain provided that you provide at least Nmin arguments.

Tangents are different. Recall that they just wrap a NamedTuple, which must have field names equal to those of the primal type. Since NamedTuples must always be fully initialised (N == Nmin, always), we handle partial initialisation by making the fieldtype of any field of the NamedTuple whose corresponding primal could be uninitialised a PossiblyUninitTangent{tangent_type(primal_field_type)}. PossiblyUninitTangent can be uninitialised, so this gives us the flexibility to construct a Tangent with as few as Nmin fields.

For example, the primal type in question here is Mooncake.TestResources.StructFoo, for which

julia> fieldcount(Mooncake.TestResources.StructFoo)
2

julia> Core.Compiler.datatype_min_ninitialized(Mooncake.TestResources.StructFoo)
1

julia> Mooncake.always_initialised(Mooncake.TestResources.StructFoo)
(true, false)

julia> tangent_type(Mooncake.TestResources.StructFoo)
Tangent{@NamedTuple{a, b::PossiblyUninitTangent{Vector{Float64}}}}

So it has 2 fields, of which 1 must always be initialised. Mooncake.always_initialised reflects this (see docstring for semantics).
The tangent type is a Tangent whose underlying NamedTuple has a field a, which can be anything, and a field b, which is it PossiblyUninitTangent.

In reverse-mode, we construct FData rather than a Tangent, but the functionality is essentially identical. If you take a look at

@generated function build_fdata(::Type{P}, x::Tuple, fdata::Tuple) where {P}
names = fieldnames(P)
fdata_exprs = map(eachindex(names)) do n
F = fdata_field_type(P, n)
if n <= length(fdata.parameters)
data_expr = Expr(:call, __get_data, P, :x, :fdata, n)
return F <: PossiblyUninitTangent ? Expr(:call, F, data_expr) : data_expr
else
return :($F())
end
end
F_out = fdata_type(tangent_type(P))
return :($F_out(NamedTuple{$names}($(Expr(:call, tuple, fdata_exprs...)))))
end

which is called from the rrule!! for _new_ you'll see how this works in reverse-mode.
Let me know if the behaviour isn't clear from the above context.

@willtebbutt
Copy link
Member

(I'm going to make adding these notes to the dev docs a todo item for tomorrow morning).

@willtebbutt
Copy link
Member

willtebbutt commented Feb 6, 2025

I have one curious question for @gdalle and @willtebbutt: Can we use the same forward-mode infrastructure to perform Jax-style symbolic tracing?

@yebai I imagine something similar would work. Perhaps we could have a proper discussion in a separate issue?

@willtebbutt
Copy link
Member

I'll take a look at 32 in the morning -- @gdalle hopefully this will keep you occupied until then!

@yebai
Copy link
Contributor

yebai commented Feb 6, 2025

@yebai I imagine something similar would work. Perhaps we could have a proper discussion in a separate issue?

No problem -- please create an issue for it to continue the discussion.

It is not meant to be a distraction but to consider some symbolic tracing design considerations into forward-mode autograd infrastructure so we can build upon it later.

@willtebbutt
Copy link
Member

I've opened a discussion re tracing: #461

@willtebbutt
Copy link
Member

32

I think this is just a case of the constructor for Tangent requiring that you give it a NamedTuple, rather than a Tuple? Presumably this would also apply to 17.

@gdalle
Copy link
Collaborator Author

gdalle commented Feb 7, 2025

Okay, 17 and 32 are fixed, and 28 no longer recurses infinitely. However, I think 28 unearthed another issue with statement insertion in the IncrementalCompact: the basic blocks are not recomputed correctly (regardless of how I set reverse_affinity in Core.Compiler.insert_node_here!).

Here's an example:

using Mooncake
f6_inner = Mooncake.TestResources.globalref_tester_6_inner
Base.code_ircode(f6_inner, (Vector{Float64},))
rule = build_frule(f6_inner, rand(10))

Which gives you the following primal IR:

julia> Base.code_ircode(f6_inner, (Vector{Float64},))
1-element Vector{Any}:
246 1 ── %1  = Base.identity::typeof(identity)       sum
    │    %2  = Base.add_sum::typeof(Base.add_sum)     #sum#933%3  = Base.getfield(_2, :size)::Tuple{Int64}  _sum
    │    %4  = $(Expr(:boundscheck, true))::Bool        #_sum#935%5  = Base.getfield(%3, 1, %4)::Int64││││││     _sum
    │    %6  = (%5 === 0)::Bool      ││││││╻              #_sum#936
    └───       goto #3 if not %6     │││││││┃│││           mapreduce
    2 ──       goto #31              ││││││││┃││            #mapreduce#926
    3 ── %9  = (%5 === 1)::Bool      │││││││││╻              _mapreduce_dim
    └───       goto #10 if not %9    ││││││││││┃              _mapreduce
    4 ── %11 = $(Expr(:boundscheck, false))::Bool              getindex
    └───       goto #8 if not %11    ││││││││││││   
    5 ── %13 = Base.sub_int(1, 1)::Int64│││││││││   
    │    %14 = Base.bitcast(Base.UInt, %13)::UInt64%15 = Base.getfield(_2, :size)::Tuple{Int64}           length
    │    %16 = $(Expr(:boundscheck, true))::Bool││╻              getindex
    │    %17 = Base.getfield(%15, 1, %16)::Int64│││ 
    │    %18 = Base.bitcast(Base.UInt, %17)::UInt64%19 = Base.ult_int(%14, %18)::Bool││││││   
    └───       goto #7 if not %19    ││││││││││││   
    6 ──       goto #8               │              
    7 ── %22 = Core.tuple(1)::Tuple{Int64}│││││││   
    │          invoke Base.throw_boundserror(_2::Vector{Float64}, %22::Tuple{Int64})::Union{}
    └───       unreachable           ││││││││││││   
    8 ┄─ %25 = Base.getfield(_2, :ref)::MemoryRef{Float64}%26 = Base.memoryrefnew(%25, 1, false)::MemoryRef{Float64}%27 = Base.memoryrefget(%26, :not_atomic, false)::Float64
    └───       goto #9               │              
    9 ──       goto #31              ││││││││││╻              _mapreduce
    10%30 = Base.slt_int(%5, 16)::Bool│││││││╻              <
    └───       goto #30 if not %30   │││││││││││    
    11%32 = $(Expr(:boundscheck, false))::Bool              getindex
    └───       goto #15 if not %32   ││││││││││││   
    12%34 = Base.sub_int(1, 1)::Int64│││││││││   
    │    %35 = Base.bitcast(Base.UInt, %34)::UInt64%36 = Base.getfield(_2, :size)::Tuple{Int64}           length
    │    %37 = $(Expr(:boundscheck, true))::Bool││╻              getindex
    │    %38 = Base.getfield(%36, 1, %37)::Int64│││ 
    │    %39 = Base.bitcast(Base.UInt, %38)::UInt64%40 = Base.ult_int(%35, %39)::Bool││││││   
    └───       goto #14 if not %40   ││││││││││││   
    13 ─       goto #15              │              
    14%43 = Core.tuple(1)::Tuple{Int64}│││││││   
    │          invoke Base.throw_boundserror(_2::Vector{Float64}, %43::Tuple{Int64})::Union{}
    └───       unreachable           ││││││││││││   
    15%46 = Base.getfield(_2, :ref)::MemoryRef{Float64}%47 = Base.memoryrefnew(%46, 1, false)::MemoryRef{Float64}%48 = Base.memoryrefget(%47, :not_atomic, false)::Float64
    └───       goto #16              │              
    16%50 = $(Expr(:boundscheck, false))::Bool              getindex
    └───       goto #20 if not %50   ││││││││││││   
    17%52 = Base.sub_int(2, 1)::Int64│││││││││   
    │    %53 = Base.bitcast(Base.UInt, %52)::UInt64%54 = Base.getfield(_2, :size)::Tuple{Int64}           length
    │    %55 = $(Expr(:boundscheck, true))::Bool││╻              getindex
    │    %56 = Base.getfield(%54, 1, %55)::Int64│││ 
    │    %57 = Base.bitcast(Base.UInt, %56)::UInt64%58 = Base.ult_int(%53, %57)::Bool││││││   
    └───       goto #19 if not %58   ││││││││││││   
    18 ─       goto #20              │              
    19%61 = Core.tuple(2)::Tuple{Int64}│││││││   
    │          invoke Base.throw_boundserror(_2::Vector{Float64}, %61::Tuple{Int64})::Union{}
    └───       unreachable           ││││││││││││   
    20%64 = Base.getfield(_2, :ref)::MemoryRef{Float64}%65 = Base.memoryrefnew(%64, 2, false)::MemoryRef{Float64}%66 = Base.memoryrefget(%65, :not_atomic, false)::Float64
    └───       goto #21              │              
    21%68 = Base.add_float(%48, %66)::Float64╻╷             add_sum
    22%69 = φ (#21 => %68, #28 => %92)::Float64            _mapreduce%70 = φ (#21 => 2, #28 => %73)::Int64││    %71 = Base.slt_int(%70, %5)::Bool││││││╻              <
    └───       goto #29 if not %71   │││││││││││    
    23%73 = Base.add_int(%70, 1)::Int64││││││╻              +%74 = $(Expr(:boundscheck, false))::Bool              getindex
    └───       goto #27 if not %74   ││││││││││││   
    24%76 = Base.sub_int(%73, 1)::Int64│││││││   
    │    %77 = Base.bitcast(Base.UInt, %76)::UInt64%78 = Base.getfield(_2, :size)::Tuple{Int64}           length
    │    %79 = $(Expr(:boundscheck, true))::Bool││╻              getindex
    │    %80 = Base.getfield(%78, 1, %79)::Int64│││ 
    │    %81 = Base.bitcast(Base.UInt, %80)::UInt64%82 = Base.ult_int(%77, %81)::Bool││││││   
    └───       goto #26 if not %82   ││││││││││││   
    25 ─       goto #27              │              
    26%85 = Core.tuple(%73)::Tuple{Int64}│││││   
    │          invoke Base.throw_boundserror(_2::Vector{Float64}, %85::Tuple{Int64})::Union{}
    └───       unreachable           ││││││││││││   
    27%88 = Base.getfield(_2, :ref)::MemoryRef{Float64}%89 = Base.memoryrefnew(%88, %73, false)::MemoryRef{Float64}%90 = Base.memoryrefget(%89, :not_atomic, false)::Float64
    └───       goto #28              │              
    28%92 = Base.add_float(%69, %90)::Float64╻╷             add_sum
    └───       goto #22              ││││││││││╻              _mapreduce
    29 ─       goto #31              │││││││││││    
    30%95 = invoke Base.mapreduce_impl(%1::typeof(identity), %2::typeof(Base.add_sum), _2::Vector{Float64}, 1::Int64, %5::Int64, 1024::Int64)::Float64
    └───       goto #31              │││││││││││    
    31%97 = φ (#2 => 0.0, #9 => %27, #29 => %69, #30 => %95)::Float64
    └───       goto #32              │              
    32 ─       goto #33              │              
    33 ─       goto #34              ││││││││       
    34 ─       goto #35              │              
    35 ─       goto #36              ││││││         
    36 ─       goto #37              │              
    37 ─       goto #38              ││││           
    38 ─       goto #39              │              
    39 ─       goto #40              ││             
    40return %97=> Float64                                     

and the following error when building the dual IR:

julia> rule = build_frule(f6_inner, rand(10))
Block 4 successors (Array{Int64, 1}(dims=(2,), mem=Memory{Int64}(2, 0x33adf1010)[8, 5])), does not match GotoNode terminator (8)
246 1 ── %1   = Base.identity::typeof(identity)
    │    %2   = Base.add_sum::typeof(Base.add_sum)
    │    %3   = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:size}())::Any%4   = (DualArguments(frule!!))(Mooncake.lgetfield, %3, Val{1}(), Val{true}())::Any%5   = (DualArguments(frule!!))(===, %4, 0)::Any%6   = (Mooncake._primal)(%5)::Any
    └───        goto #3 if not %6
    2 ──        goto #31
    3 ── %9   = (DualArguments(frule!!))(===, %4, 1)::Any%10  = (Mooncake._primal)(%9)::Any
    └───        goto #10 if not %10
    4 ──        goto #8
    5 ── %13  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.sub_int, 1, 1)::Any%14  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %13)::Any%15  = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:size}())::Any%16  = (DualArguments(frule!!))(Mooncake.lgetfield, %15, Val{1}(), Val{true}())::Any%17  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %16)::Any%18  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.ult_int, %14, %17)::Any%19  = (Mooncake._primal)(%18)::Any
    └───        goto #7 if not %19
    6 ──        goto #8
    7 ── %22  = (DualArguments(frule!!))(tuple, 1)::Any
    │           (DualArguments(frule!!))(Base.throw_boundserror, _3, %22)::Any%24  = unreachable
    └───        return %24
    8 ┄─ %26  = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:ref}())::Any%27  = (DualArguments(frule!!))(Core.memoryrefnew, %26, 1, false)::Any%28  = (DualArguments(frule!!))(Mooncake.lmemoryrefget, %27, Val{:not_atomic}(), Val{false}())::Any
    └───        goto #9
    9 ──        goto #31
    10%31  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.slt_int, %4, 16)::Any%32  = (Mooncake._primal)(%31)::Any
    └───        goto #30 if not %32
    11 ─        goto #15
    12%35  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.sub_int, 1, 1)::Any%36  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %35)::Any%37  = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:size}())::Any%38  = (DualArguments(frule!!))(Mooncake.lgetfield, %37, Val{1}(), Val{true}())::Any%39  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %38)::Any%40  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.ult_int, %36, %39)::Any%41  = (Mooncake._primal)(%40)::Any
    └───        goto #14 if not %41
    13 ─        goto #15
    14%44  = (DualArguments(frule!!))(tuple, 1)::Any
    │           (DualArguments(frule!!))(Base.throw_boundserror, _3, %44)::Any%46  = unreachable
    └───        return %46
    15%48  = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:ref}())::Any%49  = (DualArguments(frule!!))(Core.memoryrefnew, %48, 1, false)::Any%50  = (DualArguments(frule!!))(Mooncake.lmemoryrefget, %49, Val{:not_atomic}(), Val{false}())::Any
    └───        goto #16
    16 ─        goto #20
    17%53  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.sub_int, 2, 1)::Any%54  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %53)::Any%55  = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:size}())::Any%56  = (DualArguments(frule!!))(Mooncake.lgetfield, %55, Val{1}(), Val{true}())::Any%57  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %56)::Any%58  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.ult_int, %54, %57)::Any%59  = (Mooncake._primal)(%58)::Any
    └───        goto #19 if not %59
    18 ─        goto #20
    19%62  = (DualArguments(frule!!))(tuple, 2)::Any
    │           (DualArguments(frule!!))(Base.throw_boundserror, _3, %62)::Any%64  = unreachable
    └───        return %64
    20%66  = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:ref}())::Any%67  = (DualArguments(frule!!))(Core.memoryrefnew, %66, 2, false)::Any%68  = (DualArguments(frule!!))(Mooncake.lmemoryrefget, %67, Val{:not_atomic}(), Val{false}())::Any
    └───        goto #21
    21%70  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.add_float, %50, %68)::Any
    22%71  = φ (#21 => %70, #28 => %95)::Any%72  = φ (#21 => 2, #28 => %76)::Any%73  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.slt_int, %72, %4)::Any%74  = (Mooncake._primal)(%73)::Any
    └───        goto #29 if not %74
    23%76  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.add_int, %72, 1)::Any
    └───        goto #27
    24%78  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.sub_int, %76, 1)::Any%79  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %78)::Any%80  = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:size}())::Any%81  = (DualArguments(frule!!))(Mooncake.lgetfield, %80, Val{1}(), Val{true}())::Any%82  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.bitcast, Base.UInt, %81)::Any%83  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.ult_int, %79, %82)::Any%84  = (Mooncake._primal)(%83)::Any
    └───        goto #26 if not %84
    25 ─        goto #27
    26%87  = (DualArguments(frule!!))(tuple, %76)::Any
    │           (DualArguments(frule!!))(Base.throw_boundserror, _3, %87)::Any%89  = unreachable
    └───        return %89
    27%91  = (DualArguments(frule!!))(Mooncake.lgetfield, _3, Val{:ref}())::Any%92  = (DualArguments(frule!!))(Core.memoryrefnew, %91, %76, false)::Any%93  = (DualArguments(frule!!))(Mooncake.lmemoryrefget, %92, Val{:not_atomic}(), Val{false}())::Any
    └───        goto #28
    28%95  = (DualArguments(frule!!))(Mooncake.IntrinsicsWrappers.add_float, %71, %93)::Any
    └───        goto #22
    29 ─        goto #31
    30%98  = (DualArguments(Mooncake.LazyFRule{Tuple{typeof(Base.mapreduce_impl), typeof(identity), typeof(Base.add_sum), Vector{Float64}, Int64, Int64, Int64}, Mooncake.DerivedFRule{MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Dual{typeof(Base.mapreduce_impl), NoTangent}, Dual{typeof(identity), NoTangent}, Dual{typeof(Base.add_sum), NoTangent}, Dual{Vector{Float64}, Vector{Float64}}, Dual{Int64, NoTangent}, Dual{Int64, NoTangent}, Dual{Int64, NoTangent}}, Dual{Float64, Float64}}}}}(false, MethodInstance for Base.mapreduce_impl(::typeof(identity), ::typeof(Base.add_sum), ::Vector{Float64}, ::Int64, ::Int64, ::Int64), #undef)))(Base.mapreduce_impl, %1, %2, _3, 1, %4, 1024)::Any
    └───        goto #31
    31%100 = φ (#2 => 0.0, #9 => %28, #29 => %71, #30 => %98)::Any
    └───        goto #32
    32 ─        goto #33
    33 ─        goto #34
    34 ─        goto #35
    35 ─        goto #36
    36 ─        goto #37
    37 ─        goto #38
    38 ─        goto #39
    39 ─        goto #40
    40%110 = (Mooncake._dual)(%100)::Any
    └───        return %110
    ERROR: 
Stacktrace:
  [1] error(s::String)
    @ Core.Compiler ./error.jl:35
  [2] verify_ir(ir::Core.Compiler.IRCode, print::Bool, allow_frontend_forms::Bool, 𝕃ₒ::Core.Compiler.PartialsLattice{…})
    @ Core.Compiler ./compiler/ssair/verify.jl:191
  [3] verify_ir
    @ ./compiler/ssair/verify.jl:98 [inlined]
  [4] generate_dual_ir(interp::Mooncake.MooncakeInterpreter{DefaultCtx}, sig_or_mi::Type; debug_mode::Bool, do_inline::Bool)
    @ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:113
  [5] generate_dual_ir
    @ ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:62 [inlined]
  [6] build_frule(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Type; debug_mode::Bool, silence_debug_messages::Bool)
    @ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:40
  [7] build_frule
    @ ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:7 [inlined]
  [8] build_frule(::Function, ::Vararg{Any}; debug_mode::Bool)
    @ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:4
  [9] build_frule(::Function, ::Vector{Float64})
    @ Mooncake ~/Documents/GitHub/Julia/Mooncake.jl/src/interpreter/s2s_forward_mode_ad.jl:1
 [10] top-level scope
    @ ~/Documents/GitHub/Julia/Mooncake.jl/test/playground.jl:4
Some type information was truncated. Use `show(err)` to see complete types.

The error I get:

Block 4 successors (Array{Int64, 1}(dims=(2,), mem=Memory{Int64}(2, 0x33adf1010)[8, 5])), does not match GotoNode terminator (8)

As you can see, in block #4, the goto #8 if not ... in the primal IR becomes a simple goto #8 in the dual IR. I think this is because the previous boundscheck instruction was somehow removed during your IR normalization, before I even touched the IR to dualize it. As a result, block #4 can no longer point to block #5 where I assume the bounds check handling is located.
In other words, this line seems like the culprit to me:

statements[n] = nothing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants