From d7f4b7990d49b4e739431248b9751b5b8780366f Mon Sep 17 00:00:00 2001 From: Christopher Tessum Date: Mon, 14 Oct 2024 22:16:25 -0500 Subject: [PATCH] Get IMEX methods working --- src/simulator_strategies.jl | 40 +++++++++++++++---------------------- test/simulator_test.jl | 16 +++++++-------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/src/simulator_strategies.jl b/src/simulator_strategies.jl index a9b1c415..855ff84d 100644 --- a/src/simulator_strategies.jl +++ b/src/simulator_strategies.jl @@ -21,47 +21,39 @@ function run!(st::SimulatorStrategy, simulator) error("Not implemented.") end -"Return a SciMLOperator to apply the MTK system to each column of s.u after reshaping to a matrix." -function mtk_op(s::Simulator) +"Return a function to apply the MTK system to each column of s.u after reshaping to a matrix." +function mtk_func(s::Simulator) mtkf = ODEFunction(s.sys_mtk) II = CartesianIndices(size(s)[2:4]) + nrows = size(s)[1] + grid = [g for g in s.grid] function setp!(p, j) # Set the parameters for the jth grid cell. ii = II[j] - for (jj, g) ∈ enumerate(s.grid) # Set the coordinates of this grid cell. + for (jj, g) ∈ enumerate(grid) # Set the coordinates of this grid cell. p[s.pvidx[jj]] = g[ii[jj]] end end - function f(du, u::AbstractMatrix, p, t) # In-place, matrix - @inbounds for j ∈ 1:size(u, 2) - col = view(u, :, j) - ddu = view(du, :, j) + function f(du, u, p, t) # In-place + umat = reshape(u, nrows, :) + dumat = reshape(du, nrows, :) + @inbounds for j ∈ 1:size(umat, 2) + col = view(umat, :, j) + ddu = view(dumat, :, j) setp!(p, j) @inline mtkf(ddu, col, p, t) end end - function f(u::AbstractMatrix, p, t) # Out-of-place, matrix + function f(u, p, t) # Out-of-place + umat = reshape(u, nrows, :) function ff(u, p, t, j) setp!(p, j) mtkf(u, p, t) end - @inbounds @views mapreduce(jcol -> ff(jcol[2], p, t, jcol[1]), hcat, enumerate(eachcol(u))) + @inbounds @views mapreduce(jcol -> ff(jcol[2], p, t, jcol[1]), hcat, enumerate(eachcol(umat))) end - - u = zeros(dtype(s.domaininfo), size(s)...) - indata = reshape(u, size(s)[1], :) - fo = FunctionOperator(f, indata, batch=true, p=s.p) - - ncols = size(indata, 2) - # Rehape the input vector to a matrix, then apply the FunctionOperator. - #op = ScalarOperator(1.0) * TensorProductOperator(I(ncols), fo) - op = TensorProductOperator(I(ncols), fo) - cache_operator(op, u[:]) -end - -function mtk_func(s::Simulator) b = repeat([length(unknowns(s.sys_mtk))], length(s) ÷ size(s)[1]) j = BlockBandedMatrix{Float64}(undef, b, b, (0,0)) # Jacobian prototype - ODEFunction(mtk_op(s); jac_prototype=j) + ODEFunction(f; jac_prototype=j) end """ @@ -99,6 +91,6 @@ function run!(s::Simulator, st::SimulatorIMEX, u=init_u(s); kwargs...) f2 = sum([get_scimlop(op, s) for op ∈ s.sys.ops]) start, finish = tspan(s.domaininfo) - prob = SplitODEProblem(f1, f2, u, (start, finish), s.p, callback=CallbackSet(get_callbacks(s)), kwargs...) + prob = SplitODEProblem(f1, f2, u[:], (start, finish), s.p, callback=CallbackSet(get_callbacks(s)...); kwargs...) solve(prob, st.alg; kwargs...) end diff --git a/test/simulator_test.jl b/test/simulator_test.jl index 85819706..7eecc168 100644 --- a/test/simulator_test.jl +++ b/test/simulator_test.jl @@ -113,14 +113,13 @@ EarthSciMLBase.threaded_ode_step!(sim, u, IIchunks, integrators, 0.0, 1.0) @test sum(abs.(u)) ≈ 212733.04492722102 -#@testset "mtk_func" begin -begin +@testset "mtk_func" begin ucopy = copy(u) f = EarthSciMLBase.mtk_func(sim) - u = EarthSciMLBase.init_u(sim) - du = similar(u) - prob = ODEProblem(f, u[:], (0.0, 1.0), sim.p) - sol = solve(prob, KenCarp47(linsolve=KrylovJL_GMRES(), autodiff=false)) + uu = EarthSciMLBase.init_u(sim) + du = similar(uu) + prob = ODEProblem(f, uu[:], (0.0, 1.0), sim.p) + sol = solve(prob, Tsit5()) uu = reshape(sol.u[end], size(ucopy)...) @test uu[:] ≈ ucopy[:] rtol = 0.01 end @@ -168,8 +167,9 @@ end sol = run!(sim, st; abstol=1e-12, reltol=1e-12) @test sum(abs.(sol.u[end])) ≈ 3.77224671877136e7 rtol = 1e-3 - st = SimulatorIMEX(KenCarp47(linsolve=KrylovJL_GMRES(), autodiff=false)) - @test_broken run!(sim, st) + st = SimulatorIMEX(Tsit5()) + sol = run!(sim, st) + @test sum(abs.(sol.u[end])) ≈ 3.3333500929324217e7 rtol = 1e-3 # No Splitting error in this one. end mutable struct cbt