Skip to content

Commit

Permalink
Bug Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ctessum committed Nov 4, 2024
1 parent 67859fc commit e0788ed
Show file tree
Hide file tree
Showing 12 changed files with 122 additions and 109 deletions.
33 changes: 22 additions & 11 deletions src/coupled_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,15 @@ Get the ODE ModelingToolkit ODESystem representation of a [`CoupledSystem`](@ref
kwargs:
- name: The desired name for the resulting ODESystem
- simplify: if true, the observed variables that are not needed to specify the state variables
will be pruned and returned as a second return value after the ODESystem, which will be
structurally simplified.
- simplify: Whether to run `structural_simplify` on the resulting ODESystem
- prune: Whether to prune the extra observed equations to improve performance
Return values:
- The ODESystem representation of the CoupledSystem
- The extra observed equations which have been pruned to improve performance
"""
function Base.convert(::Type{<:ODESystem}, sys::CoupledSystem; name=:model, simplify=false, kwargs...)
function Base.convert(::Type{<:ODESystem}, sys::CoupledSystem; name=:model, simplify=true,
prune=true, kwargs...)
connector_eqs = Equation[]
systems = copy(sys.systems)
for (i, a) enumerate(systems)
Expand All @@ -165,12 +169,19 @@ function Base.convert(::Type{<:ODESystem}, sys::CoupledSystem; name=:model, simp

# Compose everything together.
o = compose(connectors, systems...)
o = remove_extra_defaults(o)
if simplify
return prune_observed(o)
o = namespace_events(o)
o_simplified = structural_simplify(o)
if prune
o, obs = prune_observed(o, o_simplified)
else
return o
obs = []
end
o_simplified = structural_simplify(o)
o = remove_extra_defaults(o, o_simplified)
if simplify
o = structural_simplify(o)
end
return o, obs
end

"""
Expand All @@ -179,7 +190,7 @@ end
Get the ModelingToolkit PDESystem representation of a [`CoupledSystem`](@ref).
"""
function Base.convert(::Type{<:PDESystem}, sys::CoupledSystem; name=:model, kwargs...)::ModelingToolkit.AbstractSystem
o = convert(ODESystem, sys; name, kwargs...)
o, _ = convert(ODESystem, sys; name=name, simplify=false, prune=false, kwargs...)

if sys.domaininfo !== nothing
o += sys.domaininfo
Expand Down Expand Up @@ -211,8 +222,8 @@ function nonstiff_ops(sys::CoupledSystem, sys_mtk, obs_eqs, domain, u0, p)
obs_funcs = obs_functions(obs_eqs, domain)
coord_trans_funcs = coord_trans_functions(obs_eqs, domain)
nonstiff_op = length(sys.ops) > 0 ?
sum([get_scimlop(op, sys, sys_mtk, domain, obs_funcs, coord_trans_funcs, u0, p) for op sys.ops]) :
NullOperator(length(u0))
sum([get_scimlop(op, sys, sys_mtk, domain, obs_funcs, coord_trans_funcs, u0, p) for op sys.ops]) :
NullOperator(length(u0))
nonstiff_op = cache_operator(nonstiff_op, u0)
end

Expand Down
59 changes: 30 additions & 29 deletions src/coupled_system_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,22 @@ Return the indexes of the system variables that the state variables of the final
simplified system depend on. This should be done before running `structural_simplify`
on the system.
"""
function get_needed_vars(sys::ODESystem)
function get_needed_vars(original_sys::ODESystem, simplified_sys)
varvardeps = ModelingToolkit.varvar_dependencies(
ModelingToolkit.asgraph(sys),
ModelingToolkit.variable_dependencies(sys),
ModelingToolkit.asgraph(original_sys),
ModelingToolkit.variable_dependencies(original_sys),
)
g = SimpleDiGraph(length(unknowns(sys)))
g = SimpleDiGraph(length(unknowns(original_sys)))
for (i, es) in enumerate(varvardeps.badjlist)
for e in es
add_edge!(g, i, e)
end
end
allst = unknowns(sys)
simpst = unknowns(structural_simplify(sys))
allst = unknowns(original_sys)
simpst = unknowns(simplified_sys)
stidx = [only(findall(isequal(s), allst)) for s in simpst]
collect(Graphs.DFSIterator(g, stidx))
idx = collect(Graphs.DFSIterator(g, stidx))
unknowns(original_sys)[idx]
end

"""
Expand Down Expand Up @@ -161,9 +162,9 @@ end
# Return the discrete events that affect variables that are
# needed to specify the state variables of the given system.
# This function should be run after running `structural_simplify`.
function filter_discrete_events(simplified_sys)
function filter_discrete_events(simplified_sys, obs_eqs)
de = ModelingToolkit.get_discrete_events(simplified_sys)
needed_eqs = equations(simplified_sys)
needed_eqs = vcat(equations(simplified_sys), obs_eqs)
keep = []
for e in de
evars = EarthSciMLBase.get_affected_vars(e)
Expand Down Expand Up @@ -198,27 +199,27 @@ Remove equations from an ODESystem where the variable in the LHS is not
present in any of the equations for the state variables. This can be used to
remove computationally intensive equations that are not used in the final model.
"""
function prune_observed(sys::ODESystem)
needed_var_idxs = get_needed_vars(sys)
needed_vars = Symbolics.tosymbol.(unknowns(sys)[needed_var_idxs]; escape=true)
sys = structural_simplify(sys)
function prune_observed(original_sys::ODESystem, simplified_sys)
needed_vars = var2symbol.(get_needed_vars(original_sys, simplified_sys))
deleteindex = []
for (i, eq) enumerate(observed(sys))
lhsvars = Symbolics.tosymbol.(Symbolics.get_variables(eq.lhs); escape=true)
obs = observed(simplified_sys)
for (i, eq) enumerate(obs)
lhsvars = var2symbol.(Symbolics.get_variables(eq.lhs))
# Only keep equations where all variables on the LHS are in at least one
# equation describing the system state.
if !all((var) -> var needed_vars, lhsvars)
push!(deleteindex, i)
end
end
discrete_events = filter_discrete_events(sys)
obs = observed(sys)
deleteat!(obs, deleteindex)
sys2 = structural_simplify(copy_with_change(sys;
eqs=[equations(sys); obs],
discrete_events = filter_discrete_events(simplified_sys, obs)
new_eqs = [equations(simplified_sys); obs]
sys2 = copy_with_change(simplified_sys;
eqs=new_eqs,
unknowns=get_unknowns(new_eqs),
discrete_events=discrete_events,
))
return sys2, observed(sys)
)
return sys2, observed(simplified_sys)
end

# Get the unknown variables in the system of equations.
Expand All @@ -229,14 +230,13 @@ function get_unknowns(eqs)
end

# Remove extra variable defaults that would cause a solver initialization error.
# This should be done before running `structural_simplify` on the system.
function remove_extra_defaults(sys)
all_vars = get_unknowns(equations(sys))
function remove_extra_defaults(original_sys, simplified_sys)
all_vars = unknowns(original_sys)

unk = Symbol.(unknowns(structural_simplify(sys)))
unk = var2symbol.(get_needed_vars(original_sys, simplified_sys))

# Check if v is not in the unknowns, is a variable, and has a default.
checkextra(v) = !(Symbol(v) in unk) &&
# Check if v is not in the unknowns and has a default.
checkextra(v) = !(var2symbol(v) in unk) &&
(Symbolics.VariableDefaultValue in keys(v.metadata))
extra_default_vars = all_vars[checkextra.(all_vars)]

Expand All @@ -247,9 +247,10 @@ function remove_extra_defaults(sys)
newv = @set v.metadata = newmeta
push!(replacements, v => newv)
end
new_eqs = substitute.(equations(sys), (Dict(replacements...),))
new_eqs = substitute.(equations(original_sys), (Dict(replacements...),))
new_unk = get_unknowns(new_eqs)
copy_with_change(sys; eqs=new_eqs, unknowns=new_unk, parameters=parameters(sys))
copy_with_change(original_sys; eqs=new_eqs, unknowns=new_unk,
parameters=parameters(original_sys))
end

"Initialize the state variables."
Expand Down
18 changes: 12 additions & 6 deletions src/domaininfo.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export DomainInfo, ICBCcomponent, constIC, constBC, zerogradBC, periodicBC, partialderivatives
export DomainInfo, ICBCcomponent, constIC, constBC, zerogradBC, periodicBC, partialderivatives,
get_tspan, get_tspan_datetime

"""
Initial and boundary condition components that can be combined to
Expand Down Expand Up @@ -169,9 +170,9 @@ end
"""
$(SIGNATURES)
Return the time range associated with this domain.
Return the time range associated with this domain, returning the values as Unix times.
"""
function tspan(d::DomainInfo{T})::Tuple{T,T} where {T<:AbstractFloat}
function get_tspan(d::DomainInfo{T})::Tuple{T,T} where {T<:AbstractFloat}
for icbc d.icbc
if icbc isa ICcomponent
return DomainSets.infimum(icbc.indepdomain.domain), DomainSets.supremum(icbc.indepdomain.domain)
Expand All @@ -180,8 +181,13 @@ function tspan(d::DomainInfo{T})::Tuple{T,T} where {T<:AbstractFloat}
throw(ArgumentError("Could not find a time range for this domain."))
end

function tspan_datetime(d::DomainInfo)
(Float64.(tspan(d)) .+ Float64(d.time_offset)) .|> unix2datetime
"""
$(SIGNATURES)
Return the time range associated with this domain, returning the values as DateTimes.
"""
function get_tspan_datetime(d::DomainInfo)
(Float64.(get_tspan(d)) .+ Float64(d.time_offset)) .|> unix2datetime
end

"""
Expand Down Expand Up @@ -437,7 +443,7 @@ domains(di::DomainInfo) = unique(vcat(domains.(di.icbc)...))
function Base.:(+)(sys::ModelingToolkit.ODESystem, di::DomainInfo)::ModelingToolkit.PDESystem
dimensions = dims(di)
allvars = unknowns(sys)
statevars = unknowns(structural_simplify(sys))
statevars = unknowns(sys)
ps = parameters(sys)
toreplace, replacements = replacement_params(ps, pvars(di))
dvs = add_dims(allvars, dimensions) # Add new dimensions to dependent variables.
Expand Down
4 changes: 2 additions & 2 deletions src/solver_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end
function ODEProblem(sys::CoupledSystem, st::SolverIMEX; u0=nothing, p=nothing,
name=:model, kwargs...)

sys_mtk, obs_eqs = convert(ODESystem, sys; simplify=true, name=name)
sys_mtk, obs_eqs = convert(ODESystem, sys; name=name)
dom = domain(sys)

u0 = isnothing(u0) ? init_u(sys_mtk, dom) : u0
Expand All @@ -47,7 +47,7 @@ function ODEProblem(sys::CoupledSystem, st::SolverIMEX; u0=nothing, p=nothing,
f2 = nonstiff_ops(sys, sys_mtk, obs_eqs, dom, u0, p)

cb = get_callbacks(sys, sys_mtk, obs_eqs, dom)
start, finish = tspan(dom)
start, finish = get_tspan(dom)
SplitODEProblem(f1, f2, u0[:], (start, finish), p,
callback=CallbackSet(cb...); kwargs...)
end
7 changes: 4 additions & 3 deletions src/solver_strategy_strang.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,16 @@ nthreads(st::SolverStrangSerial) = 1
function ODEProblem(s::CoupledSystem, st::SolverStrang; u0=nothing, p=nothing,
nonstiff_params=nothing, name=:model, kwargs...)

sys_mtk, obs_eqs = convert(ODESystem, s; simplify=true, name=name)
sys_mtk, obs_eqs = convert(ODESystem, s; name=name)

dom = domain(s)
u0 = isnothing(u0) ? init_u(sys_mtk, dom) : u0
p = isnothing(p) ? default_params(sys_mtk) : p

II = CartesianIndices(tuple(size(dom)...))
IIchunks = collect(Iterators.partition(II, length(II) ÷ nthreads(st)))
start, finish = tspan(dom)
prob = ODEProblem(sys_mtk, [], (start, finish), []; st.stiff_kwargs...)
start, finish = get_tspan(dom)
prob = ODEProblem(sys_mtk, [], (start, start+st.timestep), []; st.stiff_kwargs...)
stiff_integrators = [init(remake(prob, u0=zeros(length(unknowns(sys_mtk))), p=deepcopy(p)),
st.stiffalg, save_on=false, save_start=false, save_end=false, initialize_save=false;
st.stiff_kwargs...) for _ in 1:length(IIchunks)]
Expand Down Expand Up @@ -147,6 +147,7 @@ function single_ode_step!(setp!, u, IIchunk, integrator, time, step_length)
setp!(integrator.p, ii)
solve!(integrator)
@assert length(integrator.sol.u) == 0
@assert integrator.t == time + step_length
uii .= integrator.u
end
end
5 changes: 3 additions & 2 deletions test/advection_test.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using EarthSciMLBase
using Test
using DomainSets, MethodOfLines, ModelingToolkit, DifferentialEquations
using DomainSets, ModelingToolkit
using MethodOfLines, DifferentialEquations
using ModelingToolkit: t, D
import SciMLBase
using DynamicQuantities
Expand Down Expand Up @@ -41,7 +42,7 @@ using Dates, DomainSets
@test length(equations(combined_mtk)) == 6
@test length(combined_mtk.ivs) == 2
@test length(combined_mtk.dvs) == 6
@test length(combined_mtk.bcs) == 3
@test length(combined_mtk.bcs) == 18

eq = equations(combined_mtk)
eqstr = string(eq)
Expand Down
31 changes: 15 additions & 16 deletions test/coupled_system_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,16 @@ using OrdinaryDiffEq

sir = couple(seqn, ieqn, reqn)

sirfinal = convert(PDESystem, sir)

sir_simple = structural_simplify(sirfinal)
sirfinal, _ = convert(ODESystem, sir)

want_eqs = [
D(reqn.R) ~ reqn.γ * reqn.I,
D(seqn.S) ~ (-seqn.β * seqn.I * seqn.S) / (seqn.I + seqn.R + seqn.S),
D(ieqn.I) ~ (ieqn.β * ieqn.I * ieqn.S) / (ieqn.I + ieqn.R + ieqn.S) - ieqn.γ * ieqn.I,
]

have_eqs = equations(sir_simple)
have_eqs = equations(sirfinal)
obs = observed(sirfinal)
for eq in want_eqs
@test eq in have_eqs
end
Expand Down Expand Up @@ -157,8 +156,7 @@ end
]
for (i, model) enumerate(models)
@testset "permutation $i" begin
model_mtk = convert(ODESystem, model)
m = structural_simplify(model_mtk)
m, _ = convert(ODESystem, model)
eqstr = string(equations(m))
@test occursin("b₊c_NO2(t)", eqstr)
@test occursin("b₊jNO2(t)", eqstr)
Expand All @@ -173,9 +171,10 @@ end

@testset "Stable evaluation" begin
sys = couple(A(), B(), C())
eqs1 = string(equations(convert(ODESystem, sys)))
s, _ = convert(ODESystem, sys)
eqs1 = string(equations(s))
@test occursin("b₊c_NO2(t)", eqs1)
eqs2 = string(equations(convert(ODESystem, sys)))
eqs2 = string(equations(s))
@test occursin("b₊c_NO2(t)", eqs2)
@test eqs1 == eqs2
end
Expand Down Expand Up @@ -214,7 +213,7 @@ end
ModelingToolkit.t_nounits; name=:test,
discrete_events=[event1, event2, event3, event4],
)
sys = EarthSciMLBase.remove_extra_defaults(sys)
sys = EarthSciMLBase.remove_extra_defaults(sys, structural_simplify(sys))

