Skip to content

Commit

Permalink
implement KV cache reuse (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson authored Aug 2, 2024
1 parent 30df6c5 commit 3897d6c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
1 change: 1 addition & 0 deletions examples/main/src/utils/wllama.context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ export const WllamaProvider = ({ children }: any) => {
stopSignal = false;
const result = await wllamaInstance.createCompletion(input, {
nPredict: currParams.nPredict,
useCache: true,
sampling: {
temp: currParams.temperature,
},
Expand Down
2 changes: 1 addition & 1 deletion llama.cpp
42 changes: 40 additions & 2 deletions src/wllama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ export interface ChatCompletionOptions {
* Note: To convert from text to token ID, use lookupToken()
*/
stopTokens?: number[];
/**
* Equivalent to `cache_prompt` option in llama.cpp server.
* Useful for chat, because it skip evaluating the history part of the conversation.
*/
useCache?: boolean;
}

export interface ModelMetadata {
Expand Down Expand Up @@ -582,17 +587,23 @@ export class Wllama {
this.checkModelLoaded();
this.samplingConfig = options.sampling ?? {};
await this.samplingInit(this.samplingConfig);
await this.kvClear(); // TODO: maybe cache tokens?
const stopTokens = [
this.eosToken,
this.eotToken,
...(options.stopTokens ?? []),
];
// process prompt
const tokens = await this.tokenize(prompt, true);
let tokens = await this.tokenize(prompt, true);
if (this.addBosToken && tokens[0] !== this.bosToken) {
tokens.unshift(this.bosToken);
}
// maybe reuse KV cache
if (options.useCache) {
tokens = await this.getNonCachedTokens(tokens);
} else {
await this.kvClear();
}
// decode/encode tokens
await this.samplingAccept(tokens);
if (this.isEncoderDecoderArchitecture()) {
await this.encode(tokens);
Expand Down Expand Up @@ -713,6 +724,7 @@ export class Wllama {
async decode(
tokens: number[],
options: {
// when processing input prompt, we don't need to get output tokens
skipLogits?: boolean;
}
): Promise<{ nPast: number }> {
Expand Down Expand Up @@ -911,5 +923,31 @@ export class Wllama {
return await this.proxy.wllamaDebug();
}


///// Prompt cache utils /////
private async getCachedToken(): Promise<number[]> {
this.checkModelLoaded();
const result = await this.proxy.wllamaAction('current_status', {});
return result.tokens;
}

/**
* Compare the input sequence and cachedToken, then return the part that is not in cache.
* This function also remove mismatch part in cache (via kvRemove)
*/
private async getNonCachedTokens(seq: number[]) {
const cachedTokens = await this.getCachedToken();
let nKeep = 0;
for (; nKeep < Math.min(cachedTokens.length, seq.length); nKeep++) {
if (cachedTokens[nKeep] !== seq[nKeep]) {
break;
}
}
const nDiscard = cachedTokens.length - nKeep;
this.logger().debug(`Cache nKeep=${nKeep} nDiscard=${nDiscard}`)
await this.kvRemove(nKeep, nDiscard);
return seq.slice(nKeep, seq.length);
}

// TODO: add current_status
}

1 comment on commit 3897d6c

@flatsiedatsie
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain a little what this does? It seems to me that I should just keep this cache enabled all the time? Are there any downside or other '"gotcha's"?

Please sign in to comment.