diff --git a/src/wllama.ts b/src/wllama.ts index 60c5cb8..09a39f2 100644 --- a/src/wllama.ts +++ b/src/wllama.ts @@ -606,7 +606,7 @@ export class Wllama { } // maybe reuse KV cache if (options.useCache) { - tokens = await this.getNonCachedTokens(tokens); + tokens = await this.computeNonCachedTokens(tokens); } else { await this.kvClear(); } @@ -741,6 +741,12 @@ export class Wllama { 'embeddings is enabled. Use wllama.setOptions({ embeddings: false }) to disable it.' ); } + if (tokens.length === 0) { + // do not call llama_decode if list of tokens is empty + return { + nPast: (await this.getCachedTokens()).length, + }; + } const req: any = { tokens }; if (options.skipLogits) { req.skip_logits = true; @@ -775,6 +781,12 @@ export class Wllama { 'embeddings is enabled. Use wllama.setOptions({ embeddings: false }) to disable it.' ); } + if (tokens.length === 0) { + // do not call llama_encode if list of tokens is empty + return { + nPast: (await this.getCachedTokens()).length, + }; + } const req: any = { tokens }; const result = await this.proxy.wllamaAction('encode', req); if (result.error) { @@ -931,7 +943,7 @@ export class Wllama { } ///// Prompt cache utils ///// - private async getCachedToken(): Promise { + private async getCachedTokens(): Promise { this.checkModelLoaded(); const result = await this.proxy.wllamaAction('current_status', {}); return result.tokens; @@ -941,8 +953,8 @@ export class Wllama { * Compare the input sequence and cachedToken, then return the part that is not in cache. * This function also remove mismatch part in cache (via kvRemove) */ - private async getNonCachedTokens(seq: number[]) { - const cachedTokens = await this.getCachedToken(); + private async computeNonCachedTokens(seq: number[]) { + const cachedTokens = await this.getCachedTokens(); let nKeep = 0; for (; nKeep < Math.min(cachedTokens.length, seq.length); nKeep++) { if (cachedTokens[nKeep] !== seq[nKeep]) { @@ -951,7 +963,9 @@ export class Wllama { } const nDiscard = cachedTokens.length - nKeep; this.logger().debug(`Cache nKeep=${nKeep} nDiscard=${nDiscard}`); - await this.kvRemove(nKeep, nDiscard); + if (nDiscard > 0) { + await this.kvRemove(nKeep, nDiscard); + } return seq.slice(nKeep, seq.length); }