You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The policy update delay for SAC implementations can cause unexpected behaviour, when num_envs*rollout_length and policy_update_delay have common factors. If these hyperparameters share factors it will cause t % policy_update_delay = 0 at every update step rather than every policy_update_delay steps as desired, this happens because we step using t+=num_envs*rollout_length .
To Reproduce
Steps to reproduce the behavior:
Run a SAC implementation (e.g., masac) with default hyperparameters.
Add debug statements, such as jax.debug.print("Q Update") in update_q and jax.debug.print("Actor Update") in update_actor_and_alpha.
Let the program run for a while, then observe the printed output or track the gradient update steps for Q-functions and Actor updates.
Expected behavior
We'd expect "Q Update" be printed policy_update_delay number of times before "Actor Update" is printed policy_update_delay time e.g, if policy_update_delay = 4:
Q Update
Q Update
Q Update
Q Update
Actor Update
Actor Update
Actor Update
Actor Update
Q Update
...
But with the default hyperameter setting where num_envs*rollout_length and policy_update_delay share factors we actually observe:
Q Update
Actor Update
Actor Update
Actor Update
Actor Update
Q Update
Actor Update
Actor Update
Actor Update
Actor Update
...
And if we keep track of the gradient updates you will notice that the Actor updated 4 times (policy_update_delay=4) as often as the Q networks, which I don't think is the intended behavior.
Possible Solution
Use a separate counter that specifically count updates (e.g., update_t rather than t) and incremented by +1 every update step. And use this in the train function e.g.:
params, opt_states, act_loss_info=lax.cond(
update_t%cfg.system.policy_update_delay==0, # TD 3 Delayed update supportupdate_actor_and_alpha,
# just return same params and opt_states and 0 for losseslambdaparams, opt_states, *_: (
params,
opt_states,
{"actor_loss": 0.0, "alpha_loss": 0.0},
),
params,
opt_states,
data,
actor_key,
)
The text was updated successfully, but these errors were encountered:
Describe the bug
The policy update delay for SAC implementations can cause unexpected behaviour, when
num_envs*rollout_length
andpolicy_update_delay
have common factors. If these hyperparameters share factors it will causet % policy_update_delay = 0
at every update step rather than everypolicy_update_delay
steps as desired, this happens because we step usingt+=num_envs*rollout_length
.To Reproduce
Steps to reproduce the behavior:
jax.debug.print("Q Update")
inupdate_q
andjax.debug.print("Actor Update")
inupdate_actor_and_alpha
.Expected behavior
We'd expect "Q Update" be printed
policy_update_delay
number of times before "Actor Update" is printedpolicy_update_delay
time e.g, ifpolicy_update_delay = 4
:But with the default hyperameter setting where
num_envs*rollout_length
andpolicy_update_delay
share factors we actually observe:And if we keep track of the gradient updates you will notice that the Actor updated 4 times (
policy_update_delay=4
) as often as the Q networks, which I don't think is the intended behavior.Possible Solution
Use a separate counter that specifically count updates (e.g.,
update_t
rather thant
) and incremented by+1
every update step. And use this in thetrain
function e.g.:The text was updated successfully, but these errors were encountered: