Skip to content

Commit

Permalink
Implement Python chat module and REST API (octoml#223)
Browse files Browse the repository at this point in the history
This PR adds the following:

* A Python chat module with the same functionality defined in the CLI (note that this requires a module without tvm_runtime dependency, see changes to CMakeLists.txt)
* A REST API that supports some common endpoints for interacting with Vicuna and RedPajama with streaming support
* A sample client that shows you how to use the endpoints
* Some documentation on how to run the server and client
  • Loading branch information
sudeepag authored May 25, 2023
1 parent 3f74b1c commit 2256cae
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 0 deletions.
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ else()
target_link_libraries(mlc_chat_cli PUBLIC mlc_llm)
endif()

if (UNIX OR APPLE)
add_library(mlc_llm_module MODULE $<TARGET_OBJECTS:mlc_llm_objs>)
target_link_libraries(mlc_llm_module PRIVATE tokenizers_cpp)
if (APPLE)
set_property(TARGET mlc_llm_module APPEND PROPERTY LINK_OPTIONS -undefined dynamic_lookup)
else()
set_property(TARGET mlc_llm_module APPEND PROPERTY LINK_OPTIONS)
endif()
endif()

# when this option is on,
# we install all static lib deps into lib
if (MLC_LLM_INSTALL_STATIC_LIB)
Expand Down
16 changes: 16 additions & 0 deletions python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Instructions

## REST API

There is currently a dependency to build from source in order to use the [REST API](https://www.ibm.com/topics/rest-apis#:~:text=the%20next%20step-,What%20is%20a%20REST%20API%3F,representational%20state%20transfer%20architectural%20style.).

1. Follow the instructions [here](https://github.com/mlc-ai/mlc-llm/tree/main/cpp) to build the CLI from source.
2. Launch the server at [http://127.0.0.1:8000/](http://127.0.0.1:8000/).
```shell
cd mlc-llm
python python/mlc_chat/server.py
```
3. Go to [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs) to look at the list of supported endpoints, or run the sample client script to see how to send queries.
```
python python/mlc_chat/sample_client.py
```
73 changes: 73 additions & 0 deletions python/mlc_chat/chat_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Python runtime for MLC chat."""

import ctypes

import tvm


def load_llm_chat(mlc_lib_path):
return ctypes.CDLL(mlc_lib_path)


def supported_models():
return set(["vicuna-v1-7b", "RedPajama-INCITE-Chat-3B-v1"])


def quantization_keys():
return ["q3f16_0", "q4f16_0", "q4f32_0", "q0f32", "q0f16"]


class LLMChatModule:
def __init__(self, mlc_lib_path, target="cuda", device_id=0):
load_llm_chat(mlc_lib_path)
fcreate = tvm.get_global_func("mlc.llm_chat_create")
assert fcreate is not None
if target == "cuda":
self.device = tvm.cuda(device_id)
elif target == "metal":
self.device = tvm.metal(device_id)
elif target == "vulkan":
self.device = tvm.vulkan(device_id)
else:
raise ValueError("device type not supported yet")
device_type = self.device.device_type
chat_mod = fcreate(device_type, device_id)

self.reload_func = chat_mod["reload"]
self.prefill_func = chat_mod["prefill"]
self.decode_func = chat_mod["decode"]
self.stopped_func = chat_mod["stopped"]
self.get_message_func = chat_mod["get_message"]
self.reset_chat_func = chat_mod["reset_chat"]
self.runtime_stats_text_func = chat_mod["runtime_stats_text"]
self.reset_runtime_stats_func = chat_mod["reset_runtime_stats"]
self.evaluate_func = chat_mod["evaluate"]
self.get_role0 = chat_mod["get_role0"]
self.get_role1 = chat_mod["get_role1"]

def reload(self, lib, model_path):
self.reload_func(lib, model_path)

def prefill(self, input):
self.prefill_func(input)

def decode(self):
self.decode_func()

def stopped(self):
return self.stopped_func() != 0

def get_message(self):
return self.get_message_func()

def reset_chat(self):
self.reset_chat_func()

def runtime_stats_text(self):
return self.runtime_stats_text_func()

def reset_runtime_stats(self):
self.reset_runtime_stats_func()

def evaluate(self):
self.evaluate_func()
39 changes: 39 additions & 0 deletions python/mlc_chat/sample_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import requests
import json

# To launch the server, run
# $ python python/mlc_chat/server.py

# List the models that are currently supported
r = requests.get("http://127.0.0.1:8000/models")
print(f"Supported models: {r.json()}\n")

# Get a response using a prompt without streaming
payload = {
"prompt": "Write a haiku"
}
r = requests.post("http://127.0.0.1:8000/chat/completions", json=payload)
print(f"Without streaming: {r.json()}\n")

# Reset the chat
r = requests.post("http://127.0.0.1:8000/chat/reset", json=payload)
print(f"Reset chat: {str(r)}\n")

# Get a response using a prompt with streaming
payload = {
"prompt": "Write a haiku",
"stream": True
}
with requests.post("http://127.0.0.1:8000/chat/completions", json=payload, stream=True) as r:
print("With streaming: ")
try:
for data in r.iter_content(chunk_size=1024):
if data:
print(json.loads(data))
except requests.exceptions.ChunkedEncodingError as ex:
print(f"Invalid chunk encoding {str(ex)}")
print("\n")

# Get the latest runtime stats
r = requests.get("http://127.0.0.1:8000/stats")
print(f"Runtime stats: {r.json()}\n")
149 changes: 149 additions & 0 deletions python/mlc_chat/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from chat_module import LLMChatModule, supported_models, quantization_keys

from pydantic import BaseModel
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from contextlib import asynccontextmanager
import uvicorn

import tvm

import argparse
import os
import json


session = {}

@asynccontextmanager
async def lifespan(app: FastAPI):

ARGS = _parse_args()

chat_mod = LLMChatModule(
ARGS.mlc_lib_path,
ARGS.device_name,
ARGS.device_id
)
model_path = os.path.join(
ARGS.artifact_path,
ARGS.model + "-" + ARGS.quantization
)
model_dir = ARGS.model + "-" + ARGS.quantization
model_lib = model_dir + "-" + ARGS.device_name + ".so"
lib_dir = os.path.join(model_path, model_lib)
prebuilt_lib_dir = os.path.join(ARGS.artifact_path, "prebuilt", "lib", model_lib)
if os.path.exists(lib_dir):
lib = tvm.runtime.load_module(lib_dir)
elif os.path.exists(prebuilt_lib_dir):
lib = tvm.runtime.load_module(prebuilt_lib_dir)
else:
raise ValueError(f"Unable to find {model_lib} at {lib_dir} or {prebuilt_lib_dir}.")

local_model_path = os.path.join(model_path, "params")
prebuilt_model_path = os.path.join(ARGS.artifact_path, "prebuilt", f"mlc-chat-{model_dir}")
if os.path.exists(local_model_path):
chat_mod.reload(lib=lib, model_path=local_model_path)
elif os.path.exists(prebuilt_model_path):
chat_mod.reload(lib=lib, model_path=prebuilt_model_path)
else:
raise ValueError(f"Unable to find model params at {local_model_path} or {prebuilt_model_path}.")
session["chat_mod"] = chat_mod

yield

session.clear()


app = FastAPI(lifespan=lifespan)

def _parse_args():
args = argparse.ArgumentParser()
args.add_argument(
"--model",
type=str,
choices=supported_models(),
default="vicuna-v1-7b"
)
args.add_argument("--artifact-path", type=str, default="dist")
args.add_argument(
"--quantization",
type=str,
choices=quantization_keys(),
default=quantization_keys()[0],
)
args.add_argument("--device-name", type=str, default="cuda")
args.add_argument("--device-id", type=int, default=0)
args.add_argument(
"--mlc-path", type=str, default="", help="path to the mlc-llm repo"
)
parsed = args.parse_args()
parsed.mlc_lib_path = os.path.join(parsed.mlc_path, "build/libmlc_llm_module.so")
return parsed


"""
List the currently supported models and provides basic information about each of them.
"""
@app.get("/models")
async def read_models():
return {
"data": [{
"id": model,
"object":"model"
} for model in supported_models()]
}

"""
Retrieve a model instance with basic information about the model.
"""
@app.get("/models/{model}")
async def read_model(model: str):
if model not in supported_models():
raise HTTPException(status_code=404, detail=f"Model {model} is not supported.")
return {
"id": model,
"object":"model"
}

class ChatRequest(BaseModel):
prompt: str
stream: bool = False

"""
Creates model response for the given chat conversation.
"""
@app.post("/chat/completions")
def request_completion(request: ChatRequest):
session["chat_mod"].prefill(input=request.prompt)
if request.stream:
def iter_response():
while not session["chat_mod"].stopped():
session["chat_mod"].decode()
msg = session["chat_mod"].get_message()
yield json.dumps({"message": msg})
return StreamingResponse(iter_response(), media_type='application/json')
else:
msg = None
while not session["chat_mod"].stopped():
session["chat_mod"].decode()
msg = session["chat_mod"].get_message()
return {"message": msg}

"""
Reset the chat for the currently initialized model.
"""
@app.post("/chat/reset")
def reset():
session["chat_mod"].reset_chat()

"""
Get the runtime stats.
"""
@app.get("/stats")
def read_stats():
return session["chat_mod"].runtime_stats_text()


if __name__ == "__main__":
uvicorn.run("server:app", port=8000, reload=True, access_log=False)

0 comments on commit 2256cae

Please sign in to comment.