You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I wanted to use ppo_rnn.py and encountered an on error when using the algorithm. It was about the input arguments of the initialize_carry function to create the carry for the GRUCell.
I think this is due to an update of Flax RNNs API :
the arguments of initialize_carry are now (rng, input_shape) instead of (rng, batch_size, hidden_size)
Additionally you now have to provide num_features as an argument of RNNCells (GRUCell here)
Hello, I wanted to use ppo_rnn.py and encountered an on error when using the algorithm. It was about the input arguments of the
initialize_carry
function to create the carry for the GRUCell.I think this is due to an update of Flax RNNs API :
initialize_carry
are now (rng, input_shape) instead of (rng, batch_size, hidden_size)num_features
as an argument ofRNNCells
(GRUCell here)Code to reproduce the error :
Error message :
If this is indeed the error, do you want me to do a PR to fix it ?
The text was updated successfully, but these errors were encountered: