Skip to content

Commit

Permalink
Deepseek r1 (#765)
Browse files Browse the repository at this point in the history
1. move block_fp8 pad to load_weight
2. move moe fp8 linear out of loop
3. remove permute and reshape

---------

Signed-off-by: Chendi Xue <[email protected]>
Co-authored-by: root <[email protected]>
  • Loading branch information
xuechendi and root authored Jan 31, 2025
1 parent ea681df commit 0f6db60
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 51 deletions.
9 changes: 6 additions & 3 deletions scripts/run_static-online.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ out_len=1024
multi_step=1
total_len=$((in_len + out_len))
VLLM_DECODE_BLOCK_BUCKET_MIN=$((in_len * bs / 128))
VLLM_DECODE_BLOCK_BUCKET_MAX=$((total_len * bs / 128))
VLLM_DECODE_BLOCK_BUCKET_MAX=$((total_len * bs / 128 + 128))

#model="/data/models/DeepSeek-R1/"
#tokenizer="/data/models/DeepSeek-R1/"
model="/software/data/DeepSeek-R1/"
tokenizer="/software/data/DeepSeek-R1/"
model_name="DeepSeek-R1"
Expand All @@ -19,7 +21,7 @@ RAY_IGNORE_UNHANDLED_ERRORS="1" \
VLLM_PROMPT_BS_BUCKET_MIN=1 \
VLLM_PROMPT_BS_BUCKET_MAX=${bs} \
VLLM_PROMPT_SEQ_BUCKET_MIN=${in_len} \
VLLM_PROMPT_SEQ_BUCKET_MAX=${in_len} \
VLLM_PROMPT_SEQ_BUCKET_MAX=${out_len} \
VLLM_DECODE_BS_BUCKET_MIN=${bs} \
VLLM_DECODE_BS_BUCKET_MAX=${bs} \
VLLM_DECODE_BLOCK_BUCKET_MIN=${VLLM_DECODE_BLOCK_BUCKET_MIN} \
Expand All @@ -36,6 +38,7 @@ python -m vllm.entrypoints.openai.api_server \
--max-model-len 2048 \
--max-num-batched-tokens 2048 \
--distributed_executor_backend ray \
--gpu_memory_utilization 0.9 \
--trust_remote_code 2>&1 | tee benchmark_logs/serving.log &
pid=$(($!-1))

Expand All @@ -54,7 +57,7 @@ request_rate=1
start_time=$(date +%s)
echo "Start to benchmark"
python benchmarks/benchmark_serving.py --backend vllm --model ${model} --tokenizer ${tokenizer} --dataset-name sonnet --dataset-path benchmarks/sonnet.txt --request-rate ${request_rate} --num-prompts ${num_prompts} --port 8080 --sonnet-input-len ${in_len} --sonnet-output-len ${out_len} --sonnet-prefix-len 100 \
--save-result| tee benchmark_logs/static-online-gaudi3-TPparallel${tp_parrallel}-multistep${multi_step}_nprompt${num_prompts}_rrate${request_rate}_bs${bs}_i${in_len}_o${out_len}.log
--save-result 2>&1 | tee benchmark_logs/static-online-gaudi3-0.9util-TPparallel${tp_parrallel}-multistep${multi_step}_nprompt${num_prompts}_rrate${request_rate}_bs${bs}_i${in_len}_o${out_len}_prepad.log
end_time=$(date +%s)
echo "Time elapsed: $((end_time - start_time))s"

Expand Down
71 changes: 50 additions & 21 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,16 @@ def create_weights(
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
if current_platform.is_hpu():
from vllm.model_executor.layers.quantization.utils.fp8_utils import pad_block_fp8_weight_naive
layer.weight, orig_M, orig_N = pad_block_fp8_weight_naive(
layer.weight,
layer.weight_scale_inv,
self.quant_config.weight_block_size)
orig_M = torch.nn.Parameter(torch.tensor(orig_M, dtype=torch.int32), requires_grad=False)
orig_N = torch.nn.Parameter(torch.tensor(orig_N, dtype=torch.int32), requires_grad=False)
layer.register_parameter("orig_M", orig_M)
layer.register_parameter("orig_N", orig_N)
return
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
Expand Down Expand Up @@ -355,6 +365,8 @@ def apply(self,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
original_M=layer.orig_M,
original_N=layer.orig_N,
)
else:
return apply_w8a8_block_fp8_linear(
Expand Down Expand Up @@ -513,6 +525,24 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
if current_platform.is_hpu():
from vllm.model_executor.layers.quantization.utils.fp8_utils import pad_block_fp8_weight_naive
layer.w13_weight, orig_M_w13, orig_N_w13 = pad_block_fp8_weight_naive(
layer.w13_weight,
layer.w13_weight_scale_inv,
self.quant_config.weight_block_size)
layer.w2_weight, orig_M_w2, orig_N_w2 = pad_block_fp8_weight_naive(
layer.w2_weight,
layer.w2_weight_scale_inv,
self.quant_config.weight_block_size)
orig_M_w13 = torch.nn.Parameter(torch.tensor(orig_M_w13, dtype=torch.int32), requires_grad=False)
orig_N_w13 = torch.nn.Parameter(torch.tensor(orig_N_w13, dtype=torch.int32), requires_grad=False)
layer.register_parameter("orig_M_w13", orig_M_w13)
layer.register_parameter("orig_N_w13", orig_N_w13)
orig_M_w2 = torch.nn.Parameter(torch.tensor(orig_M_w2, dtype=torch.int32), requires_grad=False)
orig_N_w2 = torch.nn.Parameter(torch.tensor(orig_N_w2, dtype=torch.int32), requires_grad=False)
layer.register_parameter("orig_M_w2", orig_M_w2)
layer.register_parameter("orig_N_w2", orig_N_w2)
return
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
Expand Down Expand Up @@ -722,28 +752,29 @@ def forward_hpu(
# w2_list = layer.hpu_fused_moe.MoeOp.w2_list
from vllm.model_executor.layers.quantization.utils.fp8_utils import dequant_block_fp8_weight_naive

orig_M_w13 = layer.orig_M_w13.data
orig_N_w13 = layer.orig_N_w13.data
orig_M_w2 = layer.orig_M_w2.data
orig_N_w2 = layer.orig_N_w2.data
w13_weight = dequant_block_fp8_weight_naive(layer.w13_weight,
layer.w13_weight_scale_inv,
block_size=self.quant_config.weight_block_size,
dtype=x.dtype,
original_M=orig_M_w13,
original_N=orig_N_w13)
w2_weight = dequant_block_fp8_weight_naive(layer.w2_weight,
layer.w2_weight_scale_inv,
block_size=self.quant_config.weight_block_size,
dtype=x.dtype,
original_M=orig_M_w2,
original_N=orig_N_w2)
for i in range(8):
min_expert = i * n_expert_slice
max_expert = (i + 1) * n_expert_slice
# w13_list_slice = [w13_list[i].weight.squeeze() for i in range(min_expert, max_expert)]
# w2_list_slice = [w2_list[i].weight.squeeze() for i in range(min_expert, max_expert)]
# w13_list_slice = [layer.w13_weight[j].squeeze().clone() for j in range(min_expert, max_expert)]
# w2_list_slice = [layer.w2_weight[j].squeeze().clone() for j in range(min_expert, max_expert)]

w13_list_slice = [dequant_block_fp8_weight_naive(layer.w13_weight[j].squeeze(),
layer.w13_weight_scale_inv[j],
block_size=self.quant_config.weight_block_size,
dtype=x.dtype) for j in range(min_expert, max_expert)]
w2_list_slice = [dequant_block_fp8_weight_naive(layer.w2_weight[j].squeeze(),
layer.w2_weight_scale_inv[j],
block_size=self.quant_config.weight_block_size,
dtype=x.dtype) for j in range(min_expert, max_expert)]
# print(f"w13_list_slice[0].shape: {w13_list_slice[0].shape}, device: {w13_list_slice[0].device}, dtype: {w13_list_slice[0].dtype}")
# print(f"w2_list_slice[0].shape: {w2_list_slice[0].shape}, device: {w2_list_slice[0].device}, dtype: {w2_list_slice[0].dtype}")
# print(f"hidden_states.shape: {x.shape}, device: {x.device}, dtype: {x.dtype}")
# print(f"topk_ids.shape: {topk_ids.shape}, device: {topk_ids.device}, dtype: {topk_ids.dtype}")
# print(f"topk_weights.shape: {topk_weights.shape}, device: {topk_weights.device}, dtype: {topk_weights.dtype}")
# print(f"min_expert: {min_expert}, max_expert: {max_expert}")

w13_list_slice = [w13_weight[j] for j in range(min_expert, max_expert)]
w2_list_slice = [w2_weight[j] for j in range(min_expert, max_expert)]

final_hidden_states += torch.ops.hpu.mixture_of_experts(hidden_states=x,
expert_routing_table=topk_ids.to(torch.int64),
router_weights=topk_weights.to(x.dtype),
Expand All @@ -753,9 +784,7 @@ def forward_hpu(
activation="silu",
experts_min=min_expert,
experts_max=max_expert - 1)
# print(f"final_hidden_states.shape: {final_hidden_states.shape}, device: {final_hidden_states.device}, dtype: {final_hidden_states.dtype}")
htorch.core.mark_step()
# print(f"done mark step {i}")
return final_hidden_states.view(-1, x.shape[1])


Expand Down
76 changes: 49 additions & 27 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,66 +34,88 @@ def apply_w8a8_block_fp8_linear(

def pad_weight(weight, block_size):
"""Pads a matrix to make its dimensions multiples of block_size."""
M, N = weight.shape
M, N = weight.shape[-2:]
block_size_m, block_size_n = block_size
pad_M = (block_size_m - M % block_size_m) % block_size_m
pad_N = (block_size_n - N % block_size_n) % block_size_n

if pad_M == 0 and pad_N == 0:
return weight, M, N # No padding needed
padded_weight = torch.nn.functional.pad(weight, (0, pad_N, 0, pad_M), mode='constant', value=0)
padded_weight = torch.nn.Parameter(padded_weight, requires_grad=False)
return padded_weight, M, N # Return original dimensions for unpadding

def unpad_weight(weight, original_M, original_N):
def unpad_weight(weight, original_M, original_N, keep_first_dim=False):
"""Removes padding from the matrix to restore its original shape."""
return weight[:original_M, :original_N].contiguous()

if keep_first_dim:
return weight[:, :original_M, :original_N]
else:
return weight[:original_M, :original_N]

def dequant_block_fp8_weight_naive(weight, weight_scale, block_size, dtype):
def pad_block_fp8_weight_naive(weight, weight_scale, block_size):

assert len(block_size) == 2
assert len(weight_scale.shape) == 2
assert len(weight.shape) == 2
# assert M % block_size_m == 0 and N % block_size_n == 0, \
# f"Matrix dimensions must be divisible by block_size, got M: {M}, N: {N}, block_size_m: {block_size_m}, block_size_n: {block_size_n}"

weight, original_M, original_N = pad_weight(weight, block_size)
M, N = weight.shape

block_size_m, block_size_n = block_size
weight_scale_m, weight_scale_n = weight_scale.shape[-2:]

weight, orig_M, orig_N = pad_weight(weight, block_size)
M, N = weight.shape[-2:]

# change weight to block format
weight = weight.view(M // block_size_m, block_size_m, N // block_size_n, block_size_n) # [0, 1, 2, 3]
weight = weight.permute(0, 2, 1, 3)
weight = weight.contiguous().view(M // block_size_m, N // block_size_n, -1)

# mul scale
weight_scale_m, weight_scale_n = weight_scale.shape
assert weight_scale_m == M // block_size_m
assert weight_scale_n == N // block_size_n
weight_scale = weight_scale.view(weight_scale_m, weight_scale_n, 1)
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)

# change block format back to normal
dequant_weight = dequant_weight.view(M // block_size_m, N // block_size_n, block_size_m, block_size_n)
dequant_weight = dequant_weight.permute(0, 2, 1, 3).contiguous().view(M, N)
return weight, orig_M, orig_N


dequant_weight = unpad_weight(dequant_weight, original_M, original_N)
def dequant_block_fp8_weight_naive(weight, weight_scale, block_size, dtype, original_M, original_N):

assert len(block_size) == 2

weight_shape_len = len(weight.shape)

block_size_m, block_size_n = block_size

# mul scale
if weight_shape_len == 2:
weight_scale_m, weight_scale_n = weight_scale.shape
weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1)
weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n)
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
dequant_weight = dequant_weight.view(weight_scale_m*block_size_m, weight_scale_n*block_size_n)
keep_first_dim = False
elif weight_shape_len == 3:
fd, weight_scale_m, weight_scale_n = weight_scale.shape
weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1)
weight = weight.view(fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n)
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
dequant_weight = dequant_weight.view(fd, weight_scale_m*block_size_m, weight_scale_n*block_size_n)
keep_first_dim = True
else:
raise ValueError("Only support original weight shape is either 2 or 3")

dequant_weight = unpad_weight(dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim)

return dequant_weight


def apply_block_fp8_linear_hpu(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
original_M: Optional[torch.Tensor] = None,
original_N: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
dequant_weight = dequant_block_fp8_weight_naive(weight, weight_scale, block_size, input_2d.dtype)
original_M = original_M.data
original_N = original_N.data
output_shape = [*input.shape[:-1], original_M]
dequant_weight = dequant_block_fp8_weight_naive(weight, weight_scale, block_size, input_2d.dtype, original_M, original_N)
output = torch.nn.functional.linear(input_2d, dequant_weight, bias=None)
if bias is not None:
output = output + bias
Expand Down

0 comments on commit 0f6db60

Please sign in to comment.