Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add tool support for antropic streaming #14758

Merged
merged 5 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 150 additions & 27 deletions packages/ai-anthropic/src/node/anthropic-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,32 @@ import {
LanguageModelStreamResponsePart,
LanguageModelTextResponse
} from '@theia/ai-core';
import { CancellationToken } from '@theia/core';
import { CancellationToken, isArray } from '@theia/core';
import { Anthropic } from '@anthropic-ai/sdk';
import { MessageParam } from '@anthropic-ai/sdk/resources';

export const AnthropicModelIdentifier = Symbol('AnthropicModelIdentifier');
const DEFAULT_MAX_TOKENS_STREAMING = 4096;
const DEFAULT_MAX_TOKENS_NON_STREAMING = 2048;
const EMPTY_INPUT_SCHEMA = {
type: 'object',
properties: {},
required: []
} as const;

interface ToolCallback {
readonly name: string;
readonly id: string;
readonly index: number;
args: string;
}

/**
* Transforms Theia language model messages to Anthropic API format
* @param messages Array of LanguageModelRequestMessage to transform
* @returns Object containing transformed messages and optional system message
*/
function transformToAnthropicParams(
messages: LanguageModelRequestMessage[]
messages: readonly LanguageModelRequestMessage[]
): { messages: MessageParam[]; systemMessage?: string } {
// Extract the system message (if any), as it is a separate parameter in the Anthropic API.
const systemMessageObj = messages.find(message => message.actor === 'system');
Expand All @@ -49,6 +67,13 @@ function transformToAnthropicParams(
};
}

export const AnthropicModelIdentifier = Symbol('AnthropicModelIdentifier');

/**
* Converts Theia message actor to Anthropic role
* @param message The message to convert
* @returns Anthropic role ('user' or 'assistant')
*/
function toAnthropicRole(message: LanguageModelRequestMessage): 'user' | 'assistant' {
switch (message.actor) {
case 'ai':
Expand All @@ -58,43 +83,71 @@ function toAnthropicRole(message: LanguageModelRequestMessage): 'user' | 'assist
}
}

/**
* Implements the Anthropic language model integration for Theia
*/
export class AnthropicModel implements LanguageModel {

constructor(
public readonly id: string,
public model: string,
public enableStreaming: boolean,
public apiKey: () => string | undefined,
public defaultRequestSettings?: { [key: string]: unknown }
public defaultRequestSettings?: Readonly<Record<string, unknown>>
) { }

protected getSettings(request: LanguageModelRequest): Record<string, unknown> {
const settings = request.settings ? request.settings : this.defaultRequestSettings;
if (!settings) {
return {};
}
return settings;
protected getSettings(request: LanguageModelRequest): Readonly<Record<string, unknown>> {
return request.settings ?? this.defaultRequestSettings ?? {};
}

async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
if (!request.messages?.length) {
throw new Error('Request must contain at least one message');
}

const anthropic = this.initializeAnthropic();
if (this.enableStreaming) {
return this.handleStreamingRequest(anthropic, request, cancellationToken);

try {
if (this.enableStreaming) {
return this.handleStreamingRequest(anthropic, request, cancellationToken);
}
return this.handleNonStreamingRequest(anthropic, request);
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred';
throw new Error(`Anthropic API request failed: ${errorMessage}`);
}
return this.handleNonStreamingRequest(anthropic, request);
}

protected formatToolCallResult(result: unknown): string | Array<{ type: 'text', text: string }> {
if (typeof result === 'object' && result && 'content' in result && Array.isArray(result.content) &&
result.content.every(item => typeof item === 'object' && item && 'type' in item && 'text' in item)) {
return result.content;
}

if (isArray(result)) {
return result.map(r => ({ type: 'text', text: r as string }));
}

if (typeof result === 'object') {
return JSON.stringify(result);
}

return result as string;
}

protected async handleStreamingRequest(
anthropic: Anthropic,
request: LanguageModelRequest,
cancellationToken?: CancellationToken
cancellationToken?: CancellationToken,
toolMessages?: readonly Anthropic.Messages.MessageParam[]
): Promise<LanguageModelStreamResponse> {
const settings = this.getSettings(request);
const { messages, systemMessage } = transformToAnthropicParams(request.messages);

const tools = this.createTools(request);
const params: Anthropic.MessageCreateParams = {
max_tokens: 2048, // Setting max_tokens is mandatory for Anthropic, settings can override this default
messages,
max_tokens: DEFAULT_MAX_TOKENS_STREAMING,
messages: [...messages, ...(toolMessages ?? [])],
tools,
model: this.model,
...(systemMessage && { system: systemMessage }),
...settings
Expand All @@ -105,22 +158,82 @@ export class AnthropicModel implements LanguageModel {
cancellationToken?.onCancellationRequested(() => {
stream.abort();
});
const that = this;

const asyncIterator = {
async *[Symbol.asyncIterator](): AsyncIterator<LanguageModelStreamResponsePart> {

const toolCalls: ToolCallback[] = [];
let toolCall: ToolCallback | undefined;

for await (const event of stream) {
if (event.type === 'content_block_start') {
const contentBlock = event.content_block;

if (contentBlock.type === 'text') {
yield { content: contentBlock.text };
}
if (contentBlock.type === 'tool_use') {
toolCall = { name: contentBlock.name!, args: '', id: contentBlock.id!, index: event.index };
yield { tool_calls: [{ finished: false, id: toolCall.id, function: { name: toolCall.name, arguments: toolCall.args } }] };
}
} else if (event.type === 'content_block_delta') {
const delta = event.delta;

if (delta.type === 'text_delta') {
yield { content: delta.text };
}
if (toolCall && delta.type === 'input_json_delta') {
toolCall.args += delta.partial_json;
yield { tool_calls: [{ function: { arguments: delta.partial_json } }] };
}
} else if (event.type === 'content_block_stop') {
if (toolCall && toolCall.index === event.index) {
toolCalls.push(toolCall);
toolCall = undefined;
}
} else if (event.type === 'message_delta') {
if (event.delta.stop_reason === 'max_tokens') {
if (toolCall) {
yield { tool_calls: [{ finished: true, id: toolCall.id }] };
}
throw new Error(`The response was stopped because it exceeded the max token limit of ${event.usage.output_tokens}.`);
}
}
}
if (toolCalls.length > 0) {
const toolResult = await Promise.all(toolCalls.map(async tc => {
const tool = request.tools?.find(t => t.name === tc.name);
const argsObject = tc.args.length === 0 ? '{}' : tc.args;

return { name: tc.name, result: (await tool?.handler(argsObject)), id: tc.id, arguments: argsObject };

}));
const calls = toolResult.map(tr => ({ finished: true, id: tr.id, result: tr.result as string, function: { name: tr.name, arguments: tr.arguments } }));
yield { tool_calls: calls };

const toolRequestMessage: Anthropic.Messages.MessageParam = {
role: 'assistant',
content: calls.map(call => ({

type: 'tool_use',
id: call.id,
name: call.function.name,
input: JSON.parse(call.function.arguments)
}))
};

const toolResponseMessage: Anthropic.Messages.MessageParam = {
role: 'user',
content: toolResult.map(call => ({
type: 'tool_result',
tool_use_id: call.id!,
content: that.formatToolCallResult(call.result)
}))
};
const result = await that.handleStreamingRequest(anthropic, request, cancellationToken, [...(toolMessages ?? []), toolRequestMessage, toolResponseMessage]);
for await (const nestedEvent of result.stream) {
yield nestedEvent;
}
}
},
Expand All @@ -133,6 +246,14 @@ export class AnthropicModel implements LanguageModel {
return { stream: asyncIterator };
}

private createTools(request: LanguageModelRequest): Anthropic.Messages.Tool[] | undefined {
return request.tools?.map(tool => ({
name: tool.name,
description: tool.description,
input_schema: tool.parameters ?? EMPTY_INPUT_SCHEMA
} as Anthropic.Messages.Tool));
}

protected async handleNonStreamingRequest(
anthropic: Anthropic,
request: LanguageModelRequest
Expand All @@ -142,23 +263,25 @@ export class AnthropicModel implements LanguageModel {
const { messages, systemMessage } = transformToAnthropicParams(request.messages);

const params: Anthropic.MessageCreateParams = {
max_tokens: 2048,
max_tokens: DEFAULT_MAX_TOKENS_NON_STREAMING,
messages,
model: this.model,
...(systemMessage && { system: systemMessage }),
...settings,
};

const response: Anthropic.Message = await anthropic.messages.create(params);
try {
const response = await anthropic.messages.create(params);
const textContent = response.content[0];

if (response.content[0] && response.content[0].type === 'text') {
return {
text: response.content[0].text,
};
if (textContent?.type === 'text') {
return { text: textContent.text };
}

return { text: '' };
} catch (error) {
throw new Error(`Failed to get response from Anthropic API: ${error instanceof Error ? error.message : 'Unknown error'}`);
}
return {
text: '',
};
}

protected initializeAnthropic(): Anthropic {
Expand All @@ -167,6 +290,6 @@ export class AnthropicModel implements LanguageModel {
throw new Error('Please provide ANTHROPIC_API_KEY in preferences or via environment variable');
}

return new Anthropic({ apiKey: apiKey });
return new Anthropic({ apiKey });
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ export class ToolCallPartRenderer implements ChatResponsePartRenderer<ToolCallCh
let responseContent = response.result;
try {
if (response.result) {
responseContent = JSON.stringify(JSON.parse(response.result), undefined, 2);
let resultObject = response.result;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do this conversion in the model implementation? Is this a corner case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I found out, MCP tools behave differently to our tools.
Our tools return the result usually as a string. But MCP returns it in the format that Anthropics API expect, which is an object with a content field which can be a string or an array of objects with content.
So the rendering completely exploded for me.
I can move the logic to the backend and make sure that the result written to the tool result object is a string.

if (typeof resultObject === 'string') {
resultObject = JSON.parse(resultObject);
}
responseContent = JSON.stringify(resultObject, undefined, 2);
}
} catch (e) {
// fall through
Expand Down
Loading