Skip to content

Commit

Permalink
use redundant latent cache to better fit contiguous pa
Browse files Browse the repository at this point in the history
  • Loading branch information
yangw1234 authored and xuechendi committed Feb 3, 2025
1 parent de8ca17 commit 9702891
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 31 deletions.
1 change: 1 addition & 0 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ async def benchmark(
# multi-modal benchmark is only available on OpenAI Chat backend.
raise ValueError(
"Multi-modal content is only supported on 'openai-chat' backend.")
test_output_len = 10
test_input = RequestFuncInput(
model=model_id,
model_name=model_name,
Expand Down
12 changes: 6 additions & 6 deletions scripts/run_static-online.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
tp_parrallel=8
bs=32
bs=96
in_len=1024
out_len=1024
multi_step=1
Expand All @@ -10,12 +10,13 @@ VLLM_DECODE_BLOCK_BUCKET_MAX=$((total_len * bs / 128 + 128))

# model="/data/models/DeepSeek-R1/"
# tokenizer="/data/models/DeepSeek-R1/"
model="/software/data/DeepSeek-R1/"
tokenizer="/software/data/DeepSeek-R1/"
model="/data/models/DeepSeek-R1/"
tokenizer="/data/models/DeepSeek-R1/"
model_name="DeepSeek-R1"

HABANA_VISIBLE_DEVICES="ALL" \
VLLM_MOE_N_SLICE=8 \
VLLM_MOE_N_SLICE=4 \
VLLM_MLA_DISABLE_REQUANTIZATION=1 \
PT_HPU_ENABLE_LAZY_COLLECTIVES="true" \
VLLM_RAY_DISABLE_LOG_TO_DRIVER="1" \
RAY_IGNORE_UNHANDLED_ERRORS="1" \
Expand All @@ -37,7 +38,6 @@ python -m vllm.entrypoints.openai.api_server \
--use-v2-block-manager \
--num_scheduler_steps ${multi_step}\
--max-model-len 2048 \
--max-num-batched-tokens 2048 \
--distributed_executor_backend ray \
--gpu_memory_utilization 0.9 \
--trust_remote_code 2>&1 | tee benchmark_logs/serving.log &
Expand All @@ -53,7 +53,7 @@ done
sleep 5s
echo ${pid}

num_prompts=32
num_prompts=300
request_rate=1
start_time=$(date +%s)
echo "Start to benchmark"
Expand Down
35 changes: 23 additions & 12 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_kv_cache_shape(
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, head_size), True
return (num_blocks, block_size, head_size), (num_blocks, block_size, head_size//9*8)

@staticmethod
def get_impl_cls() -> Type["HPUAttentionImpl"]:
Expand Down Expand Up @@ -136,7 +136,8 @@ def __init__(
self.matmul_av = Matmul()
self.batch2block_matmul = Matmul()
self.block2batch_matmul = Matmul()
self.latent_cache = VLLMKVCache()
self.latent_cache_k = VLLMKVCache()
self.latent_cache_v = VLLMKVCache()
HPUFusedSDPA = kernels.fsdpa()
self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
else ModuleFusedSDPA(HPUFusedSDPA)
Expand Down Expand Up @@ -201,19 +202,29 @@ def forward(
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets

latent_vec = torch.concat(
latent_vec_k = torch.concat(
(k_c_normed, k_pe.view(batch_size, -1, self.qk_rope_head_dim)), dim=-1)
# assert layer._k_scale == 0, f"got _k_scale={layer._k_scale}"
latent_vec = latent_vec.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
latent_vec_v = k_c_normed.view(-1, self.kv_lora_rank)
if is_prefill:
latent_vec = latent_vec.unflatten(0, (block_indices.size(0), -1))
latent_vec_k = latent_vec_k.unflatten(0, (block_indices.size(0), -1))
latent_vec_v = latent_vec_v.unflatten(0, (block_indices.size(0), -1))
# print("latent_vec", latent_vec.shape)


# write the latent and rope to kv cache
if kv_cache is not None:
kv_cache = self.latent_cache(latent_vec, kv_cache, block_indices,
if kv_cache is not None and len(kv_cache) == 2:
# print(f"k cache shape: {kv_cache[0].shape}")
# print(f"v cache shape: {kv_cache[1].shape}")
# print(f"latent vec k shape: {latent_vec_k.shape}")
# print(f"latent vec v shape: {latent_vec_v.shape}")

k_cache = self.latent_cache_k(latent_vec_k, kv_cache[0], block_indices,
block_offsets)
v_cache = self.latent_cache_v(latent_vec_v, kv_cache[1], block_indices,
block_offsets)
kv_cache = (k_cache, v_cache)

if is_prefill:
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata, batch_size)
Expand Down Expand Up @@ -265,13 +276,13 @@ def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
batch_size: int
) -> torch.Tensor:
q = torch.cat([q_nope, q_pe], dim=-1)
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
kv_c_and_k_pe_cache = kv_cache[0].unsqueeze(2)
kv_c_cache = kv_cache[1].unsqueeze(2)

output = HPUPagedAttention.forward_decode(
query=q,
Expand All @@ -287,8 +298,8 @@ def _forward_decode(
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.latent_cache.fetch_from_cache,
values_fetch_func=self.latent_cache.fetch_from_cache)
keys_fetch_func=self.latent_cache_k.fetch_from_cache,
values_fetch_func=self.latent_cache_v.fetch_from_cache)
output = output.view(batch_size, 1, -1)
result = self._v_up_proj_and_o_proj(output)
result = result.view(batch_size, 1, -1)
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def get_cache_block_size(
key_cache_block = cache_config.block_size * num_heads * head_size
# For MLA there is no value cache, since the latent vector
# is joint keys and values.
value_cache_block = key_cache_block if not model_config.use_mla else 0
# value_cache_block = key_cache_block if not model_config.use_mla else 0
value_cache_block = key_cache_block // 9 * 8
total = num_attention_layers * (key_cache_block + value_cache_block)
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
Expand Down
23 changes: 11 additions & 12 deletions vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,25 +568,24 @@ def _allocate_kv_cache(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)

use_mla = False
if len(kv_cache_shape) == 2 and kv_cache_shape[1]:
if len(kv_cache_shape) == 2:
use_mla = True
kv_cache_shape = kv_cache_shape[0]
k_cache_shape = kv_cache_shape[0]
v_cache_shape = kv_cache_shape[1]
else:
k_cache_shape = kv_cache_shape
v_cache_shape = kv_cache_shape

kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
dtype = self.dtype
if device != 'hpu' and not is_fake_hpu() \
and self.dtype == torch.float8_e4m3fn:
dtype = torch.uint8
for _ in range(self.num_attention_layers):
if use_mla:
kv_layer = torch.zeros(kv_cache_shape,
dtype=dtype,
device=device)
else:
key_cache = torch.zeros(kv_cache_shape, dtype=dtype, device=device)
value_cache = torch.zeros(kv_cache_shape,
dtype=dtype,
device=device)
kv_layer = (key_cache, value_cache)
key_cache = torch.zeros(k_cache_shape, dtype=dtype, device=device)
value_cache = torch.zeros(v_cache_shape,
dtype=dtype,
device=device)
kv_layer = (key_cache, value_cache)
kv_cache.append(kv_layer)
return kv_cache

0 comments on commit 9702891

Please sign in to comment.