From 3b0dde5a20659073af5684e966a81981e614789e Mon Sep 17 00:00:00 2001 From: ilyasch2 <104485953+ilyasch2@users.noreply.github.com> Date: Fri, 27 Sep 2024 01:22:08 +0400 Subject: [PATCH] Fix exploding gradients when ngroups larger than one (#547) Co-authored-by: Ilyas.Chahed --- mamba_ssm/ops/triton/ssd_combined.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index 77d20715..bf699f0b 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -887,7 +887,7 @@ def backward(ctx, dout, *args): x_rms = rearrange(out, "b s h p -> (b s) (h p)") z_rms = rearrange(z, "b s h p -> (b s) (h p)") out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None - dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None) + dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, group_size=dim//ctx.ngroups, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None) out_for_linear = out_recompute if recompute_output else None dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim) dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(