Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spec Draft #3

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
75ae5fd
spec draft
LiuXiaoxuanPKU Nov 6, 2023
46cd4c3
Merge branch 'vllm-project:main' into spec
LiuXiaoxuanPKU Nov 6, 2023
edeaec0
minor
LiuXiaoxuanPKU Nov 6, 2023
95a7e13
minor
LiuXiaoxuanPKU Nov 8, 2023
366fbb9
draft tokens
LiuXiaoxuanPKU Nov 8, 2023
3c7397e
minor
LiuXiaoxuanPKU Nov 8, 2023
9f35009
merge
LiuXiaoxuanPKU Nov 8, 2023
9b64276
Merge branch 'main' of github.com:LiuXiaoxuanPKU/vllm
LiuXiaoxuanPKU Nov 8, 2023
1525262
Merge branch 'main' into spec
LiuXiaoxuanPKU Nov 8, 2023
7e6224a
minor
LiuXiaoxuanPKU Nov 9, 2023
93901c8
Merge branch 'spec' of github.com:LiuXiaoxuanPKU/vllm into spec
LiuXiaoxuanPKU Nov 9, 2023
692328a
draft logits
LiuXiaoxuanPKU Nov 9, 2023
8b6d647
need to change draft token probs data structure
LiuXiaoxuanPKU Nov 9, 2023
675e1ae
rejection sampling
LiuXiaoxuanPKU Nov 9, 2023
32267f6
rejection sampling
LiuXiaoxuanPKU Nov 10, 2023
1aab040
format
LiuXiaoxuanPKU Nov 12, 2023
826b54a
get draft probs
LiuXiaoxuanPKU Nov 12, 2023
b2ec9aa
style
LiuXiaoxuanPKU Nov 12, 2023
6382396
combine draft_token_ids and output_token_ids in SequenceData
LiuXiaoxuanPKU Nov 13, 2023
89d8ba2
invalidate kv draft
LiuXiaoxuanPKU Nov 13, 2023
9594d08
fix
LiuXiaoxuanPKU Nov 13, 2023
6b1e94c
pass in multiple tokens for generation phase, kv_mqa
LiuXiaoxuanPKU Nov 13, 2023
2d5c379
pass scheduler to spec worker
LiuXiaoxuanPKU Nov 13, 2023
025bb89
mqa
LiuXiaoxuanPKU Nov 15, 2023
dd23ff7
separate sampler
LiuXiaoxuanPKU Nov 15, 2023
f1b3987
lots of fix, multi_qa_kv runnable
LiuXiaoxuanPKU Nov 16, 2023
9a85990
nan in hidden states
LiuXiaoxuanPKU Nov 16, 2023
54bfebd
lots of style fix, early break accepting tokens
LiuXiaoxuanPKU Nov 17, 2023
a904ac9
fix free bug
LiuXiaoxuanPKU Nov 18, 2023
0cb9326
bug fix
LiuXiaoxuanPKU Nov 18, 2023
4e9ae6c
minor fix get target probs in prefill phase
LiuXiaoxuanPKU Nov 18, 2023
0ff36e7
fix mismatch between logical and physical blocks!!
LiuXiaoxuanPKU Nov 24, 2023
d2d67f9
add alphas
LiuXiaoxuanPKU Nov 27, 2023
7d94cb2
tokenizer & bug fix
LiuXiaoxuanPKU Nov 30, 2023
b1a5a88
pass tests
LiuXiaoxuanPKU Nov 30, 2023
93c7956
add flag
LiuXiaoxuanPKU Dec 3, 2023
141da66
remove speculative decoding for prompt run
LiuXiaoxuanPKU Dec 5, 2023
439c88b
remove temperature, only support all greedy for now
LiuXiaoxuanPKU Dec 6, 2023
40ab8d4
clean
Dec 7, 2023
bf2ebe9
minor
Dec 7, 2023
179e968
merge
LiuXiaoxuanPKU Dec 7, 2023
664a256
fix & pass tests
LiuXiaoxuanPKU Dec 7, 2023
7f9a373
format
LiuXiaoxuanPKU Dec 7, 2023
0540142
remove old files
LiuXiaoxuanPKU Dec 7, 2023
993f2d4
remove untouched file
LiuXiaoxuanPKU Dec 8, 2023
c410cbe
format
LiuXiaoxuanPKU Dec 8, 2023
9f2d98b
format
LiuXiaoxuanPKU Dec 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions examples/sd_llm_engine_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import argparse
from typing import List, Tuple

from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput


def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
"""Create a list of test prompts with their sampling parameters."""
return [
("A robot may not injure a human being",
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)),
("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
]


def process_requests(engine: LLMEngine,
test_prompts: List[Tuple[str, SamplingParams]]):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0

while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params = test_prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1

request_outputs: List[RequestOutput] = engine.step()

for request_output in request_outputs:
if request_output.finished:
print(request_output)


def initialize_engine(args: argparse.Namespace) -> LLMEngine:
"""Initialize the LLMEngine from the command line arguments."""
engine_args = EngineArgs.from_cli_args(args)
return LLMEngine.from_engine_args(engine_args)


def main(args: argparse.Namespace):
"""Main function that sets up and runs the prompt processing."""
engine = initialize_engine(args)
test_prompts = create_test_prompts()
process_requests(engine, test_prompts)


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Demo on using the LLMEngine class directly')
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
main(args)


# from transformers import AutoTokenizer, AutoModelForCausalLM
# import torch
# model = AutoModelForCausalLM.from_pretrained(
# "facebook/opt-125m").cuda()
# tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
# prompts = ["What is your name?", "Hello"]
# input_ids = tokenizer(prompts,
# return_tensors='pt',
# padding="max_length",
# max_length=30,
# truncation=True,
# ).input_ids.cuda()
# print(input_ids.shape)
# ref_generated = model.generate(
# input_ids=input_ids, max_new_tokens=5)[:, input_ids.shape[1]:]
# print(ref_generated)

8 changes: 8 additions & 0 deletions vllm/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def get_token_ids(self) -> List[int]:
def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]

# delete num tokens from the end in the same block
def delete_last_tokens(self, num: int) -> None:
assert num > 0
assert num <= self.num_tokens
self.num_tokens -= num
for i in range(self.num_tokens, len(self.token_ids)):
self.token_ids[i] = _BLANK_TOKEN_ID


class PhysicalTokenBlock:
Expand Down
8 changes: 8 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,14 @@ def _verify_args(self) -> None:
f"({self.max_num_seqs}).")


class SpecDecConfig:
def __init__(self,
draft_model_config: ModelConfig,
propose_cnt: int) -> None:
self.draft_model_config = draft_model_config
self.propose_cnt = propose_cnt


_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
Expand Down
9 changes: 9 additions & 0 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,15 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
self.gpu_allocator.free(last_block)
return last_block.block_number, new_block.block_number


def free_tailing_blocks(self, seq: Sequence) -> None:
block_table = self.block_tables[seq.seq_id]
start_idx = len(seq.logical_token_blocks) - 1
for i in range(start_idx, len(block_table) - 1):
block = block_table[i]
self.gpu_allocator.free(block)
self.block_tables[seq.seq_id] = block_table[:start_idx + 1]

def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM.
Expand Down
16 changes: 15 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
SequenceGroupMetadata, SequenceStatus,
SequenceOutputs)

logger = init_logger(__name__)

Expand Down Expand Up @@ -298,6 +299,19 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
def free_seq(self, seq: Sequence) -> None:
self.block_manager.free(seq)

def free_invalid_kv(self, seq: Sequence, seq_out: SequenceOutputs):
LiuXiaoxuanPKU marked this conversation as resolved.
Show resolved Hide resolved
invalid_token_cnt = len(seq.data.get_draft_token_ids()) + 1 - len(seq_out.accepted_tokens)
assert invalid_token_cnt >= 0

if invalid_token_cnt == 0:
return invalid_token_cnt

# delete from logical table
seq.delete_tailing_tokens(invalid_token_cnt)
# delete from physical table
self.block_manager.free_tailing_blocks(seq)
LiuXiaoxuanPKU marked this conversation as resolved.
Show resolved Hide resolved
return invalid_token_cnt

def free_finished_seq_groups(self) -> None:
self.running = [
seq_group for seq_group in self.running
Expand Down
30 changes: 27 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional, Tuple

from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
SchedulerConfig, SpecDecConfig)


@dataclass
Expand All @@ -24,14 +24,17 @@ class EngineArgs:
tensor_parallel_size: int = 1
block_size: int = 16
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
gpu_memory_utilization: float = 0.80
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_paddings: int = 256
disable_log_stats: bool = False
revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None

draft_model: Optional[str] = None
propose_cnt: Optional[int] = None

def __post_init__(self):
if self.tokenizer is None:
Expand Down Expand Up @@ -171,6 +174,17 @@ def add_cli_args(
choices=['awq', 'squeezellm', None],
default=None,
help='Method used to quantize the weights')

# speculative decoding setting
parser.add_argument('--draft-model',
type=str,
default=None,
help='name or path of the huggingface model to use as the draft model')
parser.add_argument('--propose-cnt',
type=int,
default=5,
help='for speculative decoding, number of tokens to propose each step')

return parser

@classmethod
Expand Down Expand Up @@ -200,7 +214,17 @@ def create_engine_configs(
self.max_num_seqs,
model_config.max_model_len,
self.max_paddings)
return model_config, cache_config, parallel_config, scheduler_config

spec_dec_config: SpecDecConfig = None
if self.draft_model:
# assume the draft model and target model share the same tokenizer
# for now, share the same seed as the target
draft_model_config = ModelConfig(self.draft_model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format,
'auto', self.seed)
spec_dec_config = SpecDecConfig(draft_model_config, self.propose_cnt)
return model_config, cache_config, parallel_config, scheduler_config, spec_dec_config


@dataclass
Expand Down
37 changes: 34 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union

from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
SchedulerConfig, SpecDecConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
Expand All @@ -17,6 +17,7 @@
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.utils import Counter
from vllm.engine.spec_dec import SpecDecWorker

if ray:
from ray.air.util.torch_dist import init_torch_dist_process_group
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
spec_dec_config: Optional[SpecDecConfig],
distributed_init_method: str,
placement_group: Optional["PlacementGroup"],
log_stats: bool,
Expand Down Expand Up @@ -121,6 +123,10 @@ def __init__(
self.num_prompt_tokens: List[Tuple[float, int]] = []
# List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = []

self.spec_dec_worker: SpecDecWorker = None
if spec_dec_config:
self.spec_dec_worker = SpecDecWorker(spec_dec_config, self.scheduler)

def _init_workers(self, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers
Expand Down Expand Up @@ -393,8 +399,25 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
if last_child_sample.accepted_tokens:
# Speculative Decoding enabled: invlidate kv cache for non-accepted tokens
LiuXiaoxuanPKU marked this conversation as resolved.
Show resolved Hide resolved
invalid_cnt = self.scheduler.free_invalid_kv(parent, last_child_sample)

if invalid_cnt == 0:
# if all the tokens are accepted
# add the last accept token to the output_token_ids
LiuXiaoxuanPKU marked this conversation as resolved.
Show resolved Hide resolved
# TODO: how to handle the kv cache of the last token?
# TODO: we need to get the logprob of the last token
last_token_id = last_child_sample.accepted_tokens[-1]
parent.append_token_id(last_token_id, {last_token_id: -1})
else:
# update the output_token_ids with only accepted tokens
parent.data.output_token_ids = parent.data.output_token_ids[:-invalid_cnt]
# always clear draft tokens
parent.data.draft_token_probs = []
else:
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
Expand Down Expand Up @@ -558,6 +581,10 @@ def step(self) -> List[RequestOutput]:
if scheduler_outputs.is_empty():
return ignored

if self.spec_dec_worker:
self.spec_dec_worker.set_draft_tokens(seq_group_metadata_list,
scheduler_outputs)

# Execute the model.
output = self._run_workers(
"execute_model",
Expand All @@ -566,6 +593,10 @@ def step(self) -> List[RequestOutput]:
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)

if self.spec_dec_worker:
# accept will set accepted_token_ids and accepted_token_probs in output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, the description of the method/function should be side of its own definition instead of the place you call it.

self.spec_dec_worker.accept(output, scheduler_outputs)

return self._process_model_outputs(output, scheduler_outputs) + ignored

Expand Down
Loading
Loading