Skip to content

Commit

Permalink
Add sample_backward_noise_terms_with_state (#742)
Browse files Browse the repository at this point in the history
  • Loading branch information
arthur-brigatto authored Jun 4, 2024
1 parent 8ed2e30 commit fc2d2a8
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -659,8 +659,11 @@ function solve_all_children(
continue
end
child_node = model[child.term]
for noise in
sample_backward_noise_terms(backward_sampling_scheme, child_node)
for noise in sample_backward_noise_terms_with_state(
backward_sampling_scheme,
child_node,
outgoing_state,
)
if length(scenario_path) == length_scenario_path
push!(scenario_path, (child.term, noise.term))
else
Expand Down
20 changes: 19 additions & 1 deletion src/plugins/headers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,28 @@ abstract type AbstractBackwardSamplingScheme end
)::Vector{Noise}
Returns a `Vector{Noise}` of noises sampled from `node.noise_terms` using
`backward_sampling_scheme`
`backward_sampling_scheme`.
"""
function sample_backward_noise_terms end

"""
sample_backward_noise_terms_with_state(
sampler::AbstractBackwardSamplingScheme,
node::Node,
state::Dict{Symbol,Float64},
)::Vector{Noise}
Returns a `Vector{Noise}` of noises sampled conditionally on the `state` using
`sampler`.
"""
function sample_backward_noise_terms_with_state(
sampler::AbstractBackwardSamplingScheme,
node::Node,
::Dict{Symbol,Float64},
)
return sample_backward_noise_terms(sampler, node)
end

# =========================== duality_handlers =========================== #

"""
Expand Down
60 changes: 60 additions & 0 deletions test/plugins/backward_sampling_schemes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module TestBackwardPassSamplingSchemes

using SDDP
using Test
import HiGHS

function runtests()
for name in names(@__MODULE__; all = true)
Expand Down Expand Up @@ -89,6 +90,65 @@ function test_MonteCarloSampler_100()
return
end

mutable struct WithStateSampler <: SDDP.AbstractBackwardSamplingScheme
number_of_samples::Int
end

function test_WithStateSampler()
function sample_backward_noise_terms_with_state(
sampler::WithStateSampler,
node::SDDP.Node,
state::Dict{Symbol,Float64},
)
if state[:x] / node.index == 1.0
return [
SDDP.Noise((ϵ = 3.0,), 1 / sampler.number_of_samples) for
i in 1:sampler.number_of_samples
]
elseif state[:x] / node.index == 3.0
return [
SDDP.Noise((ϵ = 1.0,), 1 / sampler.number_of_samples) for
i in 1:sampler.number_of_samples
]
end
end
model = SDDP.LinearPolicyGraph(
stages = 5,
lower_bound = 0.0,
direct_mode = false,
optimizer = HiGHS.Optimizer,
) do node, stage
@variable(node, x, SDDP.State, initial_value = 0.0)
@variable(node, ϵ)
SDDP.parameterize(node, stage * [1, 3], [0.9, 0.1]) do ω
return JuMP.fix(ϵ, ω)
end
@constraint(node, x.out == ϵ)
end
forward_trajectory = SDDP.forward_pass(
model,
SDDP.Options(model, Dict(:x => 1.0)),
SDDP.DefaultForwardPass(),
)
for node_index in 1:length(forward_trajectory.scenario_path)
state = forward_trajectory.sampled_states[node_index]
terms = sample_backward_noise_terms_with_state(
WithStateSampler(100),
model[node_index],
state,
)
for term in terms
@test term.probability == 0.01
if state[:x] / node_index == 1.0
@test term.term.ϵ == 3.0
elseif state[:x] / node_index == 3.0
@test term.term.ϵ == 1.0
end
end
end
return
end

end # module

TestBackwardPassSamplingSchemes.runtests()

0 comments on commit fc2d2a8

Please sign in to comment.