Skip to content

Commit

Permalink
Support batch_size=None and use it in various scripts (#993)
Browse files Browse the repository at this point in the history
Closes #986
  • Loading branch information
MischaPanch authored Nov 24, 2023
1 parent f134bc2 commit 8d3d1f1
Show file tree
Hide file tree
Showing 20 changed files with 63 additions and 36 deletions.
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_args():
parser.add_argument("--step-per-collect", type=int, default=80)
parser.add_argument("--repeat-per-collect", type=int, default=1)
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument("--batch-size", type=int, default=99999)
parser.add_argument("--batch-size", type=int, default=None)
parser.add_argument("--training-num", type=int, default=16)
parser.add_argument("--test-num", type=int, default=10)
# a2c special
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_a2c_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main(
step_per_epoch: int = 30000,
step_per_collect: int = 80,
repeat_per_collect: int = 1,
batch_size: int = 99999,
batch_size: int | None = None,
training_num: int = 16,
test_num: int = 10,
rew_norm: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_args():
parser.add_argument("--step-per-collect", type=int, default=1024)
parser.add_argument("--repeat-per-collect", type=int, default=1)
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument("--batch-size", type=int, default=99999)
parser.add_argument("--batch-size", type=int, default=None)
parser.add_argument("--training-num", type=int, default=16)
parser.add_argument("--test-num", type=int, default=10)
# npg special
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_npg_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def main(
step_per_epoch: int = 30000,
step_per_collect: int = 1024,
repeat_per_collect: int = 1,
batch_size: int = 99999,
batch_size: int | None = None,
training_num: int = 16,
test_num: int = 10,
rew_norm: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_args():
parser.add_argument("--step-per-collect", type=int, default=2048)
parser.add_argument("--repeat-per-collect", type=int, default=1)
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument("--batch-size", type=int, default=99999)
parser.add_argument("--batch-size", type=int, default=None)
parser.add_argument("--training-num", type=int, default=64)
parser.add_argument("--test-num", type=int, default=10)
# reinforce special
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_reinforce_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main(
step_per_epoch: int = 30000,
step_per_collect: int = 2048,
repeat_per_collect: int = 1,
batch_size: int = 99999,
batch_size: int | None = None,
training_num: int = 64,
test_num: int = 10,
rew_norm: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_args():
parser.add_argument("--step-per-collect", type=int, default=1024)
parser.add_argument("--repeat-per-collect", type=int, default=1)
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument("--batch-size", type=int, default=99999)
parser.add_argument("--batch-size", type=int, default=None)
parser.add_argument("--training-num", type=int, default=16)
parser.add_argument("--test-num", type=int, default=10)
# trpo special
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_trpo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def main(
step_per_epoch: int = 30000,
step_per_collect: int = 1024,
repeat_per_collect: int = 1,
batch_size: int = 99999,
batch_size: int | None = None,
training_num: int = 16,
test_num: int = 10,
rew_norm: bool = True,
Expand Down
16 changes: 14 additions & 2 deletions tianshou/data/buffer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,20 +290,32 @@ def add(
self._meta[ptr] = batch
return ptr, ep_rew, ep_len, ep_idx

def sample_indices(self, batch_size: int) -> np.ndarray:
def sample_indices(self, batch_size: int | None) -> np.ndarray:
"""Get a random sample of index with size = batch_size.
Return all available indices in the buffer if batch_size is 0; return an empty
numpy array if batch_size < 0 or no available index can be sampled.
:param batch_size: the number of indices to be sampled. If None, it will be set
to the length of the buffer (i.e. return all available indices in a
random order).
"""
if batch_size is None:
batch_size = len(self)
if self.stack_num == 1 or not self._sample_avail: # most often case
if batch_size > 0:
return np.random.choice(self._size, batch_size)
# TODO: is this behavior really desired?
if batch_size == 0: # construct current available indices
return np.concatenate([np.arange(self._index, self._size), np.arange(self._index)])
return np.array([], int)
# TODO: raise error on negative batch_size instead?
if batch_size < 0:
return np.array([], int)
# TODO: simplify this code - shouldn't have such a large if-else
# with many returns for handling different stack nums.
# It is also not clear whether this is really necessary - frame stacking usually is handled
# by environment wrappers (e.g. FrameStack) and not by the replay buffer.
all_indices = prev_indices = np.concatenate(
[np.arange(self._index, self._size), np.arange(self._index)],
)
Expand All @@ -314,7 +326,7 @@ def sample_indices(self, batch_size: int) -> np.ndarray:
return np.random.choice(all_indices, batch_size)
return all_indices

def sample(self, batch_size: int) -> tuple[RolloutBatchProtocol, np.ndarray]:
def sample(self, batch_size: int | None) -> tuple[RolloutBatchProtocol, np.ndarray]:
"""Get a random sample from buffer with size = batch_size.
Return all the data in the buffer if batch_size is 0.
Expand Down
2 changes: 1 addition & 1 deletion tianshou/data/buffer/her.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def add(
self._restore_cache()
return super().add(batch, buffer_ids)

def sample_indices(self, batch_size: int) -> np.ndarray:
def sample_indices(self, batch_size: int | None) -> np.ndarray:
"""Get a random sample of index with size = batch_size.
Return all available indices in the buffer if batch_size is 0; return an \
Expand Down
10 changes: 7 additions & 3 deletions tianshou/data/buffer/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,10 @@ def add(
self._meta[ptrs] = batch
return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs)

def sample_indices(self, batch_size: int) -> np.ndarray:
if batch_size < 0:
def sample_indices(self, batch_size: int | None) -> np.ndarray:
# TODO: simplify this code
if batch_size is not None and batch_size < 0:
# TODO: raise error instead?
return np.array([], int)
if self._sample_avail and self.stack_num > 1:
all_indices = np.concatenate(
Expand All @@ -181,8 +183,10 @@ def sample_indices(self, batch_size: int) -> np.ndarray:
)
if batch_size == 0:
return all_indices
if batch_size is None:
batch_size = len(all_indices)
return np.random.choice(all_indices, batch_size)
if batch_size == 0: # get all available indices
if batch_size == 0 or batch_size is None: # get all available indices
sample_num = np.zeros(self.buffer_num, int)
else:
buffer_idx = np.random.choice(
Expand Down
4 changes: 2 additions & 2 deletions tianshou/data/buffer/prio.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def add(
self.init_weight(ptr)
return ptr, ep_rew, ep_len, ep_idx

def sample_indices(self, batch_size: int) -> np.ndarray:
if batch_size > 0 and len(self) > 0:
def sample_indices(self, batch_size: int | None) -> np.ndarray:
if batch_size is not None and batch_size > 0 and len(self) > 0:
scalar = np.random.rand(batch_size) * self.weight.reduce()
return self.weight.get_prefix_sum_idx(scalar) # type: ignore
return super().sample_indices(batch_size)
Expand Down
6 changes: 4 additions & 2 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def post_process_fn(

def update(
self,
sample_size: int,
sample_size: int | None,
buffer: ReplayBuffer | None,
**kwargs: Any,
) -> dict[str, Any]:
Expand All @@ -382,7 +382,9 @@ def update(
Please refer to :ref:`policy_state` for more detailed explanation.
:param sample_size: 0 means it will extract all the data from the buffer,
otherwise it will sample a batch with given sample_size.
otherwise it will sample a batch with given sample_size. None also
means it will extract all the data from the buffer, but it will be shuffled
first. TODO: remove the option for 0?
:param buffer: the corresponding replay buffer.
:return: A dict, including the data needed to be logged (e.g., loss) from
Expand Down
2 changes: 1 addition & 1 deletion tianshou/policy/imitation/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def disc(self, batch: RolloutBatchProtocol) -> torch.Tensor:
def learn( # type: ignore
self,
batch: RolloutBatchProtocol,
batch_size: int,
batch_size: int | None,
repeat: int,
**kwargs: Any,
) -> dict[str, list[float]]:
Expand Down
5 changes: 3 additions & 2 deletions tianshou/policy/modelfree/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,15 @@ def _compute_returns(
def learn( # type: ignore
self,
batch: RolloutBatchProtocol,
batch_size: int,
batch_size: int | None,
repeat: int,
*args: Any,
**kwargs: Any,
) -> dict[str, list[float]]:
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
split_batch_size = batch_size or -1
for _ in range(repeat):
for minibatch in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(split_batch_size, merge_last=True):
# calculate loss for actor
dist = self(minibatch).dist
log_prob = dist.log_prob(minibatch.act)
Expand Down
5 changes: 3 additions & 2 deletions tianshou/policy/modelfree/npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,14 @@ def process_fn(
def learn( # type: ignore
self,
batch: Batch,
batch_size: int,
batch_size: int | None,
repeat: int,
**kwargs: Any,
) -> dict[str, list[float]]:
actor_losses, vf_losses, kls = [], [], []
split_batch_size = batch_size or -1
for _ in range(repeat):
for minibatch in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(split_batch_size, merge_last=True):
# optimize actor
# direction: calculate villia gradient
dist = self(minibatch).dist
Expand Down
5 changes: 3 additions & 2 deletions tianshou/policy/modelfree/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,15 @@ def forward(
def learn( # type: ignore
self,
batch: RolloutBatchProtocol,
batch_size: int,
batch_size: int | None,
repeat: int,
*args: Any,
**kwargs: Any,
) -> dict[str, list[float]]:
losses = []
split_batch_size = batch_size or -1
for _ in range(repeat):
for minibatch in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(split_batch_size, merge_last=True):
self.optim.zero_grad()
result = self(minibatch)
dist = result.dist
Expand Down
5 changes: 3 additions & 2 deletions tianshou/policy/modelfree/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,17 @@ def process_fn(
def learn( # type: ignore
self,
batch: RolloutBatchProtocol,
batch_size: int,
batch_size: int | None,
repeat: int,
*args: Any,
**kwargs: Any,
) -> dict[str, list[float]]:
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
split_batch_size = batch_size or -1
for step in range(repeat):
if self.recompute_adv and step > 0:
batch = self._compute_returns(batch, self._buffer, self._indices)
for minibatch in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(split_batch_size, merge_last=True):
# calculate loss for actor
dist = self(minibatch).dist
if self.norm_adv:
Expand Down
5 changes: 3 additions & 2 deletions tianshou/policy/modelfree/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,14 @@ def __init__(
def learn( # type: ignore
self,
batch: Batch,
batch_size: int,
batch_size: int | None,
repeat: int,
**kwargs: Any,
) -> dict[str, list[float]]:
actor_losses, vf_losses, step_sizes, kls = [], [], [], []
split_batch_size = batch_size or -1
for _ in range(repeat):
for minibatch in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(split_batch_size, merge_last=True):
# optimize actor
# direction: calculate villia gradient
dist = self(minibatch).dist # TODO could come from batch
Expand Down
18 changes: 11 additions & 7 deletions tianshou/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class BaseTrainer(ABC):
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param batch_size: the batch size of sample data, which is going to feed in
the policy network.
the policy network. If None, will use the whole buffer in each gradient step.
:param train_collector: the collector used for training.
:param test_collector: the collector used for testing. If it's None,
then no testing will be performed.
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(
self,
policy: BasePolicy,
max_epoch: int,
batch_size: int,
batch_size: int | None,
train_collector: Collector | None = None,
test_collector: Collector | None = None,
buffer: ReplayBuffer | None = None,
Expand Down Expand Up @@ -320,7 +320,9 @@ def __next__(self) -> None | tuple[int, dict[str, Any], dict[str, Any]]:

# for offline RL
if self.train_collector is None:
self.env_step = self.gradient_step * self.batch_size
assert self.buffer is not None
batch_size = self.batch_size or len(self.buffer)
self.env_step = self.gradient_step * batch_size

if not self.stop_fn_flag:
self.logger.save_data(
Expand Down Expand Up @@ -565,9 +567,9 @@ def policy_update_fn(
"""Perform one on-policy update."""
assert self.train_collector is not None
losses = self.policy.update(
0,
self.train_collector.buffer,
# Note: sample_size is 0, so the whole buffer is used for the update.
sample_size=0,
buffer=self.train_collector.buffer,
# Note: sample_size is None, so the whole buffer is used for the update.
# The kwargs are in the end passed to the .learn method, which uses
# batch_size to iterate through the buffer in mini-batches
# Off-policy algos typically don't use the batch_size kwarg at all
Expand All @@ -579,7 +581,9 @@ def policy_update_fn(
# TODO: remove the gradient step counting in trainers? Doesn't seem like
# it's important and it adds complexity
self.gradient_step += 1
if self.batch_size > 0:
if self.batch_size is None:
self.gradient_step += 1
elif self.batch_size > 0:
self.gradient_step += int((len(self.train_collector.buffer) - 0.1) // self.batch_size)

# Note: this is the main difference to the off-policy trainer!
Expand Down

0 comments on commit 8d3d1f1

Please sign in to comment.