Skip to content

Commit

Permalink
Merge pull request #326 from codestoryai/features/use-user-key-instea…
Browse files Browse the repository at this point in the history
…d-of-azure

[sidecar] use user key instead of the openai one if present
  • Loading branch information
theskcd authored Dec 5, 2023
2 parents 0112642 + 7fd5c06 commit 0cd747d
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 77 deletions.
104 changes: 62 additions & 42 deletions sidecar/src/agent/llm_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,19 +285,19 @@ impl From<&llm::Message> for tiktoken_rs::ChatCompletionRequestMessage {
}

pub struct LlmClient {
gpt4_client: Client<AzureConfig>,
gpt432k_client: Client<AzureConfig>,
gpt3_5_client: Client<AzureConfig>,
gpt3_5_turbo_instruct: Client<OpenAIConfig>,
gpt4_turbo_client: Client<AzureConfig>,
gpt4_client: ClientEndpoint,
gpt432k_client: ClientEndpoint,
gpt3_5_client: ClientEndpoint,
gpt3_5_turbo_instruct: ClientEndpoint,
gpt4_turbo_client: ClientEndpoint,
posthog_client: Arc<PosthogClient>,
sql_db: SqlDb,
user_id: String,
//TODO(skcd): We need a better toggle for this, because our prompt engine should also
// understand the kind of llm we are using and take that into account
// for now, we can keep using the same prompts but the burden of construction
// will fall on every place which constructs the prompt
custom_llm: Option<Client<OpenAIConfig>>,
custom_llm: Option<ClientEndpoint>,
custom_llm_type: LLMType,
}

Expand All @@ -308,14 +308,14 @@ enum OpenAIEventType {
}

#[derive(Debug, Clone)]
pub enum ClientEndpoint<'a> {
OpenAI(&'a Client<OpenAIConfig>),
Azure(&'a Client<AzureConfig>),
pub enum ClientEndpoint {
OpenAI(Client<OpenAIConfig>),
Azure(Client<AzureConfig>),
}

impl<'a> ClientEndpoint<'a> {
impl ClientEndpoint {
pub async fn create_stream(
&'a self,
&self,
request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError> {
match self {
Expand All @@ -325,7 +325,7 @@ impl<'a> ClientEndpoint<'a> {
}

pub async fn create_stream_completion(
&'a self,
&self,
request: CreateCompletionRequest,
) -> Result<CompletionResponseStream, OpenAIError> {
match self {
Expand All @@ -335,7 +335,7 @@ impl<'a> ClientEndpoint<'a> {
}

pub async fn create_chat_completion(
&'a self,
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
match self {
Expand All @@ -344,11 +344,11 @@ impl<'a> ClientEndpoint<'a> {
}
}

fn from_azure_client(azure_client: &'a Client<AzureConfig>) -> Self {
fn from_azure_client(azure_client: Client<AzureConfig>) -> Self {
ClientEndpoint::Azure(azure_client)
}

fn from_openai_client(openai_client: &'a Client<OpenAIConfig>) -> Self {
fn from_openai_client(openai_client: Client<OpenAIConfig>) -> Self {
ClientEndpoint::OpenAI(openai_client)
}
}
Expand Down Expand Up @@ -382,16 +382,16 @@ impl LlmClient {
let custom_llm = match llm_config.non_openai_endpoint() {
Some(endpoint) => {
let config = OpenAIConfig::new().with_api_base(endpoint);
Some(Client::with_config(config))
Some(ClientEndpoint::OpenAI(Client::with_config(config)))
}
None => None,
};
Self {
gpt4_client: Client::with_config(gpt4_config),
gpt432k_client: Client::with_config(gpt4_32k_config),
gpt3_5_client: Client::with_config(gpt3_5_config),
gpt3_5_turbo_instruct: Client::with_config(openai_config),
gpt4_turbo_client: Client::with_config(gpt4_turbo_128k_config),
gpt4_client: ClientEndpoint::Azure(Client::with_config(gpt4_config)),
gpt432k_client: ClientEndpoint::Azure(Client::with_config(gpt4_32k_config)),
gpt3_5_client: ClientEndpoint::Azure(Client::with_config(gpt3_5_config)),
gpt3_5_turbo_instruct: ClientEndpoint::OpenAI(Client::with_config(openai_config)),
gpt4_turbo_client: ClientEndpoint::Azure(Client::with_config(gpt4_turbo_128k_config)),
posthog_client,
sql_db,
user_id,
Expand All @@ -400,6 +400,36 @@ impl LlmClient {
}
}

pub fn user_key_openai(
posthog_client: Arc<PosthogClient>,
sql_db: SqlDb,
user_id: String,
llm_config: LLMCustomConfig,
api_key: String,
) -> LlmClient {
let openai_config = OpenAIConfig::new().with_api_key(api_key);

let gpt4_client = ClientEndpoint::OpenAI(Client::with_config(openai_config.clone()));
let gpt432k_client = ClientEndpoint::OpenAI(Client::with_config(openai_config.clone()));
let gpt3_5_client = ClientEndpoint::OpenAI(Client::with_config(openai_config.clone()));
let gpt3_5_turbo_instruct =
ClientEndpoint::OpenAI(Client::with_config(openai_config.clone()));
let gpt4_turbo_client = ClientEndpoint::OpenAI(Client::with_config(openai_config));

LlmClient {
gpt4_client,
gpt432k_client,
gpt3_5_client,
gpt3_5_turbo_instruct,
gpt4_turbo_client,
posthog_client,
sql_db,
user_id,
custom_llm: None,
custom_llm_type: llm_config.llm.clone(),
}
}

pub async fn capture_openai_request_response<T: serde::Serialize, R: serde::Serialize>(
&self,
request: T,
Expand Down Expand Up @@ -968,48 +998,38 @@ impl LlmClient {
Ok(None)
}

fn get_model(&self, model: &llm::OpenAIModel) -> Option<ClientEndpoint> {
fn get_model(&self, model: &llm::OpenAIModel) -> Option<&ClientEndpoint> {
// If the user has provided a model for us we can use that instead of
// doing anything fancy over here
if let Some(custom_client) = self.custom_llm.as_ref() {
return Some(ClientEndpoint::OpenAI(&custom_client));
return Some(custom_client);
}
let client = match model {
llm::OpenAIModel::GPT4 => ClientEndpoint::from_azure_client(&self.gpt4_client),
llm::OpenAIModel::GPT4_32k => ClientEndpoint::from_azure_client(&self.gpt432k_client),
llm::OpenAIModel::GPT3_5_16k => ClientEndpoint::from_azure_client(&self.gpt3_5_client),
llm::OpenAIModel::GPT4_Turbo => {
ClientEndpoint::from_azure_client(&self.gpt4_turbo_client)
}
llm::OpenAIModel::GPT4 => &self.gpt4_client,
llm::OpenAIModel::GPT4_32k => &self.gpt432k_client,
llm::OpenAIModel::GPT3_5_16k => &self.gpt3_5_client,
llm::OpenAIModel::GPT4_Turbo => &self.gpt4_turbo_client,
llm::OpenAIModel::GPT3_5Instruct => return None,
llm::OpenAIModel::OpenHermes2_5Mistral7b => {
return self
.custom_llm
.as_ref()
.map(|llm| ClientEndpoint::OpenAI(&llm));
return self.custom_llm.as_ref();
}
_ => return None,
};
Some(client)
}

fn get_model_openai(&self, model: &llm::OpenAIModel) -> Option<ClientEndpoint> {
fn get_model_openai(&self, model: &llm::OpenAIModel) -> Option<&ClientEndpoint> {
if let Some(custom_client) = self.custom_llm.as_ref() {
return Some(ClientEndpoint::OpenAI(&custom_client));
return Some(custom_client);
}
let client = match model {
llm::OpenAIModel::GPT4 => return None,
llm::OpenAIModel::GPT4_32k => return None,
llm::OpenAIModel::GPT3_5_16k => return None,
llm::OpenAIModel::GPT4_Turbo => return None,
llm::OpenAIModel::GPT3_5Instruct => {
ClientEndpoint::from_openai_client(&self.gpt3_5_turbo_instruct)
}
llm::OpenAIModel::GPT3_5Instruct => &self.gpt3_5_turbo_instruct,
llm::OpenAIModel::OpenHermes2_5Mistral7b => {
return self
.custom_llm
.as_ref()
.map(|llm| ClientEndpoint::OpenAI(&llm))
return self.custom_llm.as_ref();
}
_ => return None,
};
Expand Down
82 changes: 59 additions & 23 deletions sidecar/src/webserver/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ pub struct FollowupChatRequest {
pub user_context: UserContext,
pub project_labels: Vec<String>,
pub active_window_data: Option<ActiveWindowData>,
pub openai_key: Option<String>,
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -550,6 +551,7 @@ pub async fn followup_chat(
user_context,
project_labels,
active_window_data,
openai_key,
}): Json<FollowupChatRequest>,
) -> Result<impl IntoResponse> {
let llm_config = app.llm_config.clone();
Expand Down Expand Up @@ -604,23 +606,45 @@ pub async fn followup_chat(
let action = AgentAction::Answer {
paths: (0..file_path_len).collect(),
};
let agent = Agent::prepare_for_followup(
app,
repo_ref,
session_id,
Arc::new(LlmClient::codestory_infra(
posthog_client,
sql_db.clone(),
user_id.to_owned(),
llm_config,
)),
sql_db,
previous_messages,
sender,
user_context,
project_labels,
Default::default(),
);

let agent = if let Some(openai_user_key) = openai_key {
Agent::prepare_for_followup(
app,
repo_ref,
session_id,
Arc::new(LlmClient::user_key_openai(
posthog_client,
sql_db.clone(),
user_id.to_owned(),
llm_config,
openai_user_key,
)),
sql_db,
previous_messages,
sender,
user_context,
project_labels,
Default::default(),
)
} else {
Agent::prepare_for_followup(
app,
repo_ref,
session_id,
Arc::new(LlmClient::codestory_infra(
posthog_client,
sql_db.clone(),
user_id.to_owned(),
llm_config,
)),
sql_db,
previous_messages,
sender,
user_context,
project_labels,
Default::default(),
)
};

generate_agent_stream(agent, action, receiver).await
}
Expand All @@ -631,6 +655,7 @@ pub struct GotoDefinitionSymbolsRequest {
pub language: String,
pub repo_ref: RepoRef,
pub thread_id: uuid::Uuid,
pub openai_key: Option<String>,
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
Expand All @@ -647,6 +672,7 @@ pub async fn go_to_definition_symbols(
language,
repo_ref,
thread_id,
openai_key,
}): Json<GotoDefinitionSymbolsRequest>,
) -> Result<impl IntoResponse> {
let posthog_client = app.posthog_client.clone();
Expand All @@ -659,12 +685,22 @@ pub async fn go_to_definition_symbols(
reporef: repo_ref,
session_id: uuid::Uuid::new_v4(),
conversation_messages: vec![],
llm_client: Arc::new(LlmClient::codestory_infra(
posthog_client,
sql_db.clone(),
user_id,
llm_config,
)),
llm_client: if let Some(user_key_openai) = &openai_key {
Arc::new(LlmClient::user_key_openai(
posthog_client,
sql_db.clone(),
user_id,
llm_config,
user_key_openai.to_owned(),
))
} else {
Arc::new(LlmClient::codestory_infra(
posthog_client,
sql_db.clone(),
user_id,
llm_config,
))
},
model: GPT_3_5_TURBO_16K,
sql_db,
sender: tokio::sync::mpsc::channel(100).0,
Expand Down
26 changes: 20 additions & 6 deletions sidecar/src/webserver/file_edit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub struct EditFileRequest {
pub user_query: String,
pub session_id: String,
pub code_block_index: usize,
pub openai_key: Option<String>,
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -138,18 +139,30 @@ pub async fn file_edit(
user_query,
session_id,
code_block_index,
openai_key,
}): Json<EditFileRequest>,
) -> Result<impl IntoResponse> {
// Here we have to first check if the new content is tree-sitter valid, if
// thats the case only then can we apply it to the file
// First we check if the output generated is valid by itself, if it is then
// we can think about applying the changes to the file
let llm_client = Arc::new(LlmClient::codestory_infra(
app.posthog_client.clone(),
app.sql.clone(),
app.user_id.to_owned(),
app.llm_config.clone(),
));
let llm_client = if let Some(openai_key) = openai_key {
Arc::new(LlmClient::user_key_openai(
app.posthog_client.clone(),
app.sql.clone(),
app.user_id.to_owned(),
app.llm_config.clone(),
openai_key,
))
} else {
Arc::new(LlmClient::codestory_infra(
app.posthog_client.clone(),
app.sql.clone(),
app.user_id.to_owned(),
app.llm_config.clone(),
))
};

let file_diff_content = generate_file_diff(
&file_content,
&file_path,
Expand All @@ -158,6 +171,7 @@ pub async fn file_edit(
app.language_parsing.clone(),
)
.await;

if let None = file_diff_content {
let cloned_session_id = session_id.clone();
let init_stream = futures::stream::once(async move {
Expand Down
Loading

0 comments on commit 0cd747d

Please sign in to comment.