diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 8b3212831e7e0..1c4ac905e02df 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -549,6 +549,7 @@ async def benchmark( # multi-modal benchmark is only available on OpenAI Chat backend. raise ValueError( "Multi-modal content is only supported on 'openai-chat' backend.") + test_output_len = 10 test_input = RequestFuncInput( model=model_id, model_name=model_name, diff --git a/scripts/run_static-online.sh b/scripts/run_static-online.sh index 86d8fcfa8a111..16483f4172b2b 100644 --- a/scripts/run_static-online.sh +++ b/scripts/run_static-online.sh @@ -1,6 +1,6 @@ #!/bin/bash tp_parrallel=8 -bs=32 +bs=96 in_len=1024 out_len=1024 multi_step=1 @@ -10,12 +10,13 @@ VLLM_DECODE_BLOCK_BUCKET_MAX=$((total_len * bs / 128 + 128)) # model="/data/models/DeepSeek-R1/" # tokenizer="/data/models/DeepSeek-R1/" -model="/software/data/DeepSeek-R1/" -tokenizer="/software/data/DeepSeek-R1/" +model="/data/models/DeepSeek-R1/" +tokenizer="/data/models/DeepSeek-R1/" model_name="DeepSeek-R1" HABANA_VISIBLE_DEVICES="ALL" \ -VLLM_MOE_N_SLICE=8 \ +VLLM_MOE_N_SLICE=4 \ +VLLM_MLA_DISABLE_REQUANTIZATION=1 \ PT_HPU_ENABLE_LAZY_COLLECTIVES="true" \ VLLM_RAY_DISABLE_LOG_TO_DRIVER="1" \ RAY_IGNORE_UNHANDLED_ERRORS="1" \ @@ -37,7 +38,6 @@ python -m vllm.entrypoints.openai.api_server \ --use-v2-block-manager \ --num_scheduler_steps ${multi_step}\ --max-model-len 2048 \ - --max-num-batched-tokens 2048 \ --distributed_executor_backend ray \ --gpu_memory_utilization 0.9 \ --trust_remote_code 2>&1 | tee benchmark_logs/serving.log & @@ -53,7 +53,7 @@ done sleep 5s echo ${pid} -num_prompts=32 +num_prompts=300 request_rate=1 start_time=$(date +%s) echo "Start to benchmark" diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index e2f4bce16273f..3828ddaf99a49 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -75,7 +75,7 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return (num_blocks, block_size, head_size), True + return (num_blocks, block_size, head_size), (num_blocks, block_size, head_size//9*8) @staticmethod def get_impl_cls() -> Type["HPUAttentionImpl"]: @@ -136,7 +136,8 @@ def __init__( self.matmul_av = Matmul() self.batch2block_matmul = Matmul() self.block2batch_matmul = Matmul() - self.latent_cache = VLLMKVCache() + self.latent_cache_k = VLLMKVCache() + self.latent_cache_v = VLLMKVCache() HPUFusedSDPA = kernels.fsdpa() self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \ else ModuleFusedSDPA(HPUFusedSDPA) @@ -201,19 +202,29 @@ def forward( block_indices = attn_metadata.block_indices block_offsets = attn_metadata.block_offsets - latent_vec = torch.concat( + latent_vec_k = torch.concat( (k_c_normed, k_pe.view(batch_size, -1, self.qk_rope_head_dim)), dim=-1) # assert layer._k_scale == 0, f"got _k_scale={layer._k_scale}" - latent_vec = latent_vec.view(-1, self.qk_rope_head_dim + self.kv_lora_rank) + latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank) + latent_vec_v = k_c_normed.view(-1, self.kv_lora_rank) if is_prefill: - latent_vec = latent_vec.unflatten(0, (block_indices.size(0), -1)) + latent_vec_k = latent_vec_k.unflatten(0, (block_indices.size(0), -1)) + latent_vec_v = latent_vec_v.unflatten(0, (block_indices.size(0), -1)) # print("latent_vec", latent_vec.shape) # write the latent and rope to kv cache - if kv_cache is not None: - kv_cache = self.latent_cache(latent_vec, kv_cache, block_indices, + if kv_cache is not None and len(kv_cache) == 2: + # print(f"k cache shape: {kv_cache[0].shape}") + # print(f"v cache shape: {kv_cache[1].shape}") + # print(f"latent vec k shape: {latent_vec_k.shape}") + # print(f"latent vec v shape: {latent_vec_v.shape}") + + k_cache = self.latent_cache_k(latent_vec_k, kv_cache[0], block_indices, + block_offsets) + v_cache = self.latent_cache_v(latent_vec_v, kv_cache[1], block_indices, block_offsets) + kv_cache = (k_cache, v_cache) if is_prefill: return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata, batch_size) @@ -265,13 +276,13 @@ def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + kv_cache: torch.Tensor, attn_metadata: HPUAttentionMetadata, batch_size: int ) -> torch.Tensor: q = torch.cat([q_nope, q_pe], dim=-1) - kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) - kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + kv_c_and_k_pe_cache = kv_cache[0].unsqueeze(2) + kv_c_cache = kv_cache[1].unsqueeze(2) output = HPUPagedAttention.forward_decode( query=q, @@ -287,8 +298,8 @@ def _forward_decode( matmul_av_op=self.matmul_av, batch2block_matmul_op=self.batch2block_matmul, block2batch_matmul_op=self.block2batch_matmul, - keys_fetch_func=self.latent_cache.fetch_from_cache, - values_fetch_func=self.latent_cache.fetch_from_cache) + keys_fetch_func=self.latent_cache_k.fetch_from_cache, + values_fetch_func=self.latent_cache_v.fetch_from_cache) output = output.view(batch_size, 1, -1) result = self._v_up_proj_and_o_proj(output) result = result.view(batch_size, 1, -1) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index c427b759b2e97..3a995831bf7b6 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -112,7 +112,8 @@ def get_cache_block_size( key_cache_block = cache_config.block_size * num_heads * head_size # For MLA there is no value cache, since the latent vector # is joint keys and values. - value_cache_block = key_cache_block if not model_config.use_mla else 0 + # value_cache_block = key_cache_block if not model_config.use_mla else 0 + value_cache_block = key_cache_block // 9 * 8 total = num_attention_layers * (key_cache_block + value_cache_block) if cache_config.cache_dtype == "auto": dtype = model_config.dtype diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index f953b41f40e63..5bbea6f841191 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -568,9 +568,13 @@ def _allocate_kv_cache( num_blocks, self.block_size, self.num_kv_heads, self.head_size) use_mla = False - if len(kv_cache_shape) == 2 and kv_cache_shape[1]: + if len(kv_cache_shape) == 2: use_mla = True - kv_cache_shape = kv_cache_shape[0] + k_cache_shape = kv_cache_shape[0] + v_cache_shape = kv_cache_shape[1] + else: + k_cache_shape = kv_cache_shape + v_cache_shape = kv_cache_shape kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] dtype = self.dtype @@ -578,15 +582,10 @@ def _allocate_kv_cache( and self.dtype == torch.float8_e4m3fn: dtype = torch.uint8 for _ in range(self.num_attention_layers): - if use_mla: - kv_layer = torch.zeros(kv_cache_shape, - dtype=dtype, - device=device) - else: - key_cache = torch.zeros(kv_cache_shape, dtype=dtype, device=device) - value_cache = torch.zeros(kv_cache_shape, - dtype=dtype, - device=device) - kv_layer = (key_cache, value_cache) + key_cache = torch.zeros(k_cache_shape, dtype=dtype, device=device) + value_cache = torch.zeros(v_cache_shape, + dtype=dtype, + device=device) + kv_layer = (key_cache, value_cache) kv_cache.append(kv_layer) return kv_cache