Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuXiaoxuanPKU committed Nov 9, 2023
1 parent 3c7397e commit 7e6224a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
14 changes: 11 additions & 3 deletions vllm/engine/spec_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,33 @@ def _prepare_inputs(self,

def set_draft_tokens(self,
seq_group_list: List[SequenceGroupMetadata]) -> torch.Tensor:
logger.info(f"# of input request: {len(seq_group_list)}")
input_tensor = self._prepare_inputs(seq_group_list)
# recompute for now
attention_mask=(input_tensor != PAD_TOKEN_ID)
draft_tokens = self.draft_model.generate(input_ids=input_tensor,
attention_mask=attention_mask,
max_new_tokens=self.propose_cnt)[:, input_tensor.shape[1]:]
logger.info(f"Input tokens: {input_tensor}")
logger.info(f"Draft tokens: {draft_tokens}")
for i, seq_group_metadata in enumerate(seq_group_list):
seq_id = next(iter(seq_group_metadata.seq_data))
seq = seq_group_metadata.seq_data[seq_id]
seq.draft_token_ids = draft_tokens[i]
seq.draft_token_ids = draft_tokens[i].tolist()

return draft_tokens

def accept(self,
target_output: List[SamplerOutput]):
def extract_probs(output: List[SamplerOutput]):
pass

logprobs = []
logger.info(f"# of output: {len(output)}")
for seq_group_output in output:
assert len(seq_group_output.samples) == 1
print(seq_group_output.prompt_logprobs)
sample = seq_group_output.samples[0]
print(sample.logprobs)
exit(0)
target_probs = extract_probs(target_output)
_prob_accept(self.draft_probs, target_probs)

Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ def _get_logprobs(
sampling_params.prompt_logprobs)
prompt_len = input_metadata.prompt_lens[i]
prompt_tokens = input_metadata.seq_data[
seq_ids[0]].prompt_token_ids
seq_ids[0]].prompt_token_ids + input_metadata.seq_data[
seq_ids[0]].draft_token_ids
batched_logprobs_query_seq_indices.extend(
sample_idx + j for j in range(prompt_len - 1))
batched_logprobs_query_token_indices.extend(
Expand All @@ -508,6 +509,7 @@ def _get_logprobs(
batched_logprobs_query_seq_indices,
batched_logprobs_query_token_indices
]].cpu()
print(batched_logprobs_query_result)

# Batched query for logprobs of topk tokens
if largest_num_logprobs > 0:
Expand Down Expand Up @@ -535,7 +537,8 @@ def _get_logprobs(
num_logprobs = sampling_params.prompt_logprobs
prompt_len = input_metadata.prompt_lens[i]
prompt_tokens = input_metadata.seq_data[
seq_ids[0]].prompt_token_ids
seq_ids[0]].prompt_token_ids + input_metadata.seq_data[
seq_ids[0]].draft_token_ids
group_prompt_logprobs: PromptLogprobs = [None]
for token_id in prompt_tokens[1:]:
prompt_logprobs_dict = {
Expand Down
4 changes: 0 additions & 4 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,6 @@ def __init__(
self.parent_seq_id = parent_seq_id
self.output_token = output_token
self.logprobs = logprobs

# only for speculative decoding
self.sd_output_tokens = None
self.sd_logprobs_list = None

def __repr__(self) -> str:
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
Expand Down

0 comments on commit 7e6224a

Please sign in to comment.