-
Notifications
You must be signed in to change notification settings - Fork 0
/
ReplayBuffer.py
32 lines (28 loc) · 1.24 KB
/
ReplayBuffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import numpy as np
class ReplayBuffer(object):
def __init__(self, state_dim, action_dim, args):
self.max_size = args.buffer_size
self.count = 0
self.size = 0
self.s = np.zeros((self.max_size, state_dim))
self.a = np.zeros((self.max_size, action_dim))
self.r = np.zeros((self.max_size, 1))
self.s_ = np.zeros((self.max_size, state_dim))
self.dw = np.zeros((self.max_size, 1))
def store(self, s, a, r, s_, dw):
self.s[self.count] = s
self.a[self.count] = a
self.r[self.count] = r
self.s_[self.count] = s_
self.dw[self.count] = dw
self.count = (self.count + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def sample(self, batch_size):
index = np.random.choice(self.size, size=batch_size) # Randomly sampling
batch_s = torch.tensor(self.s[index], dtype=torch.float)
batch_a = torch.tensor(self.a[index], dtype=torch.float)
batch_r = torch.tensor(self.r[index], dtype=torch.float)
batch_s_ = torch.tensor(self.s_[index], dtype=torch.float)
batch_dw = torch.tensor(self.dw[index], dtype=torch.float)
return batch_s, batch_a, batch_r, batch_s_, batch_dw