From e7bebaa5d96967c1c8fdb41799b714a1b6a76fb8 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Fri, 27 Dec 2024 20:45:08 -0500 Subject: [PATCH] [V1] [4/N] API Server: ZMQ/MP Utilities (#11541) Signed-off-by: Bowen Wang --- docs/requirements-docs.txt | 1 + tests/v1/engine/test_engine_core.py | 13 +-- tests/v1/engine/test_engine_core_client.py | 10 +- vllm/entrypoints/openai/api_server.py | 11 +- vllm/executor/multiproc_worker_utils.py | 22 +--- vllm/utils.py | 90 ++++++++++++++++- vllm/v1/engine/async_llm.py | 6 +- vllm/v1/engine/core.py | 111 ++++----------------- vllm/v1/engine/core_client.py | 92 ++++++++--------- vllm/v1/engine/llm_engine.py | 6 +- vllm/v1/executor/multiproc_executor.py | 11 +- vllm/v1/utils.py | 89 +++++++++++------ 12 files changed, 247 insertions(+), 215 deletions(-) diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 4859c8ac08bea..25a700033cc9e 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -19,3 +19,4 @@ openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entr fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args requests +zmq diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index c529cd21f384b..954cec734b956 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -7,7 +7,6 @@ from vllm import SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform -from vllm.usage.usage_lib import UsageContext from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.core import EngineCore @@ -43,13 +42,11 @@ def test_engine_core(monkeypatch): m.setenv("VLLM_USE_V1", "1") """Setup the EngineCore.""" engine_args = EngineArgs(model=MODEL_NAME) - vllm_config = engine_args.create_engine_config( - usage_context=UsageContext.UNKNOWN_CONTEXT) + vllm_config = engine_args.create_engine_config() executor_class = AsyncLLM._get_executor_cls(vllm_config) engine_core = EngineCore(vllm_config=vllm_config, - executor_class=executor_class, - usage_context=UsageContext.UNKNOWN_CONTEXT) + executor_class=executor_class) """Test basic request lifecycle.""" # First request. @@ -151,13 +148,11 @@ def test_engine_core_advanced_sampling(monkeypatch): m.setenv("VLLM_USE_V1", "1") """Setup the EngineCore.""" engine_args = EngineArgs(model=MODEL_NAME) - vllm_config = engine_args.create_engine_config( - usage_context=UsageContext.UNKNOWN_CONTEXT) + vllm_config = engine_args.create_engine_config() executor_class = AsyncLLM._get_executor_cls(vllm_config) engine_core = EngineCore(vllm_config=vllm_config, - executor_class=executor_class, - usage_context=UsageContext.UNKNOWN_CONTEXT) + executor_class=executor_class) """Test basic request lifecycle.""" # First request. request: EngineCoreRequest = make_request() diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 2f1cbec607a91..729975e4ea8c4 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -86,11 +86,10 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): UsageContext.UNKNOWN_CONTEXT) executor_class = AsyncLLM._get_executor_cls(vllm_config) client = EngineCoreClient.make_client( - vllm_config, - executor_class, - UsageContext.UNKNOWN_CONTEXT, multiprocess_mode=multiprocessing_mode, asyncio_mode=False, + vllm_config=vllm_config, + executor_class=executor_class, ) MAX_TOKENS = 20 @@ -158,11 +157,10 @@ async def test_engine_core_client_asyncio(monkeypatch): usage_context=UsageContext.UNKNOWN_CONTEXT) executor_class = AsyncLLM._get_executor_cls(vllm_config) client = EngineCoreClient.make_client( - vllm_config, - executor_class, - UsageContext.UNKNOWN_CONTEXT, multiprocess_mode=True, asyncio_mode=True, + vllm_config=vllm_config, + executor_class=executor_class, ) MAX_TOKENS = 20 diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2e45b474237f9..094cc15a317e9 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -68,7 +68,7 @@ from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, - is_valid_ipv6_address, set_ulimit) + is_valid_ipv6_address, kill_process_tree, set_ulimit) from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -737,6 +737,15 @@ def signal_handler(*_) -> None: signal.signal(signal.SIGTERM, signal_handler) + # The child processes will send SIGQUIT to this process when + # any error happens. This process then clean up the whole tree. + # TODO(rob): move this into AsyncLLM.__init__ once we remove + # the context manager below. + def sigquit_handler(signum, frame): + kill_process_tree(os.getpid()) + + signal.signal(signal.SIGQUIT, sigquit_handler) + async with build_async_engine_client(args) as engine_client: app = build_app(args) diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index c4d90f0856f86..bc32826529eef 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -1,5 +1,4 @@ import asyncio -import multiprocessing import os import sys import threading @@ -13,10 +12,9 @@ import torch -import vllm.envs as envs from vllm.logger import init_logger from vllm.triton_utils.importing import HAS_TRITON -from vllm.utils import cuda_is_initialized +from vllm.utils import _check_multiproc_method, get_mp_context if HAS_TRITON: from vllm.triton_utils import maybe_set_triton_cache_manager @@ -274,24 +272,6 @@ def write_with_prefix(s: str): file.write = write_with_prefix # type: ignore[method-assign] -def _check_multiproc_method(): - if (cuda_is_initialized() - and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"): - logger.warning("CUDA was previously initialized. We must use " - "the `spawn` multiprocessing start method. Setting " - "VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " - "See https://docs.vllm.ai/en/latest/getting_started/" - "debugging.html#python-multiprocessing " - "for more information.") - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - - -def get_mp_context(): - _check_multiproc_method() - mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD - return multiprocessing.get_context(mp_method) - - def set_multiprocessing_worker_envs(parallel_config): """ Set up environment variables that should be used when there are workers in a multiprocessing environment. This should be called by the parent diff --git a/vllm/utils.py b/vllm/utils.py index 5eb4e8c4180c4..2b46c1fef0d09 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -10,6 +10,7 @@ import importlib.util import inspect import ipaddress +import multiprocessing import os import re import resource @@ -20,6 +21,7 @@ import tempfile import threading import time +import traceback import uuid import warnings import weakref @@ -29,8 +31,9 @@ from dataclasses import dataclass, field from functools import lru_cache, partial, wraps from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, - Dict, Generator, Generic, List, Literal, NamedTuple, - Optional, Tuple, Type, TypeVar, Union, overload) + Dict, Generator, Generic, Iterator, List, Literal, + NamedTuple, Optional, Tuple, Type, TypeVar, Union, + overload) from uuid import uuid4 import numpy as np @@ -39,6 +42,8 @@ import torch import torch.types import yaml +import zmq +import zmq.asyncio from packaging.version import Version from torch.library import Library from typing_extensions import ParamSpec, TypeIs, assert_never @@ -1844,7 +1849,7 @@ def memory_profiling( result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa -# Adapted from: https://github.com/sgl-project/sglang/blob/f46f394f4d4dbe4aae85403dec006199b34d2840/python/sglang/srt/utils.py#L630 # noqa: E501Curre +# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 def set_ulimit(target_soft_limit=65535): resource_type = resource.RLIMIT_NOFILE current_soft, current_hard = resource.getrlimit(resource_type) @@ -1859,3 +1864,82 @@ def set_ulimit(target_soft_limit=65535): "with error %s. This can cause fd limit errors like" "`OSError: [Errno 24] Too many open files`. Consider " "increasing with ulimit -n", current_soft, e) + + +# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501 +def get_exception_traceback(): + etype, value, tb = sys.exc_info() + err_str = "".join(traceback.format_exception(etype, value, tb)) + return err_str + + +# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501 +def make_zmq_socket( + ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] + path: str, + type: Any, +) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] + """Make a ZMQ socket with the proper bind/connect semantics.""" + + mem = psutil.virtual_memory() + socket = ctx.socket(type) + + # Calculate buffer size based on system memory + total_mem = mem.total / 1024**3 + available_mem = mem.available / 1024**3 + # For systems with substantial memory (>32GB total, >16GB available): + # - Set a large 0.5GB buffer to improve throughput + # For systems with less memory: + # - Use system default (-1) to avoid excessive memory consumption + if total_mem > 32 and available_mem > 16: + buf_size = int(0.5 * 1024**3) # 0.5GB in bytes + else: + buf_size = -1 # Use system default buffer size + + if type == zmq.constants.PULL: + socket.setsockopt(zmq.constants.RCVHWM, 0) + socket.setsockopt(zmq.constants.RCVBUF, buf_size) + socket.connect(path) + elif type == zmq.constants.PUSH: + socket.setsockopt(zmq.constants.SNDHWM, 0) + socket.setsockopt(zmq.constants.SNDBUF, buf_size) + socket.bind(path) + else: + raise ValueError(f"Unknown Socket Type: {type}") + + return socket + + +@contextlib.contextmanager +def zmq_socket_ctx( + path: str, + type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + """Context manager for a ZMQ socket""" + + ctx = zmq.Context(io_threads=2) # type: ignore[attr-defined] + try: + yield make_zmq_socket(ctx, path, type) + + except KeyboardInterrupt: + logger.debug("Got Keyboard Interrupt.") + + finally: + ctx.destroy(linger=0) + + +def _check_multiproc_method(): + if (cuda_is_initialized() + and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"): + logger.warning("CUDA was previously initialized. We must use " + "the `spawn` multiprocessing start method. Setting " + "VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " + "See https://docs.vllm.ai/en/latest/getting_started/" + "debugging.html#python-multiprocessing " + "for more information.") + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +def get_mp_context(): + _check_multiproc_method() + mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD + return multiprocessing.get_context(mp_method) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ba2b8377759d6..da3da6dad6436 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -75,11 +75,11 @@ def __init__( # EngineCore (starts the engine in background process). self.engine_core = EngineCoreClient.make_client( - vllm_config=vllm_config, - executor_class=executor_class, - usage_context=usage_context, multiprocess_mode=True, asyncio_mode=True, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=self.log_stats, ) self.output_handler: Optional[asyncio.Task] = None diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0aef61fc7f680..5840541d774ba 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -3,20 +3,19 @@ import signal import threading import time -from dataclasses import dataclass -from multiprocessing.process import BaseProcess +from multiprocessing.connection import Connection from typing import List, Tuple, Type +import psutil import zmq import zmq.asyncio from msgspec import msgpack from vllm.config import CacheConfig, VllmConfig -from vllm.executor.multiproc_worker_utils import get_mp_context from vllm.logger import init_logger from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.usage.usage_lib import UsageContext +from vllm.utils import get_exception_traceback, zmq_socket_ctx from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, @@ -25,14 +24,13 @@ from vllm.v1.executor.abstract import Executor from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import PickleEncoder -from vllm.v1.utils import make_zmq_socket from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -LOGGING_TIME_S = POLLING_TIMEOUT_S +LOGGING_TIME_S = 5 class EngineCore: @@ -42,9 +40,10 @@ def __init__( self, vllm_config: VllmConfig, executor_class: Type[Executor], - usage_context: UsageContext, + log_stats: bool = False, ): assert vllm_config.model_config.runner_type != "pooling" + self.log_stats = log_stats logger.info("Initializing an LLM engine (v%s) with config: %s", VLLM_VERSION, vllm_config) @@ -134,29 +133,19 @@ def profile(self, is_start: bool = True): self.model_executor.profile(is_start) -@dataclass -class EngineCoreProcHandle: - proc: BaseProcess - ready_path: str - input_path: str - output_path: str - - class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" - READY_STR = "READY" - def __init__( self, - vllm_config: VllmConfig, - executor_class: Type[Executor], - usage_context: UsageContext, input_path: str, output_path: str, - ready_path: str, + ready_pipe: Connection, + vllm_config: VllmConfig, + executor_class: Type[Executor], + log_stats: bool = False, ): - super().__init__(vllm_config, executor_class, usage_context) + super().__init__(vllm_config, executor_class, log_stats) # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, @@ -173,68 +162,7 @@ def __init__( daemon=True).start() # Send Readiness signal to EngineClient. - with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket: - ready_socket.send_string(EngineCoreProc.READY_STR) - - @staticmethod - def wait_for_startup( - proc: BaseProcess, - ready_path: str, - ) -> None: - """Wait until the EngineCore is ready.""" - - try: - sync_ctx = zmq.Context() # type: ignore[attr-defined] - socket = sync_ctx.socket(zmq.constants.PULL) - socket.connect(ready_path) - - # Wait for EngineCore to send EngineCoreProc.READY_STR. - while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - logger.debug("Waiting for EngineCoreProc to startup.") - - if not proc.is_alive(): - raise RuntimeError("EngineCoreProc failed to start.") - - message = socket.recv_string() - assert message == EngineCoreProc.READY_STR - - except BaseException as e: - logger.exception(e) - raise e - - finally: - sync_ctx.destroy(linger=0) - - @staticmethod - def make_engine_core_process( - vllm_config: VllmConfig, - executor_class: Type[Executor], - usage_context: UsageContext, - input_path: str, - output_path: str, - ready_path: str, - ) -> EngineCoreProcHandle: - context = get_mp_context() - - process_kwargs = { - "input_path": input_path, - "output_path": output_path, - "ready_path": ready_path, - "vllm_config": vllm_config, - "executor_class": executor_class, - "usage_context": usage_context, - } - # Run EngineCore busy loop in background process. - proc = context.Process(target=EngineCoreProc.run_engine_core, - kwargs=process_kwargs) - proc.start() - - # Wait for startup - EngineCoreProc.wait_for_startup(proc, ready_path) - return EngineCoreProcHandle(proc=proc, - ready_path=ready_path, - input_path=input_path, - output_path=output_path) + ready_pipe.send({"status": "READY"}) @staticmethod def run_engine_core(*args, **kwargs): @@ -258,6 +186,7 @@ def signal_handler(signum, frame): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) + parent_process = psutil.Process().parent() engine_core = None try: engine_core = EngineCoreProc(*args, **kwargs) @@ -266,9 +195,10 @@ def signal_handler(signum, frame): except SystemExit: logger.debug("EngineCore interrupted.") - except BaseException as e: - logger.exception(e) - raise e + except Exception: + traceback = get_exception_traceback() + logger.error("EngineCore hit an exception: %s", traceback) + parent_process.send_signal(signal.SIGQUIT) finally: if engine_core is not None: @@ -309,6 +239,9 @@ def run_busy_loop(self): def _log_stats(self): """Log basic stats every LOGGING_TIME_S""" + if not self.log_stats: + return + now = time.time() if now - self._last_logging_time > LOGGING_TIME_S: @@ -339,7 +272,7 @@ def process_input_socket(self, input_path: str): decoder_add_req = PickleEncoder() decoder_abort_req = PickleEncoder() - with make_zmq_socket(input_path, zmq.constants.PULL) as socket: + with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket: while True: # (RequestType, RequestData) type_frame, data_frame = socket.recv_multipart(copy=False) @@ -367,7 +300,7 @@ def process_output_socket(self, output_path: str): # Reuse send buffer. buffer = bytearray() - with make_zmq_socket(output_path, zmq.constants.PUSH) as socket: + with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket: while True: engine_core_outputs = self.output_queue.get() outputs = EngineCoreOutputs(outputs=engine_core_outputs) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index d56fcbdb1e7c4..beb5d57c20c83 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,19 +1,19 @@ -import os -import weakref -from typing import List, Optional +from typing import List, Optional, Type import msgspec import zmq import zmq.asyncio +from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import get_open_zmq_ipc_path, kill_process_tree +from vllm.utils import get_open_zmq_ipc_path from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, EngineCoreRequestType, EngineCoreRequestUnion) -from vllm.v1.engine.core import (EngineCore, EngineCoreProc, - EngineCoreProcHandle) +from vllm.v1.engine.core import EngineCore, EngineCoreProc +from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import PickleEncoder +from vllm.v1.utils import BackgroundProcHandle logger = init_logger(__name__) @@ -31,10 +31,11 @@ class EngineCoreClient: @staticmethod def make_client( - *args, multiprocess_mode: bool, asyncio_mode: bool, - **kwargs, + vllm_config: VllmConfig, + executor_class: Type[Executor], + log_stats: bool = False, ) -> "EngineCoreClient": # TODO: support this for debugging purposes. @@ -44,12 +45,12 @@ def make_client( "is not currently supported.") if multiprocess_mode and asyncio_mode: - return AsyncMPClient(*args, **kwargs) + return AsyncMPClient(vllm_config, executor_class, log_stats) if multiprocess_mode and not asyncio_mode: - return SyncMPClient(*args, **kwargs) + return SyncMPClient(vllm_config, executor_class, log_stats) - return InprocClient(*args, **kwargs) + return InprocClient(vllm_config, executor_class, log_stats) def shutdown(self): pass @@ -128,9 +129,10 @@ class MPClient(EngineCoreClient): def __init__( self, - *args, asyncio_mode: bool, - **kwargs, + vllm_config: VllmConfig, + executor_class: Type[Executor], + log_stats: bool = False, ): # Serialization setup. self.encoder = PickleEncoder() @@ -143,7 +145,6 @@ def __init__( self.ctx = zmq.Context() # type: ignore[attr-defined] # Path for IPC. - ready_path = get_open_zmq_ipc_path() output_path = get_open_zmq_ipc_path() input_path = get_open_zmq_ipc_path() @@ -156,47 +157,40 @@ def __init__( self.input_socket.bind(input_path) # Start EngineCore in background process. - self.proc_handle: Optional[EngineCoreProcHandle] - self.proc_handle = EngineCoreProc.make_engine_core_process( - *args, - input_path= - input_path, # type: ignore[misc] # MyPy incorrectly flags duplicate keywords - output_path=output_path, # type: ignore[misc] - ready_path=ready_path, # type: ignore[misc] - **kwargs, - ) - self._finalizer = weakref.finalize(self, self.shutdown) + self.proc_handle: Optional[BackgroundProcHandle] + self.proc_handle = BackgroundProcHandle( + input_path=input_path, + output_path=output_path, + process_name="EngineCore", + target_fn=EngineCoreProc.run_engine_core, + process_kwargs={ + "vllm_config": vllm_config, + "executor_class": executor_class, + "log_stats": log_stats, + }) def shutdown(self): # Shut down the zmq context. self.ctx.destroy(linger=0) if hasattr(self, "proc_handle") and self.proc_handle: - # Shutdown the process if needed. - if self.proc_handle.proc.is_alive(): - self.proc_handle.proc.terminate() - self.proc_handle.proc.join(5) - - if self.proc_handle.proc.is_alive(): - kill_process_tree(self.proc_handle.proc.pid) - - # Remove zmq ipc socket files - ipc_sockets = [ - self.proc_handle.ready_path, self.proc_handle.output_path, - self.proc_handle.input_path - ] - for ipc_socket in ipc_sockets: - socket_file = ipc_socket.replace("ipc://", "") - if os and os.path.exists(socket_file): - os.remove(socket_file) + self.proc_handle.shutdown() self.proc_handle = None class SyncMPClient(MPClient): """Synchronous client for multi-proc EngineCore.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, asyncio_mode=False, **kwargs) + def __init__(self, + vllm_config: VllmConfig, + executor_class: Type[Executor], + log_stats: bool = False): + super().__init__( + asyncio_mode=False, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + ) def get_output(self) -> List[EngineCoreOutput]: @@ -225,8 +219,16 @@ def profile(self, is_start: bool = True) -> None: class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, asyncio_mode=True, **kwargs) + def __init__(self, + vllm_config: VllmConfig, + executor_class: Type[Executor], + log_stats: bool = False): + super().__init__( + asyncio_mode=True, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + ) async def get_output_async(self) -> List[EngineCoreOutput]: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index b58f62778ffe9..fc323184abc8f 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -72,11 +72,11 @@ def __init__( # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) self.engine_core = EngineCoreClient.make_client( - vllm_config, - executor_class, - usage_context, multiprocess_mode=multiprocess_mode, asyncio_mode=False, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=False, ) @classmethod diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 128101aa6956d..ed64e7741390d 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -17,13 +17,12 @@ from vllm.distributed.device_communicators.shm_broadcast import (Handle, MessageQueue) from vllm.executor.multiproc_worker_utils import ( - _add_prefix, get_mp_context, set_multiprocessing_worker_envs) + _add_prefix, set_multiprocessing_worker_envs) from vllm.logger import init_logger -from vllm.utils import (get_distributed_init_method, get_open_port, - get_open_zmq_ipc_path) +from vllm.utils import (get_distributed_init_method, get_mp_context, + get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx) from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.utils import make_zmq_socket from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -250,7 +249,7 @@ def __init__( worker_response_mq_handle = self.worker_response_mq.export_handle() # Send Readiness signal to EngineCore process. - with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket: + with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket: payload = pickle.dumps(worker_response_mq_handle, protocol=pickle.HIGHEST_PROTOCOL) ready_socket.send_string(WorkerProc.READY_STR) @@ -352,7 +351,7 @@ def wait_for_startup( ready_path: str, ) -> Optional[Handle]: """Wait until the Worker is ready.""" - with make_zmq_socket(ready_path, zmq.constants.PULL) as socket: + with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket: # Wait for Worker to send READY. while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index e802c6439b740..19e0dd17237c9 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,11 +1,11 @@ +import os +import weakref from collections.abc import Sequence -from contextlib import contextmanager -from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union, - overload) - -import zmq +from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar, + Union, overload) from vllm.logger import init_logger +from vllm.utils import get_mp_context, kill_process_tree logger = init_logger(__name__) @@ -77,27 +77,58 @@ def __len__(self): return len(self._x) -@contextmanager -def make_zmq_socket( - path: str, - type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined] - """Context manager for a ZMQ socket""" - - ctx = zmq.Context() # type: ignore[attr-defined] - try: - socket = ctx.socket(type) - - if type == zmq.constants.PULL: - socket.connect(path) - elif type == zmq.constants.PUSH: - socket.bind(path) - else: - raise ValueError(f"Unknown Socket Type: {type}") - - yield socket - - except KeyboardInterrupt: - logger.debug("Worker had Keyboard Interrupt.") - - finally: - ctx.destroy(linger=0) +class BackgroundProcHandle: + """ + Utility class to handle creation, readiness, and shutdown + of background processes used by the AsyncLLM and LLMEngine. + """ + + def __init__( + self, + input_path: str, + output_path: str, + process_name: str, + target_fn: Callable, + process_kwargs: Dict[Any, Any], + ): + self._finalizer = weakref.finalize(self, self.shutdown) + + context = get_mp_context() + reader, writer = context.Pipe(duplex=False) + + assert ("ready_pipe" not in process_kwargs + and "input_path" not in process_kwargs + and "output_path" not in process_kwargs) + process_kwargs["ready_pipe"] = writer + process_kwargs["input_path"] = input_path + process_kwargs["output_path"] = output_path + self.input_path = input_path + self.output_path = output_path + + # Run Detokenizer busy loop in background process. + self.proc = context.Process(target=target_fn, kwargs=process_kwargs) + self.proc.start() + + # Wait for startup. + if reader.recv()["status"] != "READY": + raise RuntimeError(f"{process_name} initialization failed. " + "See root cause above.") + + def __del__(self): + self.shutdown() + + def shutdown(self): + # Shutdown the process if needed. + if hasattr(self, "proc") and self.proc.is_alive(): + self.proc.terminate() + self.proc.join(5) + + if self.proc.is_alive(): + kill_process_tree(self.proc.pid) + + # Remove zmq ipc socket files + ipc_sockets = [self.output_path, self.input_path] + for ipc_socket in ipc_sockets: + socket_file = ipc_socket.replace("ipc://", "") + if os and os.path.exists(socket_file): + os.remove(socket_file)