Skip to content

Commit

Permalink
refactor: update implementation of discrete save interface
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jul 3, 2024
1 parent dbf697c commit 08ea550
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 94 deletions.
60 changes: 40 additions & 20 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,8 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
sym = unwrap(sym)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
return sym isa ParameterIndex || is_parameter(ic, sym) ||
iscall(sym) && operation(sym) === getindex &&
iscall(sym) &&
operation(sym) === getindex &&
is_parameter(ic, first(arguments(sym)))
end
if unwrap(sym) isa Int
Expand Down Expand Up @@ -519,34 +520,19 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
end

function SymbolicIndexingInterface.is_timeseries_parameter(sys::AbstractSystem, sym)
is_time_dependent(sys) || return false
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return false
is_timeseries_parameter(ic, sym)
end

function SymbolicIndexingInterface.timeseries_parameter_index(sys::AbstractSystem, sym)
is_time_dependent(sys) || return nothing
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return nothing
timeseries_parameter_index(ic, sym)
end

function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
allvars = vars(sym; op = Symbolics.Operator)
ts_idxs = Set{Int}()
for var in allvars
var = unwrap(var)
# FIXME: Shouldn't have to shift systems
if istree(var) && (op = operation(var)) isa Shift && op.steps == 1
var = only(arguments(var))
end
ts_idx = check_index_map(ic.discrete_idx, unwrap(var))
ts_idx === nothing && continue
push!(ts_idxs, ts_idx[1])
end
if length(ts_idxs) == 1
ts_idx = only(ts_idxs)
else
ts_idx = nothing
end
rawobs = build_explicit_observed_function(
sys, sym; param_only = true, return_inplace = true)
if rawobs isa Tuple
Expand All @@ -573,10 +559,44 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
end
end
else
ts_idx = nothing
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
end
return ParameterObservedFunction(ts_idx, obsfn)
return obsfn
end

function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym)
if is_variable(sys, sym)
push!(ts_idxs, ContinuousTimeseries())
elseif is_timeseries_parameter(sys, sym)
push!(ts_idxs, timeseries_parameter_index(sys, sym).timeseries_idx)
end
end
# Need this to avoid ambiguity with the array case
for traitT in [
ScalarSymbolic,
ArraySymbolic
]
@eval function _all_ts_idxs!(ts_idxs, ::$traitT, sys, sym)
allsyms = vars(sym; op = Symbolics.Operator)
foreach(allsyms) do s
_all_ts_idxs!(ts_idxs, sys, s)
end
end
end
function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym::AbstractArray)
foreach(sym) do s
_all_ts_idxs!(ts_idxs, sys, s)
end
end
_all_ts_idxs!(ts_idxs, sys, sym) = _all_ts_idxs!(ts_idxs, NotSymbolic(), sys, sym)

function SymbolicIndexingInterface.get_all_timeseries_indexes(sys::AbstractSystem, sym)
if !is_time_dependent(sys)
return Set()
end
ts_idxs = Set()
_all_ts_idxs!(ts_idxs, sys, sym)
return ts_idxs
end

function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
Expand Down
18 changes: 9 additions & 9 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function IndexCache(sys::AbstractSystem)
error("Discrete subsystem $i input $inp is not a parameter")
disc_clocks[inp] = i
disc_clocks[default_toterm(inp)] = i
if hasname(inp) && (!istree(inp) || operation(inp) !== getindex)
if hasname(inp) && (!iscall(inp) || operation(inp) !== getindex)
disc_clocks[getname(inp)] = i
disc_clocks[default_toterm(inp)] = i
end
Expand All @@ -126,7 +126,7 @@ function IndexCache(sys::AbstractSystem)
error("Discrete subsystem $i unknown $sym is not a parameter")
disc_clocks[sym] = i
disc_clocks[default_toterm(sym)] = i
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
disc_clocks[getname(sym)] = i
disc_clocks[getname(default_toterm(sym))] = i
end
Expand All @@ -138,13 +138,13 @@ function IndexCache(sys::AbstractSystem)
# FIXME: This shouldn't be necessary
eq.rhs === -0.0 && continue
sym = eq.lhs
if istree(sym) && operation(sym) == Shift(t, 1)
if iscall(sym) && operation(sym) == Shift(t, 1)
sym = only(arguments(sym))
end
disc_clocks[sym] = i
disc_clocks[sym] = i
disc_clocks[default_toterm(sym)] = i
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
disc_clocks[getname(sym)] = i
disc_clocks[getname(default_toterm(sym))] = i
end
Expand All @@ -153,7 +153,7 @@ function IndexCache(sys::AbstractSystem)

for par in inputs[continuous_id]
is_parameter(sys, par) || error("Discrete subsystem input is not a parameter")
istree(par) && operation(par) isa Hold ||
iscall(par) && operation(par) isa Hold ||
error("Continuous subsystem input is not a Hold")
if haskey(disc_clocks, par)
sym = par
Expand All @@ -176,7 +176,7 @@ function IndexCache(sys::AbstractSystem)
disc_clocks[affect.lhs] = user_affect_clock
disc_clocks[default_toterm(affect.lhs)] = user_affect_clock
if hasname(affect.lhs) &&
(!istree(affect.lhs) || operation(affect.lhs) !== getindex)
(!iscall(affect.lhs) || operation(affect.lhs) !== getindex)
disc_clocks[getname(affect.lhs)] = user_affect_clock
disc_clocks[getname(default_toterm(affect.lhs))] = user_affect_clock
end
Expand All @@ -190,7 +190,7 @@ function IndexCache(sys::AbstractSystem)
disc = unwrap(disc)
disc_clocks[disc] = user_affect_clock
disc_clocks[default_toterm(disc)] = user_affect_clock
if hasname(disc) && (!istree(disc) || operation(disc) !== getindex)
if hasname(disc) && (!iscall(disc) || operation(disc) !== getindex)
disc_clocks[getname(disc)] = user_affect_clock
disc_clocks[getname(default_toterm(disc))] = user_affect_clock
end
Expand Down Expand Up @@ -245,7 +245,7 @@ function IndexCache(sys::AbstractSystem)
for (j, sym) in enumerate(buffer[btype])
disc_idxs[sym] = (clockidx, i, j)
disc_idxs[default_toterm(sym)] = (clockidx, i, j)
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
disc_idxs[getname(sym)] = (clockidx, i, j)
disc_idxs[getname(default_toterm(sym))] = (clockidx, i, j)
end
Expand All @@ -256,7 +256,7 @@ function IndexCache(sys::AbstractSystem)
haskey(disc_idxs, sym) && continue
disc_idxs[sym] = (clockid, 0, 0)
disc_idxs[default_toterm(sym)] = (clockid, 0, 0)
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
disc_idxs[getname(sym)] = (clockid, 0, 0)
disc_idxs[getname(default_toterm(sym))] = (clockid, 0, 0)
end
Expand Down
5 changes: 3 additions & 2 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ function SymbolicIndexingInterface.set_parameter!(
if validate_size && size(val) !== size(p.discrete[i][j][k])
throw(InvalidParameterSizeException(size(p.discrete[i][j][k]), size(val)))
end
p.discrete[i][j][k][l...] = val
p.discrete[i][j][k] = val
else
p.discrete[i][j][k][l...] = val
end
Expand Down Expand Up @@ -563,7 +563,8 @@ end
Base.size(::NestedGetIndex) = ()

function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(
ps::MTKParameters, args::Pair{A, B}...) where {A, B <: NestedGetIndex}
::AbstractSystem, ps::MTKParameters, args::Pair{A, B}...) where {
A, B <: NestedGetIndex}
for (i, val) in args
ps.discrete[i] = val.x
end
Expand Down
52 changes: 8 additions & 44 deletions test/mtkparameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
using SymbolicIndexingInterface
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants
using StaticArrays: SizedVector
using OrdinaryDiffEq
using ForwardDiff
using JET
Expand Down Expand Up @@ -292,29 +293,10 @@ end
end

