From 2c277fbd6bdd961f815dcde57086cb67f3e17da9 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Wed, 9 Mar 2022 19:46:42 +0000 Subject: [PATCH 1/9] Towards ReinforcementLearning.jl integration --- Project.toml | 1 + environments/rlenv.jl | 38 ++++++++++++++++++++++++++++++++++++++ src/Dojo.jl | 1 + 3 files changed, 40 insertions(+) create mode 100644 environments/rlenv.jl diff --git a/Project.toml b/Project.toml index 3df4743cb..a2dd593ea 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Polyhedra = "67491407-f73d-577b-9b50-8179a7c68029" Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44" Scratch = "6c6a2e73-6563-6170-7368-637461726353" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/environments/rlenv.jl b/environments/rlenv.jl new file mode 100644 index 000000000..a1927e883 --- /dev/null +++ b/environments/rlenv.jl @@ -0,0 +1,38 @@ +using ReinforcementLearningBase: RLBase + +mutable struct DojoRLEnv <: RLBase.AbstractEnv + dojoenv + action_space + observation_space + state + reward + done::Bool + info::Dict +end + +function DojoRLEnv(dojoenv::Environment) + action_space = convert(RLBase.Space, dojoenv.input_space) + observation_space = convert(RLBase.Space, dojoenv.observation_space) + state = reset(dojoenv) + return DojoRLEnv(dojoenv, action_space, observation_space, state, 0.0, false, Dict()) +end + +RLBase.action_space(env::DojoRLEnv) = env.action_space +RLBase.state_space(env::DojoRLEnv) = env.observation_space +RLBase.is_terminated(env::DojoRLEnv) = env.done + +RLBase.reset!(env::DojoRLEnv) = reset(env.dojoenv) + +RLBase.reward(env::DojoRLEnv) = error() +RLBase.state(env::DojoRLEnv) = env.state + +Random.seed!(env::DojoRLEnv, seed) = Dojo.seed(env.dojoenv, seed) + +function (env::DojoRLEnv)(a) + s, r, d, i = step(env.dojoenv, a) + env.state = s + env.reward = r + env.done = d + env.info = i + return nothing +end \ No newline at end of file diff --git a/src/Dojo.jl b/src/Dojo.jl index 4da0c62c0..778c15452 100644 --- a/src/Dojo.jl +++ b/src/Dojo.jl @@ -145,6 +145,7 @@ include(joinpath("..", "environments", "environment.jl")) include(joinpath("..", "environments", "dynamics.jl")) include(joinpath("..", "environments", "utilities.jl")) include(joinpath("..", "environments", "include.jl")) +include(joinpath("..", "environments", "rlenv.jl")) # Bodies export From 2dad1d9e1a5ae9ac1b061518a2ad10a9c4c9d671 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Thu, 10 Mar 2022 00:19:03 +0000 Subject: [PATCH 2/9] get basic interface working --- environments/environment.jl | 6 +++++- environments/rlenv.jl | 27 +++++++++++++++------------ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/environments/environment.jl b/environments/environment.jl index 46500d27f..3d2f5b727 100644 --- a/environments/environment.jl +++ b/environments/environment.jl @@ -196,7 +196,7 @@ end abstract type Space{T,N} end """ - BoxSpace{T,N} <: Environment{T,N} + BoxSpace{T,N} <: Space{T,N} domain with lower and upper limits @@ -226,6 +226,10 @@ function contains(s::BoxSpace{T,N}, v::AbstractVector{T}) where {T,N} all(v .>= s.low) && all(v .<= s.high) end +# For compat with RLBase +Base.in(v::AbstractVector{T}, s::BoxSpace{T,N}) where {T,N} = all(v .>= s.low) && all(v .<= s.high) +Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T,N) .* (s.high .- s.low) .+ s.low + function clip(s::BoxSpace, u) clamp.(u, s.low, s.high) end diff --git a/environments/rlenv.jl b/environments/rlenv.jl index a1927e883..ac7887674 100644 --- a/environments/rlenv.jl +++ b/environments/rlenv.jl @@ -1,24 +1,24 @@ using ReinforcementLearningBase: RLBase -mutable struct DojoRLEnv <: RLBase.AbstractEnv - dojoenv - action_space - observation_space - state - reward +mutable struct DojoRLEnv{T} <: RLBase.AbstractEnv + dojoenv::Environment + state::Vector{T} + reward::T done::Bool info::Dict end -function DojoRLEnv(dojoenv::Environment) - action_space = convert(RLBase.Space, dojoenv.input_space) - observation_space = convert(RLBase.Space, dojoenv.observation_space) +function DojoRLEnv(dojoenv::Environment{X,T}) where {X,T} state = reset(dojoenv) - return DojoRLEnv(dojoenv, action_space, observation_space, state, 0.0, false, Dict()) + return DojoRLEnv{T}(dojoenv, state, convert(T, 0.0), false, Dict()) end -RLBase.action_space(env::DojoRLEnv) = env.action_space -RLBase.state_space(env::DojoRLEnv) = env.observation_space +function DojoRLEnv(name::String; kwargs...) + DojoRLEnv(Dojo.get_environment(name; kwargs...)) +end + +RLBase.action_space(env::DojoRLEnv) = env.dojoenv.input_space +RLBase.state_space(env::DojoRLEnv) = env.dojoenv.observation_space RLBase.is_terminated(env::DojoRLEnv) = env.done RLBase.reset!(env::DojoRLEnv) = reset(env.dojoenv) @@ -28,6 +28,9 @@ RLBase.state(env::DojoRLEnv) = env.state Random.seed!(env::DojoRLEnv, seed) = Dojo.seed(env.dojoenv, seed) +# TODO: +# RLBase.ChanceStyle(env::DojoRLEnv) = RLBase.DETERMINISTIC + function (env::DojoRLEnv)(a) s, r, d, i = step(env.dojoenv, a) env.state = s From b12adf4c63c93151a00c1421c271039abd46f44d Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Thu, 10 Mar 2022 00:54:28 +0000 Subject: [PATCH 3/9] add an example --- environments/environment.jl | 5 +-- environments/rlenv.jl | 4 +-- examples/deeprl/Project.toml | 5 +++ examples/deeprl/ant_ppo.jl | 67 ++++++++++++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 examples/deeprl/Project.toml create mode 100644 examples/deeprl/ant_ppo.jl diff --git a/environments/environment.jl b/environments/environment.jl index 3d2f5b727..ae813ea79 100644 --- a/environments/environment.jl +++ b/environments/environment.jl @@ -179,8 +179,8 @@ function MeshCat.render(env::Environment, return nothing end -function seed(env::Environment; s=0) - env.rng[1] = MersenneTwister(seed) +function seed(env::Environment, s=0) + env.rng[1] = MersenneTwister(s) return nothing end @@ -227,6 +227,7 @@ function contains(s::BoxSpace{T,N}, v::AbstractVector{T}) where {T,N} end # For compat with RLBase +Base.length(s::BoxSpace) = s.n Base.in(v::AbstractVector{T}, s::BoxSpace{T,N}) where {T,N} = all(v .>= s.low) && all(v .<= s.high) Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T,N) .* (s.high .- s.low) .+ s.low diff --git a/environments/rlenv.jl b/environments/rlenv.jl index ac7887674..2a7882613 100644 --- a/environments/rlenv.jl +++ b/environments/rlenv.jl @@ -23,7 +23,7 @@ RLBase.is_terminated(env::DojoRLEnv) = env.done RLBase.reset!(env::DojoRLEnv) = reset(env.dojoenv) -RLBase.reward(env::DojoRLEnv) = error() +RLBase.reward(env::DojoRLEnv) = env.reward RLBase.state(env::DojoRLEnv) = env.state Random.seed!(env::DojoRLEnv, seed) = Dojo.seed(env.dojoenv, seed) @@ -33,7 +33,7 @@ Random.seed!(env::DojoRLEnv, seed) = Dojo.seed(env.dojoenv, seed) function (env::DojoRLEnv)(a) s, r, d, i = step(env.dojoenv, a) - env.state = s + env.state .= s env.reward = r env.done = d env.info = i diff --git a/examples/deeprl/Project.toml b/examples/deeprl/Project.toml new file mode 100644 index 000000000..2679eef35 --- /dev/null +++ b/examples/deeprl/Project.toml @@ -0,0 +1,5 @@ +[deps] +Dojo = "ac60b53e-8d92-4c83-b960-e78698fa1916" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReinforcementLearning = "158674fc-8238-5cab-b5ba-03dfc80d1318" diff --git a/examples/deeprl/ant_ppo.jl b/examples/deeprl/ant_ppo.jl new file mode 100644 index 000000000..5359ef5f4 --- /dev/null +++ b/examples/deeprl/ant_ppo.jl @@ -0,0 +1,67 @@ +using ReinforcementLearning +using Flux +using Flux.Losses + +using Random +using Dojo + +function RL.Experiment( + ::Val{:JuliaRL}, + ::Val{:PPO}, + ::Val{:DojoAnt}, + ::Nothing, + save_dir = nothing, + seed = 42 +) + rng = MersenneTwister(seed) + N_ENV = 6 + UPDATE_FREQ = 32 + env_vec = [Dojo.DojoRLEnv("ant") for i in 1:N_ENV] + for i in 1:N_ENV + Random.seed!(env_vec[i], hash(seed+i)) + end + env = MultiThreadEnv(env_vec) + + ns, na = length(state(env[1])), length(action_space(env[1])) + RLBase.reset!(env; is_force=true) + + agent = Agent( + policy = PPOPolicy( + approximator = ActorCritic( + actor = Chain( + Dense(ns, 256, relu; init = glorot_uniform(rng)), + Dense(256, na; init = glorot_uniform(rng)), + ), + critic = Chain( + Dense(ns, 256, relu; init = glorot_uniform(rng)), + Dense(256, 1; init = glorot_uniform(rng)), + ), + optimizer = ADAM(1e-3), + ), + γ = 0.99f0, + λ = 0.95f0, + clip_range = 0.1f0, + max_grad_norm = 0.5f0, + n_epochs = 4, + n_microbatches = 4, + actor_loss_weight = 1.0f0, + critic_loss_weight = 0.5f0, + entropy_loss_weight = 0.001f0, + update_freq = UPDATE_FREQ, + ), + trajectory = PPOTrajectory(; + capacity = UPDATE_FREQ, + state = Matrix{Float32} => (ns, N_ENV), + action = Vector{Int} => (N_ENV,), + action_log_prob = Vector{Float32} => (N_ENV,), + reward = Vector{Float32} => (N_ENV,), + terminal = Vector{Bool} => (N_ENV,), + ), + ) + stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + hook = TotalBatchRewardPerEpisode(N_ENV) + Experiment(agent, env, stop_condition, hook, "# PPO with Dojo Ant") +end + +ex = E`JuliaRL_PPO_DojoAnt` +run(ex) \ No newline at end of file From 6200f6c1e85af9714f5ccedae3ba8bf6241dee51 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Thu, 10 Mar 2022 01:48:05 +0000 Subject: [PATCH 4/9] add simple cartpole example; issue is with RL.jl --- environments/rlenv.jl | 3 +- examples/deeprl/cartpole_ddpg.jl | 88 ++++++++++++++++++++++++++++++++ examples/deeprl/cartpole_ppo.jl | 67 ++++++++++++++++++++++++ 3 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 examples/deeprl/cartpole_ddpg.jl create mode 100644 examples/deeprl/cartpole_ppo.jl diff --git a/environments/rlenv.jl b/environments/rlenv.jl index 2a7882613..947026973 100644 --- a/environments/rlenv.jl +++ b/environments/rlenv.jl @@ -38,4 +38,5 @@ function (env::DojoRLEnv)(a) env.done = d env.info = i return nothing -end \ No newline at end of file +end +(env::DojoRLEnv)(a::AbstractFloat) = env([a]) diff --git a/examples/deeprl/cartpole_ddpg.jl b/examples/deeprl/cartpole_ddpg.jl new file mode 100644 index 000000000..4cf3f1a1a --- /dev/null +++ b/examples/deeprl/cartpole_ddpg.jl @@ -0,0 +1,88 @@ +using ReinforcementLearning +using Flux +using Flux.Losses + +using Random +using Dojo + +function RL.Experiment( + ::Val{:JuliaRL}, + ::Val{:DDPG}, + ::Val{:DojoCartpole}, + ::Nothing, + save_dir = nothing, + seed = 42 +) + + rng = MersenneTwister(seed) + inner_env = Dojo.DojoRLEnv("cartpole") + Random.seed!(inner_env, seed) + # TODO + low = -5.0 + high = 5.0 + ns, na = length(state(inner_env)), length(action_space(inner_env)) + @show na + A = Dojo.BoxSpace(na) + env = ActionTransformedEnv( + inner_env; + action_mapping = x -> low .+ (x .+ 1) .* 0.5 .* (high .- low), + action_space_mapping = _ -> A + ) + + init = glorot_uniform(rng) + + create_actor() = Chain( + Dense(ns, 30, relu; init = init), + Dense(30, 30, relu; init = init), + Dense(30, na, tanh; init = init), + ) + create_critic() = Chain( + Dense(ns + na, 30, relu; init = init), + Dense(30, 30, relu; init = init), + Dense(30, 1; init = init), + ) + + agent = Agent( + policy = DDPGPolicy( + behavior_actor = NeuralNetworkApproximator( + model = create_actor(), + optimizer = ADAM(), + ), + behavior_critic = NeuralNetworkApproximator( + model = create_critic(), + optimizer = ADAM(), + ), + target_actor = NeuralNetworkApproximator( + model = create_actor(), + optimizer = ADAM(), + ), + target_critic = NeuralNetworkApproximator( + model = create_critic(), + optimizer = ADAM(), + ), + γ = 0.99f0, + ρ = 0.995f0, + na = na, + batch_size = 64, + start_steps = 1000, + start_policy = RandomPolicy(A; rng = rng), + update_after = 1000, + update_freq = 1, + act_limit = 1.0, + act_noise = 0.1, + rng = rng, + ), + trajectory = CircularArraySARTTrajectory( + capacity = 10000, + state = Vector{Float32} => (ns,), + action = Float32 => (na, ), + ), + ) + + stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + hook = TotalRewardPerEpisode() + Experiment(agent, env, stop_condition, hook, "# Dojo Cartpole with DDPG") +end + +ex = E`JuliaRL_DDPG_DojoCartpole` +run(ex) \ No newline at end of file diff --git a/examples/deeprl/cartpole_ppo.jl b/examples/deeprl/cartpole_ppo.jl new file mode 100644 index 000000000..ebfd86516 --- /dev/null +++ b/examples/deeprl/cartpole_ppo.jl @@ -0,0 +1,67 @@ +using ReinforcementLearning +using Flux +using Flux.Losses + +using Random +using Dojo + +function RL.Experiment( + ::Val{:JuliaRL}, + ::Val{:PPO}, + ::Val{:DojoCartpole}, + ::Nothing, + save_dir = nothing, + seed = 42 +) + rng = MersenneTwister(seed) + N_ENV = 6 + UPDATE_FREQ = 32 + env_vec = [Dojo.DojoRLEnv("cartpole") for i in 1:N_ENV] + for i in 1:N_ENV + Random.seed!(env_vec[i], hash(seed+i)) + end + env = MultiThreadEnv(env_vec) + + ns, na = length(state(env[1])), length(action_space(env[1])) + RLBase.reset!(env; is_force=true) + + agent = Agent( + policy = PPOPolicy( + approximator = ActorCritic( + actor = Chain( + Dense(ns, 256, relu; init = glorot_uniform(rng)), + Dense(256, na; init = glorot_uniform(rng)), + ), + critic = Chain( + Dense(ns, 256, relu; init = glorot_uniform(rng)), + Dense(256, 1; init = glorot_uniform(rng)), + ), + optimizer = ADAM(1e-3), + ), + γ = 0.99f0, + λ = 0.95f0, + clip_range = 0.1f0, + max_grad_norm = 0.5f0, + n_epochs = 4, + n_microbatches = 4, + actor_loss_weight = 1.0f0, + critic_loss_weight = 0.5f0, + entropy_loss_weight = 0.001f0, + update_freq = UPDATE_FREQ, + ), + trajectory = PPOTrajectory(; + capacity = UPDATE_FREQ, + state = Matrix{Float32} => (ns, N_ENV), + action = Vector{Int} => (N_ENV,), + action_log_prob = Vector{Float32} => (N_ENV,), + reward = Vector{Float32} => (N_ENV,), + terminal = Vector{Bool} => (N_ENV,), + ), + ) + stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + hook = TotalBatchRewardPerEpisode(N_ENV) + Experiment(agent, env, stop_condition, hook, "# PPO with Dojo Cartpole") +end + +ex = E`JuliaRL_PPO_DojoCartpole` +run(ex) \ No newline at end of file From 4af3538ae459c238014173fee23b6abd1ac13f18 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Fri, 11 Mar 2022 01:13:22 +0000 Subject: [PATCH 5/9] cartpole is meaningless as no reward defined --- examples/deeprl/ant_ddpg.jl | 80 +++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 examples/deeprl/ant_ddpg.jl diff --git a/examples/deeprl/ant_ddpg.jl b/examples/deeprl/ant_ddpg.jl new file mode 100644 index 000000000..e71d2844c --- /dev/null +++ b/examples/deeprl/ant_ddpg.jl @@ -0,0 +1,80 @@ +using ReinforcementLearning +using Flux +using Flux.Losses + +using Random +using Dojo + +function RL.Experiment( + ::Val{:JuliaRL}, + ::Val{:DDPG}, + ::Val{:DojoAnt}, + ::Nothing, + save_dir = nothing, + seed = 42 +) + + rng = MersenneTwister(seed) + env = Dojo.DojoRLEnv("ant") + Random.seed!(env, seed) + A = action_space(env) + ns, na = length(state(env)), length(action_space(env)) + @show na + + init = glorot_uniform(rng) + + create_actor() = Chain( + Dense(ns, 30, relu; init = init), + Dense(30, 30, relu; init = init), + Dense(30, na, tanh; init = init), + ) + create_critic() = Chain( + Dense(ns + na, 30, relu; init = init), + Dense(30, 30, relu; init = init), + Dense(30, 1; init = init), + ) + + agent = Agent( + policy = DDPGPolicy( + behavior_actor = NeuralNetworkApproximator( + model = create_actor(), + optimizer = ADAM(), + ), + behavior_critic = NeuralNetworkApproximator( + model = create_critic(), + optimizer = ADAM(), + ), + target_actor = NeuralNetworkApproximator( + model = create_actor(), + optimizer = ADAM(), + ), + target_critic = NeuralNetworkApproximator( + model = create_critic(), + optimizer = ADAM(), + ), + γ = 0.99f0, + ρ = 0.995f0, + na = na, + batch_size = 64, + start_steps = 1000, + start_policy = RandomPolicy(A; rng = rng), + update_after = 1000, + update_freq = 1, + act_limit = 1.0, + act_noise = 0.1, + rng = rng, + ), + trajectory = CircularArraySARTTrajectory( + capacity = 10000, + state = Vector{Float32} => (ns,), + action = Float32 => (na, ), + ), + ) + + stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + hook = TotalRewardPerEpisode() + Experiment(agent, env, stop_condition, hook, "# Dojo Ant with DDPG") +end + +ex = E`JuliaRL_DDPG_DojoAnt` +run(ex) \ No newline at end of file From b7b11e06bdf3f9cfab8c6bd73f73f5ea7ae634e3 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sat, 12 Mar 2022 23:48:23 +0800 Subject: [PATCH 6/9] fix space related definitions --- environments/environment.jl | 52 ++++++++++++++++--------------------- environments/rlenv.jl | 10 ++++--- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/environments/environment.jl b/environments/environment.jl index ae813ea79..94d1ddb8e 100644 --- a/environments/environment.jl +++ b/environments/environment.jl @@ -31,7 +31,7 @@ mutable struct Environment{X,T,M,A,O,I} dynamics_jacobian_state::Matrix{T} dynamics_jacobian_input::Matrix{T} input_previous::Vector{T} - control_map::Matrix{T} + control_map::Matrix{T} num_states::Int num_inputs::Int num_observations::Int @@ -66,33 +66,33 @@ end attitude_decompress: flag for pre- and post-concatenating Jacobians with attitude Jacobians """ function Base.step(env::Environment, x, u; - gradients=false, - attitude_decompress=false) + gradients = false, + attitude_decompress = false) mechanism = env.mechanism - timestep= mechanism.timestep + timestep = mechanism.timestep x0 = x # u = clip(env.input_space, u) # control limits env.input_previous .= u # for rendering in Gym - u_scaled = env.control_map * u + u_scaled = env.control_map * u z0 = env.representation == :minimal ? minimal_to_maximal(mechanism, x0) : x0 - z1 = step!(mechanism, z0, u_scaled; opts=env.opts_step) + z1 = step!(mechanism, z0, u_scaled; opts = env.opts_step) env.state .= env.representation == :minimal ? maximal_to_minimal(mechanism, z1) : z1 # Compute cost costs = cost(env, x, u) - # Check termination - done = is_done(env, x) + # Check termination + done = is_done(env, x) # Gradients if gradients if env.representation == :minimal - fx, fu = get_minimal_gradients!(env.mechanism, z0, u_scaled, opts=env.opts_grad) + fx, fu = get_minimal_gradients!(env.mechanism, z0, u_scaled, opts = env.opts_grad) elseif env.representation == :maximal - fx, fu = get_maximal_gradients!(env.mechanism, z0, u_scaled, opts=env.opts_grad) + fx, fu = get_maximal_gradients!(env.mechanism, z0, u_scaled, opts = env.opts_grad) if attitude_decompress A0 = attitude_jacobian(z0, length(env.mechanism.bodies)) A1 = attitude_jacobian(z1, length(env.mechanism.bodies)) @@ -109,11 +109,11 @@ function Base.step(env::Environment, x, u; end function Base.step(env::Environment, u; - gradients=false, - attitude_decompress=false) - step(env, env.state, u; - gradients=gradients, - attitude_decompress=attitude_decompress) + gradients = false, + attitude_decompress = false) + step(env, env.state, u; + gradients = gradients, + attitude_decompress = attitude_decompress) end """ @@ -156,7 +156,7 @@ is_done(env::Environment, x) = false x: state """ function Base.reset(env::Environment{X}; - x=nothing) where X + x = nothing) where {X} initialize!(env.mechanism, type2symbol(X)) if x != nothing @@ -172,14 +172,14 @@ function Base.reset(env::Environment{X}; return get_observation(env) end -function MeshCat.render(env::Environment, - mode="human") +function MeshCat.render(env::Environment, + mode = "human") z = env.representation == :minimal ? minimal_to_maximal(env.mechanism, env.state) : env.state - set_robot(env.vis, env.mechanism, z, name=:robot) + set_robot(env.vis, env.mechanism, z, name = :robot) return nothing end -function seed(env::Environment, s=0) +function seed(env::Environment, s = 0) env.rng[1] = MersenneTwister(s) return nothing end @@ -214,26 +214,20 @@ mutable struct BoxSpace{T,N} <: Space{T,N} dtype::DataType # this is always T, it's needed to interface with Stable-Baselines end -function BoxSpace(n::Int; low::AbstractVector{T} = -ones(n), high::AbstractVector{T} = ones(n)) where T +function BoxSpace(n::Int; low::AbstractVector{T} = -ones(n), high::AbstractVector{T} = ones(n)) where {T} return BoxSpace{T,n}(n, low, high, (n,), T) end function sample(s::BoxSpace{T,N}) where {T,N} - return rand(T,N) .* (s.high .- s.low) .+ s.low + return rand(T, N) .* (s.high .- s.low) .+ s.low end function contains(s::BoxSpace{T,N}, v::AbstractVector{T}) where {T,N} all(v .>= s.low) && all(v .<= s.high) end -# For compat with RLBase -Base.length(s::BoxSpace) = s.n -Base.in(v::AbstractVector{T}, s::BoxSpace{T,N}) where {T,N} = all(v .>= s.low) && all(v .<= s.high) -Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T,N) .* (s.high .- s.low) .+ s.low - function clip(s::BoxSpace, u) clamp.(u, s.low, s.high) end - - +Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T, N) .* (s.high .- s.low) .+ s.low diff --git a/environments/rlenv.jl b/environments/rlenv.jl index 947026973..3baf47d12 100644 --- a/environments/rlenv.jl +++ b/environments/rlenv.jl @@ -17,8 +17,12 @@ function DojoRLEnv(name::String; kwargs...) DojoRLEnv(Dojo.get_environment(name; kwargs...)) end -RLBase.action_space(env::DojoRLEnv) = env.dojoenv.input_space -RLBase.state_space(env::DojoRLEnv) = env.dojoenv.observation_space +function Base.convert(::Type{RLBase.Space}, s::BoxSpace) + RLBase.Space([BoxSpace(1; low = s.low[i:i], high = s.high[i:i]) for i in 1:s.n]) +end + +RLBase.action_space(env::DojoRLEnv) = convert(RLBase.Space, env.dojoenv.input_space) +RLBase.state_space(env::DojoRLEnv) = convert(RLBase.Space, env.dojoenv.observation_space) RLBase.is_terminated(env::DojoRLEnv) = env.done RLBase.reset!(env::DojoRLEnv) = reset(env.dojoenv) @@ -39,4 +43,4 @@ function (env::DojoRLEnv)(a) env.info = i return nothing end -(env::DojoRLEnv)(a::AbstractFloat) = env([a]) +(env::DojoRLEnv)(a::Number) = env([a]) From 08188a602843ad95d978d515a7d4183de17bf6d4 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Sun, 13 Mar 2022 02:47:01 +0000 Subject: [PATCH 7/9] fix ppo policy --- examples/deeprl/ant_ppo.jl | 14 +++++++++----- examples/deeprl/cartpole_ppo.jl | 12 ++++++++---- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/examples/deeprl/ant_ppo.jl b/examples/deeprl/ant_ppo.jl index 5359ef5f4..e2a13f463 100644 --- a/examples/deeprl/ant_ppo.jl +++ b/examples/deeprl/ant_ppo.jl @@ -28,13 +28,17 @@ function RL.Experiment( agent = Agent( policy = PPOPolicy( approximator = ActorCritic( - actor = Chain( - Dense(ns, 256, relu; init = glorot_uniform(rng)), - Dense(256, na; init = glorot_uniform(rng)), - ), + actor = GaussianNetwork( + pre = Chain( + Dense(ns, 64, relu; init = glorot_uniform(rng)), + Dense(64, 64, relu; init = glorot_uniform(rng)), + ), + μ = Chain(Dense(64, na, tanh; init = glorot_uniform(rng)), vec), + logσ = Chain(Dense(64, na; init = glorot_uniform(rng)), vec), + ), critic = Chain( Dense(ns, 256, relu; init = glorot_uniform(rng)), - Dense(256, 1; init = glorot_uniform(rng)), + Dense(256, na; init = glorot_uniform(rng)), ), optimizer = ADAM(1e-3), ), diff --git a/examples/deeprl/cartpole_ppo.jl b/examples/deeprl/cartpole_ppo.jl index ebfd86516..7ba8bb5c0 100644 --- a/examples/deeprl/cartpole_ppo.jl +++ b/examples/deeprl/cartpole_ppo.jl @@ -28,10 +28,14 @@ function RL.Experiment( agent = Agent( policy = PPOPolicy( approximator = ActorCritic( - actor = Chain( - Dense(ns, 256, relu; init = glorot_uniform(rng)), - Dense(256, na; init = glorot_uniform(rng)), - ), + actor = GaussianNetwork( + pre = Chain( + Dense(ns, 64, relu; init = glorot_uniform(rng)), + Dense(64, 64, relu; init = glorot_uniform(rng)), + ), + μ = Chain(Dense(64, na, tanh; init = glorot_uniform(rng)), vec), + logσ = Chain(Dense(64, na; init = glorot_uniform(rng)), vec), + ), critic = Chain( Dense(ns, 256, relu; init = glorot_uniform(rng)), Dense(256, 1; init = glorot_uniform(rng)), From e57b198bbf10a8d0f082b1610e8c114fa4bdca8f Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Sun, 13 Mar 2022 03:01:46 +0000 Subject: [PATCH 8/9] fixes for RL integration, still errors --- examples/deeprl/Project.toml | 1 + examples/deeprl/ant_ppo.jl | 4 +++- examples/deeprl/cartpole_ppo.jl | 15 +++++++++------ 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/deeprl/Project.toml b/examples/deeprl/Project.toml index 2679eef35..81e370e7a 100644 --- a/examples/deeprl/Project.toml +++ b/examples/deeprl/Project.toml @@ -1,4 +1,5 @@ [deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Dojo = "ac60b53e-8d92-4c83-b960-e78698fa1916" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/examples/deeprl/ant_ppo.jl b/examples/deeprl/ant_ppo.jl index e2a13f463..8e731225f 100644 --- a/examples/deeprl/ant_ppo.jl +++ b/examples/deeprl/ant_ppo.jl @@ -3,6 +3,7 @@ using Flux using Flux.Losses using Random +using Distributions using Dojo function RL.Experiment( @@ -51,12 +52,13 @@ function RL.Experiment( actor_loss_weight = 1.0f0, critic_loss_weight = 0.5f0, entropy_loss_weight = 0.001f0, + dist = Normal, update_freq = UPDATE_FREQ, ), trajectory = PPOTrajectory(; capacity = UPDATE_FREQ, state = Matrix{Float32} => (ns, N_ENV), - action = Vector{Int} => (N_ENV,), + action = Vector{Float32} => (N_ENV,), action_log_prob = Vector{Float32} => (N_ENV,), reward = Vector{Float32} => (N_ENV,), terminal = Vector{Bool} => (N_ENV,), diff --git a/examples/deeprl/cartpole_ppo.jl b/examples/deeprl/cartpole_ppo.jl index 7ba8bb5c0..992b3b1d0 100644 --- a/examples/deeprl/cartpole_ppo.jl +++ b/examples/deeprl/cartpole_ppo.jl @@ -3,6 +3,7 @@ using Flux using Flux.Losses using Random +using Distributions using Dojo function RL.Experiment( @@ -37,26 +38,28 @@ function RL.Experiment( logσ = Chain(Dense(64, na; init = glorot_uniform(rng)), vec), ), critic = Chain( - Dense(ns, 256, relu; init = glorot_uniform(rng)), - Dense(256, 1; init = glorot_uniform(rng)), + Dense(ns, 64, relu; init = glorot_uniform(rng)), + Dense(64, 64, relu; init = glorot_uniform(rng)), + Dense(64, 1; init = glorot_uniform(rng)), ), optimizer = ADAM(1e-3), ), γ = 0.99f0, λ = 0.95f0, - clip_range = 0.1f0, + clip_range = 0.2f0, max_grad_norm = 0.5f0, - n_epochs = 4, - n_microbatches = 4, + n_epochs = 10, + n_microbatches = 32, actor_loss_weight = 1.0f0, critic_loss_weight = 0.5f0, entropy_loss_weight = 0.001f0, + dist = Normal, update_freq = UPDATE_FREQ, ), trajectory = PPOTrajectory(; capacity = UPDATE_FREQ, state = Matrix{Float32} => (ns, N_ENV), - action = Vector{Int} => (N_ENV,), + action = Vector{Float32} => (N_ENV,), action_log_prob = Vector{Float32} => (N_ENV,), reward = Vector{Float32} => (N_ENV,), terminal = Vector{Bool} => (N_ENV,), From eb379f6bb850d3e6cd6b48c42e654df1f52ca33f Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Sun, 13 Mar 2022 03:04:07 +0000 Subject: [PATCH 9/9] try some more fixes --- examples/deeprl/ant_ppo.jl | 2 +- examples/deeprl/cartpole_ppo.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/deeprl/ant_ppo.jl b/examples/deeprl/ant_ppo.jl index 8e731225f..ddb06ebf2 100644 --- a/examples/deeprl/ant_ppo.jl +++ b/examples/deeprl/ant_ppo.jl @@ -58,7 +58,7 @@ function RL.Experiment( trajectory = PPOTrajectory(; capacity = UPDATE_FREQ, state = Matrix{Float32} => (ns, N_ENV), - action = Vector{Float32} => (N_ENV,), + action = Matrix{Float32} => (na, N_ENV), action_log_prob = Vector{Float32} => (N_ENV,), reward = Vector{Float32} => (N_ENV,), terminal = Vector{Bool} => (N_ENV,), diff --git a/examples/deeprl/cartpole_ppo.jl b/examples/deeprl/cartpole_ppo.jl index 992b3b1d0..659d1e7b8 100644 --- a/examples/deeprl/cartpole_ppo.jl +++ b/examples/deeprl/cartpole_ppo.jl @@ -59,7 +59,7 @@ function RL.Experiment( trajectory = PPOTrajectory(; capacity = UPDATE_FREQ, state = Matrix{Float32} => (ns, N_ENV), - action = Vector{Float32} => (N_ENV,), + action = Matrix{Float32} => (na, N_ENV), action_log_prob = Vector{Float32} => (N_ENV,), reward = Vector{Float32} => (N_ENV,), terminal = Vector{Bool} => (N_ENV,),