From 1c1bb388e0d35a2d10da5c5cda2edac57bf62591 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 16 Sep 2024 22:17:32 -0600 Subject: [PATCH] [Frontend] Improve Nullable kv Arg Parsing (#8525) Signed-off-by: Alex-Brooks --- tests/engine/test_arg_utils.py | 20 +++++++++++++++++++- vllm/engine/arg_utils.py | 28 +++++++++++++++++++++------- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 3208d6bb48bdc..8dd200b35d0f3 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -1,6 +1,8 @@ +from argparse import ArgumentTypeError + import pytest -from vllm.engine.arg_utils import EngineArgs +from vllm.engine.arg_utils import EngineArgs, nullable_kvs from vllm.utils import FlexibleArgumentParser @@ -13,6 +15,10 @@ "image": 16, "video": 2 }), + ("Image=16, Video=2", { + "image": 16, + "video": 2 + }), ]) def test_limit_mm_per_prompt_parser(arg, expected): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) @@ -22,3 +28,15 @@ def test_limit_mm_per_prompt_parser(arg, expected): args = parser.parse_args(["--limit-mm-per-prompt", arg]) assert args.limit_mm_per_prompt == expected + + +@pytest.mark.parametrize( + ("arg"), + [ + "image", # Missing = + "image=4,image=5", # Conflicting values + "image=video=4" # Too many = in tokenized arg + ]) +def test_bad_nullable_kvs(arg): + with pytest.raises(ArgumentTypeError): + nullable_kvs(arg) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b5eba9ca3727a..35013eedea9c6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -44,22 +44,36 @@ def nullable_str(val: str): def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: + """Parses a string containing comma separate key [str] to value [int] + pairs into a dictionary. + + Args: + val: String value to be parsed. + + Returns: + Dictionary with parsed values. + """ if len(val) == 0: return None out_dict: Dict[str, int] = {} for item in val.split(","): - try: - key, value = item.split("=") - except TypeError as exc: - msg = "Each item should be in the form KEY=VALUE" - raise ValueError(msg) from exc + kv_parts = [part.lower().strip() for part in item.split("=")] + if len(kv_parts) != 2: + raise argparse.ArgumentTypeError( + "Each item should be in the form KEY=VALUE") + key, value = kv_parts try: - out_dict[key] = int(value) + parsed_value = int(value) except ValueError as exc: msg = f"Failed to parse value of item {key}={value}" - raise ValueError(msg) from exc + raise argparse.ArgumentTypeError(msg) from exc + + if key in out_dict and out_dict[key] != parsed_value: + raise argparse.ArgumentTypeError( + f"Conflicting values specified for key: {key}") + out_dict[key] = parsed_value return out_dict