Skip to content

Commit

Permalink
chore: specify more completion vendors
Browse files Browse the repository at this point in the history
  • Loading branch information
zwpaper committed Oct 30, 2024
1 parent 0f8dbc6 commit 91d73ff
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 26 deletions.
2 changes: 1 addition & 1 deletion crates/http-api-bindings/src/completion/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ struct FIMResponseDelta {
#[async_trait]
impl CompletionStream for MistralFIMEngine {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let (prompt, suffix) = split_fim_prompt(prompt, true);
let (prompt, suffix) = split_fim_prompt(prompt);

Check warning on line 71 in crates/http-api-bindings/src/completion/mistral.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mistral.rs#L71

Added line #L71 was not covered by tests
let request = FIMRequest {
prompt: prompt.to_owned(),
suffix: suffix.map(|x| x.to_owned()),

Check warning on line 74 in crates/http-api-bindings/src/completion/mistral.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mistral.rs#L73-L74

Added lines #L73 - L74 were not covered by tests
Expand Down
28 changes: 4 additions & 24 deletions crates/http-api-bindings/src/completion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
);
Arc::new(engine)
}
"openai/completion" | "openai/legacy_completion" => {
"openai/legacy_completion" | "openai/completion" | "deepseek/completion" => {

Check warning on line 34 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L34

Added line #L34 was not covered by tests
let engine = OpenAICompletionEngine::create(
model.model_name.clone(),
model
Expand All @@ -43,7 +43,7 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
);
Arc::new(engine)

Check warning on line 44 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L42-L44

Added lines #L42 - L44 were not covered by tests
}
"openai/legacy_completion_no_fim" => {
"openai/legacy_completion_no_fim" | "vllm/completion" => {
let engine = OpenAICompletionEngine::create(
model.model_name.clone(),
model
Expand Down Expand Up @@ -73,11 +73,7 @@ pub fn build_completion_prompt(model: &HttpModelConfig) -> (Option<String>, Opti
}
}

fn split_fim_prompt(prompt: &str, support_fim: bool) -> (&str, Option<&str>) {
if support_fim {
return (prompt, None);
}

fn split_fim_prompt(prompt: &str) -> (&str, Option<&str>) {
let parts = prompt.splitn(2, FIM_TOKEN).collect::<Vec<_>>();
(parts[0], parts.get(1).copied())
}
Expand All @@ -88,22 +84,6 @@ mod tests {

use super::*;

#[test]
fn test_split_fim_prompt() {
let support_fim = vec![
"prefix<|FIM|>suffix",
"prefix<|FIM|>",
"<|FIM|>suffix",
"<|FIM|>",
"prefix",
];
for input in support_fim {
let (prompt, suffix) = split_fim_prompt(input, true);
assert_eq!(prompt, input);
assert!(suffix.is_none());
}
}

#[test]
fn test_split_fim_prompt_no_fim() {
let no_fim = vec![
Expand All @@ -114,7 +94,7 @@ mod tests {
("prefix", ("prefix", None)),
];
for (input, expected) in no_fim {
assert_eq!(split_fim_prompt(input, false), expected);
assert_eq!(split_fim_prompt(input), expected);
}
}
}
7 changes: 6 additions & 1 deletion crates/http-api-bindings/src/completion/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ struct CompletionResponseChoice {
#[async_trait]
impl CompletionStream for OpenAICompletionEngine {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let (prompt, suffix) = split_fim_prompt(prompt, self.support_fim);
let (prompt, suffix) = if self.support_fim {
split_fim_prompt(prompt)

Check warning on line 68 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L67-L68

Added lines #L67 - L68 were not covered by tests
} else {
(prompt, None)

Check warning on line 70 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L70

Added line #L70 was not covered by tests
};

let request = CompletionRequest {
model: self.model_name.clone(),
prompt: prompt.to_owned(),
Expand Down

0 comments on commit 91d73ff

Please sign in to comment.