From 7fd5c06cf3c513d2c76d40a4463b29e186303b0b Mon Sep 17 00:00:00 2001 From: skcd Date: Tue, 5 Dec 2023 21:12:59 +0530 Subject: [PATCH] [sidecar] use user key instead of the openai one if present --- sidecar/src/agent/llm_funcs.rs | 104 +++++++++++++++---------- sidecar/src/webserver/agent.rs | 82 +++++++++++++------ sidecar/src/webserver/file_edit.rs | 26 +++++-- sidecar/src/webserver/in_line_agent.rs | 25 ++++-- 4 files changed, 160 insertions(+), 77 deletions(-) diff --git a/sidecar/src/agent/llm_funcs.rs b/sidecar/src/agent/llm_funcs.rs index aec6e2551..3687ecff5 100644 --- a/sidecar/src/agent/llm_funcs.rs +++ b/sidecar/src/agent/llm_funcs.rs @@ -285,11 +285,11 @@ impl From<&llm::Message> for tiktoken_rs::ChatCompletionRequestMessage { } pub struct LlmClient { - gpt4_client: Client, - gpt432k_client: Client, - gpt3_5_client: Client, - gpt3_5_turbo_instruct: Client, - gpt4_turbo_client: Client, + gpt4_client: ClientEndpoint, + gpt432k_client: ClientEndpoint, + gpt3_5_client: ClientEndpoint, + gpt3_5_turbo_instruct: ClientEndpoint, + gpt4_turbo_client: ClientEndpoint, posthog_client: Arc, sql_db: SqlDb, user_id: String, @@ -297,7 +297,7 @@ pub struct LlmClient { // understand the kind of llm we are using and take that into account // for now, we can keep using the same prompts but the burden of construction // will fall on every place which constructs the prompt - custom_llm: Option>, + custom_llm: Option, custom_llm_type: LLMType, } @@ -308,14 +308,14 @@ enum OpenAIEventType { } #[derive(Debug, Clone)] -pub enum ClientEndpoint<'a> { - OpenAI(&'a Client), - Azure(&'a Client), +pub enum ClientEndpoint { + OpenAI(Client), + Azure(Client), } -impl<'a> ClientEndpoint<'a> { +impl ClientEndpoint { pub async fn create_stream( - &'a self, + &self, request: CreateChatCompletionRequest, ) -> Result { match self { @@ -325,7 +325,7 @@ impl<'a> ClientEndpoint<'a> { } pub async fn create_stream_completion( - &'a self, + &self, request: CreateCompletionRequest, ) -> Result { match self { @@ -335,7 +335,7 @@ impl<'a> ClientEndpoint<'a> { } pub async fn create_chat_completion( - &'a self, + &self, request: CreateChatCompletionRequest, ) -> Result { match self { @@ -344,11 +344,11 @@ impl<'a> ClientEndpoint<'a> { } } - fn from_azure_client(azure_client: &'a Client) -> Self { + fn from_azure_client(azure_client: Client) -> Self { ClientEndpoint::Azure(azure_client) } - fn from_openai_client(openai_client: &'a Client) -> Self { + fn from_openai_client(openai_client: Client) -> Self { ClientEndpoint::OpenAI(openai_client) } } @@ -382,16 +382,16 @@ impl LlmClient { let custom_llm = match llm_config.non_openai_endpoint() { Some(endpoint) => { let config = OpenAIConfig::new().with_api_base(endpoint); - Some(Client::with_config(config)) + Some(ClientEndpoint::OpenAI(Client::with_config(config))) } None => None, }; Self { - gpt4_client: Client::with_config(gpt4_config), - gpt432k_client: Client::with_config(gpt4_32k_config), - gpt3_5_client: Client::with_config(gpt3_5_config), - gpt3_5_turbo_instruct: Client::with_config(openai_config), - gpt4_turbo_client: Client::with_config(gpt4_turbo_128k_config), + gpt4_client: ClientEndpoint::Azure(Client::with_config(gpt4_config)), + gpt432k_client: ClientEndpoint::Azure(Client::with_config(gpt4_32k_config)), + gpt3_5_client: ClientEndpoint::Azure(Client::with_config(gpt3_5_config)), + gpt3_5_turbo_instruct: ClientEndpoint::OpenAI(Client::with_config(openai_config)), + gpt4_turbo_client: ClientEndpoint::Azure(Client::with_config(gpt4_turbo_128k_config)), posthog_client, sql_db, user_id, @@ -400,6 +400,36 @@ impl LlmClient { } } + pub fn user_key_openai( + posthog_client: Arc, + sql_db: SqlDb, + user_id: String, + llm_config: LLMCustomConfig, + api_key: String, + ) -> LlmClient { + let openai_config = OpenAIConfig::new().with_api_key(api_key); + + let gpt4_client = ClientEndpoint::OpenAI(Client::with_config(openai_config.clone())); + let gpt432k_client = ClientEndpoint::OpenAI(Client::with_config(openai_config.clone())); + let gpt3_5_client = ClientEndpoint::OpenAI(Client::with_config(openai_config.clone())); + let gpt3_5_turbo_instruct = + ClientEndpoint::OpenAI(Client::with_config(openai_config.clone())); + let gpt4_turbo_client = ClientEndpoint::OpenAI(Client::with_config(openai_config)); + + LlmClient { + gpt4_client, + gpt432k_client, + gpt3_5_client, + gpt3_5_turbo_instruct, + gpt4_turbo_client, + posthog_client, + sql_db, + user_id, + custom_llm: None, + custom_llm_type: llm_config.llm.clone(), + } + } + pub async fn capture_openai_request_response( &self, request: T, @@ -968,48 +998,38 @@ impl LlmClient { Ok(None) } - fn get_model(&self, model: &llm::OpenAIModel) -> Option { + fn get_model(&self, model: &llm::OpenAIModel) -> Option<&ClientEndpoint> { // If the user has provided a model for us we can use that instead of // doing anything fancy over here if let Some(custom_client) = self.custom_llm.as_ref() { - return Some(ClientEndpoint::OpenAI(&custom_client)); + return Some(custom_client); } let client = match model { - llm::OpenAIModel::GPT4 => ClientEndpoint::from_azure_client(&self.gpt4_client), - llm::OpenAIModel::GPT4_32k => ClientEndpoint::from_azure_client(&self.gpt432k_client), - llm::OpenAIModel::GPT3_5_16k => ClientEndpoint::from_azure_client(&self.gpt3_5_client), - llm::OpenAIModel::GPT4_Turbo => { - ClientEndpoint::from_azure_client(&self.gpt4_turbo_client) - } + llm::OpenAIModel::GPT4 => &self.gpt4_client, + llm::OpenAIModel::GPT4_32k => &self.gpt432k_client, + llm::OpenAIModel::GPT3_5_16k => &self.gpt3_5_client, + llm::OpenAIModel::GPT4_Turbo => &self.gpt4_turbo_client, llm::OpenAIModel::GPT3_5Instruct => return None, llm::OpenAIModel::OpenHermes2_5Mistral7b => { - return self - .custom_llm - .as_ref() - .map(|llm| ClientEndpoint::OpenAI(&llm)); + return self.custom_llm.as_ref(); } _ => return None, }; Some(client) } - fn get_model_openai(&self, model: &llm::OpenAIModel) -> Option { + fn get_model_openai(&self, model: &llm::OpenAIModel) -> Option<&ClientEndpoint> { if let Some(custom_client) = self.custom_llm.as_ref() { - return Some(ClientEndpoint::OpenAI(&custom_client)); + return Some(custom_client); } let client = match model { llm::OpenAIModel::GPT4 => return None, llm::OpenAIModel::GPT4_32k => return None, llm::OpenAIModel::GPT3_5_16k => return None, llm::OpenAIModel::GPT4_Turbo => return None, - llm::OpenAIModel::GPT3_5Instruct => { - ClientEndpoint::from_openai_client(&self.gpt3_5_turbo_instruct) - } + llm::OpenAIModel::GPT3_5Instruct => &self.gpt3_5_turbo_instruct, llm::OpenAIModel::OpenHermes2_5Mistral7b => { - return self - .custom_llm - .as_ref() - .map(|llm| ClientEndpoint::OpenAI(&llm)) + return self.custom_llm.as_ref(); } _ => return None, }; diff --git a/sidecar/src/webserver/agent.rs b/sidecar/src/webserver/agent.rs index b37bada02..c00db0e47 100644 --- a/sidecar/src/webserver/agent.rs +++ b/sidecar/src/webserver/agent.rs @@ -471,6 +471,7 @@ pub struct FollowupChatRequest { pub user_context: UserContext, pub project_labels: Vec, pub active_window_data: Option, + pub openai_key: Option, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -550,6 +551,7 @@ pub async fn followup_chat( user_context, project_labels, active_window_data, + openai_key, }): Json, ) -> Result { let llm_config = app.llm_config.clone(); @@ -604,23 +606,45 @@ pub async fn followup_chat( let action = AgentAction::Answer { paths: (0..file_path_len).collect(), }; - let agent = Agent::prepare_for_followup( - app, - repo_ref, - session_id, - Arc::new(LlmClient::codestory_infra( - posthog_client, - sql_db.clone(), - user_id.to_owned(), - llm_config, - )), - sql_db, - previous_messages, - sender, - user_context, - project_labels, - Default::default(), - ); + + let agent = if let Some(openai_user_key) = openai_key { + Agent::prepare_for_followup( + app, + repo_ref, + session_id, + Arc::new(LlmClient::user_key_openai( + posthog_client, + sql_db.clone(), + user_id.to_owned(), + llm_config, + openai_user_key, + )), + sql_db, + previous_messages, + sender, + user_context, + project_labels, + Default::default(), + ) + } else { + Agent::prepare_for_followup( + app, + repo_ref, + session_id, + Arc::new(LlmClient::codestory_infra( + posthog_client, + sql_db.clone(), + user_id.to_owned(), + llm_config, + )), + sql_db, + previous_messages, + sender, + user_context, + project_labels, + Default::default(), + ) + }; generate_agent_stream(agent, action, receiver).await } @@ -631,6 +655,7 @@ pub struct GotoDefinitionSymbolsRequest { pub language: String, pub repo_ref: RepoRef, pub thread_id: uuid::Uuid, + pub openai_key: Option, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -647,6 +672,7 @@ pub async fn go_to_definition_symbols( language, repo_ref, thread_id, + openai_key, }): Json, ) -> Result { let posthog_client = app.posthog_client.clone(); @@ -659,12 +685,22 @@ pub async fn go_to_definition_symbols( reporef: repo_ref, session_id: uuid::Uuid::new_v4(), conversation_messages: vec![], - llm_client: Arc::new(LlmClient::codestory_infra( - posthog_client, - sql_db.clone(), - user_id, - llm_config, - )), + llm_client: if let Some(user_key_openai) = &openai_key { + Arc::new(LlmClient::user_key_openai( + posthog_client, + sql_db.clone(), + user_id, + llm_config, + user_key_openai.to_owned(), + )) + } else { + Arc::new(LlmClient::codestory_infra( + posthog_client, + sql_db.clone(), + user_id, + llm_config, + )) + }, model: GPT_3_5_TURBO_16K, sql_db, sender: tokio::sync::mpsc::channel(100).0, diff --git a/sidecar/src/webserver/file_edit.rs b/sidecar/src/webserver/file_edit.rs index ad68ae65b..ad7dccb23 100644 --- a/sidecar/src/webserver/file_edit.rs +++ b/sidecar/src/webserver/file_edit.rs @@ -34,6 +34,7 @@ pub struct EditFileRequest { pub user_query: String, pub session_id: String, pub code_block_index: usize, + pub openai_key: Option, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -138,18 +139,30 @@ pub async fn file_edit( user_query, session_id, code_block_index, + openai_key, }): Json, ) -> Result { // Here we have to first check if the new content is tree-sitter valid, if // thats the case only then can we apply it to the file // First we check if the output generated is valid by itself, if it is then // we can think about applying the changes to the file - let llm_client = Arc::new(LlmClient::codestory_infra( - app.posthog_client.clone(), - app.sql.clone(), - app.user_id.to_owned(), - app.llm_config.clone(), - )); + let llm_client = if let Some(openai_key) = openai_key { + Arc::new(LlmClient::user_key_openai( + app.posthog_client.clone(), + app.sql.clone(), + app.user_id.to_owned(), + app.llm_config.clone(), + openai_key, + )) + } else { + Arc::new(LlmClient::codestory_infra( + app.posthog_client.clone(), + app.sql.clone(), + app.user_id.to_owned(), + app.llm_config.clone(), + )) + }; + let file_diff_content = generate_file_diff( &file_content, &file_path, @@ -158,6 +171,7 @@ pub async fn file_edit( app.language_parsing.clone(), ) .await; + if let None = file_diff_content { let cloned_session_id = session_id.clone(); let init_stream = futures::stream::once(async move { diff --git a/sidecar/src/webserver/in_line_agent.rs b/sidecar/src/webserver/in_line_agent.rs index 7e35b1cb9..28065a6e7 100644 --- a/sidecar/src/webserver/in_line_agent.rs +++ b/sidecar/src/webserver/in_line_agent.rs @@ -87,6 +87,7 @@ pub struct ProcessInEditorRequest { pub text_document_web: TextDocumentWeb, pub thread_id: uuid::Uuid, pub diagnostics_information: Option, + pub openai_key: Option, } impl ProcessInEditorRequest { @@ -129,6 +130,7 @@ pub async fn reply_to_user( thread_id, text_document_web, diagnostics_information, + openai_key, }): Json, ) -> Result { let editor_parsing: EditorParsing = Default::default(); @@ -136,12 +138,22 @@ pub async fn reply_to_user( // the proper things // Here we will handle how the in-line agent will handle the work let sql_db = app.sql.clone(); - let llm_client = LlmClient::codestory_infra( - app.posthog_client.clone(), - app.sql.clone(), - app.user_id.to_owned(), - app.llm_config.clone(), - ); + let llm_client = if let Some(user_key_openai) = &openai_key { + LlmClient::user_key_openai( + app.posthog_client.clone(), + app.sql.clone(), + app.user_id.to_owned(), + app.llm_config.clone(), + user_key_openai.to_owned(), + ) + } else { + LlmClient::codestory_infra( + app.posthog_client.clone(), + app.sql.clone(), + app.user_id.to_owned(), + app.llm_config.clone(), + ) + }; let (sender, receiver) = tokio::sync::mpsc::channel(100); let inline_agent_message = InLineAgentMessage::start_message(thread_id, query.to_owned()); snippet_information = @@ -161,6 +173,7 @@ pub async fn reply_to_user( text_document_web, thread_id, diagnostics_information, + openai_key, }, vec![inline_agent_message], sender,