diff --git a/llm_client/src/bin/code_llama_infill.rs b/llm_client/src/bin/code_llama_infill.rs index 5dc45295d..406b8c243 100644 --- a/llm_client/src/bin/code_llama_infill.rs +++ b/llm_client/src/bin/code_llama_infill.rs @@ -19,7 +19,8 @@ async fn main() { prompt.to_owned(), 0.2, None, - ); + ) + .set_max_tokens(100); let (sender, receiver) = tokio::sync::mpsc::unbounded_channel(); let response = togetherai .stream_prompt_completion(api_key, request, sender) diff --git a/llm_client/src/bin/deepseek_infill.rs b/llm_client/src/bin/deepseek_infill.rs index 1e10d07e3..549a8b11f 100644 --- a/llm_client/src/bin/deepseek_infill.rs +++ b/llm_client/src/bin/deepseek_infill.rs @@ -1,3 +1,4 @@ +use llm_client::clients::codestory::CodeStoryClient; use llm_client::clients::types::LLMClient; use llm_client::{ clients::{ollama::OllamaClient, types::LLMClientCompletionStringRequest}, @@ -16,9 +17,24 @@ async fn main() { 0.2, None, ); + // let (sender, receiver) = tokio::sync::mpsc::unbounded_channel(); + // let response = client + // .stream_prompt_completion(api_key, request, sender) + // .await; + // println!("{}", response.expect("to work")); + let codestory_client = + CodeStoryClient::new("https://codestory-provider-dot-anton-390822.ue.r.appspot.com"); + let codestory_api_key = LLMProviderAPIKeys::CodeStory; let (sender, receiver) = tokio::sync::mpsc::unbounded_channel(); - let response = client - .stream_prompt_completion(api_key, request, sender) + let request = LLMClientCompletionStringRequest::new( + llm_client::clients::types::LLMType::DeepSeekCoder33BInstruct, + prompt.to_owned(), + 0.2, + None, + ) + .set_max_tokens(100); + let response = codestory_client + .stream_prompt_completion(codestory_api_key, request, sender) .await; println!("{}", response.expect("to work")); } diff --git a/llm_client/src/broker.rs b/llm_client/src/broker.rs index 7adda6cdf..79ac9ccc0 100644 --- a/llm_client/src/broker.rs +++ b/llm_client/src/broker.rs @@ -13,11 +13,12 @@ use crate::{ lmstudio::LMStudioClient, ollama::OllamaClient, openai::OpenAIClient, + openai_compatible::OpenAICompatibleClient, togetherai::TogetherAIClient, types::{ LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse, LLMClientCompletionStringRequest, LLMClientError, - }, openai_compatible::OpenAICompatibleClient, + }, }, config::LLMBrokerConfiguration, provider::{CodeStoryLLMTypes, LLMProvider, LLMProviderAPIKeys}, @@ -45,7 +46,10 @@ impl LLMBroker { .add_provider(LLMProvider::Ollama, Box::new(OllamaClient::new())) .add_provider(LLMProvider::TogetherAI, Box::new(TogetherAIClient::new())) .add_provider(LLMProvider::LMStudio, Box::new(LMStudioClient::new())) - .add_provider(LLMProvider::OpenAICompatible, Box::new(OpenAICompatibleClient::new())) + .add_provider( + LLMProvider::OpenAICompatible, + Box::new(OpenAICompatibleClient::new()), + ) .add_provider( LLMProvider::CodeStory(CodeStoryLLMTypes { llm_type: None }), Box::new(CodeStoryClient::new( @@ -173,7 +177,7 @@ impl LLMBroker { LLMProviderAPIKeys::LMStudio(_) => LLMProvider::LMStudio, LLMProviderAPIKeys::CodeStory => { LLMProvider::CodeStory(CodeStoryLLMTypes { llm_type: None }) - }, + } LLMProviderAPIKeys::OpenAICompatible(_) => LLMProvider::OpenAICompatible, }; let provider = self.providers.get(&provider_type); diff --git a/llm_client/src/clients/codestory.rs b/llm_client/src/clients/codestory.rs index 28d2d51a7..01e67cf88 100644 --- a/llm_client/src/clients/codestory.rs +++ b/llm_client/src/clients/codestory.rs @@ -6,9 +6,12 @@ use tokio::sync::mpsc::UnboundedSender; use crate::provider::{LLMProvider, LLMProviderAPIKeys}; -use super::types::{ - LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse, - LLMClientCompletionStringRequest, LLMClientError, LLMClientRole, LLMType, +use super::{ + togetherai::TogetherAIClient, + types::{ + LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse, + LLMClientCompletionStringRequest, LLMClientError, LLMClientRole, LLMType, + }, }; #[derive(serde::Serialize, serde::Deserialize, Debug)] @@ -44,6 +47,46 @@ struct CodeStoryRequest { options: CodeStoryRequestOptions, } +#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] +pub struct CodeStoryRequestPrompt { + prompt: String, + temperature: f32, + #[serde(skip_serializing_if = "Option::is_none")] + stop_tokens: Option>, + model: String, + #[serde(skip_serializing_if = "Option::is_none")] + max_tokens: Option, +} + +#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] +pub struct CodeStoryChoice { + pub text: String, +} + +#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] +pub struct CodeStoryPromptResponse { + choices: Vec, +} + +impl CodeStoryRequestPrompt { + fn from_string_request( + request: LLMClientCompletionStringRequest, + ) -> Result { + let model = TogetherAIClient::model_str(request.model()); + println!("what is the model: {:?} {:?}", model, request.model()); + match model { + Some(model) => Ok(Self { + prompt: request.prompt().to_owned(), + model, + temperature: request.temperature(), + stop_tokens: request.stop_words().map(|stop_tokens| stop_tokens.to_vec()), + max_tokens: request.get_max_tokens(), + }), + None => Err(LLMClientError::OpenAIDoesNotSupportCompletion), + } + } +} + impl CodeStoryRequest { fn from_chat_request(request: LLMClientCompletionRequest) -> Self { Self { @@ -92,6 +135,10 @@ impl CodeStoryClient { format!("{api_base}/chat-4") } + pub fn together_api_endpoint(&self, api_base: &str) -> String { + format!("{api_base}/together-api") + } + pub fn model_name(&self, model: &LLMType) -> Result { match model { LLMType::GPT3_5_16k => Ok("gpt-3.5-turbo-16k-0613".to_owned()), @@ -107,6 +154,15 @@ impl CodeStoryClient { _ => Err(LLMClientError::UnSupportedModel), } } + + pub fn model_prompt_endpoint(&self, model: &LLMType) -> Result { + match model { + LLMType::GPT3_5_16k | LLMType::Gpt4 | LLMType::Gpt4Turbo | LLMType::Gpt4_32k => { + Err(LLMClientError::UnSupportedModel) + } + _ => Ok(self.together_api_endpoint(&self.api_base)), + } + } } #[async_trait] @@ -162,7 +218,6 @@ impl LLMClient for CodeStoryClient { .flatten() .unwrap_or("".to_owned()); buffered_stream.push_str(&delta); - println!("{}", &buffered_stream); sender.send(LLMClientCompletionResponse::new( buffered_stream.to_owned(), Some(delta), @@ -188,6 +243,51 @@ impl LLMClient for CodeStoryClient { request: LLMClientCompletionStringRequest, sender: UnboundedSender, ) -> Result { - Err(LLMClientError::UnSupportedModel) + let llm_model = request.model(); + let endpoint = self.model_prompt_endpoint(&llm_model)?; + let code_story_request = CodeStoryRequestPrompt::from_string_request(request)?; + let model = code_story_request.model.to_owned(); + let mut response_stream = self + .client + .post(endpoint) + .json(&code_story_request) + .send() + .await? + .bytes_stream() + .eventsource(); + let mut buffered_stream = "".to_owned(); + while let Some(event) = response_stream.next().await { + match event { + Ok(event) => { + if &event.data == "[DONE]" { + continue; + } + // we just proxy back the openai response back here + let response = serde_json::from_str::(&event.data); + match response { + Ok(response) => { + let delta = response + .choices + .get(0) + .map(|choice| choice.text.to_owned()) + .unwrap_or("".to_owned()); + buffered_stream.push_str(&delta); + sender.send(LLMClientCompletionResponse::new( + buffered_stream.to_owned(), + Some(delta), + model.to_owned(), + ))?; + } + Err(e) => { + dbg!(e); + } + } + } + Err(e) => { + dbg!(e); + } + } + } + Ok(buffered_stream) } } diff --git a/llm_client/src/provider.rs b/llm_client/src/provider.rs index c302665e8..edd4f422d 100644 --- a/llm_client/src/provider.rs +++ b/llm_client/src/provider.rs @@ -117,7 +117,9 @@ impl LLMProviderAPIKeys { LLMProvider::CodeStory(_) => Some(LLMProviderAPIKeys::CodeStory), LLMProvider::OpenAICompatible => { if let LLMProviderAPIKeys::OpenAICompatible(openai_compatible) = self { - Some(LLMProviderAPIKeys::OpenAICompatible(openai_compatible.clone())) + Some(LLMProviderAPIKeys::OpenAICompatible( + openai_compatible.clone(), + )) } else { None } @@ -202,4 +204,11 @@ mod tests { let string_provider_keys = serde_json::to_string(&provider_keys).expect("to work"); assert_eq!(string_provider_keys, "",); } + + #[test] + fn test_reading_from_string_for_provider_keys() { + let provider_keys = LLMProviderAPIKeys::CodeStory; + let string_provider_keys = serde_json::to_string(&provider_keys).expect("to work"); + assert_eq!(string_provider_keys, "\"CodeStory\""); + } } diff --git a/sidecar/src/inline_completion/types.rs b/sidecar/src/inline_completion/types.rs index 9a99d6763..98277f52c 100644 --- a/sidecar/src/inline_completion/types.rs +++ b/sidecar/src/inline_completion/types.rs @@ -205,7 +205,10 @@ impl FillInMiddleCompletionAgent { )], formatted_string.filled.to_owned(), )), - _ => Err(InLineCompletionError::InlineCompletionTerminated), + either::Right(Err(e)) => { + println!("{:?}", e); + Err(InLineCompletionError::InlineCompletionTerminated) + } }) })) } diff --git a/sidecar/src/webserver/inline_completion.rs b/sidecar/src/webserver/inline_completion.rs index 3b69483c3..94115f833 100644 --- a/sidecar/src/webserver/inline_completion.rs +++ b/sidecar/src/webserver/inline_completion.rs @@ -103,8 +103,8 @@ pub async fn inline_completion( let stream = Abortable::new(completions, abort_request); Ok(Sse::new(Box::pin(stream.filter_map( |completion| async move { - dbg!("completion is coming along"); - dbg!(&completion); + // dbg!("completion is coming along"); + // dbg!(&completion); match completion { Ok(completion) => Some( sse::Event::default() diff --git a/sidecar/src/webserver/model_selection.rs b/sidecar/src/webserver/model_selection.rs index 635c92ea8..51ff47521 100644 --- a/sidecar/src/webserver/model_selection.rs +++ b/sidecar/src/webserver/model_selection.rs @@ -103,7 +103,9 @@ pub struct Model { #[cfg(test)] mod tests { - use llm_client::provider::{AzureConfig, LLMProviderAPIKeys, OllamaProvider}; + use llm_client::provider::{ + AzureConfig, CodeStoryLLMTypes, LLMProvider, LLMProviderAPIKeys, OllamaProvider, + }; use super::LLMClientConfig; @@ -115,6 +117,16 @@ mod tests { assert!(serde_json::from_str::(data).is_ok()); } + #[test] + fn test_json_should_convert_properly_codestory() { + let provider = LLMProvider::CodeStory(CodeStoryLLMTypes { llm_type: None }); + println!("{:?}", serde_json::to_string(&provider)); + let data = r#" + {"slow_model":"DeepSeekCoder33BInstruct","fast_model":"DeepSeekCoder33BInstruct","models":{"Gpt4":{"context_length":8192,"temperature":0.2,"provider":"CodeStory"}},"providers":["CodeStory",{"OpenAIAzureConfig":{"deployment_id":"","api_base":"https://codestory-gpt4.openai.azure.com","api_key":"89ca8a49a33344c9b794b3dabcbbc5d0","api_version":"2023-08-01-preview"}},{"TogetherAI":{"api_key":"cc10d6774e67efef2004b85efdb81a3c9ba0b7682cc33d59c30834183502208d"}},{"OpenAICompatible":{"api_key":"somethingelse","api_base":"testendpoint"}},{"Ollama":{}}]} + "#; + assert!(serde_json::from_str::(data).is_ok()); + } + #[test] fn test_custom_llm_type_json() { let llm_config = LLMClientConfig {