Skip to content

Commit

Permalink
[sidecar] support oss models using codestory provider
Browse files Browse the repository at this point in the history
  • Loading branch information
theskcd committed Feb 9, 2024
1 parent 3c3a647 commit 2c6bb31
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 16 deletions.
3 changes: 2 additions & 1 deletion llm_client/src/bin/code_llama_infill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ async fn main() {
prompt.to_owned(),
0.2,
None,
);
)
.set_max_tokens(100);
let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
let response = togetherai
.stream_prompt_completion(api_key, request, sender)
Expand Down
20 changes: 18 additions & 2 deletions llm_client/src/bin/deepseek_infill.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use llm_client::clients::codestory::CodeStoryClient;
use llm_client::clients::types::LLMClient;
use llm_client::{
clients::{ollama::OllamaClient, types::LLMClientCompletionStringRequest},
Expand All @@ -16,9 +17,24 @@ async fn main() {
0.2,
None,
);
// let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
// let response = client
// .stream_prompt_completion(api_key, request, sender)
// .await;
// println!("{}", response.expect("to work"));
let codestory_client =
CodeStoryClient::new("https://codestory-provider-dot-anton-390822.ue.r.appspot.com");
let codestory_api_key = LLMProviderAPIKeys::CodeStory;
let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
let response = client
.stream_prompt_completion(api_key, request, sender)
let request = LLMClientCompletionStringRequest::new(
llm_client::clients::types::LLMType::DeepSeekCoder33BInstruct,
prompt.to_owned(),
0.2,
None,
)
.set_max_tokens(100);
let response = codestory_client
.stream_prompt_completion(codestory_api_key, request, sender)
.await;
println!("{}", response.expect("to work"));
}
10 changes: 7 additions & 3 deletions llm_client/src/broker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ use crate::{
lmstudio::LMStudioClient,
ollama::OllamaClient,
openai::OpenAIClient,
openai_compatible::OpenAICompatibleClient,
togetherai::TogetherAIClient,
types::{
LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse,
LLMClientCompletionStringRequest, LLMClientError,
}, openai_compatible::OpenAICompatibleClient,
},
},
config::LLMBrokerConfiguration,
provider::{CodeStoryLLMTypes, LLMProvider, LLMProviderAPIKeys},
Expand Down Expand Up @@ -45,7 +46,10 @@ impl LLMBroker {
.add_provider(LLMProvider::Ollama, Box::new(OllamaClient::new()))
.add_provider(LLMProvider::TogetherAI, Box::new(TogetherAIClient::new()))
.add_provider(LLMProvider::LMStudio, Box::new(LMStudioClient::new()))
.add_provider(LLMProvider::OpenAICompatible, Box::new(OpenAICompatibleClient::new()))
.add_provider(
LLMProvider::OpenAICompatible,
Box::new(OpenAICompatibleClient::new()),
)
.add_provider(
LLMProvider::CodeStory(CodeStoryLLMTypes { llm_type: None }),
Box::new(CodeStoryClient::new(
Expand Down Expand Up @@ -173,7 +177,7 @@ impl LLMBroker {
LLMProviderAPIKeys::LMStudio(_) => LLMProvider::LMStudio,
LLMProviderAPIKeys::CodeStory => {
LLMProvider::CodeStory(CodeStoryLLMTypes { llm_type: None })
},
}
LLMProviderAPIKeys::OpenAICompatible(_) => LLMProvider::OpenAICompatible,
};
let provider = self.providers.get(&provider_type);
Expand Down
110 changes: 105 additions & 5 deletions llm_client/src/clients/codestory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ use tokio::sync::mpsc::UnboundedSender;

use crate::provider::{LLMProvider, LLMProviderAPIKeys};

use super::types::{
LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse,
LLMClientCompletionStringRequest, LLMClientError, LLMClientRole, LLMType,
use super::{
togetherai::TogetherAIClient,
types::{
LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse,
LLMClientCompletionStringRequest, LLMClientError, LLMClientRole, LLMType,
},
};

#[derive(serde::Serialize, serde::Deserialize, Debug)]
Expand Down Expand Up @@ -44,6 +47,46 @@ struct CodeStoryRequest {
options: CodeStoryRequestOptions,
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub struct CodeStoryRequestPrompt {
prompt: String,
temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
stop_tokens: Option<Vec<String>>,
model: String,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<usize>,
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub struct CodeStoryChoice {
pub text: String,
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub struct CodeStoryPromptResponse {
choices: Vec<CodeStoryChoice>,
}

impl CodeStoryRequestPrompt {
fn from_string_request(
request: LLMClientCompletionStringRequest,
) -> Result<Self, LLMClientError> {
let model = TogetherAIClient::model_str(request.model());
println!("what is the model: {:?} {:?}", model, request.model());
match model {
Some(model) => Ok(Self {
prompt: request.prompt().to_owned(),
model,
temperature: request.temperature(),
stop_tokens: request.stop_words().map(|stop_tokens| stop_tokens.to_vec()),
max_tokens: request.get_max_tokens(),
}),
None => Err(LLMClientError::OpenAIDoesNotSupportCompletion),
}
}
}

impl CodeStoryRequest {
fn from_chat_request(request: LLMClientCompletionRequest) -> Self {
Self {
Expand Down Expand Up @@ -92,6 +135,10 @@ impl CodeStoryClient {
format!("{api_base}/chat-4")
}

pub fn together_api_endpoint(&self, api_base: &str) -> String {
format!("{api_base}/together-api")
}

pub fn model_name(&self, model: &LLMType) -> Result<String, LLMClientError> {
match model {
LLMType::GPT3_5_16k => Ok("gpt-3.5-turbo-16k-0613".to_owned()),
Expand All @@ -107,6 +154,15 @@ impl CodeStoryClient {
_ => Err(LLMClientError::UnSupportedModel),
}
}

pub fn model_prompt_endpoint(&self, model: &LLMType) -> Result<String, LLMClientError> {
match model {
LLMType::GPT3_5_16k | LLMType::Gpt4 | LLMType::Gpt4Turbo | LLMType::Gpt4_32k => {
Err(LLMClientError::UnSupportedModel)
}
_ => Ok(self.together_api_endpoint(&self.api_base)),
}
}
}

#[async_trait]
Expand Down Expand Up @@ -162,7 +218,6 @@ impl LLMClient for CodeStoryClient {
.flatten()
.unwrap_or("".to_owned());
buffered_stream.push_str(&delta);
println!("{}", &buffered_stream);
sender.send(LLMClientCompletionResponse::new(
buffered_stream.to_owned(),
Some(delta),
Expand All @@ -188,6 +243,51 @@ impl LLMClient for CodeStoryClient {
request: LLMClientCompletionStringRequest,
sender: UnboundedSender<LLMClientCompletionResponse>,
) -> Result<String, LLMClientError> {
Err(LLMClientError::UnSupportedModel)
let llm_model = request.model();
let endpoint = self.model_prompt_endpoint(&llm_model)?;
let code_story_request = CodeStoryRequestPrompt::from_string_request(request)?;
let model = code_story_request.model.to_owned();
let mut response_stream = self
.client
.post(endpoint)
.json(&code_story_request)
.send()
.await?
.bytes_stream()
.eventsource();
let mut buffered_stream = "".to_owned();
while let Some(event) = response_stream.next().await {
match event {
Ok(event) => {
if &event.data == "[DONE]" {
continue;
}
// we just proxy back the openai response back here
let response = serde_json::from_str::<CodeStoryPromptResponse>(&event.data);
match response {
Ok(response) => {
let delta = response
.choices
.get(0)
.map(|choice| choice.text.to_owned())
.unwrap_or("".to_owned());
buffered_stream.push_str(&delta);
sender.send(LLMClientCompletionResponse::new(
buffered_stream.to_owned(),
Some(delta),
model.to_owned(),
))?;
}
Err(e) => {
dbg!(e);
}
}
}
Err(e) => {
dbg!(e);
}
}
}
Ok(buffered_stream)
}
}
11 changes: 10 additions & 1 deletion llm_client/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ impl LLMProviderAPIKeys {
LLMProvider::CodeStory(_) => Some(LLMProviderAPIKeys::CodeStory),
LLMProvider::OpenAICompatible => {
if let LLMProviderAPIKeys::OpenAICompatible(openai_compatible) = self {
Some(LLMProviderAPIKeys::OpenAICompatible(openai_compatible.clone()))
Some(LLMProviderAPIKeys::OpenAICompatible(
openai_compatible.clone(),
))
} else {
None
}
Expand Down Expand Up @@ -202,4 +204,11 @@ mod tests {
let string_provider_keys = serde_json::to_string(&provider_keys).expect("to work");
assert_eq!(string_provider_keys, "",);
}

#[test]
fn test_reading_from_string_for_provider_keys() {
let provider_keys = LLMProviderAPIKeys::CodeStory;
let string_provider_keys = serde_json::to_string(&provider_keys).expect("to work");
assert_eq!(string_provider_keys, "\"CodeStory\"");
}
}
5 changes: 4 additions & 1 deletion sidecar/src/inline_completion/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,10 @@ impl FillInMiddleCompletionAgent {
)],
formatted_string.filled.to_owned(),
)),
_ => Err(InLineCompletionError::InlineCompletionTerminated),
either::Right(Err(e)) => {
println!("{:?}", e);
Err(InLineCompletionError::InlineCompletionTerminated)
}
})
}))
}
Expand Down
4 changes: 2 additions & 2 deletions sidecar/src/webserver/inline_completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ pub async fn inline_completion(
let stream = Abortable::new(completions, abort_request);
Ok(Sse::new(Box::pin(stream.filter_map(
|completion| async move {
dbg!("completion is coming along");
dbg!(&completion);
// dbg!("completion is coming along");
// dbg!(&completion);
match completion {
Ok(completion) => Some(
sse::Event::default()
Expand Down
14 changes: 13 additions & 1 deletion sidecar/src/webserver/model_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ pub struct Model {

#[cfg(test)]
mod tests {
use llm_client::provider::{AzureConfig, LLMProviderAPIKeys, OllamaProvider};
use llm_client::provider::{
AzureConfig, CodeStoryLLMTypes, LLMProvider, LLMProviderAPIKeys, OllamaProvider,
};

use super::LLMClientConfig;

Expand All @@ -115,6 +117,16 @@ mod tests {
assert!(serde_json::from_str::<LLMClientConfig>(data).is_ok());
}

#[test]
fn test_json_should_convert_properly_codestory() {
let provider = LLMProvider::CodeStory(CodeStoryLLMTypes { llm_type: None });
println!("{:?}", serde_json::to_string(&provider));
let data = r#"
{"slow_model":"DeepSeekCoder33BInstruct","fast_model":"DeepSeekCoder33BInstruct","models":{"Gpt4":{"context_length":8192,"temperature":0.2,"provider":"CodeStory"}},"providers":["CodeStory",{"OpenAIAzureConfig":{"deployment_id":"","api_base":"https://codestory-gpt4.openai.azure.com","api_key":"89ca8a49a33344c9b794b3dabcbbc5d0","api_version":"2023-08-01-preview"}},{"TogetherAI":{"api_key":"cc10d6774e67efef2004b85efdb81a3c9ba0b7682cc33d59c30834183502208d"}},{"OpenAICompatible":{"api_key":"somethingelse","api_base":"testendpoint"}},{"Ollama":{}}]}
"#;
assert!(serde_json::from_str::<LLMClientConfig>(data).is_ok());
}

#[test]
fn test_custom_llm_type_json() {
let llm_config = LLMClientConfig {
Expand Down

0 comments on commit 2c6bb31

Please sign in to comment.