Skip to content

Commit

Permalink
Merge pull request #102 from SciML/as/setp-oop-bi
Browse files Browse the repository at this point in the history
feat: add `setsym_oop` for `BatchedInterface`
  • Loading branch information
AayushSabharwal authored Oct 15, 2024
2 parents 5249fbb + a8a7d4f commit 386f701
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ with_updated_parameter_timeseries_values
```@docs
BatchedInterface
associated_systems
setsym_oop
```

## Container objects
Expand Down
2 changes: 1 addition & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ include("parameter_indexing.jl")
export getu, setu
include("state_indexing.jl")

export BatchedInterface, associated_systems
export BatchedInterface, setsym_oop, associated_systems
include("batched_interface.jl")

export ProblemState
Expand Down
115 changes: 112 additions & 3 deletions src/batched_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ See [`getu`](@ref) and [`setu`](@ref) for further details.
See also: [`associated_systems`](@ref).
"""
struct BatchedInterface{S <: AbstractVector, I, T}
struct BatchedInterface{S <: AbstractVector, I, T, P}
"Order of symbols in the union."
symbol_order::S
"Index of the index provider each symbol in the union is associated with."
Expand All @@ -36,6 +36,8 @@ struct BatchedInterface{S <: AbstractVector, I, T}
system_to_symbol_indexes::Vector{Vector{T}}
"Map from index provider to whether each of its symbols is a state in the index provider."
system_to_isstate::Vector{BitVector}
"Index providers, in order"
index_providers::Vector{P}
end

function BatchedInterface(syssyms::Tuple...)
Expand All @@ -46,11 +48,18 @@ function BatchedInterface(syssyms::Tuple...)
system_to_symbol_subset = Vector{Int}[]
system_to_symbol_indexes = []
system_to_isstate = BitVector[]
index_providers = []
for (i, (sys, syms)) in enumerate(syssyms)
symbol_subset = Int[]
symbol_indexes = []
system_isstate = BitVector()
allsyms = []
root_indp = sys
while applicable(symbolic_container, root_indp) &&
(sc = symbolic_container(root_indp)) != root_indp
root_indp = sc
end
push!(index_providers, root_indp)
for sym in syms
if symbolic_type(sym) === NotSymbolic()
error("Only symbolic variables allowed in BatchedInterface.")
Expand Down Expand Up @@ -89,9 +98,10 @@ function BatchedInterface(syssyms::Tuple...)
system_to_symbol_indexes = identity.(system_to_symbol_indexes)

return BatchedInterface{typeof(symbol_order), typeof(associated_indexes),
eltype(eltype(system_to_symbol_indexes))}(
eltype(eltype(system_to_symbol_indexes)), eltype(index_providers)}(
symbol_order, associated_systems, associated_indexes, isstate,
system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate)
system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate,
identity.(index_providers))
end

variable_symbols(bi::BatchedInterface) = bi.symbol_order
Expand Down Expand Up @@ -268,3 +278,102 @@ function setu(bi::BatchedInterface)
setter!
end
end

"""
setsym_oop(bi::BatchedInterface)
Given a [`BatchedInterface`](@ref) composed from `n` index providers (and corresponding
symbols), return a function which takes `n` corresponding value providers and an array of
values, and returns an `n`-tuple where each element is a 2-tuple consisting of the updated
state values and parameter values of the corresponding value provider. Requires that the
value provider implement [`state_values`](@ref), [`parameter_values`](@ref). The updates are
performed out-of-place using [`remake_buffer`](@ref).
Note that all of the value providers passed to the returned function must satisfy
`is_timeseries(prob) === NotTimeseries()`.
Note that if any subset of the `n` index providers share common symbols (among those passed
to `BatchedInterface`) then all of the corresponding value providers in the subset will be
updated with the values of the common symbols.
See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref).
"""
function setsym_oop(bi::BatchedInterface)
numprobs = length(bi.system_to_symbol_subset)
probnames = [Symbol(:prob, i) for i in 1:numprobs]
arg = :vals
full_update = Expr(:block)

function get_update_expr(prob::Symbol, sys_i::Int)
union_idxs = bi.system_to_symbol_subset[sys_i]
indp_idxs = bi.system_to_symbol_indexes[sys_i]
isstate = bi.system_to_isstate[sys_i]
indp = bi.index_providers[sys_i]
curexpr = Expr(:block)

statessym = Symbol(:states_, sys_i)
if all(.!isstate)
push!(curexpr.args, :($statessym = $state_values($prob)))
else
state_idxssym = Symbol(:state_idxs_, sys_i)
state_idxs = indp_idxs[isstate]
state_valssym = Symbol(:state_vals_, sys_i)
vals_idxs = union_idxs[isstate]
push!(curexpr.args, :($state_idxssym = $state_idxs))
push!(curexpr.args, :($state_valssym = $view($arg, $vals_idxs)))
push!(curexpr.args,
:($statessym = $remake_buffer(
syss[$sys_i], $state_values($prob), $state_idxssym, $state_valssym)))
end

paramssym = Symbol(:params_, sys_i)
if all(isstate)
push!(curexpr.args, :($paramssym = $parameter_values($prob)))
else
param_idxssym = Symbol(:param_idxs_, sys_i)
param_idxs = indp_idxs[.!isstate]
param_valssym = Symbol(:param_vals, sys_i)
vals_idxs = union_idxs[.!isstate]
push!(curexpr.args, :($param_idxssym = $param_idxs))
push!(curexpr.args, :($param_valssym = $view($arg, $vals_idxs)))
push!(curexpr.args,
:($paramssym = $remake_buffer(
syss[$sys_i], $parameter_values($prob), $param_idxssym, $param_valssym)))
end

return curexpr, statessym, paramssym
end

full_update_expr = Expr(:block)
full_update_retval = Expr(:tuple)
partial_update_expr = Expr(:block)
cur_partial_update_expr = partial_update_expr
for i in 1:numprobs
update_expr, statesym, paramsym = get_update_expr(probnames[i], i)
push!(full_update_expr.args, update_expr)
push!(full_update_retval.args, Expr(:tuple, statesym, paramsym))

cur_ifexpr = Expr(i == 1 ? :if : :elseif, :(idx == $i))
update_expr, statesym, paramsym = get_update_expr(:prob, i)
push!(update_expr.args, :(return ($statesym, $paramsym)))
push!(cur_ifexpr.args, update_expr)
push!(cur_partial_update_expr.args, cur_ifexpr)
cur_partial_update_expr = cur_ifexpr
end
push!(full_update_expr.args, :(return $full_update_retval))
push!(cur_partial_update_expr.args, :(error("Invalid problem index $idx")))

full_update_fnexpr = Expr(
:function, Expr(:tuple, :syss, probnames..., arg), full_update_expr)
partial_update_fnexpr = Expr(
:function, Expr(:tuple, :syss, :prob, :idx, arg), partial_update_expr)

return let full_update = @RuntimeGeneratedFunction(full_update_fnexpr),
partial_update = @RuntimeGeneratedFunction(partial_update_fnexpr),
syss = Tuple(bi.index_providers)

setter(args...) = full_update(syss, args...)
setter(prob, idx::Int, vals::AbstractVector) = partial_update(syss, prob, idx, vals)
setter
end
end
28 changes: 28 additions & 0 deletions test/batched_interface_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,31 @@ setter!(probs[1], 1, buf)
@test parameter_values(probs[1]) == [0.1, 0.2, 0.3]

@test_throws ErrorException setter!(probs[1], 4, buf)

setter!(probs..., buf)

setter = setsym_oop(bi)

buf .*= 100
vals = setter(probs..., buf)
@test length(vals) == length(probs)
@test vals[1][1] == [100.0, 2.0, 300.0]
@test vals[1][2] == [0.1, 20.0, 30.0]
@test vals[2][1] == [300.0, 500.0, 6.0]
@test vals[2][2] == [30.0, 0.5, 60.0]
@test vals[3][1] == [500.0, 100.0, 9.0]
@test vals[3][2] == [70.0, 80.0, 0.9]

# out-of-place
for i in 1:3
@test vals[i][1] != state_values(probs[i])
@test vals[i][2] != parameter_values(probs[i])
end

buf ./= 10
vals = setter(probs[1], 1, buf)
@test length(vals) == 2
@test vals[1] == [10.0, 2.0, 30.0]
@test vals[2] == [0.1, 2.0, 3.0]

@test_throws ErrorException setter(probs[1], 4, buf)
33 changes: 33 additions & 0 deletions test/downstream/batchedinterface_arrayvars.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,36 @@ buf ./= 10

setter!(probs[1], 1, buf)
@test state_values(probs[1]) == [1.0, 2.0, 3.0]

@variables a b[1:2] c

syss = [
SymbolCache([x..., y], [a, b...]),
SymbolCache([x[1], y, z], [a, b..., c])
]
syms = [
[x, y, a, b...],
[x[1], y, b[2], c]
]
probs = [
ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3]),
ProblemState(; u = [4.0, 5.0, 6.0], p = [0.1, 0.4, 0.5, 0.6])
]

bi = BatchedInterface(zip(syss, syms)...)

buf = getu(bi)(probs...)
buf .*= 100
setter = setsym_oop(bi)
vals = setter(probs..., buf)
@test length(vals) == length(probs)
@test vals[1][1] == [100.0, 200.0, 300.0]
@test vals[1][2] == [10.0, 20.0, 30.0]
@test vals[2][1] == [100.0, 300.0, 6.0]
@test vals[2][2] == [0.1, 0.4, 30.0, 60.0]

buf ./= 10
vals = setter(probs[1], 1, buf)
@test length(vals) == 2
@test vals[1] == [10.0, 20.0, 30.0]
@test vals[2] == [1.0, 2.0, 3.0]

0 comments on commit 386f701

Please sign in to comment.