Skip to content

Commit

Permalink
Merge pull request #447 from codestoryai/features/fixes-for-openai-co…
Browse files Browse the repository at this point in the history
…mpatible

[sidecar] fixes for openai compatible
  • Loading branch information
theskcd authored Feb 7, 2024
2 parents 0652b2e + afbb912 commit 438851f
Show file tree
Hide file tree
Showing 4 changed files with 311 additions and 2 deletions.
7 changes: 5 additions & 2 deletions llm_client/src/broker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
types::{
LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse,
LLMClientCompletionStringRequest, LLMClientError,
},
}, openai_compatible::OpenAICompatibleClient,
},
config::LLMBrokerConfiguration,
provider::{CodeStoryLLMTypes, LLMProvider, LLMProviderAPIKeys},
Expand Down Expand Up @@ -45,6 +45,7 @@ impl LLMBroker {
.add_provider(LLMProvider::Ollama, Box::new(OllamaClient::new()))
.add_provider(LLMProvider::TogetherAI, Box::new(TogetherAIClient::new()))
.add_provider(LLMProvider::LMStudio, Box::new(LMStudioClient::new()))
.add_provider(LLMProvider::OpenAICompatible, Box::new(OpenAICompatibleClient::new()))
.add_provider(
LLMProvider::CodeStory(CodeStoryLLMTypes { llm_type: None }),
Box::new(CodeStoryClient::new(
Expand Down Expand Up @@ -102,6 +103,7 @@ impl LLMBroker {
LLMProviderAPIKeys::CodeStory => {
LLMProvider::CodeStory(CodeStoryLLMTypes { llm_type: None })
}
LLMProviderAPIKeys::OpenAICompatible(_) => LLMProvider::OpenAICompatible,
};
let provider = self.providers.get(&provider_type);
if let Some(provider) = provider {
Expand Down Expand Up @@ -171,7 +173,8 @@ impl LLMBroker {
LLMProviderAPIKeys::LMStudio(_) => LLMProvider::LMStudio,
LLMProviderAPIKeys::CodeStory => {
LLMProvider::CodeStory(CodeStoryLLMTypes { llm_type: None })
}
},
LLMProviderAPIKeys::OpenAICompatible(_) => LLMProvider::OpenAICompatible,
};
let provider = self.providers.get(&provider_type);
if let Some(provider) = provider {
Expand Down
1 change: 1 addition & 0 deletions llm_client/src/clients/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ pub mod ollama;
pub mod openai;
pub mod togetherai;
pub mod types;
pub mod openai_compatible;
289 changes: 289 additions & 0 deletions llm_client/src/clients/openai_compatible.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
//! Client which can help us talk to openai
use async_openai::{
config::{AzureConfig, OpenAIConfig},
types::{
ChatCompletionRequestMessage, ChatCompletionRequestMessageArgs,
CreateChatCompletionRequestArgs, FunctionCall, Role, CreateCompletionRequestArgs,
},
Client,
};
use async_trait::async_trait;
use futures::StreamExt;

use crate::provider::LLMProviderAPIKeys;

use super::types::{
LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse, LLMClientError,
LLMClientMessage, LLMClientRole, LLMType, LLMClientCompletionStringRequest,
};

enum OpenAIClientType {
AzureClient(Client<AzureConfig>),
OpenAIClient(Client<OpenAIConfig>),
}

pub struct OpenAICompatibleClient {}

impl OpenAICompatibleClient {
pub fn new() -> Self {
Self {}
}

pub fn model(&self, model: &LLMType) -> Option<String> {
match model {
LLMType::GPT3_5_16k => Some("gpt-3.5-turbo-16k-0613".to_owned()),
LLMType::Gpt4 => Some("gpt-4-0613".to_owned()),
LLMType::Gpt4Turbo => Some("gpt-4-1106-preview".to_owned()),
LLMType::Gpt4_32k => Some("gpt-4-32k-0613".to_owned()),
LLMType::DeepSeekCoder33BInstruct => Some("deepseek-coder-33b".to_owned()),
LLMType::DeepSeekCoder6BInstruct => Some("deepseek-coder-6b".to_owned()),
_ => None,
}
}

pub fn messages(
&self,
messages: &[LLMClientMessage],
) -> Result<Vec<ChatCompletionRequestMessage>, LLMClientError> {
let formatted_messages = messages
.into_iter()
.map(|message| {
let role = message.role();
match role {
LLMClientRole::User => ChatCompletionRequestMessageArgs::default()
.role(Role::User)
.content(message.content().to_owned())
.build()
.map_err(|e| LLMClientError::OpenAPIError(e)),
LLMClientRole::System => ChatCompletionRequestMessageArgs::default()
.role(Role::System)
.content(message.content().to_owned())
.build()
.map_err(|e| LLMClientError::OpenAPIError(e)),
// the assistant is the one which ends up calling the function, so we need to
// handle the case where the function is called by the assistant here
LLMClientRole::Assistant => match message.get_function_call() {
Some(function_call) => ChatCompletionRequestMessageArgs::default()
.role(Role::Function)
.function_call(FunctionCall {
name: function_call.name().to_owned(),
arguments: function_call.arguments().to_owned(),
})
.build()
.map_err(|e| LLMClientError::OpenAPIError(e)),
None => ChatCompletionRequestMessageArgs::default()
.role(Role::Assistant)
.content(message.content().to_owned())
.build()
.map_err(|e| LLMClientError::OpenAPIError(e)),
},
LLMClientRole::Function => match message.get_function_call() {
Some(function_call) => ChatCompletionRequestMessageArgs::default()
.role(Role::Function)
.content(message.content().to_owned())
.function_call(FunctionCall {
name: function_call.name().to_owned(),
arguments: function_call.arguments().to_owned(),
})
.build()
.map_err(|e| LLMClientError::OpenAPIError(e)),
None => Err(LLMClientError::FunctionCallNotPresent),
},
}
})
.collect::<Vec<_>>();
formatted_messages
.into_iter()
.collect::<Result<Vec<ChatCompletionRequestMessage>, LLMClientError>>()
}

fn generate_openai_client(
&self,
api_key: LLMProviderAPIKeys,
llm_model: &LLMType,
) -> Result<OpenAIClientType, LLMClientError> {
match api_key {
LLMProviderAPIKeys::OpenAICompatible(openai_compatible) => {
let config = OpenAIConfig::new().with_api_key(openai_compatible.api_key).with_api_base(openai_compatible.api_base);
Ok(OpenAIClientType::OpenAIClient(Client::with_config(config)))
}
_ => Err(LLMClientError::WrongAPIKeyType),
}
}

fn generate_completion_openai_client(
&self,
api_key: LLMProviderAPIKeys,
llm_model: &LLMType,
) -> Result<Client<OpenAIConfig>, LLMClientError> {
match api_key {
LLMProviderAPIKeys::OpenAICompatible(openai_compatible) => {
let config = OpenAIConfig::new().with_api_key(openai_compatible.api_key).with_api_base(openai_compatible.api_base);
Ok(Client::with_config(config))
}
_ => Err(LLMClientError::WrongAPIKeyType)
}
}
}

#[async_trait]
impl LLMClient for OpenAICompatibleClient {
fn client(&self) -> &crate::provider::LLMProvider {
&crate::provider::LLMProvider::OpenAICompatible
}

async fn stream_completion(
&self,
api_key: LLMProviderAPIKeys,
request: LLMClientCompletionRequest,
sender: tokio::sync::mpsc::UnboundedSender<LLMClientCompletionResponse>,
) -> Result<String, LLMClientError> {
let llm_model = request.model();
let model = self.model(llm_model);
if model.is_none() {
return Err(LLMClientError::UnSupportedModel);
}
let model = model.unwrap();
let messages = self.messages(request.messages())?;
let mut request_builder_args = CreateChatCompletionRequestArgs::default();
let mut request_builder = request_builder_args
.model(model.to_owned())
.messages(messages)
.temperature(request.temperature())
.stream(true);
if let Some(frequency_penalty) = request.frequency_penalty() {
request_builder = request_builder.frequency_penalty(frequency_penalty);
}
let request = request_builder.build()?;
let mut buffer = String::new();
let client = self.generate_openai_client(api_key, llm_model)?;

// TODO(skcd): Bad code :| we are repeating too many things but this
// just works and we need it right now
match client {
OpenAIClientType::AzureClient(client) => {
let stream_maybe = client.chat().create_stream(request).await;
if stream_maybe.is_err() {
return Err(LLMClientError::OpenAPIError(stream_maybe.err().unwrap()));
} else {
dbg!("no error here");
}
let mut stream = stream_maybe.unwrap();
while let Some(response) = stream.next().await {
match response {
Ok(response) => {
let delta = response
.choices
.get(0)
.map(|choice| choice.delta.content.to_owned())
.flatten()
.unwrap_or("".to_owned());
let _value = response
.choices
.get(0)
.map(|choice| choice.delta.content.as_ref())
.flatten();
buffer.push_str(&delta);
let _ = sender.send(LLMClientCompletionResponse::new(
buffer.to_owned(),
Some(delta),
model.to_owned(),
));
}
Err(err) => {
dbg!(err);
break;
}
}
}
}
OpenAIClientType::OpenAIClient(client) => {
let mut stream = client.chat().create_stream(request).await?;
while let Some(response) = stream.next().await {
match response {
Ok(response) => {
let response = response
.choices
.get(0)
.ok_or(LLMClientError::FailedToGetResponse)?;
let text = response.delta.content.to_owned();
if let Some(text) = text {
buffer.push_str(&text);
let _ = sender.send(LLMClientCompletionResponse::new(
buffer.to_owned(),
Some(text),
model.to_owned(),
));
}
}
Err(err) => {
dbg!(err);
break;
}
}
}
}
}
Ok(buffer)
}

async fn completion(
&self,
api_key: LLMProviderAPIKeys,
request: LLMClientCompletionRequest,
) -> Result<String, LLMClientError> {
let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel();
let result = self.stream_completion(api_key, request, sender).await?;
Ok(result)
}

async fn stream_prompt_completion(
&self,
api_key: LLMProviderAPIKeys,
request: LLMClientCompletionStringRequest,
sender: tokio::sync::mpsc::UnboundedSender<LLMClientCompletionResponse>,
) -> Result<String, LLMClientError> {
let llm_model = request.model();
let model = self.model(llm_model);
if model.is_none() {
return Err(LLMClientError::UnSupportedModel);
}
let model = model.unwrap();
let mut request_builder_args = CreateCompletionRequestArgs::default();
let mut request_builder = request_builder_args
.model(model.to_owned())
.prompt(request.prompt())
.temperature(request.temperature())
.stream(true);
if let Some(frequency_penalty) = request.frequency_penalty() {
request_builder = request_builder.frequency_penalty(frequency_penalty);
}
let request = request_builder.build()?;
let mut buffer = String::new();
let client = self.generate_completion_openai_client(api_key, llm_model)?;
let mut stream = client.completions().create_stream(request).await?;
while let Some(response) = stream.next().await {
match response {
Ok(response) => {
let response = response
.choices
.get(0)
.ok_or(LLMClientError::FailedToGetResponse)?;
let text = response.text.to_owned();
buffer.push_str(&text);
let _ = sender.send(LLMClientCompletionResponse::new(
buffer.to_owned(),
Some(text),
model.to_owned(),
));
}
Err(err) => {
dbg!(err);
break;
}
}
}
Ok(buffer)
}
}
16 changes: 16 additions & 0 deletions llm_client/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub enum LLMProvider {
LMStudio,
CodeStory(CodeStoryLLMTypes),
Azure(AzureOpenAIDeploymentId),
OpenAICompatible,
}

#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
Expand All @@ -36,6 +37,7 @@ pub enum LLMProviderAPIKeys {
Ollama(OllamaProvider),
OpenAIAzureConfig(AzureConfig),
LMStudio(LMStudioConfig),
OpenAICompatible(OpenAIComptaibleConfig),
CodeStory,
}

Expand All @@ -58,6 +60,7 @@ impl LLMProviderAPIKeys {
LLMProviderAPIKeys::CodeStory => {
LLMProvider::CodeStory(CodeStoryLLMTypes { llm_type: None })
}
LLMProviderAPIKeys::OpenAICompatible(_) => LLMProvider::OpenAICompatible,
}
}

Expand Down Expand Up @@ -112,6 +115,13 @@ impl LLMProviderAPIKeys {
}
}
LLMProvider::CodeStory(_) => Some(LLMProviderAPIKeys::CodeStory),
LLMProvider::OpenAICompatible => {
if let LLMProviderAPIKeys::OpenAICompatible(openai_compatible) = self {
Some(LLMProviderAPIKeys::OpenAICompatible(openai_compatible.clone()))
} else {
None
}
}
}
}
}
Expand All @@ -132,6 +142,12 @@ impl TogetherAIProvider {
}
}

#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct OpenAIComptaibleConfig {
pub api_key: String,
pub api_base: String,
}

#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct OllamaProvider {}

Expand Down

0 comments on commit 438851f

Please sign in to comment.