Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when mixing datatypes while differentiating using ForwardDiff and using FunctionWrapperSpecialize #994

Open
LilithHafner opened this issue Nov 2, 2023 · 2 comments

Comments

@LilithHafner
Copy link
Member

This MWE

using DifferentialEquations
using ForwardDiff: gradient

function ode_f(du, u, p, t)
    x = u[1]
    v = u[2]
    dx = v
    dv = -x
    du[1] = dx
    du[2] = dv
end

function f(initial)
    tspan = (0.0,1.0)
    prob = ODEProblem(ode_f, initial, tspan)
    sol = solve(prob)
    sol[end][1]
end

gradient(f, Float32[1.0, 1.0])

Gives a "No matching function wrapper was found!" error with a very long stacktrace.

Empirically, I have many workarounds:

  • Using finite differencing avoids the error.
  • Using Float64 inputs avoids the error
  • Using ODEProblem{true, DifferentialEquations.SciMLBase.FullSpecialize} avoids the error
  • Using ODEProblem{true, DifferentialEquations.SciMLBase.NoSpecialize} avoids the error

I suspect that there is an inconsistency between the code the decides which input types to precompute when using FunctionWrappersWrappers and the code that calls the doubly wrapped function.

This issue stems from an investigation into SciML/juliatorch#10. If this issue were fixed then I expect that SciML/juliatorch#10 would also be fixed

Maually truncated stacktrace
ERROR: No matching function wrapper was found!
Stacktrace:
  [1] _call(#unused#::Tuple{}, arg::Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}, fww::FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Flo
    @ FunctionWrappersWrappers ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:23
  [2] _call(fw::Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, SciMLBase.NullParameters, ForwardDiff.Dual{ForwardDiff.Tag{DiffEq
    @ FunctionWrappersWrappers ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:13
  [3] _call(fw::Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.Ordin
     @ FunctionWrappersWrappers ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:13
  [4] _call(fw::Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, ForwardDif 
     @ FunctionWrappersWrappers ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:13
  [5] _call(fw::Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}
    @ FunctionWrappersWrappers ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:13
  [6] (::FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 
    @ FunctionWrappersWrappers ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:10
  [7] (::ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float
    @ SciMLBase ~/.julia/packages/SciMLBase/VS2ST/src/scimlfunctions.jl:2394
  [8] ode_determine_initdt(u0::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, t::Float64, tdir::Float64, dtmax::Float64, abstol::Float32, reltol::ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), prob::ODEProblem{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Floa
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/yppG9/src/initdt.jl:53
  [9] auto_dt_reset!
    @ ~/.julia/packages/OrdinaryDiffEq/yppG9/src/integrators/integrator_interface.jl:449 [inlined]
 [10] handle_dt!(integrator::OrdinaryDiffEq.ODEIntegrator{CompositeAlgorithm{Tuple{Vern7{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Rodas5P{1, false, LinearSolve.DefaultLinearSolver, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, OrdinaryDiffEq.AutoSwitchCache{Vern7{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Rodas5P{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}, Rational{Int64}, Int64}}, true, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Nothing, Float64, SciMLBase.NullParameters, Float64, Float32, Float32, Float64, Vector{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}}, ODESolution{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 2, Vector{V
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/yppG9/src/solve.jl:555
 [11] __init(prob::ODEProblem{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Forwar
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/yppG9/src/solve.jl:517
 [12] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEq/yppG9/src/solve.jl:10 [inlined]
 [13] __solve(::ODEProblem{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{type
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/yppG9/src/solve.jl:5
 [14] __solve
    @ ~/.julia/packages/OrdinaryDiffEq/yppG9/src/solve.jl:1 [inlined]
 [15] #solve_call#34
    @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:557 [inlined]
 [16] solve_call
    @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:523 [inlined]
 [17] #solve_up#42
    @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:1006 [inlined]
 [18] solve_up
    @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:992 [inlined]
 [19] #solve#40
    @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:929 [inlined]
 [20] __solve(::ODEProblem{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float
    @ DifferentialEquations ~/.julia/packages/DifferentialEquations/Tu7HS/src/default_solve.jl:14
 [21] __solve
    @ ~/.julia/packages/DifferentialEquations/Tu7HS/src/default_solve.jl:1 [inlined]
 [22] #__solve#63
    @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:1285 [inlined]
 [23] __solve
    @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:1278 [inlined]
 [24] solve_call(::ODEProblem{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), 
    @ DiffEqBase ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:557
 [25] solve_call
    @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:523 [inlined]
 [26] #solve_up#42
    @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:998 [inlined]
 [27] solve_up
    @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:992 [inlined]
 [28] solve(::ODEProblem{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(ode_f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{true}, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:929
 [29] solve(::ODEProblem{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(ode_f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:919
 [30] f(initial::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}})
    @ Main ./REPL[19]:4
 [31] vector_mode_dual_eval!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:24 [inlined]
 [32] vector_mode_gradient(f::typeof(f), x::Vector{Float32}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:89
 [33] gradient(f::Function, x::Vector{Float32}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}}, ::Val{true})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:0
 [34] gradient(f::Function, x::Vector{Float32}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:17
 [35] gradient(f::Function, x::Vector{Float32})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:17
 [36] top-level scope
    @ REPL[20]:1
@oscardssmith
Copy link
Contributor

oscardssmith commented Nov 2, 2023

slightly better repoducer:

using OrdinaryDiffEq
using ForwardDiff: gradient
ode_f(du, u, p, t) = du[1] = -u[1]
function f(initial)
    tspan = (0.0,1.0)
    prob = ODEProblem(ode_f, initial, tspan)
    solve(prob, Rodas5P())[end][1]
end
gradient(f, Float32[1.0])

Specifically if you use solve(prob, FBDF()), it works so the problem seems to be for the tgrad.

@oscardssmith
Copy link
Contributor

fixed by SciML/OrdinaryDiffEq.jl#2051

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants