Skip to content

Commit

Permalink
Remove extra defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
ctessum committed Oct 11, 2024
1 parent b52a726 commit 9a66e04
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 39 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "EarthSciMLBase"
uuid = "e53f1632-a13c-4728-9402-0c66d48804b0"
authors = ["EarthSciML Authors and Contributors"]
version = "0.16.1"
version = "0.16.2"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
Catalyst = "479239e8-5488-4da2-87a7-35f2df7eef83"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Expand All @@ -21,6 +22,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[compat]
Accessors = "0.1"
BlockBandedMatrices = "0.13"
Catalyst = "14"
DiffEqCallbacks = "2, 3"
Expand Down
1 change: 1 addition & 0 deletions src/EarthSciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using OrdinaryDiffEq, DomainSets
using SciMLBase: DECallback, CallbackSet
using DiffEqCallbacks
using LinearAlgebra, BlockBandedMatrices
using Accessors
using ProgressLogging
using Graphs

Expand Down
23 changes: 12 additions & 11 deletions src/coupled_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ A system for composing together other systems using the [`couple`](@ref) functio
$(FIELDS)
Things that can be added to a `CoupledSystem`:
* `ModelingToolkit.ODESystem`s. If the ODESystem has a field in the metadata called
`:coupletype` (e.g. `ModelingToolkit.get_metadata(sys)[:coupletype]` returns a struct type
* `ModelingToolkit.ODESystem`s. If the ODESystem has a field in the metadata called
`:coupletype` (e.g. `ModelingToolkit.get_metadata(sys)[:coupletype]` returns a struct type
with a single field called `sys`)
then that type will be used to check for methods of `EarthSciMLBase.couple` that use that type.
* [`Operator`](@ref)s
Expand All @@ -35,7 +35,7 @@ mutable struct CoupledSystem
"Initial and boundary conditions and other domain information"
domaininfo
"""
A vector of functions where each function takes as an argument the resulting PDESystem after DomainInfo is
A vector of functions where each function takes as an argument the resulting PDESystem after DomainInfo is
added to this system, and returns a transformed PDESystem.
"""
pdefunctions::AbstractVector
Expand All @@ -56,14 +56,14 @@ function Base.show(io::IO, cs::CoupledSystem)
print(io, "CoupledSystem containing $(length(cs.systems)) system(s), $(length(cs.ops)) operator(s), and $(length(cs.callbacks) + length(cs.init_callbacks)) callback(s).")
end

"""
"""
$(TYPEDSIGNATURES)
Couple multiple ModelingToolkit systems together.
The systems that are arguments to this system can be of type `ModelingToolkit.AbstractSystem`,
[`CoupledSystem`](@ref), [`DomainInfo`](@ref),
or any type `T` that has a method `couple(::CoupledSystem, ::T)::CoupledSystem` or a method
The systems that are arguments to this system can be of type `ModelingToolkit.AbstractSystem`,
[`CoupledSystem`](@ref), [`DomainInfo`](@ref),
or any type `T` that has a method `couple(::CoupledSystem, ::T)::CoupledSystem` or a method
`couple(::T, ::CoupledSystem)::CoupledSystem` defined for it.
"""
function couple(systems...)::CoupledSystem
Expand Down Expand Up @@ -122,10 +122,10 @@ end
"""
$(SIGNATURES)
Perform bi-directional coupling for two
Perform bi-directional coupling for two
equation systems.
To specify couplings for system pairs, create
To specify couplings for system pairs, create
methods for this function with the signature:
```julia
Expand Down Expand Up @@ -169,7 +169,8 @@ function Base.convert(::Type{<:ODESystem}, sys::CoupledSystem; name=:model, kwar
connectors = ODESystem(connector_eqs, iv; name=name, kwargs...)

# Compose everything together.
compose(connectors, systems...)
o = compose(connectors, systems...)
remove_extra_defaults(o)
end

"""
Expand Down Expand Up @@ -202,4 +203,4 @@ struct ConnectorSystem
eqs::Vector{Equation}
from::ModelingToolkit.AbstractSystem
to::ModelingToolkit.AbstractSystem
end
end
12 changes: 2 additions & 10 deletions src/operator_compose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,8 @@ function operator_compose(a::ModelingToolkit.ODESystem, b::ModelingToolkit.ODESy
end
end
end
aa = ODESystem(a_eqs, ModelingToolkit.get_iv(a);
name=nameof(a),
metadata=ModelingToolkit.get_metadata(a),
continuous_events=ModelingToolkit.get_continuous_events(a),
discrete_events=ModelingToolkit.get_discrete_events(a))

bb = ODESystem(b_eqs, ModelingToolkit.get_iv(b);
name=nameof(b), metadata=ModelingToolkit.get_metadata(b),
continuous_events=ModelingToolkit.get_continuous_events(b),
discrete_events=ModelingToolkit.get_discrete_events(b))
aa = copy_with_change(a; eqs=a_eqs)
bb = copy_with_change(b; eqs=b_eqs)
ConnectorSystem(connections, aa, bb)
end

Expand Down
11 changes: 5 additions & 6 deletions src/param_to_var.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@ function param_to_var(sys::ModelingToolkit.AbstractSystem, ps::Symbol...)
end

newsys = SymbolicUtils.substitute(sys, replace)
continuous_events = ModelingToolkit.get_continuous_events(sys)
discrete_events = ModelingToolkit.get_discrete_events(sys)
ODESystem(equations(newsys), ModelingToolkit.get_iv(newsys);
name=nameof(newsys), metadata=ModelingToolkit.get_metadata(sys),
continuous_events=continuous_events,
discrete_events=discrete_events)
copy_with_change(newsys;
metadata=ModelingToolkit.get_metadata(sys),
discrete_events=ModelingToolkit.get_discrete_events(sys),
continuous_events=ModelingToolkit.get_continuous_events(sys),
)
end
51 changes: 43 additions & 8 deletions src/simulator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,22 @@ end
"""
$(SIGNATURES)
Create a copy of an ODESystem with the given changes.
"""
function copy_with_change(sys::ODESystem;
eqs=equations(sys),
name=nameof(sys),
metadata=ModelingToolkit.get_metadata(sys),
continuous_events=ModelingToolkit.get_continuous_events(sys),
discrete_events=ModelingToolkit.get_discrete_events(sys),
)
ODESystem(eqs, ModelingToolkit.get_iv(sys), name=name, metadata=metadata,
continuous_events=continuous_events, discrete_events=discrete_events)
end

"""
$(SIGNATURES)
Remove equations from an ODESystem where the variable in the LHS is not
present in any of the equations for the state variables. This can be used to
remove computationally intensive equations that are not used in the final model.
Expand All @@ -179,14 +195,33 @@ function prune_observed(sys::ODESystem)
end
obs = observed(sys)
deleteat!(obs, deleteindex)
ce = ModelingToolkit.get_continuous_events(sys)
de = ModelingToolkit.get_discrete_events(sys)
sys2 = structural_simplify(ODESystem([equations(sys); obs],
ModelingToolkit.get_iv(sys), name=nameof(sys),
metadata=ModelingToolkit.get_metadata(sys),
continuous_events=ce,
discrete_events=de
), split=false,
sys2 = structural_simplify(
copy_with_change(sys; eqs=[equations(sys); obs]),
split=false,
)
return sys2, observed(sys)
end

# Remove extra variable defaults that would cause a solver initialization error.
# This should be done before running `structural_simplify` on the system.
function remove_extra_defaults(sys)
all_vars = unique(vcat(get_variables.(equations(sys))...))

unk = Symbol.(unknowns(structural_simplify(sys)))

# Check if v is not in the unknowns, is a variable, and has a default.
checkextra(v) = !(Symbol(v) in unk) &&
v.metadata[Symbolics.VariableSource][1] == :variables &&
(Symbolics.VariableDefaultValue in keys(v.metadata))
extra_default_vars = all_vars[checkextra.(all_vars)]

replacements = []
for v in extra_default_vars
newmeta = Base.ImmutableDict(filter(kv -> kv[1] != Symbolics.VariableDefaultValue,
Dict(v.metadata))...)
newv = @set v.metadata = newmeta
push!(replacements, v => newv)
end
new_eqs = substitute.(equations(sys), (Dict(replacements...),))
copy_with_change(sys, eqs=new_eqs)
end
88 changes: 88 additions & 0 deletions test/integrated_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using Test
using ModelingToolkit, Catalyst, Main.EarthSciMLBase
using ModelingToolkit: t_nounits, D_nounits
using OrdinaryDiffEq: ODEProblem, solve
using SciMLBase: ReturnCode
t = t_nounits
D = D_nounits

struct PhotolysisCoupler
sys
end
function Photolysis(; name=:Photolysis)
@variables j_NO2(t) = 1
eqs = [
j_NO2 ~ max(sin(t / 86400), 0)
]
ODESystem(eqs, t, [j_NO2], [], name=name,
metadata=Dict(:coupletype => PhotolysisCoupler))
end

struct ChemistryCoupler
sys
end
function Chemistry(; name=:Chemistry)
@parameters jNO2 = 1
@species NO2(t) = 2
rxs = [
Reaction(jNO2, [NO2], [], [1], [])
]
rsys = ReactionSystem(rxs, t, [NO2], [jNO2];
combinatoric_ratelaws=false, name=name)
convert(ODESystem, complete(rsys), metadata=Dict(:coupletype => ChemistryCoupler))
end

struct EmissionsCoupler
sys
end
function Emissions(; name=:Emissions)
@parameters emis = 1
@variables NO2(t) = 3
eqs = [D(NO2) ~ emis]
ODESystem(eqs, t; name=name,
metadata=Dict(:coupletype => EmissionsCoupler))
end

function EarthSciMLBase.couple2(c::ChemistryCoupler, p::PhotolysisCoupler)
c, p = c.sys, p.sys
c = param_to_var(convert(ODESystem, c), :jNO2)
ConnectorSystem([c.jNO2 ~ p.j_NO2], c, p)
end

function EarthSciMLBase.couple2(c::ChemistryCoupler, emis::EmissionsCoupler)
c, emis = c.sys, emis.sys
operator_compose(convert(ODESystem, c), emis, Dict(
c.NO2 => emis.NO2,
))
end

p = Photolysis()
@testset "Photolysis single" begin
prob = ODEProblem(structural_simplify(p), [], (0.0, 1.0))
sol = solve(prob)
@test sol.retcode == ReturnCode.Success
end

c = Chemistry()
@testset "Chemistry single" begin
prob = ODEProblem(structural_simplify(c), [], (0.0, 1.0))
sol = solve(prob)
@test sol.retcode == ReturnCode.Success
end

e = Emissions()
@testset "Emissions single" begin
prob = ODEProblem(structural_simplify(e), [], (0.0, 1.0))
sol = solve(prob)
@test sol.retcode == ReturnCode.Success
end

@testset "Coupled model" begin
model = couple(c, p, e)
sys = convert(ODESystem, model)
sys2, sys2obs = EarthSciMLBase.prune_observed(sys)

prob = ODEProblem(sys2, [], (0.0, 1.0))
sol = solve(prob, u0=[1.0])
@test sol.retcode == ReturnCode.Success
end
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ using Test, SafeTestsets
@safetestset "Parameter to Variable" begin include("param_to_var_test.jl") end
@safetestset "Simulator Utils" begin include("simulator_utils_test.jl") end
@safetestset "Simulator" begin include("simulator_test.jl") end
@safetestset "Docs" begin
@safetestset "Integrated Test" begin include("integrated_test.jl") end
@safetestset "Docs" begin
using Documenter
using EarthSciMLBase
doctest(EarthSciMLBase)
doctest(EarthSciMLBase)
end
end

0 comments on commit 9a66e04

Please sign in to comment.