Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FR] Support for optax.contrib.reduce_on_plateau #1955

Open
zmbc opened this issue Jan 22, 2025 · 5 comments
Open

[FR] Support for optax.contrib.reduce_on_plateau #1955

zmbc opened this issue Jan 22, 2025 · 5 comments
Labels
enhancement New feature or request

Comments

@zmbc
Copy link

zmbc commented Jan 22, 2025

For SVI, learning rate is extremely influential, see e.g. this discussion post: https://forum.pyro.ai/t/does-svi-converges-towards-the-right-solution-4-parameters-mvn/3677/4

The guidance there is to just play around with learning rate until you get convergence, but this is both expensive and annoying to attempt programmatically (e.g. when fitting many models for cross-validation).

Optax contains a learning rate scheduler for this that works really well, but it isn't currently easy to use this in NumPyro because it takes the current loss as an extra argument.

Here is some code that does it, based on slight modifications to optax_to_numpyro and _NumPyroOptim:

from optax.contrib import reduce_on_plateau
import optax
from numpyro.optim import _NumPyroOptim, _Params, _IterOptState, _value_and_grad
from jax.typing import ArrayLike
from collections.abc import Callable
from typing import Any

class _NumPyroOptimValueArg(_NumPyroOptim):
    def update(self, g: _Params, state: _IterOptState, value) -> _IterOptState:
        """
        Gradient update for the optimizer.

        :param g: gradient information for parameters.
        :param state: current optimizer state.
        :return: new optimizer state after the update.
        """
        i, opt_state = state
        opt_state = self.update_fn(i, g, opt_state, value=value)
        return i + 1, opt_state

    def eval_and_update(
        self,
        fn: Callable[[Any], tuple],
        state: _IterOptState,
        forward_mode_differentiation: bool = False,
    ) -> tuple[tuple[Any, Any], _IterOptState]:
        """
        Performs an optimization step for the objective function `fn`.
        For most optimizers, the update is performed based on the gradient
        of the objective function w.r.t. the current state. However, for
        some optimizers such as :class:`Minimize`, the update is performed
        by reevaluating the function multiple times to get optimal
        parameters.

        :param fn: an objective function returning a pair where the first item
            is a scalar loss function to be differentiated and the second item
            is an auxiliary output.
        :param state: current optimizer state.
        :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
        :return: a pair of the output of objective function and the new optimizer state.
        """
        params: _Params = self.get_params(state)
        (out, aux), grads = _value_and_grad(
            fn, x=params, forward_mode_differentiation=forward_mode_differentiation
        )
        return (out, aux), self.update(grads, state, value=out)

    def eval_and_stable_update(
        self,
        fn: Callable[[Any], tuple],
        state: _IterOptState,
        forward_mode_differentiation: bool = False,
    ) -> tuple[tuple[Any, Any], _IterOptState]:
        """
        Like :meth:`eval_and_update` but when the value of the objective function
        or the gradients are not finite, we will not update the input `state`
        and will set the objective output to `nan`.

        :param fn: objective function.
        :param state: current optimizer state.
        :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
        :return: a pair of the output of objective function and the new optimizer state.
        """
        params: _Params = self.get_params(state)
        (out, aux), grads = _value_and_grad(
            fn, x=params, forward_mode_differentiation=forward_mode_differentiation
        )
        out, state = lax.cond(
            jnp.isfinite(out) & jnp.isfinite(ravel_pytree(grads)[0]).all(),
            lambda _: (out, self.update(grads, state, value=out)),
            lambda _: (jnp.nan, state),
            None,
        )
        return (out, aux), state

def optax_to_numpyro_value_arg(transformation) -> _NumPyroOptimValueArg:
    """
    This function produces a ``numpyro.optim._NumPyroOptim`` instance from an
    ``optax.GradientTransformation`` so that it can be used with
    ``numpyro.infer.svi.SVI``. It is a lightweight wrapper that recreates the
    ``(init_fn, update_fn, get_params_fn)`` interface defined by
    :mod:`jax.example_libraries.optimizers`.

    :param transformation: An ``optax.GradientTransformation`` instance to wrap.
    :return: An instance of ``numpyro.optim._NumPyroOptim`` wrapping the supplied
        Optax optimizer.
    """
    import optax

    def init_fn(params: _Params) -> tuple[_Params, Any]:
        opt_state = transformation.init(params)
        return params, opt_state

    def update_fn(
        step: ArrayLike, grads: ArrayLike, state: tuple[_Params, Any], value
    ) -> tuple[_Params, Any]:
        params, opt_state = state
        updates, opt_state = transformation.update(grads, opt_state, params, value=value)
        updated_params = optax.apply_updates(params, updates)
        return updated_params, opt_state

    def get_params_fn(state: tuple[_Params, Any]) -> _Params:
        params, _ = state
        return params

    return _NumPyroOptimValueArg(lambda x, y, z: (x, y, z), init_fn, update_fn, get_params_fn)

Then you can run e.g. the SVI example from the docs with:

from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import Predictive, SVI, Trace_ELBO

def model(data):
    f = numpyro.sample("latent_fairness", dist.Beta(10, 10))
    with numpyro.plate("N", data.shape[0] if data is not None else 10):
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

def guide(data):
    alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive)
    beta_q = numpyro.param("beta_q", lambda rng_key: random.exponential(rng_key),
                           constraint=constraints.positive)
    numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
optimizer = optax_to_numpyro_value_arg(optax.chain(
    optax.adam(0.01),
    reduce_on_plateau(
        cooldown=100, accumulation_size=100, patience=200,
    ),
))
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 2000, data)
params = svi_result.params
inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
# use guide to make predictive
predictive = Predictive(model, guide=guide, params=params, num_samples=1000)
samples = predictive(random.PRNGKey(1), data=None)
# get posterior samples
predictive = Predictive(guide, params=params, num_samples=1000)
posterior_samples = predictive(random.PRNGKey(1), data=None)
# use posterior samples to make predictive
predictive = Predictive(model, posterior_samples, params=params, num_samples=1000)
samples = predictive(random.PRNGKey(1), data=None)
@fehiepsi fehiepsi added the enhancement New feature or request label Jan 25, 2025
@fehiepsi
Copy link
Member

It is a good idea to support this feature. Do you want to submit a PR? It is fine to incorporate the logic directly into the NumPyro Optim. For the transform, we can use https://optax.readthedocs.io/en/latest/api/transformations.html#optax.with_extra_args_support I guess.

@juanitorduz
Copy link
Contributor

This would be fantastic to have!

@zmbc
Copy link
Author

zmbc commented Jan 29, 2025

@fehiepsi I wasn't sure how you would want it to work. It isn't enough to just say "this has extra args" -- we have to tell it what to pass into them. In my code above, for the plateau optimizer, it is the training loss that needs to be passed in. Do I make this a special case or do I need to make a more general tool?

@fehiepsi
Copy link
Member

fehiepsi commented Jan 29, 2025

I just meant that under that utility, all optax optimizer has the same signature and we can always pass value to the update call.

The new NumPyro optimizer would be your _NumPyroOptimValueArg implementation. The new optax_to_numpyro utility would be your optax_to_numpyro_value_args.

@OlaRonning
Copy link
Member

Agree with @juanitorduz, this would be awesome to have. It would also enable newton methods and line search.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants