From 2256cae8b9afbee1ffaab055d6f1ade577a9192f Mon Sep 17 00:00:00 2001 From: Sudeep Agarwal Date: Thu, 25 May 2023 16:41:26 -0400 Subject: [PATCH] Implement Python chat module and REST API (#223) This PR adds the following: * A Python chat module with the same functionality defined in the CLI (note that this requires a module without tvm_runtime dependency, see changes to CMakeLists.txt) * A REST API that supports some common endpoints for interacting with Vicuna and RedPajama with streaming support * A sample client that shows you how to use the endpoints * Some documentation on how to run the server and client --- CMakeLists.txt | 10 +++ python/README.md | 16 ++++ python/mlc_chat/chat_module.py | 73 +++++++++++++++ python/mlc_chat/sample_client.py | 39 ++++++++ python/mlc_chat/server.py | 149 +++++++++++++++++++++++++++++++ 5 files changed, 287 insertions(+) create mode 100644 python/README.md create mode 100644 python/mlc_chat/chat_module.py create mode 100644 python/mlc_chat/sample_client.py create mode 100644 python/mlc_chat/server.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 7bd730be9c..1518d5df20 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -102,6 +102,16 @@ else() target_link_libraries(mlc_chat_cli PUBLIC mlc_llm) endif() +if (UNIX OR APPLE) + add_library(mlc_llm_module MODULE $) + target_link_libraries(mlc_llm_module PRIVATE tokenizers_cpp) + if (APPLE) + set_property(TARGET mlc_llm_module APPEND PROPERTY LINK_OPTIONS -undefined dynamic_lookup) + else() + set_property(TARGET mlc_llm_module APPEND PROPERTY LINK_OPTIONS) + endif() +endif() + # when this option is on, # we install all static lib deps into lib if (MLC_LLM_INSTALL_STATIC_LIB) diff --git a/python/README.md b/python/README.md new file mode 100644 index 0000000000..6564ad01fb --- /dev/null +++ b/python/README.md @@ -0,0 +1,16 @@ +# Instructions + +## REST API + +There is currently a dependency to build from source in order to use the [REST API](https://www.ibm.com/topics/rest-apis#:~:text=the%20next%20step-,What%20is%20a%20REST%20API%3F,representational%20state%20transfer%20architectural%20style.). + +1. Follow the instructions [here](https://github.com/mlc-ai/mlc-llm/tree/main/cpp) to build the CLI from source. +2. Launch the server at [http://127.0.0.1:8000/](http://127.0.0.1:8000/). + ```shell + cd mlc-llm + python python/mlc_chat/server.py + ``` +3. Go to [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs) to look at the list of supported endpoints, or run the sample client script to see how to send queries. + ``` + python python/mlc_chat/sample_client.py + ``` \ No newline at end of file diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py new file mode 100644 index 0000000000..0e8970827c --- /dev/null +++ b/python/mlc_chat/chat_module.py @@ -0,0 +1,73 @@ +"""Python runtime for MLC chat.""" + +import ctypes + +import tvm + + +def load_llm_chat(mlc_lib_path): + return ctypes.CDLL(mlc_lib_path) + + +def supported_models(): + return set(["vicuna-v1-7b", "RedPajama-INCITE-Chat-3B-v1"]) + + +def quantization_keys(): + return ["q3f16_0", "q4f16_0", "q4f32_0", "q0f32", "q0f16"] + + +class LLMChatModule: + def __init__(self, mlc_lib_path, target="cuda", device_id=0): + load_llm_chat(mlc_lib_path) + fcreate = tvm.get_global_func("mlc.llm_chat_create") + assert fcreate is not None + if target == "cuda": + self.device = tvm.cuda(device_id) + elif target == "metal": + self.device = tvm.metal(device_id) + elif target == "vulkan": + self.device = tvm.vulkan(device_id) + else: + raise ValueError("device type not supported yet") + device_type = self.device.device_type + chat_mod = fcreate(device_type, device_id) + + self.reload_func = chat_mod["reload"] + self.prefill_func = chat_mod["prefill"] + self.decode_func = chat_mod["decode"] + self.stopped_func = chat_mod["stopped"] + self.get_message_func = chat_mod["get_message"] + self.reset_chat_func = chat_mod["reset_chat"] + self.runtime_stats_text_func = chat_mod["runtime_stats_text"] + self.reset_runtime_stats_func = chat_mod["reset_runtime_stats"] + self.evaluate_func = chat_mod["evaluate"] + self.get_role0 = chat_mod["get_role0"] + self.get_role1 = chat_mod["get_role1"] + + def reload(self, lib, model_path): + self.reload_func(lib, model_path) + + def prefill(self, input): + self.prefill_func(input) + + def decode(self): + self.decode_func() + + def stopped(self): + return self.stopped_func() != 0 + + def get_message(self): + return self.get_message_func() + + def reset_chat(self): + self.reset_chat_func() + + def runtime_stats_text(self): + return self.runtime_stats_text_func() + + def reset_runtime_stats(self): + self.reset_runtime_stats_func() + + def evaluate(self): + self.evaluate_func() diff --git a/python/mlc_chat/sample_client.py b/python/mlc_chat/sample_client.py new file mode 100644 index 0000000000..ca5ee58f59 --- /dev/null +++ b/python/mlc_chat/sample_client.py @@ -0,0 +1,39 @@ +import requests +import json + +# To launch the server, run +# $ python python/mlc_chat/server.py + +# List the models that are currently supported +r = requests.get("http://127.0.0.1:8000/models") +print(f"Supported models: {r.json()}\n") + +# Get a response using a prompt without streaming +payload = { + "prompt": "Write a haiku" +} +r = requests.post("http://127.0.0.1:8000/chat/completions", json=payload) +print(f"Without streaming: {r.json()}\n") + +# Reset the chat +r = requests.post("http://127.0.0.1:8000/chat/reset", json=payload) +print(f"Reset chat: {str(r)}\n") + +# Get a response using a prompt with streaming +payload = { + "prompt": "Write a haiku", + "stream": True +} +with requests.post("http://127.0.0.1:8000/chat/completions", json=payload, stream=True) as r: + print("With streaming: ") + try: + for data in r.iter_content(chunk_size=1024): + if data: + print(json.loads(data)) + except requests.exceptions.ChunkedEncodingError as ex: + print(f"Invalid chunk encoding {str(ex)}") + print("\n") + +# Get the latest runtime stats +r = requests.get("http://127.0.0.1:8000/stats") +print(f"Runtime stats: {r.json()}\n") diff --git a/python/mlc_chat/server.py b/python/mlc_chat/server.py new file mode 100644 index 0000000000..818c25710d --- /dev/null +++ b/python/mlc_chat/server.py @@ -0,0 +1,149 @@ +from chat_module import LLMChatModule, supported_models, quantization_keys + +from pydantic import BaseModel +from fastapi import FastAPI, HTTPException +from fastapi.responses import StreamingResponse +from contextlib import asynccontextmanager +import uvicorn + +import tvm + +import argparse +import os +import json + + +session = {} + +@asynccontextmanager +async def lifespan(app: FastAPI): + + ARGS = _parse_args() + + chat_mod = LLMChatModule( + ARGS.mlc_lib_path, + ARGS.device_name, + ARGS.device_id + ) + model_path = os.path.join( + ARGS.artifact_path, + ARGS.model + "-" + ARGS.quantization + ) + model_dir = ARGS.model + "-" + ARGS.quantization + model_lib = model_dir + "-" + ARGS.device_name + ".so" + lib_dir = os.path.join(model_path, model_lib) + prebuilt_lib_dir = os.path.join(ARGS.artifact_path, "prebuilt", "lib", model_lib) + if os.path.exists(lib_dir): + lib = tvm.runtime.load_module(lib_dir) + elif os.path.exists(prebuilt_lib_dir): + lib = tvm.runtime.load_module(prebuilt_lib_dir) + else: + raise ValueError(f"Unable to find {model_lib} at {lib_dir} or {prebuilt_lib_dir}.") + + local_model_path = os.path.join(model_path, "params") + prebuilt_model_path = os.path.join(ARGS.artifact_path, "prebuilt", f"mlc-chat-{model_dir}") + if os.path.exists(local_model_path): + chat_mod.reload(lib=lib, model_path=local_model_path) + elif os.path.exists(prebuilt_model_path): + chat_mod.reload(lib=lib, model_path=prebuilt_model_path) + else: + raise ValueError(f"Unable to find model params at {local_model_path} or {prebuilt_model_path}.") + session["chat_mod"] = chat_mod + + yield + + session.clear() + + +app = FastAPI(lifespan=lifespan) + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--model", + type=str, + choices=supported_models(), + default="vicuna-v1-7b" + ) + args.add_argument("--artifact-path", type=str, default="dist") + args.add_argument( + "--quantization", + type=str, + choices=quantization_keys(), + default=quantization_keys()[0], + ) + args.add_argument("--device-name", type=str, default="cuda") + args.add_argument("--device-id", type=int, default=0) + args.add_argument( + "--mlc-path", type=str, default="", help="path to the mlc-llm repo" + ) + parsed = args.parse_args() + parsed.mlc_lib_path = os.path.join(parsed.mlc_path, "build/libmlc_llm_module.so") + return parsed + + +""" +List the currently supported models and provides basic information about each of them. +""" +@app.get("/models") +async def read_models(): + return { + "data": [{ + "id": model, + "object":"model" + } for model in supported_models()] + } + +""" +Retrieve a model instance with basic information about the model. +""" +@app.get("/models/{model}") +async def read_model(model: str): + if model not in supported_models(): + raise HTTPException(status_code=404, detail=f"Model {model} is not supported.") + return { + "id": model, + "object":"model" + } + +class ChatRequest(BaseModel): + prompt: str + stream: bool = False + +""" +Creates model response for the given chat conversation. +""" +@app.post("/chat/completions") +def request_completion(request: ChatRequest): + session["chat_mod"].prefill(input=request.prompt) + if request.stream: + def iter_response(): + while not session["chat_mod"].stopped(): + session["chat_mod"].decode() + msg = session["chat_mod"].get_message() + yield json.dumps({"message": msg}) + return StreamingResponse(iter_response(), media_type='application/json') + else: + msg = None + while not session["chat_mod"].stopped(): + session["chat_mod"].decode() + msg = session["chat_mod"].get_message() + return {"message": msg} + +""" +Reset the chat for the currently initialized model. +""" +@app.post("/chat/reset") +def reset(): + session["chat_mod"].reset_chat() + +""" +Get the runtime stats. +""" +@app.get("/stats") +def read_stats(): + return session["chat_mod"].runtime_stats_text() + + +if __name__ == "__main__": + uvicorn.run("server:app", port=8000, reload=True, access_log=False)