Skip to content

Commit

Permalink
WIP: add Threaded parallel scheme
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Jul 23, 2024
1 parent 2d7a193 commit f466f91
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 22 deletions.
65 changes: 45 additions & 20 deletions src/algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

macro _timeit_threadsafe(timer, label, block)
return esc(quote
if Threads.threadid() == 1
TimerOutputs.@timeit $timer $label $block
else
$block

Check warning on line 11 in src/algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithm.jl#L11

Added line #L11 was not covered by tests
end
end)
end

# to_nodal_form is an internal helper function so users can pass arguments like:
# risk_measure = SDDP.Expectation(),
# risk_measure = Dict(1=>Expectation(), 2=>WorstCase())
Expand Down Expand Up @@ -101,6 +111,8 @@ struct Options{T}
forward_pass_callback::Any
post_iteration_callback::Any
last_log_iteration::Ref{Int}
# For threading
lock::ReentrantLock
# Internal function: users should never construct this themselves.
function Options(
model::PolicyGraph{T},
Expand Down Expand Up @@ -144,6 +156,7 @@ struct Options{T}
forward_pass_callback,
post_iteration_callback,
Ref{Int}(0), # last_log_iteration
ReentrantLock(),
)
end
end
Expand Down Expand Up @@ -423,7 +436,7 @@ function solve_subproblem(
end
state = get_outgoing_state(node)
stage_objective = stage_objective_value(node.stage_objective)
TimerOutputs.@timeit model.timer_output "get_dual_solution" begin
@_timeit_threadsafe model.timer_output "get_dual_solution" begin

Check warning on line 439 in src/algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithm.jl#L439

Added line #L439 was not covered by tests
objective, dual_values = get_dual_solution(node, duality_handler)
end
if node.post_optimize_hook !== nothing
Expand Down Expand Up @@ -505,7 +518,7 @@ function backward_pass(
objective_states::Vector{NTuple{N,Float64}},
belief_states::Vector{Tuple{Int,Dict{T,Float64}}},
) where {T,NoiseType,N}
TimerOutputs.@timeit model.timer_output "prepare_backward_pass" begin
@_timeit_threadsafe model.timer_output "prepare_backward_pass" begin
restore_duality =
prepare_backward_pass(model, options.duality_handler, options)
end
Expand Down Expand Up @@ -613,7 +626,7 @@ function backward_pass(
end
end
end
TimerOutputs.@timeit model.timer_output "prepare_backward_pass" begin
@_timeit_threadsafe model.timer_output "prepare_backward_pass" begin

Check warning on line 629 in src/algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithm.jl#L629

Added line #L629 was not covered by tests
restore_duality()
end
return cuts
Expand Down Expand Up @@ -695,7 +708,7 @@ function solve_all_children(
noise.term,
)
end
TimerOutputs.@timeit model.timer_output "solve_subproblem" begin
@_timeit_threadsafe model.timer_output "solve_subproblem" begin

Check warning on line 711 in src/algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithm.jl#L711

Added line #L711 was not covered by tests
subproblem_results = solve_subproblem(
model,
child_node,
Expand Down Expand Up @@ -812,11 +825,11 @@ end

function iteration(model::PolicyGraph{T}, options::Options) where {T}
model.ext[:numerical_issue] = false
TimerOutputs.@timeit model.timer_output "forward_pass" begin
@_timeit_threadsafe model.timer_output "forward_pass" begin

Check warning on line 828 in src/algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithm.jl#L828

Added line #L828 was not covered by tests
forward_trajectory = forward_pass(model, options, options.forward_pass)
options.forward_pass_callback(forward_trajectory)
end
TimerOutputs.@timeit model.timer_output "backward_pass" begin
@_timeit_threadsafe model.timer_output "backward_pass" begin

Check warning on line 832 in src/algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithm.jl#L832

Added line #L832 was not covered by tests
cuts = backward_pass(
model,
options,
Expand All @@ -826,22 +839,27 @@ function iteration(model::PolicyGraph{T}, options::Options) where {T}
forward_trajectory.belief_states,
)
end
TimerOutputs.@timeit model.timer_output "calculate_bound" begin
@_timeit_threadsafe model.timer_output "calculate_bound" begin

Check warning on line 842 in src/algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithm.jl#L842

Added line #L842 was not covered by tests
bound = calculate_bound(model)
end
push!(
options.log,
Log(
length(options.log) + 1,
bound,
forward_trajectory.cumulative_value,
time() - options.start_time,
Distributed.myid(),
model.ext[:total_solves],
duality_log_key(options.duality_handler),
model.ext[:numerical_issue],
),
)
lock(options.lock)
try
push!(
options.log,
Log(
length(options.log) + 1,
bound,
forward_trajectory.cumulative_value,
time() - options.start_time,
Distributed.myid(),
model.ext[:total_solves],
duality_log_key(options.duality_handler),
model.ext[:numerical_issue],
),
)
finally
unlock(options.lock)
end
has_converged, status =
convergence_test(model, options.log, options.stopping_rules)
return IterationResult(
Expand Down Expand Up @@ -1130,6 +1148,11 @@ function train(
finally
# And close the dashboard callback if necessary.
dashboard_callback(nothing, true)
for node in values(model.nodes)
if islocked(node.lock)
unlock(node.lock)
end
end
end
training_results = TrainingResults(status, log)
model.most_recent_training_results = training_results
Expand Down Expand Up @@ -1177,6 +1200,7 @@ function _simulate(
objective_states = NTuple{N,Float64}[]
for (depth, (node_index, noise)) in enumerate(scenario_path)
node = model[node_index]
lock(node.lock) # LOCK-ID-002
# Objective state interpolation.
objective_state_vector = update_objective_state(
node.objective_state,
Expand Down Expand Up @@ -1253,6 +1277,7 @@ function _simulate(
push!(simulation, store)
# Set outgoing state as the incoming state for the next node.
incoming_state = copy(subproblem_results.state)
unlock(node.lock) # LOCK-ID-002
end
return simulation
end
Expand Down
6 changes: 4 additions & 2 deletions src/plugins/forward_passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function forward_pass(
) where {T}
# First up, sample a scenario. Note that if a cycle is detected, this will
# return the cycle node as well.
TimerOutputs.@timeit model.timer_output "sample_scenario" begin
@_timeit_threadsafe model.timer_output "sample_scenario" begin
scenario_path, terminated_due_to_cycle =
sample_scenario(model, options.sampling_scheme)
end
Expand All @@ -51,6 +51,7 @@ function forward_pass(
# Iterate down the scenario.
for (depth, (node_index, noise)) in enumerate(scenario_path)
node = model[node_index]
lock(node.lock) # LOCK-ID-001
# Objective state interpolation.
objective_state_vector = update_objective_state(
node.objective_state,
Expand Down Expand Up @@ -94,7 +95,7 @@ function forward_pass(
end
# ===== End: starting state for infinite horizon =====
# Solve the subproblem, note that `duality_handler = nothing`.
TimerOutputs.@timeit model.timer_output "solve_subproblem" begin
@_timeit_threadsafe model.timer_output "solve_subproblem" begin
subproblem_results = solve_subproblem(
model,
node,
Expand All @@ -112,6 +113,7 @@ function forward_pass(
# Add the outgoing state variable to the list of states we have sampled
# on this forward pass.
push!(sampled_states, incoming_state_value)
unlock(node.lock) # LOCK-ID-001
end
if terminated_due_to_cycle
# We terminated due to a cycle. Here is the list of possible starting
Expand Down
49 changes: 49 additions & 0 deletions src/plugins/parallel_schemes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,52 @@ function _simulate(
end
return
end

"""
Threaded()
Run SDDP in threaded mode.
"""
struct Threaded <: AbstractParallelScheme end

Base.show(io::IO, ::Threaded) = print(io, "Threaded()")

Check warning on line 338 in src/plugins/parallel_schemes.jl

View check run for this annotation

Codecov / codecov/patch

src/plugins/parallel_schemes.jl#L338

Added line #L338 was not covered by tests

interrupt(::Threaded) = nothing

Check warning on line 340 in src/plugins/parallel_schemes.jl

View check run for this annotation

Codecov / codecov/patch

src/plugins/parallel_schemes.jl#L340

Added line #L340 was not covered by tests

function master_loop(

Check warning on line 342 in src/plugins/parallel_schemes.jl

View check run for this annotation

Codecov / codecov/patch

src/plugins/parallel_schemes.jl#L342

Added line #L342 was not covered by tests
::Threaded,
model::PolicyGraph{T},
options::Options,
) where {T}
_initialize_solver(model; throw_error = false)
while true
result = iteration(model, options)
lock(options.lock) do
options.post_iteration_callback(result)
log_iteration(options)

Check warning on line 352 in src/plugins/parallel_schemes.jl

View check run for this annotation

Codecov / codecov/patch

src/plugins/parallel_schemes.jl#L347-L352

Added lines #L347 - L352 were not covered by tests
end
if result.has_converged
return result.status

Check warning on line 355 in src/plugins/parallel_schemes.jl

View check run for this annotation

Codecov / codecov/patch

src/plugins/parallel_schemes.jl#L354-L355

Added lines #L354 - L355 were not covered by tests
end
end
return

Check warning on line 358 in src/plugins/parallel_schemes.jl

View check run for this annotation

Codecov / codecov/patch

src/plugins/parallel_schemes.jl#L357-L358

Added lines #L357 - L358 were not covered by tests
end

function _simulate(

Check warning on line 361 in src/plugins/parallel_schemes.jl

View check run for this annotation

Codecov / codecov/patch

src/plugins/parallel_schemes.jl#L361

Added line #L361 was not covered by tests
model::PolicyGraph,
::Threaded,
number_replications::Int,
variables::Vector{Symbol};
kwargs...,
)
_initialize_solver(model; throw_error = false)
ret = Vector{Dict{Symbol,Any}}[]
ret_lock = ReentrantLock()
Threads.@threads for _ in 1:number_replications
simulation = _simulate(model, variables; kwargs...)
lock(ret_lock) do
push!(ret, simulation)

Check warning on line 374 in src/plugins/parallel_schemes.jl

View check run for this annotation

Codecov / codecov/patch

src/plugins/parallel_schemes.jl#L368-L374

Added lines #L368 - L374 were not covered by tests
end
end
return ret

Check warning on line 377 in src/plugins/parallel_schemes.jl

View check run for this annotation

Codecov / codecov/patch

src/plugins/parallel_schemes.jl#L376-L377

Added lines #L376 - L377 were not covered by tests
end
3 changes: 3 additions & 0 deletions src/user_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,8 @@ mutable struct Node{T}
# An extension dictionary. This is a useful place for packages that extend
# SDDP.jl to stash things.
ext::Dict{Symbol,Any}
# Lock for threading
lock::ReentrantLock
end

function Base.show(io::IO, node::Node)
Expand Down Expand Up @@ -990,6 +992,7 @@ function PolicyGraph(
direct_mode ? nothing : optimizer,
# The extension dictionary.
Dict{Symbol,Any}(),
ReentrantLock(),
)
subproblem.ext[:sddp_policy_graph] = policy_graph
policy_graph.nodes[node_index] = subproblem.ext[:sddp_node] = node
Expand Down

0 comments on commit f466f91

Please sign in to comment.