From 353ee6f1864b446ce076755adfb325300c450a55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Rzepecki?= Date: Sat, 7 Sep 2024 16:49:32 -0700 Subject: [PATCH] feat: Choose models dynamically when using VSCode LM --- .vscode/launch.json | 5 +- src/extension.ts | 2 +- src/services/chatCompletion.ts | 60 +++++++++++++++++------ src/services/processWatcher.ts | 12 ++--- test/unit/services/chatCompletion.test.ts | 54 ++++++++++++++++++-- 5 files changed, 103 insertions(+), 30 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 13f18ad7..7561ba80 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -9,8 +9,9 @@ "name": "Run extension", "type": "extensionHost", "request": "launch", - "args": ["--extensionDevelopmentPath=${workspaceFolder}", "--disable-extensions"], - + "args": [ + "--extensionDevelopmentPath=${workspaceFolder}" + ], "env": { "APPMAP_TELEMETRY_DEBUG": "1", "APPMAP_DEV_EXTENSION": "1" diff --git a/src/extension.ts b/src/extension.ts index 5f3e5060..d0fc2591 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -248,7 +248,7 @@ export async function activate(context: vscode.ExtensionContext): Promise { + const pref = ChatCompletion.preferredModel; + return { + OPENAI_API_KEY: this.key, + OPENAI_BASE_URL: this.url, + APPMAP_NAVIE_TOKEN_LIMIT: String(pref?.maxInputTokens ?? 3926), + APPMAP_NAVIE_MODEL: pref?.family ?? 'gpt-4o', + }; + } + + private static models: vscode.LanguageModelChat[] = []; + + static get preferredModel(): vscode.LanguageModelChat | undefined { + return ChatCompletion.models[0]; + } + + static async refreshModels(): Promise { + const previousBest = this.preferredModel?.id; + ChatCompletion.models = (await vscode.lm.selectChatModels()).sort( + (a, b) => b.maxInputTokens - a.maxInputTokens + b.family.localeCompare(a.family) + ); + if (this.preferredModel?.id !== previousBest) this.settingsChanged.fire(); + } + static get instance(): Promise { if (!instance) return Promise.resolve(undefined); return instance; @@ -123,14 +146,14 @@ export default class ChatCompletion implements Disposable { return; } - const model = - (await lm.selectChatModels({ family: request.model }))[0] ?? - (await lm.selectChatModels({ version: request.model }))[0]; + const modelName = request.model; + const model = ChatCompletion.models.find((m) => + [m.id, m.name, m.family, m.version].includes(modelName) + ); if (!model) { res.writeHead(404); - const availableModels = await lm.selectChatModels(); - const message = `Model ${request.model} not found. Available models: ${JSON.stringify( - availableModels + const message = `Model ${modelName} not found. Available models: ${JSON.stringify( + ChatCompletion.models )}`; warn(message); res.end(message); @@ -166,7 +189,6 @@ export default class ChatCompletion implements Disposable { static initialize(context: ExtensionContext) { // TODO: make the messages and handling generic for all LM extensions - const hasLM = 'lm' in vscode && 'selectChatModels' in vscode.lm; if (ExtensionSettings.useVsCodeLM && checkAvailability()) @@ -184,6 +206,15 @@ export default class ChatCompletion implements Disposable { }) ); + if (hasLM) { + ChatCompletion.refreshModels(); + vscode.lm.onDidChangeChatModels( + ChatCompletion.refreshModels, + undefined, + context.subscriptions + ); + } + function checkAvailability() { if (!hasLM) vscode.window.showErrorMessage( @@ -197,14 +228,11 @@ export default class ChatCompletion implements Disposable { ) .then((selection) => { if (selection === 'Install Copilot') { - vscode.lm.onDidChangeChatModels( - () => { - context.subscriptions.push(new ChatCompletion()); - ChatCompletion.settingsChanged.fire(); - }, - undefined, - context.subscriptions - ); + const odc = vscode.lm.onDidChangeChatModels(() => { + context.subscriptions.push(new ChatCompletion()); + ChatCompletion.settingsChanged.fire(); + odc.dispose(); + }); vscode.commands.executeCommand( 'workbench.extensions.installExtension', 'github.copilot' diff --git a/src/services/processWatcher.ts b/src/services/processWatcher.ts index 14bff11e..5446c6b7 100644 --- a/src/services/processWatcher.ts +++ b/src/services/processWatcher.ts @@ -51,21 +51,17 @@ async function accessToken(): Promise { export async function loadEnvironment( context: vscode.ExtensionContext ): Promise { + const chat = await ChatCompletion.instance; + const env: Record = { APPMAP_API_URL: ExtensionSettings.apiUrl, APPMAP_API_KEY: await accessToken(), ...ExtensionSettings.appMapCommandLineEnvironment, + ...chat?.env, }; const openAIApiKey = await getOpenAIApiKey(context); - const chat = await ChatCompletion.instance; - if (chat) { - env.OPENAI_API_KEY = chat.key; - env.OPENAI_BASE_URL = chat.url; - // TODO: set these dynamically based on the available models - env.APPMAP_NAVIE_TOKEN_LIMIT = '3925'; - env.APPMAP_NAVIE_MODEL = 'gpt-4-turbo'; - } else if (openAIApiKey) { + if (!chat && openAIApiKey) { if ('AZURE_OPENAI_API_VERSION' in env) env.AZURE_OPENAI_API_KEY = openAIApiKey; else env.OPENAI_API_KEY = openAIApiKey; } diff --git a/test/unit/services/chatCompletion.test.ts b/test/unit/services/chatCompletion.test.ts index 023782cb..16569945 100644 --- a/test/unit/services/chatCompletion.test.ts +++ b/test/unit/services/chatCompletion.test.ts @@ -11,7 +11,7 @@ import type { } from 'vscode'; import ChatCompletion from '../../../src/services/chatCompletion'; -import { addMockChatModel } from '../mock/vscode/lm'; +import { addMockChatModel, resetModelMocks } from '../mock/vscode/lm'; import assert from 'node:assert'; const mockModel: LanguageModelChat = { @@ -27,10 +27,15 @@ const mockModel: LanguageModelChat = { vendor: 'Test Vendor', }; -addMockChatModel(mockModel); - describe('ChatCompletion', () => { let chatCompletion: ChatCompletion; + + beforeEach(async () => { + resetModelMocks(); + addMockChatModel(mockModel); + await ChatCompletion.refreshModels(); + }); + before(async () => { mockModel.sendRequest = sendRequestEcho; chatCompletion = new ChatCompletion(0, 'test-key'); @@ -42,6 +47,49 @@ describe('ChatCompletion', () => { sinon.restore(); }); + it('should return the correct environment variables', () => { + const env = chatCompletion.env; + expect(env).to.deep.equal({ + OPENAI_API_KEY: 'test-key', + OPENAI_BASE_URL: chatCompletion.url, + APPMAP_NAVIE_TOKEN_LIMIT: '325', + APPMAP_NAVIE_MODEL: 'test-family', + }); + }); + + it('should refresh models and set the preferred model', async () => { + resetModelMocks(); + const mockModel1 = { + ...mockModel, + id: 'model-1', + family: 'family-1', + version: 'version-1', + maxInputTokens: 100, + name: 'Model 1', + vendor: 'Vendor 1', + }; + const mockModel2 = { + ...mockModel, + id: 'model-2', + family: 'family-2', + version: 'version-2', + maxInputTokens: 200, + name: 'Model 2', + vendor: 'Vendor 2', + }; + addMockChatModel(mockModel1); + addMockChatModel(mockModel2); + + await ChatCompletion.refreshModels(); + expect(ChatCompletion.preferredModel).to.equal(mockModel2); + }); + + it('should return undefined if no models are available', async () => { + resetModelMocks(); + await ChatCompletion.refreshModels(); + expect(ChatCompletion.preferredModel).to.be.undefined; + }); + it('should create a server and listen on a random port', async () => { const instance = await ChatCompletion.instance; expect(instance).to.equal(chatCompletion);