-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NoisyNets implementation issues #189
Comments
I also realized that this might be an issue. If we want to resample noise we should use either explicitly pass in a new rng every time or use self.make_rng to ensure that RNGs are split correctly. |
Flax linen module variables are not able to be updated so the only way to have "new" random noise is to pass in the PRNG as a parameter like I have done in my example code |
Edit: I understood the original comment incorrectly -- it was pointing out the correlated noise in Line 316 & 317 -- It's unclear how much impact it has on performance but will fix it - thanks for pointing it out! Also, this should fix it:
I am not sure if this is a bug -- as @young-geng mentioned, if we want to resample noise, then we need to pass an explicit rng every time as done in the Here's a simplified example to verify that explicitly passing rng works:
|
hi, thanks for raising this! i agree with what rishabh pointed out. i believe once the rngs used for |
@agarwl Thanks, I hadn't spotted the FullRainbowNetwork implementation passed a new rng key to the noisy network each time so you are correct. With the modification that you propose then the noisy network works are expected But as the eval_mode and rng_key are attributes of the network then it is potentially misleading as these are actually attributes that need to be passed to the call function every time. And in reverse, the features, use_bias and kernel_init should not be modified after init. @psc-g I may be wrong but I think a new rng should be passed every time (when eval_mode = False) as if new noise is not added each time then all that is happening is a linear transformation is being applied to the inputs. |
Now I see that it passes in a new RNG key every time so I believe I was wrong about the noise not being resampled and the implementation should be correct. Sorry for the confusion. |
I'm implementing my own RL framework in Jax to better understand RL algorithms and found your code very helpful
Looking at the NoisyNets implementation, on line 316 and 317 (https://github.com/google/dopamine/blob/master/dopamine/jax/networks.py)
The same rng_key is used each time noise is generated meaning that no 'new' noise is generated each time an input is passed to the layer. In effect, the layer just applies a linear transform I think
This is a short testing example
If this is an issue, then I implemented the following code for my framework
Here is some similar testing code
I would have submitted this as a pull request but noticed that you are not accepting merges
The text was updated successfully, but these errors were encountered: