From afbb9124b46e4b5752c9725ea41837f0862a295f Mon Sep 17 00:00:00 2001 From: skcd Date: Wed, 7 Feb 2024 20:53:06 +0000 Subject: [PATCH] [sidecar] fixes for openai compatible --- llm_client/src/broker.rs | 7 +- llm_client/src/clients/mod.rs | 1 + llm_client/src/clients/openai_compatible.rs | 289 ++++++++++++++++++++ llm_client/src/provider.rs | 16 ++ 4 files changed, 311 insertions(+), 2 deletions(-) create mode 100644 llm_client/src/clients/openai_compatible.rs diff --git a/llm_client/src/broker.rs b/llm_client/src/broker.rs index b1a6dde1b..7adda6cdf 100644 --- a/llm_client/src/broker.rs +++ b/llm_client/src/broker.rs @@ -17,7 +17,7 @@ use crate::{ types::{ LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse, LLMClientCompletionStringRequest, LLMClientError, - }, + }, openai_compatible::OpenAICompatibleClient, }, config::LLMBrokerConfiguration, provider::{CodeStoryLLMTypes, LLMProvider, LLMProviderAPIKeys}, @@ -45,6 +45,7 @@ impl LLMBroker { .add_provider(LLMProvider::Ollama, Box::new(OllamaClient::new())) .add_provider(LLMProvider::TogetherAI, Box::new(TogetherAIClient::new())) .add_provider(LLMProvider::LMStudio, Box::new(LMStudioClient::new())) + .add_provider(LLMProvider::OpenAICompatible, Box::new(OpenAICompatibleClient::new())) .add_provider( LLMProvider::CodeStory(CodeStoryLLMTypes { llm_type: None }), Box::new(CodeStoryClient::new( @@ -102,6 +103,7 @@ impl LLMBroker { LLMProviderAPIKeys::CodeStory => { LLMProvider::CodeStory(CodeStoryLLMTypes { llm_type: None }) } + LLMProviderAPIKeys::OpenAICompatible(_) => LLMProvider::OpenAICompatible, }; let provider = self.providers.get(&provider_type); if let Some(provider) = provider { @@ -171,7 +173,8 @@ impl LLMBroker { LLMProviderAPIKeys::LMStudio(_) => LLMProvider::LMStudio, LLMProviderAPIKeys::CodeStory => { LLMProvider::CodeStory(CodeStoryLLMTypes { llm_type: None }) - } + }, + LLMProviderAPIKeys::OpenAICompatible(_) => LLMProvider::OpenAICompatible, }; let provider = self.providers.get(&provider_type); if let Some(provider) = provider { diff --git a/llm_client/src/clients/mod.rs b/llm_client/src/clients/mod.rs index 9b0ad1859..e692a94a1 100644 --- a/llm_client/src/clients/mod.rs +++ b/llm_client/src/clients/mod.rs @@ -7,3 +7,4 @@ pub mod ollama; pub mod openai; pub mod togetherai; pub mod types; +pub mod openai_compatible; \ No newline at end of file diff --git a/llm_client/src/clients/openai_compatible.rs b/llm_client/src/clients/openai_compatible.rs new file mode 100644 index 000000000..c0004b8b4 --- /dev/null +++ b/llm_client/src/clients/openai_compatible.rs @@ -0,0 +1,289 @@ +//! Client which can help us talk to openai + +use async_openai::{ + config::{AzureConfig, OpenAIConfig}, + types::{ + ChatCompletionRequestMessage, ChatCompletionRequestMessageArgs, + CreateChatCompletionRequestArgs, FunctionCall, Role, CreateCompletionRequestArgs, + }, + Client, +}; +use async_trait::async_trait; +use futures::StreamExt; + +use crate::provider::LLMProviderAPIKeys; + +use super::types::{ + LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse, LLMClientError, + LLMClientMessage, LLMClientRole, LLMType, LLMClientCompletionStringRequest, +}; + +enum OpenAIClientType { + AzureClient(Client), + OpenAIClient(Client), +} + +pub struct OpenAICompatibleClient {} + +impl OpenAICompatibleClient { + pub fn new() -> Self { + Self {} + } + + pub fn model(&self, model: &LLMType) -> Option { + match model { + LLMType::GPT3_5_16k => Some("gpt-3.5-turbo-16k-0613".to_owned()), + LLMType::Gpt4 => Some("gpt-4-0613".to_owned()), + LLMType::Gpt4Turbo => Some("gpt-4-1106-preview".to_owned()), + LLMType::Gpt4_32k => Some("gpt-4-32k-0613".to_owned()), + LLMType::DeepSeekCoder33BInstruct => Some("deepseek-coder-33b".to_owned()), + LLMType::DeepSeekCoder6BInstruct => Some("deepseek-coder-6b".to_owned()), + _ => None, + } + } + + pub fn messages( + &self, + messages: &[LLMClientMessage], + ) -> Result, LLMClientError> { + let formatted_messages = messages + .into_iter() + .map(|message| { + let role = message.role(); + match role { + LLMClientRole::User => ChatCompletionRequestMessageArgs::default() + .role(Role::User) + .content(message.content().to_owned()) + .build() + .map_err(|e| LLMClientError::OpenAPIError(e)), + LLMClientRole::System => ChatCompletionRequestMessageArgs::default() + .role(Role::System) + .content(message.content().to_owned()) + .build() + .map_err(|e| LLMClientError::OpenAPIError(e)), + // the assistant is the one which ends up calling the function, so we need to + // handle the case where the function is called by the assistant here + LLMClientRole::Assistant => match message.get_function_call() { + Some(function_call) => ChatCompletionRequestMessageArgs::default() + .role(Role::Function) + .function_call(FunctionCall { + name: function_call.name().to_owned(), + arguments: function_call.arguments().to_owned(), + }) + .build() + .map_err(|e| LLMClientError::OpenAPIError(e)), + None => ChatCompletionRequestMessageArgs::default() + .role(Role::Assistant) + .content(message.content().to_owned()) + .build() + .map_err(|e| LLMClientError::OpenAPIError(e)), + }, + LLMClientRole::Function => match message.get_function_call() { + Some(function_call) => ChatCompletionRequestMessageArgs::default() + .role(Role::Function) + .content(message.content().to_owned()) + .function_call(FunctionCall { + name: function_call.name().to_owned(), + arguments: function_call.arguments().to_owned(), + }) + .build() + .map_err(|e| LLMClientError::OpenAPIError(e)), + None => Err(LLMClientError::FunctionCallNotPresent), + }, + } + }) + .collect::>(); + formatted_messages + .into_iter() + .collect::, LLMClientError>>() + } + + fn generate_openai_client( + &self, + api_key: LLMProviderAPIKeys, + llm_model: &LLMType, + ) -> Result { + match api_key { + LLMProviderAPIKeys::OpenAICompatible(openai_compatible) => { + let config = OpenAIConfig::new().with_api_key(openai_compatible.api_key).with_api_base(openai_compatible.api_base); + Ok(OpenAIClientType::OpenAIClient(Client::with_config(config))) + } + _ => Err(LLMClientError::WrongAPIKeyType), + } + } + + fn generate_completion_openai_client( + &self, + api_key: LLMProviderAPIKeys, + llm_model: &LLMType, + ) -> Result, LLMClientError> { + match api_key { + LLMProviderAPIKeys::OpenAICompatible(openai_compatible) => { + let config = OpenAIConfig::new().with_api_key(openai_compatible.api_key).with_api_base(openai_compatible.api_base); + Ok(Client::with_config(config)) + } + _ => Err(LLMClientError::WrongAPIKeyType) + } + } +} + +#[async_trait] +impl LLMClient for OpenAICompatibleClient { + fn client(&self) -> &crate::provider::LLMProvider { + &crate::provider::LLMProvider::OpenAICompatible + } + + async fn stream_completion( + &self, + api_key: LLMProviderAPIKeys, + request: LLMClientCompletionRequest, + sender: tokio::sync::mpsc::UnboundedSender, + ) -> Result { + let llm_model = request.model(); + let model = self.model(llm_model); + if model.is_none() { + return Err(LLMClientError::UnSupportedModel); + } + let model = model.unwrap(); + let messages = self.messages(request.messages())?; + let mut request_builder_args = CreateChatCompletionRequestArgs::default(); + let mut request_builder = request_builder_args + .model(model.to_owned()) + .messages(messages) + .temperature(request.temperature()) + .stream(true); + if let Some(frequency_penalty) = request.frequency_penalty() { + request_builder = request_builder.frequency_penalty(frequency_penalty); + } + let request = request_builder.build()?; + let mut buffer = String::new(); + let client = self.generate_openai_client(api_key, llm_model)?; + + // TODO(skcd): Bad code :| we are repeating too many things but this + // just works and we need it right now + match client { + OpenAIClientType::AzureClient(client) => { + let stream_maybe = client.chat().create_stream(request).await; + if stream_maybe.is_err() { + return Err(LLMClientError::OpenAPIError(stream_maybe.err().unwrap())); + } else { + dbg!("no error here"); + } + let mut stream = stream_maybe.unwrap(); + while let Some(response) = stream.next().await { + match response { + Ok(response) => { + let delta = response + .choices + .get(0) + .map(|choice| choice.delta.content.to_owned()) + .flatten() + .unwrap_or("".to_owned()); + let _value = response + .choices + .get(0) + .map(|choice| choice.delta.content.as_ref()) + .flatten(); + buffer.push_str(&delta); + let _ = sender.send(LLMClientCompletionResponse::new( + buffer.to_owned(), + Some(delta), + model.to_owned(), + )); + } + Err(err) => { + dbg!(err); + break; + } + } + } + } + OpenAIClientType::OpenAIClient(client) => { + let mut stream = client.chat().create_stream(request).await?; + while let Some(response) = stream.next().await { + match response { + Ok(response) => { + let response = response + .choices + .get(0) + .ok_or(LLMClientError::FailedToGetResponse)?; + let text = response.delta.content.to_owned(); + if let Some(text) = text { + buffer.push_str(&text); + let _ = sender.send(LLMClientCompletionResponse::new( + buffer.to_owned(), + Some(text), + model.to_owned(), + )); + } + } + Err(err) => { + dbg!(err); + break; + } + } + } + } + } + Ok(buffer) + } + + async fn completion( + &self, + api_key: LLMProviderAPIKeys, + request: LLMClientCompletionRequest, + ) -> Result { + let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); + let result = self.stream_completion(api_key, request, sender).await?; + Ok(result) + } + + async fn stream_prompt_completion( + &self, + api_key: LLMProviderAPIKeys, + request: LLMClientCompletionStringRequest, + sender: tokio::sync::mpsc::UnboundedSender, + ) -> Result { + let llm_model = request.model(); + let model = self.model(llm_model); + if model.is_none() { + return Err(LLMClientError::UnSupportedModel); + } + let model = model.unwrap(); + let mut request_builder_args = CreateCompletionRequestArgs::default(); + let mut request_builder = request_builder_args + .model(model.to_owned()) + .prompt(request.prompt()) + .temperature(request.temperature()) + .stream(true); + if let Some(frequency_penalty) = request.frequency_penalty() { + request_builder = request_builder.frequency_penalty(frequency_penalty); + } + let request = request_builder.build()?; + let mut buffer = String::new(); + let client = self.generate_completion_openai_client(api_key, llm_model)?; + let mut stream = client.completions().create_stream(request).await?; + while let Some(response) = stream.next().await { + match response { + Ok(response) => { + let response = response + .choices + .get(0) + .ok_or(LLMClientError::FailedToGetResponse)?; + let text = response.text.to_owned(); + buffer.push_str(&text); + let _ = sender.send(LLMClientCompletionResponse::new( + buffer.to_owned(), + Some(text), + model.to_owned(), + )); + } + Err(err) => { + dbg!(err); + break; + } + } + } + Ok(buffer) + } +} diff --git a/llm_client/src/provider.rs b/llm_client/src/provider.rs index bbb25d385..c302665e8 100644 --- a/llm_client/src/provider.rs +++ b/llm_client/src/provider.rs @@ -27,6 +27,7 @@ pub enum LLMProvider { LMStudio, CodeStory(CodeStoryLLMTypes), Azure(AzureOpenAIDeploymentId), + OpenAICompatible, } #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] @@ -36,6 +37,7 @@ pub enum LLMProviderAPIKeys { Ollama(OllamaProvider), OpenAIAzureConfig(AzureConfig), LMStudio(LMStudioConfig), + OpenAICompatible(OpenAIComptaibleConfig), CodeStory, } @@ -58,6 +60,7 @@ impl LLMProviderAPIKeys { LLMProviderAPIKeys::CodeStory => { LLMProvider::CodeStory(CodeStoryLLMTypes { llm_type: None }) } + LLMProviderAPIKeys::OpenAICompatible(_) => LLMProvider::OpenAICompatible, } } @@ -112,6 +115,13 @@ impl LLMProviderAPIKeys { } } LLMProvider::CodeStory(_) => Some(LLMProviderAPIKeys::CodeStory), + LLMProvider::OpenAICompatible => { + if let LLMProviderAPIKeys::OpenAICompatible(openai_compatible) = self { + Some(LLMProviderAPIKeys::OpenAICompatible(openai_compatible.clone())) + } else { + None + } + } } } } @@ -132,6 +142,12 @@ impl TogetherAIProvider { } } +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] +pub struct OpenAIComptaibleConfig { + pub api_key: String, + pub api_base: String, +} + #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] pub struct OllamaProvider {}