Skip to content

Commit

Permalink
Don't collect more than necessary in RaySampler (#1583)
Browse files Browse the repository at this point in the history
There is a poorly-handled case where we have more workers than we need.
At the moment, the fix for this is to only call rollout on workers if
we need samples from those workers. For example, if there are 2
workers, and they are each configured collect 16 samples per call, and
then a call to obtain samples is made asking for only 16 samples, then
only the first sampler worker should be started.

Fixes #1349
  • Loading branch information
avnishn authored Jun 22, 2020
1 parent a3ae9fd commit 888209b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 9 deletions.
18 changes: 14 additions & 4 deletions src/garage/sampler/ray_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def obtain_samples(self, itr, num_samples):
completed_samples = 0
traj = []
updating_workers = []

samples_to_be_collected = 0
# update the policy params of each worker before sampling
# for the current iteration
curr_policy_params = self.algo.policy.get_param_values()
Expand All @@ -98,7 +98,6 @@ def obtain_samples(self, itr, num_samples):
worker_id = self._idle_worker_ids.pop()
worker = self._all_workers[worker_id]
updating_workers.append(worker.set_agent.remote(params_id))

while completed_samples < num_samples:
# if there are workers still being updated, check
# which ones are still updating and take the workers that
Expand All @@ -113,12 +112,20 @@ def obtain_samples(self, itr, num_samples):

# if there are idle workers, use them to collect trajectories
# mark the newly busy workers as active
while self._idle_worker_ids:
workers_to_use = int(
np.clip(
np.ceil(
(num_samples - completed_samples -
samples_to_be_collected) / self._max_path_length) -
len(self._active_worker_ids), 0, len(self._all_workers)))
workers_started = 0
while self._idle_worker_ids and workers_started < workers_to_use:
idle_worker_id = self._idle_worker_ids.pop()
workers_started += 1
self._active_worker_ids.append(idle_worker_id)
samples_to_be_collected += self._max_path_length
worker = self._all_workers[idle_worker_id]
_active_workers.append(worker.rollout.remote())

# check which workers are done/not done collecting a sample
# if any are done, send them to process the collected trajectory
# if they are not, keep checking if they are done
Expand All @@ -129,9 +136,12 @@ def obtain_samples(self, itr, num_samples):
for result in ready:
trajectory, num_returned_samples = self._process_trajectory(
result)
samples_to_be_collected -= self._max_path_length
completed_samples += num_returned_samples
pbar.inc(num_returned_samples)
traj.append(trajectory)
self._idle_worker_ids = list(range(self._num_workers))
self._active_worker_ids = []
pbar.stop()
return traj

Expand Down
32 changes: 27 additions & 5 deletions tests/garage/sampler/test_ray_batched_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,37 @@ def test_ray_batch_sampler(self):
assert (trajs1[0]['observations'].shape == np.array(
trajs2[0]['observations']).shape == (6, ))
traj2_action_shape = np.array(trajs2[0]['actions']).shape
assert (trajs1[0]['actions'].shape == traj2_action_shape == (6, ))
assert (sum(trajs1[0]['rewards']) == sum(trajs2[0]['rewards']) == 1)
assert trajs1[0]['actions'].shape == traj2_action_shape == (6, )
assert sum(trajs1[0]['rewards']) == sum(trajs2[0]['rewards']) == 1

true_obs = np.array([0, 1, 2, 6, 10, 14])
true_actions = np.array([2, 2, 1, 1, 1, 2])
true_rewards = np.array([0, 0, 0, 0, 0, 1])
for trajectory in trajs1:
assert (np.array_equal(trajectory['observations'], true_obs))
assert (np.array_equal(trajectory['actions'], true_actions))
assert (np.array_equal(trajectory['rewards'], true_rewards))
assert np.array_equal(trajectory['observations'], true_obs)
assert np.array_equal(trajectory['actions'], true_actions)
assert np.array_equal(trajectory['rewards'], true_rewards)
sampler1.shutdown_worker()
sampler2.shutdown_worker()

def test_ray_sampler_idle_workers(self):
"""Ensures we only call the necessary number of workers.
There is a poorly-handled case where we have more workers than we need.
At the moment, the fix for this is to only call rollout on workers if
we need samples from those workers. For example, if there are 2
workers, and they are each configured collect 16 samples per call, and
then a call to obtain samples is made asking for only 16 samples, then
only the first sampler worker should be started.
"""
sampler = RaySampler(self.algo,
self.env,
seed=100,
num_processors=2,
sampler_worker_cls=SamplerWorker)
sampler.start_worker()
assert len(sampler._idle_worker_ids) == 2
sampler.obtain_samples(0, 16)
assert len(sampler._idle_worker_ids) == 2
sampler.shutdown_worker()

0 comments on commit 888209b

Please sign in to comment.