From 91d73ffe36cfe0a19415bc7037265fc7b0921e28 Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Wed, 30 Oct 2024 10:43:44 +0800 Subject: [PATCH] chore: specify more completion vendors --- .../src/completion/mistral.rs | 2 +- .../http-api-bindings/src/completion/mod.rs | 28 +++---------------- .../src/completion/openai.rs | 7 ++++- 3 files changed, 11 insertions(+), 26 deletions(-) diff --git a/crates/http-api-bindings/src/completion/mistral.rs b/crates/http-api-bindings/src/completion/mistral.rs index 9327c276567..fef89f49606 100644 --- a/crates/http-api-bindings/src/completion/mistral.rs +++ b/crates/http-api-bindings/src/completion/mistral.rs @@ -68,7 +68,7 @@ struct FIMResponseDelta { #[async_trait] impl CompletionStream for MistralFIMEngine { async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream { - let (prompt, suffix) = split_fim_prompt(prompt, true); + let (prompt, suffix) = split_fim_prompt(prompt); let request = FIMRequest { prompt: prompt.to_owned(), suffix: suffix.map(|x| x.to_owned()), diff --git a/crates/http-api-bindings/src/completion/mod.rs b/crates/http-api-bindings/src/completion/mod.rs index 1da4e8ea08f..00229b84572 100644 --- a/crates/http-api-bindings/src/completion/mod.rs +++ b/crates/http-api-bindings/src/completion/mod.rs @@ -31,7 +31,7 @@ pub async fn create(model: &HttpModelConfig) -> Arc { ); Arc::new(engine) } - "openai/completion" | "openai/legacy_completion" => { + "openai/legacy_completion" | "openai/completion" | "deepseek/completion" => { let engine = OpenAICompletionEngine::create( model.model_name.clone(), model @@ -43,7 +43,7 @@ pub async fn create(model: &HttpModelConfig) -> Arc { ); Arc::new(engine) } - "openai/legacy_completion_no_fim" => { + "openai/legacy_completion_no_fim" | "vllm/completion" => { let engine = OpenAICompletionEngine::create( model.model_name.clone(), model @@ -73,11 +73,7 @@ pub fn build_completion_prompt(model: &HttpModelConfig) -> (Option, Opti } } -fn split_fim_prompt(prompt: &str, support_fim: bool) -> (&str, Option<&str>) { - if support_fim { - return (prompt, None); - } - +fn split_fim_prompt(prompt: &str) -> (&str, Option<&str>) { let parts = prompt.splitn(2, FIM_TOKEN).collect::>(); (parts[0], parts.get(1).copied()) } @@ -88,22 +84,6 @@ mod tests { use super::*; - #[test] - fn test_split_fim_prompt() { - let support_fim = vec![ - "prefix<|FIM|>suffix", - "prefix<|FIM|>", - "<|FIM|>suffix", - "<|FIM|>", - "prefix", - ]; - for input in support_fim { - let (prompt, suffix) = split_fim_prompt(input, true); - assert_eq!(prompt, input); - assert!(suffix.is_none()); - } - } - #[test] fn test_split_fim_prompt_no_fim() { let no_fim = vec![ @@ -114,7 +94,7 @@ mod tests { ("prefix", ("prefix", None)), ]; for (input, expected) in no_fim { - assert_eq!(split_fim_prompt(input, false), expected); + assert_eq!(split_fim_prompt(input), expected); } } } diff --git a/crates/http-api-bindings/src/completion/openai.rs b/crates/http-api-bindings/src/completion/openai.rs index 8350d7f43e2..d6e8fa572b8 100644 --- a/crates/http-api-bindings/src/completion/openai.rs +++ b/crates/http-api-bindings/src/completion/openai.rs @@ -64,7 +64,12 @@ struct CompletionResponseChoice { #[async_trait] impl CompletionStream for OpenAICompletionEngine { async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream { - let (prompt, suffix) = split_fim_prompt(prompt, self.support_fim); + let (prompt, suffix) = if self.support_fim { + split_fim_prompt(prompt) + } else { + (prompt, None) + }; + let request = CompletionRequest { model: self.model_name.clone(), prompt: prompt.to_owned(),