diff --git a/packages/ai-anthropic/src/node/anthropic-language-model.ts b/packages/ai-anthropic/src/node/anthropic-language-model.ts index 36c43095e08d8..4a99d96454b39 100644 --- a/packages/ai-anthropic/src/node/anthropic-language-model.ts +++ b/packages/ai-anthropic/src/node/anthropic-language-model.ts @@ -23,14 +23,32 @@ import { LanguageModelStreamResponsePart, LanguageModelTextResponse } from '@theia/ai-core'; -import { CancellationToken } from '@theia/core'; +import { CancellationToken, isArray } from '@theia/core'; import { Anthropic } from '@anthropic-ai/sdk'; import { MessageParam } from '@anthropic-ai/sdk/resources'; -export const AnthropicModelIdentifier = Symbol('AnthropicModelIdentifier'); +const DEFAULT_MAX_TOKENS_STREAMING = 4096; +const DEFAULT_MAX_TOKENS_NON_STREAMING = 2048; +const EMPTY_INPUT_SCHEMA = { + type: 'object', + properties: {}, + required: [] +} as const; + +interface ToolCallback { + readonly name: string; + readonly id: string; + readonly index: number; + args: string; +} +/** + * Transforms Theia language model messages to Anthropic API format + * @param messages Array of LanguageModelRequestMessage to transform + * @returns Object containing transformed messages and optional system message + */ function transformToAnthropicParams( - messages: LanguageModelRequestMessage[] + messages: readonly LanguageModelRequestMessage[] ): { messages: MessageParam[]; systemMessage?: string } { // Extract the system message (if any), as it is a separate parameter in the Anthropic API. const systemMessageObj = messages.find(message => message.actor === 'system'); @@ -49,6 +67,13 @@ function transformToAnthropicParams( }; } +export const AnthropicModelIdentifier = Symbol('AnthropicModelIdentifier'); + +/** + * Converts Theia message actor to Anthropic role + * @param message The message to convert + * @returns Anthropic role ('user' or 'assistant') + */ function toAnthropicRole(message: LanguageModelRequestMessage): 'user' | 'assistant' { switch (message.actor) { case 'ai': @@ -58,6 +83,9 @@ function toAnthropicRole(message: LanguageModelRequestMessage): 'user' | 'assist } } +/** + * Implements the Anthropic language model integration for Theia + */ export class AnthropicModel implements LanguageModel { constructor( @@ -65,36 +93,61 @@ export class AnthropicModel implements LanguageModel { public model: string, public enableStreaming: boolean, public apiKey: () => string | undefined, - public defaultRequestSettings?: { [key: string]: unknown } + public defaultRequestSettings?: Readonly> ) { } - protected getSettings(request: LanguageModelRequest): Record { - const settings = request.settings ? request.settings : this.defaultRequestSettings; - if (!settings) { - return {}; - } - return settings; + protected getSettings(request: LanguageModelRequest): Readonly> { + return request.settings ?? this.defaultRequestSettings ?? {}; } async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise { + if (!request.messages?.length) { + throw new Error('Request must contain at least one message'); + } + const anthropic = this.initializeAnthropic(); - if (this.enableStreaming) { - return this.handleStreamingRequest(anthropic, request, cancellationToken); + + try { + if (this.enableStreaming) { + return this.handleStreamingRequest(anthropic, request, cancellationToken); + } + return this.handleNonStreamingRequest(anthropic, request); + } catch (error) { + const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'; + throw new Error(`Anthropic API request failed: ${errorMessage}`); } - return this.handleNonStreamingRequest(anthropic, request); + } + + protected formatToolCallResult(result: unknown): string | Array<{ type: 'text', text: string }> { + if (typeof result === 'object' && result && 'content' in result && Array.isArray(result.content) && + result.content.every(item => typeof item === 'object' && item && 'type' in item && 'text' in item)) { + return result.content; + } + + if (isArray(result)) { + return result.map(r => ({ type: 'text', text: r as string })); + } + + if (typeof result === 'object') { + return JSON.stringify(result); + } + + return result as string; } protected async handleStreamingRequest( anthropic: Anthropic, request: LanguageModelRequest, - cancellationToken?: CancellationToken + cancellationToken?: CancellationToken, + toolMessages?: readonly Anthropic.Messages.MessageParam[] ): Promise { const settings = this.getSettings(request); const { messages, systemMessage } = transformToAnthropicParams(request.messages); - + const tools = this.createTools(request); const params: Anthropic.MessageCreateParams = { - max_tokens: 2048, // Setting max_tokens is mandatory for Anthropic, settings can override this default - messages, + max_tokens: DEFAULT_MAX_TOKENS_STREAMING, + messages: [...messages, ...(toolMessages ?? [])], + tools, model: this.model, ...(systemMessage && { system: systemMessage }), ...settings @@ -105,9 +158,14 @@ export class AnthropicModel implements LanguageModel { cancellationToken?.onCancellationRequested(() => { stream.abort(); }); + const that = this; const asyncIterator = { async *[Symbol.asyncIterator](): AsyncIterator { + + const toolCalls: ToolCallback[] = []; + let toolCall: ToolCallback | undefined; + for await (const event of stream) { if (event.type === 'content_block_start') { const contentBlock = event.content_block; @@ -115,12 +173,71 @@ export class AnthropicModel implements LanguageModel { if (contentBlock.type === 'text') { yield { content: contentBlock.text }; } + if (contentBlock.type === 'tool_use') { + toolCall = { name: contentBlock.name!, args: '', id: contentBlock.id!, index: event.index }; + yield { tool_calls: [{ finished: false, id: toolCall.id, function: { name: toolCall.name, arguments: toolCall.args } }] }; + } } else if (event.type === 'content_block_delta') { const delta = event.delta; if (delta.type === 'text_delta') { yield { content: delta.text }; } + if (toolCall && delta.type === 'input_json_delta') { + toolCall.args += delta.partial_json; + yield { tool_calls: [{ function: { arguments: delta.partial_json } }] }; + } + } else if (event.type === 'content_block_stop') { + if (toolCall && toolCall.index === event.index) { + toolCalls.push(toolCall); + toolCall = undefined; + } + } else if (event.type === 'message_delta') { + if (event.delta.stop_reason === 'max_tokens') { + if (toolCall) { + yield { tool_calls: [{ finished: true, id: toolCall.id }] }; + } + throw new Error(`The response was stopped because it exceeded the max token limit of ${event.usage.output_tokens}.`); + } + } + } + if (toolCalls.length > 0) { + const toolResult = await Promise.all(toolCalls.map(async tc => { + const tool = request.tools?.find(t => t.name === tc.name); + const argsObject = tc.args.length === 0 ? '{}' : tc.args; + + return { name: tc.name, result: (await tool?.handler(argsObject)), id: tc.id, arguments: argsObject }; + + })); + + const calls = toolResult.map(tr => { + const resultAsString = typeof tr.result === 'string' ? tr.result : JSON.stringify(tr.result); + return { finished: true, id: tr.id, result: resultAsString, function: { name: tr.name, arguments: tr.arguments } }; + }); + yield { tool_calls: calls }; + + const toolRequestMessage: Anthropic.Messages.MessageParam = { + role: 'assistant', + content: toolResult.map(call => ({ + + type: 'tool_use', + id: call.id, + name: call.name, + input: JSON.parse(call.arguments) + })) + }; + + const toolResponseMessage: Anthropic.Messages.MessageParam = { + role: 'user', + content: toolResult.map(call => ({ + type: 'tool_result', + tool_use_id: call.id!, + content: that.formatToolCallResult(call.result) + })) + }; + const result = await that.handleStreamingRequest(anthropic, request, cancellationToken, [...(toolMessages ?? []), toolRequestMessage, toolResponseMessage]); + for await (const nestedEvent of result.stream) { + yield nestedEvent; } } }, @@ -133,6 +250,14 @@ export class AnthropicModel implements LanguageModel { return { stream: asyncIterator }; } + private createTools(request: LanguageModelRequest): Anthropic.Messages.Tool[] | undefined { + return request.tools?.map(tool => ({ + name: tool.name, + description: tool.description, + input_schema: tool.parameters ?? EMPTY_INPUT_SCHEMA + } as Anthropic.Messages.Tool)); + } + protected async handleNonStreamingRequest( anthropic: Anthropic, request: LanguageModelRequest @@ -142,23 +267,25 @@ export class AnthropicModel implements LanguageModel { const { messages, systemMessage } = transformToAnthropicParams(request.messages); const params: Anthropic.MessageCreateParams = { - max_tokens: 2048, + max_tokens: DEFAULT_MAX_TOKENS_NON_STREAMING, messages, model: this.model, ...(systemMessage && { system: systemMessage }), ...settings, }; - const response: Anthropic.Message = await anthropic.messages.create(params); + try { + const response = await anthropic.messages.create(params); + const textContent = response.content[0]; - if (response.content[0] && response.content[0].type === 'text') { - return { - text: response.content[0].text, - }; + if (textContent?.type === 'text') { + return { text: textContent.text }; + } + + return { text: '' }; + } catch (error) { + throw new Error(`Failed to get response from Anthropic API: ${error instanceof Error ? error.message : 'Unknown error'}`); } - return { - text: '', - }; } protected initializeAnthropic(): Anthropic { @@ -167,6 +294,6 @@ export class AnthropicModel implements LanguageModel { throw new Error('Please provide ANTHROPIC_API_KEY in preferences or via environment variable'); } - return new Anthropic({ apiKey: apiKey }); + return new Anthropic({ apiKey }); } } diff --git a/packages/ai-chat-ui/src/browser/chat-response-renderer/toolcall-part-renderer.tsx b/packages/ai-chat-ui/src/browser/chat-response-renderer/toolcall-part-renderer.tsx index a4915819bdf59..ffdc84dc40dc5 100644 --- a/packages/ai-chat-ui/src/browser/chat-response-renderer/toolcall-part-renderer.tsx +++ b/packages/ai-chat-ui/src/browser/chat-response-renderer/toolcall-part-renderer.tsx @@ -74,10 +74,16 @@ export class ToolCallPartRenderer implements ChatResponsePartRenderer