From c656da1ae12704379bd74fafc0b1f42498143015 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Thu, 10 Mar 2022 00:54:28 +0000 Subject: [PATCH] add an example --- environments/environment.jl | 5 +-- environments/rlenv.jl | 4 +-- examples/deeprl/Project.toml | 5 +++ examples/deeprl/ant_ppo.jl | 67 ++++++++++++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 examples/deeprl/Project.toml create mode 100644 examples/deeprl/ant_ppo.jl diff --git a/environments/environment.jl b/environments/environment.jl index 3d2f5b727..ae813ea79 100644 --- a/environments/environment.jl +++ b/environments/environment.jl @@ -179,8 +179,8 @@ function MeshCat.render(env::Environment, return nothing end -function seed(env::Environment; s=0) - env.rng[1] = MersenneTwister(seed) +function seed(env::Environment, s=0) + env.rng[1] = MersenneTwister(s) return nothing end @@ -227,6 +227,7 @@ function contains(s::BoxSpace{T,N}, v::AbstractVector{T}) where {T,N} end # For compat with RLBase +Base.length(s::BoxSpace) = s.n Base.in(v::AbstractVector{T}, s::BoxSpace{T,N}) where {T,N} = all(v .>= s.low) && all(v .<= s.high) Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T,N) .* (s.high .- s.low) .+ s.low diff --git a/environments/rlenv.jl b/environments/rlenv.jl index ac7887674..2a7882613 100644 --- a/environments/rlenv.jl +++ b/environments/rlenv.jl @@ -23,7 +23,7 @@ RLBase.is_terminated(env::DojoRLEnv) = env.done RLBase.reset!(env::DojoRLEnv) = reset(env.dojoenv) -RLBase.reward(env::DojoRLEnv) = error() +RLBase.reward(env::DojoRLEnv) = env.reward RLBase.state(env::DojoRLEnv) = env.state Random.seed!(env::DojoRLEnv, seed) = Dojo.seed(env.dojoenv, seed) @@ -33,7 +33,7 @@ Random.seed!(env::DojoRLEnv, seed) = Dojo.seed(env.dojoenv, seed) function (env::DojoRLEnv)(a) s, r, d, i = step(env.dojoenv, a) - env.state = s + env.state .= s env.reward = r env.done = d env.info = i diff --git a/examples/deeprl/Project.toml b/examples/deeprl/Project.toml new file mode 100644 index 000000000..2679eef35 --- /dev/null +++ b/examples/deeprl/Project.toml @@ -0,0 +1,5 @@ +[deps] +Dojo = "ac60b53e-8d92-4c83-b960-e78698fa1916" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReinforcementLearning = "158674fc-8238-5cab-b5ba-03dfc80d1318" diff --git a/examples/deeprl/ant_ppo.jl b/examples/deeprl/ant_ppo.jl new file mode 100644 index 000000000..bb8da3de3 --- /dev/null +++ b/examples/deeprl/ant_ppo.jl @@ -0,0 +1,67 @@ +using ReinforcementLearning +using Flux +using Flux.Losses + +using Random +using Dojo + +function RL.Experiment( + ::Val{:JuliaRL}, + ::Val{:PPO}, + ::Val{:DojoAnt}, + ::Nothing, + save_dir = nothing, + seed = 42 +) + rng = MersenneTwister(seed) + N_ENV = 6 + UPDATE_FREQ = 32 + env_vec = [Dojo.DojoRLEnv("ant") for i in 1:N_ENV] + for i in 1:N_ENV + Random.seed!(env_vec[i], hash(seed+i)) + end + env = MultiThreadEnv(env_vec; is_force=true) + + ns, na = length(state(env[1])), length(action_space(env[1])) + RLBase.reset!(env) + + agent = Agent( + policy = PPOPolicy( + approximator = ActorCritic( + actor = Chain( + Dense(ns, 256, relu; init = glorot_uniform(rng)), + Dense(256, na; init = glorot_uniform(rng)), + ), + critic = Chain( + Dense(ns, 256, relu; init = glorot_uniform(rng)), + Dense(256, 1; init = glorot_uniform(rng)), + ), + optimizer = ADAM(1e-3), + ), + γ = 0.99f0, + λ = 0.95f0, + clip_range = 0.1f0, + max_grad_norm = 0.5f0, + n_epochs = 4, + n_microbatches = 4, + actor_loss_weight = 1.0f0, + critic_loss_weight = 0.5f0, + entropy_loss_weight = 0.001f0, + update_freq = UPDATE_FREQ, + ), + trajectory = PPOTrajectory(; + capacity = UPDATE_FREQ, + state = Matrix{Float32} => (ns, N_ENV), + action = Vector{Int} => (N_ENV,), + action_log_prob = Vector{Float32} => (N_ENV,), + reward = Vector{Float32} => (N_ENV,), + terminal = Vector{Bool} => (N_ENV,), + ), + ) + stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + hook = TotalBatchRewardPerEpisode(N_ENV) + Experiment(agent, env, stop_condition, hook, "# PPO with Dojo Ant") +end + +ex = E`JuliaRL_PPO_DojoAnt` +run(ex) \ No newline at end of file