Skip to content

Commit

Permalink
Fix exploding gradients when ngroups larger than one (state-spaces#547)
Browse files Browse the repository at this point in the history
Co-authored-by: Ilyas.Chahed <[email protected]>
  • Loading branch information
ilyasch2 and Ilyas.Chahed authored Sep 26, 2024
1 parent 9259852 commit 3b0dde5
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ def backward(ctx, dout, *args):
x_rms = rearrange(out, "b s h p -> (b s) (h p)")
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None
dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None)
dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, group_size=dim//ctx.ngroups, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None)
out_for_linear = out_recompute if recompute_output else None
dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(
Expand Down

0 comments on commit 3b0dde5

Please sign in to comment.