prob = ODEProblem(structural_simplify(sys), [], (0, 100), [])
sol = solve(prob, abstol=1e-8, reltol=1e-8)
Expand All @@ -238,21 +237,21 @@ end
end

@testset "filter events" begin
kept_events = EarthSciMLBase.filter_discrete_events(structural_simplify(sys))
kept_events = EarthSciMLBase.filter_discrete_events(structural_simplify(sys), [])
@test length(kept_events) == 2
@test EarthSciMLBase.var2symbol(only(EarthSciMLBase.get_affected_vars(kept_events[1]))) == :p_1
@test EarthSciMLBase.var2symbol(only(EarthSciMLBase.get_affected_vars(kept_events[2]))) == :p_2
end

@testset "prune observed" begin
sys2, obs = EarthSciMLBase.prune_observed(sys)
sys2, obs = EarthSciMLBase.prune_observed(sys, structural_simplify(sys))
@test length(equations(sys2)) == 2
@test length(ModelingToolkit.get_discrete_events(sys2)) == 2
@test length(obs) == 1
end

sys2, _ = EarthSciMLBase.prune_observed(sys)
prob = ODEProblem(sys2, [], (0, 100), [])
sys2, _ = EarthSciMLBase.prune_observed(sys, structural_simplify(sys))
prob = ODEProblem(structural_simplify(sys2), [], (0, 100), [])
sol = solve(prob, abstol=1e-8, reltol=1e-8)
@test sol[x][end] 3
@test sol[x2][end] 3
Expand Down Expand Up @@ -287,9 +286,9 @@ end
sys_composed = compose(ODESystem(Equation[], ModelingToolkit.t_nounits; name=:coupled),
create_sys(name=:a), create_sys(name=:b))
sysc = EarthSciMLBase.namespace_events(sys_composed)
sysc2 = EarthSciMLBase.remove_extra_defaults(sysc)
sysc3, obs = EarthSciMLBase.prune_observed(sysc2)
prob = ODEProblem(sysc3, [], (0, 100), [])
sysc2 = EarthSciMLBase.remove_extra_defaults(sysc, structural_simplify(sysc))
sysc3, obs = EarthSciMLBase.prune_observed(sysc2, structural_simplify(sysc2))
prob = ODEProblem(structural_simplify(sysc3), [], (0, 100), [])
sol = solve(prob, abstol=1e-8, reltol=1e-8)
@test length(sol.u[end]) == 4
@test all(sol.u[end] .≈ 3)
Expand Down
Loading

0 comments on commit e0788ed

Please sign in to comment.