Skip to content

Commit

Permalink
Use of ProgressMeter.jl and added a reset_callback to DPWSolver (#69)
Browse files Browse the repository at this point in the history
* Added DPWSolver parameter 'enable_state_pw' for explicit enabling of state space progressive widening (defaults to true, i.e. previous behavior)

* Added 'reset_callback' option to DPWSolver

* ProgressMeter: added @showprogress around DPW simulate

* Added `show_progress=false` kwarg to DWP action_info

* Removed use of `isnothing` for Julia v1.0 compatibility

* Ensure no performance degradation when not using ProgressMeter

* Clarify use of `reset_callback`

* Added DWP test for ProgressMeter and `reset_callback`

* Removed left-over progress `dt`

* Optimized out reset_callback, moved show_progress to solver, and called finish! on progress meter after timeout

* Fixed `show_progress` test

* Relax Colors lower bound version requirement

* put reset_callback into planner

Co-authored-by: Zachary Sunberg <[email protected]>
  • Loading branch information
mossr and zsunberg authored Sep 7, 2020
1 parent 1434f7e commit 6782345
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 8 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4"
POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
CPUTime = "1"
Colors = "0.12"
Colors = "0.11, 0.12"
D3Trees = "0.3"
POMDPLinter = "0.1"
POMDPModelTools = "0.3"
Expand Down
1 change: 1 addition & 0 deletions src/MCTS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using POMDPSimulators
using CPUTime
using Random
using Printf
using ProgressMeter
using POMDPLinter: @show_requirements, requirements_info, @POMDP_require, @req, @subreq
import POMDPLinter

Expand Down
5 changes: 5 additions & 0 deletions src/dpw.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,19 @@ function POMDPModelTools.action_info(p::DPWPlanner, s; tree_in_info=false)
snode = insert_state_node!(tree, s, p.solver.check_repeat_state)
end

p.solver.show_progress ? progress = Progress(p.solver.n_iterations) : nothing
nquery = 0
start_us = CPUtime_us()
for i = 1:p.solver.n_iterations
nquery += 1
simulate(p, snode, p.solver.depth) # (not 100% sure we need to make a copy of the state here)
p.solver.show_progress ? next!(progress) : nothing
if CPUtime_us() - start_us >= p.solver.max_time * 1e6
p.solver.show_progress ? finish!(progress) : nothing
break
end
end
p.reset_callback(p.mdp, s) # Optional: leave the MDP in the current state.
info[:search_time_us] = CPUtime_us() - start_us
info[:tree_queries] = nquery
if p.solver.tree_in_info || tree_in_info
Expand Down Expand Up @@ -89,6 +93,7 @@ function simulate(dpw::DPWPlanner, snode::Int, d::Int)
sol = dpw.solver
tree = dpw.tree
s = tree.s_labels[snode]
dpw.reset_callback(dpw.mdp, s) # Optional: used to reset/reinitialize MDP to a given state.
if isterminal(dpw.mdp, s)
return 0.0
elseif d == 0
Expand Down
31 changes: 24 additions & 7 deletions src/dpw_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ MCTS solver with DPW
Fields:
depth::Int64:
depth::Int64
Maximum rollout horizon and tree depth.
default: 10
exploration_constant::Float64:
exploration_constant::Float64
Specified how much the solver should explore.
In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant.
default: 1.0
Expand Down Expand Up @@ -45,11 +45,11 @@ Fields:
When constructing the tree, check whether a state or action has been seen before (there is a computational cost to maintaining the dictionaries necessary for this)
default: true
tree_in_info::Bool:
tree_in_info::Bool
If true, return the tree in the info dict when action_info is called. False by default because it can use a lot of memory if histories are being saved.
default: false
rng::AbstractRNG:
rng::AbstractRNG
Random number generator
estimate_value::Any (rollout policy)
Expand Down Expand Up @@ -86,6 +86,16 @@ Fields:
If this is a Policy `p`, `action(p, belief)` will be called.
If it is an object `a`, `default_action(a, pomdp, belief, ex)` will be called, and if this method is not implemented, `a` will be returned directly.
default: `ExceptionRethrow()`
reset_callback::Function
Function used to reset/reinitialize the MDP to a given state `s`.
Useful when the simulator state is not truly separate from the MDP state.
`f(mdp, s)` will be called.
default: `(mdp, s)->false` (optimized out)
show_progress::Bool
Show progress bar during simulation.
default: false
"""
mutable struct DPWSolver <: AbstractMCTSSolver
depth::Int
Expand All @@ -108,6 +118,8 @@ mutable struct DPWSolver <: AbstractMCTSSolver
init_N::Any
next_action::Any
default_action::Any
reset_callback::Function
show_progress::Bool
end

"""
Expand All @@ -134,9 +146,11 @@ function DPWSolver(;depth::Int=10,
init_Q::Any = 0.0,
init_N::Any = 0,
next_action::Any = RandomActionGenerator(rng),
default_action::Any = ExceptionRethrow()
default_action::Any = ExceptionRethrow(),
reset_callback::Function = (mdp, s)->false,
show_progress::Bool = false,
)
DPWSolver(depth, exploration_constant, n_iterations, max_time, k_action, alpha_action, k_state, alpha_state, keep_tree, enable_action_pw, enable_state_pw, check_repeat_state, check_repeat_action, tree_in_info, rng, estimate_value, init_Q, init_N, next_action, default_action)
DPWSolver(depth, exploration_constant, n_iterations, max_time, k_action, alpha_action, k_state, alpha_state, keep_tree, enable_action_pw, enable_state_pw, check_repeat_state, check_repeat_action, tree_in_info, rng, estimate_value, init_Q, init_N, next_action, default_action, reset_callback, show_progress)
end

#=
Expand Down Expand Up @@ -234,12 +248,13 @@ children(n::DPWStateNode) = n.tree.children[n.index]
n_children(n::DPWStateNode) = length(children(n))
isroot(n::DPWStateNode) = n.index == 1

mutable struct DPWPlanner{P<:Union{MDP,POMDP}, S, A, SE, NA, RNG} <: AbstractMCTSPlanner{P}
mutable struct DPWPlanner{P<:Union{MDP,POMDP}, S, A, SE, NA, RCB, RNG} <: AbstractMCTSPlanner{P}
solver::DPWSolver
mdp::P
tree::Union{Nothing, DPWTree{S,A}}
solved_estimate::SE
next_action::NA
reset_callback::RCB
rng::RNG
end

Expand All @@ -250,11 +265,13 @@ function DPWPlanner(solver::DPWSolver, mdp::P) where P<:Union{POMDP,MDP}
actiontype(P),
typeof(se),
typeof(solver.next_action),
typeof(solver.reset_callback),
typeof(solver.rng)}(solver,
mdp,
nothing,
se,
solver.next_action,
solver.reset_callback,
solver.rng
)
end
Expand Down
11 changes: 11 additions & 0 deletions test/dpw_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,15 @@ let
state = GridWorldState(1,1)

a = @inferred action(policy, state)


# ProgressMeter and reset_callback test
solver = DPWSolver(n_iterations=n_iter, depth=depth, exploration_constant=ec, reset_callback=(mdp,s)->nothing)
mdp = LegacyGridWorld()

policy = solve(solver, mdp)

state = GridWorldState(1,1)

@inferred action_info(policy, state)
end

0 comments on commit 6782345

Please sign in to comment.