Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support forwardiff of states #4

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "GraphDynamics"
uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c"
version = "0.1.1"
version = "0.1.2"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
27 changes: 20 additions & 7 deletions src/GraphDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export
GraphSystem,
ODEGraphSystem,
SDEGraphSystem,
get_tag,
get_states,
get_params,
ODEProblem,
Expand Down Expand Up @@ -69,7 +70,8 @@ using SciMLBase:
CallbackSet,
VectorContinuousCallback,
ContinuousCallback,
DiscreteCallback
DiscreteCallback,
remake

using RecursiveArrayTools: ArrayPartition

Expand Down Expand Up @@ -101,20 +103,31 @@ include("utils.jl")
#----------------------------------------------------------
# API functions to be implemented by new Systems

struct SubsystemStates{Name, Eltype, States <: NamedTuple} <: AbstractVector{Eltype}
struct SubsystemStates{T, Eltype, States <: NamedTuple} <: AbstractVector{Eltype}
states::States
end

struct SubsystemParams{Name, Params <: NamedTuple}
struct SubsystemParams{T, Params <: NamedTuple}
params::Params
end

struct Subsystem{Name, Eltype, States, Params}
states::SubsystemStates{Name, Eltype, States}
params::SubsystemParams{Name, Params}
"""
Subsystem{T, Eltype, StateNT, ParamNT}

A `Subsystem` struct describes a complete subcomponent to an `GraphSystem`. This stores a `SubsystemStates` to describe the continuous dynamical state of the subsystem, and a `GraphSystemParams` which describes various non-dynamical parameters of the subsystem. The type parameter `T` is the subsystem's \"tag\" which labels what sort of subsystem it is.

See also `subsystem_differential`, `SubsystemStates`, `SubsystemParams`.

For example, if we wanted to describe a system where one sub-component is a billiard ball,


"""
struct Subsystem{T, Eltype, States, Params}
states::SubsystemStates{T, Eltype, States}
params::SubsystemParams{T, Params}
end

function get_name end
function get_tag end
function get_params end
function get_states end

Expand Down
10 changes: 0 additions & 10 deletions src/graph_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ tany(f, coll; kwargs...) = tmapreduce(f, |, coll; kwargs...)
@nexprs $Len k -> begin
M = connection_matrices[nc][k, i]
if has_discrete_events(eltype(M))
#tany(foo(Val(k), Val(i), Val(NConn), M, t), eachindex(states_partitioned[i])) && return true
for j ∈ eachindex(states_partitioned[i])
for (l, Mlj) ∈ maybe_sparse_enumerate_col(M, j)
discrete_event_condition(Mlj, t) && return true
Expand All @@ -313,15 +312,6 @@ tany(f, coll; kwargs...) = tmapreduce(f, |, coll; kwargs...)
end
end

function foo(::Val{k}, ::Val{i}, ::Val{NConn}, M, t) where {i, NConn, k}
function f(j)
for (l, Mlj) ∈ maybe_sparse_enumerate_col(M, j)
discrete_event_condition(Mlj, t) && return true
end
false
end
end

function discrete_affect!(integrator)
(;params_partitioned, state_types_val, connection_matrices) = integrator.p
state_data = integrator.u.x
Expand Down
34 changes: 30 additions & 4 deletions src/problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ end
function SciMLBase.ODEProblem(g::ODEGraphSystem, u0map, tspan, param_map=[];
scheduler=SerialScheduler(), tstops=Float64[],
allow_nonconcrete=false, kwargs...)
nt = _problem(g, tspan; scheduler, allow_nonconcrete)
nt = _problem(g, tspan; scheduler, allow_nonconcrete, u0map, param_map)
(; f, u, tspan, p, callback) = nt
tstops = vcat(tstops, nt.tstops)
prob = ODEProblem(f, u, tspan, p; callback, tstops, kwargs...)
Expand All @@ -21,11 +21,18 @@ end
function SciMLBase.SDEProblem(g::SDEGraphSystem, u0map, tspan, param_map=[];
scheduler=SerialScheduler(), tstops=Float64[],
allow_nonconcrete=false, kwargs...)
nt = _problem(g, tspan; scheduler, allow_nonconcrete)
nt = _problem(g, tspan; scheduler, allow_nonconcrete, u0map, param_map)
(; f, u, tspan, p, callback) = nt

noise_rate_prototype = nothing # zeros(length(u)) # this'll need to change once we support correlated noise
SDEProblem(f, graph_noise!, u, tspan, p; callback, noise_rate_prototype, tstops = vcat(tstops, nt.tstops), kwargs...)
prob = SDEProblem(f, graph_noise!, u, tspan, p; callback, noise_rate_prototype, tstops = vcat(tstops, nt.tstops), kwargs...)
for (k, v) ∈ u0map
setu(prob, k)(prob, v)
end
for (k, v) ∈ param_map
setp(prob, k)(prob, v)
end
prob
end

