diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 3ced92f5a..1e9c4f457 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -79,12 +79,20 @@ function automatic_sensealg_choice( prob::Union{SciMLBase.AbstractODEProblem, SciMLBase.AbstractSDEProblem}, u0, p, verbose, repack) + if p === nothing || p isa SciMLBase.NullParameters + tunables, repack = p, identity + elseif isscimlstructure(p) + tunables, repack, _ = canonicalize(Tunable(), p) + else + throw(SciMLStructuresCompatibilityError()) + end + default_sensealg = if p !== DiffEqBase.NullParameters() && !(eltype(u0) <: ForwardDiff.Dual) && !(eltype(p) <: ForwardDiff.Dual) && !(eltype(u0) <: Complex) && !(eltype(p) <: Complex) && - length(u0) + length(p) <= 100 + length(u0) + length(tunables) <= 100 ForwardDiffSensitivity() elseif u0 isa GPUArraysCore.AbstractGPUArray || !DiffEqBase.isinplace(prob) # only Zygote is GPU compatible and fast @@ -124,7 +132,7 @@ function automatic_sensealg_choice( ReverseDiff.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0) else ReverseDiff.gradient( - (u, _p) -> sum(prob.f(u, repack(_p), prob.tspan[1])), u0, p) + (u, _p) -> sum(prob.f(u, repack(_p), prob.tspan[1])), u0, tunables) end ReverseDiffVJP() catch e @@ -151,8 +159,8 @@ function automatic_sensealg_choice( end tmp1 = back(λ) else - _dy, back = Tracker.forward(y, p) do u, p - vec(f(u, p, t)) + _dy, back = Tracker.forward(y, tunables) do u, tunables + vec(f(u, repack(tunables), t)) end tmp1, tmp2 = back(λ) end