Skip to content

Commit

Permalink
[Core] V1: Use multiprocessing by default (#11074)
Browse files Browse the repository at this point in the history
Signed-off-by: Russell Bryant <[email protected]>
  • Loading branch information
russellb authored Dec 14, 2024
1 parent 0d8451c commit 4863e5f
Show file tree
Hide file tree
Showing 10 changed files with 299 additions and 17 deletions.
195 changes: 195 additions & 0 deletions docs/source/design/multiprocessing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Python Multiprocessing

## Debugging

Please see the [Debugging
Tips](https://docs.vllm.ai/en/latest/getting_started/debugging.html#python-multiprocessing)
page for information on known issues and how to solve them.

## Introduction

*Note that source code references are to the state of the code at the time of writing in December, 2024.*

The use of Python multiprocessing in vLLM is complicated by:

- The use of vLLM as a library and the inability to control the code using vLLM
- Varying levels of incompatibilities between multiprocessing methods and vLLM
dependencies

This document describes how vLLM deals with these challenges.

## Multiprocessing Methods

[Python multiprocessing methods](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods) include:

- `spawn` - spawn a new Python process. This will be the default as of Python
3.14.

- `fork` - Use `os.fork()` to fork the Python interpreter. This is the default
in Python versions prior to 3.14.

- `forkserver` - Spawn a server process that will fork a new process on request.

### Tradeoffs

`fork` is the fastest method, but is incompatible with dependencies that use
threads.

`spawn` is more compatible with dependencies, but can be problematic when vLLM
is used as a library. If the consuming code does not use a `__main__` guard (`if
__name__ == "__main__":`), the code will be inadvertently re-executed when vLLM
spawns a new process. This can lead to infinite recursion, among other problems.

`forkserver` will spawn a new server process that will fork new processes on
demand. This unfortunately has the same problem as `spawn` when vLLM is used as
a library. The server process is created as a spawned new process, which will
re-execute code not protected by a `__main__` guard.

For both `spawn` and `forkserver`, the process must not depend on inheriting any
global state as would be the case with `fork`.

## Compatibility with Dependencies

Multiple vLLM dependencies indicate either a preference or requirement for using
`spawn`:

- <https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing>
- <https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors>
- <https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders>

It is perhaps more accurate to say that there are known problems with using
`fork` after initializing these dependencies.

## Current State (v0)

The environment variable `VLLM_WORKER_MULTIPROC_METHOD` can be used to control which method is used by vLLM. The current default is `fork`.

- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/envs.py#L339-L342>

When we know we own the process because the `vllm` command was used, we use
`spawn` because it's the most widely compatible.

- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/scripts.py#L123-L140>

The `multiproc_xpu_executor` forces the use of `spawn`.

- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/executor/multiproc_xpu_executor.py#L14-L18>

There are other miscellaneous places hard-coding the use of `spawn`:

- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L135>
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/entrypoints/openai/api_server.py#L184>

Related PRs:

- <https://github.com/vllm-project/vllm/pull/8823>

## Prior State in v1

There was an environment variable to control whether multiprocessing is used in
the v1 engine core, `VLLM_ENABLE_V1_MULTIPROCESSING`. This defaulted to off.

- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/envs.py#L452-L454>

When it was enabled, the v1 `LLMEngine` would create a new process to run the
engine core.

- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/v1/engine/llm_engine.py#L93-L95>
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/v1/engine/llm_engine.py#L70-L77>
- https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/v1/engine/core_client.py#L44-L45

It was off by default for all the reasons mentioned above - compatibility with
dependencies and code using vLLM as a library.

### Changes Made in v1

There is not an easy solution with Python's `multiprocessing` that will work
everywhere. As a first step, we can get v1 into a state where it does "best
effort" choice of multiprocessing method to maximize compatibility.

- Default to `fork`.
- Use `spawn` when we know we control the main process (`vllm` was executed).
- If we detect `cuda` was previously initialized, force `spawn` and emit a
warning. We know `fork` will break, so this is the best we can do.

The case that is known to still break in this scenario is code using vLLM as a
library that initializes `cuda` before calling vLLM. The warning we emit should
instruct users to either add a `__main__` guard or to disable multiprocessing.

If that known-failure case occurs, the user will see two messages that explain
what is happening. First, a log message from vLLM:

```
WARNING 12-11 14:50:37 multiproc_worker_utils.py:281] CUDA was previously
initialized. We must use the `spawn` multiprocessing start method. Setting
VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See
https://docs.vllm.ai/en/latest/getting_started/debugging.html#python-multiprocessing
for more information.
```

Second, Python itself will raise an exception with a nice explanation:

```
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.
This probably means that you are not using fork to start your
child processes and you have forgotten to use the proper idiom
in the main module:
if __name__ == '__main__':
freeze_support()
...
The "freeze_support()" line can be omitted if the program
is not going to be frozen to produce an executable.
To fix this issue, refer to the "Safe importing of main module"
section in https://docs.python.org/3/library/multiprocessing.html
```

## Alternatives Considered

### Detect if a `__main__` guard is present

It has been suggested that we could behave better if we could detect whether
code using vLLM as a library has a `__main__` guard in place. This [post on
stackoverflow](https://stackoverflow.com/questions/77220442/multiprocessing-pool-in-a-python-class-without-name-main-guard)
was from a library author facing the same question.

It is possible to detect whether we are in the original, `__main__` process, or
a subsequent spawned process. However, it does not appear to be straight forward
to detect whether a `__main__` guard is present in the code.

This option has been discarded as impractical.

### Use `forkserver`

At first it appears that `forkserver` is a nice solution to the problem.
However, the way it works presents the same challenges that `spawn` does when
vLLM is used as a library.

### Force `spawn` all the time

One way to clean this up is to just force the use of `spawn` all the time and
document that the use of a `__main__` guard is required when using vLLM as a
library. This would unfortunately break existing code and make vLLM harder to
use, violating the desire to make the `LLM` class as easy as possible to use.

Instead of pushing this on our users, we will retain the complexity to do our
best to make things work.

## Future Work

We may want to consider a different worker management approach in the future
that works around these challenges.

1. We could implement something `forkserver`-like, but have the process manager
be something we initially launch by running our own subprocess and a custom
entrypoint for worker management (launch a `vllm-manager` process).

2. We can explore other libraries that may better suit our needs. Examples to
consider:

- <https://github.com/joblib/loky>
56 changes: 56 additions & 0 deletions docs/source/getting_started/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,62 @@ If the test script hangs or crashes, usually it means the hardware/drivers are b

Adjust ``--nproc-per-node``, ``--nnodes``, and ``--node-rank`` according to your setup, being sure to execute different commands (with different ``--node-rank``) on different nodes.

Python multiprocessing
----------------------

`RuntimeError` Exception
^^^^^^^^^^^^^^^^^^^^^^^^

If you have seen a warning in your logs like this:

.. code-block:: console
WARNING 12-11 14:50:37 multiproc_worker_utils.py:281] CUDA was previously
initialized. We must use the `spawn` multiprocessing start method. Setting
VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See
https://docs.vllm.ai/en/latest/getting_started/debugging.html#python-multiprocessing
for more information.
or an error from Python that looks like this:

.. code-block:: console
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.
This probably means that you are not using fork to start your
child processes and you have forgotten to use the proper idiom
in the main module:
if __name__ == '__main__':
freeze_support()
...
The "freeze_support()" line can be omitted if the program
is not going to be frozen to produce an executable.
To fix this issue, refer to the "Safe importing of main module"
section in https://docs.python.org/3/library/multiprocessing.html
then you must update your Python code to guard usage of ``vllm`` behind a ``if
__name__ == '__main__':`` block. For example, instead of this:

.. code-block:: python
import vllm
llm = vllm.LLM(...)
try this instead:

.. code-block:: python
if __name__ == '__main__':
import vllm
llm = vllm.LLM(...)
Known Issues
----------------------------------------
- In ``v0.5.2``, ``v0.5.3``, and ``v0.5.3.post1``, there is a bug caused by `zmq <https://github.com/zeromq/pyzmq/issues/2000>`_ , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of ``vllm`` to include the `fix <https://github.com/vllm-project/vllm/pull/6759>`_.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ Documentation
design/input_processing/model_inputs_index
design/kernel/paged_attention
design/multimodal/multimodal_index
design/multiprocessing

.. For Developers: contributing to the vLLM project
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def __init__(

self.request_counter = Counter()

def __del__(self):
if self.llm_engine and hasattr(self.llm_engine, "shutdown"):
self.llm_engine.shutdown()

@staticmethod
def get_engine_class() -> Type[LLMEngine]:
if envs.VLLM_USE_V1:
Expand Down
4 changes: 2 additions & 2 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1


Expand Down Expand Up @@ -460,7 +460,7 @@ def get_default_config_root():

# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))),
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))),
"VLLM_LOG_BATCHSIZE_INTERVAL":
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
}
Expand Down
20 changes: 14 additions & 6 deletions vllm/executor/multiproc_worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,20 @@ def write_with_prefix(s: str):
file.write = write_with_prefix # type: ignore[method-assign]


def _check_multiproc_method():
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"debugging.html#python-multiprocessing "
"for more information.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"


def get_mp_context():
_check_multiproc_method()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)

Expand All @@ -284,12 +297,7 @@ def set_multiprocessing_worker_envs(parallel_config):
in a multiprocessing environment. This should be called by the parent
process before worker processes are created"""

if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
_check_multiproc_method()

# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
Expand Down
8 changes: 2 additions & 6 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import multiprocessing
import pickle
import queue
import signal
Expand All @@ -13,6 +12,7 @@
from msgspec import msgpack

from vllm.config import CacheConfig, VllmConfig
from vllm.executor.multiproc_worker_utils import get_mp_context
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler
Expand Down Expand Up @@ -210,11 +210,7 @@ def make_engine_core_process(
output_path: str,
ready_path: str,
) -> EngineCoreProcHandle:
# The current process might have CUDA context,
# so we need to spawn a new process.
# NOTE(rob): this is a problem for using EngineCoreProc w/
# LLM, since we need a if __name__ == "__main__" guard.
context = multiprocessing.get_context("spawn")
context = get_mp_context()

process_kwargs = {
"input_path": input_path,
Expand Down
11 changes: 9 additions & 2 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,16 @@ def __init__(
atexit.register(self.shutdown)

def shutdown(self):
# During final garbage collection in process shutdown, atexit may be
# None.
if atexit:
# in case shutdown gets called via __del__ first
atexit.unregister(self.shutdown)

# Shut down the zmq context.
self.ctx.destroy(linger=0)

if hasattr(self, "proc_handle"):
if hasattr(self, "proc_handle") and self.proc_handle:
# Shutdown the process if needed.
if self.proc_handle.proc.is_alive():
self.proc_handle.proc.terminate()
Expand All @@ -178,8 +184,9 @@ def shutdown(self):
]
for ipc_socket in ipc_sockets:
socket_file = ipc_socket.replace("ipc://", "")
if os.path.exists(socket_file):
if os and os.path.exists(socket_file):
os.remove(socket_file)
self.proc_handle = None

def __del__(self):
self.shutdown()
Expand Down
7 changes: 7 additions & 0 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,10 @@ def get_tokenizer_group(
f"found type: {type(tokenizer_group)}")

return tokenizer_group

def __del__(self):
self.shutdown()

def shutdown(self):
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
Loading

0 comments on commit 4863e5f

Please sign in to comment.