Skip to content

Commit

Permalink
sync with upstream llama.cpp source code (#147)
Browse files Browse the repository at this point in the history
* sync with upstream llama.cpp source code

* add isTokenEOG()

* v2.1.2

* sync
  • Loading branch information
ngxson authored Jan 12, 2025
1 parent 86138c8 commit 30adc2a
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 44 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ let wllamaInstance = new Wllama(WLLAMA_CONFIG_PATHS, ...);
// (the rest is the same with earlier example)
```

For complete code example, see [examples/main/utils/wllama.context.tsx](./examples/main/utils/wllama.context.tsx)
For complete code example, see [examples/main/src/utils/wllama.context.tsx](./examples/main/src/utils/wllama.context.tsx)

NOTE: this example only covers completions usage. For embeddings, please see [examples/embeddings/index.html](./examples/embeddings/index.html)

Expand Down
68 changes: 39 additions & 29 deletions actions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ struct app_t
{
llama_model *model;
llama_context *ctx;
const llama_vocab *vocab;
common_sampler *ctx_sampling = nullptr;
llama_batch batch = llama_batch_init(512, 0, 1);
std::vector<llama_token> tokens;
llama_tokens tokens;
int32_t seed = LLAMA_DEFAULT_SEED;
};

Expand Down Expand Up @@ -119,7 +120,7 @@ void free_all(app_t &app)
if (app.ctx != nullptr)
llama_free(app.ctx);
if (app.model != nullptr)
llama_free_model(app.model);
llama_model_free(app.model);
if (app.ctx_sampling != nullptr)
common_sampler_free(app.ctx_sampling);
}
Expand Down Expand Up @@ -210,15 +211,16 @@ json action_load(app_t &app, json &body)
cparams.type_k = kv_cache_type_from_str(body["cache_type_k"]);
if (body.contains("cache_type_v"))
cparams.type_k = kv_cache_type_from_str(body["cache_type_v"]);
app.model = llama_load_model_from_file(model_path.c_str(), mparams);
app.model = llama_model_load_from_file(model_path.c_str(), mparams);
if (app.model == nullptr)
{
free_all(app);
throw app_exception("Error while loading model");
}
app.vocab = llama_model_get_vocab(app.model);
for (; cparams.n_ctx > 0; cparams.n_ctx -= 1024)
{
app.ctx = llama_new_context_with_model(app.model, cparams);
app.ctx = llama_init_from_model(app.model, cparams);
if (app.ctx != nullptr)
{
break; // OK
Expand All @@ -244,23 +246,31 @@ json action_load(app_t &app, json &body)
auto decoder_start_token = llama_model_decoder_start_token(app.model);
if (decoder_start_token < 0)
{
decoder_start_token = llama_token_bos(app.model);
decoder_start_token = llama_vocab_bos(app.vocab);
}
int n_vocab = llama_vocab_n_tokens(app.vocab);
llama_tokens list_tokens_eog;
for (int i = 0; i < n_vocab; i++) {
if (llama_vocab_is_eog(app.vocab, i)) {
list_tokens_eog.push_back(i);
}
}
return json{
{"success", true},
{"n_ctx", cparams.n_ctx},
{"n_batch", llama_n_batch(app.ctx)},
{"n_ubatch", llama_n_ubatch(app.ctx)},
{"n_vocab", llama_n_vocab(app.model)},
{"n_ctx_train", llama_n_ctx_train(app.model)},
{"n_embd", llama_n_embd(app.model)},
{"n_layer", llama_n_layer(app.model)},
{"n_vocab", n_vocab},
{"n_ctx_train", llama_model_n_ctx_train(app.model)},
{"n_embd", llama_model_n_embd(app.model)},
{"n_layer", llama_model_n_layer(app.model)},
{"metadata", dump_metadata(app)},
{"token_bos", llama_token_bos(app.model)},
{"token_eos", llama_token_eos(app.model)},
{"token_eot", llama_token_eot(app.model)},
{"add_bos_token", llama_add_bos_token(app.model) == 1},
{"add_eos_token", llama_add_eos_token(app.model) == 1},
{"token_bos", llama_vocab_bos(app.vocab)},
{"token_eos", llama_vocab_eos(app.vocab)},
{"token_eot", llama_vocab_eot(app.vocab)},
{"list_tokens_eog", list_tokens_eog},
{"add_bos_token", llama_vocab_get_add_bos(app.vocab) == 1},
{"add_eos_token", llama_vocab_get_add_eos(app.vocab) == 1},
{"has_encoder", llama_model_has_encoder(app.model)},
{"token_decoder_start", llama_model_decoder_start_token(app.model)},
};
Expand Down Expand Up @@ -348,7 +358,7 @@ json action_sampling_init(app_t &app, json &body)
app.ctx_sampling = common_sampler_init(app.model, sparams);
if (body.contains("tokens"))
{
std::vector<llama_token> tokens = body["tokens"];
llama_tokens tokens = body["tokens"];
for (auto id : tokens)
{
common_sampler_accept(app.ctx_sampling, id, false);
Expand All @@ -360,7 +370,7 @@ json action_sampling_init(app_t &app, json &body)
// get map token ID to vocab (be careful, it is slow!)
json action_get_vocab(app_t &app, json &body)
{
int32_t max_tokens = llama_n_vocab(app.model);
int32_t max_tokens = llama_vocab_n_tokens(app.vocab);
std::vector<std::vector<unsigned int>> vocab(max_tokens);
for (int32_t id = 0; id < max_tokens; id++)
{
Expand All @@ -377,7 +387,7 @@ json action_get_vocab(app_t &app, json &body)
json action_lookup_token(app_t &app, json &body)
{
std::string piece = body["piece"];
int32_t max_tokens = llama_n_vocab(app.model);
int32_t max_tokens = llama_vocab_n_tokens(app.vocab);
for (int32_t id = 0; id < max_tokens; id++)
{
std::string token_as_str = common_token_to_piece(app.ctx, id);
Expand All @@ -398,8 +408,8 @@ json action_tokenize(app_t &app, json &body)
{
std::string text = body["text"];
bool special = body.contains("special");
std::vector<llama_token> tokens_list;
tokens_list = common_tokenize(app.model, text, false, special);
llama_tokens tokens_list;
tokens_list = common_tokenize(app.vocab, text, false, special);
return json{
{"success", true},
{"tokens", tokens_list},
Expand All @@ -409,7 +419,7 @@ json action_tokenize(app_t &app, json &body)
// detokenize a list of tokens
json action_detokenize(app_t &app, json &body)
{
std::vector<llama_token> tokens = body["tokens"];
llama_tokens tokens = body["tokens"];
std::stringstream output;
for (auto id : tokens)
{
Expand All @@ -425,7 +435,7 @@ json action_detokenize(app_t &app, json &body)
// decode an array of tokens
json action_decode(app_t &app, json &body)
{
std::vector<llama_token> tokens_list = body["tokens"];
llama_tokens tokens_list = body["tokens"];
bool skip_logits = body.contains("skip_logits")
? body.at("skip_logits").get<bool>()
: false;
Expand Down Expand Up @@ -460,7 +470,7 @@ json action_decode(app_t &app, json &body)
// encode an array of tokens
json action_encode(app_t &app, json &body)
{
std::vector<llama_token> tokens_list = body["tokens"];
llama_tokens tokens_list = body["tokens"];
if (!llama_model_has_encoder(app.model))
{
return json{{"error", "this model does not have an encoder"}};
Expand Down Expand Up @@ -501,7 +511,7 @@ json action_sampling_sample(app_t &app, json &body)
// accept this token
json action_sampling_accept(app_t &app, json &body)
{
std::vector<llama_token> tokens_list = body["tokens"];
llama_tokens tokens_list = body["tokens"];
for (auto id : tokens_list)
{
common_sampler_accept(app.ctx_sampling, id, false);
Expand All @@ -515,7 +525,7 @@ json action_get_logits(app_t &app, json &body)
int top_k = body["top_k"]; // if is -1, we take all logits (will be slow!)
int32_t idx = app.batch.n_tokens - 1;
float *logits = llama_get_logits_ith(app.ctx, idx);
int32_t n_vocab = llama_n_vocab(app.model);
int32_t n_vocab = llama_vocab_n_tokens(app.vocab);
auto sort_fn = [](llama_token_data &a, llama_token_data &b) -> bool
{
return b.logit < a.logit;
Expand Down Expand Up @@ -555,9 +565,9 @@ json action_get_logits(app_t &app, json &body)
// get embeddings, this will call action_decode internally
json action_embeddings(app_t &app, json &body)
{
std::vector<llama_token> tokens_list = body["tokens"];
llama_tokens tokens_list = body["tokens"];
// allocate output
const int n_embd = llama_n_embd(app.model);
const int n_embd = llama_model_n_embd(app.model);
std::vector<float> embeddings(n_embd, 0); // single seq
float *out = embeddings.data();
// decode
Expand Down Expand Up @@ -645,7 +655,7 @@ json action_kv_clear(app_t &app, json &body)
json action_session_save(app_t &app, json &body)
{
std::string session_path = body["session_path"];
std::vector<llama_token> dummy;
llama_tokens dummy;
if (!llama_state_seq_save_file(
app.ctx,
session_path.c_str(),
Expand All @@ -666,10 +676,10 @@ json action_session_save(app_t &app, json &body)
json action_session_load(app_t &app, json &body)
{
std::string session_path = body["session_path"];
std::vector<llama_token> saved_tokens = body["tokens"];
llama_tokens saved_tokens = body["tokens"];
auto n_ctx = llama_n_ctx(app.ctx);
size_t n_token_count_out = 0;
std::vector<llama_token> dummy;
llama_tokens dummy;
if (!llama_state_seq_load_file(
app.ctx,
session_path.c_str(),
Expand Down
2 changes: 1 addition & 1 deletion llama.cpp
Submodule llama.cpp updated 220 files
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@wllama/wllama",
"version": "2.1.1",
"version": "2.1.2",
"description": "WebAssembly binding for llama.cpp - Enabling on-browser LLM inference",
"main": "index.js",
"type": "module",
Expand Down
2 changes: 1 addition & 1 deletion src/multi-thread/wllama.js

Large diffs are not rendered by default.

Binary file modified src/multi-thread/wllama.wasm
Binary file not shown.
2 changes: 1 addition & 1 deletion src/single-thread/wllama.js

Large diffs are not rendered by default.

Binary file modified src/single-thread/wllama.wasm
Binary file not shown.
4 changes: 2 additions & 2 deletions src/wasm-from-cdn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// Do not edit this file directly

const WasmFromCDN = {
'single-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/[email protected].1/src/single-thread/wllama.wasm',
'multi-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/[email protected].1/src/multi-thread/wllama.wasm',
'single-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/[email protected].2/src/single-thread/wllama.wasm',
'multi-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/[email protected].2/src/multi-thread/wllama.wasm',
};

export default WasmFromCDN;
25 changes: 19 additions & 6 deletions src/wllama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ export interface LoadedContextInfo {
token_bos: number;
token_eos: number;
token_eot: number;
list_tokens_eog: number[];
has_encoder: boolean;
token_decoder_start: number;
add_bos_token: boolean;
Expand Down Expand Up @@ -215,6 +216,7 @@ export class Wllama {
private bosToken: number = -1;
private eosToken: number = -1;
private eotToken: number = -1;
private eogTokens: Set<number> = new Set();
private addBosToken: boolean = false;
private addEosToken: boolean = false;
private chatTemplate?: string;
Expand Down Expand Up @@ -293,6 +295,20 @@ export class Wllama {
return this.eotToken;
}

/**
* Check if a given token is end-of-generation token (e.g. EOS, EOT, etc.)
*
* @param token the token ID to be checked
* @returns true if the token is EOS, EOT, or any other end-of-generation tokens
*/
isTokenEOG(token: number): boolean {
return (
token === this.eosToken ||
token === this.eotToken ||
this.eogTokens.has(token)
);
}

/**
* Get token ID associated to token used by decoder, to start generating output sequence(only usable for encoder-decoder architecture). In other words, encoder uses normal BOS and decoder uses this token.
*
Expand Down Expand Up @@ -531,6 +547,7 @@ export class Wllama {
this.addEosToken = loadResult.add_eos_token;
this.chatTemplate = loadResult.metadata['tokenizer.chat_template'];
this.loadedContextInfo = loadResult;
this.eogTokens = new Set(loadResult.list_tokens_eog);
this.logger().debug({ loadResult });
}

Expand Down Expand Up @@ -608,11 +625,7 @@ export class Wllama {
this.checkModelLoaded();
this.samplingConfig = options.sampling ?? {};
await this.samplingInit(this.samplingConfig);
const stopTokens = [
this.eosToken,
this.eotToken,
...(options.stopTokens ?? []),
];
const stopTokens = new Set(options.stopTokens ?? []);
// process prompt
let tokens = await this.tokenize(prompt, true);
if (this.addBosToken && tokens[0] !== this.bosToken) {
Expand Down Expand Up @@ -641,7 +654,7 @@ export class Wllama {
// predict next tokens
for (let i = 0; i < (options.nPredict ?? Infinity); i++) {
const sampled = await this.samplingSample();
if (stopTokens.includes(sampled.token)) {
if (this.isTokenEOG(sampled.token) || stopTokens.has(sampled.token)) {
break; // stop token
}
// @ts-ignore Type 'Uint8Array<ArrayBufferLike>' is not assignable to type 'Uint8Array<ArrayBuffer>'
Expand Down
4 changes: 2 additions & 2 deletions src/workers-code/generated.ts

Large diffs are not rendered by default.

0 comments on commit 30adc2a

Please sign in to comment.