forked from octoml/mlc-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement Python chat module and REST API (octoml#223)
This PR adds the following: * A Python chat module with the same functionality defined in the CLI (note that this requires a module without tvm_runtime dependency, see changes to CMakeLists.txt) * A REST API that supports some common endpoints for interacting with Vicuna and RedPajama with streaming support * A sample client that shows you how to use the endpoints * Some documentation on how to run the server and client
- Loading branch information
Showing
5 changed files
with
287 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Instructions | ||
|
||
## REST API | ||
|
||
There is currently a dependency to build from source in order to use the [REST API](https://www.ibm.com/topics/rest-apis#:~:text=the%20next%20step-,What%20is%20a%20REST%20API%3F,representational%20state%20transfer%20architectural%20style.). | ||
|
||
1. Follow the instructions [here](https://github.com/mlc-ai/mlc-llm/tree/main/cpp) to build the CLI from source. | ||
2. Launch the server at [http://127.0.0.1:8000/](http://127.0.0.1:8000/). | ||
```shell | ||
cd mlc-llm | ||
python python/mlc_chat/server.py | ||
``` | ||
3. Go to [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs) to look at the list of supported endpoints, or run the sample client script to see how to send queries. | ||
``` | ||
python python/mlc_chat/sample_client.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
"""Python runtime for MLC chat.""" | ||
|
||
import ctypes | ||
|
||
import tvm | ||
|
||
|
||
def load_llm_chat(mlc_lib_path): | ||
return ctypes.CDLL(mlc_lib_path) | ||
|
||
|
||
def supported_models(): | ||
return set(["vicuna-v1-7b", "RedPajama-INCITE-Chat-3B-v1"]) | ||
|
||
|
||
def quantization_keys(): | ||
return ["q3f16_0", "q4f16_0", "q4f32_0", "q0f32", "q0f16"] | ||
|
||
|
||
class LLMChatModule: | ||
def __init__(self, mlc_lib_path, target="cuda", device_id=0): | ||
load_llm_chat(mlc_lib_path) | ||
fcreate = tvm.get_global_func("mlc.llm_chat_create") | ||
assert fcreate is not None | ||
if target == "cuda": | ||
self.device = tvm.cuda(device_id) | ||
elif target == "metal": | ||
self.device = tvm.metal(device_id) | ||
elif target == "vulkan": | ||
self.device = tvm.vulkan(device_id) | ||
else: | ||
raise ValueError("device type not supported yet") | ||
device_type = self.device.device_type | ||
chat_mod = fcreate(device_type, device_id) | ||
|
||
self.reload_func = chat_mod["reload"] | ||
self.prefill_func = chat_mod["prefill"] | ||
self.decode_func = chat_mod["decode"] | ||
self.stopped_func = chat_mod["stopped"] | ||
self.get_message_func = chat_mod["get_message"] | ||
self.reset_chat_func = chat_mod["reset_chat"] | ||
self.runtime_stats_text_func = chat_mod["runtime_stats_text"] | ||
self.reset_runtime_stats_func = chat_mod["reset_runtime_stats"] | ||
self.evaluate_func = chat_mod["evaluate"] | ||
self.get_role0 = chat_mod["get_role0"] | ||
self.get_role1 = chat_mod["get_role1"] | ||
|
||
def reload(self, lib, model_path): | ||
self.reload_func(lib, model_path) | ||
|
||
def prefill(self, input): | ||
self.prefill_func(input) | ||
|
||
def decode(self): | ||
self.decode_func() | ||
|
||
def stopped(self): | ||
return self.stopped_func() != 0 | ||
|
||
def get_message(self): | ||
return self.get_message_func() | ||
|
||
def reset_chat(self): | ||
self.reset_chat_func() | ||
|
||
def runtime_stats_text(self): | ||
return self.runtime_stats_text_func() | ||
|
||
def reset_runtime_stats(self): | ||
self.reset_runtime_stats_func() | ||
|
||
def evaluate(self): | ||
self.evaluate_func() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import requests | ||
import json | ||
|
||
# To launch the server, run | ||
# $ python python/mlc_chat/server.py | ||
|
||
# List the models that are currently supported | ||
r = requests.get("http://127.0.0.1:8000/models") | ||
print(f"Supported models: {r.json()}\n") | ||
|
||
# Get a response using a prompt without streaming | ||
payload = { | ||
"prompt": "Write a haiku" | ||
} | ||
r = requests.post("http://127.0.0.1:8000/chat/completions", json=payload) | ||
print(f"Without streaming: {r.json()}\n") | ||
|
||
# Reset the chat | ||
r = requests.post("http://127.0.0.1:8000/chat/reset", json=payload) | ||
print(f"Reset chat: {str(r)}\n") | ||
|
||
# Get a response using a prompt with streaming | ||
payload = { | ||
"prompt": "Write a haiku", | ||
"stream": True | ||
} | ||
with requests.post("http://127.0.0.1:8000/chat/completions", json=payload, stream=True) as r: | ||
print("With streaming: ") | ||
try: | ||
for data in r.iter_content(chunk_size=1024): | ||
if data: | ||
print(json.loads(data)) | ||
except requests.exceptions.ChunkedEncodingError as ex: | ||
print(f"Invalid chunk encoding {str(ex)}") | ||
print("\n") | ||
|
||
# Get the latest runtime stats | ||
r = requests.get("http://127.0.0.1:8000/stats") | ||
print(f"Runtime stats: {r.json()}\n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
from chat_module import LLMChatModule, supported_models, quantization_keys | ||
|
||
from pydantic import BaseModel | ||
from fastapi import FastAPI, HTTPException | ||
from fastapi.responses import StreamingResponse | ||
from contextlib import asynccontextmanager | ||
import uvicorn | ||
|
||
import tvm | ||
|
||
import argparse | ||
import os | ||
import json | ||
|
||
|
||
session = {} | ||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI): | ||
|
||
ARGS = _parse_args() | ||
|
||
chat_mod = LLMChatModule( | ||
ARGS.mlc_lib_path, | ||
ARGS.device_name, | ||
ARGS.device_id | ||
) | ||
model_path = os.path.join( | ||
ARGS.artifact_path, | ||
ARGS.model + "-" + ARGS.quantization | ||
) | ||
model_dir = ARGS.model + "-" + ARGS.quantization | ||
model_lib = model_dir + "-" + ARGS.device_name + ".so" | ||
lib_dir = os.path.join(model_path, model_lib) | ||
prebuilt_lib_dir = os.path.join(ARGS.artifact_path, "prebuilt", "lib", model_lib) | ||
if os.path.exists(lib_dir): | ||
lib = tvm.runtime.load_module(lib_dir) | ||
elif os.path.exists(prebuilt_lib_dir): | ||
lib = tvm.runtime.load_module(prebuilt_lib_dir) | ||
else: | ||
raise ValueError(f"Unable to find {model_lib} at {lib_dir} or {prebuilt_lib_dir}.") | ||
|
||
local_model_path = os.path.join(model_path, "params") | ||
prebuilt_model_path = os.path.join(ARGS.artifact_path, "prebuilt", f"mlc-chat-{model_dir}") | ||
if os.path.exists(local_model_path): | ||
chat_mod.reload(lib=lib, model_path=local_model_path) | ||
elif os.path.exists(prebuilt_model_path): | ||
chat_mod.reload(lib=lib, model_path=prebuilt_model_path) | ||
else: | ||
raise ValueError(f"Unable to find model params at {local_model_path} or {prebuilt_model_path}.") | ||
session["chat_mod"] = chat_mod | ||
|
||
yield | ||
|
||
session.clear() | ||
|
||
|
||
app = FastAPI(lifespan=lifespan) | ||
|
||
def _parse_args(): | ||
args = argparse.ArgumentParser() | ||
args.add_argument( | ||
"--model", | ||
type=str, | ||
choices=supported_models(), | ||
default="vicuna-v1-7b" | ||
) | ||
args.add_argument("--artifact-path", type=str, default="dist") | ||
args.add_argument( | ||
"--quantization", | ||
type=str, | ||
choices=quantization_keys(), | ||
default=quantization_keys()[0], | ||
) | ||
args.add_argument("--device-name", type=str, default="cuda") | ||
args.add_argument("--device-id", type=int, default=0) | ||
args.add_argument( | ||
"--mlc-path", type=str, default="", help="path to the mlc-llm repo" | ||
) | ||
parsed = args.parse_args() | ||
parsed.mlc_lib_path = os.path.join(parsed.mlc_path, "build/libmlc_llm_module.so") | ||
return parsed | ||
|
||
|
||
""" | ||
List the currently supported models and provides basic information about each of them. | ||
""" | ||
@app.get("/models") | ||
async def read_models(): | ||
return { | ||
"data": [{ | ||
"id": model, | ||
"object":"model" | ||
} for model in supported_models()] | ||
} | ||
|
||
""" | ||
Retrieve a model instance with basic information about the model. | ||
""" | ||
@app.get("/models/{model}") | ||
async def read_model(model: str): | ||
if model not in supported_models(): | ||
raise HTTPException(status_code=404, detail=f"Model {model} is not supported.") | ||
return { | ||
"id": model, | ||
"object":"model" | ||
} | ||
|
||
class ChatRequest(BaseModel): | ||
prompt: str | ||
stream: bool = False | ||
|
||
""" | ||
Creates model response for the given chat conversation. | ||
""" | ||
@app.post("/chat/completions") | ||
def request_completion(request: ChatRequest): | ||
session["chat_mod"].prefill(input=request.prompt) | ||
if request.stream: | ||
def iter_response(): | ||
while not session["chat_mod"].stopped(): | ||
session["chat_mod"].decode() | ||
msg = session["chat_mod"].get_message() | ||
yield json.dumps({"message": msg}) | ||
return StreamingResponse(iter_response(), media_type='application/json') | ||
else: | ||
msg = None | ||
while not session["chat_mod"].stopped(): | ||
session["chat_mod"].decode() | ||
msg = session["chat_mod"].get_message() | ||
return {"message": msg} | ||
|
||
""" | ||
Reset the chat for the currently initialized model. | ||
""" | ||
@app.post("/chat/reset") | ||
def reset(): | ||
session["chat_mod"].reset_chat() | ||
|
||
""" | ||
Get the runtime stats. | ||
""" | ||
@app.get("/stats") | ||
def read_stats(): | ||
return session["chat_mod"].runtime_stats_text() | ||
|
||
|
||
if __name__ == "__main__": | ||
uvicorn.run("server:app", port=8000, reload=True, access_log=False) |