Skip to content

Commit

Permalink
decode/encode : do not fail on empty batch (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson authored Sep 23, 2024
1 parent 7beefeb commit 49e29ec
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions src/wllama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ export class Wllama {
}
// maybe reuse KV cache
if (options.useCache) {
tokens = await this.getNonCachedTokens(tokens);
tokens = await this.computeNonCachedTokens(tokens);
} else {
await this.kvClear();
}
Expand Down Expand Up @@ -741,6 +741,12 @@ export class Wllama {
'embeddings is enabled. Use wllama.setOptions({ embeddings: false }) to disable it.'
);
}
if (tokens.length === 0) {
// do not call llama_decode if list of tokens is empty
return {
nPast: (await this.getCachedTokens()).length,
};
}
const req: any = { tokens };
if (options.skipLogits) {
req.skip_logits = true;
Expand Down Expand Up @@ -775,6 +781,12 @@ export class Wllama {
'embeddings is enabled. Use wllama.setOptions({ embeddings: false }) to disable it.'
);
}
if (tokens.length === 0) {
// do not call llama_encode if list of tokens is empty
return {
nPast: (await this.getCachedTokens()).length,
};
}
const req: any = { tokens };
const result = await this.proxy.wllamaAction('encode', req);
if (result.error) {
Expand Down Expand Up @@ -931,7 +943,7 @@ export class Wllama {
}

///// Prompt cache utils /////
private async getCachedToken(): Promise<number[]> {
private async getCachedTokens(): Promise<number[]> {
this.checkModelLoaded();
const result = await this.proxy.wllamaAction('current_status', {});
return result.tokens;
Expand All @@ -941,8 +953,8 @@ export class Wllama {
* Compare the input sequence and cachedToken, then return the part that is not in cache.
* This function also remove mismatch part in cache (via kvRemove)
*/
private async getNonCachedTokens(seq: number[]) {
const cachedTokens = await this.getCachedToken();
private async computeNonCachedTokens(seq: number[]) {
const cachedTokens = await this.getCachedTokens();
let nKeep = 0;
for (; nKeep < Math.min(cachedTokens.length, seq.length); nKeep++) {
if (cachedTokens[nKeep] !== seq[nKeep]) {
Expand All @@ -951,7 +963,9 @@ export class Wllama {
}
const nDiscard = cachedTokens.length - nKeep;
this.logger().debug(`Cache nKeep=${nKeep} nDiscard=${nDiscard}`);
await this.kvRemove(nKeep, nDiscard);
if (nDiscard > 0) {
await this.kvRemove(nKeep, nDiscard);
}
return seq.slice(nKeep, seq.length);
}

Expand Down

0 comments on commit 49e29ec

Please sign in to comment.