Skip to content

Commit

Permalink
Handle events
Browse files Browse the repository at this point in the history
  • Loading branch information
ctessum committed Nov 3, 2024
1 parent 1467cf5 commit 67859fc
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 24 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -39,7 +38,7 @@ ModelingToolkit = "9"
OrdinaryDiffEq = "6"
SciMLBase = "2"
SciMLOperators = "0.3"
SymbolicIndexingInterface = "0.3.34"
SymbolicIndexingInterface = "0.3"
Symbolics = "5, 6"
julia = "1.6"

Expand All @@ -49,9 +48,10 @@ DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Documenter", "Test", "SafeTestsets", "MethodOfLines", "Dates", "DomainSets", "DifferentialEquations", "SciMLBase"]
test = ["Documenter", "Test", "SafeTestsets", "MethodOfLines", "Dates", "DomainSets", "DifferentialEquations", "OrdinaryDiffEq", "SciMLBase"]
6 changes: 4 additions & 2 deletions src/EarthSciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ using ModelingToolkit: AbstractSystem
using Graphs, MetaGraphsNext
using DocStringExtensions
using DynamicQuantities, Dates
using OrdinaryDiffEq, DomainSets
using SciMLBase: DECallback, CallbackSet, ODEProblem
using DomainSets
using SciMLBase: DECallback, CallbackSet, ODEProblem, SplitODEProblem, reinit!, solve!,
init, remake
using SciMLOperators: cache_operator, NullOperator
using Statistics
using DiffEqCallbacks
using LinearAlgebra, BlockBandedMatrices
Expand Down
88 changes: 82 additions & 6 deletions src/coupled_system_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,78 @@ Create a copy of an ODESystem with the given changes.
function copy_with_change(sys::ODESystem;
eqs=equations(sys),
name=nameof(sys),
unknowns=unknowns(sys),
parameters=parameters(sys),
metadata=ModelingToolkit.get_metadata(sys),
continuous_events=ModelingToolkit.get_continuous_events(sys),
discrete_events=ModelingToolkit.get_discrete_events(sys),
)
ODESystem(eqs, ModelingToolkit.get_iv(sys); name=name, metadata=metadata,
ODESystem(eqs, ModelingToolkit.get_iv(sys), unknowns, parameters;
name=name, metadata=metadata,
continuous_events=continuous_events, discrete_events=discrete_events)
end

# Get variables effected by this event.
function get_affected_vars(event)
vars = []
if event.affects isa AbstractVector
for aff in event.affects
push!(vars, aff.lhs)
end
else
push!(vars, event.affects.pars...)
push!(vars, event.affects.sts...)
push!(vars, event.affects.discretes...)
end
unique(vars)
end

function var2symbol(var)
if var isa Symbolics.CallWithMetadata
var = var.f
elseif iscall(var)
var = operation(var)
end
Symbolics.tosymbol(var; escape=false)
end

function var_in_eqs(var, eqs)
any([any(isequal.((var2symbol(var),), var2symbol.(get_variables(eq)))) for eq in eqs])
end

# Return the discrete events that affect variables that are
# needed to specify the state variables of the given system.
# This function should be run after running `structural_simplify`.
function filter_discrete_events(simplified_sys)
de = ModelingToolkit.get_discrete_events(simplified_sys)
needed_eqs = equations(simplified_sys)
keep = []
for e in de
evars = EarthSciMLBase.get_affected_vars(e)
if any(var_in_eqs.(evars, (needed_eqs,)))
push!(keep, e)
end
end
keep
end

# Add the namespace to the affects of the discrete events.
# This should be done before doing anything else to the system because the
# events get ignored by default.
# TODO(CT): This is probably only necessary due a bug in MTK. Remove this when fixed.
function namespace_events(sys)
events = ModelingToolkit.get_discrete_events(sys)
for sys2 in ModelingToolkit.get_systems(sys)
events2 = ModelingToolkit.get_discrete_events(sys2)
for ev in events2
af = ModelingToolkit.namespace_affects(ev.affects, sys2)
ev2 = @set ev.affects = af
push!(events, ev2)
end
end
copy_with_change(sys; discrete_events=events)
end

"""
$(SIGNATURES)
Expand All @@ -147,22 +211,32 @@ function prune_observed(sys::ODESystem)
push!(deleteindex, i)
end
end
discrete_events = filter_discrete_events(sys)
obs = observed(sys)
deleteat!(obs, deleteindex)
sys2 = structural_simplify(copy_with_change(sys; eqs=[equations(sys); obs]))
sys2 = structural_simplify(copy_with_change(sys;
eqs=[equations(sys); obs],
discrete_events=discrete_events,
))
return sys2, observed(sys)
end

# Get the unknown variables in the system of equations.
function get_unknowns(eqs)
all_vars = unique(vcat(get_variables.(eqs)...))
unk = [v.metadata[Symbolics.VariableSource][1] == :variables for v in all_vars]
all_vars[unk]
end

# Remove extra variable defaults that would cause a solver initialization error.
# This should be done before running `structural_simplify` on the system.
function remove_extra_defaults(sys)
all_vars = unique(vcat(get_variables.(equations(sys))...))
all_vars = get_unknowns(equations(sys))

unk = Symbol.(unknowns(structural_simplify(sys)))

# Check if v is not in the unknowns, is a variable, and has a default.
checkextra(v) = !(Symbol(v) in unk) &&
v.metadata[Symbolics.VariableSource][1] == :variables &&
(Symbolics.VariableDefaultValue in keys(v.metadata))
extra_default_vars = all_vars[checkextra.(all_vars)]

Expand All @@ -174,7 +248,8 @@ function remove_extra_defaults(sys)
push!(replacements, v => newv)
end
new_eqs = substitute.(equations(sys), (Dict(replacements...),))
copy_with_change(sys, eqs=new_eqs)
new_unk = get_unknowns(new_eqs)
copy_with_change(sys; eqs=new_eqs, unknowns=new_unk, parameters=parameters(sys))
end

"Initialize the state variables."
Expand Down Expand Up @@ -202,7 +277,8 @@ end
function coord_params(mtk_sys::AbstractSystem, domain::DomainInfo)
pv = pvars(domain)
params = parameters(mtk_sys)
_pvidx = [findall(v -> split(String(Symbol(v)), "")[end] == String(Symbol(p)), params) for p pv]

_pvidx = [findall(v -> split(String(var2symbol(v)), "")[end] == String(var2symbol(p)), params) for p pv]
for (i, idx) in enumerate(_pvidx)
if length(idx) > 1
error("Partial independent variable '$(pv[i])' has multiple matches in system parameters: [$(parameters(mtk_sys)[idx])].")
Expand Down
24 changes: 17 additions & 7 deletions src/solver_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,39 @@ A solver strategy based on implicit-explicit (IMEX) time integration.
See [here](https://docs.sciml.ai/DiffEqDocs/stable/types/split_ode_types/)
for additional information.
kwargs for ODEProblem constructor:
kwargs:
- stiff_scimlop: Whether the stiff ODE function should be implemented as a SciMLOperator.
- stiff_sparse: Whether the stiff ODE function should use a sparse Jacobian.
- stiff_jac: Whether the stiff ODE function should use an analytical Jacobian.
- stiff_jac_scimlop: Whether the stiff ODE function Jacobian should be implemented as a SciMLOperator. (Ignored if `stiff_jac==false`.)
- stiff_tgrad: Whether the stiff ODE function should use an analytical time gradient.
Additional kwargs for ODEProblem constructor:
- u0: initial condtions; if "nothing", default values will be used.
- p: parameters; if "nothing", default values will be used.
- name: name of the model.
"""
struct SolverIMEX <: SolverStrategy end
struct SolverIMEX <: SolverStrategy
stiff_scimlop::Bool
stiff_jac::Bool
stiff_sparse::Bool
stiff_tgrad::Bool
function SolverIMEX(; stiff_scimlop=false, stiff_jac=true, stiff_sparse=true, stiff_tgrad=true)
new(stiff_scimlop, stiff_jac, stiff_sparse, stiff_tgrad)
end
end

function SciMLBase.ODEProblem(sys::CoupledSystem, st::SolverIMEX; u0=nothing, p=nothing,
stiff_scimlop=false, stiff_jac=true,
stiff_sparse=true, stiff_tgrad=true, name=:model, kwargs...)
function ODEProblem(sys::CoupledSystem, st::SolverIMEX; u0=nothing, p=nothing,
name=:model, kwargs...)

sys_mtk, obs_eqs = convert(ODESystem, sys; simplify=true, name=name)
dom = domain(sys)

u0 = isnothing(u0) ? init_u(sys_mtk, dom) : u0
p = isnothing(p) ? default_params(sys_mtk) : p

f1 = mtk_grid_func(sys_mtk, dom, u0, p; jac=stiff_jac,
sparse=stiff_sparse, scimlop=stiff_scimlop, tgrad=stiff_tgrad)
f1 = mtk_grid_func(sys_mtk, dom, u0, p; jac=st.stiff_jac,
sparse=st.stiff_sparse, scimlop=st.stiff_scimlop, tgrad=st.stiff_tgrad)

f2 = nonstiff_ops(sys, sys_mtk, obs_eqs, dom, u0, p)

Expand Down
7 changes: 6 additions & 1 deletion src/solver_strategy_strang.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ Perform a simulation using [Strang splitting](https://en.wikipedia.org/wiki/Stra
where the MTK system is assumed to be stiff and the operators are assumed to be non-stiff.
The solution will be calculated in serial.
Additional kwargs for ODEProblem constructor:
- u0: initial condtions; if "nothing", default values will be used.
- p: parameters; if "nothing", default values will be used.
- nonstiff_params: parameters for the [`Operator`](@ref)s.
$(FIELDS)
"""
struct SolverStrangSerial <: SolverStrang
Expand All @@ -71,7 +76,7 @@ end
nthreads(st::SolverStrangThreads) = st.threads
nthreads(st::SolverStrangSerial) = 1

function SciMLBase.ODEProblem(s::CoupledSystem, st::SolverStrang; u0=nothing, p=nothing,
function ODEProblem(s::CoupledSystem, st::SolverStrang; u0=nothing, p=nothing,
nonstiff_params=nothing, name=:model, kwargs...)

sys_mtk, obs_eqs = convert(ODESystem, s; simplify=true, name=name)
Expand Down
Loading

0 comments on commit 67859fc

Please sign in to comment.