-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Spec Draft #3
base: main
Are you sure you want to change the base?
Spec Draft #3
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please take a look the comments!
vllm/engine/spec_dec.py
Outdated
token_id: int, | ||
index: int): | ||
token_prob = sample.sd_draft_probs[index] | ||
assert token_id in token_prob |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if vllm itself how to handle such assertion. Typically, when you assertion failed, it will terminate the program. You don't want to terminate your whole vllm server program because of just one request. Please double check on how vllm handle this kind of case. Be careful to decide you should assert or not.
@@ -566,6 +593,10 @@ def step(self) -> List[RequestOutput]: | |||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, | |||
blocks_to_copy=scheduler_outputs.blocks_to_copy, | |||
) | |||
|
|||
if self.spec_dec_worker: | |||
# accept will set accepted_token_ids and accepted_token_probs in output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, the description of the method/function should be side of its own definition instead of the place you call it.
PAD_TOKEN_ID = 0 | ||
|
||
|
||
class SpecDecWorker(Worker): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to add some comments to this quite complex and important class.
why we need this class, how to use this class. you may add some thing like "User should call function xxx to xxxx when xxx, then call function xxx to xxx....."(only describe the public method) Give a high level usage explanation would helps others(or even yourself in the future) to understand your code
draft_tokens = [list(tp.keys())[0] for tp in self.draft_token_probs] | ||
return draft_tokens | ||
|
||
def get_verified_token_ids(self) -> List[int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps get_accpeted_token_ids
would be better?
def _delete_logical_block(self, block: LogicalTokenBlock) -> None: | ||
self.logical_token_blocks.remove(block) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason to create this method instead of just inplace use self.logical_token_blocks.remove(block)
?
while n > 0: | ||
assert len(self.logical_token_blocks) > 0 | ||
last_block = self.logical_token_blocks[-1] | ||
if last_block.num_tokens < n: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if last_block.num_tokens == n, I think you also need to call self._delete_logical_block(last_block)?
vllm/worker/worker.py
Outdated
verify_tokens = seq_data.get_verified_token_ids() | ||
verify_len = len(verify_tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should use accept
instead of verify
in all places since difference word used for same key word would introduce ambiguity
vllm/worker/worker.py
Outdated
sliding_window_blocks = (self.sliding_window // | ||
self.block_size) | ||
block_table = block_table[-sliding_window_blocks:] | ||
if draft_len > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here you'd better add some comments like:
"Speculative Decoding enabled case: <explain the main difference of this case compared with normal case, maybe something like it need to pass in multiple tokens that are accepted in previous run>"
@LiuXiaoxuanPKU Are you considering to also port these changes to TGI (huggingface's inference project), at any point? |
Thanks for the attention! The PR is still in progress and the current version mainly focuses on correctness instead of performance. There are still many performance improvement spaces. We are targeting to make the PR runnable this week. The vllm team will start reviewing late this week and next week. We will not work on porting to TGI for now. Making it work in vllm is already complicated... |
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
No description provided.