Skip to content

Commit

Permalink
fix: change of prompt api (#21)
Browse files Browse the repository at this point in the history
* fix: change of prompt api

* test: coverage
  • Loading branch information
jeasonstudio authored Aug 8, 2024
1 parent e831ec2 commit a19e676
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 41 deletions.
10 changes: 8 additions & 2 deletions src/global.d.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
export type ChromeAISessionAvailable = 'no' | 'after-download' | 'readily';

export interface ChromeAISessionOptions {
export interface ChromeAIModelInfo {
defaultTemperature: number;
defaultTopK: number;
maxTopK: number;
}

export interface ChromeAISessionOptions extends Record<string, any> {
temperature?: number;
topK?: number;
}
Expand All @@ -13,7 +19,7 @@ export interface ChromeAISession {

export interface ChromePromptAPI {
canCreateTextSession: () => Promise<ChromeAISessionAvailable>;
defaultTextSessionOptions: () => Promise<ChromeAISessionOptions>;
textModelInfo: () => Promise<ChromeAIModelInfo>;
createTextSession: (
options?: ChromeAISessionOptions
) => Promise<ChromeAISession>;
Expand Down
13 changes: 8 additions & 5 deletions src/language-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,15 @@ describe('language-model', () => {

it('should do generate text', async () => {
const canCreateSession = vi.fn(async () => 'readily');
const getOptions = vi.fn(async () => ({ temperature: 1, topK: 10 }));
const getOptions = vi.fn(async () => ({
defaultTemperature: 1,
defaultTopK: 10,
}));
const prompt = vi.fn(async (prompt: string) => prompt);
const createSession = vi.fn(async () => ({ prompt }));
vi.stubGlobal('ai', {
canCreateTextSession: canCreateSession,
defaultTextSessionOptions: getOptions,
textModelInfo: getOptions,
createTextSession: createSession,
});

Expand Down Expand Up @@ -104,7 +107,7 @@ describe('language-model', () => {
});
vi.stubGlobal('ai', {
canCreateTextSession: vi.fn(async () => 'readily'),
defaultTextSessionOptions: vi.fn(async () => ({})),
textModelInfo: vi.fn(async () => ({})),
createTextSession: vi.fn(async () => ({ promptStreaming })),
});

Expand All @@ -121,7 +124,7 @@ describe('language-model', () => {
const prompt = vi.fn(async (prompt: string) => '{"hello":"world"}');
vi.stubGlobal('ai', {
canCreateTextSession: vi.fn(async () => 'readily'),
defaultTextSessionOptions: vi.fn(async () => ({})),
textModelInfo: vi.fn(async () => ({})),
createTextSession: vi.fn(async () => ({ prompt })),
});

Expand All @@ -140,7 +143,7 @@ describe('language-model', () => {
const prompt = vi.fn(async (prompt: string) => prompt);
vi.stubGlobal('ai', {
canCreateTextSession: vi.fn(async () => 'readily'),
defaultTextSessionOptions: vi.fn(async () => ({})),
textModelInfo: vi.fn(async () => ({})),
createTextSession: vi.fn(async () => ({ prompt })),
});
await expect(() =>
Expand Down
35 changes: 7 additions & 28 deletions src/language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,7 @@ export type ChromeAIChatModelId = 'text';

export interface ChromeAIChatSettings extends Record<string, unknown> {
temperature?: number;
/**
* Optional. The maximum number of tokens to consider when sampling.
*
* Models use nucleus sampling or combined Top-k and nucleus sampling.
* Top-k sampling considers the set of topK most probable tokens.
* Models running with nucleus sampling don't allow topK setting.
*/
topK?: number;

/**
* Optional. A list of unique safety settings for blocking unsafe content.
* @note this is not working yet
*/
safetySettings?: Array<{
category:
| 'HARM_CATEGORY_HATE_SPEECH'
| 'HARM_CATEGORY_DANGEROUS_CONTENT'
| 'HARM_CATEGORY_HARASSMENT'
| 'HARM_CATEGORY_SEXUALLY_EXPLICIT';

threshold:
| 'HARM_BLOCK_THRESHOLD_UNSPECIFIED'
| 'BLOCK_LOW_AND_ABOVE'
| 'BLOCK_MEDIUM_AND_ABOVE'
| 'BLOCK_ONLY_HIGH'
| 'BLOCK_NONE';
}>;
}

function getStringContent(
Expand Down Expand Up @@ -105,8 +79,13 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 {
throw new LoadSettingError({ message: 'Built-in model not ready' });
}

const defaultOptions = await ai.defaultTextSessionOptions();
this.options = { ...defaultOptions, ...this.options, ...options };
const defaultOptions = await ai.textModelInfo();
this.options = {
temperature: defaultOptions.defaultTemperature,
topK: defaultOptions.defaultTopK,
...this.options,
...options,
};

this.session = await ai.createTextSession(this.options);

Expand Down
14 changes: 8 additions & 6 deletions src/polyfill/session.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { LlmInference, ProgressListener } from '@mediapipe/tasks-genai';
import {
ChromeAIModelInfo,
ChromeAISession,
ChromeAISessionAvailable,
ChromeAISessionOptions,
Expand Down Expand Up @@ -87,16 +88,17 @@ export class PolyfillChromeAI implements ChromePromptAPI {
return isModelAssetBufferReady ? 'readily' : 'after-download';
};

public defaultTextSessionOptions =
async (): Promise<ChromeAISessionOptions> => ({
temperature: 0.8,
topK: 3,
});
public textModelInfo = async (): Promise<ChromeAIModelInfo> => ({
defaultTemperature: 0.8,
defaultTopK: 3,
maxTopK: 128,
});

public createTextSession = async (
options?: ChromeAISessionOptions
): Promise<ChromeAISession> => {
const argv = options ?? (await this.defaultTextSessionOptions());
const defaultParams = await this.textModelInfo();
const argv = options ?? { temperature: 0.8, topK: 3 };
const llm = await LlmInference.createFromOptions(
{
wasmLoaderPath: this.aiOptions.wasmLoaderPath!,
Expand Down

0 comments on commit a19e676

Please sign in to comment.