From 789365b152ae88202ffc5e174ffb6e620a490bf1 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Fri, 15 Nov 2024 22:57:58 +0000 Subject: [PATCH] Fix segv occuring in fused_moe.py kernel Signed-off-by: Randall Smith --- vllm/model_executor/layers/fused_moe/fused_moe.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 340da32263c1c..5e56e862ba1e7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -105,16 +105,18 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( + tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_bn = (pid_n * BLOCK_SIZE_N + + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) - off_experts = tl.load(expert_ids_ptr + pid_m) + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) if use_int8_w8a16: @@ -167,7 +169,7 @@ def fused_moe_kernel( accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N)