Skip to content

Commit

Permalink
Support llama_encode (WIP) (#91)
Browse files Browse the repository at this point in the history
* add support for llama_encoder

* add encoder to createCompletion
  • Loading branch information
ngxson authored Jul 10, 2024
1 parent 8bf1b80 commit 97db6f5
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 23 deletions.
55 changes: 36 additions & 19 deletions actions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ struct app_t
struct llama_sampling_context *ctx_sampling = nullptr;
llama_batch batch = llama_batch_init(512, 0, 1);
std::vector<llama_token> tokens;
// group attention
int32_t ga_i = 0; // group-attention state
int32_t ga_n = 0; // group-attention factor
int32_t ga_w = 0; // group-attention width
int32_t n_past_self_extension = 0;
};

inline void send_response(json data)
Expand Down Expand Up @@ -208,11 +203,6 @@ json action_load(app_t &app, json &body)
cparams.yarn_beta_slow = body["yarn_beta_slow"];
if (body.contains("yarn_orig_ctx"))
cparams.yarn_orig_ctx = body["yarn_orig_ctx"];
// group attention
if (body.contains("grp_attn_n"))
app.ga_n = body["grp_attn_n"];
if (body.contains("grp_attn_w"))
app.ga_w = body["grp_attn_w"];
// optimizations
if (body.contains("cache_type_k"))
cparams.type_k = kv_cache_type_from_str(body["cache_type_k"]);
Expand Down Expand Up @@ -249,6 +239,11 @@ json action_load(app_t &app, json &body)
}
llama_batch_free(app.batch);
app.batch = llama_batch_init(cparams.n_batch, 0, 1);
auto decoder_start_token = llama_model_decoder_start_token(app.model);
if (decoder_start_token < 0)
{
decoder_start_token = llama_token_bos(app.model);
}
return json{
{"success", true},
{"n_ctx", cparams.n_ctx},
Expand All @@ -260,6 +255,8 @@ json action_load(app_t &app, json &body)
{"token_bos", llama_token_bos(app.model)},
{"token_eos", llama_token_eos(app.model)},
{"token_eot", llama_token_eot(app.model)},
{"has_encoder", llama_model_has_encoder(app.model)},
{"token_decoder_start", llama_model_decoder_start_token(app.model)},
};
}

Expand Down Expand Up @@ -423,23 +420,15 @@ json action_decode(app_t &app, json &body)
{
std::vector<llama_token> tokens_list = body["tokens"];
bool skip_logits = body.contains("skip_logits");
/*bool grp_attn_enabled = app.ga_n > 1;
if (grp_attn_enabled)
{
group_attention_shift_context(app);
}*/
size_t i = 0;
llama_batch_clear(app.batch);
for (auto id : tokens_list)
{
bool grp_attn_enabled = false; // TODO: maybe remove grp_attn
int32_t n_past = grp_attn_enabled
? app.n_past_self_extension
: app.tokens.size();
int32_t n_past = app.tokens.size();
llama_batch_add(app.batch, id, n_past, {0}, false);
app.tokens.push_back(id);
i++;
app.n_past_self_extension++;
}
// llama_decode will output logits only for the last token of the prompt
if (!skip_logits)
Expand All @@ -459,6 +448,34 @@ json action_decode(app_t &app, json &body)
}
}

// encode an array of tokens
json action_encode(app_t &app, json &body)
{
std::vector<llama_token> tokens_list = body["tokens"];
if (!llama_model_has_encoder(app.model))
{
return json{{"error", "this model does not have an encoder"}};
}
size_t n_past = 0;
llama_batch_clear(app.batch);
for (auto id : tokens_list)
{
llama_batch_add(app.batch, id, n_past, {0}, false);
n_past++;
}
if (llama_encode(app.ctx, app.batch) != 0)
{
return json{{"error", "llama_encode failed, maybe n_batch is too small?"}};
}
else
{
return json{
{"success", true},
{"n_past", n_past},
};
}
}

// decode the current logits and sample the new token
json action_sampling_sample(app_t &app, json &body)
{
Expand Down
2 changes: 1 addition & 1 deletion llama.cpp
Submodule llama.cpp updated 172 files
2 changes: 1 addition & 1 deletion src/multi-thread/wllama.js

Large diffs are not rendered by default.

Binary file modified src/multi-thread/wllama.wasm
Binary file not shown.
2 changes: 1 addition & 1 deletion src/single-thread/wllama.js

Large diffs are not rendered by default.

Binary file modified src/single-thread/wllama.wasm
Binary file not shown.
61 changes: 60 additions & 1 deletion src/wllama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ export class Wllama {
private eotToken: number = -1;
private metadata?: ModelMetadata;
private samplingConfig: SamplingConfig = {};
private hasEncoder: boolean = false;
private decoderStartToken: number = -1;

constructor(pathConfig: AssetsPathConfig, wllamaConfig: WllamaConfig = {}) {
checkEnvironmentCompatible();
Expand Down Expand Up @@ -205,6 +207,17 @@ export class Wllama {
return this.eotToken;
}

/**
* Get token ID associated to token used by decoder, to start generating output sequence(only usable for encoder-decoder architecture). In other words, encoder uses normal BOS and decoder uses this token.
*
* NOTE: This can only being used after `loadModel` is called.
*
* @returns -1 if the model is not loaded.
*/
getDecoderStartToken(): number {
return this.decoderStartToken;
}

/**
* Get model hyper-parameters and metadata
*
Expand All @@ -229,6 +242,18 @@ export class Wllama {
return this.useMultiThread;
}

/**
* Check if the current model uses encoder-decoder architecture
*
* NOTE: This can only being used after `loadModel` is called.
*
* @returns true if multi-thread is used.
*/
isEncoderDecoderArchitecture(): boolean {
this.checkModelLoaded();
return this.hasEncoder;
}

/**
* Parses a model URL and returns an array of URLs based on the following patterns:
* - If the input URL is an array, it returns the array itself.
Expand Down Expand Up @@ -344,6 +369,8 @@ export class Wllama {
token_bos: number,
token_eos: number,
token_eot: number,
has_encoder: boolean,
token_decoder_start: number,
} = await this.proxy.wllamaAction('load', {
...config,
use_mmap: true,
Expand All @@ -368,6 +395,8 @@ export class Wllama {
},
meta: loadResult.metadata,
};
this.hasEncoder = !!loadResult.has_encoder;
this.decoderStartToken = loadResult.token_decoder_start;
}

//////////////////////////////////////////////
Expand Down Expand Up @@ -421,7 +450,12 @@ export class Wllama {
// process prompt
const tokens = await this.tokenize(prompt, true);
await this.samplingAccept(tokens);
await this.decode(tokens, {});
if (this.isEncoderDecoderArchitecture()) {
await this.encode(tokens);
await this.decode([this.getDecoderStartToken()], {});
} else {
await this.decode(tokens, {});
}
let outBuf = new Uint8Array();
// abort signal
let abort = false;
Expand Down Expand Up @@ -548,6 +582,31 @@ export class Wllama {
}
}

/**
* Run llama_encode()
* @param tokens A list of tokens to be encoded
* @param options Unused for now
* @returns n_past (number of tokens so far in the sequence)
*/
async encode(tokens: number[], options?: Record<never, never>): Promise<{ nPast: number }> {
this.checkModelLoaded();
if (!this.hasEncoder) {
throw new Error('This model does not use encoder-decoder architecture.');
}
if (this.useEmbeddings) {
throw new Error('embeddings is enabled. Use wllama.setOptions({ embeddings: false }) to disable it.');
}
const req: any = { tokens };
const result = await this.proxy.wllamaAction('encode', req);
if (result.error) {
throw new Error(result.error);
} else if (!result.success) {
throw new Error('Cannot encode, unknown error');
} else {
return { nPast: result.n_past };
}
}

/**
* Sample a new token (remember to samplingInit() at least once before calling this function)
* @returns the token ID and its detokenized value (which maybe an unfinished unicode)
Expand Down
1 change: 1 addition & 0 deletions wllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ extern "C" const char *wllama_action(const char *name, const char *body)
WLLAMA_ACTION(tokenize);
WLLAMA_ACTION(detokenize);
WLLAMA_ACTION(decode);
WLLAMA_ACTION(encode);
WLLAMA_ACTION(get_logits);
WLLAMA_ACTION(embeddings);
WLLAMA_ACTION(kv_remove);
Expand Down

0 comments on commit 97db6f5

Please sign in to comment.