From 2a23daf5cdafd2435763fdea68fb05307622ae6b Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sat, 28 Oct 2023 13:37:32 -0700 Subject: [PATCH] refactor: switch back to llama batch interface --- crates/llama-cpp-bindings/src/engine.cc | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 3b5caaa939b..b92ad6a7d42 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -21,6 +21,11 @@ class TextInferenceEngineImpl : public TextInferenceEngine { TextInferenceEngineImpl(owned model, owned ctx) : model_(std::move(model)), ctx_(std::move(ctx)) { + batch_ = llama_batch_init(N_BATCH, 0, 1); + } + + ~TextInferenceEngineImpl() override { + llama_batch_free(batch_); } void start(rust::Slice input_token_ids) override { @@ -52,7 +57,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { uint32_t sample() const { auto* ctx = ctx_.get(); - auto logits = llama_get_logits_ith(ctx, 0); + auto logits = llama_get_logits_ith(ctx, batch_.n_tokens - 1); auto n_vocab = llama_n_vocab(llama_get_model(ctx)); // Greedy sampling (always select the highest logit). @@ -64,9 +69,19 @@ class TextInferenceEngineImpl : public TextInferenceEngine { n_past_ = 0; } + batch_.n_tokens = size; + for (size_t i = 0; i < size; ++i) { + batch_.token[i] = data[i]; + batch_.pos[i] = n_past_ + i; + batch_.n_seq_id[i] = 1; + batch_.seq_id[i][0] = 0; + batch_.logits[i] = false; + } + batch_.logits[size - 1] = true; + auto* ctx = ctx_.get(); llama_kv_cache_tokens_rm(ctx, n_past_, -1); - if (llama_decode(ctx, llama_batch_get_one(data, size, n_past_, 0))) { + if (llama_decode(ctx, batch_)) { throw std::runtime_error("Failed to eval"); } @@ -76,6 +91,8 @@ class TextInferenceEngineImpl : public TextInferenceEngine { size_t n_past_; owned model_; owned ctx_; + + llama_batch batch_; }; static int g_llama_cpp_log_level = 0;