forked from tesslerc/TD3-JAX
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreplay_buffer.py
59 lines (46 loc) · 1.59 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from dataclasses import dataclass
import numpy as np
import jax.numpy as jnp
import jax
from typing import Iterable
@dataclass
class Sample:
T = np.ndarray
obs: T
action: T
next_obs: T
reward: T
done: T
@dataclass
class Step(Sample):
pass
class ReplayBuffer(object):
"""A simple container for maintaining the history of the agent."""
def __init__(self, obs_shape: Iterable, action_shape: Iterable, max_size: int):
self.max_size = max_size
self.ptr = 0
self.size = 0
self.obs = np.zeros((max_size, *obs_shape))
self.action = np.zeros((max_size, *action_shape))
self.next_obs = np.zeros((max_size, *obs_shape))
self.reward = np.zeros((max_size, 1))
self.done = np.zeros((max_size, 1))
def add(self, sample: Sample) -> None:
"""Memory built for per-transition interaction, does not handle batch updates."""
self.obs[self.ptr] = sample.obs
self.action[self.ptr] = sample.action
self.next_obs[self.ptr] = sample.next_obs
self.reward[self.ptr] = sample.reward
self.done[self.ptr] = sample.done
self.ptr = (self.ptr + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def sample(self, batch_size: int, rng: jnp.ndarray) -> Sample:
"""Given a JAX PRNG key, sample batch from memory."""
ind = jax.random.randint(rng, (batch_size,), 0, self.size)
return Sample(
self.obs[ind],
self.action[ind],
self.next_obs[ind],
self.reward[ind],
self.done[ind],
)