Skip to content

Commit

Permalink
vLLM-Base: Resolved ALiBI bias regression
Browse files Browse the repository at this point in the history
- Works in lazy and eager mode

Co-authored-by: Tanner Voas <[email protected]>
Co-authored-by: Haihao Xiang <[email protected]>
Signed-off-by: Tanner Voas <[email protected]>
  • Loading branch information
tannervoas742 and xhaihao committed Nov 26, 2024
1 parent 5eb8b1f commit b339767
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 63 deletions.
82 changes: 69 additions & 13 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ def __init__(
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'

self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA',
'true').lower() == 'true'
if self.use_contiguous_pa:
assert alibi_slopes is None, \
'Contiguous PA not supported with alibi slopes!'

if ops.ACTUAL_PA_SOFTMAX_IMPL != 'wsum_head_amax':
assert alibi_slopes is None, \
'Alibi slopes supports only "wsum_head_amax" softmax implementation!'

suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
Expand Down Expand Up @@ -202,11 +212,15 @@ def forward(
'attn_bias must be set before calling model.forward'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads,
attn_bias.dtype, attn_bias.shape[-1])
position_bias = _make_prompt_alibi_bias(
self.alibi_slopes,
self.num_kv_heads,
self.alibi_slopes.dtype,
attn_bias.shape[-1],
)
attn_bias = attn_bias.tile(
(1, self.num_kv_heads, 1, 1))
(1, self.num_kv_heads, 1, 1)
)
attn_bias.add_(position_bias)
else:
attn_bias = None
Expand Down Expand Up @@ -242,6 +256,17 @@ def forward(
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
self.position_bias = None
attn_bias = attn_metadata.attn_bias
alibi_blocks = attn_metadata.alibi_blocks
if self.alibi_slopes is not None and alibi_blocks is not None:
self.position_bias = _make_decode_alibi_bias(
alibi_blocks,
self.alibi_slopes,
self.num_kv_heads,
self.alibi_slopes.dtype,
)

output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
Expand All @@ -252,17 +277,21 @@ def forward(
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
alibi_slopes=self.position_bias,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
values_fetch_func=self.v_cache.fetch_from_cache,
)

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
output = output.view(batch_size, seq_len, hidden_size)
return output


def _make_alibi_bias(
def _make_prompt_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
Expand All @@ -280,15 +309,42 @@ def _make_alibi_bias(

padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
per_head_bias = torch.empty(
1,
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
)[:, :, :, :seq_len]
# NOTE(Tanner):
# .copy_ was not performing broadcasting of bias to all 32 heads in Eager mode.
per_head_bias[:, :] = bias
per_head_bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
per_head_bias = per_head_bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))

return per_head_bias


def _make_decode_alibi_bias(
alibi_blocks: torch.Tensor,
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
) -> torch.Tensor:
num_heads = alibi_slopes.shape[0]
per_head_bias = torch.empty(
alibi_blocks.size(0), # num blocks
num_heads,
alibi_blocks.size(-1),
device=alibi_slopes.device,
dtype=dtype,
)
# NOTE(Tanner):
# .copy_ was not performing broadcasting of bias to all 32 heads in Eager mode.
per_head_bias[:, :] = alibi_blocks.unsqueeze(-2)
per_head_bias.mul_(alibi_slopes[None, :, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return bias
per_head_bias = per_head_bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return per_head_bias
1 change: 1 addition & 0 deletions vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class HPUPagedAttentionMetadata:
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
alibi_blocks: Optional[torch.Tensor]


class HPUPagedAttention:
Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.postion_embedding = position_embedding
self.position_embedding = position_embedding
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings

Expand All @@ -146,7 +146,7 @@ def __init__(
quant_config=quant_config,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
if self.position_embedding == "ALIBI":
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
Expand Down Expand Up @@ -182,7 +182,7 @@ def forward(
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
if self.position_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
Expand Down Expand Up @@ -357,6 +357,7 @@ def __init__(
self.lora_config = lora_config

self.quant_config = quant_config
self.use_alibi = position_embedding == "ALIBI"
self.model = BaiChuanModel(vllm_config=vllm_config,
prefix=prefix,
position_embedding=position_embedding)
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.use_alibi = True
self.transformer = BloomModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.use_alibi = config.alibi

# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
Expand Down Expand Up @@ -417,6 +416,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.use_alibi = config.alibi
self.transformer = FalconModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/jais.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.use_alibi = config.position_embedding_type == "alibi"
self.transformer = JAISModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = config
assert config.tie_word_embeddings
self.quant_config = quant_config

self.use_alibi = config.attn_config['alibi']
self.transformer = MPTModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "transformer"))
self.lm_head = self.transformer.wte
Expand Down
Loading

0 comments on commit b339767

Please sign in to comment.