# Parameter timeseries
# dt = 0.1
# dt2 = 0.2
# @variables x(t)=0 y(t)=0 u(t)=0 yd1(t)=0 ud1(t)=0 yd2(t)=0 ud2(t)=0
# @parameters kp=1 r=1

# eqs = [
# # controller (time discrete part `dt=0.1`)
# yd1 ~ Sample(t, dt)(y)
# ud1 ~ kp * (r - yd1)
# # controller (time discrete part `dt=0.2`)
# yd2 ~ Sample(t, dt2)(y)
# ud2 ~ kp * (r - yd2)

# # plant (time continuous part)
# u ~ Hold(ud1) + Hold(ud2)
# D(x) ~ -x + u
# y ~ x]

# @mtkbuild cl = ODESystem(eqs, t)
ps = MTKParameters(([1.0, 1.0],), SizedArray{2}([([0.0, 0.0],), ([0.0, 0.0],)]), (), (), (), nothing, nothing)
# ps = MTKParameters(cl, [kp => 1.0])
ps = MTKParameters(([1.0, 1.0],), SizedVector{2}([([0.0, 0.0],), ([0.0, 0.0],)]),
(), (), (), nothing, nothing)
with_updated_parameter_timeseries_values(
ps, 1 => ModelingToolkit.NestedGetIndex(([5.0, 10.0],)))
sys, ps, 1 => ModelingToolkit.NestedGetIndex(([5.0, 10.0],)))
@test ps.discrete[1][1] == [5.0, 10.0]
with_updated_parameter_timeseries_values(
ps, 1 => ModelingToolkit.NestedGetIndex(([3.0, 30.0],)),
Expand All @@ -324,27 +306,9 @@ with_updated_parameter_timeseries_values(
@test SciMLBase.get_saveable_values(ps, 1).x == ps.discrete[1]

# With multiple types and clocks
# @variables x(t) xd1(t) xd2(t) flag(t)::Bool yd1(t) yd2(t) yc1(t) yc2(t)
# dt = 0.1
# k1 = ShiftIndex(t, dt)
# ssc = ModelingToolkit.SolverStepClock(t)
# k2 = ShiftIndex(ssc)

# eqs = [
# flag ~ ~flag(k1 - 1),
# xd1 ~ Sample(t, dt)(x),
# yd1 ~ ifelse(flag, xd1, yd1(k1 - 1)), xd2 ~ Sample(ssc)(x),
# yd2 ~ yd2(k2 - 1) + xd2, yc1 ~ Hold(yd1),
# yc2 ~ Hold(yd2),
# D(x) ~ yc1 + yc2
# ]
# @mtkbuild sys = ODESystem(eqs, t)
# ps = MTKParameters(sys,
# [flag => true, yd1 => ifelse(flag, Sample(t, dt)(x), 1.0),
# yd2 => 2.0 + Sample(ssc)(x), Sample(t, dt)(x) => x,
# Sample(ssc)(x) => x, Hold(yd1) => yd1, Hold(yd2) => yd2],
# [x => 3.0])
ps = MTKParameters((), SizedVector{2}([([1.0, 2.0, 3.0], falses(1)), ([4.0, 5.0, 6.0], falses(0))]), (), (), (), nothing, nothing)
ps = MTKParameters(
(), SizedVector{2}([([1.0, 2.0, 3.0], falses(1)), ([4.0, 5.0, 6.0], falses(0))]),
(), (), (), nothing, nothing)
@test SciMLBase.get_saveable_values(ps, 1).x isa Tuple{Vector{Float64}, BitVector}
# tsidx1 = timeseries_parameter_index(sys, flag).timeseries_idx
# tsidx2 = 3 - tsidx1
Expand All @@ -355,6 +319,6 @@ tsidx2 = 2
@test length(ps.discrete[tsidx2][1]) == 3
@test length(ps.discrete[tsidx2][2]) == 0
with_updated_parameter_timeseries_values(
ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false])))
sys, ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false])))
@test ps.discrete[tsidx1][1] == [10.0, 11.0, 12.0]
@test ps.discrete[tsidx1][2][] == false
9 changes: 6 additions & 3 deletions test/parameter_dependencies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,21 @@ end
@test_skip begin
Tf = 1.0
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0;
yd(k - 2) => 2.0])
@test_nowarn solve(prob, Tsit5())

@mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp],
discrete_events = [[0.5] => [kp ~ 2.0]])
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0;
yd(k - 2) => 2.0])
@test prob.ps[kp] == 1.0
@test prob.ps[kq] == 2.0
@test_nowarn solve(prob, Tsit5())
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0;
yd(k - 2) => 2.0])
integ = init(prob, Tsit5())
@test integ.ps[kp] == 1.0
@test integ.ps[kq] == 2.0
Expand Down
29 changes: 13 additions & 16 deletions test/symbolic_indexing_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ using SciMLStructures: Tunable
odesys = complete(odesys)
@test default_values(odesys)[xy] == 3.0
pobs = parameter_observed(odesys, a + b)
@test pobs.timeseries_idx === nothing
@test pobs.observed_fn(
@test isempty(get_all_timeseries_indexes(odesys, a + b))
@test pobs(
ModelingToolkit.MTKParameters(odesys, [a => 1.0, b => 2.0]), 0.0) 3.0
pobs = parameter_observed(odesys, [a + b, a - b])
@test pobs.timeseries_idx === nothing
@test pobs.observed_fn(
@test isempty(get_all_timeseries_indexes(odesys, [a + b, a - b]))
@test pobs(
ModelingToolkit.MTKParameters(odesys, [a => 1.0, b => 2.0]), 0.0) [3.0, -1.0]
end

Expand Down Expand Up @@ -102,11 +102,11 @@ end
@test !is_time_dependent(ns)
ps = ModelingToolkit.MTKParameters(ns, [σ => 1.0, ρ => 2.0, β => 3.0])
pobs = parameter_observed(ns, σ + ρ)
@test pobs.timeseries_idx === nothing
@test pobs.observed_fn(ps) == 3.0
@test isempty(get_all_timeseries_indexes(ns, σ + ρ))
@test pobs(ps) == 3.0
pobs = parameter_observed(ns, [σ + ρ, ρ + β])
@test pobs.timeseries_idx === nothing
@test pobs.observed_fn(ps) == [3.0, 5.0]
@test isempty(get_all_timeseries_indexes(ns, [σ + ρ, ρ + β]))
@test pobs(ps) == [3.0, 5.0]
end

@testset "PDESystem" begin
Expand Down Expand Up @@ -147,6 +147,11 @@ end
domains = [t (0.0, 1.0),
x (0.0, 1.0)]

analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)]
analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t)

@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h], analytic = analytic)

@test isequal(pdesys.ps, [h])
@test isequal(parameter_symbols(pdesys), [h])
@test isequal(parameters(pdesys), [h])
Expand Down Expand Up @@ -179,12 +184,4 @@ get_dep = @test_nowarn getu(prob, 2p1)
@test getu(prob, z)(prob) == getu(prob, :z)(prob)
@test getu(prob, p1)(prob) == getu(prob, :p1)(prob)
@test getu(prob, p2)(prob) == getu(prob, :p2)(prob)
analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)]
analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t)

@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h], analytic = analytic)

@test isequal(pdesys.ps, [h])
@test isequal(parameter_symbols(pdesys), [h])
@test isequal(parameters(pdesys), [h])
end

0 comments on commit 08ea550

Please sign in to comment.