-
Notifications
You must be signed in to change notification settings - Fork 819
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add decision transformer * update README * upgrade maximum cloumn length * increase line limit * specify py3 * Update .pre-commit-config.yaml * remove unrelated files * Update README.md * Update README.md * Update README.md * remove the argument mode * fix bugs in data_loader * fix bugs in data_loader * address comments
- Loading branch information
1 parent
04439b5
commit b081974
Showing
11 changed files
with
1,438 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
## Introduction | ||
Based on PARL, we provide the implementation of decision transformer, with the same performance as reported in the original paper. | ||
|
||
> Paper: [Decision Transformer: Reinforcement | ||
Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) | ||
|
||
### Dataset for RL | ||
Follow the installation instruction in [D4RL](https://github.com/Farama-Foundation/D4RL) to install D4RL. | ||
Then run the scripts in `data` directory to download dataset for traininig. | ||
```shell | ||
python download_d4rl_datasets.py | ||
``` | ||
|
||
|
||
### Benchmark result | ||
#### 1. Mujoco results | ||
<p align="center"> | ||
<img src="https://github.com/benchmarking-rl/PARL-experiments/blob/master/DT/torch/mujoco_result.png" alt="mujoco-result"/> | ||
</p> | ||
|
||
+ Each experiment was run three times with different random seeds | ||
|
||
## How to use | ||
### Dependencies: | ||
+ [D4RL](//github.com/Farama-Foundation/D4RL) | ||
+ [parl>=2.1](https://github.com/PaddlePaddle/PARL) | ||
+ pytorch | ||
+ gym==0.18.3 | ||
+ mujoco-py==2.0.2.13 | ||
+ transformers==4.5.1 | ||
|
||
|
||
### Training: | ||
|
||
```shell | ||
# To train an agent for `hooper` environment with `medium` dataset | ||
python train.py --env hopper --dataset medium | ||
|
||
# To train an agent for `hooper` environment with `expert` dataset | ||
python train.py --env hopper --dataset expert | ||
``` | ||
|
||
|
||
### Reference | ||
|
||
https://github.com/kzl/decision-transformer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import parl | ||
import torch | ||
import numpy as np | ||
from data_loader import DataLoader | ||
|
||
|
||
class DTAgent(parl.Agent): | ||
def __init__(self, algorithm, config): | ||
super(DTAgent, self).__init__(algorithm) | ||
self.dataset = None | ||
self.config = config | ||
|
||
def predict(self, states, actions, rewards, returns_to_go, timesteps, **kwargs): | ||
action = self.alg.predict(states, actions, rewards, returns_to_go, timesteps) | ||
actions[-1] = action | ||
action = action.detach().cpu().numpy() | ||
return action | ||
|
||
def learn(self): | ||
batch_data = self.dataset.get_batch(self.config['batch_size']) | ||
loss = self.alg.learn(*batch_data) | ||
loss = loss.detach().cpu().item() | ||
return loss | ||
|
||
def load_data(self, dataset_path): | ||
config = self.config | ||
self.dataset = DataLoader(dataset_path, config['pct_traj'], config['max_ep_len'], | ||
config['rew_scale']) | ||
return self.dataset.state_mean, self.dataset.state_std |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Third party code | ||
# The following code is are copied or modified from: | ||
# https://github.com/kzl/decision-transformer/blob/master/gym/data/download_d4rl_datasets.py | ||
import gym | ||
import numpy as np | ||
|
||
import collections | ||
import pickle | ||
|
||
import d4rl | ||
|
||
datasets = [] | ||
|
||
for env_name in ['halfcheetah', 'hopper', 'walker2d']: | ||
for dataset_type in ['medium', 'medium-replay', 'expert']: | ||
name = f'{env_name}-{dataset_type}-v2' | ||
env = gym.make(name) | ||
dataset = env.get_dataset() | ||
|
||
N = dataset['rewards'].shape[0] | ||
data_ = collections.defaultdict(list) | ||
|
||
use_timeouts = False | ||
if 'timeouts' in dataset: | ||
use_timeouts = True | ||
|
||
episode_step = 0 | ||
paths = [] | ||
for i in range(N): | ||
done_bool = bool(dataset['terminals'][i]) | ||
if use_timeouts: | ||
final_timestep = dataset['timeouts'][i] | ||
else: | ||
final_timestep = (episode_step == 1000 - 1) | ||
for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']: | ||
data_[k].append(dataset[k][i]) | ||
if done_bool or final_timestep: | ||
episode_step = 0 | ||
episode_data = {} | ||
for k in data_: | ||
episode_data[k] = np.array(data_[k]) | ||
paths.append(episode_data) | ||
data_ = collections.defaultdict(list) | ||
episode_step += 1 | ||
|
||
returns = np.array([np.sum(p['rewards']) for p in paths]) | ||
num_samples = np.sum([p['rewards'].shape[0] for p in paths]) | ||
print(f'Number of samples collected: {num_samples}') | ||
print( | ||
f'Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}' | ||
) | ||
|
||
with open(f'{name}.pkl', 'wb') as f: | ||
pickle.dump(paths, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Third Party Code | ||
# The following code is copied or modified from: | ||
# https://github.com/kzl/decision-transformer/tree/master/gym/decision_transformer | ||
|
||
import numpy as np | ||
import pickle | ||
import random | ||
import torch | ||
from parl.utils import logger | ||
|
||
|
||
class DataLoader(object): | ||
def __init__(self, dataset_path, pct_traj, max_ep_len, scale): | ||
self.dataset_path = dataset_path | ||
self.pct_traj = pct_traj | ||
self.scale = scale | ||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
self.max_ep_len = max_ep_len | ||
with open(dataset_path, 'rb') as f: | ||
trajectories = pickle.load(f) | ||
|
||
states, traj_lens, returns = [], [], [] | ||
for path in trajectories: | ||
states.append(path['observations']) | ||
traj_lens.append(len(path['observations'])) | ||
returns.append(path['rewards'].sum()) | ||
traj_lens, returns = np.array(traj_lens), np.array(returns) | ||
|
||
# used for input normalization | ||
states = np.concatenate(states, axis=0) | ||
self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 | ||
|
||
num_timesteps = sum(traj_lens) | ||
|
||
# only train on top pct_traj trajectories (for %BC experiment) | ||
num_timesteps = max(int(pct_traj * num_timesteps), 1) | ||
sorted_inds = np.argsort(returns) # lowest to highest | ||
num_trajectories = 1 | ||
timesteps = traj_lens[sorted_inds[-1]] | ||
ind = len(trajectories) - 2 | ||
while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps: | ||
timesteps += traj_lens[sorted_inds[ind]] | ||
num_trajectories += 1 | ||
ind -= 1 | ||
sorted_inds = sorted_inds[-num_trajectories:] | ||
|
||
# used to reweight sampling so we sample according to timesteps instead of trajectories | ||
p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds]) | ||
|
||
logger.info(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found') | ||
logger.info(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}') | ||
logger.info(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}') | ||
self.trajectories = trajectories | ||
self.num_trajectories = num_trajectories | ||
self.p_sample = p_sample | ||
self.sorted_inds = sorted_inds | ||
self.state_dim = trajectories[0]['observations'].shape[1] | ||
self.act_dim = trajectories[0]['actions'].shape[1] | ||
|
||
def get_batch(self, batch_size=256, max_len=20): | ||
batch_inds = np.random.choice( | ||
np.arange(self.num_trajectories), | ||
size=batch_size, | ||
replace=True, | ||
p=self.p_sample, # reweights so we sample according to timesteps | ||
) | ||
|
||
s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], [] | ||
for i in range(batch_size): | ||
traj = self.trajectories[int(self.sorted_inds[batch_inds[i]])] | ||
si = random.randint(0, traj['rewards'].shape[0] - 1) | ||
|
||
# get sequences from dataset | ||
s.append(traj['observations'][si:si + max_len].reshape(1, -1, self.state_dim)) | ||
a.append(traj['actions'][si:si + max_len].reshape(1, -1, self.act_dim)) | ||
r.append(traj['rewards'][si:si + max_len].reshape(1, -1, 1)) | ||
if 'terminals' in traj: | ||
d.append(traj['terminals'][si:si + max_len].reshape(1, -1)) | ||
else: | ||
d.append(traj['dones'][si:si + max_len].reshape(1, -1)) | ||
timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1)) | ||
timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len - 1 # padding cutoff | ||
rtg.append(self.discount_cumsum(traj['rewards'][si:], gamma=1.)[:s[-1].shape[1] + 1].reshape(1, -1, 1)) | ||
if rtg[-1].shape[1] <= s[-1].shape[1]: | ||
rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1) | ||
|
||
# padding and state + reward normalization | ||
tlen = s[-1].shape[1] | ||
s[-1] = np.concatenate([np.zeros((1, max_len - tlen, self.state_dim)), s[-1]], axis=1) | ||
s[-1] = (s[-1] - self.state_mean) / self.state_std | ||
a[-1] = np.concatenate([np.ones((1, max_len - tlen, self.act_dim)) * -10., a[-1]], axis=1) | ||
r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1) | ||
d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1) | ||
rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / self.scale | ||
timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1) | ||
mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1)) | ||
|
||
device = self.device | ||
s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device) | ||
a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device) | ||
r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=device) | ||
d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=device) | ||
rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device) | ||
timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device) | ||
mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device) | ||
|
||
return s, a, r, d, rtg, timesteps, mask | ||
|
||
def discount_cumsum(self, x, gamma): | ||
discount_cumsum = np.zeros_like(x) | ||
discount_cumsum[-1] = x[-1] | ||
for t in reversed(range(x.shape[0] - 1)): | ||
discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1] | ||
return discount_cumsum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Third Party Code | ||
# The following code is copied or modified from: | ||
# https://github.com/kzl/decision-transformer/tree/master/gym/decision_transformer | ||
|
||
import numpy as np | ||
import torch | ||
from tqdm import tqdm | ||
|
||
|
||
def evaluate_episode_rtg( | ||
env, | ||
state_dim, | ||
act_dim, | ||
agent, | ||
max_ep_len=1000, | ||
scale=1000., | ||
state_mean=0., | ||
state_std=1., | ||
device='cuda', | ||
target_return=None, | ||
): | ||
|
||
state_mean = torch.from_numpy(state_mean).to(device=device) | ||
state_std = torch.from_numpy(state_std).to(device=device) | ||
|
||
state = env.reset() | ||
|
||
# we keep all the histories on the device | ||
# note that the latest action and reward will be "padding" | ||
states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32) | ||
actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32) | ||
rewards = torch.zeros(0, device=device, dtype=torch.float32) | ||
|
||
ep_return = target_return | ||
target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(1, 1) | ||
timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1) | ||
|
||
sim_states = [] | ||
|
||
episode_return, episode_length = 0, 0 | ||
for t in range(max_ep_len): | ||
|
||
# add padding | ||
actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0) | ||
rewards = torch.cat([rewards, torch.zeros(1, device=device)]) | ||
|
||
action = agent.predict( | ||
(states.to(dtype=torch.float32) - state_mean) / state_std, | ||
actions.to(dtype=torch.float32), | ||
rewards.to(dtype=torch.float32), | ||
target_return.to(dtype=torch.float32), | ||
timesteps.to(dtype=torch.long), | ||
) | ||
|
||
state, reward, done, _ = env.step(action) | ||
|
||
cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim) | ||
states = torch.cat([states, cur_state], dim=0) | ||
rewards[-1] = reward | ||
|
||
pred_return = target_return[0, -1] - (reward / scale) | ||
target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1) | ||
timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1) | ||
|
||
episode_return += reward | ||
episode_length += 1 | ||
|
||
if done: | ||
break | ||
|
||
return episode_return, episode_length | ||
|
||
|
||
def eval_episodes(target_rew, env, state_dim, act_dim, agent, max_ep_len, scale, state_mean, state_std, device): | ||
returns, lengths = [], [] | ||
num_eval_episodes = 100 | ||
for _ in tqdm(range(num_eval_episodes), desc='eval'): | ||
with torch.no_grad(): | ||
ret, length = evaluate_episode_rtg( | ||
env, | ||
state_dim, | ||
act_dim, | ||
agent, | ||
max_ep_len=max_ep_len, | ||
scale=scale, | ||
target_return=target_rew / scale, | ||
state_mean=state_mean, | ||
state_std=state_std, | ||
device=device, | ||
) | ||
returns.append(ret) | ||
lengths.append(length) | ||
return { | ||
f'target_{target_rew}_return_mean': np.mean(returns), | ||
f'target_{target_rew}_return_std': np.std(returns), | ||
f'target_{target_rew}_length_mean': np.mean(lengths), | ||
f'target_{target_rew}_length_std': np.std(lengths), | ||
} |
Oops, something went wrong.