diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index c1f1744a118bd..e24aa2489a80f 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -1,4 +1,3 @@ -import asyncio import time from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional, Tuple) @@ -17,7 +16,7 @@ from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput -from vllm.utils import random_uuid +from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) @@ -50,41 +49,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]: return prompt_is_tokens, prompts -def merge_async_iterators(*iterators): - """Merge multiple asynchronous iterators into a single iterator. - - This method handle the case where some iterators finish before others. - When it yields, it yields a tuple (i, item) where i is the index of the - iterator that yields the item. - """ - queue = asyncio.Queue() - - finished = [False] * len(iterators) - - async def producer(i, iterator): - try: - async for item in iterator: - await queue.put((i, item)) - except Exception as e: - await queue.put(e) - finished[i] = True - - _tasks = [ - asyncio.create_task(producer(i, iterator)) - for i, iterator in enumerate(iterators) - ] - - async def consumer(): - while not all(finished) or not queue.empty(): - item = await queue.get() - if isinstance(item, Exception): - raise item - yield item - await asyncio.gather(*_tasks) - - return consumer() - - class OpenAIServingCompletion(OpenAIServing): def __init__(self, diff --git a/vllm/utils.py b/vllm/utils.py index a36748f25858e..e67d267aed408 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -9,8 +9,8 @@ from collections import OrderedDict, defaultdict from functools import lru_cache, partial from platform import uname -from typing import (Any, Awaitable, Callable, Dict, Generic, Hashable, List, - Optional, Tuple, TypeVar, Union) +from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, + Hashable, List, Optional, Tuple, TypeVar, Union) import psutil import torch @@ -181,6 +181,42 @@ def _async_wrapper(*args, **kwargs) -> asyncio.Future: return _async_wrapper +def merge_async_iterators( + *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]: + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + """ + queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue() + + finished = [False] * len(iterators) + + async def producer(i: int, iterator: AsyncIterator[T]): + try: + async for item in iterator: + await queue.put((i, item)) + except Exception as e: + await queue.put(e) + finished[i] = True + + _tasks = [ + asyncio.create_task(producer(i, iterator)) + for i, iterator in enumerate(iterators) + ] + + async def consumer(): + while not all(finished) or not queue.empty(): + item = await queue.get() + if isinstance(item, Exception): + raise item + yield item + await asyncio.gather(*_tasks) + + return consumer() + + def get_ip() -> str: host_ip = os.environ.get("HOST_IP") if host_ip: