Skip to content

Commit

Permalink
Fix integer overflow causing gpu segfault
Browse files Browse the repository at this point in the history
Signed-off-by: Randall Smith <[email protected]>
  • Loading branch information
rasmith committed Nov 15, 2024
1 parent 79ee45b commit 5cd1689
Showing 1 changed file with 34 additions and 45 deletions.
79 changes: 34 additions & 45 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op

logger = init_logger(__name__)

Expand Down Expand Up @@ -105,16 +104,18 @@ def fused_moe_kernel(
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)

Check failure on line 107 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/layers/fused_moe/fused_moe.py:107:81: E501 Line too long (82 > 80)
offs_token_id = offs_token_id.to(tl.int64)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64)
token_mask = offs_token < num_valid_tokens

offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_bn = offs_bn.to(tl.int64)
offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)

off_experts = tl.load(expert_ids_ptr + pid_m)
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
if use_int8_w8a16:
Expand Down Expand Up @@ -168,6 +169,7 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_cn = offs_cn.to(tl.int64)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
Expand Down Expand Up @@ -467,6 +469,8 @@ def get_config_dtype_str(dtype: torch.dtype,
return None


@torch.library.custom_op("vllm::inplace_fused_experts",
mutates_args=["hidden_states"])
def inplace_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
Expand All @@ -483,29 +487,22 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale, a2_scale)


def inplace_fused_experts_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> None:
@inplace_fused_experts.register_fake
def _(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> None:
pass


direct_register_custom_op(
op_name="inplace_fused_experts",
op_func=inplace_fused_experts,
mutates_args=["hidden_states"],
fake_impl=inplace_fused_experts_fake,
)


@torch.library.custom_op("vllm::outplace_fused_experts", mutates_args=[])
def outplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand All @@ -523,29 +520,21 @@ def outplace_fused_experts(
w2_scale, a1_scale, a2_scale)


def outplace_fused_experts_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
@outplace_fused_experts.register_fake
def _(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states)


direct_register_custom_op(
op_name="outplace_fused_experts",
op_func=outplace_fused_experts,
mutates_args=[],
fake_impl=outplace_fused_experts_fake,
)


def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
Expand Down

0 comments on commit 5cd1689

Please sign in to comment.