You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
println("Load libraries")
using OrdinaryDiffEq, DifferentialEquations, ComponentArrays
using Enzyme, Lux
using Zygote, SciMLSensitivity
using StableRNGs
using Optimization, OptimizationOptimisers
using Plots
rng = StableRNG(1111)
println("Setup NODE")
in_size = 4
# This script works with Zygote.gradient when layer_size=5, but julia crashes on type inference when layer_size=10 when calling Zygote.gradient()
layer_size = 10
const sc = Lux.Chain(Lux.Dense(in_size,layer_size,tanh),
Lux.Dense(layer_size,layer_size,tanh),
Lux.Dense(layer_size,2))
# Get the initial parameters and state variables of the model
p_nn, st = Lux.setup(rng, sc)
const _st = st
function NODE!(du,u,p,t)
NN = sc([u;vs],p,_st)[1]
du[1] = u[2] + NN[1]
du[2] = NN[2]
end
println("Test NODE")
u0_test = [1.0,2.0]
vs = [3.,4.]
theta = ComponentArray(p=p_nn, u0=u0_test)
sc([theta.u0;vs], theta.p, _st)
println("Check if ODE solves")
ts = [0,.1,.2,.3,.4,.5]
prob = ODEProblem(NODE!, theta.u0, (0.,1.))
pred = solve(prob, Tsit5(), u0=theta.u0, p=theta.p, saveat=ts)[1,:]
R = rand(length(ts))
function predict_neuralode(theta)
Array(solve(prob, Tsit5(), u0 = theta.u0, p = theta.p, saveat = ts))
end
function loss_neuralode(theta)
pred = predict_neuralode(theta)[1,:]
loss = sum(abs2, R .- pred)
return loss
end
loss_neuralode(theta)
println("Check if works with Enzyme autodiff")
Enzyme.gradient(Reverse, loss_neuralode, theta)
# Testing Enzyme autodiff based on: https://docs.sciml.ai/SciMLSensitivity/dev/faq/
prob = ODEProblem(NODE!, theta.u0, (0.,1.), theta.p)
u0 = prob.u0
p = prob.p
tmp2 = Enzyme.make_zero(p)
t = prob.tspan[1]
du = zero(u0)
if DiffEqBase.isinplace(prob)
_f = prob.f
else
_f = (du, u, p, t) -> (du .= prob.f(u, p, t); nothing)
end
_tmp6 = Enzyme.make_zero(_f)
tmp3 = zero(u0)
tmp4 = zero(u0)
ytmp = zero(u0)
tmp1 = zero(u0)
# Error here
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6),
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
Enzyme.Duplicated(p, tmp2),
Enzyme.Const(t))
Full discussion is here: https://discourse.julialang.org/t/lux-enzyme-and-zygote-neuralode-segmentation-fault/123645/4
MWE (with full error below the code):
Error message:
The text was updated successfully, but these errors were encountered: