Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spec Draft #3

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open

Spec Draft #3

wants to merge 47 commits into from

Conversation

LiuXiaoxuanPKU
Copy link
Owner

No description provided.

vllm/engine/spec_decoding.py Outdated Show resolved Hide resolved
vllm/model_executor/models/llama.py Outdated Show resolved Hide resolved
vllm/sequence.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@pian13131 pian13131 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look the comments!

vllm/engine/spec_dec.py Outdated Show resolved Hide resolved
vllm/engine/spec_dec.py Outdated Show resolved Hide resolved
vllm/engine/spec_dec.py Show resolved Hide resolved
vllm/engine/spec_dec.py Outdated Show resolved Hide resolved
token_id: int,
index: int):
token_prob = sample.sd_draft_probs[index]
assert token_id in token_prob
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if vllm itself how to handle such assertion. Typically, when you assertion failed, it will terminate the program. You don't want to terminate your whole vllm server program because of just one request. Please double check on how vllm handle this kind of case. Be careful to decide you should assert or not.

vllm/worker/worker.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
vllm/block.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/spec_dec.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/kv_mqa.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/sampler.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/sampler.py Outdated Show resolved Hide resolved
vllm/sequence.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Show resolved Hide resolved
@@ -566,6 +593,10 @@ def step(self) -> List[RequestOutput]:
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)

if self.spec_dec_worker:
# accept will set accepted_token_ids and accepted_token_probs in output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, the description of the method/function should be side of its own definition instead of the place you call it.

PAD_TOKEN_ID = 0


class SpecDecWorker(Worker):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to add some comments to this quite complex and important class.
why we need this class, how to use this class. you may add some thing like "User should call function xxx to xxxx when xxx, then call function xxx to xxx....."(only describe the public method) Give a high level usage explanation would helps others(or even yourself in the future) to understand your code

draft_tokens = [list(tp.keys())[0] for tp in self.draft_token_probs]
return draft_tokens

def get_verified_token_ids(self) -> List[int]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps get_accpeted_token_ids would be better?

Comment on lines +160 to +161
def _delete_logical_block(self, block: LogicalTokenBlock) -> None:
self.logical_token_blocks.remove(block)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to create this method instead of just inplace use self.logical_token_blocks.remove(block)?

while n > 0:
assert len(self.logical_token_blocks) > 0
last_block = self.logical_token_blocks[-1]
if last_block.num_tokens < n:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if last_block.num_tokens == n, I think you also need to call self._delete_logical_block(last_block)?

Comment on lines 272 to 273
verify_tokens = seq_data.get_verified_token_ids()
verify_len = len(verify_tokens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should use accept instead of verify in all places since difference word used for same key word would introduce ambiguity

sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
if draft_len > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here you'd better add some comments like:
"Speculative Decoding enabled case: <explain the main difference of this case compared with normal case, maybe something like it need to pass in multiple tokens that are accepted in previous run>"

@truenorth8
Copy link

@LiuXiaoxuanPKU Are you considering to also port these changes to TGI (huggingface's inference project), at any point?
the current MR looks great btw

@LiuXiaoxuanPKU
Copy link
Owner Author

@LiuXiaoxuanPKU Are you considering to also port these changes to TGI (huggingface's inference project), at any point? the current MR looks great btw

Thanks for the attention! The PR is still in progress and the current version mainly focuses on correctness instead of performance. There are still many performance improvement spaces. We are targeting to make the PR runnable this week. The vllm team will start reviewing late this week and next week.

We will not work on porting to TGI for now. Making it work in vllm is already complicated...

Copy link

github-actions bot commented Nov 7, 2024

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale label Nov 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants