Skip to content

Commit

Permalink
Merge pull request #1079 from SciML/dg/length
Browse files Browse the repository at this point in the history
Check `p` while choosing default sensealg
  • Loading branch information
ChrisRackauckas authored Jul 26, 2024
2 parents 4f90aca + 94e212a commit ef7a3b8
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,20 @@ function automatic_sensealg_choice(
prob::Union{SciMLBase.AbstractODEProblem,
SciMLBase.AbstractSDEProblem},
u0, p, verbose, repack)
if p === nothing || p isa SciMLBase.NullParameters
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, _ = canonicalize(Tunable(), p)
else
throw(SciMLStructuresCompatibilityError())
end

default_sensealg = if p !== DiffEqBase.NullParameters() &&
!(eltype(u0) <: ForwardDiff.Dual) &&
!(eltype(p) <: ForwardDiff.Dual) &&
!(eltype(u0) <: Complex) &&
!(eltype(p) <: Complex) &&
length(u0) + length(p) <= 100
length(u0) + length(tunables) <= 100
ForwardDiffSensitivity()
elseif u0 isa GPUArraysCore.AbstractGPUArray || !DiffEqBase.isinplace(prob)
# only Zygote is GPU compatible and fast
Expand Down Expand Up @@ -124,7 +132,7 @@ function automatic_sensealg_choice(
ReverseDiff.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0)
else
ReverseDiff.gradient(
(u, _p) -> sum(prob.f(u, repack(_p), prob.tspan[1])), u0, p)
(u, _p) -> sum(prob.f(u, repack(_p), prob.tspan[1])), u0, tunables)
end
ReverseDiffVJP()
catch e
Expand All @@ -151,8 +159,8 @@ function automatic_sensealg_choice(
end
tmp1 = back(λ)
else
_dy, back = Tracker.forward(y, p) do u, p
vec(f(u, p, t))
_dy, back = Tracker.forward(y, tunables) do u, tunables
vec(f(u, repack(tunables), t))
end
tmp1, tmp2 = back(λ)
end
Expand Down

0 comments on commit ef7a3b8

Please sign in to comment.