Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update QueueLLM #97

Open
wants to merge 2 commits into
base: MLPerf_4.1
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 46 additions & 32 deletions vllm/entrypoints/queue_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from vllm.utils import Counter, deprecate_kwargs
import multiprocessing as mp
import queue
import threading

logger = init_logger(__name__)

Expand Down Expand Up @@ -167,28 +166,36 @@ def start(self, use_tqdm: bool = True):
self._pull_tokens_from_input_queue(block=True)
self._run_engine(use_tqdm=use_tqdm)

def _pull_all_tokens_from_input_queue(self, block: bool = True):
while self._pull_tokens_from_input_queue(block=False):
pass
if block:
self._pull_tokens_from_input_queue(block)

def _pull_tokens_from_input_queue(self, block: bool = True):
try:
input = self.input_queue.get() if block else self.input_queue.get_nowait()
if input is None:
self.finish = True
for sample_id, token_ids in input:
inputs = self._convert_v1_inputs(
prompts=None,
prompt_token_ids=token_ids,
multi_modal_data=None,
)

self._validate_and_add_requests(
inputs=inputs,
params=self.sampling_params,
request_id=sample_id,
)
else:
for sample_id, token_ids in input:
inputs = self._convert_v1_inputs(
prompts=None,
prompt_token_ids=token_ids,
multi_modal_data=None,
)

self._validate_and_add_requests(
inputs=inputs,
params=self.sampling_params,
request_id=sample_id,
)
except queue.Empty:
pass
return False
except Exception as e:
logger.error(f"Unexpected exception during pulling tokens: {e}")

return False
return True

def _convert_v1_inputs(
self,
Expand Down Expand Up @@ -291,24 +298,31 @@ def _run_engine(
)
# Run the engine.
total_toks = 0
first_token_sent = set()
while not self.finish and self.llm_engine.has_unfinished_requests():
self._pull_tokens_from_input_queue(block=False)
request_stats = {}
while not self.finish or self.llm_engine.has_unfinished_requests():
block = not self.llm_engine.has_unfinished_requests() and not self.finish
self._pull_all_tokens_from_input_queue(block=block)
step_outputs = self.llm_engine.step()
for output in step_outputs:
if len(output.outputs) > 0 and (output.request_id not in first_token_sent):
self.first_token_queue.put(output)
first_token_sent.add(output.request_id)
if output.finished:
self.result_queue.put_nowait(output)
first_token_sent.remove(output.request_id)
if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
total_toks += sum(
len(stp.token_ids) for stp in output.outputs)
spd = total_toks / pbar.format_dict["elapsed"]
pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
pbar.update(1)
output_len = len(output.outputs[0].token_ids)
if output_len > 0 and (output.request_id not in request_stats):
self.first_token_queue.put((output.request_id, output.outputs[0].token_ids))
request_stats[output.request_id] = output_len
if request_stats[output.request_id] < output_len:
self.result_queue.put_nowait((output.request_id, output.outputs[0].token_ids[request_stats[output.request_id]: output_len]))
if output.finished:
# signal end of stream with None
self.result_queue.put_nowait((output.request_id, None))
del request_stats[output.request_id]
if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
total_toks += sum(
len(stp.token_ids) for stp in output.outputs)
spd = total_toks / pbar.format_dict["elapsed"]
pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
pbar.update(1)
else:
request_stats[output.request_id] = output_len
if use_tqdm:
pbar.close()