Skip to content

Commit

Permalink
refactor: switch back to llama batch interface
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys committed Oct 28, 2023
1 parent 43cc5f3 commit 2a23daf
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions crates/llama-cpp-bindings/src/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
model_(std::move(model)),
ctx_(std::move(ctx)) {
batch_ = llama_batch_init(N_BATCH, 0, 1);
}

~TextInferenceEngineImpl() override {
llama_batch_free(batch_);
}

void start(rust::Slice<const uint32_t> input_token_ids) override {
Expand Down Expand Up @@ -52,7 +57,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
uint32_t sample() const {
auto* ctx = ctx_.get();

auto logits = llama_get_logits_ith(ctx, 0);
auto logits = llama_get_logits_ith(ctx, batch_.n_tokens - 1);
auto n_vocab = llama_n_vocab(llama_get_model(ctx));

// Greedy sampling (always select the highest logit).
Expand All @@ -64,9 +69,19 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
n_past_ = 0;
}

batch_.n_tokens = size;
for (size_t i = 0; i < size; ++i) {
batch_.token[i] = data[i];
batch_.pos[i] = n_past_ + i;
batch_.n_seq_id[i] = 1;
batch_.seq_id[i][0] = 0;
batch_.logits[i] = false;
}
batch_.logits[size - 1] = true;

auto* ctx = ctx_.get();
llama_kv_cache_tokens_rm(ctx, n_past_, -1);
if (llama_decode(ctx, llama_batch_get_one(data, size, n_past_, 0))) {
if (llama_decode(ctx, batch_)) {
throw std::runtime_error("Failed to eval");
}

Expand All @@ -76,6 +91,8 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
size_t n_past_;
owned<llama_model> model_;
owned<llama_context> ctx_;

llama_batch batch_;
};

static int g_llama_cpp_log_level = 0;
Expand Down

0 comments on commit 2a23daf

Please sign in to comment.