forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core] Support Lora lineage and base model metadata management (vllm-…
- Loading branch information
Showing
15 changed files
with
337 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import json | ||
import unittest | ||
|
||
from vllm.entrypoints.openai.cli_args import make_arg_parser | ||
from vllm.entrypoints.openai.serving_engine import LoRAModulePath | ||
from vllm.utils import FlexibleArgumentParser | ||
|
||
LORA_MODULE = { | ||
"name": "module2", | ||
"path": "/path/to/module2", | ||
"base_model_name": "llama" | ||
} | ||
|
||
|
||
class TestLoraParserAction(unittest.TestCase): | ||
|
||
def setUp(self): | ||
# Setting up argparse parser for tests | ||
parser = FlexibleArgumentParser( | ||
description="vLLM's remote OpenAI server.") | ||
self.parser = make_arg_parser(parser) | ||
|
||
def test_valid_key_value_format(self): | ||
# Test old format: name=path | ||
args = self.parser.parse_args([ | ||
'--lora-modules', | ||
'module1=/path/to/module1', | ||
]) | ||
expected = [LoRAModulePath(name='module1', path='/path/to/module1')] | ||
self.assertEqual(args.lora_modules, expected) | ||
|
||
def test_valid_json_format(self): | ||
# Test valid JSON format input | ||
args = self.parser.parse_args([ | ||
'--lora-modules', | ||
json.dumps(LORA_MODULE), | ||
]) | ||
expected = [ | ||
LoRAModulePath(name='module2', | ||
path='/path/to/module2', | ||
base_model_name='llama') | ||
] | ||
self.assertEqual(args.lora_modules, expected) | ||
|
||
def test_invalid_json_format(self): | ||
# Test invalid JSON format input, missing closing brace | ||
with self.assertRaises(SystemExit): | ||
self.parser.parse_args([ | ||
'--lora-modules', | ||
'{"name": "module3", "path": "/path/to/module3"' | ||
]) | ||
|
||
def test_invalid_type_error(self): | ||
# Test type error when values are not JSON or key=value | ||
with self.assertRaises(SystemExit): | ||
self.parser.parse_args([ | ||
'--lora-modules', | ||
'invalid_format' # This is not JSON or key=value format | ||
]) | ||
|
||
def test_invalid_json_field(self): | ||
# Test valid JSON format but missing required fields | ||
with self.assertRaises(SystemExit): | ||
self.parser.parse_args([ | ||
'--lora-modules', | ||
'{"name": "module4"}' # Missing required 'path' field | ||
]) | ||
|
||
def test_empty_values(self): | ||
# Test when no LoRA modules are provided | ||
args = self.parser.parse_args(['--lora-modules', '']) | ||
self.assertEqual(args.lora_modules, []) | ||
|
||
def test_multiple_valid_inputs(self): | ||
# Test multiple valid inputs (both old and JSON format) | ||
args = self.parser.parse_args([ | ||
'--lora-modules', | ||
'module1=/path/to/module1', | ||
json.dumps(LORA_MODULE), | ||
]) | ||
expected = [ | ||
LoRAModulePath(name='module1', path='/path/to/module1'), | ||
LoRAModulePath(name='module2', | ||
path='/path/to/module2', | ||
base_model_name='llama') | ||
] | ||
self.assertEqual(args.lora_modules, expected) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import json | ||
|
||
import openai # use the official client for correctness check | ||
import pytest | ||
import pytest_asyncio | ||
# downloading lora to test lora requests | ||
from huggingface_hub import snapshot_download | ||
|
||
from ...utils import RemoteOpenAIServer | ||
|
||
# any model with a chat template should work here | ||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" | ||
# technically this needs Mistral-7B-v0.1 as base, but we're not testing | ||
# generation quality here | ||
LORA_NAME = "typeof/zephyr-7b-beta-lora" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def zephyr_lora_files(): | ||
return snapshot_download(repo_id=LORA_NAME) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def server_with_lora_modules_json(zephyr_lora_files): | ||
# Define the json format LoRA module configurations | ||
lora_module_1 = { | ||
"name": "zephyr-lora", | ||
"path": zephyr_lora_files, | ||
"base_model_name": MODEL_NAME | ||
} | ||
|
||
lora_module_2 = { | ||
"name": "zephyr-lora2", | ||
"path": zephyr_lora_files, | ||
"base_model_name": MODEL_NAME | ||
} | ||
|
||
args = [ | ||
# use half precision for speed and memory savings in CI environment | ||
"--dtype", | ||
"bfloat16", | ||
"--max-model-len", | ||
"8192", | ||
"--enforce-eager", | ||
# lora config below | ||
"--enable-lora", | ||
"--lora-modules", | ||
json.dumps(lora_module_1), | ||
json.dumps(lora_module_2), | ||
"--max-lora-rank", | ||
"64", | ||
"--max-cpu-loras", | ||
"2", | ||
"--max-num-seqs", | ||
"64", | ||
] | ||
|
||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: | ||
yield remote_server | ||
|
||
|
||
@pytest_asyncio.fixture | ||
async def client_for_lora_lineage(server_with_lora_modules_json): | ||
async with server_with_lora_modules_json.get_async_client( | ||
) as async_client: | ||
yield async_client | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI, | ||
zephyr_lora_files): | ||
models = await client_for_lora_lineage.models.list() | ||
models = models.data | ||
served_model = models[0] | ||
lora_models = models[1:] | ||
assert served_model.id == MODEL_NAME | ||
assert served_model.root == MODEL_NAME | ||
assert served_model.parent is None | ||
assert all(lora_model.root == zephyr_lora_files | ||
for lora_model in lora_models) | ||
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models) | ||
assert lora_models[0].id == "zephyr-lora" | ||
assert lora_models[1].id == "zephyr-lora2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.