From a19e676df8d54760f39edfb4d73f67d4653dd968 Mon Sep 17 00:00:00 2001 From: Jeason Date: Thu, 8 Aug 2024 13:58:42 +0800 Subject: [PATCH] fix: change of prompt api (#21) * fix: change of prompt api * test: coverage --- src/global.d.ts | 10 ++++++++-- src/language-model.test.ts | 13 ++++++++----- src/language-model.ts | 35 +++++++---------------------------- src/polyfill/session.ts | 14 ++++++++------ 4 files changed, 31 insertions(+), 41 deletions(-) diff --git a/src/global.d.ts b/src/global.d.ts index 6a84c80..33ab90a 100644 --- a/src/global.d.ts +++ b/src/global.d.ts @@ -1,6 +1,12 @@ export type ChromeAISessionAvailable = 'no' | 'after-download' | 'readily'; -export interface ChromeAISessionOptions { +export interface ChromeAIModelInfo { + defaultTemperature: number; + defaultTopK: number; + maxTopK: number; +} + +export interface ChromeAISessionOptions extends Record { temperature?: number; topK?: number; } @@ -13,7 +19,7 @@ export interface ChromeAISession { export interface ChromePromptAPI { canCreateTextSession: () => Promise; - defaultTextSessionOptions: () => Promise; + textModelInfo: () => Promise; createTextSession: ( options?: ChromeAISessionOptions ) => Promise; diff --git a/src/language-model.test.ts b/src/language-model.test.ts index 7d5c1f4..2a3bf37 100644 --- a/src/language-model.test.ts +++ b/src/language-model.test.ts @@ -56,12 +56,15 @@ describe('language-model', () => { it('should do generate text', async () => { const canCreateSession = vi.fn(async () => 'readily'); - const getOptions = vi.fn(async () => ({ temperature: 1, topK: 10 })); + const getOptions = vi.fn(async () => ({ + defaultTemperature: 1, + defaultTopK: 10, + })); const prompt = vi.fn(async (prompt: string) => prompt); const createSession = vi.fn(async () => ({ prompt })); vi.stubGlobal('ai', { canCreateTextSession: canCreateSession, - defaultTextSessionOptions: getOptions, + textModelInfo: getOptions, createTextSession: createSession, }); @@ -104,7 +107,7 @@ describe('language-model', () => { }); vi.stubGlobal('ai', { canCreateTextSession: vi.fn(async () => 'readily'), - defaultTextSessionOptions: vi.fn(async () => ({})), + textModelInfo: vi.fn(async () => ({})), createTextSession: vi.fn(async () => ({ promptStreaming })), }); @@ -121,7 +124,7 @@ describe('language-model', () => { const prompt = vi.fn(async (prompt: string) => '{"hello":"world"}'); vi.stubGlobal('ai', { canCreateTextSession: vi.fn(async () => 'readily'), - defaultTextSessionOptions: vi.fn(async () => ({})), + textModelInfo: vi.fn(async () => ({})), createTextSession: vi.fn(async () => ({ prompt })), }); @@ -140,7 +143,7 @@ describe('language-model', () => { const prompt = vi.fn(async (prompt: string) => prompt); vi.stubGlobal('ai', { canCreateTextSession: vi.fn(async () => 'readily'), - defaultTextSessionOptions: vi.fn(async () => ({})), + textModelInfo: vi.fn(async () => ({})), createTextSession: vi.fn(async () => ({ prompt })), }); await expect(() => diff --git a/src/language-model.ts b/src/language-model.ts index a768edc..dc94eeb 100644 --- a/src/language-model.ts +++ b/src/language-model.ts @@ -24,33 +24,7 @@ export type ChromeAIChatModelId = 'text'; export interface ChromeAIChatSettings extends Record { temperature?: number; - /** - * Optional. The maximum number of tokens to consider when sampling. - * - * Models use nucleus sampling or combined Top-k and nucleus sampling. - * Top-k sampling considers the set of topK most probable tokens. - * Models running with nucleus sampling don't allow topK setting. - */ topK?: number; - - /** - * Optional. A list of unique safety settings for blocking unsafe content. - * @note this is not working yet - */ - safetySettings?: Array<{ - category: - | 'HARM_CATEGORY_HATE_SPEECH' - | 'HARM_CATEGORY_DANGEROUS_CONTENT' - | 'HARM_CATEGORY_HARASSMENT' - | 'HARM_CATEGORY_SEXUALLY_EXPLICIT'; - - threshold: - | 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' - | 'BLOCK_LOW_AND_ABOVE' - | 'BLOCK_MEDIUM_AND_ABOVE' - | 'BLOCK_ONLY_HIGH' - | 'BLOCK_NONE'; - }>; } function getStringContent( @@ -105,8 +79,13 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 { throw new LoadSettingError({ message: 'Built-in model not ready' }); } - const defaultOptions = await ai.defaultTextSessionOptions(); - this.options = { ...defaultOptions, ...this.options, ...options }; + const defaultOptions = await ai.textModelInfo(); + this.options = { + temperature: defaultOptions.defaultTemperature, + topK: defaultOptions.defaultTopK, + ...this.options, + ...options, + }; this.session = await ai.createTextSession(this.options); diff --git a/src/polyfill/session.ts b/src/polyfill/session.ts index 42511e8..ce8bedb 100644 --- a/src/polyfill/session.ts +++ b/src/polyfill/session.ts @@ -1,5 +1,6 @@ import { LlmInference, ProgressListener } from '@mediapipe/tasks-genai'; import { + ChromeAIModelInfo, ChromeAISession, ChromeAISessionAvailable, ChromeAISessionOptions, @@ -87,16 +88,17 @@ export class PolyfillChromeAI implements ChromePromptAPI { return isModelAssetBufferReady ? 'readily' : 'after-download'; }; - public defaultTextSessionOptions = - async (): Promise => ({ - temperature: 0.8, - topK: 3, - }); + public textModelInfo = async (): Promise => ({ + defaultTemperature: 0.8, + defaultTopK: 3, + maxTopK: 128, + }); public createTextSession = async ( options?: ChromeAISessionOptions ): Promise => { - const argv = options ?? (await this.defaultTextSessionOptions()); + const defaultParams = await this.textModelInfo(); + const argv = options ?? { temperature: 0.8, topK: 3 }; const llm = await LlmInference.createFromOptions( { wasmLoaderPath: this.aiOptions.wasmLoaderPath!,