Skip to content

Commit

Permalink
Update observed function
Browse files Browse the repository at this point in the history
  • Loading branch information
ctessum committed Nov 8, 2024
1 parent f508770 commit abd4d7b
Show file tree
Hide file tree
Showing 18 changed files with 260 additions and 199 deletions.
50 changes: 31 additions & 19 deletions docs/src/operator.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,41 +40,43 @@ There is also a variable `windspeed` which is "observed" based on the parameters
Next, we define an operator. To do so, first we create a new type that is a subtype of [`Operator`](@ref):

```@example sim
mutable struct ExampleOp <: Operator
α::Num # Multiplier from ODESystem
struct ExampleOp <: Operator
end
```
In the case above, we're setting up our operator so that it can hold a parameter from our ODE system.

Next, we need to define a method of `EarthSciMLBase.get_scimlop` for our operator. This method will be called to get a [`SciMLOperators.AbstractSciMLOperator`](https://docs.sciml.ai/SciMLOperators/stable/interface/) that will be used conjunction with the ModelingToolkit system above to integrate the simulation forward in time.

```@example sim
function EarthSciMLBase.get_scimlop(op::ExampleOp, csys::CoupledSystem, mtk_sys, domain::DomainInfo, obs_functions, coordinate_transform_functions, u0, p)
obs_f = obs_functions(op.α)
function EarthSciMLBase.get_scimlop(op::ExampleOp, csys::CoupledSystem, mtk_sys, domain::DomainInfo, u0, p)
α, trans1, trans2, trans3 = EarthSciMLBase.get_needed_vars(op, csys, mtk_sys, domain)
obs_f! = ModelingToolkit.build_explicit_observed_function(mtk_sys,
[α, trans1, trans2, trans3], checkbounds=false, return_inplace=true)[2]
setp! = EarthSciMLBase.coord_setter(mtk_sys, domain)
obscache = zeros(EarthSciMLBase.dtype(domain), 4)
grd = EarthSciMLBase.grid(domain)
function run(du, u, p, t)
u = reshape(u, size(u0)...)
du = reshape(du, size(u0)...)
II = CartesianIndices(size(u)[2:end])
for ix ∈ 1:size(u, 1)
for (i, c1) ∈ enumerate(grd[1])
for (j, c2) ∈ enumerate(grd[2])
for (k, c3) ∈ enumerate(grd[3])
# Demonstrate coordinate transforms
t1 = coordinate_transform_functions[1](t, c1, c2, c3)
t2 = coordinate_transform_functions[2](t, c1, c2, c3)
t3 = coordinate_transform_functions[3](t, c1, c2, c3)
# Demonstrate calculating observed value.
fv = obs_f(t, c1, c2, c3)
# Set derivative value.
du[ix, i, j, k] = (t1 + t2 + t3) * fv
end
end
for I in II
# Demonstrate coordinate transforms and observed values
uu = view(u, :, I)
setp!(p, I)
obs_f!(obscache, u, p, t)
t1, t2, t3, fv = obscache
# Set derivative value.
du[ix, I] = (t1 + t2 + t3) * fv
end
end
nothing
end
FunctionOperator(run, u0[:], p=p)
end
nothing
```
The function above also doesn't have any physical meaning, but it demonstrates some functionality of the `Operator` "`s`".
First, it retrieves a function to get the current value of an observed variable in our
Expand All @@ -83,6 +85,16 @@ function to get that value.
It also demonstrates how to get coordinate transforms using the `coordinate_transform_functions` argument.
Coordinate transforms are discussed in more detail in the documentation for the [`DomainInfo`](@ref) type.

We also need to define a method of `EarthSciMLBase.get_needed_vars`, which will return which variables are needed by the operator.

```@example sim
function EarthSciMLBase.get_needed_vars(::ExampleOp, csys, mtk_sys, domain::DomainInfo)
ts = EarthSciMLBase.partialderivative_transform_vars(mtk_sys, domain)
return [mtk_sys.sys₊windspeed, ts...]
end
nothing
```

## Domain

Once we have an ODE system and an operator, the final component we need is a domain to run the simulation on.
Expand Down Expand Up @@ -121,7 +133,7 @@ coordinates, which we set as 0.1π, 0.1π, and 1, respectively.
Next, initialize our operator, giving the the `windspeed` observed variable, and we can couple our ODESystem, Operator, and Domain together into a single model:

```@example sim
op = ExampleOp(sys.windspeed)
op = ExampleOp()
csys = couple(sys, op, domain)
```
Expand Down
2 changes: 1 addition & 1 deletion src/EarthSciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using DynamicQuantities, Dates
using DomainSets
using SciMLBase: DECallback, CallbackSet, ODEProblem, SplitODEProblem, reinit!, solve!,
init, remake
using SciMLOperators: cache_operator, NullOperator
using SciMLOperators: cache_operator, NullOperator, FunctionOperator, TensorProductOperator
using Statistics
using DiffEqCallbacks
using LinearAlgebra, BlockBandedMatrices
Expand Down
24 changes: 15 additions & 9 deletions src/coord_trans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,28 @@ varindex(pvars::AbstractVector, varname::Symbol) = findfirst(nameof.(pvars) .==
"""
$(SIGNATURES)
Return partial derivative operator transform factors corresponding
Return partial derivative operator transform factors corresponding
for the given partial-independent variables
after converting variables named `lon` and `lat` from degrees to x and y meters,
after converting variables named `lon` and `lat` from degrees to x and y meters,
assuming they represent longitude and latitude on a spherical Earth.
"""
function partialderivatives_δxyδlonlat(pvars::AbstractVector; default_lat=0.0)
latindex = varindex(pvars, :lat)
lonindex = varindex(pvars, :lon)
if !isnothing(latindex)
lat = pvars[latindex]
latindex = matching_suffix_idx(pvars, :lat)
lonindex = matching_suffix_idx(pvars, :lon)
if length(latindex) > 1
throw(error("Multiple variables with suffix :lat found in pvars: $(pvars[latindex])"))
end
if length(lonindex) > 1
throw(error("Multiple variables with suffix :lon found in pvars: $(pvars[lonindex])"))
end
if length(latindex) > 0
lat = pvars[only(latindex)]
else
lat = default_lat
end

Dict(
lonindex => 1.0 / lon2meters(lat),
latindex => 1.0 / lat2meters
only(lonindex) => 1.0 / lon2meters(lat),
only(latindex) => 1.0 / lat2meters
)
end
end
35 changes: 22 additions & 13 deletions src/coupled_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function couple(systems...)::CoupledSystem
push!(o.callbacks, sys)
elseif (sys isa Tuple) || (sys isa AbstractVector)
o = couple(o, sys...)
elseif hasmethod(init_callback, (typeof(sys), CoupledSystem, Any, Any, DomainInfo))
elseif hasmethod(init_callback, (typeof(sys), CoupledSystem, Any, DomainInfo))
push!(o.init_callbacks, sys)
elseif hasmethod(couple2, (CoupledSystem, typeof(sys)))
o = couple2(o, sys)
Expand Down Expand Up @@ -164,24 +164,31 @@ function Base.convert(::Type{<:ODESystem}, sys::CoupledSystem; name=:model, simp
end
end
end

iv = ModelingToolkit.get_iv(first(systems))
connectors = ODESystem(connector_eqs, iv; name=name, kwargs...)

# Compose everything together.
o = compose(connectors, systems...)
o = namespace_events(o)
if !isnothing(sys.domaininfo)
o = extend(o, partialderivative_transform_eqs(o, sys.domaininfo))
end
o_simplified = structural_simplify(o)
# Add coordinate transform equations.
if prune
o, obs = prune_observed(o, o_simplified)
else
obs = []
extra_vars = []
if !isnothing(sys.domaininfo)
extra_vars = operator_vars(sys, o_simplified, sys.domaininfo)
end
o = prune_observed(o, o_simplified, extra_vars)
end
o_simplified = structural_simplify(o)
o = remove_extra_defaults(o, o_simplified)
if simplify
o = structural_simplify(o)
end
return o, obs
return o
end

"""
Expand All @@ -190,7 +197,7 @@ end
Get the ModelingToolkit PDESystem representation of a [`CoupledSystem`](@ref).
"""
function Base.convert(::Type{<:PDESystem}, sys::CoupledSystem; name=:model, kwargs...)::ModelingToolkit.AbstractSystem
o, _ = convert(ODESystem, sys; name=name, simplify=false, prune=false, kwargs...)
o = convert(ODESystem, sys; name=name, simplify=false, prune=false, kwargs...)

if sys.domaininfo !== nothing
o += sys.domaininfo
Expand Down Expand Up @@ -218,28 +225,30 @@ end

# Combine the non-stiff operators into a single operator.
# This works because SciMLOperators can be added together.
function nonstiff_ops(sys::CoupledSystem, sys_mtk, obs_eqs, domain, u0, p)
obs_funcs = obs_functions(obs_eqs, domain)
coord_trans_funcs = coord_trans_functions(obs_eqs, domain)
function nonstiff_ops(sys::CoupledSystem, sys_mtk, domain, u0, p)
nonstiff_op = length(sys.ops) > 0 ?
sum([get_scimlop(op, sys, sys_mtk, domain, obs_funcs, coord_trans_funcs, u0, p) for op sys.ops]) :
sum([get_scimlop(op, sys, sys_mtk, domain, u0, p) for op sys.ops]) :
NullOperator(length(u0))
nonstiff_op = cache_operator(nonstiff_op, u0)
end

function operator_vars(sys::CoupledSystem, mtk_sys, domain::DomainInfo)
unique(vcat([get_needed_vars(op, sys, mtk_sys, domain) for op in sys.ops]...))
end

"""
Types that implement an:
`init_callback(x, sys::CoupledSystem, obs_eqs, domain::DomainInfo)::DECallback`
`init_callback(x, sys::CoupledSystem, sys_mtk, domain::DomainInfo)::DECallback`
method can also be coupled into a `CoupledSystem`.
The `init_callback` function will be run before the simulator is run
to get the callback.
"""
init_callback() = error("Not implemented")

function get_callbacks(sys::CoupledSystem, sys_mtk, obs_eqs, domain::DomainInfo)
extra_cb = [init_callback(c, sys, sys_mtk, obs_eqs, domain::DomainInfo) for c sys.init_callbacks]
function get_callbacks(sys::CoupledSystem, sys_mtk, domain::DomainInfo)
extra_cb = [init_callback(c, sys, sys_mtk, domain::DomainInfo) for c sys.init_callbacks]
[sys.callbacks; extra_cb]
end

Expand Down
56 changes: 18 additions & 38 deletions src/coupled_system_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ $(SIGNATURES)
Return the indexes of the system variables that the state variables of the final
simplified system depend on. This should be done before running `structural_simplify`
on the system.
`extra_vars` is a list of additional variables that need to be kept.
"""
function get_needed_vars(original_sys::ODESystem, simplified_sys)
function get_needed_vars(original_sys::ODESystem, simplified_sys::ODESystem, extra_vars=[])
varvardeps = ModelingToolkit.varvar_dependencies(
ModelingToolkit.asgraph(original_sys),
ModelingToolkit.variable_dependencies(original_sys),
Expand All @@ -112,7 +113,7 @@ function get_needed_vars(original_sys::ODESystem, simplified_sys)
return []
end
idx = collect(Graphs.DFSIterator(g, stidx))
unknowns(original_sys)[idx]
unique(vcat(unknowns(original_sys)[idx], extra_vars))
end

"""
Expand Down Expand Up @@ -202,8 +203,8 @@ Remove equations from an ODESystem where the variable in the LHS is not
present in any of the equations for the state variables. This can be used to
remove computationally intensive equations that are not used in the final model.
"""
function prune_observed(original_sys::ODESystem, simplified_sys)
needed_vars = var2symbol.(get_needed_vars(original_sys, simplified_sys))
function prune_observed(original_sys::ODESystem, simplified_sys, extra_vars)
needed_vars = var2symbol.(get_needed_vars(original_sys, simplified_sys, extra_vars))
deleteindex = []
obs = observed(simplified_sys)
for (i, eq) enumerate(obs)
Expand All @@ -222,7 +223,7 @@ function prune_observed(original_sys::ODESystem, simplified_sys)
unknowns=get_unknowns(new_eqs),
discrete_events=discrete_events,
)
return sys2, observed(simplified_sys)
return sys2
end

# Get the unknown variables in the system of equations.
Expand Down Expand Up @@ -277,12 +278,23 @@ function default_params(mtk_sys::AbstractSystem)
MTKParameters(mtk_sys, dflts)
end

# return whether the part of a after the last "₊" character matches b.
function is_matching_suffix(a, b)
is_matching_suffix(a, var2symbol(b))
end
function is_matching_suffix(a, b::Symbol)
split(String(var2symbol(a)), "")[end] == String(b)
end
function matching_suffix_idx(a::AbstractVector, b)
findall(v -> is_matching_suffix(v, b), a)
end

# Return the coordinate parameters from the parameter vector.
function coord_params(mtk_sys::AbstractSystem, domain::DomainInfo)
pv = pvars(domain)
params = parameters(mtk_sys)

_pvidx = [findall(v -> split(String(var2symbol(v)), "")[end] == String(var2symbol(p)), params) for p pv]
_pvidx = [matching_suffix_idx(params, p) for p pv]
for (i, idx) in enumerate(_pvidx)
if length(idx) > 1
error("Partial independent variable '$(pv[i])' has multiple matches in system parameters: [$(parameters(mtk_sys)[idx])].")
Expand All @@ -308,35 +320,3 @@ function coord_setter(sys_mtk::ODESystem, domain::DomainInfo)
end
return setp!
end

# Create functions to get concrete values for the observed variables.
# The return value is a function that returns the function when
# given a value.
function obs_functions(obs_eqs, domain::DomainInfo)
pv = pvars(domain)
iv = ivar(domain)

obs_fs_idx = Dict()
obs_fs = []
for (i, x) enumerate([eq.lhs for eq obs_eqs])
obs_fs_idx[x] = i
push!(obs_fs, observed_function(obs_eqs, x, [iv, pv...]))
end
obs_fs = Tuple(obs_fs)

(v) -> obs_fs[obs_fs_idx[v]]
end

# Return functions to perform coordinate transforms for each of the coordinates.
function coord_trans_functions(obs_eqs, domain::DomainInfo)
pv = pvars(domain)
iv = ivar(domain)

# Get functions for coordinate transforms
tf_fs = []
@variables 🌈🐉🏒 # Dummy variable.
for tf partialderivative_transforms(domain)
push!(tf_fs, observed_function([obs_eqs..., 🌈🐉🏒 ~ tf], 🌈🐉🏒, [iv, pv...]))
end
tf_fs = Tuple(tf_fs)
end
29 changes: 29 additions & 0 deletions src/domaininfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,15 @@ $(TYPEDSIGNATURES)
Return transform factor to multiply each partial derivative operator by,
for example to convert from degrees to meters.
"""
function partialderivative_transforms(mtk_sys::ODESystem, di::DomainInfo)
xs = coord_params(mtk_sys, di)
partialderivative_transforms(xs, di)
end
function partialderivative_transforms(di::DomainInfo)
xs = pvars(di)
partialderivative_transforms(xs, di)
end
function partialderivative_transforms(xs, di::DomainInfo)
fs = Dict()
for f in di.partial_derivative_funcs
for (k, v) f(xs)
Expand All @@ -260,6 +267,28 @@ function partialderivative_transforms(di::DomainInfo)
ts
end

function partialderivative_transform_vars(mtk_sys, di::DomainInfo)
xs = coord_params(mtk_sys, di)
iv = ivar(di)
ts = partialderivative_transforms(mtk_sys, di)
vs = []
for (i, x) in enumerate(xs)
n = Symbol("δ$(x)_transform")
v = only(@variables $n(iv) [unit=ModelingToolkit.get_unit(ts[i]),
description = "Transform factor for $(x)"])
push!(vs, v)
end
vs
end

function partialderivative_transform_eqs(mtk_sys, di::DomainInfo)
vs = partialderivative_transform_vars(mtk_sys, di)
ts = partialderivative_transforms(mtk_sys, di)
eqs = vs .~ ts
ODESystem(eqs, ivar(di); name=:transforms)
end



"""
$(TYPEDSIGNATURES)
Expand Down
Loading

0 comments on commit abd4d7b

Please sign in to comment.