Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
rickyyx committed Nov 15, 2024
1 parent ac49b59 commit b8889a2
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ steps:
- vllm/
- tests/v1
commands:
- pytest -v -s v1
- VLLM_USE_V1=1 pytest -v -s v1

- label: Examples Test # 15min
working_dir: "/vllm-workspace/examples"
Expand Down
3 changes: 3 additions & 0 deletions tests/v1/engine/test_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ async def generate(engine: AsyncLLM, request_id: str,

@pytest.mark.asyncio
async def test_load(monkeypatch):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the
# tests.
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

Expand Down
16 changes: 16 additions & 0 deletions tests/v1/engine/test_engine_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

Check failure on line 1 in tests/v1/engine/test_engine_args.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/v1/engine/test_engine_args.py:1:8: F401 `pytest` imported but unused

from vllm import envs

assert envs.VLLM_USE_V1, "VLLM_USE_V1 must be set to run this test"

from vllm.engine.arg_utils import EngineArgs

Check failure on line 7 in tests/v1/engine/test_engine_args.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E402)

tests/v1/engine/test_engine_args.py:7:1: E402 Module level import not at top of file


def test_v1_defaults():
engine_args = EngineArgs(model="facebook/opt-125m")

# Assert V1 defaults
assert engine_args.enable_prefix_caching
assert engine_args.max_num_seqs == 1024
assert engine_args.max_num_batched_tokens is None
4 changes: 3 additions & 1 deletion tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def test_engine_core(monkeypatch):
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT
)
executor_class = AsyncLLM._get_executor_cls(vllm_config)

engine_core = EngineCore(vllm_config=vllm_config,
Expand Down
4 changes: 3 additions & 1 deletion tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ async def test_engine_core_client_asyncio(monkeypatch):
m.setenv("VLLM_USE_V1", "1")

engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT
)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
client = EngineCoreClient.make_client(
vllm_config,
Expand Down
59 changes: 58 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.platforms import current_platform
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean

if TYPE_CHECKING:
Expand Down Expand Up @@ -985,7 +986,9 @@ def create_load_config(self) -> LoadConfig:
ignore_patterns=self.ignore_patterns,
)

def create_engine_config(self) -> VllmConfig:
def create_engine_config(
self, usage_context: Optional[UsageContext] = None
) -> VllmConfig:
# gguf file needs a specific model loader and doesn't use hf_repo
if check_gguf_file(self.model):
self.quantization = self.load_format = "gguf"
Expand Down Expand Up @@ -1210,6 +1213,60 @@ def create_engine_config(self) -> VllmConfig:
)


@dataclass
class V1EngineArgs(EngineArgs):
"""Arguments for vLLM engine v1."""

# V1's default values that differ from the default values in EngineArgs.
# This allows to switch between V1 and V0's default behaviour transparently.
enable_prefix_caching: bool = True
max_num_seqs: int = 1024
max_num_batched_tokens: Optional[int] = None

@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser = EngineArgs.add_cli_args(parser)
return parser

def create_engine_config(
self, usage_context: Optional[UsageContext] = None
) -> VllmConfig:
assert (
usage_context is not None
), "usage_context must be provided for V1EngineArgs"

if self.max_num_batched_tokens is None:
if usage_context == UsageContext.LLM_CLASS:
logger.warning(
"Setting max_num_batched_tokens to 8192 "
"for LLM_CLASS usage context."
)
self.max_num_batched_tokens = 8192
elif usage_context == UsageContext.OPENAI_API_SERVER:
logger.warning(
"Setting max_num_batched_tokens to 2048 "
"for OPENAI_API_SERVER usage context."
)
self.max_num_batched_tokens = 2048

engine_config = super().create_engine_config(usage_context)

# TODO (ywang96): Enable APC by default when VLM supports it.
if engine_config.model_config.is_multimodal_model:
logger.warning(
"Prefix caching is currently not supported for multimodal "
"models and has been disabled."
)
engine_config.cache_config.enable_prefix_caching = False
return engine_config


if envs.VLLM_USE_V1:
# Overwrite EngineArgs to use V1EngineArgs
# This has to be done before `AsyncEngineArgs` is imported.
EngineArgs = V1EngineArgs


@dataclass
class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine."""
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def from_engine_args(
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
engine_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def from_engine_args(

# Create the engine configs.
if engine_config is None:
vllm_config = engine_args.create_engine_config()
vllm_config = engine_args.create_engine_config(usage_context)
else:
vllm_config = engine_config

Expand Down
13 changes: 0 additions & 13 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,6 @@ def __init__(
executor_class: Type[GPUExecutor],
usage_context: UsageContext,
):
# Override the configs for V1.
# FIXME
if usage_context == UsageContext.LLM_CLASS:
vllm_config.scheduler_config.max_num_seqs = 1024
vllm_config.scheduler_config.max_num_batched_tokens = 8192
elif usage_context == UsageContext.OPENAI_API_SERVER:
vllm_config.scheduler_config.max_num_seqs = 1024
vllm_config.scheduler_config.max_num_batched_tokens = 2048

# TODO (ywang96): Enable APC by default when VLM supports it.
if not vllm_config.model_config.is_multimodal_model:
vllm_config.cache_config.enable_prefix_caching = True

assert vllm_config.model_config.task != "embedding"

logger.info("Initializing an LLM engine (v%s) with config: %s",
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def from_engine_args(
"""Creates an LLM engine from the engine arguments."""

# Create the engine configs.
vllm_config = engine_args.create_engine_config()
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(vllm_config)

if VLLM_ENABLE_V1_MULTIPROCESSING:
Expand Down

0 comments on commit b8889a2

Please sign in to comment.