Base.@kwdef struct GraphSystemParameters{PP, CM, S, STV}
Expand All @@ -35,14 +42,33 @@ Base.@kwdef struct GraphSystemParameters{PP, CM, S, STV}
state_types_val::STV
end

function _problem(g::GraphSystem, tspan; scheduler, allow_nonconcrete)
function _problem(g::GraphSystem, tspan; scheduler, allow_nonconcrete, u0map, param_map)
(; states_partitioned,
params_partitioned,
connection_matrices,
tstops,
composite_discrete_events_partitioned,
composite_continuous_events_partitioned,) = g

total_eltype = let
states_eltype = mapreduce(promote_type, states_partitioned) do v
eltype(eltype(v))
end
u0map_eltype = mapreduce(promote_type, u0map; init=Union{}) do (k, v)
typeof(v)
end
promote_type(states_eltype, u0map_eltype)
end

re_eltype(s::SubsystemStates{T}) where {T} = convert(SubsystemStates{T, total_eltype}, s)
states_partitioned = map(states_partitioned) do v
if eltype(eltype(v)) <: total_eltype && eltype(eltype(v)) !== Union{}
v
else
re_eltype.(v)
end
end

length(states_partitioned) == length(params_partitioned) ||
error("Incompatible state and parameter lengths")
for i ∈ eachindex(states_partitioned, params_partitioned)
Expand Down
58 changes: 50 additions & 8 deletions src/subsystems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ function ConstructionBase.setproperties(s::SubsystemParams{T}, patch::NamedTuple
SubsystemParams{T}(props′)
end

get_name(::SubsystemParams{Name}) where {Name} = Name
get_name(::Type{<:SubsystemParams{Name}}) where {Name} = Name
get_tag(::SubsystemParams{Name}) where {Name} = Name
get_tag(::Type{<:SubsystemParams{Name}}) where {Name} = Name
Base.NamedTuple(p::SubsystemParams) = getfield(p, :params)
Base.Tuple(s::SubsystemParams) = Tuple(getfield(s, :params))
Base.getproperty(p::SubsystemParams, prop::Symbol) = getproperty(NamedTuple(p), prop)
Expand All @@ -43,6 +43,10 @@ end
function SubsystemStates{Name}(nt::NamedTuple{state_names, NTuple{N, Eltype}}) where {Name, state_names, N, Eltype}
SubsystemStates{Name, Eltype, typeof(nt)}(nt)
end
function SubsystemStates{Name}(nt::NamedTuple{state_names, <:NTuple{N, Any}}) where {Name, state_names, N}
nt_promoted = NamedTuple{state_names}(promote(nt...))
SubsystemStates{Name}(nt_promoted)
end
function SubsystemStates{Name}(nt::NamedTuple{(), Tuple{}}) where {Name}
SubsystemStates{Name, Union{}, NamedTuple{(), Tuple{}}}(nt)
end
Expand Down Expand Up @@ -81,8 +85,8 @@ function ConstructionBase.setproperties(s::SubsystemStates{T}, patch::NamedTuple
SubsystemStates{T}(props′)
end

get_name(::SubsystemStates{Name}) where {Name} = Name
get_name(::Type{<:SubsystemStates{Name}}) where {Name} = Name
get_tag(::SubsystemStates{Name}) where {Name} = Name
get_tag(::Type{<:SubsystemStates{Name}}) where {Name} = Name
Base.NamedTuple(s::SubsystemStates) = getfield(s, :states)
Base.Tuple(s::SubsystemStates) = Tuple(getfield(s, :states))
Base.getproperty(s::SubsystemStates, prop::Symbol) = getproperty(NamedTuple(s), prop)
Expand All @@ -93,12 +97,25 @@ function state_ind(::Type{SubsystemStates{Name, Eltype, NamedTuple{names, Tup}}}
i = findfirst(==(s), names)
end

function Base.convert(::Type{SubsystemStates{Name, Eltype, NT}}, s::SubsystemStates{Name}) where {Name, Eltype, NT}
SubsystemStates{Name}(convert(NT, NamedTuple(s)))
end
function Base.convert(::Type{SubsystemStates{Name, Eltype}},
s::SubsystemStates{Name, <:Any, <:NamedTuple{state_names}}) where {Name, Eltype, state_names}
nt = NamedTuple{state_names}(convert.(Eltype, Tuple(s)))
SubsystemStates{Name, Eltype, typeof(nt)}(nt)
end

#------------------------------------------------------------
# Subsystem
function Subsystem{T}(;states, params) where {T}
ET = eltype(states)
Subsystem{T, ET, typeof(states), typeof(params)}(SubsystemStates{T}(states), SubsystemParams{T}(params))
Subsystem{T}(SubsystemStates{T}(states), SubsystemParams{T}(params))
end
function Subsystem{T}(states::SubsystemStates{T, Eltype, States},
params::SubsystemParams{T, Params}) where {T, Eltype, States, Params}
Subsystem{T, Eltype, States, Params}(states, params)
end

function Base.show(io::IO, sys::Subsystem{Name, Eltype}) where {Name, Eltype}
print(io,
"$Subsystem{$Name, $Eltype}(states = ",
Expand All @@ -124,15 +141,40 @@ function ConstructionBase.setproperties(s::Subsystem{T, Eltype, States, Params},
Subsystem{T, Eltype, States, Params}(SubsystemStates{T}(states′), SubsystemParams{T}(params′))
end

function Base.convert(::Type{Subsystem{Name, Eltype, SNT, PNT}}, s::Subsystem{Name}) where {Name, Eltype, SNT, PNT}
Subsystem{Name}(convert(SubsystemStates{Name, Eltype, SNT}, get_states(s)),
convert(SubsystemParams{Name, PNT}, get_params(s)))
end
function Base.convert(::Type{Subsystem{Name, Eltype}}, s::Subsystem{Name}) where {Name, Eltype}
Subsystem{Name}(convert(SubsystemStates{Name, Eltype}, get_states(s)), get_params(s))
end

@generated function promote_nt_type(::Type{NamedTuple{names, T1}},
::Type{NamedTuple{names, T2}}) where {names, T1, T2}
NamedTuple{names, Tuple{(promote_type(T1.parameters[i], T2.parameters[i]) for i ∈ eachindex(names))...}}
end

function Base.promote_rule(::Type{SubsystemParams{Name, NT1}},
::Type{SubsystemParams{Name, NT2}}) where {Name, NT1, NT2}
SubsystemParams{Name, promote_nt_type(NT1, NT2)}
end
function Base.promote_rule(::Type{SubsystemStates{Name, ET1, NT1}},
::Type{SubsystemStates{Name, ET2, NT2}}) where {Name, ET1, ET2, NT1, NT2}
SubsystemStates{Name, promote_type(ET1, ET2), promote_nt_type(NT1, NT2)}
end

function Base.promote_rule(::Type{Subsystem{Name, ET1, SNT1, PNT1}},
::Type{Subsystem{Name, ET2, SNT2, PNT2}}) where {Name, ET1, SNT1, PNT1, ET2, SNT2, PNT2}
Subsystem{Name, promote_type(ET1, ET2), promote_nt_type(SNT1, SNT2), promote_nt_type(PNT1, PNT2)}
end

get_states(s::Subsystem) = getfield(s, :states)
get_params(s::Subsystem) = getfield(s, :params)
get_name(::Subsystem{Name}) where {Name} = Name
get_tag(::Subsystem{Name}) where {Name} = Name



get_name(::Type{<:Subsystem{Name}}) where {Name} = Name
get_tag(::Type{<:Subsystem{Name}}) where {Name} = Name


function Base.getproperty(s::Subsystem{<:Any, States, Params},
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GraphDynamics = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
125 changes: 71 additions & 54 deletions test/particle_osc_example.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using GraphDynamics, OrdinaryDiffEq, Test
using GraphDynamics, OrdinaryDiffEq, Test, ForwardDiff

struct Particle end
function GraphDynamics.subsystem_differential(sys::Subsystem{Particle}, F, t)
Expand Down Expand Up @@ -34,56 +34,73 @@ function ((;fac)::Coulomb)(a, b)
-fac * a.q * b.q * sign(a.x - b.x)/(abs(a.x - b.x) + 1e-10)^2
end

# put some garbage values in here for states and params, but we'll set them to reasonable values later with the
# u0map and param_map
subsystems_partitioned = ([Subsystem{Particle}(states=(;x= NaN, v=0.0), params=(;m=1.0, q=1.0)),
Subsystem{Particle}(states=(;x=-1.0, v=Inf), params=(;m=2.0, q=1.0))],
[Subsystem{Oscillator}(states=(;x=-Inf, v=1.0), params=(;x₀=0.0, m=-3000.0, k=1.0, q=1.0))])

states_partitioned = map(v -> map(get_states, v), subsystems_partitioned)
params_partitioned = map(v -> map(get_params, v), subsystems_partitioned)
names_partitioned = ([:particle1, :particle2], [:osc])

spring_conns_par_par = NotConnected()
spring_conns_par_osc = [Spring(1)
Spring(1);;]
spring_conns_osc_par = [Spring(1) Spring(1)]
spring_conns_osc_osc = NotConnected()

spring_conns = ConnectionMatrix(((spring_conns_par_par, spring_conns_par_osc),
(spring_conns_osc_par, spring_conns_osc_osc)))

# Spring[⎡. .⎤ ⎡2⎤
# ⎣. .⎦ ⎣0⎦
# [2 0] [.]]

coulomb_conns_par_par = [Coulomb(0) Coulomb(.05)
Coulomb(.05) Coulomb(0)]
coulomb_conns_par_osc = [Coulomb(.05)
Coulomb(.05);;]
coulomb_conns_osc_par = [Coulomb(.05) Coulomb(.05)]
coulomb_conns_osc_osc = NotConnected()

coulomb_conns = ConnectionMatrix(((coulomb_conns_par_par, coulomb_conns_par_osc),
(coulomb_conns_osc_par, coulomb_conns_osc_osc)))

# Coulomb[⎡0 1⎤ ⎡1⎤
# ⎣1 0⎦ ⎣1⎦
# [0 0] [.]]

connection_matrices = ConnectionMatrices((spring_conns, coulomb_conns))

sys = ODEGraphSystem(;connection_matrices, states_partitioned, params_partitioned, names_partitioned)
tspan = (0.0, 20.0)

prob = ODEProblem(sys,
# Fix the garbage state values
[:particle1₊x => 1.0, :particle2₊v => 0.0, :osc₊x => 0.0],
tspan,
# fix the garbage param values
[:osc₊m => 3.0])
sol = solve(prob, Tsit5())

@test sol[:particle1₊x][end] ≈ 1.4923823131014389
@test sol[:particle2₊x][end] ≈ -0.11189010002787175
@test sol[:osc₊x][end] ≈ 1.3175449091469553

function solve_particle_osc(;x1, x2)
# put some garbage values in here for states and params, but we'll set them to reasonable values later with the
# u0map and param_map
subsystems_partitioned = ([Subsystem{Particle}(states=(;x= NaN, v=0.0), params=(;m=1.0, q=1.0)),
Subsystem{Particle}(states=(;x=-1.0, v=Inf), params=(;m=2.0, q=1.0))],
[Subsystem{Oscillator}(states=(;x=-Inf, v=1.0), params=(;x₀=0.0, m=-3000.0, k=1.0, q=1.0))])

states_partitioned = map(v -> map(get_states, v), subsystems_partitioned)
params_partitioned = map(v -> map(get_params, v), subsystems_partitioned)
names_partitioned = ([:particle1, :particle2], [:osc])

spring_conns_par_par = NotConnected()
spring_conns_par_osc = [Spring(1)
Spring(1);;]
spring_conns_osc_par = [Spring(1) Spring(1)]
spring_conns_osc_osc = NotConnected()

spring_conns = ConnectionMatrix(((spring_conns_par_par, spring_conns_par_osc),
(spring_conns_osc_par, spring_conns_osc_osc)))

# Spring[⎡. .⎤ ⎡2⎤
# ⎣. .⎦ ⎣0⎦
# [2 0] [.]]

coulomb_conns_par_par = [Coulomb(0) Coulomb(.05)
Coulomb(.05) Coulomb(0)]
coulomb_conns_par_osc = [Coulomb(.05)
Coulomb(.05);;]
coulomb_conns_osc_par = [Coulomb(.05) Coulomb(.05)]
coulomb_conns_osc_osc = NotConnected()

coulomb_conns = ConnectionMatrix(((coulomb_conns_par_par, coulomb_conns_par_osc),
(coulomb_conns_osc_par, coulomb_conns_osc_osc)))

# Coulomb[⎡0 1⎤ ⎡1⎤
# ⎣1 0⎦ ⎣1⎦
# [0 0] [.]]

connection_matrices = ConnectionMatrices((spring_conns, coulomb_conns))

sys = ODEGraphSystem(;connection_matrices, states_partitioned, params_partitioned, names_partitioned)
tspan = (0.0, 20.0)

prob = ODEProblem(sys,
# Fix the garbage state values
[:particle1₊x => x1, :particle2₊x => x2, :particle2₊v => 0.0, :osc₊x => 0.0],
tspan,
# fix the garbage param values
[:osc₊m => 3.0])
sol = solve(prob, Tsit5())
end

@testset "solutions" begin
sol = solve_particle_osc(;x1=1.0, x2=-1.0)
@test sol[:particle1₊x][end] ≈ 1.4923823131014389 rtol=1e-7
@test sol[:particle2₊x][end] ≈ -0.11189010002787175 rtol=1e-7
@test sol[:osc₊x][end] ≈ 1.3175449091469553 rtol=1e-7
end

@testset "sensitivies" begin
jac = ForwardDiff.jacobian([1.0, -1.0]) do (x1, x2)
sol = solve_particle_osc(;x1, x2)
[sol[:particle1₊x][end], sol[:particle2₊x][end], sol[:osc₊x][end]]
end
@test jac ≈ [0.498565 -0.0161443
-1.92556 3.14649
-0.249007 0.808641] rtol=1e-5

end
Loading