Skip to content

Commit

Permalink
add a new env var for VLLM_MOE_N_SLICE (#769)
Browse files Browse the repository at this point in the history
Add VLLM_MOE_N_SLICE in test script and fix warmup bucket

Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
xuechendi authored Jan 31, 2025
1 parent 0f6db60 commit 3d328b9
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
2 changes: 2 additions & 0 deletions scripts/run_example_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Parse the command-line arguments.
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="/software/data/DeepSeek-R1/", help="The model path.")
#parser.add_argument("--model", type=str, default="/data/models/DeepSeek-R1/", help="The model path.")
parser.add_argument("--tokenizer", type=str, default="deepseek-ai/DeepSeek-R1", help="The model path.")
#parser.add_argument("--model", type=str, default="/data/models/DeepSeek-R1-bf16-small/", help="The model path.")
#parser.add_argument("--tokenizer", type=str, default="opensourcerelease/DeepSeek-R1-bf16", help="The model path.")
Expand All @@ -17,6 +18,7 @@
os.environ["PT_HPU_ENABLE_LAZY_COLLECTIVES"] = "true"
os.environ["VLLM_RAY_DISABLE_LOG_TO_DRIVER"] = "1"
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
os.environ["VLLM_MOE_N_SLICE"] = "8"


# Sample prompts.
Expand Down
7 changes: 4 additions & 3 deletions scripts/run_static-online.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,21 @@ total_len=$((in_len + out_len))
VLLM_DECODE_BLOCK_BUCKET_MIN=$((in_len * bs / 128))
VLLM_DECODE_BLOCK_BUCKET_MAX=$((total_len * bs / 128 + 128))

#model="/data/models/DeepSeek-R1/"
#tokenizer="/data/models/DeepSeek-R1/"
# model="/data/models/DeepSeek-R1/"
# tokenizer="/data/models/DeepSeek-R1/"
model="/software/data/DeepSeek-R1/"
tokenizer="/software/data/DeepSeek-R1/"
model_name="DeepSeek-R1"

HABANA_VISIBLE_DEVICES="ALL" \
VLLM_MOE_N_SLICE=8 \
PT_HPU_ENABLE_LAZY_COLLECTIVES="true" \
VLLM_RAY_DISABLE_LOG_TO_DRIVER="1" \
RAY_IGNORE_UNHANDLED_ERRORS="1" \
VLLM_PROMPT_BS_BUCKET_MIN=1 \
VLLM_PROMPT_BS_BUCKET_MAX=${bs} \
VLLM_PROMPT_SEQ_BUCKET_MIN=${in_len} \
VLLM_PROMPT_SEQ_BUCKET_MAX=${out_len} \
VLLM_PROMPT_SEQ_BUCKET_MAX=${total_len} \
VLLM_DECODE_BS_BUCKET_MIN=${bs} \
VLLM_DECODE_BS_BUCKET_MAX=${bs} \
VLLM_DECODE_BLOCK_BUCKET_MIN=${VLLM_DECODE_BLOCK_BUCKET_MIN} \
Expand Down
18 changes: 7 additions & 11 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.nn.parameter import Parameter

import vllm.envs as envs
import os
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
Expand Down Expand Up @@ -405,6 +406,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
self.moe_n_slice = int(os.environ.get("VLLM_MOE_N_SLICE", 8))

def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size: int, params_dtype: torch.dtype,
Expand Down Expand Up @@ -736,17 +738,10 @@ def forward_hpu(
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
# final_hidden_states = layer.hpu_fused_moe.MoeOp(
# hidden_states=x,
# expert_routing_table=topk_ids,
# router_weights=topk_weights,
# permuted_weights=True,
# activation="silu",
# )
final_hidden_states = torch.zeros_like(x)
num_experts = layer.w13_weight.shape[0]
n_expert_slice = layer.w13_weight.shape[0] // 8
assert n_expert_slice * 8 == num_experts
n_expert_slice = layer.w13_weight.shape[0] // self.moe_n_slice
assert n_expert_slice * self.moe_n_slice == num_experts

# w13_list = layer.hpu_fused_moe.MoeOp.w13_list
# w2_list = layer.hpu_fused_moe.MoeOp.w2_list
Expand All @@ -768,14 +763,15 @@ def forward_hpu(
dtype=x.dtype,
original_M=orig_M_w2,
original_N=orig_N_w2)
for i in range(8):
for i in range(self.moe_n_slice):
min_expert = i * n_expert_slice
max_expert = (i + 1) * n_expert_slice

w13_list_slice = [w13_weight[j] for j in range(min_expert, max_expert)]
w2_list_slice = [w2_weight[j] for j in range(min_expert, max_expert)]

final_hidden_states += torch.ops.hpu.mixture_of_experts(hidden_states=x,
final_hidden_states += torch.ops.hpu.mixture_of_experts(
hidden_states=x,
expert_routing_table=topk_ids.to(torch.int64),
router_weights=topk_weights.to(x.dtype),
w12=w13_list_slice,
Expand Down

0 comments on commit 3d328b9

Please sign in to comment.