From 260d40b5ea48df9421325388abcc8d907a560fc5 Mon Sep 17 00:00:00 2001 From: Jiaxin Shan Date: Thu, 19 Sep 2024 23:20:56 -0700 Subject: [PATCH] [Core] Support Lora lineage and base model metadata management (#6315) --- docs/source/models/lora.rst | 64 +++++++++++++ tests/entrypoints/openai/test_cli_args.py | 91 +++++++++++++++++++ tests/entrypoints/openai/test_lora_lineage.py | 83 +++++++++++++++++ tests/entrypoints/openai/test_models.py | 6 +- tests/entrypoints/openai/test_serving_chat.py | 6 +- .../entrypoints/openai/test_serving_engine.py | 5 +- vllm/entrypoints/openai/api_server.py | 14 ++- vllm/entrypoints/openai/cli_args.py | 27 +++++- vllm/entrypoints/openai/run_batch.py | 9 +- vllm/entrypoints/openai/serving_chat.py | 11 ++- vllm/entrypoints/openai/serving_completion.py | 9 +- vllm/entrypoints/openai/serving_embedding.py | 6 +- vllm/entrypoints/openai/serving_engine.py | 43 ++++++--- .../openai/serving_tokenization.py | 7 +- vllm/lora/request.py | 1 + 15 files changed, 337 insertions(+), 45 deletions(-) create mode 100644 tests/entrypoints/openai/test_cli_args.py create mode 100644 tests/entrypoints/openai/test_lora_lineage.py diff --git a/docs/source/models/lora.rst b/docs/source/models/lora.rst index b3821ebdfceca..ef0177eaf2162 100644 --- a/docs/source/models/lora.rst +++ b/docs/source/models/lora.rst @@ -159,3 +159,67 @@ Example request to unload a LoRA adapter: -d '{ "lora_name": "sql_adapter" }' + + +New format for `--lora-modules` +------------------------------- + +In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example: + +.. code-block:: bash + + --lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/ + +This would only include the `name` and `path` for each LoRA module, but did not provide a way to specify a `base_model_name`. +Now, you can specify a base_model_name alongside the name and path using JSON format. For example: + +.. code-block:: bash + + --lora-modules '{"name": "sql-lora", "path": "/path/to/lora", "base_model_name": "meta-llama/Llama-2-7b"}' + +To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case. + + +Lora model lineage in model card +-------------------------------- + +The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this: + +- The `parent` field of LoRA model `sql-lora` now links to its base model `meta-llama/Llama-2-7b-hf`. This correctly reflects the hierarchical relationship between the base model and the LoRA adapter. +- The `root` field points to the artifact location of the lora adapter. + +.. code-block:: bash + + $ curl http://localhost:8000/v1/models + + { + "object": "list", + "data": [ + { + "id": "meta-llama/Llama-2-7b-hf", + "object": "model", + "created": 1715644056, + "owned_by": "vllm", + "root": "~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9/", + "parent": null, + "permission": [ + { + ..... + } + ] + }, + { + "id": "sql-lora", + "object": "model", + "created": 1715644056, + "owned_by": "vllm", + "root": "~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/", + "parent": meta-llama/Llama-2-7b-hf, + "permission": [ + { + .... + } + ] + } + ] + } diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py new file mode 100644 index 0000000000000..8ee7fb8b2c6bf --- /dev/null +++ b/tests/entrypoints/openai/test_cli_args.py @@ -0,0 +1,91 @@ +import json +import unittest + +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.openai.serving_engine import LoRAModulePath +from vllm.utils import FlexibleArgumentParser + +LORA_MODULE = { + "name": "module2", + "path": "/path/to/module2", + "base_model_name": "llama" +} + + +class TestLoraParserAction(unittest.TestCase): + + def setUp(self): + # Setting up argparse parser for tests + parser = FlexibleArgumentParser( + description="vLLM's remote OpenAI server.") + self.parser = make_arg_parser(parser) + + def test_valid_key_value_format(self): + # Test old format: name=path + args = self.parser.parse_args([ + '--lora-modules', + 'module1=/path/to/module1', + ]) + expected = [LoRAModulePath(name='module1', path='/path/to/module1')] + self.assertEqual(args.lora_modules, expected) + + def test_valid_json_format(self): + # Test valid JSON format input + args = self.parser.parse_args([ + '--lora-modules', + json.dumps(LORA_MODULE), + ]) + expected = [ + LoRAModulePath(name='module2', + path='/path/to/module2', + base_model_name='llama') + ] + self.assertEqual(args.lora_modules, expected) + + def test_invalid_json_format(self): + # Test invalid JSON format input, missing closing brace + with self.assertRaises(SystemExit): + self.parser.parse_args([ + '--lora-modules', + '{"name": "module3", "path": "/path/to/module3"' + ]) + + def test_invalid_type_error(self): + # Test type error when values are not JSON or key=value + with self.assertRaises(SystemExit): + self.parser.parse_args([ + '--lora-modules', + 'invalid_format' # This is not JSON or key=value format + ]) + + def test_invalid_json_field(self): + # Test valid JSON format but missing required fields + with self.assertRaises(SystemExit): + self.parser.parse_args([ + '--lora-modules', + '{"name": "module4"}' # Missing required 'path' field + ]) + + def test_empty_values(self): + # Test when no LoRA modules are provided + args = self.parser.parse_args(['--lora-modules', '']) + self.assertEqual(args.lora_modules, []) + + def test_multiple_valid_inputs(self): + # Test multiple valid inputs (both old and JSON format) + args = self.parser.parse_args([ + '--lora-modules', + 'module1=/path/to/module1', + json.dumps(LORA_MODULE), + ]) + expected = [ + LoRAModulePath(name='module1', path='/path/to/module1'), + LoRAModulePath(name='module2', + path='/path/to/module2', + base_model_name='llama') + ] + self.assertEqual(args.lora_modules, expected) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/entrypoints/openai/test_lora_lineage.py b/tests/entrypoints/openai/test_lora_lineage.py new file mode 100644 index 0000000000000..ab39684c2f31a --- /dev/null +++ b/tests/entrypoints/openai/test_lora_lineage.py @@ -0,0 +1,83 @@ +import json + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +# downloading lora to test lora requests +from huggingface_hub import snapshot_download + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +# technically this needs Mistral-7B-v0.1 as base, but we're not testing +# generation quality here +LORA_NAME = "typeof/zephyr-7b-beta-lora" + + +@pytest.fixture(scope="module") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="module") +def server_with_lora_modules_json(zephyr_lora_files): + # Define the json format LoRA module configurations + lora_module_1 = { + "name": "zephyr-lora", + "path": zephyr_lora_files, + "base_model_name": MODEL_NAME + } + + lora_module_2 = { + "name": "zephyr-lora2", + "path": zephyr_lora_files, + "base_model_name": MODEL_NAME + } + + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + # lora config below + "--enable-lora", + "--lora-modules", + json.dumps(lora_module_1), + json.dumps(lora_module_2), + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + "--max-num-seqs", + "64", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_for_lora_lineage(server_with_lora_modules_json): + async with server_with_lora_modules_json.get_async_client( + ) as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI, + zephyr_lora_files): + models = await client_for_lora_lineage.models.list() + models = models.data + served_model = models[0] + lora_models = models[1:] + assert served_model.id == MODEL_NAME + assert served_model.root == MODEL_NAME + assert served_model.parent is None + assert all(lora_model.root == zephyr_lora_files + for lora_model in lora_models) + assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models) + assert lora_models[0].id == "zephyr-lora" + assert lora_models[1].id == "zephyr-lora2" diff --git a/tests/entrypoints/openai/test_models.py b/tests/entrypoints/openai/test_models.py index 5cd570f43e1a7..ae5bf404d3d2b 100644 --- a/tests/entrypoints/openai/test_models.py +++ b/tests/entrypoints/openai/test_models.py @@ -51,12 +51,14 @@ async def client(server): @pytest.mark.asyncio -async def test_check_models(client: openai.AsyncOpenAI): +async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files): models = await client.models.list() models = models.data served_model = models[0] lora_models = models[1:] assert served_model.id == MODEL_NAME - assert all(model.root == MODEL_NAME for model in models) + assert served_model.root == MODEL_NAME + assert all(lora_model.root == zephyr_lora_files + for lora_model in lora_models) assert lora_models[0].id == "zephyr-lora" assert lora_models[1].id == "zephyr-lora2" diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index de2a932199a01..db31745cc102e 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -7,10 +7,12 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.transformers_utils.tokenizer import get_tokenizer MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" +BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] @dataclass @@ -37,7 +39,7 @@ async def _async_serving_chat_init(): serving_completion = OpenAIServingChat(engine, model_config, - served_model_names=[MODEL_NAME], + BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, lora_modules=None, @@ -58,7 +60,7 @@ def test_serving_chat_should_set_correct_max_tokens(): serving_chat = OpenAIServingChat(mock_engine, MockModelConfig(), - served_model_names=[MODEL_NAME], + BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, lora_modules=None, diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index 6d9e620b4af7d..6199a75b5b4f8 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -8,9 +8,10 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest) -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing MODEL_NAME = "meta-llama/Llama-2-7b" +BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] LORA_LOADING_SUCCESS_MESSAGE = ( "Success: LoRA adapter '{lora_name}' added successfully.") LORA_UNLOADING_SUCCESS_MESSAGE = ( @@ -25,7 +26,7 @@ async def _async_serving_engine_init(): serving_engine = OpenAIServing(mock_engine_client, mock_model_config, - served_model_names=[MODEL_NAME], + BASE_MODEL_PATHS, lora_modules=None, prompt_adapters=None, request_logger=None) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index fd6f36e8768dd..5078a2654eb22 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -50,6 +50,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.logger import init_logger @@ -476,13 +477,18 @@ def init_app_state( else: request_logger = RequestLogger(max_log_len=args.max_log_len) + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + state.engine_client = engine_client state.log_stats = not args.disable_log_stats state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, - served_model_names, + base_model_paths, args.response_role, lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, @@ -494,7 +500,7 @@ def init_app_state( state.openai_serving_completion = OpenAIServingCompletion( engine_client, model_config, - served_model_names, + base_model_paths, lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, request_logger=request_logger, @@ -503,13 +509,13 @@ def init_app_state( state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, - served_model_names, + base_model_paths, request_logger=request_logger, ) state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, - served_model_names, + base_model_paths, lora_modules=args.lora_modules, request_logger=request_logger, chat_template=args.chat_template, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index bbb0823de9a51..9d3071a97fbe6 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -31,8 +31,23 @@ def __call__( lora_list: List[LoRAModulePath] = [] for item in values: - name, path = item.split('=') - lora_list.append(LoRAModulePath(name, path)) + if item in [None, '']: # Skip if item is None or empty string + continue + if '=' in item and ',' not in item: # Old format: name=path + name, path = item.split('=') + lora_list.append(LoRAModulePath(name, path)) + else: # Assume JSON format + try: + lora_dict = json.loads(item) + lora = LoRAModulePath(**lora_dict) + lora_list.append(lora) + except json.JSONDecodeError: + parser.error( + f"Invalid JSON format for --lora-modules: {item}") + except TypeError as e: + parser.error( + f"Invalid fields for --lora-modules: {item} - {str(e)}" + ) setattr(namespace, self.dest, lora_list) @@ -95,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=None, nargs='+', action=LoRAParserAction, - help="LoRA module configurations in the format name=path. " - "Multiple modules can be specified.") + help="LoRA module configurations in either 'name=path' format" + "or JSON format. " + "Example (old format): 'name=path' " + "Example (new format): " + "'{\"name\": \"name\", \"local_path\": \"path\", " + "\"base_model_name\": \"id\"}'") parser.add_argument( "--prompt-adapters", type=nullable_str, diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index b745410fe6b3b..f5249a0c447b3 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -20,6 +20,7 @@ # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION @@ -196,6 +197,10 @@ async def main(args): engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) model_config = await engine.get_model_config() + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] if args.disable_log_requests: request_logger = None @@ -206,7 +211,7 @@ async def main(args): openai_serving_chat = OpenAIServingChat( engine, model_config, - served_model_names, + base_model_paths, args.response_role, lora_modules=None, prompt_adapters=None, @@ -216,7 +221,7 @@ async def main(args): openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, - served_model_names, + base_model_paths, request_logger=request_logger, ) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b84898dc39b0f..1ee4b3ce17cfa 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -23,7 +23,8 @@ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + LoRAModulePath, OpenAIServing, PromptAdapterPath, TextTokensPrompt) @@ -47,7 +48,7 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], response_role: str, *, lora_modules: Optional[List[LoRAModulePath]], @@ -59,7 +60,7 @@ def __init__(self, tool_parser: Optional[str] = None): super().__init__(engine_client=engine_client, model_config=model_config, - served_model_names=served_model_names, + base_model_paths=base_model_paths, lora_modules=lora_modules, prompt_adapters=prompt_adapters, request_logger=request_logger, @@ -262,7 +263,7 @@ async def chat_completion_stream_generator( conversation: List[ConversationMessage], tokenizer: AnyTokenizer, ) -> AsyncGenerator[str, None]: - model_name = self.served_model_names[0] + model_name = self.base_model_paths[0].name created_time = int(time.time()) chunk_object_type: Final = "chat.completion.chunk" first_iteration = True @@ -596,7 +597,7 @@ async def chat_completion_full_generator( tokenizer: AnyTokenizer, ) -> Union[ErrorResponse, ChatCompletionResponse]: - model_name = self.served_model_names[0] + model_name = self.base_model_paths[0].name created_time = int(time.time()) final_res: Optional[RequestOutput] = None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 14fa60243c584..9abd74d0561d0 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -20,7 +20,8 @@ CompletionStreamResponse, ErrorResponse, UsageInfo) # yapf: enable -from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + LoRAModulePath, OpenAIServing, PromptAdapterPath) from vllm.logger import init_logger @@ -45,7 +46,7 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], *, lora_modules: Optional[List[LoRAModulePath]], prompt_adapters: Optional[List[PromptAdapterPath]], @@ -54,7 +55,7 @@ def __init__( ): super().__init__(engine_client=engine_client, model_config=model_config, - served_model_names=served_model_names, + base_model_paths=base_model_paths, lora_modules=lora_modules, prompt_adapters=prompt_adapters, request_logger=request_logger, @@ -89,7 +90,7 @@ async def create_completion( return self.create_error_response( "suffix is not currently supported") - model_name = self.served_model_names[0] + model_name = self.base_model_paths[0].name request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index f111a3a8277b5..5d95e1369b884 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -14,7 +14,7 @@ EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.logger import init_logger from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput from vllm.utils import merge_async_iterators, random_uuid @@ -73,13 +73,13 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], *, request_logger: Optional[RequestLogger], ): super().__init__(engine_client=engine_client, model_config=model_config, - served_model_names=served_model_names, + base_model_paths=base_model_paths, lora_modules=None, prompt_adapters=None, request_logger=request_logger) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 72f9381abc7db..9c4e8d8bb671a 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -39,6 +39,12 @@ logger = init_logger(__name__) +@dataclass +class BaseModelPath: + name: str + model_path: str + + @dataclass class PromptAdapterPath: name: str @@ -49,6 +55,7 @@ class PromptAdapterPath: class LoRAModulePath: name: str path: str + base_model_name: Optional[str] = None AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, @@ -66,7 +73,7 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], *, lora_modules: Optional[List[LoRAModulePath]], prompt_adapters: Optional[List[PromptAdapterPath]], @@ -79,17 +86,20 @@ def __init__( self.model_config = model_config self.max_model_len = model_config.max_model_len - self.served_model_names = served_model_names + self.base_model_paths = base_model_paths self.lora_id_counter = AtomicCounter(0) self.lora_requests = [] if lora_modules is not None: self.lora_requests = [ - LoRARequest( - lora_name=lora.name, - lora_int_id=i, - lora_path=lora.path, - ) for i, lora in enumerate(lora_modules, start=1) + LoRARequest(lora_name=lora.name, + lora_int_id=i, + lora_path=lora.path, + base_model_name=lora.base_model_name + if lora.base_model_name + and self._is_model_supported(lora.base_model_name) + else self.base_model_paths[0].name) + for i, lora in enumerate(lora_modules, start=1) ] self.prompt_adapter_requests = [] @@ -112,21 +122,23 @@ def __init__( async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ - ModelCard(id=served_model_name, + ModelCard(id=base_model.name, max_model_len=self.max_model_len, - root=self.served_model_names[0], + root=base_model.model_path, permission=[ModelPermission()]) - for served_model_name in self.served_model_names + for base_model in self.base_model_paths ] lora_cards = [ ModelCard(id=lora.lora_name, - root=self.served_model_names[0], + root=lora.local_path, + parent=lora.base_model_name if lora.base_model_name else + self.base_model_paths[0].name, permission=[ModelPermission()]) for lora in self.lora_requests ] prompt_adapter_cards = [ ModelCard(id=prompt_adapter.prompt_adapter_name, - root=self.served_model_names[0], + root=self.base_model_paths[0].name, permission=[ModelPermission()]) for prompt_adapter in self.prompt_adapter_requests ] @@ -169,7 +181,7 @@ async def _check_model( self, request: AnyRequest, ) -> Optional[ErrorResponse]: - if request.model in self.served_model_names: + if self._is_model_supported(request.model): return None if request.model in [lora.lora_name for lora in self.lora_requests]: return None @@ -187,7 +199,7 @@ def _maybe_get_adapters( self, request: AnyRequest ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[ None, PromptAdapterRequest]]: - if request.model in self.served_model_names: + if self._is_model_supported(request.model): return None, None for lora in self.lora_requests: if request.model == lora.lora_name: @@ -480,3 +492,6 @@ async def unload_lora_adapter( if lora_request.lora_name != lora_name ] return f"Success: LoRA adapter '{lora_name}' removed successfully." + + def _is_model_supported(self, model_name): + return any(model.name == model_name for model in self.base_model_paths) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 8f8862897fc4e..6d9a1ae088079 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -16,7 +16,8 @@ TokenizeRequest, TokenizeResponse) # yapf: enable -from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + LoRAModulePath, OpenAIServing) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import MistralTokenizer @@ -31,7 +32,7 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], *, lora_modules: Optional[List[LoRAModulePath]], request_logger: Optional[RequestLogger], @@ -39,7 +40,7 @@ def __init__( ): super().__init__(engine_client=engine_client, model_config=model_config, - served_model_names=served_model_names, + base_model_paths=base_model_paths, lora_modules=lora_modules, prompt_adapters=None, request_logger=request_logger) diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 47a59d80d3a45..c4b26dc92c6f4 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -28,6 +28,7 @@ class LoRARequest( lora_path: str = "" lora_local_path: Optional[str] = msgspec.field(default=None) long_lora_max_len: Optional[int] = None + base_model_name: Optional[str] = msgspec.field(default=None) def __post_init__(self): if 'lora_local_path' in self.__struct_fields__: