Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Write an OpenAI-compatible error when context overflows #1031

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 66 additions & 10 deletions src/services/chatCompletion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,20 +163,29 @@ export default class ChatCompletion implements Disposable {
debug(`Processing request: ${JSON.stringify(request)}`);

let result: LanguageModelChatResponse;
const messages = toVSCodeMessages(request.messages);
try {
const cancellation = new CancellationTokenSource();
req.on('close', () => cancellation.cancel());
res.on('close', () => cancellation.cancel());
result = await model.sendRequest(toVSCodeMessages(request.messages), {}, cancellation.token);
result = await model.sendRequest(messages, {}, cancellation.token);
} catch (e) {
res.writeHead(422);
res.end(isNativeError(e) && e.message);
warn(`Error processing request: ${e}`);
return;
}

if ('stream' in request && request.stream) streamChatCompletion(res, model, result);
else sendChatCompletionResponse(res, model, result);
const countTokens = async () => {
const tokenCounts = await Promise.all(
messages.map(({ content }) => model.countTokens(content))
);
return tokenCounts.reduce((sum, c) => sum + c, 0);
};

if ('stream' in request && request.stream)
streamChatCompletion(res, model, result, countTokens);
else sendChatCompletionResponse(res, model, result, countTokens);
}

async dispose(): Promise<void> {
Expand Down Expand Up @@ -247,7 +256,8 @@ export default class ChatCompletion implements Disposable {
async function sendChatCompletionResponse(
res: ServerResponse<IncomingMessage>,
model: LanguageModelChat,
result: LanguageModelChatResponse
result: LanguageModelChatResponse,
countTokens: () => Promise<number>
) {
try {
let content = '';
Expand All @@ -257,15 +267,59 @@ async function sendChatCompletionResponse(
} catch (e) {
warn(`Error streaming response: ${e}`);
if (isNativeError(e)) warn(e.stack);
res.writeHead(422);
res.end(isNativeError(e) && e.message);
const apiError = await convertToOpenAiApiError(e, model, countTokens);
res.writeHead(422).end(JSON.stringify(apiError));
}
}

interface OpenAiApiError {
error: {
message: string;
type?: string;
param?: string;
code?: string;
};
}

async function convertToOpenAiApiError(
e: unknown,
model: LanguageModelChat,
countTokens: () => Promise<number>
): Promise<OpenAiApiError> {
if (!isNativeError(e)) return { error: { message: String(e), type: 'server_error' } };

switch (e.message) {
case 'Message exceeds token limit.': {
const error = {
message: `This model's maximum context length is ${model.maxInputTokens} tokens.`,
type: 'invalid_request_error',
param: 'messages',
code: 'context_length_exceeded',
};
try {
const tokensUsed = await countTokens();
error.message += ` However, your messages resulted in ${tokensUsed} tokens.`;
} catch (e) {
warn(`Error counting tokens: ${e}`);
}
return { error };
}

default:
return {
error: {
message: e.message,
type: 'server_error',
},
};
}
}

async function streamChatCompletion(
res: ServerResponse,
model: LanguageModelChat,
result: LanguageModelChatResponse
result: LanguageModelChatResponse,
countTokens: () => Promise<number>
) {
try {
const chunk = prepareChatCompletionChunk(model);
Expand All @@ -282,10 +336,12 @@ async function streamChatCompletion(
} catch (e) {
warn(`Error streaming response: ${e}`);
if (isNativeError(e)) warn(e.stack);
const apiError = await convertToOpenAiApiError(e, model, countTokens);
if (!res.headersSent) {
res.writeHead(422);
res.end(isNativeError(e) && e.message);
} else res.end(`data: ${JSON.stringify({ error: e })}`);
res.writeHead(422, { 'Content-Type': 'application/json' }).end(JSON.stringify(apiError));
} else {
res.end(`data: ${JSON.stringify(apiError)}`);
}
}
}

Expand Down
42 changes: 39 additions & 3 deletions test/unit/services/chatCompletion.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ const mockModel: LanguageModelChat = {
id: 'test-model',
family: 'test-family',
version: 'test-model',
async countTokens(): Promise<number> {
return 0;
async countTokens(content: string): Promise<number> {
return content.length;
},
maxInputTokens: 325,
name: 'Test Model',
Expand Down Expand Up @@ -163,12 +163,13 @@ describe('ChatCompletion', () => {
});

describe('when streaming errors', () => {
let error = new Error('test error');
beforeEach(() => {
mockModel.sendRequest = async () => {
return {
// eslint-disable-next-line require-yield
text: (async function* () {
throw new Error('test error');
throw error;
})(),
};
};
Expand All @@ -189,13 +190,48 @@ describe('ChatCompletion', () => {
});

it('reports errors when not streaming', async () => {
const messages = [
{ content: 'Hello', role: 'user' },
{ content: 'How are you?', role: 'assistant' },
{ content: 'I am good, thank you!', role: 'user' },
];
const response = await postAuthorized(chatCompletion.url, {
model: 'test-model',
messages,
});
expect(response.statusCode).to.equal(422);
});

it('reports token usage when exceeded (streaming)', async () => {
error = new Error('Message exceeds token limit.');
const messages = [
{ content: 'Hello', role: 'user' },
{ content: 'How are you?', role: 'assistant' },
{ content: 'I am good, thank you!', role: 'user' },
];
const response = await postAuthorized(chatCompletion.url, {
model: 'test-model',
messages,
stream: true,
});
expect(response.statusCode).to.equal(422);
expect(response.data).to.equal(
'{"error":{"message":"This model\'s maximum context length is 325 tokens. However, your messages resulted in 38 tokens.","type":"invalid_request_error","param":"messages","code":"context_length_exceeded"}}'
);
});

it('reports token usage when exceeded (non-streaming)', async () => {
error = new Error('Message exceeds token limit.');
const messages = [
{ content: 'Hello', role: 'user' },
{ content: 'How are you?', role: 'assistant' },
{ content: 'I am good, thank you!', role: 'user' },
];
const response = await postAuthorized(chatCompletion.url, { model: 'test-model', messages });
expect(response.statusCode).to.equal(422);
expect(response.data).to.equal(
'{"error":{"message":"This model\'s maximum context length is 325 tokens. However, your messages resulted in 38 tokens.","type":"invalid_request_error","param":"messages","code":"context_length_exceeded"}}'
);
});
});
});
Expand Down
Loading