From fc2d2a875683a6389b09d0eeb544ccd58e63dc50 Mon Sep 17 00:00:00 2001 From: arthur-brigatto <103693830+arthur-brigatto@users.noreply.github.com> Date: Mon, 3 Jun 2024 23:53:21 -0300 Subject: [PATCH] Add sample_backward_noise_terms_with_state (#742) --- src/algorithm.jl | 7 ++- src/plugins/headers.jl | 20 +++++++- test/plugins/backward_sampling_schemes.jl | 60 +++++++++++++++++++++++ 3 files changed, 84 insertions(+), 3 deletions(-) diff --git a/src/algorithm.jl b/src/algorithm.jl index 92bce7d93..9f7b46555 100644 --- a/src/algorithm.jl +++ b/src/algorithm.jl @@ -659,8 +659,11 @@ function solve_all_children( continue end child_node = model[child.term] - for noise in - sample_backward_noise_terms(backward_sampling_scheme, child_node) + for noise in sample_backward_noise_terms_with_state( + backward_sampling_scheme, + child_node, + outgoing_state, + ) if length(scenario_path) == length_scenario_path push!(scenario_path, (child.term, noise.term)) else diff --git a/src/plugins/headers.jl b/src/plugins/headers.jl index 809b96750..38cfb179c 100644 --- a/src/plugins/headers.jl +++ b/src/plugins/headers.jl @@ -115,10 +115,28 @@ abstract type AbstractBackwardSamplingScheme end )::Vector{Noise} Returns a `Vector{Noise}` of noises sampled from `node.noise_terms` using -`backward_sampling_scheme` +`backward_sampling_scheme`. """ function sample_backward_noise_terms end +""" + sample_backward_noise_terms_with_state( + sampler::AbstractBackwardSamplingScheme, + node::Node, + state::Dict{Symbol,Float64}, + )::Vector{Noise} + +Returns a `Vector{Noise}` of noises sampled conditionally on the `state` using +`sampler`. +""" +function sample_backward_noise_terms_with_state( + sampler::AbstractBackwardSamplingScheme, + node::Node, + ::Dict{Symbol,Float64}, +) + return sample_backward_noise_terms(sampler, node) +end + # =========================== duality_handlers =========================== # """ diff --git a/test/plugins/backward_sampling_schemes.jl b/test/plugins/backward_sampling_schemes.jl index 8e2d2a9d9..8c172bd13 100644 --- a/test/plugins/backward_sampling_schemes.jl +++ b/test/plugins/backward_sampling_schemes.jl @@ -7,6 +7,7 @@ module TestBackwardPassSamplingSchemes using SDDP using Test +import HiGHS function runtests() for name in names(@__MODULE__; all = true) @@ -89,6 +90,65 @@ function test_MonteCarloSampler_100() return end +mutable struct WithStateSampler <: SDDP.AbstractBackwardSamplingScheme + number_of_samples::Int +end + +function test_WithStateSampler() + function sample_backward_noise_terms_with_state( + sampler::WithStateSampler, + node::SDDP.Node, + state::Dict{Symbol,Float64}, + ) + if state[:x] / node.index == 1.0 + return [ + SDDP.Noise((ϵ = 3.0,), 1 / sampler.number_of_samples) for + i in 1:sampler.number_of_samples + ] + elseif state[:x] / node.index == 3.0 + return [ + SDDP.Noise((ϵ = 1.0,), 1 / sampler.number_of_samples) for + i in 1:sampler.number_of_samples + ] + end + end + model = SDDP.LinearPolicyGraph( + stages = 5, + lower_bound = 0.0, + direct_mode = false, + optimizer = HiGHS.Optimizer, + ) do node, stage + @variable(node, x, SDDP.State, initial_value = 0.0) + @variable(node, ϵ) + SDDP.parameterize(node, stage * [1, 3], [0.9, 0.1]) do ω + return JuMP.fix(ϵ, ω) + end + @constraint(node, x.out == ϵ) + end + forward_trajectory = SDDP.forward_pass( + model, + SDDP.Options(model, Dict(:x => 1.0)), + SDDP.DefaultForwardPass(), + ) + for node_index in 1:length(forward_trajectory.scenario_path) + state = forward_trajectory.sampled_states[node_index] + terms = sample_backward_noise_terms_with_state( + WithStateSampler(100), + model[node_index], + state, + ) + for term in terms + @test term.probability == 0.01 + if state[:x] / node_index == 1.0 + @test term.term.ϵ == 3.0 + elseif state[:x] / node_index == 3.0 + @test term.term.ϵ == 1.0 + end + end + end + return +end + end # module TestBackwardPassSamplingSchemes.runtests()