Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Actor updates too frequently in SAC with Policy Update Delay #1154

Open
LeoHink opened this issue Jan 13, 2025 · 1 comment
Open

[BUG] Actor updates too frequently in SAC with Policy Update Delay #1154

LeoHink opened this issue Jan 13, 2025 · 1 comment
Labels
bug Something isn't working

Comments

@LeoHink
Copy link

LeoHink commented Jan 13, 2025

Describe the bug

The policy update delay for SAC implementations can cause unexpected behaviour, when num_envs*rollout_length and policy_update_delay have common factors. If these hyperparameters share factors it will cause t % policy_update_delay = 0 at every update step rather than every policy_update_delay steps as desired, this happens because we step using t+=num_envs*rollout_length .

To Reproduce

Steps to reproduce the behavior:

  1. Run a SAC implementation (e.g., masac) with default hyperparameters.
  2. Add debug statements, such as jax.debug.print("Q Update") in update_q and jax.debug.print("Actor Update") in update_actor_and_alpha.
  3. Let the program run for a while, then observe the printed output or track the gradient update steps for Q-functions and Actor updates.

Expected behavior

We'd expect "Q Update" be printed policy_update_delay number of times before "Actor Update" is printed policy_update_delay time e.g, if policy_update_delay = 4:

Q Update
Q Update
Q Update
Q Update
Actor Update
Actor Update
Actor Update
Actor Update
Q Update
...

But with the default hyperameter setting where num_envs*rollout_length and policy_update_delay share factors we actually observe:

Q Update
Actor Update
Actor Update
Actor Update
Actor Update
Q Update
Actor Update
Actor Update
Actor Update
Actor Update
...

And if we keep track of the gradient updates you will notice that the Actor updated 4 times (policy_update_delay=4) as often as the Q networks, which I don't think is the intended behavior.

Possible Solution

Use a separate counter that specifically count updates (e.g., update_t rather than t) and incremented by +1 every update step. And use this in the train function e.g.:

     params, opt_states, act_loss_info = lax.cond(
            update_t % cfg.system.policy_update_delay == 0,  # TD 3 Delayed update support
            update_actor_and_alpha,
            # just return same params and opt_states and 0 for losses
            lambda params, opt_states, *_: (
                params,
                opt_states,
                {"actor_loss": 0.0, "alpha_loss": 0.0},
            ),
            params,
            opt_states,
            data,
            actor_key,
        )
@LeoHink LeoHink added the bug Something isn't working label Jan 13, 2025
@sash-a
Copy link
Contributor

sash-a commented Jan 14, 2025

Hi @LeoHink great find, this is indeed a bug! Would you be able to put up a PR for this, your solution looks good to me 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants