-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathsampling_utils.py
109 lines (99 loc) · 4.14 KB
/
sampling_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import numpy as np
import rllab.misc.logger as logger
class SimpleReplayPool(object):
"""
Used from https://raw.githubusercontent.com/shaneshixiang/rllabplusplus/master/rllab/pool/simple_pool.py
"""
def __init__(
self, max_pool_size, observation_dim, action_dim,
replacement_policy='stochastic', replacement_prob=1.0,
max_skip_episode=10):
self._observation_dim = observation_dim
self._action_dim = action_dim
self._max_pool_size = max_pool_size
self._replacement_policy = replacement_policy
self._replacement_prob = replacement_prob
self._max_skip_episode = max_skip_episode
self._observations = np.zeros(
(max_pool_size, observation_dim),
)
self._actions = np.zeros(
(max_pool_size, action_dim),
)
self._rewards = np.zeros(max_pool_size)
self._terminals = np.zeros(max_pool_size, dtype='uint8')
self._initials = np.zeros(max_pool_size, dtype='uint8')
self._bottom = 0
self._top = 0
self._size = 0
def add_sample(self, observation, action, reward, terminal, initial):
self.check_replacement()
self._observations[self._top] = observation
self._actions[self._top] = action
self._rewards[self._top] = reward
self._terminals[self._top] = terminal
self._initials[self._top] = initial
self.advance()
def check_replacement(self):
if self._replacement_prob < 1.0:
if self._size < self._max_pool_size or \
not self._initials[self._top]: return
self.advance_until_terminate()
def get_skip_flag(self):
if self._replacement_policy == 'full': skip = False
elif self._replacement_policy == 'stochastic':
skip = np.random.uniform() > self._replacement_prob
else: raise NotImplementedError
return skip
def advance_until_terminate(self):
skip = self.get_skip_flag()
n_skips = 0
old_top = self._top
new_top = (old_top + 1) % self._max_pool_size
while skip and old_top != new_top and n_skips < self._max_skip_episode:
n_skips += 1
self.advance()
while not self._initials[self._top]:
self.advance()
skip = self.get_skip_flag()
new_top = self._top
logger.log("add_sample, skipped %d episodes, top=%d->%d"%(
n_skips, old_top, new_top))
def advance(self):
self._top = (self._top + 1) % self._max_pool_size
if self._size >= self._max_pool_size:
self._bottom = (self._bottom + 1) % self._max_pool_size
else:
self._size += 1
def random_batch(self, batch_size):
assert self._size > batch_size
indices = np.zeros(batch_size, dtype='uint64')
transition_indices = np.zeros(batch_size, dtype='uint64')
count = 0
while count < batch_size:
index = np.random.randint(self._bottom, self._bottom + self._size) % self._max_pool_size
# make sure that the transition is valid: if we are at the end of the pool, we need to discard
# this sample
if index == self._size - 1 and self._size <= self._max_pool_size:
continue
# if self._terminals[index]:
# continue
transition_index = (index + 1) % self._max_pool_size
# make sure that the transition is valid: discard the transition if it crosses horizon-triggered resets
if not self._terminals[index] and self._initials[transition_index]:
continue
indices[count] = index
transition_indices[count] = transition_index
count += 1
return dict(
observations=self._observations[indices],
actions=self._actions[indices],
rewards=self._rewards[indices],
terminals=self._terminals[indices],
initials=self._initials[indices],
next_observations=self._observations[transition_indices]
)
@property
def size(self):
return self._size