diff --git a/sidecar/src/agentic/symbol/identifier.rs b/sidecar/src/agentic/symbol/identifier.rs index b38b0754e..ad4f83572 100644 --- a/sidecar/src/agentic/symbol/identifier.rs +++ b/sidecar/src/agentic/symbol/identifier.rs @@ -76,6 +76,12 @@ impl LLMProperties { self.llm = LLMType::GeminiPro; self } + + /// Only allow tool use when we are using anthropic since open-router does not + /// support the str_replace_editor tool natively + pub fn supports_midwit_and_tool_use(&self) -> bool { + self.llm() == &LLMType::ClaudeSonnet && matches!(&self.provider, &LLMProvider::Anthropic) + } } #[derive(Debug, Clone, Eq, PartialEq, std::hash::Hash, serde::Serialize)] diff --git a/sidecar/src/agentic/tool/session/chat.rs b/sidecar/src/agentic/tool/session/chat.rs index 9c893767c..35679ff3f 100644 --- a/sidecar/src/agentic/tool/session/chat.rs +++ b/sidecar/src/agentic/tool/session/chat.rs @@ -153,6 +153,16 @@ impl SessionChatMessage { } } + pub fn insert_tool_return_value(mut self, tool_return: SessionChatToolReturn) -> Self { + self.tool_return.push(tool_return); + self + } + + pub fn insert_tool_use(mut self, tool_use: SessionChatToolUse) -> Self { + self.tool_use.push(tool_use); + self + } + pub fn message(&self) -> &str { &self.message } diff --git a/sidecar/src/agentic/tool/session/service.rs b/sidecar/src/agentic/tool/session/service.rs index 8f6868747..3c1b1c940 100644 --- a/sidecar/src/agentic/tool/session/service.rs +++ b/sidecar/src/agentic/tool/session/service.rs @@ -523,6 +523,7 @@ impl SessionService { tool_box: Arc, llm_broker: Arc, user_context: UserContext, + is_midwit_tool_agent: bool, mut message_properties: SymbolEventMessageProperties, ) -> Result<(), SymbolError> { println!("session_service::tool_use_agentic::start"); @@ -682,6 +683,153 @@ impl SessionService { Ok(()) } + /// Use the tool based json mode over here + async fn tool_use_midwit_json_mode( + &self, + mut session: Session, + session_id: String, + storage_path: String, + user_message: String, + exchange_id: String, + all_files: Vec, + open_files: Vec, + shell: String, + project_labels: Vec, + repo_ref: RepoRef, + root_directory: String, + tool_box: Arc, + llm_broker: Arc, + user_context: UserContext, + is_midwit_tool_agent: bool, + mut message_properties: SymbolEventMessageProperties, + ) -> Result<(), SymbolError> { + // os can be passed over here safely since we can assume the sidecar is running + // close to the vscode server + // we should ideally get this information from the vscode-server side setting + let tool_agent = ToolUseAgent::new( + llm_broker.clone(), + root_directory.to_owned(), + std::env::consts::OS.to_owned(), + shell.to_owned(), + None, + ); + + session = session.human_message_tool_use( + exchange_id.to_owned(), + user_message.to_owned(), + all_files, + open_files, + shell, + user_context, + ); + let _ = self.save_to_storage(&session).await; + + session = session.accept_open_exchanges_if_any(message_properties.clone()); + let mut human_message_ticker = 0; + // now that we have saved it we can start the loop over here and look out for the cancellation + // token which will imply that we should end the current loop + loop { + let _ = self.save_to_storage(&session).await; + let tool_exchange_id = self + .tool_box + .create_new_exchange(session_id.to_owned(), message_properties.clone()) + .await?; + + println!("tool_exchange_id::({:?})", &tool_exchange_id); + + let cancellation_token = tokio_util::sync::CancellationToken::new(); + + message_properties = message_properties + .set_request_id(tool_exchange_id.to_owned()) + .set_cancellation_token(cancellation_token.clone()); + + // track the new exchange over here + self.track_exchange(&session_id, &tool_exchange_id, cancellation_token.clone()) + .await; + + let tool_use_output = session + // the clone here is pretty bad but its the easiest and the sanest + // way to keep things on the happy path + .clone() + .get_tool_to_use( + tool_box.clone(), + tool_exchange_id.to_owned(), + exchange_id.to_owned(), + tool_agent.clone(), + message_properties.clone(), + ) + .await; + + println!("tool_use_output::{:?}", tool_use_output); + + match tool_use_output { + Ok(AgentToolUseOutput::Success((tool_input_partial, new_session))) => { + // update our session + session = new_session; + // store to disk + let _ = self.save_to_storage(&session).await; + let tool_type = tool_input_partial.to_tool_type(); + session = session + .invoke_tool( + tool_type.clone(), + tool_input_partial, + tool_box.clone(), + root_directory.to_owned(), + message_properties.clone(), + ) + .await?; + + let _ = self.save_to_storage(&session).await; + if matches!(tool_type, ToolType::AskFollowupQuestions) + || matches!(tool_type, ToolType::AttemptCompletion) + { + // we break if it is any of these 2 events, since these + // require the user to intervene + println!("session_service::tool_use_agentic::reached_terminating_tool"); + break; + } + } + Ok(AgentToolUseOutput::Cancelled) => { + // if it is cancelled then we should break + break; + } + Ok(AgentToolUseOutput::Failed(failed_to_parse_output)) => { + let human_message = format!( + r#"Your output was incorrect, please give me the output in the correct format: +{}"#, + failed_to_parse_output.to_owned() + ); + human_message_ticker = human_message_ticker + 1; + session = session.human_message( + human_message_ticker.to_string(), + human_message, + UserContext::default(), + vec![], + repo_ref.clone(), + ); + let _ = message_properties + .ui_sender() + .send(UIEventWithID::tool_not_found( + session_id.to_owned(), + tool_exchange_id.to_owned(), + "Failed to get tool output".to_owned(), + )); + } + Err(e) => { + let _ = message_properties + .ui_sender() + .send(UIEventWithID::tool_not_found( + session_id.to_owned(), + tool_exchange_id.to_owned(), + e.to_string(), + )); + Err(e)? + } + } + } + Ok(()) + } + pub async fn code_edit_agentic( &self, session_id: String, diff --git a/sidecar/src/agentic/tool/session/session.rs b/sidecar/src/agentic/tool/session/session.rs index ef8ff2a9f..967dff44b 100644 --- a/sidecar/src/agentic/tool/session/session.rs +++ b/sidecar/src/agentic/tool/session/session.rs @@ -1,7 +1,7 @@ //! We can create a new session over here and its composed of exchanges //! The exchanges can be made by the human or the agent -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use futures::StreamExt; use tokio::io::AsyncWriteExt; @@ -24,7 +24,6 @@ use crate::{ ui_event::UIEventWithID, }, tool::{ - broker::ToolBroker, helpers::diff_recent_changes::DiffFileContent, input::{ToolInput, ToolInputPartial}, lsp::{ @@ -37,6 +36,7 @@ use crate::{ }, r#type::{Tool, ToolType}, repo_map::generator::RepoMapGeneratorRequest, + session::tool_use_agent::{ToolUseAgentInputOnlyTools, ToolUseAgentOutputWithTools}, terminal::terminal::TerminalInput, test_runner::runner::TestRunnerRequest, }, @@ -47,11 +47,22 @@ use crate::{ }; use super::{ - chat::{SessionChatClientRequest, SessionChatMessage, SessionChatMessageImage}, + chat::{ + SessionChatClientRequest, SessionChatMessage, SessionChatMessageImage, + SessionChatToolReturn, SessionChatToolUse, + }, hot_streak::SessionHotStreakRequest, tool_use_agent::{ToolUseAgent, ToolUseAgentInput, ToolUseAgentOutput}, }; +#[derive(Debug)] +struct ToolExecutionOutput { + message: String, + thinking: Option, + expect_correction: bool, + summary: Option, +} + #[derive(Debug)] pub enum AgentToolUseOutput { Success((ToolInputPartial, Session)), @@ -156,6 +167,7 @@ pub struct ExchangeTypeToolOutput { output: String, exchange_id: String, user_context: UserContext, + tool_use_id: String, } impl ExchangeTypeToolOutput { @@ -164,12 +176,14 @@ impl ExchangeTypeToolOutput { output: String, exchange_id: String, user_context: UserContext, + tool_use_id: String, ) -> Self { Self { tool_type, output, exchange_id, user_context, + tool_use_id, } } } @@ -435,6 +449,7 @@ impl Exchange { tool_type: ToolType, output: String, user_context: UserContext, + tool_use_id: String, ) -> Self { Self { exchange_id: exchange_id.to_owned(), @@ -443,6 +458,7 @@ impl Exchange { output, exchange_id.clone(), user_context, + tool_use_id, )), exchange_state: ExchangeState::Running, } @@ -487,7 +503,7 @@ impl Exchange { /// /// We can have consecutive human messages now on every API so this is no /// longer a big worry - async fn to_conversation_message(&self, _tool_broker: Arc) -> SessionChatMessage { + async fn to_conversation_message(&self, is_json_mode: bool) -> SessionChatMessage { match &self.exchange_type { ExchangeType::HumanChat(ref chat_message) => { // TODO(skcd): Figure out caching etc later on @@ -577,29 +593,69 @@ impl Exchange { } } ExchangeReplyAgent::Tool(tool_input) => { - let tool_input_parameters = &tool_input.tool_input_partial; - let thinking = &tool_input.thinking; - SessionChatMessage::assistant( - format!( - r#" -{thinking} - -{}"#, - tool_input_parameters.to_string() - ), - vec![], - ) + if is_json_mode { + let tool_input_parameters = + &tool_input.tool_input_partial.to_json_value(); + match tool_input_parameters { + Some(schema) => { + // TODO(skcd): Figure out if the thinking here is in proper tags + SessionChatMessage::assistant( + tool_input.thinking.to_owned(), + vec![], + ) + .insert_tool_use( + SessionChatToolUse::new( + tool_input + .tool_input_partial + .to_tool_type() + .to_string(), + tool_input.tool_use_id.to_owned(), + schema.clone(), + ), + ) + } + None => SessionChatMessage::assistant( + tool_input.tool_input_partial.to_string(), + vec![], + ), + } + } else { + let tool_input_parameters = &tool_input.tool_input_partial; + let thinking = &tool_input.thinking; + SessionChatMessage::assistant( + format!( + r#" + {thinking} + + {}"#, + tool_input_parameters.to_string() + ), + vec![], + ) + } } } } - ExchangeType::ToolOutput(ref tool_output) => SessionChatMessage::user( - format!( - "Tool Output ({}): {}", - tool_output.tool_type.to_string(), - tool_output.output, - ), - vec![], - ), + ExchangeType::ToolOutput(ref tool_output) => { + if is_json_mode { + SessionChatMessage::user("".to_owned(), vec![]).insert_tool_return_value( + SessionChatToolReturn::new( + tool_output.tool_use_id.to_owned(), + tool_output.tool_type.to_string(), + tool_output.output.to_owned(), + ), + ) + } else { + SessionChatMessage::user( + format!( + "Tool Output ({}): {}", + tool_output.tool_type.to_string(), + tool_output.output, + ), + vec![], + ) + } + } ExchangeType::Plan(ref plan) => { let user_query = &plan.query; SessionChatMessage::user( @@ -822,12 +878,18 @@ impl Session { tool_type: ToolType, output: String, user_context: UserContext, + tool_use_id: String, ) -> Self { self.global_running_user_context = self .global_running_user_context .merge_user_context(user_context.clone()); - let exchange = - Exchange::tool_output(exchange_id.to_owned(), tool_type, output, user_context); + let exchange = Exchange::tool_output( + exchange_id.to_owned(), + tool_type, + output, + user_context, + tool_use_id, + ); self.exchanges.push(exchange); self } @@ -1035,6 +1097,128 @@ impl Session { Ok(self) } + /// TODO(skcd): Figure out how to support this properly + pub async fn get_tool_to_use_json( + mut self, + tool_box: Arc, + excahnge_id: String, + parent_exchange_id: String, + tool_use_agent: ToolUseAgent, + message_properties: SymbolEventMessageProperties, + ) -> Result { + let mut convereted_messages = vec![]; + for previous_message in self.exchanges.iter() { + convereted_messages.push(previous_message.to_conversation_message(true).await); + } + + // grab the terminal output if anything is present and pass it as part of the + // agent input + let pending_spawned_process_output = tool_box + .grab_pending_subprocess_output(message_properties.clone()) + .await?; + + let tool_agent_input = ToolUseAgentInputOnlyTools::new( + convereted_messages, + self.tools + .into_iter() + .filter_map(|tool_type| tool_box.tools().get_tool_json(&tool_type)) + .collect(), + "".to_owned(), + true, + pending_spawned_process_output, + message_properties.clone(), + ); + + // have a retry logic here which tries hard to make sure there are no errors + // when creating the tool which needs to be used + let mut tool_retry_index = 0; + // we can try a max of 3 times before giving up + let max_tool_retry = 3; + + let mut tool_use_output: Result; + loop { + tool_use_output = tool_use_agent + .invoke_json_tool_use(tool_agent_input.clone()) + .await; + if tool_use_output.is_ok() { + // check if the result of running the tool use output is empty + if let Ok(tool_use_output) = tool_use_output.as_ref() { + match tool_use_output { + ToolUseAgentOutputWithTools::Success((tools, _input)) => { + if tools.is_empty() { + println!( + "{}", + format!("inference::enging::retrying_empty_tool_output") + ); + tokio::time::sleep(Duration::from_secs(1)).await; + tool_retry_index = tool_retry_index + 1; + if tool_retry_index >= max_tool_retry { + break; + } + continue; + } + } + _ => {} + } + } + break; + } else { + println!( + "{}", + format!("inference::engine::retrying_tool_call::erroredbefore") + ); + tokio::time::sleep(Duration::from_secs(1)).await; + // just give it a plain retry and call it a day + tool_retry_index = tool_retry_index + 1; + } + if tool_retry_index >= max_tool_retry { + break; + } + } + + // The real problem here is how do we show the feedback to the user over here + // we were previously sending back feedback about the tool parameters which we found + // what should we do now? + // TODO(skcd): Figure out how the data passing here will work + match tool_use_output { + Ok(tool_use_parameters) => match tool_use_parameters { + ToolUseAgentOutputWithTools::Success((tool_input_partial, _thinking)) => { + if tool_input_partial.is_empty() {} + } + ToolUseAgentOutputWithTools::Failure(_thinking) => {} + }, + Err(e) => {} + } + + todo!() + } + + /// Executes the tool and generates the output over here to use from the tool + fn execute_tool_and_generate_observation( + &self, + tool_input_partial: ToolInputPartial, + thinking: String, + tool_box: Arc, + message_properties: SymbolEventMessageProperties, + ) -> Result { + let tool_execution_output = match tool_input_partial { + ToolInputPartial::AskFollowupQuestions(followup_question) => {} + ToolInputPartial::AttemptCompletion(attempt_completion) => {} + ToolInputPartial::CodeEditing(code_editing) => {} + ToolInputPartial::CodeEditorParameters(code_editor_parameters) => {} + ToolInputPartial::LSPDiagnostics(lsp_diagnostics) => {} + ToolInputPartial::ListFiles(list_files) => {} + ToolInputPartial::OpenFile(open_files) => {} + ToolInputPartial::RepoMapGeneration(repo_map_generation) => {} + ToolInputPartial::SearchFileContentWithRegex(search_with_regex) => {} + ToolInputPartial::TerminalCommand(terminal_command) => {} + ToolInputPartial::TestRunner(_) => { + todo!("test runner command is not supported") + } + }; + todo!("figure out how to implement each of these") + } + pub async fn get_tool_to_use( mut self, tool_box: Arc, @@ -1046,11 +1230,7 @@ impl Session { // figure out what to do over here given the state of the session let mut converted_messages = vec![]; for previous_message in self.exchanges.iter() { - converted_messages.push( - previous_message - .to_conversation_message(tool_box.tools().clone()) - .await, - ); + converted_messages.push(previous_message.to_conversation_message(false).await); } // decay the content of the messages depending on the decay condition @@ -1189,11 +1369,7 @@ impl Session { // reply to let mut converted_messages = vec![]; for previous_message in self.exchanges.iter() { - converted_messages.push( - previous_message - .to_conversation_message(tool_box.tools().clone()) - .await, - ); + converted_messages.push(previous_message.to_conversation_message(false).await); } let exchange_id = message_properties.request_id_str().to_owned(); @@ -1428,11 +1604,7 @@ impl Session { // reply to let mut converted_messages = vec![]; for previous_message in self.exchanges.iter() { - converted_messages.push( - previous_message - .to_conversation_message(tool_box.tools().clone()) - .await, - ); + converted_messages.push(previous_message.to_conversation_message(false).await); } let (sender, receiver) = tokio::sync::mpsc::unbounded_channel(); let mut stream_receiver = @@ -1744,11 +1916,7 @@ impl Session { { let mut converted_messages = vec![]; for previous_message in self.exchanges.iter() { - converted_messages.push( - previous_message - .to_conversation_message(tool_box.tools().clone()) - .await, - ); + converted_messages.push(previous_message.to_conversation_message(false).await); } // send a message over that the inference will start in a bit let _ = message_properties @@ -1894,11 +2062,7 @@ impl Session { let mut converted_messages = vec![]; for previous_message in self.exchanges.iter() { - converted_messages.push( - previous_message - .to_conversation_message(tool_box.tools().clone()) - .await, - ); + converted_messages.push(previous_message.to_conversation_message(false).await); } let (diagnostics, mut extra_variables) = tool_box .grab_workspace_diagnostics(message_properties.clone()) @@ -2117,6 +2281,7 @@ impl Session { tool_type.clone(), formatted_output, // truncated UserContext::default(), + exchange_id.to_owned(), ); } ToolInputPartial::AskFollowupQuestions(_followup_question) => { @@ -2225,6 +2390,7 @@ impl Session { diff_changes.l1_changes() ), UserContext::default(), + exchange_id.to_owned(), ); } ToolInputPartial::LSPDiagnostics(diagnostics) => { @@ -2265,6 +2431,7 @@ impl Session { tool_type.clone(), formatted_diagnostics, UserContext::default(), + exchange_id.to_owned(), ); } ToolInputPartial::ListFiles(list_files) => { @@ -2298,6 +2465,7 @@ impl Session { tool_type.clone(), response, UserContext::default(), + exchange_id.to_owned(), ); } ToolInputPartial::OpenFile(open_file) => { @@ -2327,6 +2495,7 @@ impl Session { tool_type.clone(), response, UserContext::default(), + exchange_id.to_owned(), ); } ToolInputPartial::SearchFileContentWithRegex(search_file) => { @@ -2360,6 +2529,7 @@ impl Session { tool_type.clone(), response.to_owned(), UserContext::default(), + exchange_id.to_owned(), ); } ToolInputPartial::TerminalCommand(terminal_command) => { @@ -2389,6 +2559,7 @@ impl Session { tool_type.clone(), output, UserContext::default(), + exchange_id.to_owned(), ); } ToolInputPartial::RepoMapGeneration(repo_map_request) => { @@ -2422,6 +2593,7 @@ impl Session { tool_type.clone(), repo_map_str.to_owned(), UserContext::default(), + exchange_id.to_owned(), ); } ToolInputPartial::CodeEditorParameters(_code_editor_parameters) => { diff --git a/sidecar/src/agentic/tool/session/tool_use_agent.rs b/sidecar/src/agentic/tool/session/tool_use_agent.rs index 8f60de1a9..9f9080c7d 100644 --- a/sidecar/src/agentic/tool/session/tool_use_agent.rs +++ b/sidecar/src/agentic/tool/session/tool_use_agent.rs @@ -45,6 +45,7 @@ pub struct ToolUseAgentInputOnlyTools { tools: Vec, problem_statement: String, is_midwit_mode: bool, + pending_spawned_process_output: Option, symbol_event_message_properties: SymbolEventMessageProperties, } @@ -54,6 +55,7 @@ impl ToolUseAgentInputOnlyTools { tools: Vec, problem_statement: String, is_midwit_mode: bool, + pending_spawned_process_output: Option, symbol_event_message_properties: SymbolEventMessageProperties, ) -> Self { Self { @@ -61,6 +63,7 @@ impl ToolUseAgentInputOnlyTools { tools, problem_statement, is_midwit_mode, + pending_spawned_process_output, symbol_event_message_properties, } } @@ -133,6 +136,55 @@ impl ToolUseAgent { } } + fn system_message_midwit_json_with_notes(&self) -> String { + let working_directory = self.working_directory.to_owned(); + let operating_system = self.operating_system.to_owned(); + let shell = self.shell.to_owned(); + format!( + r#"You are an expert software engineer taked with helping the developer. +You know in detail everything about this repository and all the different code structures which are present in it source code for it. + + +{working_directory} + +I've uploaded a python code repository in the directory {working_directory} (not in /tmp/inputs). + +Can you help me implement the necessary changes to the repository so that the requirements specified by the user are met? +I've also setup the developer environment in {working_directory}. + +Your task is to make the minimal changes to files in the {working_directory} directory to ensure the developer is satisfied. + +Tool capabilities: +- You have access to tools that let you execute CLI commands on the local checkout, list files, view source code definitions, regex search, read and write files. These tools help you effectively accomplish a wide range of tasks, such as writing code, making edits or improvements to existing files, understanding the current state of a project, and much more. +- You can use search_files to perform regex searches across files in a specified directory, outputting context-rich results that include surrounding lines. This is particularly useful for understanding code patterns, finding specific implementations, or identifying areas that need refactoring. +- When using the search_files tool, craft your regex patterns carefully to balance specificity and flexibility. Based on the developer needs you may use it to find code patterns, function definitions, or any text-based information across the project. The results include context, so analyze the surrounding code to better understand the matches. Leverage the search_files tool in combination with other tools for more comprehensive analysis. +- Once a file has been created using `create` on `str_replace_editor` tool, you should not keep creating the same file again and again. Focus on editing the file after it has been created. +- You can run long running terminal commands which can run in the background, we will present you with the updated logs. This can be useful if the user wants you to start a debug server in the terminal and then look at the logs or other long running processes. + +==== + +SYSTEM INFORMATION + +Operating System: {operating_system} +Default Shell: {shell} +Current Working Directory: {working_directory} + +==== + +FOLLOW these steps to resolve the issue: +1. As a first step, it might be a good idea to explore the repo to familiarize yourself with its structure. +2. Open the file called notes.txt where you have previously taken notes about the repository. This will be useful for you to understand what is going on in the repository. You should reuse and make sure the knowledge here is upto date with the repository. Keep making changes to this to keep the notes up to date with your work as well. +3. Edit the sourcecode of the repo to resolve the issue, your job is to make minimal changes. + +Your thinking should be thorough and so it's fine if it's very long. +This is super important and before using any tool you have to output your thinking in section like this:' + +{{your thoughts about using the tool}} + +NEVER forget to include the section before using a tool. We will not be able to invoke the tool properly if you forget it"# + ) + } + fn system_message_midwit_json_mode(&self, repo_name: &str, problem_statement: &str) -> String { let working_directory = self.working_directory.to_owned(); format!( @@ -518,9 +570,221 @@ You accomplish a given task iteratively, breaking it down into clear steps and w ) } + /// Use this when invoking the agent for the normal tool use flow + pub async fn invoke_json_tool_use( + &self, + input: ToolUseAgentInputOnlyTools, + ) -> Result { + let system_message = LLMClientMessage::system(self.system_message_midwit_json_with_notes()) + .insert_tools(input.tools); + + // grab the previous messages as well + let llm_properties = input + .symbol_event_message_properties + .llm_properties() + .clone(); + let mut previous_messages = input + .session_messages + .into_iter() + .map(|session_message| { + let role = session_message.role(); + let tool_use = session_message.tool_use(); + match role { + SessionChatRole::User => { + LLMClientMessage::user(session_message.message().to_owned()) + .with_images( + session_message + .images() + .into_iter() + .map(|session_image| session_image.to_llm_image()) + .collect(), + ) + .insert_tool_return_values( + session_message + .tool_return() + .into_iter() + .map(|tool_return| tool_return.to_llm_tool_return()) + .collect(), + ) + } + SessionChatRole::Assistant => { + LLMClientMessage::assistant(session_message.message().to_owned()) + .insert_tool_use_values( + tool_use + .into_iter() + .map(|tool_use| tool_use.to_llm_tool_use()) + .collect(), + ) + } + } + }) + .collect::>(); + + let mut cache_points_set = 0; + let cache_points_allowed = 3; + previous_messages + .iter_mut() + .rev() + .into_iter() + .for_each(|message| { + if cache_points_set >= cache_points_allowed { + return; + } + if message.is_human_message() { + message.set_cache_point(); + cache_points_set = cache_points_set + 1; + } + }); + + // TODO(skcd): This will not work since we have to grab the pending spawned process output here properly + if previous_messages + .last() + .map(|last_message| last_message.is_human_message()) + .unwrap_or_default() + { + if let Some(pending_spawned_process_output) = input.pending_spawned_process_output { + previous_messages.push(LLMClientMessage::user(format!( + r#" +{} +"#, + pending_spawned_process_output + ))); + } + } + + let root_request_id = input + .symbol_event_message_properties + .root_request_id() + .to_owned(); + let final_messages: Vec<_> = vec![system_message] + .into_iter() + .chain(previous_messages) + .collect::>(); + + let cancellation_token = input.symbol_event_message_properties.cancellation_token(); + + let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); + let cloned_root_request_id = root_request_id.to_owned(); + let response = run_with_cancellation( + cancellation_token.clone(), + tokio::spawn(async move { + if llm_properties.provider().is_anthropic_api_key() { + AnthropicClient::new() + .stream_completion_with_tool( + llm_properties.api_key().clone(), + LLMClientCompletionRequest::new( + llm_properties.llm().clone(), + final_messages, + 0.2, + None, + ), + // llm_properties.provider().clone(), + vec![ + ("event_type".to_owned(), "tool_use".to_owned()), + ("root_id".to_owned(), cloned_root_request_id), + ] + .into_iter() + .collect(), + sender, + ) + .await + } else { + OpenRouterClient::new() + .stream_completion_with_tool( + llm_properties.api_key().clone(), + LLMClientCompletionRequest::new( + llm_properties.llm().clone(), + final_messages, + 0.2, + None, + ), + // llm_properties.provider().clone(), + vec![ + ("event_type".to_owned(), "tool_use".to_owned()), + ("root_id".to_owned(), cloned_root_request_id), + ] + .into_iter() + .collect(), + sender, + ) + .await + } + }), + ) + .await; + + println!("tool_use_agent::invoke_json_tool"); + if let Some(Ok(Ok(response))) = response { + println!("tool_use_agent::invoke_json_tool::reply({:?})", &response); + // we will have a string here representing the thinking and another with the various tool inputs and their json representation + let thinking = response.0; + let tool_inputs = response.1; + let mut tool_inputs_parsed = vec![]; + for (tool_type, tool_input) in tool_inputs.into_iter() { + let tool_use_id = tool_input.0; + let tool_input = tool_input.1; + let tool_input = match tool_type.as_ref() { + "list_files" => ToolInputPartial::ListFiles( + serde_json::from_str::(&tool_input).map_err(|_e| { + SymbolError::ToolError(ToolError::SerdeConversionFailed) + })?, + ), + "search_files" => ToolInputPartial::SearchFileContentWithRegex( + serde_json::from_str::(&tool_input) + .map_err(|_e| { + SymbolError::ToolError(ToolError::SerdeConversionFailed) + })?, + ), + "read_file" => ToolInputPartial::OpenFile( + serde_json::from_str::(&tool_input).map_err( + |_e| SymbolError::ToolError(ToolError::SerdeConversionFailed), + )?, + ), + "execute_command" => ToolInputPartial::TerminalCommand({ + serde_json::from_str::(&tool_input) + .map_err(|_e| SymbolError::ToolError(ToolError::SerdeConversionFailed))? + // well gotta do the hard things sometimes right? + // or the dumb things + .sanitise_for_repro_script() + }), + "attempt_completion" => ToolInputPartial::AttemptCompletion( + serde_json::from_str::(&tool_input) + .map_err(|_e| { + SymbolError::ToolError(ToolError::SerdeConversionFailed) + })?, + ), + "test_runner" => ToolInputPartial::TestRunner( + serde_json::from_str::(&tool_input).map_err( + |_e| SymbolError::ToolError(ToolError::SerdeConversionFailed), + )?, + ), + "str_replace_editor" => ToolInputPartial::CodeEditorParameters( + serde_json::from_str::(&tool_input).map_err(|e| { + println!("str_replace_editor::error::{:?}", e); + SymbolError::ToolError(ToolError::SerdeConversionFailed) + })?, + ), + _ => { + println!("unknow tool found: {}", tool_type); + return Err(SymbolError::WrongToolOutput); + } + }; + tool_inputs_parsed.push((tool_use_id, tool_input)); + } + + Ok(ToolUseAgentOutputWithTools::Success(( + tool_inputs_parsed, + // trim the string properly so we remove all the \n + thinking.trim().to_owned(), + ))) + } else { + Ok(ToolUseAgentOutputWithTools::Failure(None)) + } + } + /// TODO(skcd): This is a special call we are using only for anthropic and nothing /// else right now - pub async fn invoke_json_tool( + pub async fn invoke_json_tool_swe_bench( &self, input: ToolUseAgentInputOnlyTools, ) -> Result { diff --git a/sidecar/src/mcts/execution/inference.rs b/sidecar/src/mcts/execution/inference.rs index 93e08ac4c..b6673fa1f 100644 --- a/sidecar/src/mcts/execution/inference.rs +++ b/sidecar/src/mcts/execution/inference.rs @@ -281,6 +281,7 @@ impl InferenceEngine { .collect(), problem_statement, self.agent_settings.is_midwit(), + None, message_properties.clone(), ); @@ -293,7 +294,7 @@ impl InferenceEngine { let mut tool_use_output: Result; loop { tool_use_output = tool_use_agent - .invoke_json_tool(tool_agent_input.clone()) + .invoke_json_tool_swe_bench(tool_agent_input.clone()) .await; if tool_use_output.is_ok() { // check if the result of running the tool use output is empty diff --git a/sidecar/src/webserver/agentic.rs b/sidecar/src/webserver/agentic.rs index 76d0bea3d..16db0a3cc 100644 --- a/sidecar/src/webserver/agentic.rs +++ b/sidecar/src/webserver/agentic.rs @@ -1440,6 +1440,8 @@ pub async fn agent_tool_use( println!("user_context::({:?})", &user_context); let cancellation_token = tokio_util::sync::CancellationToken::new(); let (sender, receiver) = tokio::sync::mpsc::unbounded_channel(); + // check if model and provider combo supports tool use and midwit agent use + let is_midwit_tool_agent = llm_provider.supports_midwit_and_tool_use(); let message_properties = SymbolEventMessageProperties::new( SymbolEventRequestId::new(exchange_id.to_owned(), session_id.to_string()), sender.clone(), @@ -1472,6 +1474,7 @@ pub async fn agent_tool_use( tool_box, llm_broker, user_context, + is_midwit_tool_agent, message_properties, ) .await;