From 58a8babaa396d864dc81e339c6f266a29105dfe4 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Fri, 15 Nov 2024 12:00:58 +1300 Subject: [PATCH] Add root_node_risk_measure kwarg to SDDP.train --- src/algorithm.jl | 16 +++++++++++++++- src/plugins/parallel_schemes.jl | 5 ++++- test/algorithm.jl | 30 ++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/algorithm.jl b/src/algorithm.jl index 1b84bc355..295346754 100644 --- a/src/algorithm.jl +++ b/src/algorithm.jl @@ -116,6 +116,7 @@ struct Options{T} last_log_iteration::Ref{Int} # For threading lock::ReentrantLock + root_node_risk_measure::AbstractRiskMeasure # Internal function: users should never construct this themselves. function Options( model::PolicyGraph{T}, @@ -136,6 +137,7 @@ struct Options{T} duality_handler::AbstractDualityHandler = ContinuousConicDuality(), forward_pass_callback = x -> nothing, post_iteration_callback = result -> nothing, + root_node_risk_measure::AbstractRiskMeasure = Expectation(), ) where {T} return new{T}( initial_state, @@ -160,6 +162,7 @@ struct Options{T} post_iteration_callback, Ref{Int}(0), # last_log_iteration ReentrantLock(), + root_node_risk_measure, ) end end @@ -946,7 +949,10 @@ function iteration(model::PolicyGraph{T}, options::Options) where {T} ) end @_timeit_threadsafe model.timer_output "calculate_bound" begin - bound = calculate_bound(model) + bound = calculate_bound( + model; + risk_measure = options.root_node_risk_measure, + ) end lock(options.lock) try @@ -1036,6 +1042,12 @@ Train the policy for `model`. - `risk_measure`: the risk measure to use at each node. Defaults to [`Expectation`](@ref). + - `root_node_risk_measure::AbstractRiskMeasure`: the risk measure to use at + the root node when computing the `Bound` column. Note that the choice of + this option does not change the primal policy, and it applies only if the + transition from the root node to the first stage is stochastic. Defaults to + [`Expectation`](@ref). + - `sampling_scheme`: a sampling scheme to use on the forward pass of the algorithm. Defaults to [`InSampleMonteCarlo`](@ref). @@ -1086,6 +1098,7 @@ function train( run_numerical_stability_report::Bool = true, stopping_rules = AbstractStoppingRule[], risk_measure = SDDP.Expectation(), + root_node_risk_measure::AbstractRiskMeasure = Expectation(), sampling_scheme = SDDP.InSampleMonteCarlo(), cut_type = SDDP.SINGLE_CUT, cycle_discretization_delta::Float64 = 0.0, @@ -1248,6 +1261,7 @@ function train( duality_handler, forward_pass_callback, post_iteration_callback, + root_node_risk_measure, ) status = :not_solved try diff --git a/src/plugins/parallel_schemes.jl b/src/plugins/parallel_schemes.jl index 4eee4a031..2e7c8d38b 100644 --- a/src/plugins/parallel_schemes.jl +++ b/src/plugins/parallel_schemes.jl @@ -284,7 +284,10 @@ function master_loop( end end slave_update(model, result) - bound = calculate_bound(model) + bound = calculate_bound( + model; + risk_measure = options.root_node_risk_measure, + ) push!( options.log, Log( diff --git a/test/algorithm.jl b/test/algorithm.jl index 8e555affb..b17d7e0de 100644 --- a/test/algorithm.jl +++ b/test/algorithm.jl @@ -387,6 +387,36 @@ function test_numerical_difficulty_callback() return end +function test_root_node_risk_measure() + model = SDDP.LinearPolicyGraph(; + stages = 3, + lower_bound = 0.0, + optimizer = HiGHS.Optimizer, + ) do sp, stage + @variable(sp, 0 <= x <= 100, SDDP.State, initial_value = 0) + @variable(sp, 0 <= u_p <= 200) + @variable(sp, u_o >= 0) + @variable(sp, w) + SDDP.parameterize(ω -> JuMP.fix(w, ω), sp, [100, 300]) + @constraint(sp, x.out == x.in + u_p + u_o - w) + @stageobjective(sp, 100 * u_p + 300 * u_o + 50 * x.out) + end + SDDP.train(model; root_node_risk_measure = SDDP.WorstCase()) + @test isapprox( + model.most_recent_training_results.log[end].bound, + 107500.0, + ) + @test isapprox( + SDDP.calculate_bound(model; risk_measure = SDDP.WorstCase()), + 107500.0, + ) + @test isapprox( + SDDP.calculate_bound(model; risk_measure = SDDP.Expectation()), + 85000.0, + ) + return +end + end # module TestAlgorithm.runtests()