Skip to content

Commit

Permalink
feat: Choose models dynamically when using VSCode LM
Browse files Browse the repository at this point in the history
  • Loading branch information
dividedmind authored and dustinbyrne committed Sep 11, 2024
1 parent 3e7c90a commit 353ee6f
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 30 deletions.
5 changes: 3 additions & 2 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
"name": "Run extension",
"type": "extensionHost",
"request": "launch",
"args": ["--extensionDevelopmentPath=${workspaceFolder}", "--disable-extensions"],

"args": [
"--extensionDevelopmentPath=${workspaceFolder}"
],
"env": {
"APPMAP_TELEMETRY_DEBUG": "1",
"APPMAP_DEV_EXTENSION": "1"
Expand Down
2 changes: 1 addition & 1 deletion src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ export async function activate(context: vscode.ExtensionContext): Promise<AppMap
await rpcService.restartServer();
vscode.window.showInformationMessage('Navie restarted successfully.');
}),
ChatCompletion.onSettingsChanged(rpcService.restartServer, rpcService)
ChatCompletion.onSettingsChanged(rpcService.scheduleRestart, rpcService)
);

const webview = ChatSearchWebview.register(
Expand Down
60 changes: 44 additions & 16 deletions src/services/chatCompletion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import vscode, {
LanguageModelChat,
LanguageModelChatMessage,
LanguageModelChatResponse,
lm,
} from 'vscode';

import ExtensionSettings from '../configuration/extensionSettings';
Expand Down Expand Up @@ -75,6 +74,30 @@ export default class ChatCompletion implements Disposable {
return `http://localhost:${this.port}/vscode/copilot`;
}

get env(): Record<string, string> {
const pref = ChatCompletion.preferredModel;
return {
OPENAI_API_KEY: this.key,
OPENAI_BASE_URL: this.url,
APPMAP_NAVIE_TOKEN_LIMIT: String(pref?.maxInputTokens ?? 3926),
APPMAP_NAVIE_MODEL: pref?.family ?? 'gpt-4o',
};
}

private static models: vscode.LanguageModelChat[] = [];

static get preferredModel(): vscode.LanguageModelChat | undefined {
return ChatCompletion.models[0];
}

static async refreshModels(): Promise<void> {
const previousBest = this.preferredModel?.id;
ChatCompletion.models = (await vscode.lm.selectChatModels()).sort(
(a, b) => b.maxInputTokens - a.maxInputTokens + b.family.localeCompare(a.family)
);
if (this.preferredModel?.id !== previousBest) this.settingsChanged.fire();
}

static get instance(): Promise<ChatCompletion | undefined> {
if (!instance) return Promise.resolve(undefined);
return instance;
Expand Down Expand Up @@ -123,14 +146,14 @@ export default class ChatCompletion implements Disposable {
return;
}

const model =
(await lm.selectChatModels({ family: request.model }))[0] ??
(await lm.selectChatModels({ version: request.model }))[0];
const modelName = request.model;
const model = ChatCompletion.models.find((m) =>
[m.id, m.name, m.family, m.version].includes(modelName)
);
if (!model) {
res.writeHead(404);
const availableModels = await lm.selectChatModels();
const message = `Model ${request.model} not found. Available models: ${JSON.stringify(
availableModels
const message = `Model ${modelName} not found. Available models: ${JSON.stringify(
ChatCompletion.models
)}`;
warn(message);
res.end(message);
Expand Down Expand Up @@ -166,7 +189,6 @@ export default class ChatCompletion implements Disposable {

static initialize(context: ExtensionContext) {
// TODO: make the messages and handling generic for all LM extensions

const hasLM = 'lm' in vscode && 'selectChatModels' in vscode.lm;

if (ExtensionSettings.useVsCodeLM && checkAvailability())
Expand All @@ -184,6 +206,15 @@ export default class ChatCompletion implements Disposable {
})
);

if (hasLM) {
ChatCompletion.refreshModels();
vscode.lm.onDidChangeChatModels(
ChatCompletion.refreshModels,
undefined,
context.subscriptions
);
}

function checkAvailability() {
if (!hasLM)
vscode.window.showErrorMessage(
Expand All @@ -197,14 +228,11 @@ export default class ChatCompletion implements Disposable {
)
.then((selection) => {
if (selection === 'Install Copilot') {
vscode.lm.onDidChangeChatModels(
() => {
context.subscriptions.push(new ChatCompletion());
ChatCompletion.settingsChanged.fire();
},
undefined,
context.subscriptions
);
const odc = vscode.lm.onDidChangeChatModels(() => {
context.subscriptions.push(new ChatCompletion());
ChatCompletion.settingsChanged.fire();
odc.dispose();
});
vscode.commands.executeCommand(
'workbench.extensions.installExtension',
'github.copilot'
Expand Down
12 changes: 4 additions & 8 deletions src/services/processWatcher.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,17 @@ async function accessToken(): Promise<string | undefined> {
export async function loadEnvironment(
context: vscode.ExtensionContext
): Promise<NodeJS.ProcessEnv> {
const chat = await ChatCompletion.instance;

const env: Record<string, string | undefined> = {
APPMAP_API_URL: ExtensionSettings.apiUrl,
APPMAP_API_KEY: await accessToken(),
...ExtensionSettings.appMapCommandLineEnvironment,
...chat?.env,
};

const openAIApiKey = await getOpenAIApiKey(context);
const chat = await ChatCompletion.instance;
if (chat) {
env.OPENAI_API_KEY = chat.key;
env.OPENAI_BASE_URL = chat.url;
// TODO: set these dynamically based on the available models
env.APPMAP_NAVIE_TOKEN_LIMIT = '3925';
env.APPMAP_NAVIE_MODEL = 'gpt-4-turbo';
} else if (openAIApiKey) {
if (!chat && openAIApiKey) {
if ('AZURE_OPENAI_API_VERSION' in env) env.AZURE_OPENAI_API_KEY = openAIApiKey;
else env.OPENAI_API_KEY = openAIApiKey;
}
Expand Down
54 changes: 51 additions & 3 deletions test/unit/services/chatCompletion.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import type {
} from 'vscode';

import ChatCompletion from '../../../src/services/chatCompletion';
import { addMockChatModel } from '../mock/vscode/lm';
import { addMockChatModel, resetModelMocks } from '../mock/vscode/lm';
import assert from 'node:assert';

const mockModel: LanguageModelChat = {
Expand All @@ -27,10 +27,15 @@ const mockModel: LanguageModelChat = {
vendor: 'Test Vendor',
};

addMockChatModel(mockModel);

describe('ChatCompletion', () => {
let chatCompletion: ChatCompletion;

beforeEach(async () => {
resetModelMocks();
addMockChatModel(mockModel);
await ChatCompletion.refreshModels();
});

before(async () => {
mockModel.sendRequest = sendRequestEcho;
chatCompletion = new ChatCompletion(0, 'test-key');
Expand All @@ -42,6 +47,49 @@ describe('ChatCompletion', () => {
sinon.restore();
});

it('should return the correct environment variables', () => {
const env = chatCompletion.env;
expect(env).to.deep.equal({
OPENAI_API_KEY: 'test-key',
OPENAI_BASE_URL: chatCompletion.url,
APPMAP_NAVIE_TOKEN_LIMIT: '325',
APPMAP_NAVIE_MODEL: 'test-family',
});
});

it('should refresh models and set the preferred model', async () => {
resetModelMocks();
const mockModel1 = {
...mockModel,
id: 'model-1',
family: 'family-1',
version: 'version-1',
maxInputTokens: 100,
name: 'Model 1',
vendor: 'Vendor 1',
};
const mockModel2 = {
...mockModel,
id: 'model-2',
family: 'family-2',
version: 'version-2',
maxInputTokens: 200,
name: 'Model 2',
vendor: 'Vendor 2',
};
addMockChatModel(mockModel1);
addMockChatModel(mockModel2);

await ChatCompletion.refreshModels();
expect(ChatCompletion.preferredModel).to.equal(mockModel2);
});

it('should return undefined if no models are available', async () => {
resetModelMocks();
await ChatCompletion.refreshModels();
expect(ChatCompletion.preferredModel).to.be.undefined;
});

it('should create a server and listen on a random port', async () => {
const instance = await ChatCompletion.instance;
expect(instance).to.equal(chatCompletion);
Expand Down

0 comments on commit 353ee6f

Please sign in to comment.