Skip to content

Commit

Permalink
improve tool handling
Browse files Browse the repository at this point in the history
  • Loading branch information
eneufeld committed Jan 29, 2025
1 parent e74e749 commit e4ddde2
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 35 deletions.
125 changes: 91 additions & 34 deletions packages/ai-anthropic/src/node/anthropic-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,28 @@ import { CancellationToken, isArray } from '@theia/core';
import { Anthropic } from '@anthropic-ai/sdk';
import { MessageParam } from '@anthropic-ai/sdk/resources';

const emptyInputSchema = {
const DEFAULT_MAX_TOKENS_STREAMING = 4096;
const DEFAULT_MAX_TOKENS_NON_STREAMING = 2048;
const EMPTY_INPUT_SCHEMA = {
type: 'object',
properties: {},
required: []
};
interface ToolCallback { name: string, args: string, id: string, index: number }
} as const;

interface ToolCallback {
readonly name: string;
readonly id: string;
readonly index: number;
args: string;
}

/**
* Transforms Theia language model messages to Anthropic API format
* @param messages Array of LanguageModelRequestMessage to transform
* @returns Object containing transformed messages and optional system message
*/
function transformToAnthropicParams(
messages: LanguageModelRequestMessage[]
messages: readonly LanguageModelRequestMessage[]
): { messages: MessageParam[]; systemMessage?: string } {
// Extract the system message (if any), as it is a separate parameter in the Anthropic API.
const systemMessageObj = messages.find(message => message.actor === 'system');
Expand All @@ -55,6 +68,12 @@ function transformToAnthropicParams(
}

export const AnthropicModelIdentifier = Symbol('AnthropicModelIdentifier');

/**
* Converts Theia message actor to Anthropic role
* @param message The message to convert
* @returns Anthropic role ('user' or 'assistant')
*/
function toAnthropicRole(message: LanguageModelRequestMessage): 'user' | 'assistant' {
switch (message.actor) {
case 'ai':
Expand All @@ -64,43 +83,69 @@ function toAnthropicRole(message: LanguageModelRequestMessage): 'user' | 'assist
}
}

/**
* Implements the Anthropic language model integration for Theia
*/
export class AnthropicModel implements LanguageModel {

constructor(
public readonly id: string,
public model: string,
public enableStreaming: boolean,
public apiKey: () => string | undefined,
public defaultRequestSettings?: { [key: string]: unknown }
public defaultRequestSettings?: Readonly<Record<string, unknown>>
) { }

protected getSettings(request: LanguageModelRequest): Record<string, unknown> {
const settings = request.settings ? request.settings : this.defaultRequestSettings;
if (!settings) {
return {};
}
return settings;
protected getSettings(request: LanguageModelRequest): Readonly<Record<string, unknown>> {
return request.settings ?? this.defaultRequestSettings ?? {};
}

async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
if (!request.messages?.length) {
throw new Error('Request must contain at least one message');
}

const anthropic = this.initializeAnthropic();
if (this.enableStreaming) {
return this.handleStreamingRequest(anthropic, request, cancellationToken);

try {
if (this.enableStreaming) {
return this.handleStreamingRequest(anthropic, request, cancellationToken);
}
return this.handleNonStreamingRequest(anthropic, request);
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred';
throw new Error(`Anthropic API request failed: ${errorMessage}`);
}
return this.handleNonStreamingRequest(anthropic, request);
}

protected formatToolCallResult(result: unknown): string | Array<{ type: 'text', text: string }> {
if (typeof result === 'object' && result && 'content' in result && Array.isArray(result.content) &&
result.content.every(item => typeof item === 'object' && item && 'type' in item && 'text' in item)) {
return result.content;
}

if (isArray(result)) {
return result.map(r => ({ type: 'text', text: r as string }));
}

if (typeof result === 'object') {
return JSON.stringify(result);
}

return result as string;
}

protected async handleStreamingRequest(
anthropic: Anthropic,
request: LanguageModelRequest,
cancellationToken?: CancellationToken,
toolMessages?: Anthropic.Messages.MessageParam[]
toolMessages?: readonly Anthropic.Messages.MessageParam[]
): Promise<LanguageModelStreamResponse> {
const settings = this.getSettings(request);
const { messages, systemMessage } = transformToAnthropicParams(request.messages);
const tools = this.createTools(request);
const params: Anthropic.MessageCreateParams = {
max_tokens: 2048, // Setting max_tokens is mandatory for Anthropic, settings can override this default
max_tokens: DEFAULT_MAX_TOKENS_STREAMING,
messages: [...messages, ...(toolMessages ?? [])],
tools,
model: this.model,
Expand Down Expand Up @@ -130,6 +175,7 @@ export class AnthropicModel implements LanguageModel {
}
if (contentBlock.type === 'tool_use') {
toolCall = { name: contentBlock.name!, args: '', id: contentBlock.id!, index: event.index };
yield { tool_calls: [{ finished: false, id: toolCall.id, function: { name: toolCall.name, arguments: toolCall.args } }] };
}
} else if (event.type === 'content_block_delta') {
const delta = event.delta;
Expand All @@ -139,32 +185,41 @@ export class AnthropicModel implements LanguageModel {
}
if (toolCall && delta.type === 'input_json_delta') {
toolCall.args += delta.partial_json;
yield { tool_calls: [{ function: { arguments: delta.partial_json } }] };
}
} else if (event.type === 'content_block_stop') {
if (toolCall && toolCall.index === event.index) {
toolCalls.push(toolCall);
yield { tool_calls: [{ finished: false, id: toolCall.id, function: { name: toolCall.name, arguments: toolCall.args } }] };
toolCall = undefined;
}
} else if (event.type === 'message_delta') {
if (event.delta.stop_reason === 'max_tokens') {
if (toolCall) {
yield { tool_calls: [{ finished: true, id: toolCall.id }] };
}
throw new Error(`The response was stopped because it exceeded the max token limit of ${event.usage.output_tokens}.`);
}
}
}
if (toolCalls.length > 0) {
const toolResult = await Promise.all(toolCalls.map(async tc => {
const tool = request.tools?.find(t => t.name === tc.name);
return { name: tc.name, result: (await tool?.handler(tc.args)) as string, id: tc.id };
const argsObject = tc.args.length === 0 ? '{}' : tc.args;

return { name: tc.name, result: (await tool?.handler(argsObject)), id: tc.id, arguments: argsObject };

}));
const calls = toolResult.map(tr => ({ finished: true, id: tr.id, result: tr.result as string }));
const calls = toolResult.map(tr => ({ finished: true, id: tr.id, result: tr.result as string, function: { name: tr.name, arguments: tr.arguments } }));
yield { tool_calls: calls };

const toolRequestMessage: Anthropic.Messages.MessageParam = {
role: 'assistant',
content: toolCalls.map(call => ({
content: calls.map(call => ({

type: 'tool_use',
id: call.id,
name: call.name,
input: JSON.parse(call.args)
name: call.function.name,
input: JSON.parse(call.function.arguments)
}))
};

Expand All @@ -173,7 +228,7 @@ export class AnthropicModel implements LanguageModel {
content: toolResult.map(call => ({
type: 'tool_result',
tool_use_id: call.id!,
content: isArray(call.result!) ? call.result.map(r => ({ type: 'text', text: r as string })) : call.result
content: that.formatToolCallResult(call.result)
}))
};
const result = await that.handleStreamingRequest(anthropic, request, cancellationToken, [...(toolMessages ?? []), toolRequestMessage, toolResponseMessage]);
Expand All @@ -195,7 +250,7 @@ export class AnthropicModel implements LanguageModel {
return request.tools?.map(tool => ({
name: tool.name,
description: tool.description,
input_schema: tool.parameters ?? emptyInputSchema
input_schema: tool.parameters ?? EMPTY_INPUT_SCHEMA
} as Anthropic.Messages.Tool));
}

Expand All @@ -208,23 +263,25 @@ export class AnthropicModel implements LanguageModel {
const { messages, systemMessage } = transformToAnthropicParams(request.messages);

const params: Anthropic.MessageCreateParams = {
max_tokens: 2048,
max_tokens: DEFAULT_MAX_TOKENS_NON_STREAMING,
messages,
model: this.model,
...(systemMessage && { system: systemMessage }),
...settings,
};

const response: Anthropic.Message = await anthropic.messages.create(params);
try {
const response = await anthropic.messages.create(params);
const textContent = response.content[0];

if (response.content[0] && response.content[0].type === 'text') {
return {
text: response.content[0].text,
};
if (textContent?.type === 'text') {
return { text: textContent.text };
}

return { text: '' };
} catch (error) {
throw new Error(`Failed to get response from Anthropic API: ${error instanceof Error ? error.message : 'Unknown error'}`);
}
return {
text: '',
};
}

protected initializeAnthropic(): Anthropic {
Expand All @@ -233,6 +290,6 @@ export class AnthropicModel implements LanguageModel {
throw new Error('Please provide ANTHROPIC_API_KEY in preferences or via environment variable');
}

return new Anthropic({ apiKey: apiKey });
return new Anthropic({ apiKey });
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ export class ToolCallPartRenderer implements ChatResponsePartRenderer<ToolCallCh
let responseContent = response.result;
try {
if (response.result) {
responseContent = JSON.stringify(JSON.parse(response.result), undefined, 2);
let resultObject = response.result;
if (typeof resultObject === 'string') {
resultObject = JSON.parse(resultObject);
}
responseContent = JSON.stringify(resultObject, undefined, 2);
}
} catch (e) {
// fall through
Expand Down

0 comments on commit e4ddde2

Please sign in to comment.