-
Notifications
You must be signed in to change notification settings - Fork 249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FR] Support for optax.contrib.reduce_on_plateau #1955
Comments
It is a good idea to support this feature. Do you want to submit a PR? It is fine to incorporate the logic directly into the NumPyro Optim. For the transform, we can use https://optax.readthedocs.io/en/latest/api/transformations.html#optax.with_extra_args_support I guess. |
This would be fantastic to have! |
@fehiepsi I wasn't sure how you would want it to work. It isn't enough to just say "this has extra args" -- we have to tell it what to pass into them. In my code above, for the plateau optimizer, it is the training loss that needs to be passed in. Do I make this a special case or do I need to make a more general tool? |
I just meant that under that utility, all optax optimizer has the same signature and we can always pass The new NumPyro optimizer would be your _NumPyroOptimValueArg implementation. The new optax_to_numpyro utility would be your optax_to_numpyro_value_args. |
Agree with @juanitorduz, this would be awesome to have. It would also enable newton methods and line search. |
For SVI, learning rate is extremely influential, see e.g. this discussion post: https://forum.pyro.ai/t/does-svi-converges-towards-the-right-solution-4-parameters-mvn/3677/4
The guidance there is to just play around with learning rate until you get convergence, but this is both expensive and annoying to attempt programmatically (e.g. when fitting many models for cross-validation).
Optax contains a learning rate scheduler for this that works really well, but it isn't currently easy to use this in NumPyro because it takes the current loss as an extra argument.
Here is some code that does it, based on slight modifications to
optax_to_numpyro
and_NumPyroOptim
:Then you can run e.g. the SVI example from the docs with:
The text was updated successfully, but these errors were encountered: