diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 80c132c0a8c05..f4f58a1595894 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -73,6 +73,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): is_prompt: bool attn_bias: Optional[torch.Tensor] seq_lens_tensor: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] class HPUAttentionImpl(AttentionImpl, torch.nn.Module): diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index 4c0fb2a628361..603d3959377c4 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -21,6 +21,7 @@ class HPUPagedAttentionMetadata: block_indices: Optional[torch.Tensor] block_offsets: Optional[torch.Tensor] block_scales: Optional[torch.Tensor] + block_groups: Optional[torch.Tensor] class HPUPagedAttention: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index da0ce1885dbb2..95fe933c9e3e6 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -184,6 +184,26 @@ def forward_cpu( num_expert_group, ) + def forward_hpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + **kwargs, + ): + assert not use_grouped_topk, "use_grouped_topk must be False on HPU" + assert num_expert_group is None, ('num_expert_group is ' + 'not supported on HPU') + assert topk_group is None, "topk_group is not supported on HPU" + if layer is not None: + return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight, + router_logits, top_k) + def forward_tpu( self, layer: torch.nn.Module, @@ -283,6 +303,9 @@ def __init__( if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") + if current_platform.is_hpu(): + from vllm_hpu_extension.ops import DynamicFusedMOE + self.hpu_fused_moe = DynamicFusedMOE(self.num_experts) if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index fbb3704fa080f..733c413ed68b9 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -499,4 +500,6 @@ def load_weights(self, weights: Iterable[Tuple[str, default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) + if current_platform.is_hpu(): + torch.hpu.synchronize() return loaded_params diff --git a/vllm/utils.py b/vllm/utils.py index 15481fb06e08e..2db1eda8b9f20 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -352,6 +352,11 @@ def reset(self): self._index = 0 +@lru_cache(maxsize=None) +def is_fake_hpu() -> bool: + return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0' + + @cache def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index a339c97a8383c..dc50190854bde 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -23,6 +23,7 @@ import torch import torch.nn as nn from vllm_hpu_extension.ops import LoraMask as LoraMask +from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, HabanaMemoryProfiler, format_bytes) @@ -42,7 +43,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) -from vllm.utils import (bind_kv_cache, is_pin_memory_available, +from vllm.utils import (bind_kv_cache, is_fake_hpu, is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, @@ -258,10 +259,19 @@ def setup_profiler(): return profiler -def pad_list(list, k, v): - target_len = round_up(len(list), k) - padding = target_len - len(list) - return list + [v] * padding +def pad_list(input, k, v): + input_len = len(input) + target_len = round_up(input_len, k) + padding = target_len - input_len + return input + [v] * padding + + +def gather_list(input, indices, v): + return [input[i] if i is not None else v for i in indices] + + +def flatten(in_list): + return list(itertools.chain(*in_list)) def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt): @@ -332,23 +342,60 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype): mask = mask >= metadata.block_usage.unsqueeze(-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) - block_mapping = torch.nn.functional.one_hot(metadata.block_mapping, - num_classes=batch_size) + if not is_fake_hpu() and htorch.utils.internal.is_lazy(): + block_mapping = torch.nn.functional.one_hot(metadata.block_groups, + num_classes=batch_size) + else: + # Unfortunately one_hot on CPU/torch.compile mode/eager mode + # doesn't handle out of bounds classes so we need to convert + # all negative values to 0 (block_mapping) or bs (block_groups) + block_groups = metadata.block_groups.to(torch.long) + block_mapping = torch.nn.functional.relu(block_groups) + block_mapping = torch.nn.functional.one_hot(block_mapping, + num_classes=batch_size) + oob_values = block_groups.lt(0) + block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) + block_groups.masked_fill_(oob_values, batch_size) + metadata = metadata._replace(block_groups=block_groups) block_mapping = block_mapping.to(dtype) metadata = metadata._replace(block_mapping=block_mapping, attn_bias=attn_bias) return metadata + def _set_block_scales(self, metadata, device): + block_mapping = metadata.block_mapping + ones = torch.ones((block_mapping.size(0), ), + device=device, + dtype=block_mapping.dtype) + sums = batch2block(block2batch(ones, block_mapping), block_mapping) + block_scales = torch.reciprocal(torch.maximum(ones, sums)) + metadata = metadata._replace(block_scales=block_scales) + return metadata + + def _set_indices_and_offsets(self, metadata, block_size, is_prompt): + slot_mapping = metadata.slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + if is_prompt: + indices = indices.unflatten(0, (-1, block_size))[:, 0] + offsets = None + else: + offsets = torch.fmod(slot_mapping, block_size) + metadata = metadata._replace(block_offsets=offsets, + block_indices=indices) + return metadata + def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): if attn_metadata.is_prompt: - meta = attn_metadata - attn_metadata = self._set_attn_bias(meta, batch_size, seq_len, - device, dtype) + attn_metadata = self._set_attn_bias(attn_metadata, batch_size, + seq_len, device, dtype) else: - meta = attn_metadata - attn_metadata = self._set_block_mapping(meta, batch_size, device, - dtype) + attn_metadata = self._set_block_mapping(attn_metadata, batch_size, + device, dtype) + attn_metadata = self._set_block_scales(attn_metadata, device) + attn_metadata = self._set_indices_and_offsets(attn_metadata, + self.block_size, + attn_metadata.is_prompt) return attn_metadata def forward(self, *args, **kwargs): @@ -584,6 +631,8 @@ def __init__( self.bucketing_global_state = HPUBucketingGlobalState() self._setup_buckets() self._set_gc_threshold() + self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', + 'true').lower() == 'true' def _set_gc_threshold(self) -> None: # Read https://docs.python.org/3/library/gc.html#gc.set_threshold @@ -888,6 +937,11 @@ def _prepare_prompt( block_indices, block_offsets = precompute_indices_and_offsets( self.block_size, slot_mapping, True) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.long, + device='cpu') + context_lens_tensor = context_lens_tensor.to(self.device, + non_blocking=True) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, block_list=None, @@ -896,8 +950,10 @@ def _prepare_prompt( block_indices=block_indices, block_offsets=block_offsets, block_scales=None, + block_groups=None, attn_bias=None, seq_lens_tensor=seq_lens_tensor, + context_lens_tensor=context_lens_tensor, num_prefills=real_num_seqs, num_prefill_tokens=sum_query_len, num_decode_tokens=0, @@ -923,6 +979,7 @@ def _prepare_prompt( def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], + output=None, ) -> PrepareDecodeMetadata: input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] @@ -953,8 +1010,9 @@ def _prepare_decode( for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) + if output is None: + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) seq_len = seq_data.get_len() position = seq_len - 1 @@ -965,6 +1023,9 @@ def _prepare_decode( seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] + num_fully_occupied_blocks = position // self.block_size + block_table = block_table[:num_fully_occupied_blocks + 1] + if len(block_table) == 0: block_number = _PAD_BLOCK_ID else: @@ -984,76 +1045,85 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) + if output is None: + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + else: + real_batch_size = len(seq_group_metadata_list) + input_tokens = output[:real_batch_size] + input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) num_decode_tokens = sum(seq_lens) - blocks_used = [len(bt) for bt in block_tables if bt] - block_list = [] - block_scales = [] - for i, bt in enumerate(block_tables): - block_list.extend(bt) - blocks_in_group = len(bt) - if blocks_in_group > 0: - scale = 1.0 / blocks_in_group - block_scales.extend([scale] * blocks_in_group) - - block_mapping_nested: List[List[int]] = [ - [i] * b_u for i, b_u in enumerate(blocks_used) + last_block_usage = [ + slot[0] % self.block_size + 1 for slot in slot_mapping ] - block_mapping: List[int] = list( - itertools.chain.from_iterable(block_mapping_nested)) + block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] + block_usage = [[self.block_size] * (len(bt) - 1) + [lbu] + for bt, lbu in zip(block_tables, last_block_usage) + if bt] + + block_list = flatten(block_tables) + block_groups = flatten(block_groups) + block_usage = flatten(block_usage) + + assert len(block_list) == len(block_groups) + assert len(block_list) == len(block_usage) + + padding_fn = None + if self.use_contiguous_pa: + block_bucket_size = find_bucket( + max(block_list) + 1, + self.bucketing_global_state.decode_block_bucket_cfg) + indices: List[Any] + indices = [None] * block_bucket_size + for i, bid in enumerate(block_list): + indices[bid] = i + padding_fn = lambda tensor, pad_value: gather_list( + tensor, indices, pad_value) + else: + block_bucket_size = find_bucket( + len(block_list), + self.bucketing_global_state.decode_block_bucket_cfg) + padding_fn = lambda tensor, pad_value: pad_list( + tensor, block_bucket_size, pad_value) - last_block = [ - sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping) - ] - block_usage = [[self.block_size] * (b_u - 1) + [lb] - for b_u, lb in zip(blocks_used, last_block)] - block_usage = list(itertools.chain(*block_usage)) - - block_bucket_size = find_bucket( - len(block_list), - self.bucketing_global_state.decode_block_bucket_cfg) - block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID) - block_mapping = pad_list(block_mapping, block_bucket_size, -1) - block_usage = pad_list(block_usage, block_bucket_size, 1) - block_scales = pad_list(block_scales, block_bucket_size, 0.0) + block_list = padding_fn(block_list, _PAD_BLOCK_ID) + block_groups = padding_fn(block_groups, -1) + block_usage = padding_fn(block_usage, 1) block_list = torch.tensor(block_list, dtype=torch.int, device=self.device) - block_mapping = torch.tensor(block_mapping, - dtype=torch.long, - device=self.device) + block_groups = torch.tensor(block_groups, + dtype=torch.int, + device=self.device) block_usage = torch.tensor(block_usage, dtype=self.model_config.dtype, device=self.device) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) block_indices, block_offsets = precompute_indices_and_offsets( self.block_size, slot_mapping, False) - block_scales = torch.tensor(block_scales, - dtype=self.model_config.dtype, - device=self.device) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, block_list=block_list, - block_mapping=block_mapping, + block_mapping=None, block_usage=block_usage, block_indices=block_indices, block_offsets=block_offsets, - block_scales=block_scales, + block_scales=None, + block_groups=block_groups, attn_bias=None, seq_lens_tensor=None, + context_lens_tensor=None, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, @@ -1263,9 +1333,10 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # input_hash(123) != input_hash(321) # input_hash("abc") != input_hash("cba") attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ - 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', - 'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', - 'block_offsets', 'block_scales' + 'attn_bias', 'seq_lens_tensor', 'context_lens_tensor', + 'block_list', 'block_mapping', 'block_usage', 'slot_mapping', + 'is_prompt', 'block_indices', 'block_offsets', 'block_scales', + 'block_groups' ]) return attention_metadata