Skip to content

Commit

Permalink
chore: adapte to chrome's prompt api changes (#16)
Browse files Browse the repository at this point in the history
* chore: adapte to chrome's prompt api changes

* chore: add changeset
  • Loading branch information
jeasonstudio authored Jul 15, 2024
1 parent 7670413 commit 773d9d5
Show file tree
Hide file tree
Showing 10 changed files with 40 additions and 61 deletions.
5 changes: 5 additions & 0 deletions .changeset/stupid-ghosts-smash.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"chrome-ai": minor
---

chore: adapte to chrome's prompt api changes
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ It automatically selects the correct model id. You can also pass additional sett
```ts
import { chromeai } from 'chrome-ai';

const model = chromeai('generic', {
const model = chromeai('text', {
// additional settings
temperature: 0.5,
topK: 5,
Expand All @@ -77,7 +77,7 @@ const model = chromeai('generic', {

You can use the following optional settings to customize:

- **modelId** `'text' | 'generic'` (default: `'generic'`)
- **modelId** `'text' (default: `'text'`)
- **temperature** `number` (default: `0.8`)
- **topK** `number` (default: `3`)

Expand Down
2 changes: 1 addition & 1 deletion app/components/outputs.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export const Outputs = React.forwardRef<

setIsEnabledFlags(!!('ai' in globalThis));

globalThis.ai?.canCreateGenericSession().then((status) => {
globalThis.ai?.canCreateTextSession().then((status) => {
setIsEnabledFlags(status === 'readily');
});
}, []);
Expand Down
4 changes: 2 additions & 2 deletions app/components/settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import { Footer } from '../components/footer';
import { cn } from '../utils';

