From daad26afd4d702d876fec93e0b631601c52de35d Mon Sep 17 00:00:00 2001 From: Piotr Frankowski Date: Fri, 18 Oct 2024 18:47:54 +0200 Subject: [PATCH 1/4] [SERVE][CPP][Android] add native executable program to benchmark models --- CMakeLists.txt | 14 +++ cpp/json_ffi/json_ffi_engine.cc | 63 ++++++----- cpp/json_ffi/json_ffi_engine.h | 34 ++++++ cpp/json_ffi/openai_api_protocol.cc | 120 ++++++++++++++++++++ cpp/json_ffi/openai_api_protocol.h | 2 + cpp/llm_benchmark.cpp | 164 ++++++++++++++++++++++++++++ cpp/serve/config.cc | 2 +- 7 files changed, 366 insertions(+), 33 deletions(-) create mode 100644 cpp/llm_benchmark.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index e09728727c..887832dce1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) # tvm runtime config: minimize runtime components set(USE_RPC OFF) set(USE_MICRO OFF) +# set(USE_VULKAN ON) set(USE_GRAPH_EXECUTOR OFF) set(USE_GRAPH_EXECUTOR_DEBUG OFF) set(USE_AOT_EXECUTOR OFF) @@ -175,3 +176,16 @@ else() LIBRARY DESTINATION lib${LIB_SUFFIX} ) endif() + +add_executable(llm_benchmark cpp/llm_benchmark.cpp) + +target_include_directories(llm_benchmark PRIVATE + ${TVM_SOURCE_DIR}/include + ${TVM_SOURCE_DIR}/3rdparty/dlpack/include + ${TVM_SOURCE_DIR}/3rdparty/dmlc-core/include + ${TVM_SOURCE_DIR}/3rdparty/picojson + ${TOKENZIER_CPP_PATH}/include +) +target_link_libraries(llm_benchmark PUBLIC mlc_llm_module) + +# target_link_libraries(tvm PRIVATE log) diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index 9141a324d7..19a526fc93 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -150,24 +150,8 @@ void JSONFFIEngine::ExitBackgroundLoop() { this->engine_->ExitBackgroundLoop(); JSONFFIEngine::~JSONFFIEngine() { this->ExitBackgroundLoop(); } -class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { - public: - TVM_MODULE_VTABLE_BEGIN("mlc.json_ffi"); - TVM_MODULE_VTABLE_ENTRY("init_background_engine", &JSONFFIEngineImpl::InitBackgroundEngine); - TVM_MODULE_VTABLE_ENTRY("reload", &JSONFFIEngineImpl::Reload); - TVM_MODULE_VTABLE_ENTRY("unload", &JSONFFIEngineImpl::Unload); - TVM_MODULE_VTABLE_ENTRY("reset", &JSONFFIEngineImpl::Reset); - TVM_MODULE_VTABLE_ENTRY("chat_completion", &JSONFFIEngineImpl::ChatCompletion); - TVM_MODULE_VTABLE_ENTRY("abort", &JSONFFIEngineImpl::Abort); - TVM_MODULE_VTABLE_ENTRY("get_last_error", &JSONFFIEngineImpl::GetLastError); - TVM_MODULE_VTABLE_ENTRY("run_background_loop", &JSONFFIEngineImpl::RunBackgroundLoop); - TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop", - &JSONFFIEngineImpl::RunBackgroundStreamBackLoop); - TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); - TVM_MODULE_VTABLE_END(); - - void InitBackgroundEngine(int device_type, int device_id, - Optional request_stream_callback) { +void JSONFFIEngineImpl::InitBackgroundEngine(int device_type, int device_id, + Optional request_stream_callback) { DLDevice device{static_cast(device_type), device_id}; this->device_ = device; CHECK(request_stream_callback.defined()) @@ -175,17 +159,17 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { this->request_stream_callback_ = request_stream_callback.value(); auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args.size(), 1); - Array delta_outputs = args[0]; - std::string responses = this->GetResponseFromStreamOutput(delta_outputs); - this->request_stream_callback_(responses); + ICHECK_EQ(args.size(), 1); + Array delta_outputs = args[0]; + std::string responses = this->GetResponseFromStreamOutput(delta_outputs); + this->request_stream_callback_(responses); }; request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); this->engine_->InitThreadedEngine(device, std::move(request_stream_callback), NullOpt); - } +} - void Reload(String engine_config_json_str) { +void JSONFFIEngineImpl::Reload(String engine_config_json_str) { this->engine_->Reload(engine_config_json_str); this->default_generation_config_ = this->engine_->GetDefaultGenerationConfig(); auto engine_config = this->engine_->GetCompleteEngineConfig(); @@ -203,17 +187,26 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { this->model_config_ = ModelConfig::FromJSON( json::Lookup(model_config_json_unwrapped, "model_config")); this->tokenizer_ = Tokenizer::FromPath(engine_config->model); - } +} - void Unload() { this->engine_->Unload(); } +void JSONFFIEngineImpl::Unload() { + this->engine_->Unload(); +} - void Reset() { this->engine_->Reset(); } +void JSONFFIEngineImpl::Reset() { + this->engine_->Reset(); +} - void RunBackgroundLoop() { this->engine_->RunBackgroundLoop(); } +void JSONFFIEngineImpl::RunBackgroundLoop() { + this->engine_->RunBackgroundLoop(); +} - void RunBackgroundStreamBackLoop() { this->engine_->RunBackgroundStreamBackLoop(); } +void JSONFFIEngineImpl::RunBackgroundStreamBackLoop() { + this->engine_->RunBackgroundStreamBackLoop(); +} - String GetResponseFromStreamOutput(Array delta_outputs) { +String JSONFFIEngineImpl::GetResponseFromStreamOutput(Array delta_outputs) { + picojson::array json_response_arr; for (const auto& delta_output : delta_outputs) { std::string request_id = delta_output->request_id; @@ -292,8 +285,14 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { } } return picojson::value(json_response_arr).serialize(); - } -}; +} + +Module JSONFFIEngineImpl::Create() { + auto n = make_object(); + return Module(n); +} + +// TVM_REGISTER_GLOBAL("mlc.json_ffi.CreateJSONFFIEngine").set_body_typed(JSONFFIEngineImpl::Create); TVM_REGISTER_GLOBAL("mlc.json_ffi.CreateJSONFFIEngine").set_body_typed([]() { return Module(make_object()); diff --git a/cpp/json_ffi/json_ffi_engine.h b/cpp/json_ffi/json_ffi_engine.h index 616c3c12ac..10d841c90c 100644 --- a/cpp/json_ffi/json_ffi_engine.h +++ b/cpp/json_ffi/json_ffi_engine.h @@ -69,6 +69,40 @@ class JSONFFIEngine { std::unordered_map request_map_; }; + + +class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { +public: + static Module Create(); + + TVM_MODULE_VTABLE_BEGIN("mlc.json_ffi"); + TVM_MODULE_VTABLE_ENTRY("init_background_engine", &JSONFFIEngineImpl::InitBackgroundEngine); + TVM_MODULE_VTABLE_ENTRY("reload", &JSONFFIEngineImpl::Reload); + TVM_MODULE_VTABLE_ENTRY("unload", &JSONFFIEngineImpl::Unload); + TVM_MODULE_VTABLE_ENTRY("reset", &JSONFFIEngineImpl::Reset); + TVM_MODULE_VTABLE_ENTRY("chat_completion", &JSONFFIEngineImpl::ChatCompletion); + TVM_MODULE_VTABLE_ENTRY("abort", &JSONFFIEngineImpl::Abort); + TVM_MODULE_VTABLE_ENTRY("get_last_error", &JSONFFIEngineImpl::GetLastError); + TVM_MODULE_VTABLE_ENTRY("run_background_loop", &JSONFFIEngineImpl::RunBackgroundLoop); + TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop", + &JSONFFIEngineImpl::RunBackgroundStreamBackLoop); + TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); + TVM_MODULE_VTABLE_END(); + + void InitBackgroundEngine(int device_type, int device_id, Optional request_stream_callback); + void Reload(String engine_config_json_str); + void Unload(); + void Reset(); + void RunBackgroundLoop(); + void RunBackgroundStreamBackLoop(); + + // Implement the TVM_MODULE_VTABLE + // TVM_DEFINE_OBJECT_REF_METHODS(JSONFFIEngineImpl, ModuleNode, JSONFFIEngineImplNode); + +private: + String GetResponseFromStreamOutput(Array delta_outputs); +}; + } // namespace json_ffi } // namespace llm } // namespace mlc diff --git a/cpp/json_ffi/openai_api_protocol.cc b/cpp/json_ffi/openai_api_protocol.cc index 3e11af4d11..234685ec0f 100644 --- a/cpp/json_ffi/openai_api_protocol.cc +++ b/cpp/json_ffi/openai_api_protocol.cc @@ -500,6 +500,51 @@ picojson::object ChatCompletionStreamResponseChoice::AsJSON() const { return obj; } +Result ChatCompletionStreamResponseChoice::FromJSON(const picojson::object& json_obj) { + using TResult = Result; + ChatCompletionStreamResponseChoice choice; + + // // index + // Result index_res = json::LookupWithResultReturn(json_obj, "index"); + // if (index_res.IsErr()) { + // return TResult::Error(index_res.UnwrapErr()); + // } + // choice.index = index_res.Unwrap(); + + // delta + Result delta_obj_res = json::LookupWithResultReturn(json_obj, "delta"); + if (delta_obj_res.IsErr()) { + return TResult::Error(delta_obj_res.UnwrapErr()); + } + Result delta_res = ChatCompletionMessage::FromJSON(delta_obj_res.Unwrap()); + if (delta_res.IsErr()) { + return TResult::Error(delta_res.UnwrapErr()); + } + choice.delta = delta_res.Unwrap(); + + // // finish_reason (optional) + // Result> finish_reason_res = json::LookupOptionalWithResultReturn(json_obj, "finish_reason"); + // if (finish_reason_res.IsErr()) { + // return TResult::Error(finish_reason_res.UnwrapErr()); + // } + // std::optional finish_reason_str = finish_reason_res.Unwrap(); + // if (finish_reason_str.has_value()) { + // if (finish_reason_str.value() == "stop") { + // choice.finish_reason = FinishReason::stop; + // } else if (finish_reason_str.value() == "length") { + // choice.finish_reason = FinishReason::length; + // } else if (finish_reason_str.value() == "tool_calls") { + // choice.finish_reason = FinishReason::tool_calls; + // } else if (finish_reason_str.value() == "error") { + // choice.finish_reason = FinishReason::error; + // } else { + // return TResult::Error("Invalid finish_reason: " + finish_reason_str.value()); + // } + // } + + return TResult::Ok(choice); +} + picojson::object ChatCompletionResponse::AsJSON() const { picojson::object obj; obj["id"] = picojson::value(this->id); @@ -535,6 +580,81 @@ picojson::object ChatCompletionStreamResponse::AsJSON() const { return obj; } +Result ChatCompletionStreamResponse::FromJSON(const std::string& json_str) { + using TResult = Result; + Result json_obj_res = json::ParseToJSONObjectWithResultReturn(json_str); + if (json_obj_res.IsErr()) { + return TResult::Error(json_obj_res.UnwrapErr()); + } + picojson::object json_obj = json_obj_res.Unwrap(); + ChatCompletionStreamResponse response; + + // // id + // Result id_res = json::LookupWithResultReturn(json_obj, "id"); + // if (id_res.IsErr()) { + // return TResult::Error(id_res.UnwrapErr()); + // } + // response.id = id_res.Unwrap(); + + // // object + // Result object_res = json::LookupWithResultReturn(json_obj, "object"); + // if (object_res.IsErr()) { + // return TResult::Error(object_res.UnwrapErr()); + // } + // response.object = object_res.Unwrap(); + + // // created + // Result created_res = json::LookupWithResultReturn(json_obj, "created"); + // if (created_res.IsErr()) { + // return TResult::Error(created_res.UnwrapErr()); + // } + // response.created = created_res.Unwrap(); + + // // model + // Result model_res = json::LookupWithResultReturn(json_obj, "model"); + // if (model_res.IsErr()) { + // return TResult::Error(model_res.UnwrapErr()); + // } + // response.model = model_res.Unwrap(); + + // // system_fingerprint + // Result system_fingerprint_res = json::LookupWithResultReturn(json_obj, "system_fingerprint"); + // if (system_fingerprint_res.IsErr()) { + // return TResult::Error(system_fingerprint_res.UnwrapErr()); + // } + // response.system_fingerprint = system_fingerprint_res.Unwrap(); + + // choices + Result choices_arr_res = json::LookupWithResultReturn(json_obj, "choices"); + if (choices_arr_res.IsErr()) { + return TResult::Error(choices_arr_res.UnwrapErr()); + } + std::vector choices; + for (const auto& item : choices_arr_res.Unwrap()) { + if (!item.is()) { + return TResult::Error("A choice in chat completion stream response is not an object"); + } + Result choice = ChatCompletionStreamResponseChoice::FromJSON(item.get()); + if (choice.IsErr()) { + return TResult::Error(choice.UnwrapErr()); + } + choices.push_back(choice.Unwrap()); + } + response.choices = choices; + + // // usage (optional) + // Result> usage_res = json::LookupOptionalWithResultReturn(json_obj, "usage"); + // if (usage_res.IsErr()) { + // return TResult::Error(usage_res.UnwrapErr()); + // } + // std::optional usage_obj = usage_res.Unwrap(); + // if (usage_obj.has_value()) { + // response.usage = picojson::value(usage_obj.value()); + // } + + return TResult::Ok(response); +} + } // namespace json_ffi } // namespace llm } // namespace mlc diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index 61de01da1d..9d38f7e27d 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -172,6 +172,7 @@ class ChatCompletionStreamResponseChoice { // TODO: logprobs picojson::object AsJSON() const; + static Result FromJSON(const picojson::object& json_obj); }; class ChatCompletionResponse { @@ -198,6 +199,7 @@ class ChatCompletionStreamResponse { std::optional usage; picojson::object AsJSON() const; + static Result FromJSON(const std::string& json_str); }; } // namespace json_ffi diff --git a/cpp/llm_benchmark.cpp b/cpp/llm_benchmark.cpp new file mode 100644 index 0000000000..a274b26e7f --- /dev/null +++ b/cpp/llm_benchmark.cpp @@ -0,0 +1,164 @@ +#include +#include +#include +#include +#include // for exit() +#include "json_ffi/json_ffi_engine.h" +#include "support/result.h" + + +using namespace tvm::runtime; +using namespace mlc::llm::json_ffi; + +auto start = std::chrono::high_resolution_clock::now(); +auto end = std::chrono::high_resolution_clock::now(); +int counter = 0; +int glob_max_tokens = 0; +int iterations = 1; + +void exitProgram() { + std::this_thread::sleep_for(std::chrono::seconds(3)); // Sleep for 3 seconds + std::cout << "Exiting program after 3 seconds..." << std::endl; + exit(0); // Exit the program +} + +// Define the callback function that processes the responses +void RequestStreamCallback(const std::string& response) { + mlc::llm::Result stream_response_result = + ChatCompletionStreamResponse::FromJSON( + response.substr(1, response.size() - 2) + ); + + if (stream_response_result.IsOk()) { + auto unwrp_res = stream_response_result.Unwrap(); + + if (unwrp_res.choices.size() > 1) { + std::cerr << response << "(!!!More choices!!!)\n"; + } else { + std::string chunk_text = "."; + if (iterations == 1) { + chunk_text = unwrp_res.choices[0].delta.content.Text(); + } + std::cout << chunk_text; + counter++; + if (counter >= (glob_max_tokens - 20)) { + end = std::chrono::high_resolution_clock::now(); + } + } + } else { + std::string chunk_text = stream_response_result.UnwrapErr(); + std::cerr << "Error parsing response." + chunk_text + "\n"; + std::cerr << response << "\n"; + } +} + + +void benchmark_llm( + const std::string& model_path, + const std::string& modellib_path, + const std::string& mode, + const int device_type, + const int timeout, + const std::string& input_text) { + + int device_id = 0; + tvm::runtime::PackedFunc request_stream_callback = tvm::runtime::PackedFunc( + [](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue* ret) { + std::string response = args[0]; + RequestStreamCallback(response); + } + ); + + + const PackedFunc* create_engine_func = tvm::runtime::Registry::Get("mlc.json_ffi.CreateJSONFFIEngine"); + + if (create_engine_func == nullptr) { + throw std::runtime_error("Cannot find mlc.json_ffi.CreateJSONFFIEngine in the registry."); + } + + // Call the function and get the module (which holds JSONFFIEngineImpl) + Module engine_mod = (*create_engine_func)(); + + // Cast the module to JSONFFIEngineImpl + auto* engine = dynamic_cast(engine_mod.operator->()); + if (!engine) { + throw std::runtime_error("Failed to cast to JSONFFIEngineImpl."); + } + + engine->InitBackgroundEngine(device_type, device_id, request_stream_callback); + + std::thread background_stream_back_loop([&engine]() { + engine->RunBackgroundStreamBackLoop(); + }); + + std::thread background_loop([&engine]() { + engine->RunBackgroundLoop(); + }); + + + // Now call the Reload function + std::string engine_json = "{\"model\":\"" + model_path + "\", \"model_lib\":\"" + modellib_path + "\", \"mode\": \"" + mode + "\"}"; + std::cerr << engine_json << std::endl; + engine->Reload(engine_json); + std::cerr << "\engine->Reload\n"; + + + // Prepare input + std::string request_json = "{\"messages\":[{\"role\":\"user\",\"content\":\"" + input_text + "\"}], \"max_tokens\": " + std::to_string(glob_max_tokens) + "}"; + std::string request_id = "benchmark_request"; + + for(int i = 1; i <= iterations; i++) { + counter = 0; + // Measure inference time + start = std::chrono::high_resolution_clock::now(); + engine->ChatCompletion(request_json, request_id); + + // std::cerr << "\nRunning in background. Sleeping main thread " + timeout + "s... Wait for text response (3s - 1m).\n\n"; + std::this_thread::sleep_for(std::chrono::seconds(timeout)); + // std::cerr << "\nWakeup...\n"; + + // std::cerr << "\nAborting...\n"; + engine->Abort(request_id); + std::this_thread::sleep_for(std::chrono::seconds(3)); + + std::cerr << i << " Max tokens:" << glob_max_tokens << "; Counter:" << counter << "\n\n"; + + std::chrono::duration elapsed = end - start; + std::cerr << i << " Inference time: " << elapsed.count() << " seconds" << std::endl; + + std::cerr << i << " End-to-end decoded avg token/s: " << std::to_string(counter / elapsed.count()) << "\n"; + } + + + engine->ExitBackgroundLoop(); + std::this_thread::sleep_for(std::chrono::seconds(3)); + + background_stream_back_loop.join(); + background_loop.join(); + + std::cerr << "engine->Unload\n"; + std::thread(exitProgram).detach(); + engine->Unload(); + return; +} + +int main(int argc, char* argv[]) { + if (argc != 9) { + std::cerr << "Usage: " << argv[0] << " <1:model_path:str> <2:model_lib_path:str> <3:mode:str> <4:device_type:int> <5:timeout:int> <6:max_tokens:int> <7:input_text:str> <8:iterations>" << std::endl + << "Device types: kDLCPU = 1; kDLOpenCL = 4; kDLVulkan = 7;\n" << "Be carefull with number of iterations.\n 1 iteration gives you text outputs."; + return 1; + } + + std::string model_path = argv[1]; + std::string modellib_path = argv[2]; + std::string mode = argv[3]; + int device_type = std::stoi(argv[4]); + int timeout = std::stoi(argv[5]); + glob_max_tokens = std::stoi(argv[6]); + std::string input_text = argv[7]; + iterations = std::stoi(argv[8]); + + benchmark_llm(model_path, modellib_path, mode, device_type, timeout, input_text); + + return 0; +} \ No newline at end of file diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 9310b79028..a865dadfea 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -28,7 +28,7 @@ uint64_t TotalDetectGlobalMemory(DLDevice device) { // memory space, we set a best available space so that MLC LLM can run 7B or 8B models on Android // with OpenCL. if (device.device_type == kDLOpenCL) { - int64_t min_size_bytes = 5LL * 1024 * 1024 * 1024; // Minimum size is 5 GB + int64_t min_size_bytes = 10LL * 1024 * 1024 * 1024; // Minimum size is 5 GB gpu_size_bytes = std::max(gpu_size_bytes, min_size_bytes); } return gpu_size_bytes; From 828ac15892e5074d00fc6d703482f225ca36ba22 Mon Sep 17 00:00:00 2001 From: Piotr Frankowski Date: Fri, 18 Oct 2024 18:57:40 +0200 Subject: [PATCH 2/4] [TYPO] remove leading spaces --- cpp/llm_benchmark.cpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/cpp/llm_benchmark.cpp b/cpp/llm_benchmark.cpp index a274b26e7f..31b315ef39 100644 --- a/cpp/llm_benchmark.cpp +++ b/cpp/llm_benchmark.cpp @@ -24,7 +24,7 @@ void exitProgram() { // Define the callback function that processes the responses void RequestStreamCallback(const std::string& response) { - mlc::llm::Result stream_response_result = + mlc::llm::Result stream_response_result = ChatCompletionStreamResponse::FromJSON( response.substr(1, response.size() - 2) ); @@ -54,9 +54,9 @@ void RequestStreamCallback(const std::string& response) { void benchmark_llm( - const std::string& model_path, - const std::string& modellib_path, - const std::string& mode, + const std::string& model_path, + const std::string& modellib_path, + const std::string& mode, const int device_type, const int timeout, const std::string& input_text) { @@ -101,8 +101,7 @@ void benchmark_llm( std::cerr << engine_json << std::endl; engine->Reload(engine_json); std::cerr << "\engine->Reload\n"; - - + // Prepare input std::string request_json = "{\"messages\":[{\"role\":\"user\",\"content\":\"" + input_text + "\"}], \"max_tokens\": " + std::to_string(glob_max_tokens) + "}"; std::string request_id = "benchmark_request"; @@ -148,7 +147,7 @@ int main(int argc, char* argv[]) { << "Device types: kDLCPU = 1; kDLOpenCL = 4; kDLVulkan = 7;\n" << "Be carefull with number of iterations.\n 1 iteration gives you text outputs."; return 1; } - + std::string model_path = argv[1]; std::string modellib_path = argv[2]; std::string mode = argv[3]; @@ -158,7 +157,7 @@ int main(int argc, char* argv[]) { std::string input_text = argv[7]; iterations = std::stoi(argv[8]); - benchmark_llm(model_path, modellib_path, mode, device_type, timeout, input_text); + benchmark_llm(model_path, modellib_path, mode, device_type, timeout, input_text); return 0; } \ No newline at end of file From 0e9c77977c5a07ed1dcddad50cc2ccc7616278b5 Mon Sep 17 00:00:00 2001 From: Piotr Frankowski Date: Fri, 18 Oct 2024 19:11:37 +0200 Subject: [PATCH 3/4] [TYPO] clang-format --- cpp/json_ffi/json_ffi_engine.cc | 221 +++++++++++++--------------- cpp/json_ffi/json_ffi_engine.h | 61 ++++---- cpp/json_ffi/openai_api_protocol.cc | 30 ++-- 3 files changed, 155 insertions(+), 157 deletions(-) diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index 19a526fc93..66e9b8a7da 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -152,144 +152,135 @@ JSONFFIEngine::~JSONFFIEngine() { this->ExitBackgroundLoop(); } void JSONFFIEngineImpl::InitBackgroundEngine(int device_type, int device_id, Optional request_stream_callback) { - DLDevice device{static_cast(device_type), device_id}; - this->device_ = device; - CHECK(request_stream_callback.defined()) - << "JSONFFIEngine requires request stream callback function, but it is not given."; - this->request_stream_callback_ = request_stream_callback.value(); - - auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args.size(), 1); - Array delta_outputs = args[0]; - std::string responses = this->GetResponseFromStreamOutput(delta_outputs); - this->request_stream_callback_(responses); - }; - - request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - this->engine_->InitThreadedEngine(device, std::move(request_stream_callback), NullOpt); + DLDevice device{static_cast(device_type), device_id}; + this->device_ = device; + CHECK(request_stream_callback.defined()) + << "JSONFFIEngine requires request stream callback function, but it is not given."; + this->request_stream_callback_ = request_stream_callback.value(); + + auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args.size(), 1); + Array delta_outputs = args[0]; + std::string responses = this->GetResponseFromStreamOutput(delta_outputs); + this->request_stream_callback_(responses); + }; + + request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); + this->engine_->InitThreadedEngine(device, std::move(request_stream_callback), NullOpt); } void JSONFFIEngineImpl::Reload(String engine_config_json_str) { - this->engine_->Reload(engine_config_json_str); - this->default_generation_config_ = this->engine_->GetDefaultGenerationConfig(); - auto engine_config = this->engine_->GetCompleteEngineConfig(); - - // Load conversation template. - Result model_config_json = - serve::Model::LoadModelConfig(engine_config->model); - CHECK(model_config_json.IsOk()) << model_config_json.UnwrapErr(); - const picojson::object& model_config_json_unwrapped = model_config_json.Unwrap(); - Result conv_template = Conversation::FromJSON( - json::Lookup(model_config_json_unwrapped, "conv_template")); - CHECK(!conv_template.IsErr()) << "Invalid conversation template JSON: " - << conv_template.UnwrapErr(); - this->conv_template_ = conv_template.Unwrap(); - this->model_config_ = ModelConfig::FromJSON( - json::Lookup(model_config_json_unwrapped, "model_config")); - this->tokenizer_ = Tokenizer::FromPath(engine_config->model); + this->engine_->Reload(engine_config_json_str); + this->default_generation_config_ = this->engine_->GetDefaultGenerationConfig(); + auto engine_config = this->engine_->GetCompleteEngineConfig(); + + // Load conversation template. + Result model_config_json = serve::Model::LoadModelConfig(engine_config->model); + CHECK(model_config_json.IsOk()) << model_config_json.UnwrapErr(); + const picojson::object& model_config_json_unwrapped = model_config_json.Unwrap(); + Result conv_template = Conversation::FromJSON( + json::Lookup(model_config_json_unwrapped, "conv_template")); + CHECK(!conv_template.IsErr()) << "Invalid conversation template JSON: " + << conv_template.UnwrapErr(); + this->conv_template_ = conv_template.Unwrap(); + this->model_config_ = ModelConfig::FromJSON( + json::Lookup(model_config_json_unwrapped, "model_config")); + this->tokenizer_ = Tokenizer::FromPath(engine_config->model); } -void JSONFFIEngineImpl::Unload() { - this->engine_->Unload(); -} +void JSONFFIEngineImpl::Unload() { this->engine_->Unload(); } -void JSONFFIEngineImpl::Reset() { - this->engine_->Reset(); -} +void JSONFFIEngineImpl::Reset() { this->engine_->Reset(); } -void JSONFFIEngineImpl::RunBackgroundLoop() { - this->engine_->RunBackgroundLoop(); -} +void JSONFFIEngineImpl::RunBackgroundLoop() { this->engine_->RunBackgroundLoop(); } void JSONFFIEngineImpl::RunBackgroundStreamBackLoop() { - this->engine_->RunBackgroundStreamBackLoop(); + this->engine_->RunBackgroundStreamBackLoop(); } String JSONFFIEngineImpl::GetResponseFromStreamOutput(Array delta_outputs) { - - picojson::array json_response_arr; - for (const auto& delta_output : delta_outputs) { - std::string request_id = delta_output->request_id; - auto request_state_it = request_map_.find(request_id); - if (request_state_it == request_map_.end()) continue; - RequestState& rstate = request_state_it->second; - - // build the final usage messages - // invariant, we can always let other messages to come first - // then the final usage messages, as final usage is always last - if (delta_output->request_final_usage_json_str.defined()) { - ChatCompletionStreamResponse response; - response.id = request_id; - response.model = rstate.model; - response.system_fingerprint = ""; - std::string usage_json_str = delta_output->request_final_usage_json_str.value(); - picojson::value usage_json; - std::string err = picojson::parse(usage_json, usage_json_str); - if (!err.empty()) { - err_ = err; - } else { - response.usage = usage_json; - } - json_response_arr.push_back(picojson::value(response.AsJSON())); - request_map_.erase(request_state_it); - continue; - } - ICHECK_NE(delta_output->group_finish_reason.size(), 0); - ICHECK_EQ(delta_output->group_delta_token_ids.size(), - delta_output->group_finish_reason.size()); - ICHECK_EQ(delta_output->group_delta_token_ids.size(), rstate.streamer.size()); - + picojson::array json_response_arr; + for (const auto& delta_output : delta_outputs) { + std::string request_id = delta_output->request_id; + auto request_state_it = request_map_.find(request_id); + if (request_state_it == request_map_.end()) continue; + RequestState& rstate = request_state_it->second; + + // build the final usage messages + // invariant, we can always let other messages to come first + // then the final usage messages, as final usage is always last + if (delta_output->request_final_usage_json_str.defined()) { ChatCompletionStreamResponse response; response.id = request_id; response.model = rstate.model; response.system_fingerprint = ""; - - for (size_t i = 0; i < delta_output->group_finish_reason.size(); ++i) { - // choice - ChatCompletionStreamResponseChoice choice; - Optional finish_reason = delta_output->group_finish_reason[i]; - if (finish_reason.defined()) { - if (finish_reason.value() == "stop") { - choice.finish_reason = FinishReason::stop; - } else if (finish_reason.value() == "length") { - choice.finish_reason = FinishReason::length; - } else if (finish_reason.value() == "tool_calls") { - choice.finish_reason = FinishReason::tool_calls; - } else if (finish_reason.value() == "error") { - choice.finish_reason = FinishReason::error; - } - } else { - choice.finish_reason = std::nullopt; - } - choice.index = static_cast(i); - ChatCompletionMessage delta; - // Size of delta_output->group_delta_token_ids Array should be 1 - const IntTuple& delta_token_ids = delta_output->group_delta_token_ids[i]; - std::vector delta_token_ids_vec(delta_token_ids.begin(), delta_token_ids.end()); - std::string content = rstate.streamer[i]->Put(delta_token_ids_vec); - if (finish_reason.defined()) { - content += rstate.streamer[i]->Finish(); - } - if (!content.empty()) { - delta.content = content; - } - delta.role = "assistant"; - choice.delta = delta; - if (!choice.delta.content.IsNull() || choice.finish_reason.has_value()) { - response.choices.push_back(choice); + std::string usage_json_str = delta_output->request_final_usage_json_str.value(); + picojson::value usage_json; + std::string err = picojson::parse(usage_json, usage_json_str); + if (!err.empty()) { + err_ = err; + } else { + response.usage = usage_json; + } + json_response_arr.push_back(picojson::value(response.AsJSON())); + request_map_.erase(request_state_it); + continue; + } + ICHECK_NE(delta_output->group_finish_reason.size(), 0); + ICHECK_EQ(delta_output->group_delta_token_ids.size(), delta_output->group_finish_reason.size()); + ICHECK_EQ(delta_output->group_delta_token_ids.size(), rstate.streamer.size()); + + ChatCompletionStreamResponse response; + response.id = request_id; + response.model = rstate.model; + response.system_fingerprint = ""; + + for (size_t i = 0; i < delta_output->group_finish_reason.size(); ++i) { + // choice + ChatCompletionStreamResponseChoice choice; + Optional finish_reason = delta_output->group_finish_reason[i]; + if (finish_reason.defined()) { + if (finish_reason.value() == "stop") { + choice.finish_reason = FinishReason::stop; + } else if (finish_reason.value() == "length") { + choice.finish_reason = FinishReason::length; + } else if (finish_reason.value() == "tool_calls") { + choice.finish_reason = FinishReason::tool_calls; + } else if (finish_reason.value() == "error") { + choice.finish_reason = FinishReason::error; } + } else { + choice.finish_reason = std::nullopt; } - // if it is not the usage block, choices cannot be empty - if (!response.choices.empty()) { - json_response_arr.push_back(picojson::value(response.AsJSON())); + choice.index = static_cast(i); + ChatCompletionMessage delta; + // Size of delta_output->group_delta_token_ids Array should be 1 + const IntTuple& delta_token_ids = delta_output->group_delta_token_ids[i]; + std::vector delta_token_ids_vec(delta_token_ids.begin(), delta_token_ids.end()); + std::string content = rstate.streamer[i]->Put(delta_token_ids_vec); + if (finish_reason.defined()) { + content += rstate.streamer[i]->Finish(); + } + if (!content.empty()) { + delta.content = content; + } + delta.role = "assistant"; + choice.delta = delta; + if (!choice.delta.content.IsNull() || choice.finish_reason.has_value()) { + response.choices.push_back(choice); } } - return picojson::value(json_response_arr).serialize(); + // if it is not the usage block, choices cannot be empty + if (!response.choices.empty()) { + json_response_arr.push_back(picojson::value(response.AsJSON())); + } + } + return picojson::value(json_response_arr).serialize(); } Module JSONFFIEngineImpl::Create() { - auto n = make_object(); - return Module(n); + auto n = make_object(); + return Module(n); } // TVM_REGISTER_GLOBAL("mlc.json_ffi.CreateJSONFFIEngine").set_body_typed(JSONFFIEngineImpl::Create); diff --git a/cpp/json_ffi/json_ffi_engine.h b/cpp/json_ffi/json_ffi_engine.h index 10d841c90c..07cecd3bd0 100644 --- a/cpp/json_ffi/json_ffi_engine.h +++ b/cpp/json_ffi/json_ffi_engine.h @@ -69,38 +69,37 @@ class JSONFFIEngine { std::unordered_map request_map_; }; - - class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { -public: - static Module Create(); - - TVM_MODULE_VTABLE_BEGIN("mlc.json_ffi"); - TVM_MODULE_VTABLE_ENTRY("init_background_engine", &JSONFFIEngineImpl::InitBackgroundEngine); - TVM_MODULE_VTABLE_ENTRY("reload", &JSONFFIEngineImpl::Reload); - TVM_MODULE_VTABLE_ENTRY("unload", &JSONFFIEngineImpl::Unload); - TVM_MODULE_VTABLE_ENTRY("reset", &JSONFFIEngineImpl::Reset); - TVM_MODULE_VTABLE_ENTRY("chat_completion", &JSONFFIEngineImpl::ChatCompletion); - TVM_MODULE_VTABLE_ENTRY("abort", &JSONFFIEngineImpl::Abort); - TVM_MODULE_VTABLE_ENTRY("get_last_error", &JSONFFIEngineImpl::GetLastError); - TVM_MODULE_VTABLE_ENTRY("run_background_loop", &JSONFFIEngineImpl::RunBackgroundLoop); - TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop", - &JSONFFIEngineImpl::RunBackgroundStreamBackLoop); - TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); - TVM_MODULE_VTABLE_END(); - - void InitBackgroundEngine(int device_type, int device_id, Optional request_stream_callback); - void Reload(String engine_config_json_str); - void Unload(); - void Reset(); - void RunBackgroundLoop(); - void RunBackgroundStreamBackLoop(); - - // Implement the TVM_MODULE_VTABLE - // TVM_DEFINE_OBJECT_REF_METHODS(JSONFFIEngineImpl, ModuleNode, JSONFFIEngineImplNode); - -private: - String GetResponseFromStreamOutput(Array delta_outputs); + public: + static Module Create(); + + TVM_MODULE_VTABLE_BEGIN("mlc.json_ffi"); + TVM_MODULE_VTABLE_ENTRY("init_background_engine", &JSONFFIEngineImpl::InitBackgroundEngine); + TVM_MODULE_VTABLE_ENTRY("reload", &JSONFFIEngineImpl::Reload); + TVM_MODULE_VTABLE_ENTRY("unload", &JSONFFIEngineImpl::Unload); + TVM_MODULE_VTABLE_ENTRY("reset", &JSONFFIEngineImpl::Reset); + TVM_MODULE_VTABLE_ENTRY("chat_completion", &JSONFFIEngineImpl::ChatCompletion); + TVM_MODULE_VTABLE_ENTRY("abort", &JSONFFIEngineImpl::Abort); + TVM_MODULE_VTABLE_ENTRY("get_last_error", &JSONFFIEngineImpl::GetLastError); + TVM_MODULE_VTABLE_ENTRY("run_background_loop", &JSONFFIEngineImpl::RunBackgroundLoop); + TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop", + &JSONFFIEngineImpl::RunBackgroundStreamBackLoop); + TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); + TVM_MODULE_VTABLE_END(); + + void InitBackgroundEngine(int device_type, int device_id, + Optional request_stream_callback); + void Reload(String engine_config_json_str); + void Unload(); + void Reset(); + void RunBackgroundLoop(); + void RunBackgroundStreamBackLoop(); + + // Implement the TVM_MODULE_VTABLE + // TVM_DEFINE_OBJECT_REF_METHODS(JSONFFIEngineImpl, ModuleNode, JSONFFIEngineImplNode); + + private: + String GetResponseFromStreamOutput(Array delta_outputs); }; } // namespace json_ffi diff --git a/cpp/json_ffi/openai_api_protocol.cc b/cpp/json_ffi/openai_api_protocol.cc index 234685ec0f..9425ef4503 100644 --- a/cpp/json_ffi/openai_api_protocol.cc +++ b/cpp/json_ffi/openai_api_protocol.cc @@ -500,7 +500,8 @@ picojson::object ChatCompletionStreamResponseChoice::AsJSON() const { return obj; } -Result ChatCompletionStreamResponseChoice::FromJSON(const picojson::object& json_obj) { +Result ChatCompletionStreamResponseChoice::FromJSON( + const picojson::object& json_obj) { using TResult = Result; ChatCompletionStreamResponseChoice choice; @@ -512,7 +513,8 @@ Result ChatCompletionStreamResponseChoice::F // choice.index = index_res.Unwrap(); // delta - Result delta_obj_res = json::LookupWithResultReturn(json_obj, "delta"); + Result delta_obj_res = + json::LookupWithResultReturn(json_obj, "delta"); if (delta_obj_res.IsErr()) { return TResult::Error(delta_obj_res.UnwrapErr()); } @@ -523,8 +525,9 @@ Result ChatCompletionStreamResponseChoice::F choice.delta = delta_res.Unwrap(); // // finish_reason (optional) - // Result> finish_reason_res = json::LookupOptionalWithResultReturn(json_obj, "finish_reason"); - // if (finish_reason_res.IsErr()) { + // Result> finish_reason_res = + // json::LookupOptionalWithResultReturn(json_obj, "finish_reason"); if + // (finish_reason_res.IsErr()) { // return TResult::Error(finish_reason_res.UnwrapErr()); // } // std::optional finish_reason_str = finish_reason_res.Unwrap(); @@ -580,7 +583,8 @@ picojson::object ChatCompletionStreamResponse::AsJSON() const { return obj; } -Result ChatCompletionStreamResponse::FromJSON(const std::string& json_str) { +Result ChatCompletionStreamResponse::FromJSON( + const std::string& json_str) { using TResult = Result; Result json_obj_res = json::ParseToJSONObjectWithResultReturn(json_str); if (json_obj_res.IsErr()) { @@ -618,14 +622,16 @@ Result ChatCompletionStreamResponse::FromJSON(cons // response.model = model_res.Unwrap(); // // system_fingerprint - // Result system_fingerprint_res = json::LookupWithResultReturn(json_obj, "system_fingerprint"); - // if (system_fingerprint_res.IsErr()) { + // Result system_fingerprint_res = + // json::LookupWithResultReturn(json_obj, "system_fingerprint"); if + // (system_fingerprint_res.IsErr()) { // return TResult::Error(system_fingerprint_res.UnwrapErr()); // } // response.system_fingerprint = system_fingerprint_res.Unwrap(); // choices - Result choices_arr_res = json::LookupWithResultReturn(json_obj, "choices"); + Result choices_arr_res = + json::LookupWithResultReturn(json_obj, "choices"); if (choices_arr_res.IsErr()) { return TResult::Error(choices_arr_res.UnwrapErr()); } @@ -634,7 +640,8 @@ Result ChatCompletionStreamResponse::FromJSON(cons if (!item.is()) { return TResult::Error("A choice in chat completion stream response is not an object"); } - Result choice = ChatCompletionStreamResponseChoice::FromJSON(item.get()); + Result choice = + ChatCompletionStreamResponseChoice::FromJSON(item.get()); if (choice.IsErr()) { return TResult::Error(choice.UnwrapErr()); } @@ -643,8 +650,9 @@ Result ChatCompletionStreamResponse::FromJSON(cons response.choices = choices; // // usage (optional) - // Result> usage_res = json::LookupOptionalWithResultReturn(json_obj, "usage"); - // if (usage_res.IsErr()) { + // Result> usage_res = + // json::LookupOptionalWithResultReturn(json_obj, "usage"); if + // (usage_res.IsErr()) { // return TResult::Error(usage_res.UnwrapErr()); // } // std::optional usage_obj = usage_res.Unwrap(); From 8a30d62f03dbc0515067c4a1a9b79d389c7b9f2e Mon Sep 17 00:00:00 2001 From: Piotr Frankowski Date: Fri, 18 Oct 2024 19:15:55 +0200 Subject: [PATCH 4/4] [TYPO] clang-format of llm_benchmark --- cpp/llm_benchmark.cpp | 253 +++++++++++++++++++++--------------------- 1 file changed, 124 insertions(+), 129 deletions(-) diff --git a/cpp/llm_benchmark.cpp b/cpp/llm_benchmark.cpp index 31b315ef39..1320ebcda3 100644 --- a/cpp/llm_benchmark.cpp +++ b/cpp/llm_benchmark.cpp @@ -1,12 +1,13 @@ -#include #include +#include + +#include // for exit() #include #include -#include // for exit() + #include "json_ffi/json_ffi_engine.h" #include "support/result.h" - using namespace tvm::runtime; using namespace mlc::llm::json_ffi; @@ -17,147 +18,141 @@ int glob_max_tokens = 0; int iterations = 1; void exitProgram() { - std::this_thread::sleep_for(std::chrono::seconds(3)); // Sleep for 3 seconds - std::cout << "Exiting program after 3 seconds..." << std::endl; - exit(0); // Exit the program + std::this_thread::sleep_for(std::chrono::seconds(3)); // Sleep for 3 seconds + std::cout << "Exiting program after 3 seconds..." << std::endl; + exit(0); // Exit the program } // Define the callback function that processes the responses void RequestStreamCallback(const std::string& response) { - mlc::llm::Result stream_response_result = - ChatCompletionStreamResponse::FromJSON( - response.substr(1, response.size() - 2) - ); - - if (stream_response_result.IsOk()) { - auto unwrp_res = stream_response_result.Unwrap(); - - if (unwrp_res.choices.size() > 1) { - std::cerr << response << "(!!!More choices!!!)\n"; - } else { - std::string chunk_text = "."; - if (iterations == 1) { - chunk_text = unwrp_res.choices[0].delta.content.Text(); - } - std::cout << chunk_text; - counter++; - if (counter >= (glob_max_tokens - 20)) { - end = std::chrono::high_resolution_clock::now(); - } - } - } else { - std::string chunk_text = stream_response_result.UnwrapErr(); - std::cerr << "Error parsing response." + chunk_text + "\n"; - std::cerr << response << "\n"; - } -} - + mlc::llm::Result stream_response_result = + ChatCompletionStreamResponse::FromJSON(response.substr(1, response.size() - 2)); -void benchmark_llm( - const std::string& model_path, - const std::string& modellib_path, - const std::string& mode, - const int device_type, - const int timeout, - const std::string& input_text) { + if (stream_response_result.IsOk()) { + auto unwrp_res = stream_response_result.Unwrap(); - int device_id = 0; - tvm::runtime::PackedFunc request_stream_callback = tvm::runtime::PackedFunc( - [](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue* ret) { - std::string response = args[0]; - RequestStreamCallback(response); - } - ); - - - const PackedFunc* create_engine_func = tvm::runtime::Registry::Get("mlc.json_ffi.CreateJSONFFIEngine"); - - if (create_engine_func == nullptr) { - throw std::runtime_error("Cannot find mlc.json_ffi.CreateJSONFFIEngine in the registry."); - } - - // Call the function and get the module (which holds JSONFFIEngineImpl) - Module engine_mod = (*create_engine_func)(); - - // Cast the module to JSONFFIEngineImpl - auto* engine = dynamic_cast(engine_mod.operator->()); - if (!engine) { - throw std::runtime_error("Failed to cast to JSONFFIEngineImpl."); + if (unwrp_res.choices.size() > 1) { + std::cerr << response << "(!!!More choices!!!)\n"; + } else { + std::string chunk_text = "."; + if (iterations == 1) { + chunk_text = unwrp_res.choices[0].delta.content.Text(); + } + std::cout << chunk_text; + counter++; + if (counter >= (glob_max_tokens - 20)) { + end = std::chrono::high_resolution_clock::now(); + } } + } else { + std::string chunk_text = stream_response_result.UnwrapErr(); + std::cerr << "Error parsing response." + chunk_text + "\n"; + std::cerr << response << "\n"; + } +} - engine->InitBackgroundEngine(device_type, device_id, request_stream_callback); - - std::thread background_stream_back_loop([&engine]() { - engine->RunBackgroundStreamBackLoop(); - }); - - std::thread background_loop([&engine]() { - engine->RunBackgroundLoop(); - }); - - - // Now call the Reload function - std::string engine_json = "{\"model\":\"" + model_path + "\", \"model_lib\":\"" + modellib_path + "\", \"mode\": \"" + mode + "\"}"; - std::cerr << engine_json << std::endl; - engine->Reload(engine_json); - std::cerr << "\engine->Reload\n"; - - // Prepare input - std::string request_json = "{\"messages\":[{\"role\":\"user\",\"content\":\"" + input_text + "\"}], \"max_tokens\": " + std::to_string(glob_max_tokens) + "}"; - std::string request_id = "benchmark_request"; - - for(int i = 1; i <= iterations; i++) { - counter = 0; - // Measure inference time - start = std::chrono::high_resolution_clock::now(); - engine->ChatCompletion(request_json, request_id); - - // std::cerr << "\nRunning in background. Sleeping main thread " + timeout + "s... Wait for text response (3s - 1m).\n\n"; - std::this_thread::sleep_for(std::chrono::seconds(timeout)); - // std::cerr << "\nWakeup...\n"; - - // std::cerr << "\nAborting...\n"; - engine->Abort(request_id); - std::this_thread::sleep_for(std::chrono::seconds(3)); - - std::cerr << i << " Max tokens:" << glob_max_tokens << "; Counter:" << counter << "\n\n"; +void benchmark_llm(const std::string& model_path, const std::string& modellib_path, + const std::string& mode, const int device_type, const int timeout, + const std::string& input_text) { + int device_id = 0; + tvm::runtime::PackedFunc request_stream_callback = + tvm::runtime::PackedFunc([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue* ret) { + std::string response = args[0]; + RequestStreamCallback(response); + }); + + const PackedFunc* create_engine_func = + tvm::runtime::Registry::Get("mlc.json_ffi.CreateJSONFFIEngine"); + + if (create_engine_func == nullptr) { + throw std::runtime_error("Cannot find mlc.json_ffi.CreateJSONFFIEngine in the registry."); + } + + // Call the function and get the module (which holds JSONFFIEngineImpl) + Module engine_mod = (*create_engine_func)(); + + // Cast the module to JSONFFIEngineImpl + auto* engine = dynamic_cast(engine_mod.operator->()); + if (!engine) { + throw std::runtime_error("Failed to cast to JSONFFIEngineImpl."); + } + + engine->InitBackgroundEngine(device_type, device_id, request_stream_callback); + + std::thread background_stream_back_loop([&engine]() { engine->RunBackgroundStreamBackLoop(); }); + + std::thread background_loop([&engine]() { engine->RunBackgroundLoop(); }); + + // Now call the Reload function + std::string engine_json = "{\"model\":\"" + model_path + "\", \"model_lib\":\"" + modellib_path + + "\", \"mode\": \"" + mode + "\"}"; + std::cerr << engine_json << std::endl; + engine->Reload(engine_json); + std::cerr << "\engine->Reload\n"; + + // Prepare input + std::string request_json = "{\"messages\":[{\"role\":\"user\",\"content\":\"" + input_text + + "\"}], \"max_tokens\": " + std::to_string(glob_max_tokens) + "}"; + std::string request_id = "benchmark_request"; + + for (int i = 1; i <= iterations; i++) { + counter = 0; + // Measure inference time + start = std::chrono::high_resolution_clock::now(); + engine->ChatCompletion(request_json, request_id); + + // std::cerr << "\nRunning in background. Sleeping main thread " + timeout + "s... Wait for text + // response (3s - 1m).\n\n"; + std::this_thread::sleep_for(std::chrono::seconds(timeout)); + // std::cerr << "\nWakeup...\n"; + + // std::cerr << "\nAborting...\n"; + engine->Abort(request_id); + std::this_thread::sleep_for(std::chrono::seconds(3)); - std::chrono::duration elapsed = end - start; - std::cerr << i << " Inference time: " << elapsed.count() << " seconds" << std::endl; + std::cerr << i << " Max tokens:" << glob_max_tokens << "; Counter:" << counter << "\n\n"; - std::cerr << i << " End-to-end decoded avg token/s: " << std::to_string(counter / elapsed.count()) << "\n"; - } + std::chrono::duration elapsed = end - start; + std::cerr << i << " Inference time: " << elapsed.count() << " seconds" << std::endl; + std::cerr << i + << " End-to-end decoded avg token/s: " << std::to_string(counter / elapsed.count()) + << "\n"; + } - engine->ExitBackgroundLoop(); - std::this_thread::sleep_for(std::chrono::seconds(3)); + engine->ExitBackgroundLoop(); + std::this_thread::sleep_for(std::chrono::seconds(3)); - background_stream_back_loop.join(); - background_loop.join(); + background_stream_back_loop.join(); + background_loop.join(); - std::cerr << "engine->Unload\n"; - std::thread(exitProgram).detach(); - engine->Unload(); - return; + std::cerr << "engine->Unload\n"; + std::thread(exitProgram).detach(); + engine->Unload(); + return; } int main(int argc, char* argv[]) { - if (argc != 9) { - std::cerr << "Usage: " << argv[0] << " <1:model_path:str> <2:model_lib_path:str> <3:mode:str> <4:device_type:int> <5:timeout:int> <6:max_tokens:int> <7:input_text:str> <8:iterations>" << std::endl - << "Device types: kDLCPU = 1; kDLOpenCL = 4; kDLVulkan = 7;\n" << "Be carefull with number of iterations.\n 1 iteration gives you text outputs."; - return 1; - } - - std::string model_path = argv[1]; - std::string modellib_path = argv[2]; - std::string mode = argv[3]; - int device_type = std::stoi(argv[4]); - int timeout = std::stoi(argv[5]); - glob_max_tokens = std::stoi(argv[6]); - std::string input_text = argv[7]; - iterations = std::stoi(argv[8]); - - benchmark_llm(model_path, modellib_path, mode, device_type, timeout, input_text); - - return 0; + if (argc != 9) { + std::cerr << "Usage: " << argv[0] + << " <1:model_path:str> <2:model_lib_path:str> <3:mode:str> <4:device_type:int> " + "<5:timeout:int> <6:max_tokens:int> <7:input_text:str> <8:iterations>" + << std::endl + << "Device types: kDLCPU = 1; kDLOpenCL = 4; kDLVulkan = 7;\n" + << "Be carefull with number of iterations.\n 1 iteration gives you text outputs."; + return 1; + } + + std::string model_path = argv[1]; + std::string modellib_path = argv[2]; + std::string mode = argv[3]; + int device_type = std::stoi(argv[4]); + int timeout = std::stoi(argv[5]); + glob_max_tokens = std::stoi(argv[6]); + std::string input_text = argv[7]; + iterations = std::stoi(argv[8]); + + benchmark_llm(model_path, modellib_path, mode, device_type, timeout, input_text); + + return 0; } \ No newline at end of file