Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UDE Training does not work with AutoZygote #20

Open
sathvikbhagavan opened this issue Apr 11, 2024 · 19 comments
Open

UDE Training does not work with AutoZygote #20

sathvikbhagavan opened this issue Apr 11, 2024 · 19 comments
Labels
bug Something isn't working

Comments

@sathvikbhagavan
Copy link
Member

Describe the bug 🐞

UDE Training does not work with AutoZygote

Expected behavior

It should work.

Minimal Reproducible Example 👇

The lotka volterra test in tests/lotka_volterra.jl

Error & Stacktrace ⚠️

julia> res = solve(op, Adam(), maxiters = 10)
ERROR: MethodError: no method matching length(::ModelingToolkit.MTKParameters{Tuple{Vector{Vector{Float64}}}, Tuple{}, Tuple{Vector{Float64}}, Tuple{}, Tuple{Vector{DataType}}, Nothing, Nothing})

Closest candidates are:
  length(::LaTeXStrings.LaTeXString)
   @ LaTeXStrings ~/.julia/packages/LaTeXStrings/ZtSdh/src/LaTeXStrings.jl:115
  length(::SymbolicUtils.Code.Assignment)
   @ SymbolicUtils ~/.julia/packages/SymbolicUtils/c0xQb/src/utils.jl:225
  length(::CSTParser.EXPR)
   @ CSTParser ~/.julia/packages/CSTParser/mVfZt/src/spec.jl:278
  ...

Stacktrace:
  [1] automatic_sensealg_choice(prob::ODEProblem{…}, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, verbose::Bool)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/rXkM4/src/concrete_solve.jl:84
  [2] _concrete_solve_adjoint(::ODEProblem{…}, ::Rodas4{…}, ::Nothing, ::Vector{…}, ::ModelingToolkit.MTKParameters{…}, ::SciMLBase.ChainRulesOriginator; verbose::Bool, kwargs::@Kwargs{})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/rXkM4/src/concrete_solve.jl:218
  [3] _solve_adjoint(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, originator::SciMLBase.ChainRulesOriginator, args::Rodas4{…}; merge_callbacks::Bool, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:1537
  [4] rrule(::typeof(DiffEqBase.solve_up), prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, args::Rodas4{…}; kwargs::@Kwargs{})
    @ DiffEqBaseChainRulesCoreExt ~/.julia/packages/DiffEqBase/O8cUq/ext/DiffEqBaseChainRulesCoreExt.jl:26
  [5] kwcall(::@NamedTuple{}, ::typeof(ChainRulesCore.rrule), ::Zygote.ZygoteRuleConfig{…}, ::Function, ::ODEProblem{…}, ::Nothing, ::Vector{…}, ::ModelingToolkit.MTKParameters{…}, ::Rodas4{…})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zgT0R/src/rules.jl:140
  [6] chain_rrule_kw
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:235 [inlined]
  [7] macro expansion
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0 [inlined]
  [8] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{…}, ::Nothing, ::Vector{…}, ::ModelingToolkit.MTKParameters{…}, ::Rodas4{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:81
  [9] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [10] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
 [11] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [12] #solve#51
    @ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:1003 [inlined]
 [13] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#51", ::Nothing, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{}, ::typeof(solve), ::ODEProblem{…}, ::Rodas4{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [14] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [15] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
 [16] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [17] solve
    @ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:993 [inlined]
 [18] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{}, ::typeof(solve), ::ODEProblem{…}, ::Rodas4{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [19] loss
    @ ./REPL[30]:5 [inlined]
 [20] _pullback(::Zygote.Context{…}, ::typeof(loss), ::Vector{…}, ::Tuple{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [21] _apply
    @ ./boot.jl:838 [inlined]
 [22] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
 [23] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [24] OptimizationFunction
    @ ~/.julia/packages/SciMLBase/NjslX/src/scimlfunctions.jl:3649 [inlined]
 [25] _pullback(::Zygote.Context{…}, ::OptimizationFunction{…}, ::Vector{…}, ::Tuple{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [26] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [27] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
 [28] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [29] #37
    @ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:90 [inlined]
 [30] _pullback(ctx::Zygote.Context{false}, f::OptimizationZygoteExt.var"#37#55"{OptimizationFunction{}, OptimizationBase.ReInitCache{}}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [31] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [32] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
 [33] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [34] #39
    @ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93 [inlined]
 [35] _pullback(ctx::Zygote.Context{false}, f::OptimizationZygoteExt.var"#39#57"{Tuple{}, OptimizationZygoteExt.var"#37#55"{}}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [36] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:90
 [37] pullback
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:88 [inlined]
 [38] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:147
 [39] (::OptimizationZygoteExt.var"#38#56"{OptimizationZygoteExt.var"#37#55"{OptimizationFunction{}, OptimizationBase.ReInitCache{}}})(::Vector{Float64}, ::Vector{Float64})
    @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93
 [40] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [41] macro expansion
    @ ~/.julia/packages/Optimization/5DEdF/src/utils.jl:32 [inlined]
 [42] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [43] solve!(cache::OptimizationCache{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/NjslX/src/solve.jl:180
 [44] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{})
    @ SciMLBase ~/.julia/packages/SciMLBase/NjslX/src/solve.jl:96
 [45] top-level scope
    @ REPL[35]:1
Some type information was truncated. Use `show(err)` to see complete types.
@sathvikbhagavan sathvikbhagavan added the bug Something isn't working label Apr 11, 2024
@ChrisRackauckas
Copy link
Member

This is SciML/SciMLSensitivity.jl#1010

@hstrey
Copy link

hstrey commented Jul 3, 2024

after updates to SciMLSensitivity.jl
pkg> st [336ed68f] CSV v0.10.14 [13f3f980] CairoMakie v0.12.4 [b0b7db55] ComponentArrays v0.15.14 [a93c6f00] DataFrames v1.6.1 [0c46a032] DifferentialEquations v7.13.0 [a98d9a8b] Interpolations v0.15.1 [033835bb] JLD2 v0.4.48 [b2108857] Lux v0.5.59 [961ee093] ModelingToolkit v9.23.0 [f162e290] ModelingToolkitNeuralNets v1.0.2 [16a59e39] ModelingToolkitStandardLibrary v2.7.2 [7f7a1694] Optimization v3.26.3 [36348300] OptimizationOptimJL v0.3.2 [42dfb2eb] OptimizationOptimisers v0.2.1 [1dea7af3] OrdinaryDiffEq v6.85.0 [18e31ff7] Peaks v0.5.2 [1ed8b502] SciMLSensitivity v7.62.0 [53ae85a6] SciMLStructures v1.4.1 [860ef19b] StableRNGs v1.0.2 [2efcf032] SymbolicIndexingInterface v0.3.22 [e88e6eb3] Zygote v0.6.70 [9e88b42a] SerializationI

I reran the Friction example with AutoZygote and I got the same error as before:
`ERROR: MethodError: no method matching length(::ModelingToolkit.MTKParameters{…})

Closest candidates are:
length(::Combinatorics.FixedPartitions)
@ Combinatorics ~/.julia/packages/Combinatorics/Udg6X/src/partitions.jl:96
length(::Core.Compiler.InstructionStream)
@ Base show.jl:2777
length(::SymbolicUtils.Code.AtIndex)
@ SymbolicUtils ~/.julia/packages/SymbolicUtils/dtCid/src/utils.jl:227
...

Stacktrace:
[1] _empty(x::ModelingToolkit.MTKParameters{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:171
[2] map
@ ./tuple.jl:291 [inlined]
[3] adjoint
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:204 [inlined]
[4] adjoint(context::Zygote.Context{…}, 463::typeof(Core._apply_iterate), 464::typeof(iterate), f::Function, args::ModelingToolkit.MTKParameters{…})
@ Zygote ./none:0
[5] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[6] replace
@ ~/.julia/packages/ModelingToolkit/nadwo/src/systems/parameter_buffer.jl:288 [inlined]
[7] _pullback(::Zygote.Context{…}, ::typeof(SciMLStructures.replace), ::Tunable, ::ModelingToolkit.MTKParameters{…}, ::Vector{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[8] loss
@ ~/Documents/programming/NeurobloxSciML/friction.jl:73 [inlined]
[9] _pullback(::Zygote.Context{…}, ::typeof(loss), ::Vector{…}, ::Tuple{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[10] _apply
@ ./boot.jl:838 [inlined]
[11] adjoint
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
[12] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[13] OptimizationFunction
@ ~/.julia/packages/SciMLBase/rR75x/src/scimlfunctions.jl:3763 [inlined]
[14] _pullback(::Zygote.Context{…}, ::OptimizationFunction{…}, ::Vector{…}, ::Tuple{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[15] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[16] adjoint
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
[17] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[18] #37
@ ~/.julia/packages/OptimizationBase/32Mb0/ext/OptimizationZygoteExt.jl:90 [inlined]
[19] _pullback(ctx::Zygote.Context{…}, f::OptimizationZygoteExt.var"#37#55"{…}, args::Vector{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[20] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[21] adjoint
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
[22] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[23] #39
@ ~/.julia/packages/OptimizationBase/32Mb0/ext/OptimizationZygoteExt.jl:93 [inlined]
[24] _pullback(ctx::Zygote.Context{…}, f::OptimizationZygoteExt.var"#39#57"{…}, args::Vector{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[25] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90
[26] pullback
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88 [inlined]
[27] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:147
[28] (::OptimizationZygoteExt.var"#38#56"{OptimizationZygoteExt.var"#37#55"{…}})(::Vector{Float64}, ::Vector{Float64})
@ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/32Mb0/ext/OptimizationZygoteExt.jl:93
[29] macro expansion
@ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
[30] macro expansion
@ ~/.julia/packages/Optimization/EmxXu/src/utils.jl:32 [inlined]
[31] __solve(cache::OptimizationCache{…})
@ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
[32] solve!(cache::OptimizationCache{…})
@ SciMLBase ~/.julia/packages/SciMLBase/rR75x/src/solve.jl:188
[33] solve(::OptimizationProblem{…}, ::Adam; kwargs::@kwargs{…})
@ SciMLBase ~/.julia/packages/SciMLBase/rR75x/src/solve.jl:96
[34] top-level scope
@ ~/Documents/programming/NeurobloxSciML/friction.jl:101
Some type information was truncated. Use show(err) to see complete types.`

@ChrisRackauckas
Copy link
Member

@DhairyaLGandhi can you work on updating the example here to use AutoZygote?

@DhairyaLGandhi
Copy link
Member

yeah this is the UDEs piece.

As a workaround it is possible to define something like

function Base.length(p::ModelingToolkit.MTKParameters)
    return length(p.tunable)
end

The basic issue is that MTKParameters defines getindex but not length.

@hstrey
Copy link

hstrey commented Jul 5, 2024

@DhairyaLGandhi I tried your suggestion, but then it failed on a similar issue with similar(::ModelingToolkit.MTKParameters{…})

so I made another dispatch for similar:
function Base.similar(p::ModelingToolkit.MTKParameters) return similar(p.tunable) end

but now the run end with an error:
`ERROR: MethodError: no method matching similar(::Tuple{Vector{Vector{Float64}}})

Closest candidates are:
similar(::VSCodeServer.JuliaInterpreter.Compiled, ::Any)
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.79.2/scripts/packages/JuliaInterpreter/src/types.jl:7
similar(::Type{SA}, ::Type{T}, ::StaticArraysCore.Size{S}) where {SA<:StaticArraysCore.SizedArray, T, S}
@ StaticArrays ~/.julia/packages/StaticArrays/MSJcA/src/abstractarray.jl:135
similar(::Type{A}, ::Type{T}, ::StaticArraysCore.Size{S}) where {A<:Array, T, S}
@ StaticArrays ~/.julia/packages/StaticArrays/MSJcA/src/abstractarray.jl:136
...

Stacktrace:
[1] similar(p::ModelingToolkit.MTKParameters{…})
@ Main ~/Documents/programming/NeurobloxSciML/friction.jl:23
[2] build_param_jac_config(alg::GaussAdjoint{…}, pf::Function, u::Vector{…}, p::ModelingToolkit.MTKParameters{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4YtYh/src/derivative_wrappers.jl:1069
[3] SciMLSensitivity.GaussIntegrand(sol::ODESolution{…}, sensealg::GaussAdjoint{…}, checkpoints::Vector{…}, dgdp::Nothing)
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4YtYh/src/gauss_adjoint.jl:440`

p.tunable is not compatible with similar

@DhairyaLGandhi
Copy link
Member

@hstrey
Copy link

hstrey commented Jul 5, 2024

Tried your branch. Same error. It seems that all the SciMLSensitivity function expect a flat vector for p, but when I look at p.tunable then I get:
prob.p.tunable ([[-0.42492082715034485, -0.6530026793479919, 0.6973646879196167, 0.058748241513967514, 0.1943797767162323, -0.05171017348766327, -0.08400869369506836, 0.2643197178840637, -0.5101186037063599, -0.7106459736824036 … -0.6610037684440613, 0.1330592930316925, 0.03395846113562584, 0.6738154292106628, -0.41331109404563904, 0.14921322464942932, 0.6819839477539062, 0.3220810294151306, -0.3412023186683655, 0.21613739430904388]],)
which is a Tuple{Vector{Vector{Float64}}}

@hstrey
Copy link

hstrey commented Jul 5, 2024

I even tried to get around this by:
function Base.similar(p::ModelingToolkit.MTKParameters) return similar(p.tunable[1][1]) end

but then it just fails at the next function

`ERROR: MethodError: no method matching jacobian!(::Matrix{…}, ::SciMLBase.ParamJacobianWrapper{…}, ::ModelingToolkit.MTKParameters{…}, ::Vector{…}, ::GaussAdjoint{…}, ::FiniteDiff.JacobianCache{…})

Closest candidates are:
jacobian!(::AbstractMatrix{<:Number}, ::Any, ::AbstractArray{<:Number}, ::Union{Nothing, AbstractArray{<:Number}}, ::SciMLBase.AbstractOverloadingSensitivityAlgorithm, ::Any)
@ SciMLSensitivity ~/Documents/programming/SciMLSensitivity.jl/src/derivative_wrappers.jl:147

Stacktrace:
[1] vec_pjac!(out::RecursiveArrayTools.ArrayPartition{…}, λ::Vector{…}, y::Vector{…}, t::Float64, S::SciMLSensitivity.GaussIntegrand{…})
@ SciMLSensitivity ~/Documents/programming/SciMLSensitivity.jl/src/gauss_adjoint.jl:469
[2] GaussIntegrand
@ ~/Documents/programming/SciMLSensitivity.jl/src/gauss_adjoint.jl:519 [inlined]
[3] (::SciMLSensitivity.var"#262#263"{…})(out::RecursiveArrayTools.ArrayPartition{…}, u::Vector{…}, t::Float64, integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ SciMLSensitivity ~/Documents/programming/SciMLSensitivity.jl/src/gauss_adjoint.jl:560
[4] (::DiffEqCallbacks.SavingIntegrandSumAffect{…})(integrator::OrdinaryDiffEq.ODEIntegrator{…})`

@DhairyaLGandhi
Copy link
Member

Looking at the stacktrace, the culprit is in the default sensealg selection with SDEProblem. https://github.com/SciML/SciMLSensitivity.jl/tree/dg/length should fix that.

Running tests/lotka_volterra.jl as in the OP

julia> res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
retcode: Default
u: 57-element Vector{Float64}:
  0.17041097372357306
  0.06721610894572924
  1.5036621719846985
  0.6223167933686936
  2.406843634370448
 -1.2392234160442255
  1.0211360153466225
 -0.26749421630122894
  1.2072691961676691
 -0.3280020513366658
 -0.3199580480281966
  0.15962631397576563
 -0.12814730955995285
  0.25890635293263664
 -0.9340225939644783
 -0.7582786846498972
  
  0.33400507029273274
  0.37703694871113425
 -0.38265915595848343
  0.2532281365300615
 -1.0286697496982982
  1.9856277316142081
 -1.2304071436862438
  2.7505848772204997
 -1.663046900479545
  2.3058247613525467
  1.4008403661001196
 -2.0467838051849423
 -0.6461702206670253
  2.6279530122972967
 -0.5972921571545664
  1.723921054838542

@ChrisRackauckas
Copy link
Member

Why SDEProblem? it's an ODE?

@hstrey
Copy link

hstrey commented Jul 5, 2024

@DhairyaLGandhi I tried your branch "dg/length" on the test/lotka_volterra.jl using AutoZygote, and this is what I got:

ERROR: ForwardDiffSensitivity assumes the AbstractArray interface for p. Thus while
DifferentialEquations.jl can support any parameter struct type, usage
with ForwardDiffSensitivity requires that p could be a valid
type for being the initial condition u0 of an array. This means that
many simple types, such as Tuples and NamedTuples, will work as
parameters in normal contexts but will fail during ForwardDiffSensitivity
construction. To work around this issue for complicated cases like nested structs,
look into defining p using AbstractArray libraries such as RecursiveArrayTools.jl
or ComponentArrays.jl.

Stacktrace:
[1] _concrete_solve_adjoint(::ODEProblem{…}, ::Rodas4{…}, ::ForwardDiffSensitivity{…}, ::Vector{…}, ::ModelingToolkit.MTKParameters{…}, ::SciMLBase.ChainRulesOriginator; saveat::Vector{…}, kwargs::@kwargs{…})
@ SciMLSensitivity ~/Documents/programming/SciMLSensitivity.jl/src/concrete_solve.jl:788
[2] _concrete_solve_adjoint(::ODEProblem{…}, ::Rodas4{…}, ::Nothing, ::Vector{…}, ::ModelingToolkit.MTKParameters{…}, ::SciMLBase.ChainRulesOriginator; verbose::Bool, kwargs::@kwargs{…})
@ SciMLSensitivity ~/Documents/programming/SciMLSensitivity.jl/src/concrete_solve.jl:278
[3] _concrete_solve_adjoint
@ ~/Documents/programming/SciMLSensitivity.jl/src/concrete_solve.jl:245 [inlined]
[4] #_solve_adjoint#75
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1537 [inlined]
[5] _solve_adjoint
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1510 [inlined]

@hstrey
Copy link

hstrey commented Jul 5, 2024

@DhairyaLGandhi I omitted these messages before the final fail in my previous attempt with my own simular function:

┌ Warning: Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add verbose = false to the solve call.
└ @ SciMLSensitivity ~/Documents/programming/SciMLSensitivity.jl/src/concrete_solve.jl:24

┌ Warning: Potential performance improvement omitted. ReverseDiffVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add verbose = false to the solve call.
└ @ SciMLSensitivity ~/Documents/programming/SciMLSensitivity.jl/src/concrete_solve.jl:67

┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/Documents/programming/SciMLSensitivity.jl/src/concrete_solve.jl:207
ERROR: MethodError: no method matching jacobian!(::Matrix{…}, ::SciMLBase.ParamJacobianWrapper{…}, ::ModelingToolkit.MTKParameters{…}, ::Vector{…}, ::GaussAdjoint{…}, ::FiniteDiff.JacobianCache{…})

Closest candidates are:
jacobian!(::AbstractMatrix{<:Number}, ::Any, ::AbstractArray{<:Number}, ::Union{Nothing, AbstractArray{<:Number}}, ::SciMLBase.AbstractOverloadingSensitivityAlgorithm, ::Any)
@ SciMLSensitivity ~/Documents/programming/SciMLSensitivity.jl/src/derivative_wrappers.jl:147

Stacktrace:
[1] vec_pjac!(out::RecursiveArrayTools.ArrayPartition{…}, λ::Vector{…}, y::Vector{…}, t::Float64, S::SciMLSensitivity.GaussIntegrand{…})
@ SciMLSensitivity ~/Documents/programming/SciMLSensitivity.jl/src/gauss_adjoint.jl:469
[2] GaussIntegrand
@ ~/Documents/programming/SciMLSensitivity.jl/src/gauss_adjoint.jl:519 [inlined]
[3] (::SciMLSensitivity.var"#262#263"{…})(out::RecursiveArrayTools.ArrayPartition{…}, u::Vector{…}, t::Float64, integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ SciMLSensitivity ~/Documents/programming/SciMLSensitivity.jl/src/gauss_adjoint.jl:560
[4] (::DiffEqCallbacks.SavingIntegrandSumAffect{…})(integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ DiffEqCallbacks ~/.julia/packages/DiffEqCallbacks/9fKPq/src/integrating_sum.jl:50

@DhairyaLGandhi
Copy link
Member

Why SDEProblem? it's an ODE?

It's the same dispatch for both. This is the automatic sensealg. Apologies for the confusion. Yes I think I understand what is happening there. I'll add a fix for this.

@acertain
Copy link

I'm getting

ERROR: ForwardDiffSensitivity assumes the `AbstractArray` interface for `p`. Thus while
DifferentialEquations.jl can support any parameter struct type, usage
with ForwardDiffSensitivity requires that `p` could be a valid
type for being the initial condition `u0` of an array. This means that
many simple types, such as `Tuple`s and `NamedTuple`s, will work as
parameters in normal contexts but will fail during ForwardDiffSensitivity
construction. To work around this issue for complicated cases like nested structs,
look into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl
or ComponentArrays.jl.

with Zygote and default auto sensealg.

If I specify sensealg = ReverseDiffAdjoint(), I get
ERROR: MethodError: no method matching length(::ModelingToolkit.MTKParameters{…})

@DhairyaLGandhi I think ReverseDiffAdjoint also needs to be updated to use tunables a la your PR?

If I use sensealg = ZygoteAdjoint() I get

ERROR: ArgumentError: new: too few arguments (expected 49)
Stacktrace:
   [1] __new__(::Type, ::SciMLBase.ODESolution{…}, ::Vararg{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/tools/builtins.jl:9
   [2] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:296 [inlined]
   [3] adjoint(::Zygote.Context{…}, ::typeof(Zygote.__new__), ::Type, ::SciMLBase.ODESolution{…}, ::Vector{…}, ::Nothing, ::Vector{…}, ::Vararg{…})
     @ Zygote ./none:0
   [4] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
   [5] ODEIntegrator
     @ ~/.julia/packages/OrdinaryDiffEq/s27pa/src/integrators/type.jl:168 [inlined]
   [6] _pullback(::Zygote.Context{…}, ::Type{…}, ::SciMLBase.ODESolution{…}, ::Vector{…}, ::Nothing, ::Vector{…}, ::Float64, ::Float64, ::SciMLBase.ODEFunction{…}, ::ModelingToolkit.MTKParameters{…}, ::Vector{…}, ::Vector{…}, ::Nothing, ::Float64, ::OrdinaryDiffEq.Tsit5{…}, ::Float64, ::Bool, ::Float64, ::Float64, ::Float64, ::Float64, ::Float64, ::Float64, ::Float64, ::Float64, ::Int64, ::Int64, ::Int64, ::Int64, ::OrdinaryDiffEq.Tsit5Cache{…}, ::Nothing, ::Int64, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::Int64, ::Float64, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::OrdinaryDiffEq.DEOptions{…}, ::SciMLBase.DEStats, ::OrdinaryDiffEq.DefaultInit, ::Nothing)
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
   [7] #__init#434
     @ ~/.julia/packages/OrdinaryDiffEq/s27pa/src/solve.jl:474 [inlined]
   [8] _pullback(::Zygote.Context{…}, ::OrdinaryDiffEq.var"##__init#434", ::StepRangeLen{…}, ::Tuple{}, ::Tuple{}, ::Nothing, ::Bool, ::Bool, ::Bool, ::Nothing, ::SciMLBase.CallbackSet{…}, ::Bool, ::Bool, ::Float64, ::Float64, ::Float64, ::Bool, ::Bool, ::Rational{…}, ::Float64, ::Float64, ::Rational{…}, ::Int64, ::Int64, ::Int64, ::Nothing, ::Nothing, ::Rational{…}, ::Nothing, ::Bool, ::Int64, ::Int64, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(LinearAlgebra.opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Symbol, ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::OrdinaryDiffEq.DefaultInit, ::@Kwargs{}, ::typeof(SciMLBase.__init), ::SciMLBase.ODEProblem{…}, ::OrdinaryDiffEq.Tsit5{…}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
   [9] __init
     @ ~/.julia/packages/OrdinaryDiffEq/s27pa/src/solve.jl:11 [inlined]
  [10] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(SciMLBase.__init), ::SciMLBase.ODEProblem{…}, ::OrdinaryDiffEq.Tsit5{…}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [11] __init (repeats 4 times)
     @ ~/.julia/packages/OrdinaryDiffEq/s27pa/src/solve.jl:11 [inlined]
  [12] _apply
     @ ./boot.jl:838 [inlined]
  [13] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [14] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [15] #__solve#433
     @ ~/.julia/packages/OrdinaryDiffEq/s27pa/src/solve.jl:6 [inlined]
  [16] _pullback(::Zygote.Context{…}, ::OrdinaryDiffEq.var"##__solve#433", ::@Kwargs{…}, ::typeof(SciMLBase.__solve), ::SciMLBase.ODEProblem{…}, ::OrdinaryDiffEq.Tsit5{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [17] _apply(::Function, ::Vararg{Any})
     @ Core ./boot.jl:838
  [18] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [19] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [20] __solve
     @ ~/.julia/packages/OrdinaryDiffEq/s27pa/src/solve.jl:1 [inlined]
  [21] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(SciMLBase.__solve), ::SciMLBase.ODEProblem{…}, ::OrdinaryDiffEq.Tsit5{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [22] _apply(::Function, ::Vararg{Any})
     @ Core ./boot.jl:838
  [23] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [24] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [25] #solve_call#44
     @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:612 [inlined]
  [26] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve_call#44", ::Bool, ::Nothing, ::@Kwargs{…}, ::typeof(DiffEqBase.solve_call), ::SciMLBase.ODEProblem{…}, ::OrdinaryDiffEq.Tsit5{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [27] _apply
     @ ./boot.jl:838 [inlined]
  [28] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [29] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [30] solve_call
     @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:569 [inlined]
  [31] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(DiffEqBase.solve_call), ::SciMLBase.ODEProblem{…}, ::OrdinaryDiffEq.Tsit5{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [32] #solve_up#53
     @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1080 [inlined]
  [33] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve_up#53", ::@Kwargs{…}, ::typeof(DiffEqBase.solve_up), ::SciMLBase.ODEProblem{…}, ::DiffEqBase.SensitivityADPassThrough, ::Vector{…}, ::ModelingToolkit.MTKParameters{…}, ::OrdinaryDiffEq.Tsit5{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [34] _apply
     @ ./boot.jl:838 [inlined]
  [35] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [36] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [37] solve_up
     @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1066 [inlined]
  [38] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(DiffEqBase.solve_up), ::SciMLBase.ODEProblem{…}, ::DiffEqBase.SensitivityADPassThrough, ::Vector{…}, ::ModelingToolkit.MTKParameters{…}, ::OrdinaryDiffEq.Tsit5{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [39] _apply(::Function, ::Vararg{Any})
     @ Core ./boot.jl:838
  [40] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [41] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [42] #solve#51
     @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
  [43] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#51", ::DiffEqBase.SensitivityADPassThrough, ::Vector{…}, ::ModelingToolkit.MTKParameters{…}, ::Val{…}, ::@Kwargs{…}, ::typeof(CommonSolve.solve), ::SciMLBase.ODEProblem{…}, ::OrdinaryDiffEq.Tsit5{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [44] _apply(::Function, ::Vararg{Any})
     @ Core ./boot.jl:838
  [45] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [46] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [47] solve
     @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:993 [inlined]
  [48] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(CommonSolve.solve), ::SciMLBase.ODEProblem{…}, ::OrdinaryDiffEq.Tsit5{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [49] _apply(::Function, ::Vararg{Any})
     @ Core ./boot.jl:838
  [50] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [51] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [52] #346
     @ ~/Sync/Code/scripts/qself/old/julia/dev/SciMLSensitivity/src/concrete_solve.jl:1124 [inlined]
  [53] _pullback(::Zygote.Context{…}, ::SciMLSensitivity.var"#346#348"{…}, ::Vector{…}, ::ModelingToolkit.MTKParameters{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [54] pullback(::Function, ::Zygote.Context{false}, ::Vector{Float64}, ::Vararg{Any})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90
  [55] pullback
     @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88 [inlined]
  [56] #_concrete_solve_adjoint#344
     @ ~/Sync/Code/scripts/qself/old/julia/dev/SciMLSensitivity/src/concrete_solve.jl:1123 [inlined]
  [57] _concrete_solve_adjoint
     @ ~/Sync/Code/scripts/qself/old/julia/dev/SciMLSensitivity/src/concrete_solve.jl:1110 [inlined]
  [58] #_solve_adjoint#75
     @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1537 [inlined]
  [59] _solve_adjoint
     @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1510 [inlined]
  [60] #rrule#4
     @ ~/.julia/packages/DiffEqBase/c8MAQ/ext/DiffEqBaseChainRulesCoreExt.jl:26 [inlined]
  [61] rrule
     @ ~/.julia/packages/DiffEqBase/c8MAQ/ext/DiffEqBaseChainRulesCoreExt.jl:22 [inlined]
  [62] rrule
     @ ~/.julia/packages/ChainRulesCore/I1EbV/src/rules.jl:140 [inlined]
  [63] chain_rrule_kw
     @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:235 [inlined]
  [64] macro expansion
     @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0 [inlined]
  [65] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(DiffEqBase.solve_up), ::SciMLBase.ODEProblem{…}, ::SciMLSensitivity.ZygoteAdjoint, ::Vector{…}, ::ModelingToolkit.MTKParameters{…}, ::OrdinaryDiffEq.Tsit5{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:87
  [66] _apply(::Function, ::Vararg{Any})
     @ Core ./boot.jl:838
  [67] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [68] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [69] #solve#51
     @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
  [70] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#51", ::SciMLSensitivity.ZygoteAdjoint, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{…}, ::typeof(CommonSolve.solve), ::SciMLBase.ODEProblem{…}, ::OrdinaryDiffEq.Tsit5{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [71] _apply(::Function, ::Vararg{Any})
     @ Core ./boot.jl:838
  [72] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [73] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [74] solve
     @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:993 [inlined]
  [75] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(CommonSolve.solve), ::SciMLBase.ODEProblem{…}, ::OrdinaryDiffEq.Tsit5{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [76] simulate
     @ ~/Sync/Code/scripts/qself/old/julia/src/QSelf.jl:222 [inlined]
  [77] _pullback(::Zygote.Context{…}, ::typeof(QSelf.simulate), ::DynamicPPL.Model{…}, ::DynamicPPL.ThreadSafeVarInfo{…}, ::DynamicPPL.DefaultContext, ::QSelf.Params, ::Bool)
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [78] _apply(::Function, ::Vararg{Any})
     @ Core ./boot.jl:838
  [79] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [80] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [81] _evaluate!!
     @ ~/.julia/packages/DynamicPPL/i2EbF/src/model.jl:968 [inlined]
  [82] _pullback(::Zygote.Context{…}, ::typeof(DynamicPPL._evaluate!!), ::DynamicPPL.Model{…}, ::DynamicPPL.ThreadSafeVarInfo{…}, ::DynamicPPL.DefaultContext)
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [83] evaluate_threadsafe!!
     @ ~/.julia/packages/DynamicPPL/i2EbF/src/model.jl:957 [inlined]
  [84] _pullback(::Zygote.Context{…}, ::typeof(DynamicPPL.evaluate_threadsafe!!), ::DynamicPPL.Model{…}, ::DynamicPPL.TypedVarInfo{…}, ::DynamicPPL.DefaultContext)
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [85] evaluate!!
     @ ~/.julia/packages/DynamicPPL/i2EbF/src/model.jl:892 [inlined]
  [86] _pullback(::Zygote.Context{…}, ::typeof(AbstractPPL.evaluate!!), ::DynamicPPL.Model{…}, ::DynamicPPL.TypedVarInfo{…}, ::DynamicPPL.DefaultContext)
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [87] obj_func
     @ ~/Sync/Code/scripts/qself/old/julia/src/sample.jl:91 [inlined]
  [88] _pullback(::Zygote.Context{…}, ::QSelf.var"#obj_func#41"{…}, ::Vector{…}, ::SciMLBase.NullParameters)
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [89] _apply(::Function, ::Vararg{Any})
     @ Core ./boot.jl:838
  [90] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [91] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [92] OptimizationFunction
     @ ~/.julia/packages/SciMLBase/hq1ku/src/scimlfunctions.jl:3775 [inlined]
  [93] _pullback(::Zygote.Context{…}, ::SciMLBase.OptimizationFunction{…}, ::Vector{…}, ::SciMLBase.NullParameters)
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [94] _apply(::Function, ::Vararg{Any})
     @ Core ./boot.jl:838
  [95] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [96] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [97] #37
     @ ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:94 [inlined]
  [98] _pullback(ctx::Zygote.Context{…}, f::OptimizationZygoteExt.var"#37#55"{…}, args::Vector{…})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [99] _apply(::Function, ::Vararg{Any})
     @ Core ./boot.jl:838
 [100] adjoint
     @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [101] _pullback
     @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [102] #39
     @ ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:97 [inlined]
 [103] _pullback(ctx::Zygote.Context{false}, f::OptimizationZygoteExt.var"#39#57"{Tuple{}, OptimizationZygoteExt.var"#37#55"{…}}, args::Vector{Float64})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [104] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90
 [105] pullback
     @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88 [inlined]
 [106] gradient(f::Function, args::Vector{Float64})
     @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:147
 [107] (::OptimizationZygoteExt.var"#38#56"{OptimizationZygoteExt.var"#37#55"{…}})(::Vector{Float64}, ::Vector{Float64})
     @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:97
 [108] (::OptimizationOptimJL.var"#8#14"{OptimizationBase.OptimizationCache{…}, OptimizationOptimJL.var"#7#13"{…}})(G::Vector{Float64}, θ::Vector{Float64})
     @ OptimizationOptimJL ~/.julia/packages/OptimizationOptimJL/hDX5k/src/OptimizationOptimJL.jl:175
 [109] value_gradient!!(obj::NLSolversBase.TwiceDifferentiable{Float64, Vector{Float64}, Matrix{Float64}, Vector{Float64}}, x::Vector{Float64})
     @ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/interface.jl:82
 [110] initial_state(method::Optim.LBFGS{…}, options::Optim.Options{…}, d::NLSolversBase.TwiceDifferentiable{…}, initial_x::Vector{…})
     @ Optim ~/.julia/packages/Optim/ZhuZN/src/multivariate/solvers/first_order/l_bfgs.jl:164
 [111] optimize(d::NLSolversBase.TwiceDifferentiable{…}, initial_x::Vector{…}, method::Optim.LBFGS{…}, options::Optim.Options{…})
     @ Optim ~/.julia/packages/Optim/ZhuZN/src/multivariate/optimize/optimize.jl:36
 [112] __solve(cache::OptimizationBase.OptimizationCache{…})
     @ OptimizationOptimJL ~/.julia/packages/OptimizationOptimJL/hDX5k/src/OptimizationOptimJL.jl:224
 [113] solve!(cache::OptimizationBase.OptimizationCache{…})
     @ SciMLBase ~/.julia/packages/SciMLBase/hq1ku/src/solve.jl:188
 [114] solve(::SciMLBase.OptimizationProblem{…}, ::Optim.LBFGS{…}; kwargs::@Kwargs{…})
     @ SciMLBase ~/.julia/packages/SciMLBase/hq1ku/src/solve.jl:96
 [115] optimize_with_trace(prob::SciMLBase.OptimizationProblem{…}, optimizer::Optim.LBFGS{…}; progress_name::String, progress_id::Base.UUID, maxiters::Int64, callback::Nothing, fail_on_nonfinite::Bool, kwargs::@Kwargs{})
     @ Pathfinder ~/Sync/Code/scripts/qself/old/julia/dev/Pathfinder/src/optimize.jl:69
 [116] optimize_with_trace
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Pathfinder/src/optimize.jl:50 [inlined]
 [117] _pathfinder(rng::Random._GLOBAL_RNG, prob::SciMLBase.OptimizationProblem{…}, logp::Pathfinder.var"#logp#26"{…}; history_length::Int64, optimizer::Optim.LBFGS{…}, ndraws_elbo::Int64, executor::Transducers.SequentialEx{…}, kwargs::@Kwargs{…})
     @ Pathfinder ~/Sync/Code/scripts/qself/old/julia/dev/Pathfinder/src/singlepath.jl:299
 [118] _pathfinder
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Pathfinder/src/singlepath.jl:288 [inlined]
 [119] _pathfinder_try_until_succeed(rng::Random._GLOBAL_RNG, prob::SciMLBase.OptimizationProblem{…}, logp::Pathfinder.var"#logp#26"{…}; ntries::Int64, init_scale::Int64, init_sampler::QSelf.var"#init_sampler#39"{…}, allow_mutating_init::Bool, kwargs::@Kwargs{…})
     @ Pathfinder ~/Sync/Code/scripts/qself/old/julia/dev/Pathfinder/src/singlepath.jl:274
 [120] _pathfinder_try_until_succeed
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Pathfinder/src/singlepath.jl:262 [inlined]
 [121] #25
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Pathfinder/src/singlepath.jl:191 [inlined]
 [122] progress(f::Pathfinder.var"#25#27"{…}; name::String)
     @ ProgressLogging ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:262
 [123] progress
     @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:258 [inlined]
 [124] pathfinder(prob::SciMLBase.OptimizationProblem{…}; rng::Random._GLOBAL_RNG, history_length::Int64, optimizer::Optim.LBFGS{…}, ndraws_elbo::Int64, ndraws::Int64, input::Function, kwargs::@Kwargs{…})
     @ Pathfinder ~/Sync/Code/scripts/qself/old/julia/dev/Pathfinder/src/singlepath.jl:190
 [125] pathfinder
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Pathfinder/src/singlepath.jl:179 [inlined]
 [126] #pathfinder#23
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Pathfinder/src/singlepath.jl:177 [inlined]
 [127] pathfinder
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Pathfinder/src/singlepath.jl:156 [inlined]
 [128] #32
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Pathfinder/src/multipath.jl:163 [inlined]
 [129] next
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/library.jl:54 [inlined]
 [130] (::Transducers.var"#50#51"{Transducers.Reduction{…}, Tuple{…}})(u0::Tuple{Int64, Float64}, iresult::BangBang.SafeCollector{BangBang.NoBang.Empty{…}})
     @ Transducers ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/library.jl:1302
 [131] wrapping
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/core.jl:734 [inlined]
 [132] next(rf::Transducers.Reduction{…}, result::Transducers.PrivateState{…}, input::Tuple{…})
     @ Transducers ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/library.jl:1300
 [133] macro expansion
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/core.jl:181 [inlined]
 [134] __foldl__
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/processes.jl:237 [inlined]
 [135] (::Transducers.var"#232#234"{Transducers.Reduction{…}, BangBang.SafeCollector{…}, Transducers.ProgressLoggingFoldable{…}})(id::Symbol)
     @ Transducers ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/progress.jl:96
 [136] __progress(f::Transducers.var"#232#234"{Transducers.Reduction{…}, BangBang.SafeCollector{…}, Transducers.ProgressLoggingFoldable{…}}; name::String)
     @ Transducers ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/progress.jl:75
 [137] __progress
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/progress.jl:71 [inlined]
 [138] __foldl__
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/progress.jl:82 [inlined]
 [139] #transduce#141
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/processes.jl:519 [inlined]
 [140] transduce
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/processes.jl:508 [inlined]
 [141] transduce(xform::Transducers.Composition{…}, f::Transducers.AdHocRF{…}, init::BangBang.SafeCollector{…}, coll::Transducers.ProgressLoggingFoldable{…}; kwargs::@Kwargs{})
     @ Transducers ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/processes.jl:502
 [142] transduce
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/processes.jl:500 [inlined]
 [143] _collect(xf::Transducers.Map{…}, coll::Transducers.ProgressLoggingFoldable{…}, ::Transducers.SizeStable, ::Base.HasShape{…})
     @ Transducers ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/processes.jl:806
 [144] collect
     @ ~/Sync/Code/scripts/qself/old/julia/dev/Transducers/src/processes.jl:802 [inlined]
 [145] collect(itr::Transducers.Eduction{Transducers.Reduction{…}, Transducers.ProgressLoggingFoldable{…}}, ex::Transducers.SequentialEx{@NamedTuple{}})
     @ Folds.Implementations ~/.julia/packages/Folds/qbSal/src/collect.jl:10
 [146] multipathfinder(optim_fun::SciMLBase.OptimizationFunction{…}, ndraws::Int64; init::Nothing, input::Function, nruns::Int64, ndraws_elbo::Int64, ndraws_per_run::Int64, rng::Random._GLOBAL_RNG, history_length::Int64, optimizer::Optim.LBFGS{…}, executor::Transducers.SequentialEx{…}, executor_per_run::Transducers.SequentialEx{…}, importance::Bool, kwargs::@Kwargs{…})
     @ Pathfinder ~/Sync/Code/scripts/qself/old/julia/dev/Pathfinder/src/multipath.jl:185
 [147] run_pathfinder(model::DynamicPPL.Model{…}; ndraws::Int64, ntries::Int64, nruns::Int64, adtype::ADTypes.AutoZygote)
     @ QSelf ~/Sync/Code/scripts/qself/old/julia/src/sample.jl:98
 [148] run_pathfinder
     @ ~/Sync/Code/scripts/qself/old/julia/src/sample.jl:67 [inlined]
 [149] run_inference(m::DynamicPPL.Model{…}, do_mcmc::Bool; n_iterations::Int64, n_chains::Int64)
     @ QSelf ~/Sync/Code/scripts/qself/old/julia/src/QSelf.jl:304
 [150] run_inference
     @ ~/Sync/Code/scripts/qself/old/julia/src/QSelf.jl:284 [inlined]
 [151] main(; do_mcmc::Bool)
     @ QSelf ~/Sync/Code/scripts/qself/old/julia/src/QSelf.jl:351
Some type information was truncated. Use `show(err)` to see complete types.

@acertain
Copy link

acertain commented Aug 11, 2024

One Zygote issue is FluxML/Zygote.jl#1517

@DhairyaLGandhi
Copy link
Member

I have updated SciML/SciMLSensitivity.jl#1085 to actually get this example running which doesn't run into differentiating the integrator interface. (That's what causes the zygote failure in the previous comments)

It needs adding an extra rrule for MTKParameters which was caused by an update elsewhere in a different SciML package (one of MTK, SII or SciMLStructures).

@acertain
Copy link

acertain commented Aug 12, 2024

Issues:

  1. ForwardDiffSensitivity was broken with MTK, which you fixed I think
  2. ReverseDiffAdjoint is broken with MTK
  3. ZygoteAdjoint is totally broken (due to at least adjoint of new with fewer arguments than fields is broken FluxML/Zygote.jl#1517 & try/catches inside ODE solver)

ForwardDiffSensitivity is slow with many parameters though, so it'd be really nice if ZygoteAdjoint and/or ReverseDiffAdjoint worked

@ChrisRackauckas
Copy link
Member

Why not one of the adjoint methods, like GaussAdjoint? This is really an orthogonal discussion to the purpose of the thread: if ForwardDiffSensitivity and GaussAdjoint work then the standard AutoZygote will work in any default chosen case or recommended case, so ReverseDiffAdjoint and ZygoteAdjoint really don't have a place in the story.

@ChrisRackauckas
Copy link
Member

ReverseDiffAdjoint is broken with MTK

ReverseDiffAdjoint should do better now, since it now always canonicalizes the tunables to a vector. So it's at least in principle possible to handle that now. It might need a SciMLStructures part added to its concrete_solve dispatch though to grab that vector, which isn't the hardest PR. Basically the same thing as SciML/SciMLSensitivity.jl#1085

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants