Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
yangw1234 authored and xuechendi committed Feb 3, 2025
1 parent c7c55ef commit de8ca17
Showing 1 changed file with 5 additions and 15 deletions.
20 changes: 5 additions & 15 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,6 @@ def forward(
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim)
input_positions = attn_metadata.input_positions.view(-1)
print("q_pe", q_pe.shape)
print("k_pe", k_pe.shape)
print("input_positions", attn_metadata.input_positions.shape)
q_pe, k_pe = \
self.rotary_emb(input_positions, q_pe, k_pe)
else:
Expand All @@ -196,9 +193,6 @@ def forward(

q_pe = q[..., self.qk_nope_head_dim:]

# print("q_pe shape", q_pe.shape)
# print("k_pe shape", k_pe.shape)
# print("input_positions shape", attn_metadata.input_positions.shape)
input_positions = attn_metadata.input_positions.view(-1)
# TODO(lucas): there must be a nicer way to write this line
q[..., self.qk_nope_head_dim:], k_pe = \
Expand All @@ -210,7 +204,11 @@ def forward(
latent_vec = torch.concat(
(k_c_normed, k_pe.view(batch_size, -1, self.qk_rope_head_dim)), dim=-1)
# assert layer._k_scale == 0, f"got _k_scale={layer._k_scale}"
# print(f"layer._k_scale={layer._k_scale}")
latent_vec = latent_vec.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
if is_prefill:
latent_vec = latent_vec.unflatten(0, (block_indices.size(0), -1))
# print("latent_vec", latent_vec.shape)


# write the latent and rope to kv cache
if kv_cache is not None:
Expand Down Expand Up @@ -271,16 +269,10 @@ def _forward_decode(
attn_metadata: HPUAttentionMetadata,
batch_size: int
) -> torch.Tensor:
print(f"q_nope shape: {q_nope.shape}")
print(f"q_pe shape: {q_pe.shape}")

q = torch.cat([q_nope, q_pe], dim=-1)
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]

print(f"q shape: {q.shape}")
print(f"kv_c_and_k_pe_cache shape: {kv_c_and_k_pe_cache.shape}")
print(f"kv_c_cache shape: {kv_c_cache.shape}")
output = HPUPagedAttention.forward_decode(
query=q,
key_cache=kv_c_and_k_pe_cache,
Expand All @@ -298,10 +290,8 @@ def _forward_decode(
keys_fetch_func=self.latent_cache.fetch_from_cache,
values_fetch_func=self.latent_cache.fetch_from_cache)
output = output.view(batch_size, 1, -1)
print("output", output.shape)
result = self._v_up_proj_and_o_proj(output)
result = result.view(batch_size, 1, -1)
print("result", result.shape)
return result


Expand Down

0 comments on commit de8ca17

Please sign in to comment.