Skip to content

Commit

Permalink
updates to fix issues associated with Flux 0.14
Browse files Browse the repository at this point in the history
  • Loading branch information
dylan-asmar committed Jan 13, 2025
1 parent bdaa6cb commit 4e37bc9
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 23 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ log*
*.bson
events.out.tfevents*
.vscode
Manifest.toml
Manifest.toml
.DS_Store
24 changes: 12 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ using DeepQLearning
using POMDPs
using Flux
using POMDPModels
using POMDPSimulators
using POMDPTools

# load MDP model from POMDPModels or define your own!
Expand All @@ -37,7 +36,7 @@ mdp = SimpleGridWorld();
# the gridworld state is represented by a 2 dimensional vector.
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2))
exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2));

solver = DeepQLearningSolver(qnetwork = model, max_steps=10000,
exploration_policy = exploration,
Expand Down Expand Up @@ -99,39 +98,40 @@ mdp = SimpleGridWorld();
# the model weights will be send to the gpu in the call to solve
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

solver = DeepQLearningSolver(qnetwork = model, max_steps=10000,
learning_rate=0.005,log_freq=500,
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2));

solver = DeepQLearningSolver(qnetwork=model, max_steps=10000,
exploration_policy=exploration,
learning_rate=0.005,log_freq=500,
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
policy = solve(solver, mdp)
```

## Solver Options

**Fields of the Q Learning solver:**
- `qnetwork::Any = nothing` Specify the architecture of the Q network
- `exploration_policy::<ExplorationPolicy` Exploration strategy (e.g. EpsGreedyPolicy)
- `learning_rate::Float64 = 1e-4` learning rate
- `max_steps::Int64` total number of training step default = 1000
- `target_update_freq::Int64` frequency at which the target network is updated default = 500
- `batch_size::Int64` batch size sampled from the replay buffer default = 32
- `train_freq::Int64` frequency at which the active network is updated default = 4
- `log_freq::Int64` frequency at which to logg info default = 100
- `eval_freq::Int64` frequency at which to eval the network default = 100
- `target_update_freq::Int64` frequency at which the target network is updated default = 500
- `num_ep_eval::Int64` number of episodes to evaluate the policy default = 100
- `eps_fraction::Float64` fraction of the training set used to explore default = 0.5
- `eps_end::Float64` value of epsilon at the end of the exploration phase default = 0.01
- `double_q::Bool` double q learning udpate default = true
- `dueling::Bool` dueling structure for the q network default = true
- `recurrence::Bool = false` set to true to use DRQN, it will throw an error if you set it to false and pass a recurrent model.
- `evaluation_policy::Function = basic_evaluation` function use to evaluate the policy every `eval_freq` steps, the default is a rollout that return the undiscounted average reward
- `prioritized_replay::Bool` enable prioritized experience replay default = true
- `prioritized_replay_alpha::Float64` default = 0.6
- `prioritized_replay_epsilon::Float64` default = 1e-6
- `prioritized_replay_beta::Float64` default = 0.4
- `buffer_size::Int64` size of the experience replay buffer default = 1000
- `max_episode_length::Int64` maximum length of a training episode default = 100
- `train_start::Int64` number of steps used to fill in the replay buffer initially default = 200
- `save_freq::Int64` save the model every `save_freq` steps, default = 1000
- `evaluation_policy::Function = basic_evaluation` function use to evaluate the policy every `eval_freq` steps, the default is a rollout that return the undiscounted average reward
- `exploration_policy::Any = linear_epsilon_greedy(max_steps, eps_fraction, eps_end)` exploration strategy (default is epsilon greedy with linear decay)
- `rng::AbstractRNG` random number generator default = MersenneTwister(0)
- `logdir::String = ""` folder in which to save the model
- `save_freq::Int64` save the model every `save_freq` steps, default = 1000
- `log_freq::Int64` frequency at which to logg info default = 100
- `verbose::Bool` default = true
2 changes: 1 addition & 1 deletion src/dueling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function (m::DuelingNetwork)(inpt)
return m.val(x) .+ m.adv(x) .- mean(m.adv(x), dims=1)
end

Flux.@functor DuelingNetwork
Flux.@layer DuelingNetwork

function Flux.reset!(m::DuelingNetwork)
Flux.reset!(m.base)
Expand Down
19 changes: 10 additions & 9 deletions src/solver.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@with_kw mutable struct DeepQLearningSolver{E<:ExplorationPolicy} <: Solver
qnetwork::Any = nothing # intended to be a flux model
exploration_policy::E # No default since 9ac3ab
learning_rate::Float32 = 1f-4
max_steps::Int64 = 1000
batch_size::Int64 = 32
Expand All @@ -11,7 +12,6 @@
dueling::Bool = true
recurrence::Bool = false
evaluation_policy::Any = basic_evaluation
exploration_policy::E
trace_length::Int64 = 40
prioritized_replay::Bool = true
prioritized_replay_alpha::Float32 = 0.6f0
Expand Down Expand Up @@ -139,9 +139,8 @@ function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnv, policy::Abstr
sethiddenstates!(active_q, hs)
end

if t%solver.target_update_freq == 0
weights = Flux.params(active_q)
Flux.loadparams!(target_q, weights)
if t % solver.target_update_freq == 0
target_q = deepcopy(active_q)
end

if t % solver.eval_freq == 0
Expand Down Expand Up @@ -170,9 +169,9 @@ function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnv, policy::Abstr
if model_saved
if solver.verbose
@printf("Restore model with eval reward %1.3f \n", saved_mean_reward)
saved_model = BSON.load(joinpath(solver.logdir, "qnetwork.bson"))[:qnetwork]
Flux.loadparams!(getnetwork(policy), saved_model)
end
saved_model_state = BSON.load(joinpath(solver.logdir, "qnetwork_state.bson"))[:qnetwork_state]
Flux.loadmodel!(policy.qnetwork, saved_model_state)
end
return policy
end
Expand Down Expand Up @@ -289,7 +288,9 @@ end

function save_model(solver::DeepQLearningSolver, active_q, scores_eval::Float64, saved_mean_reward::Float64, model_saved::Bool)
if scores_eval >= saved_mean_reward
bson(joinpath(solver.logdir, "qnetwork.bson"), qnetwork=[w for w in Flux.params(active_q)])
copied_model = deepcopy(active_q)
Flux.reset!(copied_model)
bson(joinpath(solver.logdir, "qnetwork_state.bson"), qnetwork_state=Flux.state(copied_model))
if solver.verbose
@printf("Saving new model with eval reward %1.3f \n", scores_eval)
end
Expand All @@ -311,8 +312,8 @@ function restore_best_model(solver::DeepQLearningSolver, env::AbstractEnv)
active_q = solver.qnetwork
end
policy = NNPolicy(env, active_q, collect(actions(env)), length(obs_dimensions(env)))
weights = BSON.load(solver.logdir*"qnetwork.bson")[:qnetwork]
Flux.loadparams!(getnetwork(policy), weights)
saved_network_state = BSON.load(solver.logdir*"qnetwork_state.bson")[:qnetwork_state]
Flux.loadmodel!(getnetwork(policy), saved_network_state)
Flux.testmode!(getnetwork(policy))
return policy
end
Expand Down
42 changes: 42 additions & 0 deletions test/README_examples.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using DeepQLearning
using POMDPs
using Flux
using POMDPModels
using POMDPTools

@testset "README Example 1" begin
# load MDP model from POMDPModels or define your own!
mdp = SimpleGridWorld();

# Define the Q network (see Flux.jl documentation)
# the gridworld state is represented by a 2 dimensional vector.
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2));

solver = DeepQLearningSolver(qnetwork = model, max_steps=10000,
exploration_policy = exploration,
learning_rate=0.005,log_freq=500,
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
policy = solve(solver, mdp)

sim = RolloutSimulator(max_steps=30)
r_tot = simulate(sim, mdp, policy)
println("Total discounted reward for 1 simulation: $r_tot")
end

@testset "README Example 2" begin
# Without using CuArrays
mdp = SimpleGridWorld();

# the model weights will be send to the gpu in the call to solve
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2));

solver = DeepQLearningSolver(qnetwork=model, max_steps=10000,
exploration_policy=exploration,
learning_rate=0.005,log_freq=500,
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
policy = solve(solver, mdp)
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,8 @@ end

@test evaluate(env, policy, GLOBAL_RNG) > 1.0
end


@testset "README Examples" begin
include("README_examples.jl")
end

0 comments on commit 4e37bc9

Please sign in to comment.