Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BVProblem with constraints #3323

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,16 @@ MTKBifurcationKitExt = "BifurcationKit"
MTKChainRulesCoreExt = "ChainRulesCore"
MTKDeepDiffsExt = "DeepDiffs"
MTKHomotopyContinuationExt = "HomotopyContinuation"
MTKLabelledArraysExt = "LabelledArrays"
MTKInfiniteOptExt = "InfiniteOpt"
MTKLabelledArraysExt = "LabelledArrays"

[compat]
AbstractTrees = "0.3, 0.4"
ArrayInterface = "6, 7"
BifurcationKit = "0.4"
BlockArrays = "1.1"
BoundaryValueDiffEq = "5.12.0"
BoundaryValueDiffEqAscher = "1.1.0"
ChainRulesCore = "1"
Combinatorics = "1"
CommonSolve = "0.2.4"
Expand Down Expand Up @@ -139,8 +141,8 @@ SimpleNonlinearSolve = "0.1.0, 1, 2"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
StochasticDiffEq = "6.72.1"
StochasticDelayDiffEq = "1.8.1"
StochasticDiffEq = "6.72.1"
SymbolicIndexingInterface = "0.3.36"
SymbolicUtils = "3.7"
Symbolics = "6.19"
Expand All @@ -152,6 +154,8 @@ julia = "1.9"
[extras]
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
BoundaryValueDiffEqAscher = "7227322d-7511-4e07-9247-ad6ff830280e"
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
Expand Down Expand Up @@ -183,4 +187,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEq", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]
8 changes: 4 additions & 4 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ include("systems/imperative_affect.jl")
include("systems/callbacks.jl")
include("systems/problem_utils.jl")

include("systems/optimization/constraints_system.jl")
include("systems/optimization/optimizationsystem.jl")
include("systems/optimization/modelingtoolkitize.jl")

include("systems/nonlinear/nonlinearsystem.jl")
include("systems/nonlinear/homotopy_continuation.jl")
include("systems/diffeqs/odesystem.jl")
Expand All @@ -165,10 +169,6 @@ include("systems/discrete_system/discrete_system.jl")

include("systems/jumps/jumpsystem.jl")

include("systems/optimization/constraints_system.jl")
include("systems/optimization/optimizationsystem.jl")
include("systems/optimization/modelingtoolkitize.jl")

include("systems/pde/pdesystem.jl")

include("systems/sparsematrixclil.jl")
Expand Down
1 change: 1 addition & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,7 @@ for prop in [:eqs
:structure
:op
:constraints
:constraintsystem
:controls
:loss
:bcs
Expand Down
175 changes: 175 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,12 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
end

if !isnothing(get_constraintsystem(sys))
error("An ODESystem with constraints cannot be used to construct a regular ODEProblem.
Consider a BVProblem instead.")
end

f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
Expand All @@ -849,6 +855,175 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
end
get_callback(prob::ODEProblem) = prob.kwargs[:callback]

"""
```julia
SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
parammap = DiffEqBase.NullParameters();
constraints = nothing, guesses = nothing,
version = nothing, tgrad = false,
jac = true, sparse = true,
simplify = false,
kwargs...) where {iip}
```

Create a boundary value problem from the [`ODESystem`](@ref).

`u0map` is used to specify fixed initial values for the states. Every variable
must have either an initial guess supplied using `guesses` or a fixed initial
value specified using `u0map`.

Boundary value conditions are supplied to ODESystems
in the form of a ConstraintsSystem. These equations
should specify values that state variables should
take at specific points, as in `x(0.5) ~ 1`). More general constraints that
should hold over the entire solution, such as `x(t)^2 + y(t)^2`, should be
specified as one of the equations used to build the `ODESystem`.

If an ODESystem without `constraints` is specified, it will be treated as an initial value problem.

```julia
@parameters g t_c = 0.5
@variables x(..) y(t) [state_priority = 10] λ(t)
eqs = [D(D(x(t))) ~ λ * x(t)
D(D(y)) ~ λ * y - g
x(t)^2 + y^2 ~ 1]
cstr = [x(0.5) ~ 1]
@named cstrs = ConstraintsSystem(cstr, t)
@mtkbuild pend = ODESystem(eqs, t)

tspan = (0.0, 1.5)
u0map = [x(t) => 0.6, y => 0.8]
parammap = [g => 1]
guesses = [λ => 1]
constraints = [x(0.5) ~ 1]

bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
```

If the `ODESystem` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting
`BVProblem` must be solved using BVDAE solvers, such as Ascher.
"""
function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem(sys::AbstractODESystem,
u0map::StaticArray,
args...;
kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
end

function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
guesses = Dict(),
version = nothing, tgrad = false,
callback = nothing,
check_length = true,
warn_initialize_determined = true,
eval_expression = false,
eval_module = @__MODULE__,
kwargs...) where {iip, specialize}

if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
end
!isnothing(callback) && error("BVP solvers do not support callbacks.")

has_alg_eqs(sys) && error("The BVProblem constructor currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.

sts = unknowns(sys)
ps = parameters(sys)
constraintsys = get_constraintsystem(sys)

if !isnothing(constraintsys)
(length(constraints(constraintsys)) + length(u0map) > length(sts)) &&
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
end

# ODESystems without algebraic equations should use both fixed values + guesses
# for initialization.
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, _u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan, guesses,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)

stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]

bc = generate_function_bc(sys, u0, u0_idxs, tspan, iip)
return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
end

get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")

"""
generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)

Given an ODESystem with constraints, generate the boundary condition function to pass to boundary value problem solvers.
"""
function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
iv = get_iv(sys)
sts = get_unknowns(sys)
ps = get_ps(sys)
np = length(ps)
ns = length(sts)
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])

@variables sol(..)[1:ns] p[1:np]

conssys = get_constraintsystem(sys)
cons = Any[]
if !isnothing(conssys)
cons = [con.lhs - con.rhs for con in constraints(conssys)]

for st in get_unknowns(conssys)
x = operation(st)
t = only(arguments(st))
idx = stidxmap[x(iv)]

cons = map(c -> Symbolics.substitute(c, Dict(x(t) => sol(t)[idx])), cons)
end

for var in parameters(conssys)
if iscall(var)
x = operation(var)
t = only(arguments(var))
idx = pidxmap[x]

cons = map(c -> Symbolics.substitute(c, Dict(x(t) => p[idx])), cons)
else
idx = pidxmap[var]
cons = map(c -> Symbolics.substitute(c, Dict(var => p[idx])), cons)
end
end
end

init_conds = Any[]
for i in u0_idxs
expr = sol(tspan[1])[i] - u0[i]
push!(init_conds, expr)
end

exprs = vcat(init_conds, cons)
bcs = Symbolics.build_function(exprs, sol, p, expression = Val{false})
if iip
return (resid, u, p, t) -> bcs[2](resid, u, p)
else
return (u, p, t) -> bcs[1](u, p)
end
end

"""
```julia
DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
Expand Down
Loading