Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Error in initialize carry function for ppo_rnn #29

Open
corentinlger opened this issue Aug 25, 2024 · 2 comments
Open

[Bug] Error in initialize carry function for ppo_rnn #29

corentinlger opened this issue Aug 25, 2024 · 2 comments

Comments

@corentinlger
Copy link

Hello, I wanted to use ppo_rnn.py and encountered an on error when using the algorithm. It was about the input arguments of the initialize_carry function to create the carry for the GRUCell.

I think this is due to an update of Flax RNNs API :

  • the arguments of initialize_carry are now (rng, input_shape) instead of (rng, batch_size, hidden_size)
  • Additionally you now have to provide num_features as an argument of RNNCells (GRUCell here)

Code to reproduce the error :

import jax 
from purejaxrl.ppo_rnn import make_train

config = {
    "LR": 2.5e-4,
    "NUM_ENVS": 4,
    "NUM_STEPS": 128,
    "TOTAL_TIMESTEPS": 5e5,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 4,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.01,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "tanh",
    "ENV_NAME": "CartPole-v1",
    "ANNEAL_LR": True,
}

rng = jax.random.PRNGKey(42)
train_jit = jax.jit(make_train(config))
out = train_jit(rng)

Error message :

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 24
     22 rng = jax.random.PRNGKey(42)
     23 train_jit = jax.jit(make_train(config))
---> 24 out = train_jit(rng)

    [... skipping hidden 11 frame]

File[ ~/Desktop/code/purejaxrl/purejaxrl/ppo_rnn.py:121](about:blank), in make_train.<locals>.train(rng)
    114 rng, _rng = jax.random.split(rng)
    115 init_x = (
    116     jnp.zeros(
    117         (1, config["NUM_ENVS"], *env.observation_space(env_params).shape)
    118     ),
    119     jnp.zeros((1, config["NUM_ENVS"])),
    120 )
-->[ 121](about:blank) init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
    122 network_params = network.init(_rng, init_hstate, init_x)
    123 if config["ANNEAL_LR"]:

File[ ~/Desktop/code/purejaxrl/purejaxrl/ppo_rnn.py:41](about:blank), in ScannedRNN.initialize_carry(batch_size, hidden_size)
     38 @staticmethod
     39 def initialize_carry(batch_size, hidden_size):
     40     # Use a dummy key since the default state init fn is just zeros.
--->[ 41](about:blank)     return nn.GRUCell.initialize_carry(
     42         jax.random.PRNGKey(0), (batch_size,), hidden_size
     43     )

File[ ~/Desktop/code/purejaxrl/venv/lib/python3.10/site-packages/flax/linen/recurrent.py:614](about:blank), in GRUCell.initialize_carry(self, rng, input_shape)
    603 @nowrap
    604 def initialize_carry(self, rng: PRNGKey, input_shape: Tuple[int, ...]):
    605  """Initialize the RNN cell carry.
    606 
    607  Args:
   (...)
    612    An initialized carry for the given RNN cell.
    613  """
-->[ 614](about:blank)   batch_dims = input_shape[:-1]
    615   mem_shape = batch_dims + (self.features,)
    616   return self.carry_init(rng, mem_shape, self.param_dtype)

TypeError: 'int' object is not subscriptable

If this is indeed the error, do you want me to do a PR to fix it ?

@smokbel
Copy link

smokbel commented Sep 30, 2024

Another issue that occurs (after fixing this one) is:

AttributeError: DynamicJaxprTracer has no attribute features

coming from the GRUCell initialize_carry function in flax, when trying to access its features attribute within a traced object.

@corentinlger
Copy link
Author

Yes I think it's the second point I mention in the issue.

You can maybe try this fixed version of the file (it worked 1 month ago) : https://github.com/corentinlger/purejaxrl/blob/fix_ppo_rnn/purejaxrl/ppo_rnn.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants