Skip to content

Commit

Permalink
fixing softplus bug with _chunk_cumsum_bwd_kernel() triton kernel (st…
Browse files Browse the repository at this point in the history
  • Loading branch information
stephen-youn authored Sep 26, 2024
1 parent 62db608 commit 9259852
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
4 changes: 2 additions & 2 deletions mamba_ssm/ops/triton/selective_state_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def _selective_scan_update_kernel(
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
dt = tl.where(dt <= 20.0, softplus(dt), dt)
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
dt = tl.where(dt <= 20.0, softplus(dt), dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix

Expand Down
6 changes: 2 additions & 4 deletions mamba_ssm/ops/triton/softplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
if TRITON3:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
return dt
return tl.math.log(tl.math.exp(dt) + 1)
else:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
return dt
return tl.math.log1p(tl.exp(dt))
4 changes: 2 additions & 2 deletions mamba_ssm/ops/triton/ssd_chunk_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _chunk_cumsum_fwd_kernel(
dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
dt += dt_bias[:, None]
if DT_SOFTPLUS:
dt = softplus(dt)
dt = tl.where(dt <= 20.0, softplus(dt), dt)
# As of Triton 2.2.0, tl.clamp is not available yet
# dt = tl.clamp(dt, dt_min, dt_max)
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
Expand Down Expand Up @@ -141,7 +141,7 @@ def _chunk_cumsum_bwd_kernel(
dt += dt_bias[:, None]
if DT_SOFTPLUS:
dt_presoftplus = dt
dt = softplus(dt)
dt = tl.where(dt <= 20.0, softplus(dt), ddt)
clamp_mask = (dt < dt_min) | (dt > dt_max)
# As of Triton 2.2.0, tl.clamp is not available yet
# dt = tl.clamp(dt, dt_min, dt_max)
Expand Down

0 comments on commit 9259852

Please sign in to comment.