Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuXiaoxuanPKU committed Nov 6, 2023
1 parent 46cd4c3 commit edeaec0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 4 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def __init__(
# List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = []

self.speculative_decoding = True
if self.speculative_decoding:
self.spec_handler = None
if spec_decoding_config is not None:
self.spec_handler = SpeculativeHandler(spec_decoding_config)

def _init_workers(self, distributed_init_method: str):
Expand Down Expand Up @@ -565,7 +565,7 @@ def step(self) -> List[RequestOutput]:
return ignored

draft_tokens = None
if self.speculative_decoding:
if self.spec_handler:
draft_tokens = self.spec_handler.propose()

# Execute the model.
Expand All @@ -578,7 +578,7 @@ def step(self) -> List[RequestOutput]:
draft_tokens=draft_tokens
)

if self.speculative_decoding:
if self.spec_handler:
self.spec_handler.accept(output)
self.spec_handler.invalidate_draft_kv()
self.spec_handler.invalidate_target_kv()
Expand Down
4 changes: 4 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ def __init__(
self.parent_seq_id = parent_seq_id
self.output_token = output_token
self.logprobs = logprobs

# only for speculative decoding
self.sd_output_tokens = None
self.sd_logprobs_list = None

def __repr__(self) -> str:
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
Expand Down

0 comments on commit edeaec0

Please sign in to comment.