Skip to content

Commit

Permalink
Merge pull request #1631 from codestoryai/features/add-cache-point-fo…
Browse files Browse the repository at this point in the history
…r-open-router

[sidecar] add cache point for open router
  • Loading branch information
theskcd authored Dec 10, 2024
2 parents 04f7551 + 1fe2c1f commit 59e5a8d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
46 changes: 43 additions & 3 deletions llm_client/src/clients/open_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ use super::types::{
use async_trait::async_trait;
use eventsource_stream::Eventsource;

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
enum OpenRouterCacheType {
#[serde(rename = "ephemeral")]
Ephemeral,
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct OpenRouterCacheControl {
r#type: OpenRouterCacheType,
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
#[serde(rename = "image_url")]
struct OpenRouterImageSource {
Expand All @@ -29,7 +40,10 @@ pub struct OpenRouterRequestMessageToolCall {
#[serde(tag = "type")]
enum OpenRouterRequestMessageType {
#[serde(rename = "text")]
Text { text: String },
Text {
text: String,
cache_control: Option<OpenRouterCacheControl>,
},
#[serde(rename = "image_url")]
Image { image_url: OpenRouterImageSource },
#[serde(rename = "tool_result")]
Expand All @@ -41,7 +55,10 @@ enum OpenRouterRequestMessageType {

impl OpenRouterRequestMessageType {
pub fn text(message: String) -> Self {
Self::Text { text: message }
Self::Text {
text: message,
cache_control: None,
}
}

pub fn tool_return(tool_use_id: String, content: String) -> Self {
Expand All @@ -63,6 +80,19 @@ impl OpenRouterRequestMessageType {
},
}
}

pub fn set_cache_control(mut self) -> Self {
if let Self::Text {
text: _,
ref mut cache_control,
} = self
{
*cache_control = Some(OpenRouterCacheControl {
r#type: OpenRouterCacheType::Ephemeral,
});
}
self
}
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -277,7 +307,17 @@ impl OpenRouterRequest {
} else {
let content = message.content();
let images = message.images();
vec![OpenRouterRequestMessageType::text(content.to_owned())]

// enable cache point if its set, open-router requires
// this for anthropic models, we would need to toggle it
// for openai-models later on
let is_cache_enabled = message.is_cache_point();
let mut content_messaage =
OpenRouterRequestMessageType::text(content.to_owned());
if is_cache_enabled {
content_messaage = content_messaage.set_cache_control();
}
vec![content_messaage]
.into_iter()
.chain(images.into_iter().map(|image| {
OpenRouterRequestMessageType::image(image)
Expand Down
2 changes: 1 addition & 1 deletion sidecar/src/bin/swe_bench_mcts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Instantiate the mcts tree over here and start the search
let mut search_tree = SearchTree::new(
expansions, // max_expansions
20, // max_depth of the tree
30, // max_depth of the tree
400, // max_iterations
Some(5), // max_finished_nodes
None, // reward_threshold
Expand Down

0 comments on commit 59e5a8d

Please sign in to comment.