Skip to content

Commit

Permalink
Merge pull request #1637 from codestoryai/features/cleanup-test-agent…
Browse files Browse the repository at this point in the history
…-code

[sidecar] cleanup test agent code
theskcd authored Dec 16, 2024
2 parents 8c23ada + 08fbdac commit 59f4ea4
Showing 5 changed files with 2 additions and 793 deletions.
193 changes: 1 addition & 192 deletions sidecar/src/agentic/tool/session/service.rs
Original file line number Diff line number Diff line change
@@ -14,11 +14,9 @@ use crate::{
ui_event::UIEventWithID,
},
tool::{
input::ToolInput,
plan::service::PlanService,
r#type::{Tool, ToolType},
r#type::ToolType,
session::{session::AgentToolUseOutput, tool_use_agent::ToolUseAgent},
terminal::terminal::TerminalInput,
},
},
chunking::text_document::Range,
@@ -349,178 +347,6 @@ impl SessionService {
Ok(())
}

pub async fn tool_use_test_generation(
&self,
session_id: String,
storage_path: String,
repo_name: String,
user_message: String,
exchange_id: String,
all_files: Vec<String>,
open_files: Vec<String>,
shell: String,
project_labels: Vec<String>,
repo_ref: RepoRef,
root_directory: String,
tool_box: Arc<ToolBox>,
llm_broker: Arc<LLMBroker>,
mut message_properties: SymbolEventMessageProperties,
) -> Result<TestGenerateCompletion, SymbolError> {
println!("session_service::test::tool_use_agentic_swe_bench::start");
let mut session = if let Ok(session) = self.load_from_storage(storage_path.to_owned()).await
{
println!(
"session_service::test::load_from_storage_ok::session_id({})",
&session_id
);
session
} else {
self.create_new_session_with_tools(
&session_id,
project_labels.to_vec(),
repo_ref.clone(),
storage_path,
vec![
ToolType::ListFiles,
ToolType::SearchFileContentWithRegex,
ToolType::OpenFile,
ToolType::CodeEditing,
ToolType::AttemptCompletion,
ToolType::RepoMapGeneration,
ToolType::TestRunner,
],
UserContext::default(),
)
};

// os can be passed over here safely since we can assume the sidecar is running
// close to the vscode server
// we should ideally get this information from the vscode-server side setting
let tool_agent = ToolUseAgent::new(
llm_broker.clone(),
root_directory,
std::env::consts::OS.to_owned(),
shell.to_owned(),
Some(repo_name),
true, // this makes it a test generation agent
);

session = session.human_message_tool_use(
exchange_id.to_owned(),
user_message.to_owned(),
all_files,
open_files,
shell,
UserContext::default(),
);
println!("session_service::test_agent::save_to_storage");
let _ = self.save_to_storage(&session).await;

session = session.accept_open_exchanges_if_any(message_properties.clone());

let mut human_message_ticker = 0;
// now that we have saved it we can start the loop over here and look out for the cancellation
// token which will imply that we should end the current loop

let mut iteration_count = 0;
const MAX_ITERATIONS: usize = 10; // Prevent infinite loops

loop {
let _ = self.save_to_storage(&session).await;
let tool_exchange_id = self
.tool_box
.create_new_exchange(session_id.to_owned(), message_properties.clone())
.await?;

let cancellation_token = tokio_util::sync::CancellationToken::new();

message_properties = message_properties
.set_request_id(tool_exchange_id.to_owned())
.set_cancellation_token(cancellation_token.clone());

// track the new exchange over here
self.track_exchange(&session_id, &tool_exchange_id, cancellation_token.clone())
.await;

let tool_use_output = session
// the clone here is pretty bad but its the easiest and the sanest
// way to keep things on the happy path
.clone()
.get_tool_to_use(
tool_box.clone(),
tool_exchange_id,
exchange_id.to_owned(),
tool_agent.clone(),
message_properties.clone(),
)
.await?;

match tool_use_output {
AgentToolUseOutput::Success((tool_input_partial, new_session)) => {
session = new_session;
let _ = self.save_to_storage(&session).await;
let tool_type = tool_input_partial.to_tool_type();
let session_output = session
.invoke_tool(
tool_type.clone(),
tool_input_partial,
tool_box.clone(),
false,
tool_agent.clone(),
user_message.to_owned(),
true,
true,
message_properties.clone(),
)
.await;
// return here if the test case is passing
if matches!(session_output, Err(SymbolError::TestCaseIsPassing)) {
println!("session_service::tool_type::test_case_passing");
break;
}

session = session_output?;

let _ = self.save_to_storage(&session).await;
if matches!(tool_type, ToolType::AskFollowupQuestions)
|| matches!(tool_type, ToolType::AttemptCompletion)
{
// we break if it is any of these 2 events, since these
// require the user to intervene
println!("session_service::tool_use_agentic::reached_terminating_tool");
break;
}

iteration_count += 1;
if iteration_count >= MAX_ITERATIONS {
println!("session_service::tool_use_agentic::hit_iteration_limit");
let git_diff = self.get_git_diff(message_properties.editor_url()).await?;
return Ok(TestGenerateCompletion::HitIterationLimit(git_diff));
}
}
AgentToolUseOutput::Cancelled => {}
AgentToolUseOutput::Failed(failed_to_parse_output) => {
let human_message = format!(
r#"Your output was incorrect, please give me the output in the correct format:
{}"#,
failed_to_parse_output
);
human_message_ticker = human_message_ticker + 1;
session = session.human_message(
human_message_ticker.to_string(),
human_message,
UserContext::default(),
vec![],
repo_ref.clone(),
);
}
}
}

let git_diff = self.get_git_diff(message_properties.editor_url()).await?;
Ok(TestGenerateCompletion::LLMChoseToFinish(git_diff))
}

pub async fn tool_use_agentic_swe_bench(
&self,
session_id: String,
@@ -577,7 +403,6 @@ impl SessionService {
std::env::consts::OS.to_owned(),
shell.to_owned(),
Some(repo_name),
false,
);

session = session.human_message_tool_use(
@@ -743,7 +568,6 @@ impl SessionService {
std::env::consts::OS.to_owned(),
shell.to_owned(),
None,
false,
);

session = session.human_message_tool_use(
@@ -1152,21 +976,6 @@ impl SessionService {
.map_err(|e| SymbolError::IOError(e))?;
Ok(())
}

async fn get_git_diff(&self, editor_url: String) -> Result<String, SymbolError> {
let tool_input =
ToolInput::TerminalCommand(TerminalInput::new("git diff".to_owned(), editor_url));
let tool_output = self
.tool_box
.tools()
.invoke(tool_input)
.await
.map_err(|e| SymbolError::ToolError(e))?
.terminal_command()
.ok_or(SymbolError::WrongToolOutput)?;

Ok(tool_output.output().to_owned())
}
}

#[derive(Debug)]
Loading

0 comments on commit 59f4ea4

Please sign in to comment.