Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuXiaoxuanPKU committed Nov 12, 2023
1 parent 32267f6 commit 1aab040
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 72 deletions.
4 changes: 4 additions & 0 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
self.gpu_allocator.free(last_block)
return last_block.block_number, new_block.block_number


def delete_slot(self, seq: Sequence) -> None:
pass

def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM.
Expand Down
6 changes: 3 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,16 @@ def create_engine_configs(
model_config.max_model_len,
self.max_paddings)

spec_decoding_config = None
spec_dec_config = None
if self.draft_model:
# assume the draft model and target model share the same tokenizer
# for now, share the same seed as the target
draft_model_config = ModelConfig(self.draft_model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format,
'auto', self.seed)
spec_decoding_config = SpecDecConfig(draft_model_config, self.propose_cnt)
return model_config, cache_config, parallel_config, scheduler_config, spec_decoding_config
spec_dec_config = SpecDecConfig(draft_model_config, self.propose_cnt)
return model_config, cache_config, parallel_config, scheduler_config, spec_dec_config


@dataclass
Expand Down
10 changes: 6 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
spec_decoding_config: Optional[SpecDecConfig],
spec_dec_config: Optional[SpecDecConfig],
distributed_init_method: str,
placement_group: Optional["PlacementGroup"],
log_stats: bool,
Expand Down Expand Up @@ -125,8 +125,8 @@ def __init__(
self.num_generation_tokens: List[Tuple[float, int]] = []

self.spec_worker = None
if spec_decoding_config:
self.spec_worker = SpecDecWorker(spec_decoding_config)
if spec_dec_config:
self.spec_worker = SpecDecWorker(spec_dec_config)

def _init_workers(self, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers
Expand Down Expand Up @@ -577,7 +577,9 @@ def step(self) -> List[RequestOutput]:
)

if self.spec_worker:
self.spec_worker.accept(output)
# accept will read draft_token_ids and draft_token_probs from scheduler_outputs
# and set accepted_token_ids and accepted_token_probs in output
self.spec_worker.accept(output, scheduler_outputs)

return self._process_model_outputs(output, scheduler_outputs) + ignored

Expand Down
134 changes: 78 additions & 56 deletions vllm/engine/spec_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from typing import List, Dict
from vllm.sequence import SamplerOutput, SequenceGroupOutputs, SequenceOutputs
from vllm.core.scheduler import SchedulerOutputs
from vllm.worker.worker import Worker
from vllm.logger import init_logger

Expand All @@ -13,62 +14,68 @@
# FIXME: we should get pad_token_id from tokenizer
PAD_TOKEN_ID = 0


class SpecDecWorker(Worker):
def __init__(self, config: SpecDecConfig) -> None:
self.propose_cnt = config.propose_cnt
self.draft_model_config = config.draft_model_config

# self.draft_model = get_model(self.draft_model_config)
logger.info(
"Initializing speculative decoding worker: "
f"model={self.draft_model_config.model!r}, "
f"tokenizer={self.draft_model_config.tokenizer!r}, "
f"propose_cnt={self.propose_cnt}, "
f"seed={self.draft_model_config.seed})")
self.draft_model = AutoModelForCausalLM.from_pretrained(self.draft_model_config.model).cuda()

##### values to be set
self.draft_probs = None
self.draft_kvs = None # if we use hf stype kvs
self.draft_model = AutoModelForCausalLM.from_pretrained(
self.draft_model_config.model).cuda()

##### values to be set #####
self.draft_kvs = None # if we use hf stype kvs

def _prepare_inputs(self,
def _prepare_inputs(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[torch.Tensor]:
input_ids_list = []
for seq_group_metadata in seq_group_metadata_list:
assert len(seq_group_metadata.seq_data) == 1, f"Speculative Decoding does nor beam search for now: {len(seq_group_metadata.seq_data)}"
assert len(
seq_group_metadata.seq_data) == 1, f"Speculative Decoding does nor beam search for now: {len(seq_group_metadata.seq_data)}"
seq_id = next(iter(seq_group_metadata.seq_data))
seq = seq_group_metadata.seq_data[seq_id]
input_ids_list.append(seq.get_token_ids())
max_len = max([len(input_ids) for input_ids in input_ids_list])
input_ids_list = [_pad_left_to_max(input_ids, max_len, PAD_TOKEN_ID) for input_ids in input_ids_list]
input_ids_list = [_pad_left_to_max(
input_ids, max_len, PAD_TOKEN_ID) for input_ids in input_ids_list]
return torch.tensor(input_ids_list, dtype=torch.long, device='cuda')

# TODO: we need to align draft and target model's sampler
def _sample_method(self, logits):
temperature = 1.0
return torch.softmax(logits / temperature, dim=-1)


# propose draft tokens
# the function will run the draft model and set draft_tokens and draft_token_probs of each seq
def set_draft_tokens(self,
seq_group_list: List[SequenceGroupMetadata]) -> None:
seq_group_list: List[SequenceGroupMetadata]) -> None:
logger.info(f"# of input request: {len(seq_group_list)}")
input_tensor = self._prepare_inputs(seq_group_list)
draft_logits, draft_distributions, draft_tokens = [], [], []
# recompute for now
attention_mask=(input_tensor != PAD_TOKEN_ID)
attention_mask = (input_tensor != PAD_TOKEN_ID)
past_key_values = None
for i in range(self.propose_cnt):
with torch.no_grad():
outputs = self.draft_model(input_tensor,
past_key_values=past_key_values,
attention_mask=attention_mask,
use_cache=True)
past_key_values=past_key_values,
attention_mask=attention_mask,
use_cache=True)

past_key_values = outputs.past_key_values
next_token_logits = outputs.logits[:, -1, :]
distribution = self._sample_method(next_token_logits)
attention_mask = torch.cat([attention_mask, torch.ones(input_tensor.shape[0], 1, device='cuda')], dim=1)
attention_mask = torch.cat([attention_mask, torch.ones(
input_tensor.shape[0], 1, device='cuda')], dim=1)
input_tensor = torch.multinomial(distribution, num_samples=1)

draft_logits.append(next_token_logits)
draft_distributions.append(distribution)
draft_tokens.append(input_tensor)
Expand All @@ -78,61 +85,76 @@ def set_draft_tokens(self,
seq = seq_group_metadata.seq_data[seq_id]
for j in range(self.propose_cnt):
draft_token = draft_tokens[j][i].item()
seq.draft_token_probs.append({draft_token:draft_distributions[j][i]})
seq.draft_token_probs.append(
{draft_token: draft_distributions[j][i]})
seq.draft_token_ids.append(draft_token)
logger.info(f"Seq draft tokens: {seq.draft_token_ids}")
# logger.info(f"Seq draft prob: {seq.draft_token_probs}")


@staticmethod
def _extract_draft_prob_dis(sample: SequenceOutputs,
token_id: int,
index: int):
token_prob = sample.sd_draft_probs[index]
assert token_id in token_prob
return token_prob[token_id]

@staticmethod
def _extract_target_prob_dis(seq_grou_output: SequenceGroupOutputs,
token_id: int):
# TODO: implement this
vocab_size = 50272
return torch.rand(1, vocab_size, device='cuda').squeeze(0)

# Accept draft tokens based on draft probabilities and target probabilities
# The implementation strictly follows rejection sampling:
# r = rand(0, 1)
# accpet if r <= p/q
# reject and sample from a new distribution if r > p/q
# The function reads draft tokens/probs from scheduler_output and set accepted token_ids
# in traget_outputs
def accept(self,
target_outputs: List[SamplerOutput]):
def extract_draft_prob_dis(sample: SequenceOutputs,
token_id: int,
index: int):
token_prob = sample.sd_draft_probs[index]
assert token_id in token_prob
return token_prob[token_id]

def extract_target_prob_dis(seq_grou_output: SequenceGroupOutputs,
token_id: int):
# TODO: implement this
vocab_size = 50272
return torch.rand(1, vocab_size, device='cuda').squeeze(0)

# Rejection Sampling
target_outputs: List[SamplerOutput],
scheduler_output: SchedulerOutputs):
for seq_group_output in target_outputs:
assert len(seq_group_output.samples) == 1
sample = seq_group_output.samples[0]
accept_token_ids = []

accepted_token_ids = []
for i, token_id in enumerate(sample.sd_draft_ids):
draft_prob_dis = extract_draft_prob_dis(sample, token_id, i)
target_prob_dis = extract_target_prob_dis(seq_group_output, token_id)
p, q = draft_prob_dis[token_id].item(), target_prob_dis[token_id].item()
draft_prob_dis = SpecDecWorker.extract_draft_prob_dis(
sample, token_id, i)
target_prob_dis = SpecDecWorker.extract_target_prob_dis(
seq_group_output, token_id)
p, q = draft_prob_dis[token_id].item(
), target_prob_dis[token_id].item()
r = torch.rand(1).item()
logger.info(f"p: {p}, q: {q}, r: {r}")
if r <= p/q: # accept
accept_token_ids.append(token_id)
else: # reject and resample
new_dis = torch.clamp(target_prob_dis - draft_prob_dis, min=0)
if r <= p/q: # accept
accepted_token_ids.append(token_id)
else: # reject and resample
new_dis = torch.clamp(
target_prob_dis - draft_prob_dis, min=0)
logger.info((draft_prob_dis - target_prob_dis).max())
new_dis = new_dis / new_dis.sum(dim=-1, keepdim=True)
next_token = torch.multinomial(new_dis, num_samples=1)
accept_token_ids.append(next_token.item())
accepted_token_ids.append(next_token.item())

# all proposed tokens are accepted
if len(accept_token_ids) == len(sample.sd_draft_ids):
accept_token_ids.append(sample.output_token)
logger.info(f"accept tokens: {accept_token_ids}")
if len(accepted_token_ids) == len(sample.sd_draft_ids):
accepted_token_ids.append(sample.output_token)
logger.info(f"accept tokens: {accepted_token_ids}")

self.invalidate_draft_kv()
self.invalidate_target_kv()
exit(0)
exit(0)

def invalidate_draft_kv(self):
pass

def invalidate_target_kv(self):
pass



def _pad_left_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
return [pad] * (max_len - len(x)) + x
return [pad] * (max_len - len(x)) + x
4 changes: 1 addition & 3 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,9 +598,7 @@ def _build_sampler_output(
next_token_ids,
group_sample_logprobs):
seq_outputs.append(
SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs,
input_metadata.seq_data[seq_ids[0]].draft_token_ids,
input_metadata.seq_data[seq_ids[0]].draft_token_probs))
SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append(
SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
return sampler_output
7 changes: 1 addition & 6 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,16 +380,11 @@ def __init__(
self,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, float],
sd_draft_ids: Optional[List[int]],
sd_draft_probs: Optional[Dict[int, float]]
logprobs: Dict[int, float]
) -> None:
self.parent_seq_id = parent_seq_id
self.output_token = output_token
self.logprobs = logprobs

self.sd_draft_ids = sd_draft_ids
self.sd_draft_probs = sd_draft_probs

def __repr__(self) -> str:
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
Expand Down

0 comments on commit 1aab040

Please sign in to comment.