-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[sidecar] fixes for openai compatible
- Loading branch information
Showing
4 changed files
with
311 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,4 @@ pub mod ollama; | |
pub mod openai; | ||
pub mod togetherai; | ||
pub mod types; | ||
pub mod openai_compatible; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,289 @@ | ||
//! Client which can help us talk to openai | ||
use async_openai::{ | ||
config::{AzureConfig, OpenAIConfig}, | ||
types::{ | ||
ChatCompletionRequestMessage, ChatCompletionRequestMessageArgs, | ||
CreateChatCompletionRequestArgs, FunctionCall, Role, CreateCompletionRequestArgs, | ||
}, | ||
Client, | ||
}; | ||
use async_trait::async_trait; | ||
use futures::StreamExt; | ||
|
||
use crate::provider::LLMProviderAPIKeys; | ||
|
||
use super::types::{ | ||
LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse, LLMClientError, | ||
LLMClientMessage, LLMClientRole, LLMType, LLMClientCompletionStringRequest, | ||
}; | ||
|
||
enum OpenAIClientType { | ||
AzureClient(Client<AzureConfig>), | ||
OpenAIClient(Client<OpenAIConfig>), | ||
} | ||
|
||
pub struct OpenAICompatibleClient {} | ||
|
||
impl OpenAICompatibleClient { | ||
pub fn new() -> Self { | ||
Self {} | ||
} | ||
|
||
pub fn model(&self, model: &LLMType) -> Option<String> { | ||
match model { | ||
LLMType::GPT3_5_16k => Some("gpt-3.5-turbo-16k-0613".to_owned()), | ||
LLMType::Gpt4 => Some("gpt-4-0613".to_owned()), | ||
LLMType::Gpt4Turbo => Some("gpt-4-1106-preview".to_owned()), | ||
LLMType::Gpt4_32k => Some("gpt-4-32k-0613".to_owned()), | ||
LLMType::DeepSeekCoder33BInstruct => Some("deepseek-coder-33b".to_owned()), | ||
LLMType::DeepSeekCoder6BInstruct => Some("deepseek-coder-6b".to_owned()), | ||
_ => None, | ||
} | ||
} | ||
|
||
pub fn messages( | ||
&self, | ||
messages: &[LLMClientMessage], | ||
) -> Result<Vec<ChatCompletionRequestMessage>, LLMClientError> { | ||
let formatted_messages = messages | ||
.into_iter() | ||
.map(|message| { | ||
let role = message.role(); | ||
match role { | ||
LLMClientRole::User => ChatCompletionRequestMessageArgs::default() | ||
.role(Role::User) | ||
.content(message.content().to_owned()) | ||
.build() | ||
.map_err(|e| LLMClientError::OpenAPIError(e)), | ||
LLMClientRole::System => ChatCompletionRequestMessageArgs::default() | ||
.role(Role::System) | ||
.content(message.content().to_owned()) | ||
.build() | ||
.map_err(|e| LLMClientError::OpenAPIError(e)), | ||
// the assistant is the one which ends up calling the function, so we need to | ||
// handle the case where the function is called by the assistant here | ||
LLMClientRole::Assistant => match message.get_function_call() { | ||
Some(function_call) => ChatCompletionRequestMessageArgs::default() | ||
.role(Role::Function) | ||
.function_call(FunctionCall { | ||
name: function_call.name().to_owned(), | ||
arguments: function_call.arguments().to_owned(), | ||
}) | ||
.build() | ||
.map_err(|e| LLMClientError::OpenAPIError(e)), | ||
None => ChatCompletionRequestMessageArgs::default() | ||
.role(Role::Assistant) | ||
.content(message.content().to_owned()) | ||
.build() | ||
.map_err(|e| LLMClientError::OpenAPIError(e)), | ||
}, | ||
LLMClientRole::Function => match message.get_function_call() { | ||
Some(function_call) => ChatCompletionRequestMessageArgs::default() | ||
.role(Role::Function) | ||
.content(message.content().to_owned()) | ||
.function_call(FunctionCall { | ||
name: function_call.name().to_owned(), | ||
arguments: function_call.arguments().to_owned(), | ||
}) | ||
.build() | ||
.map_err(|e| LLMClientError::OpenAPIError(e)), | ||
None => Err(LLMClientError::FunctionCallNotPresent), | ||
}, | ||
} | ||
}) | ||
.collect::<Vec<_>>(); | ||
formatted_messages | ||
.into_iter() | ||
.collect::<Result<Vec<ChatCompletionRequestMessage>, LLMClientError>>() | ||
} | ||
|
||
fn generate_openai_client( | ||
&self, | ||
api_key: LLMProviderAPIKeys, | ||
llm_model: &LLMType, | ||
) -> Result<OpenAIClientType, LLMClientError> { | ||
match api_key { | ||
LLMProviderAPIKeys::OpenAICompatible(openai_compatible) => { | ||
let config = OpenAIConfig::new().with_api_key(openai_compatible.api_key).with_api_base(openai_compatible.api_base); | ||
Ok(OpenAIClientType::OpenAIClient(Client::with_config(config))) | ||
} | ||
_ => Err(LLMClientError::WrongAPIKeyType), | ||
} | ||
} | ||
|
||
fn generate_completion_openai_client( | ||
&self, | ||
api_key: LLMProviderAPIKeys, | ||
llm_model: &LLMType, | ||
) -> Result<Client<OpenAIConfig>, LLMClientError> { | ||
match api_key { | ||
LLMProviderAPIKeys::OpenAICompatible(openai_compatible) => { | ||
let config = OpenAIConfig::new().with_api_key(openai_compatible.api_key).with_api_base(openai_compatible.api_base); | ||
Ok(Client::with_config(config)) | ||
} | ||
_ => Err(LLMClientError::WrongAPIKeyType) | ||
} | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl LLMClient for OpenAICompatibleClient { | ||
fn client(&self) -> &crate::provider::LLMProvider { | ||
&crate::provider::LLMProvider::OpenAICompatible | ||
} | ||
|
||
async fn stream_completion( | ||
&self, | ||
api_key: LLMProviderAPIKeys, | ||
request: LLMClientCompletionRequest, | ||
sender: tokio::sync::mpsc::UnboundedSender<LLMClientCompletionResponse>, | ||
) -> Result<String, LLMClientError> { | ||
let llm_model = request.model(); | ||
let model = self.model(llm_model); | ||
if model.is_none() { | ||
return Err(LLMClientError::UnSupportedModel); | ||
} | ||
let model = model.unwrap(); | ||
let messages = self.messages(request.messages())?; | ||
let mut request_builder_args = CreateChatCompletionRequestArgs::default(); | ||
let mut request_builder = request_builder_args | ||
.model(model.to_owned()) | ||
.messages(messages) | ||
.temperature(request.temperature()) | ||
.stream(true); | ||
if let Some(frequency_penalty) = request.frequency_penalty() { | ||
request_builder = request_builder.frequency_penalty(frequency_penalty); | ||
} | ||
let request = request_builder.build()?; | ||
let mut buffer = String::new(); | ||
let client = self.generate_openai_client(api_key, llm_model)?; | ||
|
||
// TODO(skcd): Bad code :| we are repeating too many things but this | ||
// just works and we need it right now | ||
match client { | ||
OpenAIClientType::AzureClient(client) => { | ||
let stream_maybe = client.chat().create_stream(request).await; | ||
if stream_maybe.is_err() { | ||
return Err(LLMClientError::OpenAPIError(stream_maybe.err().unwrap())); | ||
} else { | ||
dbg!("no error here"); | ||
} | ||
let mut stream = stream_maybe.unwrap(); | ||
while let Some(response) = stream.next().await { | ||
match response { | ||
Ok(response) => { | ||
let delta = response | ||
.choices | ||
.get(0) | ||
.map(|choice| choice.delta.content.to_owned()) | ||
.flatten() | ||
.unwrap_or("".to_owned()); | ||
let _value = response | ||
.choices | ||
.get(0) | ||
.map(|choice| choice.delta.content.as_ref()) | ||
.flatten(); | ||
buffer.push_str(&delta); | ||
let _ = sender.send(LLMClientCompletionResponse::new( | ||
buffer.to_owned(), | ||
Some(delta), | ||
model.to_owned(), | ||
)); | ||
} | ||
Err(err) => { | ||
dbg!(err); | ||
break; | ||
} | ||
} | ||
} | ||
} | ||
OpenAIClientType::OpenAIClient(client) => { | ||
let mut stream = client.chat().create_stream(request).await?; | ||
while let Some(response) = stream.next().await { | ||
match response { | ||
Ok(response) => { | ||
let response = response | ||
.choices | ||
.get(0) | ||
.ok_or(LLMClientError::FailedToGetResponse)?; | ||
let text = response.delta.content.to_owned(); | ||
if let Some(text) = text { | ||
buffer.push_str(&text); | ||
let _ = sender.send(LLMClientCompletionResponse::new( | ||
buffer.to_owned(), | ||
Some(text), | ||
model.to_owned(), | ||
)); | ||
} | ||
} | ||
Err(err) => { | ||
dbg!(err); | ||
break; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
Ok(buffer) | ||
} | ||
|
||
async fn completion( | ||
&self, | ||
api_key: LLMProviderAPIKeys, | ||
request: LLMClientCompletionRequest, | ||
) -> Result<String, LLMClientError> { | ||
let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); | ||
let result = self.stream_completion(api_key, request, sender).await?; | ||
Ok(result) | ||
} | ||
|
||
async fn stream_prompt_completion( | ||
&self, | ||
api_key: LLMProviderAPIKeys, | ||
request: LLMClientCompletionStringRequest, | ||
sender: tokio::sync::mpsc::UnboundedSender<LLMClientCompletionResponse>, | ||
) -> Result<String, LLMClientError> { | ||
let llm_model = request.model(); | ||
let model = self.model(llm_model); | ||
if model.is_none() { | ||
return Err(LLMClientError::UnSupportedModel); | ||
} | ||
let model = model.unwrap(); | ||
let mut request_builder_args = CreateCompletionRequestArgs::default(); | ||
let mut request_builder = request_builder_args | ||
.model(model.to_owned()) | ||
.prompt(request.prompt()) | ||
.temperature(request.temperature()) | ||
.stream(true); | ||
if let Some(frequency_penalty) = request.frequency_penalty() { | ||
request_builder = request_builder.frequency_penalty(frequency_penalty); | ||
} | ||
let request = request_builder.build()?; | ||
let mut buffer = String::new(); | ||
let client = self.generate_completion_openai_client(api_key, llm_model)?; | ||
let mut stream = client.completions().create_stream(request).await?; | ||
while let Some(response) = stream.next().await { | ||
match response { | ||
Ok(response) => { | ||
let response = response | ||
.choices | ||
.get(0) | ||
.ok_or(LLMClientError::FailedToGetResponse)?; | ||
let text = response.text.to_owned(); | ||
buffer.push_str(&text); | ||
let _ = sender.send(LLMClientCompletionResponse::new( | ||
buffer.to_owned(), | ||
Some(text), | ||
model.to_owned(), | ||
)); | ||
} | ||
Err(err) => { | ||
dbg!(err); | ||
break; | ||
} | ||
} | ||
} | ||
Ok(buffer) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters