diff --git a/examples/torch/sac_half_cheetah_batch.py b/examples/torch/sac_half_cheetah_batch.py index 85da7879e6..79135dc5e9 100755 --- a/examples/torch/sac_half_cheetah_batch.py +++ b/examples/torch/sac_half_cheetah_batch.py @@ -56,6 +56,7 @@ def sac_half_cheetah_batch(ctxt=None, seed=1): qf2=qf2, gradient_steps_per_itr=1000, max_episode_length=500, + max_episode_length_eval=1000, replay_buffer=replay_buffer, min_buffer_size=1e4, target_update_tau=5e-3, diff --git a/src/garage/torch/algos/sac.py b/src/garage/torch/algos/sac.py index 7068379f83..a45b31df82 100644 --- a/src/garage/torch/algos/sac.py +++ b/src/garage/torch/algos/sac.py @@ -129,7 +129,8 @@ def __init__( self._discount = discount self._reward_scale = reward_scale self.max_episode_length = max_episode_length - self._max_episode_length_eval = max_episode_length_eval + self._max_episode_length_eval = (max_episode_length_eval + or max_episode_length) self.policy = policy self.env_spec = env_spec