diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d083b962eb282..d735bd60ea9ea 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -124,8 +124,8 @@ def __init__( # List of (timestamp, num_tokens) self.num_generation_tokens: List[Tuple[float, int]] = [] - self.speculative_decoding = True - if self.speculative_decoding: + self.spec_handler = None + if spec_decoding_config is not None: self.spec_handler = SpeculativeHandler(spec_decoding_config) def _init_workers(self, distributed_init_method: str): @@ -565,7 +565,7 @@ def step(self) -> List[RequestOutput]: return ignored draft_tokens = None - if self.speculative_decoding: + if self.spec_handler: draft_tokens = self.spec_handler.propose() # Execute the model. @@ -578,7 +578,7 @@ def step(self) -> List[RequestOutput]: draft_tokens=draft_tokens ) - if self.speculative_decoding: + if self.spec_handler: self.spec_handler.accept(output) self.spec_handler.invalidate_draft_kv() self.spec_handler.invalidate_target_kv() diff --git a/vllm/sequence.py b/vllm/sequence.py index ecfaee6e8c3d6..f2e808d9ed926 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -372,6 +372,10 @@ def __init__( self.parent_seq_id = parent_seq_id self.output_token = output_token self.logprobs = logprobs + + # only for speculative decoding + self.sd_output_tokens = None + self.sd_logprobs_list = None def __repr__(self) -> str: return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "