From 756485fafaf029b9682fe95f690038619e1b8b19 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Thu, 28 Nov 2024 03:09:38 -0600 Subject: [PATCH] [BUG FIX] [SPEC DECODE] 0.6.4 rebase cause incorrectness in spec decode, fix in this PR (#523) Noticed that Spec Decode went incorrect after rebase to 0.6.4 Identified root cause and fixed in the PR 1. incorrect return value position in batch_expansion.py 2. ContinuousPA generates faulty result in spec decode CI added: https://github.com/HabanaAI/vllm-fork/pull/524 --------- Signed-off-by: Chendi.Xue --- .jenkins/test_config.yaml | 7 +- vllm/model_executor/model_loader/loader.py | 7 +- vllm/spec_decode/batch_expansion.py | 95 ++++++++++++++++++---- vllm/spec_decode/hpu_draft_model_runner.py | 16 +++- vllm/spec_decode/ngram_worker.py | 20 ++++- vllm/worker/hpu_model_runner.py | 5 ++ 6 files changed, 125 insertions(+), 25 deletions(-) diff --git a/.jenkins/test_config.yaml b/.jenkins/test_config.yaml index 3707725161576..c729be31b3a1e 100644 --- a/.jenkins/test_config.yaml +++ b/.jenkins/test_config.yaml @@ -43,4 +43,9 @@ stages: command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-mss.txt -t 2 - name: gsm8k_small_g2_tp2_mss flavor: g2.s - command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-mss.txt -t 2 \ No newline at end of file + command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-mss.txt -t 2 + - name: test_gsm8k_spec_decode + steps: + - name: gsm8k_small_g2_tp1_spec_decode + flavor: g2 + command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-mss.txt -t 1 \ No newline at end of file diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 2c9e42d5b613c..0e91de1587246 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -346,9 +346,10 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: if model_config.quantization is None and loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights if weights_not_loaded: - raise ValueError( - "Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") + warning_msg = f"Following weights were not initialized \ + from checkpoint: {weights_not_loaded}" + + logger.warning(warning_msg) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index b0fc1f7616802..0747b3a8c968e 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -6,6 +6,7 @@ from vllm import SamplingParams from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, SequenceData, SequenceGroupMetadata, get_all_seq_ids) @@ -159,11 +160,18 @@ def _contract_batch( target_sampler_output will be contracted to. """ contracted_bs = len(contracted_seq_group_metadata_list) - (target_token_ids, target_probs, target_logprobs, target_hidden_states, - non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs, - non_spec_target_hidden_states) = self._split_scoring_output( - target_sampler_output, num_scoring_tokens) + if current_platform.is_hpu(): + (target_token_ids, target_probs, target_logprobs, + target_hidden_states, non_spec_target_token_ids, + non_spec_target_probs, non_spec_target_logprobs, + non_spec_target_hidden_states) = self._split_scoring_output_hpu( + target_sampler_output, num_scoring_tokens) + else: + (target_token_ids, target_probs, target_logprobs, + target_hidden_states, non_spec_target_token_ids, + non_spec_target_probs, non_spec_target_logprobs, + non_spec_target_hidden_states) = self._split_scoring_output( + target_sampler_output, num_scoring_tokens) # Map distinct sequences used to score each token # of shape [batch_size * k + 1] back to [batch_size, k + 1]. @@ -239,18 +247,30 @@ def _contract_batch_all_spec( # Map distinct sequences used to score each token # of shape [batch_size * k + 1] back to [batch_size, k + 1]. contracted_bs, k = proposals.proposal_token_ids.shape - - ( - target_sampler_output.sampled_token_ids, - target_sampler_output.sampled_token_probs, - target_sampler_output.logprobs, - target_sampler_output.hidden_states, - _, - _, - _, - _, - ) = self._split_scoring_output(target_sampler_output, - num_scoring_tokens) + if current_platform.is_hpu(): + ( + target_sampler_output.sampled_token_ids, + target_sampler_output.sampled_token_probs, + target_sampler_output.logprobs, + target_sampler_output.hidden_states, + _, + _, + _, + _, + ) = self._split_scoring_output_hpu(target_sampler_output, + num_scoring_tokens) + else: + ( + target_sampler_output.sampled_token_ids, + target_sampler_output.sampled_token_probs, + target_sampler_output.logprobs, + target_sampler_output.hidden_states, + _, + _, + _, + _, + ) = self._split_scoring_output(target_sampler_output, + num_scoring_tokens) # Reshape tensors to original batch size target_token_ids = target_sampler_output.sampled_token_ids.reshape( @@ -397,6 +417,47 @@ def _create_single_target_seq_group_metadata( token_chunk_size=1, ) + @staticmethod + def _split_scoring_output_hpu( + sampler_output: SamplerOutput, num_scoring_tokens: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor, Optional[torch.Tensor]]: + """Split the target model output into speculative and non-speculative + output. + """ + + # vLLM currently only supports proposal lens equal to zero or the batch + # proposal len. This adds some complexity (splitting the batch into spec + # and non spec sequences) and should be removed in the future. It can be + # done by supporting per-sequence proposal lens. + # + # First samples are from speculative scoring, latter samples are non- + # speculative samples. + split_sizes = (num_scoring_tokens, + sampler_output.sampled_token_ids.numel() - + num_scoring_tokens) + (spec_probs, non_spec_probs + ) = sampler_output.sampled_token_probs.split(split_sizes) + (spec_sampled_tokens, non_spec_sampled_tokens + ) = sampler_output.sampled_token_ids.flatten().split(split_sizes) + ( + spec_logprobs, + non_spec_logprobs, + ) = sampler_output.logprobs.split(split_sizes) + + if sampler_output.hidden_states is not None: + ( + spec_hidden_states, + non_spec_hidden_states, + ) = sampler_output.hidden_states.split(split_sizes) + else: + spec_hidden_states, non_spec_hidden_states = None, None + + return (spec_sampled_tokens, spec_probs, spec_logprobs, + spec_hidden_states, non_spec_sampled_tokens, non_spec_probs, + non_spec_logprobs, non_spec_hidden_states) + @staticmethod def _split_scoring_output( sampler_output: SamplerOutput, num_scoring_tokens: int diff --git a/vllm/spec_decode/hpu_draft_model_runner.py b/vllm/spec_decode/hpu_draft_model_runner.py index a5943bdd7d804..e53a46f188282 100644 --- a/vllm/spec_decode/hpu_draft_model_runner.py +++ b/vllm/spec_decode/hpu_draft_model_runner.py @@ -50,9 +50,19 @@ def execute_model( num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: if previous_hidden_states is not None: - _, block_size = model_input.input_tokens.shape - previous_hidden_states = previous_hidden_states.expand( - block_size, -1).unsqueeze(0) + batch_size, block_size = model_input.input_tokens.shape + previous_hidden_states = previous_hidden_states.unsqueeze( + dim=1).expand(-1, block_size, -1) + # because HPU will pad batch_size, + # we need to pad previous_hidden_states as well + batch_size_padding = batch_size - previous_hidden_states.shape[0] + if batch_size_padding > 0: + dummy_previous_hidden_states = torch.zeros_like( + previous_hidden_states[1:2]).expand( + batch_size_padding, -1, -1) + previous_hidden_states = torch.cat( + [previous_hidden_states, dummy_previous_hidden_states], + dim=0) return super().execute_model( model_input=model_input, kv_caches=kv_caches, diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index debb3b2d5ec30..c759551ad1246 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -4,11 +4,29 @@ import torch from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer +if current_platform.is_cuda_alike(): + DEVICE_TYPE = "cuda" +elif current_platform.is_neuron(): + DEVICE_TYPE = "neuron" +elif current_platform.is_hpu(): + DEVICE_TYPE = "hpu" +elif current_platform.is_openvino(): + DEVICE_TYPE = "openvino" +elif current_platform.is_cpu(): + DEVICE_TYPE = "cpu" +elif current_platform.is_tpu(): + DEVICE_TYPE = "tpu" +elif current_platform.is_xpu(): + DEVICE_TYPE = "xpu" +else: + raise ValueError(f"Unsupported platform: {current_platform}") + class NGramWorker(NonLLMProposerWorkerBase): """NGramWorker provides a light drafter without need for model. @@ -34,7 +52,7 @@ def set_ngram_window_size(self, ngram_prompt_lookup_min: int, self.ngram_prompt_lookup_min = ngram_prompt_lookup_min def init_device(self): - self.device = torch.device(f"cuda:{self.local_rank}") + self.device = torch.device(f"{DEVICE_TYPE}:{self.local_rank}") self.load_model = lambda *args, **kwargs: None # Current NGramWorker only supports Top1Proposer diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 71bd0bb8f14d8..ded3aa91a0579 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -538,6 +538,11 @@ def __init__( self._set_gc_threshold() self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', 'true').lower() == 'true' + if vllm_config.speculative_config is not None \ + and self.use_contiguous_pa: + raise ValueError( + "Speculative decoding is not supported with " + "contiguous PA, please set VLLM_CONTIGUOUS_PA=false") # For multi-step scheduling self.cached_step_outputs: List[torch.Tensor] = []