Skip to content

Commit

Permalink
Construct sampler in the launcher (#2182)
Browse files Browse the repository at this point in the history
* Construct samplers in launchers

* Change the constructor of samplers

* Fix uss of sampler

* Fix tests

* Pickling ray sampler

* Fix raysampler test

* Fix multiprocessing sampler test
  • Loading branch information
yeukfu authored Nov 30, 2020
1 parent f856f74 commit f8aaef2
Show file tree
Hide file tree
Showing 145 changed files with 1,367 additions and 552 deletions.
7 changes: 7 additions & 0 deletions docs/user/experiments.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ simple one, :code:`examples/tf/trpo_cartpole.py`, is also pasted below:
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import RaySampler
from garage.tf.algos import TRPO
from garage.tf.policies import CategoricalMLPPolicy
from garage.trainer import TFTrainer
Expand All @@ -61,9 +62,15 @@ simple one, :code:`examples/tf/trpo_cartpole.py`, is also pasted below:

baseline = LinearFeatureBaseline(env_spec=env.spec)

sampler = RaySampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True)

algo = TRPO(env_spec=env.spec,
policy=policy,
baseline=baseline,
sampler=sampler,
discount=0.99,
max_kl_step=0.01)

Expand Down
7 changes: 7 additions & 0 deletions examples/np/cem_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.np.algos import CEM
from garage.sampler import LocalSampler
from garage.tf.policies import CategoricalMLPPolicy
from garage.trainer import TFTrainer

Expand All @@ -36,8 +37,14 @@ def cem_cartpole(ctxt=None, seed=1):

n_samples = 20

sampler = LocalSampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True)

algo = CEM(env_spec=env.spec,
policy=policy,
sampler=sampler,
best_frac=0.05,
n_samples=n_samples)

Expand Down
11 changes: 10 additions & 1 deletion examples/np/cma_es_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.np.algos import CMAES
from garage.sampler import LocalSampler
from garage.tf.policies import CategoricalMLPPolicy
from garage.trainer import TFTrainer

Expand All @@ -37,7 +38,15 @@ def cma_es_cartpole(ctxt=None, seed=1):

n_samples = 20

algo = CMAES(env_spec=env.spec, policy=policy, n_samples=n_samples)
sampler = LocalSampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True)

algo = CMAES(env_spec=env.spec,
policy=policy,
sampler=sampler,
n_samples=n_samples)

trainer.setup(algo, env)
trainer.train(n_epochs=100, batch_size=1000)
Expand Down
11 changes: 8 additions & 3 deletions examples/np/tutorial_cem.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ class SimpleCEM:
Args:
env_spec (EnvSpec): Environment specification.
policy (Policy): Action policy.
sampler (garage.sampler.Sampler): Sampler.
"""
sampler_cls = LocalSampler

def __init__(self, env_spec, policy):
def __init__(self, env_spec, policy, sampler):
self.env_spec = env_spec
self.policy = policy
self.sampler = sampler
self.max_episode_length = env_spec.max_episode_length
self._discount = 0.99
self._extra_std = 1
Expand Down Expand Up @@ -117,7 +118,11 @@ def tutorial_cem(ctxt=None):
with TFTrainer(ctxt) as trainer:
env = GymEnv('CartPole-v1')
policy = CategoricalMLPPolicy(env.spec)
algo = SimpleCEM(env.spec, policy)
sampler = LocalSampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True)
algo = SimpleCEM(env.spec, policy, sampler)
trainer.setup(algo, env)
trainer.train(n_epochs=100, batch_size=1000)

Expand Down
8 changes: 8 additions & 0 deletions examples/tf/ddpg_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from garage.experiment.deterministic import set_seed
from garage.np.exploration_policies import AddOrnsteinUhlenbeckNoise
from garage.replay_buffer import PathBuffer
from garage.sampler import FragmentWorker, LocalSampler
from garage.tf.algos import DDPG
from garage.tf.policies import ContinuousMLPPolicy
from garage.tf.q_functions import ContinuousMLPQFunction
Expand Down Expand Up @@ -51,12 +52,19 @@ def ddpg_pendulum(ctxt=None, seed=1):

replay_buffer = PathBuffer(capacity_in_transitions=int(1e6))

sampler = LocalSampler(agents=exploration_policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True,
worker_class=FragmentWorker)

ddpg = DDPG(env_spec=env.spec,
policy=policy,
policy_lr=1e-4,
qf_lr=1e-3,
qf=qf,
replay_buffer=replay_buffer,
sampler=sampler,
steps_per_epoch=20,
target_update_tau=1e-2,
n_train_steps=50,
Expand Down
9 changes: 9 additions & 0 deletions examples/tf/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from garage.experiment.deterministic import set_seed
from garage.np.exploration_policies import EpsilonGreedyPolicy
from garage.replay_buffer import PathBuffer
from garage.sampler import FragmentWorker, LocalSampler
from garage.tf.algos import DQN
from garage.tf.policies import DiscreteQFArgmaxPolicy
from garage.tf.q_functions import DiscreteMLPQFunction
Expand Down Expand Up @@ -41,11 +42,19 @@ def dqn_cartpole(ctxt=None, seed=1):
max_epsilon=1.0,
min_epsilon=0.02,
decay_ratio=0.1)

sampler = LocalSampler(agents=exploration_policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True,
worker_class=FragmentWorker)

algo = DQN(env_spec=env.spec,
policy=policy,
qf=qf,
exploration_policy=exploration_policy,
replay_buffer=replay_buffer,
sampler=sampler,
steps_per_epoch=steps_per_epoch,
qf_lr=1e-4,
discount=1.0,
Expand Down
8 changes: 8 additions & 0 deletions examples/tf/dqn_pong.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from garage.experiment.deterministic import set_seed
from garage.np.exploration_policies import EpsilonGreedyPolicy
from garage.replay_buffer import PathBuffer
from garage.sampler import FragmentWorker, LocalSampler
from garage.tf.algos import DQN
from garage.tf.policies import DiscreteQFArgmaxPolicy
from garage.tf.q_functions import DiscreteCNNQFunction
Expand Down Expand Up @@ -83,11 +84,18 @@ def dqn_pong(ctxt=None, seed=1, buffer_size=int(5e4), max_episode_length=500):
min_epsilon=0.02,
decay_ratio=0.1)

sampler = LocalSampler(agents=exploration_policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True,
worker_class=FragmentWorker)

algo = DQN(env_spec=env.spec,
policy=policy,
qf=qf,
exploration_policy=exploration_policy,
replay_buffer=replay_buffer,
sampler=sampler,
qf_lr=1e-4,
discount=0.99,
min_buffer_size=int(1e4),
Expand Down
7 changes: 7 additions & 0 deletions examples/tf/erwr_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import RaySampler
from garage.tf.algos import ERWR
from garage.tf.policies import CategoricalMLPPolicy
from garage.trainer import TFTrainer
Expand All @@ -37,9 +38,15 @@ def erwr_cartpole(ctxt=None, seed=1):

baseline = LinearFeatureBaseline(env_spec=env.spec)

sampler = RaySampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True)

algo = ERWR(env_spec=env.spec,
policy=policy,
baseline=baseline,
sampler=sampler,
discount=0.99)

trainer.setup(algo=algo, env=env)
Expand Down
8 changes: 8 additions & 0 deletions examples/tf/her_ddpg_fetchreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from garage.experiment.deterministic import set_seed
from garage.np.exploration_policies import AddOrnsteinUhlenbeckNoise
from garage.replay_buffer import HERReplayBuffer
from garage.sampler import FragmentWorker, LocalSampler
from garage.tf.algos import DDPG
from garage.tf.policies import ContinuousMLPPolicy
from garage.tf.q_functions import ContinuousMLPQFunction
Expand Down Expand Up @@ -56,13 +57,20 @@ def her_ddpg_fetchreach(ctxt=None, seed=1):
reward_fn=env.compute_reward,
env_spec=env.spec)

sampler = LocalSampler(agents=exploration_policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True,
worker_class=FragmentWorker)

ddpg = DDPG(
env_spec=env.spec,
policy=policy,
policy_lr=1e-3,
qf_lr=1e-3,
qf=qf,
replay_buffer=replay_buffer,
sampler=sampler,
target_update_tau=0.01,
steps_per_epoch=50,
n_train_steps=40,
Expand Down
7 changes: 7 additions & 0 deletions examples/tf/multi_env_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from garage.envs.multi_env_wrapper import MultiEnvWrapper
from garage.experiment.deterministic import set_seed
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import RaySampler
from garage.tf.algos import PPO
from garage.tf.policies import CategoricalMLPPolicy
from garage.trainer import TFTrainer
Expand Down Expand Up @@ -36,9 +37,15 @@ def multi_env_ppo(ctxt=None, seed=1):

baseline = LinearFeatureBaseline(env_spec=env.spec)

sampler = RaySampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True)

algo = PPO(env_spec=env.spec,
policy=policy,
baseline=baseline,
sampler=sampler,
discount=0.99,
gae_lambda=0.95,
lr_clip_range=0.2,
Expand Down
7 changes: 7 additions & 0 deletions examples/tf/multi_env_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from garage.envs.multi_env_wrapper import MultiEnvWrapper
from garage.experiment.deterministic import set_seed
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import RaySampler
from garage.tf.algos import TRPO
from garage.tf.policies import GaussianMLPPolicy
from garage.trainer import TFTrainer
Expand All @@ -31,9 +32,15 @@ def multi_env_trpo(ctxt=None, seed=1):

baseline = LinearFeatureBaseline(env_spec=env.spec)

sampler = RaySampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True)

algo = TRPO(env_spec=env.spec,
policy=policy,
baseline=baseline,
sampler=sampler,
discount=0.99,
gae_lambda=0.95,
lr_clip_range=0.2,
Expand Down
7 changes: 7 additions & 0 deletions examples/tf/ppo_memorize_digits.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from garage import wrap_experiment
from garage.envs import GymEnv, normalize
from garage.experiment.deterministic import set_seed
from garage.sampler import RaySampler
from garage.tf.algos import PPO
from garage.tf.baselines import GaussianCNNBaseline
from garage.tf.policies import CategoricalCNNPolicy
Expand Down Expand Up @@ -61,9 +62,15 @@ def ppo_memorize_digits(ctxt=None,
hidden_sizes=(256, ),
use_trust_region=True) # yapf: disable

sampler = RaySampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True)

algo = PPO(env_spec=env.spec,
policy=policy,
baseline=baseline,
sampler=sampler,
discount=0.99,
gae_lambda=0.95,
lr_clip_range=0.2,
Expand Down
7 changes: 7 additions & 0 deletions examples/tf/ppo_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from garage import wrap_experiment
from garage.envs import GymEnv, normalize
from garage.experiment.deterministic import set_seed
from garage.sampler import RaySampler
from garage.tf.algos import PPO
from garage.tf.baselines import GaussianMLPBaseline
from garage.tf.policies import GaussianMLPPolicy
Expand Down Expand Up @@ -48,13 +49,19 @@ def ppo_pendulum(ctxt=None, seed=1):
use_trust_region=True,
)

sampler = RaySampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True)

# NOTE: make sure when setting entropy_method to 'max', set
# center_adv to False and turn off policy gradient. See
# tf.algos.NPO for detailed documentation.
algo = PPO(
env_spec=env.spec,
policy=policy,
baseline=baseline,
sampler=sampler,
discount=0.99,
gae_lambda=0.95,
lr_clip_range=0.2,
Expand Down
7 changes: 7 additions & 0 deletions examples/tf/reps_gym_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import RaySampler
from garage.tf.algos import REPS
from garage.tf.policies import CategoricalMLPPolicy
from garage.trainer import TFTrainer
Expand All @@ -37,9 +38,15 @@ def reps_gym_cartpole(ctxt=None, seed=1):

baseline = LinearFeatureBaseline(env_spec=env.spec)

sampler = RaySampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True)

algo = REPS(env_spec=env.spec,
policy=policy,
baseline=baseline,
sampler=sampler,
discount=0.99)

trainer.setup(algo, env)
Expand Down
18 changes: 12 additions & 6 deletions examples/tf/rl2_ppo_halfcheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,22 @@ def rl2_ppo_halfcheetah(ctxt, seed, max_episode_length, meta_batch_size,

baseline = LinearFeatureBaseline(env_spec=env_spec)

envs = tasks.sample(meta_batch_size)
sampler = LocalSampler(
agents=policy,
envs=envs,
max_episode_length=env_spec.max_episode_length,
is_tf_worker=True,
n_workers=meta_batch_size,
worker_class=RL2Worker,
worker_args=dict(n_episodes_per_trial=episode_per_task))

algo = RL2PPO(meta_batch_size=meta_batch_size,
task_sampler=tasks,
env_spec=env_spec,
policy=policy,
baseline=baseline,
sampler=sampler,
episodes_per_trial=episode_per_task,
discount=0.99,
gae_lambda=0.95,
Expand All @@ -73,12 +84,7 @@ def rl2_ppo_halfcheetah(ctxt, seed, max_episode_length, meta_batch_size,
policy_ent_coeff=0.02,
center_adv=False)

trainer.setup(algo,
tasks.sample(meta_batch_size),
sampler_cls=LocalSampler,
n_workers=meta_batch_size,
worker_class=RL2Worker,
worker_args=dict(n_episodes_per_trial=episode_per_task))
trainer.setup(algo, envs)

trainer.train(n_epochs=n_epochs,
batch_size=episode_per_task * max_episode_length *
Expand Down
Loading

0 comments on commit f8aaef2

Please sign in to comment.