diff --git a/test/odesystem.jl b/test/odesystem.jl index 85d135b338..76c47a8f1d 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1552,3 +1552,60 @@ end expected_tstops = unique!(sort!(vcat(0.0:0.075:10.0, 0.1, 0.2, 0.65, 0.35, 0.45))) @test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops) end + +@testset "dae_order_lowering basic test" begin + @parameters a + @variables x(t) y(t) z(t) + @named dae_sys = ODESystem([ + D(x) ~ y, + 0 ~ x + z, + 0 ~ x - y + z + ], t, [z, y, x], []) + + lowered_dae_sys = dae_order_lowering(dae_sys) + @variables x1(t) y1(t) z1(t) + expected_eqs = [ + 0 ~ x + z, + 0 ~ x - y + z, + Differential(t)(x) ~ y + ] + lowered_eqs = equations(lowered_dae_sys) + sorted_lowered_eqs = sort(lowered_eqs, by=string) + sorted_expected_eqs = sort(expected_eqs, by=string) + @test sorted_lowered_eqs == sorted_expected_eqs + + expected_vars = Set([z, y, x]) + lowered_vars = Set(unknowns(lowered_dae_sys)) + @test lowered_vars == expected_vars +end + +@testset "dae_order_lowering test with structural_simplify" begin + @variables x(t) y(t) z(t) + @parameters M b k + eqs = [ + D(D(x)) ~ -b / M * D(x) - k / M * x, + 0 ~ y - D(x), + 0 ~ z - x + ] + ps = [M, b, k] + default_u0 = [ + D(x) => 0.0, x => 10.0, y => 0.0, z => 10.0 + ] + default_p = [M => 1.0, b => 1.0, k => 1.0] + @named dae_sys = ODESystem(eqs, t, [x, y, z], ps; defaults = [default_u0; default_p]) + + simplified_dae_sys = structural_simplify(dae_sys) + + lowered_dae_sys = dae_order_lowering(simplified_dae_sys) + lowered_dae_sys = complete(lowered_dae_sys) + + tspan = (0.0, 10.0) + prob = ODEProblem(lowered_dae_sys, nothing, tspan) + sol = solve(prob, Tsit5()) + + @test sol.t[end] == tspan[end] + @test sum(abs, sol.u[end]) < 1 + + prob = ODEProblem{false}(lowered_dae_sys; u0_constructor = x -> SVector(x...)) + @test prob.u0 isa SVector +end \ No newline at end of file