Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Jun 27, 2024
1 parent e0c0d30 commit 7280cf6
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion scripts/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def create_env_fn():
# Add KL loss term
with torch.no_grad():
prior_dist = prior.get_dist(batch)
kl_div = kl_divergence(actor_training.get_dist(batch), prior_dist)
kl_div = kl_divergence(actor_training.get_dist(batch), prior_dist)
kl_div = (kl_div * mask.squeeze()).sum(-1).mean(-1)
loss_sum += kl_div * kl_coef
losses[0] = TensorDict({"kl_div": kl_div.detach().item()}, batch_size=[])
Expand Down

0 comments on commit 7280cf6

Please sign in to comment.