const formSchema = z.object({
model: z.enum(['generic', 'text']),
model: z.enum(['text']),
temperature: z.number().min(0).max(1),
topK: z.number().min(1),
role: z.enum(['system', 'user', 'assistant']),
Expand All @@ -52,7 +52,7 @@ export const useSettingsForm = (
resolver: zodResolver(formSchema),
defaultValues: Object.assign(
{
model: 'generic',
model: 'text',
temperature: 0.8,
topK: 3,
role: 'system',
Expand Down
2 changes: 1 addition & 1 deletion src/chromeai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { chromeai } from './chromeai';

describe('chromeai', () => {
it('should correctly create instance', async () => {
expect(chromeai().modelId).toBe('generic');
expect(chromeai().modelId).toBe('text');
expect(chromeai('text').modelId).toBe('text');
expect(chromeai('embedding').modelId).toBe('embedding');
expect(chromeai.embedding().modelId).toBe('embedding');
Expand Down
4 changes: 2 additions & 2 deletions src/chromeai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const debug = createDebug('chromeai');

/**
* Create a new ChromeAI model/embedding instance.
* @param modelId 'generic' | 'text' | 'embedding'
* @param modelId 'text' | 'embedding'
* @param settings Options for the model
*/
export function chromeai(
Expand All @@ -24,7 +24,7 @@ export function chromeai(
modelId?: 'embedding',
settings?: ChromeAIEmbeddingModelSettings
): ChromeAIEmbeddingModel;
export function chromeai(modelId: string = 'generic', settings: any = {}) {
export function chromeai(modelId: string = 'text', settings: any = {}) {
debug('create instance', modelId, settings);
if (modelId === 'embedding') {
return new ChromeAIEmbeddingModel(settings);
Expand Down
7 changes: 0 additions & 7 deletions src/global.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,11 @@ export interface ChromeAISession {
destroy: () => Promise<void>;
prompt: (prompt: string) => Promise<string>;
promptStreaming: (prompt: string) => ReadableStream<string>;
execute: (prompt: string) => Promise<string>;
executeStreaming: (prompt: string) => ReadableStream<string>;
}

export interface ChromePromptAPI {
canCreateGenericSession: () => Promise<ChromeAISessionAvailable>;
canCreateTextSession: () => Promise<ChromeAISessionAvailable>;
defaultGenericSessionOptions: () => Promise<ChromeAISessionOptions>;
defaultTextSessionOptions: () => Promise<ChromeAISessionOptions>;
createGenericSession: (
options?: ChromeAISessionOptions
) => Promise<ChromeAISession>;
createTextSession: (
options?: ChromeAISessionOptions
) => Promise<ChromeAISession>;
Expand Down
44 changes: 20 additions & 24 deletions src/language-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ describe('language-model', () => {
});

it('should instantiation anyways', () => {
expect(new ChromeAIChatLanguageModel('generic')).toBeInstanceOf(
expect(new ChromeAIChatLanguageModel('text')).toBeInstanceOf(
ChromeAIChatLanguageModel
);
expect(new ChromeAIChatLanguageModel('text').modelId).toBe('text');
Expand All @@ -27,15 +27,14 @@ describe('language-model', () => {
it('should throw when not support', async () => {
await expect(() =>
generateText({
model: new ChromeAIChatLanguageModel('generic'),
model: new ChromeAIChatLanguageModel('text'),
prompt: 'empty',
})
).rejects.toThrowError(LoadSettingError);

const cannotCreateSession = vi.fn(async () => 'no');
vi.stubGlobal('ai', {
canCreateTextSession: cannotCreateSession,
canCreateGenericSession: cannotCreateSession,
});

await expect(() =>
Expand All @@ -48,7 +47,7 @@ describe('language-model', () => {

await expect(() =>
generateText({
model: new ChromeAIChatLanguageModel('generic'),
model: new ChromeAIChatLanguageModel('text'),
prompt: 'empty',
})
).rejects.toThrowError(LoadSettingError);
Expand All @@ -61,11 +60,8 @@ describe('language-model', () => {
const prompt = vi.fn(async (prompt: string) => prompt);
const createSession = vi.fn(async () => ({ prompt }));
vi.stubGlobal('ai', {
canCreateGenericSession: canCreateSession,
canCreateTextSession: canCreateSession,
defaultGenericSessionOptions: getOptions,
defaultTextSessionOptions: getOptions,
createGenericSession: createSession,
createTextSession: createSession,
});

Expand All @@ -76,7 +72,7 @@ describe('language-model', () => {
expect(getOptions).toHaveBeenCalledTimes(1);

const result = await generateText({
model: new ChromeAIChatLanguageModel('generic'),
model: new ChromeAIChatLanguageModel('text'),
prompt: 'test',
});
expect(result).toMatchObject({
Expand All @@ -85,7 +81,7 @@ describe('language-model', () => {
});

const resultForMessages = await generateText({
model: new ChromeAIChatLanguageModel('generic'),
model: new ChromeAIChatLanguageModel('text'),
messages: [
{ role: 'user', content: 'test' },
{ role: 'assistant', content: 'assistant' },
Expand All @@ -107,13 +103,13 @@ describe('language-model', () => {
return stream;
});
vi.stubGlobal('ai', {
canCreateGenericSession: vi.fn(async () => 'readily'),
defaultGenericSessionOptions: vi.fn(async () => ({})),
createGenericSession: vi.fn(async () => ({ promptStreaming })),
canCreateTextSession: vi.fn(async () => 'readily'),
defaultTextSessionOptions: vi.fn(async () => ({})),
createTextSession: vi.fn(async () => ({ promptStreaming })),
});

const result = await streamText({
model: new ChromeAIChatLanguageModel('generic'),
model: new ChromeAIChatLanguageModel('text'),
prompt: 'test',
});
for await (const textPart of result.textStream) {
Expand All @@ -124,13 +120,13 @@ describe('language-model', () => {
it('should do generate object', async () => {
const prompt = vi.fn(async (prompt: string) => '{"hello":"world"}');
vi.stubGlobal('ai', {
canCreateGenericSession: vi.fn(async () => 'readily'),
defaultGenericSessionOptions: vi.fn(async () => ({})),
createGenericSession: vi.fn(async () => ({ prompt })),
canCreateTextSession: vi.fn(async () => 'readily'),
defaultTextSessionOptions: vi.fn(async () => ({})),
createTextSession: vi.fn(async () => ({ prompt })),
});

const { object } = await generateObject({
model: new ChromeAIChatLanguageModel('generic'),
model: new ChromeAIChatLanguageModel('text'),
schema: z.object({
hello: z.string(),
}),
Expand All @@ -143,13 +139,13 @@ describe('language-model', () => {
it('should throw when tool call', async () => {
const prompt = vi.fn(async (prompt: string) => prompt);
vi.stubGlobal('ai', {
canCreateGenericSession: vi.fn(async () => 'readily'),
defaultGenericSessionOptions: vi.fn(async () => ({})),
createGenericSession: vi.fn(async () => ({ prompt })),
canCreateTextSession: vi.fn(async () => 'readily'),
defaultTextSessionOptions: vi.fn(async () => ({})),
createTextSession: vi.fn(async () => ({ prompt })),
});
await expect(() =>
generateText({
model: new ChromeAIChatLanguageModel('generic'),
model: new ChromeAIChatLanguageModel('text'),
messages: [
{
role: 'tool',
Expand All @@ -166,7 +162,7 @@ describe('language-model', () => {
})
).rejects.toThrowError(UnsupportedFunctionalityError);

const model = new ChromeAIChatLanguageModel('generic', { temperature: 1 });
const model = new ChromeAIChatLanguageModel('text', { temperature: 1 });
(model as any).session = { prompt };
const result = await generateText({
model,
Expand All @@ -181,7 +177,7 @@ describe('language-model', () => {

await expect(() =>
generateObject({
model: new ChromeAIChatLanguageModel('generic'),
model: new ChromeAIChatLanguageModel('text'),
mode: 'grammar',
schema: z.object({}),
prompt: 'test',
Expand All @@ -190,7 +186,7 @@ describe('language-model', () => {

await expect(() =>
streamObject({
model: new ChromeAIChatLanguageModel('generic'),
model: new ChromeAIChatLanguageModel('text'),
mode: 'grammar',
schema: z.object({}),
prompt: 'test',
Expand Down
25 changes: 6 additions & 19 deletions src/language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@ import {
import { ChromeAISession, ChromeAISessionOptions } from './global';
import createDebug from 'debug';
import { StreamAI } from './stream-ai';
import {
ChromeAIEmbeddingModel,
ChromeAIEmbeddingModelSettings,
} from './embedding-model';

const debug = createDebug('chromeai');

export type ChromeAIChatModelId = 'text' | 'generic';
export type ChromeAIChatModelId = 'text';

export interface ChromeAIChatSettings extends Record<string, unknown> {
temperature?: number;
Expand Down Expand Up @@ -80,7 +76,7 @@ function getStringContent(
export class ChromeAIChatLanguageModel implements LanguageModelV1 {
readonly specificationVersion = 'v1';
readonly defaultObjectGenerationMode = 'json';
readonly modelId: ChromeAIChatModelId = 'generic';
readonly modelId: ChromeAIChatModelId = 'text';
readonly provider = 'gemini-nano';
options: ChromeAIChatSettings;

Expand All @@ -97,31 +93,22 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 {
private getSession = async (
options?: ChromeAISessionOptions
): Promise<ChromeAISession> => {
if (!globalThis.ai?.canCreateGenericSession) {
if (!globalThis.ai?.canCreateTextSession) {
throw new LoadSettingError({ message: 'Browser not support' });
}

const available =
this.modelId === 'text'
? await ai.canCreateTextSession()
: await ai.canCreateGenericSession();
const available = await ai.canCreateTextSession();

if (this.session) return this.session;

if (available !== 'readily') {
throw new LoadSettingError({ message: 'Built-in model not ready' });
}

const defaultOptions =
this.modelId === 'text'
? await ai.defaultTextSessionOptions()
: await ai.defaultGenericSessionOptions();
const defaultOptions = await ai.defaultTextSessionOptions();
this.options = { ...defaultOptions, ...this.options, ...options };

this.session =
this.modelId === 'text'
? await ai.createTextSession(this.options)
: await ai.createGenericSession(this.options);
this.session = await ai.createTextSession(this.options);

debug('session created:', this.session, this.options);
return this.session;
Expand Down
4 changes: 1 addition & 3 deletions src/polyfill/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,9 @@ export class PolyfillChromeAI implements ChromePromptAPI {
return session;
};

public canCreateGenericSession = this.canCreateSession;

public canCreateTextSession = this.canCreateSession;
public defaultGenericSessionOptions = this.defaultSessionOptions;
public defaultTextSessionOptions = this.defaultSessionOptions;
public createGenericSession = this.createSession;
public createTextSession = this.createSession;
}

Expand Down

0 comments on commit 773d9d5

Please sign in to comment.