-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
120 changed files
with
11,224 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Contributing to Keras-RL | ||
|
||
New contributors are very welcomed! If you're interested, please post a message on the [Gitter](https://gitter.im/keras-rl/Lobby). | ||
|
||
Here is a list of ways you can contribute to this repository: | ||
- Tackle an open issue on [Github](https://github.com/keras-rl/keras-rl/issues) | ||
- Improve documentation | ||
- Improve test coverage | ||
- Add examples | ||
- Implement new algorithms on Keras-RL (please get in touch on Gitter) | ||
- Link to your personal projects built on top of Keras-RL | ||
|
||
|
||
## How to run the tests | ||
|
||
To run the tests locally, you'll first have to install the following dependencies: | ||
```bash | ||
pip install pytest pytest-xdist pep8 pytest-pep8 pytest-cov python-coveralls | ||
``` | ||
You can then run all tests using this command: | ||
```bash | ||
py.test tests/. | ||
``` | ||
If you want to check if the files conform to the PEP8 style guidelines, run the following command: | ||
```bash | ||
py.test --pep8 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
Please make sure that the boxes below are checked before you submit your issue. If your issue is an implementation question, please ask your question in the [Keras-RL Google group](https://groups.google.com/forum/#!forum/keras-rl-users) or [join the Keras-RL Gitter channel](https://gitter.im/keras-rl/Lobby) and ask there instead of filing a GitHub issue. | ||
|
||
Thank you! | ||
|
||
- [ ] Check that you are up-to-date with the master branch of Keras-RL. You can update with: | ||
`pip install git+git://github.com/keras-rl/keras-rl.git --upgrade --no-deps` | ||
|
||
- [ ] Check that you are up-to-date with the master branch of Keras. You can update with: | ||
`pip install git+git://github.com/fchollet/keras.git --upgrade --no-deps` | ||
|
||
- [ ] Provide a link to a GitHub Gist of a Python script that can reproduce your issue (or just copy the script here if it is short). If you report an error, please include the error message and the backtrace. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
The MIT License (MIT) | ||
|
||
Copyright (c) 2016 Matthias Plappert | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,102 @@ | ||
# blockchain_networking_DRL | ||
# Deep Reinforcement Learning for Keras | ||
[data:image/s3,"s3://crabby-images/8ebe2/8ebe263e746516ab7eb4d9de38a59d7ed3f63bb2" alt="Build Status"](https://travis-ci.org/keras-rl/keras-rl) | ||
[data:image/s3,"s3://crabby-images/8b443/8b443a7371d6d0a25fc95fca5bf359b5ecaa28cc" alt="Documentation"](http://keras-rl.readthedocs.io/) | ||
[data:image/s3,"s3://crabby-images/0285e/0285e4221aa4d153cf6fc44904072d05aff0db53" alt="License"](https://github.com/keras-rl/keras-rl/blob/master/LICENSE) | ||
[data:image/s3,"s3://crabby-images/74df4/74df41cd453f94eef39ae318d252cfb57cd60ce4" alt="Join the chat at https://gitter.im/keras-rl/Lobby"](https://gitter.im/keras-rl/Lobby) | ||
|
||
|
||
<table> | ||
<tr> | ||
<td><img src="/assets/breakout.gif?raw=true" width="200"></td> | ||
<td><img src="/assets/cartpole.gif?raw=true" width="200"></td> | ||
<td><img src="/assets/pendulum.gif?raw=true" width="200"></td> | ||
</tr> | ||
</table> | ||
|
||
|
||
## What is it? | ||
|
||
`keras-rl` implements some state-of-the art deep reinforcement learning algorithms in Python and seamlessly integrates with the deep learning library [Keras](http://keras.io). | ||
|
||
Furthermore, `keras-rl` works with [OpenAI Gym](https://gym.openai.com/) out of the box. This means that evaluating and playing around with different algorithms is easy. | ||
|
||
Of course you can extend `keras-rl` according to your own needs. You can use built-in Keras callbacks and metrics or define your own. | ||
Even more so, it is easy to implement your own environments and even algorithms by simply extending some simple abstract classes. Documentation is available [online](http://keras-rl.readthedocs.org). | ||
|
||
|
||
## What is included? | ||
As of today, the following algorithms have been implemented: | ||
|
||
- [x] Deep Q Learning (DQN) [[1]](http://arxiv.org/abs/1312.5602), [[2]](https://www.nature.com/articles/nature14236) | ||
- [x] Double DQN [[3]](http://arxiv.org/abs/1509.06461) | ||
- [x] Deep Deterministic Policy Gradient (DDPG) [[4]](http://arxiv.org/abs/1509.02971) | ||
- [x] Continuous DQN (CDQN or NAF) [[6]](http://arxiv.org/abs/1603.00748) | ||
- [x] Cross-Entropy Method (CEM) [[7]](http://learning.mpi-sws.org/mlss2016/slides/2016-MLSS-RL.pdf), [[8]](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.81.6579&rep=rep1&type=pdf) | ||
- [x] Dueling network DQN (Dueling DQN) [[9]](https://arxiv.org/abs/1511.06581) | ||
- [x] Deep SARSA [[10]](http://people.inf.elte.hu/lorincz/Files/RL_2006/SuttonBook.pdf) | ||
- [ ] Asynchronous Advantage Actor-Critic (A3C) [[5]](http://arxiv.org/abs/1602.01783) | ||
- [ ] Proximal Policy Optimization Algorithms (PPO) [[11]](https://arxiv.org/abs/1707.06347) | ||
|
||
You can find more information on each agent in the [doc](http://keras-rl.readthedocs.io/en/latest/agents/overview/). | ||
|
||
|
||
## Installation | ||
|
||
- Install Keras-RL from Pypi (recommended): | ||
|
||
``` | ||
pip install keras-rl | ||
``` | ||
|
||
- Install from Github source: | ||
|
||
``` | ||
git clone https://github.com/keras-rl/keras-rl.git | ||
cd keras-rl | ||
python setup.py install | ||
``` | ||
|
||
## Examples | ||
|
||
If you want to run the examples, you'll also have to install: | ||
- **gym** by OpenAI: [Installation instruction](https://github.com/openai/gym#installation) | ||
- **h5py**: simply run `pip install h5py` | ||
|
||
Once you have installed everything, you can try out a simple example: | ||
```bash | ||
python examples/dqn_cartpole.py | ||
``` | ||
This is a very simple example and it should converge relatively quickly, so it's a great way to get started! | ||
It also visualizes the game during training, so you can watch it learn. How cool is that? | ||
|
||
Some sample weights are available on [keras-rl-weights](https://github.com/matthiasplappert/keras-rl-weights). | ||
|
||
If you have questions or problems, please file an issue or, even better, fix the problem yourself and submit a pull request! | ||
|
||
## Citing | ||
|
||
If you use `keras-rl` in your research, you can cite it as follows: | ||
```bibtex | ||
@misc{plappert2016kerasrl, | ||
author = {Matthias Plappert}, | ||
title = {keras-rl}, | ||
year = {2016}, | ||
publisher = {GitHub}, | ||
journal = {GitHub repository}, | ||
howpublished = {\url{https://github.com/keras-rl/keras-rl}}, | ||
} | ||
``` | ||
|
||
## References | ||
|
||
1. *Playing Atari with Deep Reinforcement Learning*, Mnih et al., 2013 | ||
2. *Human-level control through deep reinforcement learning*, Mnih et al., 2015 | ||
3. *Deep Reinforcement Learning with Double Q-learning*, van Hasselt et al., 2015 | ||
4. *Continuous control with deep reinforcement learning*, Lillicrap et al., 2015 | ||
5. *Asynchronous Methods for Deep Reinforcement Learning*, Mnih et al., 2016 | ||
6. *Continuous Deep Q-Learning with Model-based Acceleration*, Gu et al., 2016 | ||
7. *Learning Tetris Using the Noisy Cross-Entropy Method*, Szita et al., 2006 | ||
8. *Deep Reinforcement Learning (MLSS lecture notes)*, Schulman, 2016 | ||
9. *Dueling Network Architectures for Deep Reinforcement Learning*, Wang et al., 2016 | ||
10. *Reinforcement learning: An introduction*, Sutton and Barto, 2011 | ||
11. *Proximal Policy Optimization Algorithms*, Schulman et al., 2017 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from __future__ import absolute_import | ||
from .dqn import DQNAgent, NAFAgent, ContinuousDQNAgent | ||
from .ddpg import DDPGAgent | ||
from .cem import CEMAgent | ||
from .sarsa import SarsaAgent, SARSAAgent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
from __future__ import division | ||
from collections import deque | ||
from copy import deepcopy | ||
|
||
import numpy as np | ||
import keras.backend as K | ||
from keras.models import Model | ||
|
||
from rl.core import Agent | ||
from rl.util import * | ||
|
||
class CEMAgent(Agent): | ||
"""Write me | ||
""" | ||
def __init__(self, model, nb_actions, memory, batch_size=50, nb_steps_warmup=1000, | ||
train_interval=50, elite_frac=0.05, memory_interval=1, theta_init=None, | ||
noise_decay_const=0.0, noise_ampl=0.0, **kwargs): | ||
super(CEMAgent, self).__init__(**kwargs) | ||
|
||
# Parameters. | ||
self.nb_actions = nb_actions | ||
self.batch_size = batch_size | ||
self.elite_frac = elite_frac | ||
self.num_best = int(self.batch_size * self.elite_frac) | ||
self.nb_steps_warmup = nb_steps_warmup | ||
self.train_interval = train_interval | ||
self.memory_interval = memory_interval | ||
|
||
# if using noisy CEM, the minimum standard deviation will be ampl * exp (- decay_const * step ) | ||
self.noise_decay_const = noise_decay_const | ||
self.noise_ampl = noise_ampl | ||
|
||
# default initial mean & cov, override this by passing an theta_init argument | ||
self.init_mean = 0.0 | ||
self.init_stdev = 1.0 | ||
|
||
# Related objects. | ||
self.memory = memory | ||
self.model = model | ||
self.shapes = [w.shape for w in model.get_weights()] | ||
self.sizes = [w.size for w in model.get_weights()] | ||
self.num_weights = sum(self.sizes) | ||
|
||
# store the best result seen during training, as a tuple (reward, flat_weights) | ||
self.best_seen = (-np.inf, np.zeros(self.num_weights)) | ||
|
||
self.theta = np.zeros(self.num_weights*2) | ||
self.update_theta(theta_init) | ||
|
||
# State. | ||
self.episode = 0 | ||
self.compiled = False | ||
self.reset_states() | ||
|
||
def compile(self): | ||
self.model.compile(optimizer='sgd', loss='mse') | ||
self.compiled = True | ||
|
||
def load_weights(self, filepath): | ||
self.model.load_weights(filepath) | ||
|
||
def save_weights(self, filepath, overwrite=False): | ||
self.model.save_weights(filepath, overwrite=overwrite) | ||
|
||
def get_weights_flat(self,weights): | ||
weights_flat = np.zeros(self.num_weights) | ||
|
||
pos = 0 | ||
for i_layer, size in enumerate(self.sizes): | ||
weights_flat[pos:pos+size] = weights[i_layer].flatten() | ||
pos += size | ||
return weights_flat | ||
|
||
def get_weights_list(self,weights_flat): | ||
weights = [] | ||
pos = 0 | ||
for i_layer, size in enumerate(self.sizes): | ||
arr = weights_flat[pos:pos+size].reshape(self.shapes[i_layer]) | ||
weights.append(arr) | ||
pos += size | ||
return weights | ||
|
||
def reset_states(self): | ||
self.recent_observation = None | ||
self.recent_action = None | ||
|
||
def select_action(self, state, stochastic=False): | ||
batch = np.array([state]) | ||
if self.processor is not None: | ||
batch = self.processor.process_state_batch(batch) | ||
|
||
action = self.model.predict_on_batch(batch).flatten() | ||
if stochastic or self.training: | ||
return np.random.choice(np.arange(self.nb_actions), p=np.exp(action) / np.sum(np.exp(action))) | ||
return np.argmax(action) | ||
|
||
def update_theta(self,theta): | ||
if (theta is not None): | ||
assert theta.shape == self.theta.shape, "Invalid theta, shape is {0} but should be {1}".format(theta.shape,self.theta.shape) | ||
assert (not np.isnan(theta).any()), "Invalid theta, NaN encountered" | ||
assert (theta[self.num_weights:] >= 0.).all(), "Invalid theta, standard deviations must be nonnegative" | ||
self.theta = theta | ||
else: | ||
means = np.ones(self.num_weights) * self.init_mean | ||
stdevs = np.ones(self.num_weights) * self.init_stdev | ||
self.theta = np.hstack((means,stdevs)) | ||
|
||
def choose_weights(self): | ||
mean = self.theta[:self.num_weights] | ||
std = self.theta[self.num_weights:] | ||
weights_flat = std * np.random.randn(self.num_weights) + mean | ||
|
||
sampled_weights = self.get_weights_list(weights_flat) | ||
self.model.set_weights(sampled_weights) | ||
|
||
def forward(self, observation): | ||
# Select an action. | ||
state = self.memory.get_recent_state(observation) | ||
action = self.select_action(state) | ||
|
||
# Book-keeping. | ||
self.recent_observation = observation | ||
self.recent_action = action | ||
|
||
return action | ||
|
||
@property | ||
def layers(self): | ||
return self.model.layers[:] | ||
|
||
def backward(self, reward, terminal): | ||
# Store most recent experience in memory. | ||
if self.step % self.memory_interval == 0: | ||
self.memory.append(self.recent_observation, self.recent_action, reward, terminal, | ||
training=self.training) | ||
|
||
metrics = [np.nan for _ in self.metrics_names] | ||
if not self.training: | ||
# We're done here. No need to update the experience memory since we only use the working | ||
# memory to obtain the state over the most recent observations. | ||
return metrics | ||
|
||
if terminal: | ||
params = self.get_weights_flat(self.model.get_weights()) | ||
self.memory.finalize_episode(params) | ||
|
||
if self.step > self.nb_steps_warmup and self.episode % self.train_interval == 0: | ||
params, reward_totals = self.memory.sample(self.batch_size) | ||
best_idx = np.argsort(np.array(reward_totals))[-self.num_best:] | ||
best = np.vstack([params[i] for i in best_idx]) | ||
|
||
if reward_totals[best_idx[-1]] > self.best_seen[0]: | ||
self.best_seen = (reward_totals[best_idx[-1]], params[best_idx[-1]]) | ||
|
||
metrics = [np.mean(np.array(reward_totals)[best_idx])] | ||
if self.processor is not None: | ||
metrics += self.processor.metrics | ||
min_std = self.noise_ampl * np.exp(-self.step * self.noise_decay_const) | ||
|
||
mean = np.mean(best, axis=0) | ||
std = np.std(best, axis=0) + min_std | ||
new_theta = np.hstack((mean, std)) | ||
self.update_theta(new_theta) | ||
self.choose_weights() | ||
self.episode += 1 | ||
return metrics | ||
|
||
def _on_train_end(self): | ||
self.model.set_weights(self.get_weights_list(self.best_seen[1])) | ||
|
||
@property | ||
def metrics_names(self): | ||
names = ['mean_best_reward'] | ||
if self.processor is not None: | ||
names += self.processor.metrics_names[:] | ||
return names |
Oops, something went wrong.