From 1fe2c1fe62cd517ec1214e87dbb650ff9d1a21f7 Mon Sep 17 00:00:00 2001 From: skcd Date: Tue, 10 Dec 2024 22:45:07 +0000 Subject: [PATCH] [sidecar] add cache point for open router --- llm_client/src/clients/open_router.rs | 46 +++++++++++++++++++++++++-- sidecar/src/bin/swe_bench_mcts.rs | 2 +- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/llm_client/src/clients/open_router.rs b/llm_client/src/clients/open_router.rs index a0dff12d..6677fab2 100644 --- a/llm_client/src/clients/open_router.rs +++ b/llm_client/src/clients/open_router.rs @@ -12,6 +12,17 @@ use super::types::{ use async_trait::async_trait; use eventsource_stream::Eventsource; +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +enum OpenRouterCacheType { + #[serde(rename = "ephemeral")] + Ephemeral, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct OpenRouterCacheControl { + r#type: OpenRouterCacheType, +} + #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] #[serde(rename = "image_url")] struct OpenRouterImageSource { @@ -29,7 +40,10 @@ pub struct OpenRouterRequestMessageToolCall { #[serde(tag = "type")] enum OpenRouterRequestMessageType { #[serde(rename = "text")] - Text { text: String }, + Text { + text: String, + cache_control: Option, + }, #[serde(rename = "image_url")] Image { image_url: OpenRouterImageSource }, #[serde(rename = "tool_result")] @@ -41,7 +55,10 @@ enum OpenRouterRequestMessageType { impl OpenRouterRequestMessageType { pub fn text(message: String) -> Self { - Self::Text { text: message } + Self::Text { + text: message, + cache_control: None, + } } pub fn tool_return(tool_use_id: String, content: String) -> Self { @@ -63,6 +80,19 @@ impl OpenRouterRequestMessageType { }, } } + + pub fn set_cache_control(mut self) -> Self { + if let Self::Text { + text: _, + ref mut cache_control, + } = self + { + *cache_control = Some(OpenRouterCacheControl { + r#type: OpenRouterCacheType::Ephemeral, + }); + } + self + } } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -277,7 +307,17 @@ impl OpenRouterRequest { } else { let content = message.content(); let images = message.images(); - vec![OpenRouterRequestMessageType::text(content.to_owned())] + + // enable cache point if its set, open-router requires + // this for anthropic models, we would need to toggle it + // for openai-models later on + let is_cache_enabled = message.is_cache_point(); + let mut content_messaage = + OpenRouterRequestMessageType::text(content.to_owned()); + if is_cache_enabled { + content_messaage = content_messaage.set_cache_control(); + } + vec![content_messaage] .into_iter() .chain(images.into_iter().map(|image| { OpenRouterRequestMessageType::image(image) diff --git a/sidecar/src/bin/swe_bench_mcts.rs b/sidecar/src/bin/swe_bench_mcts.rs index e772759e..bd9a4466 100644 --- a/sidecar/src/bin/swe_bench_mcts.rs +++ b/sidecar/src/bin/swe_bench_mcts.rs @@ -251,7 +251,7 @@ async fn main() -> Result<(), Box> { // Instantiate the mcts tree over here and start the search let mut search_tree = SearchTree::new( expansions, // max_expansions - 20, // max_depth of the tree + 30, // max_depth of the tree 400, // max_iterations Some(5), // max_finished_nodes None, // reward_threshold