Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(openai): support fim and no_fim for openai completion #3338

Merged
merged 3 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions crates/http-api-bindings/src/completion/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
use serde::{Deserialize, Serialize};
use tabby_inference::{CompletionOptions, CompletionStream};

use super::FIM_TOKEN;
use super::split_fim_prompt;

pub struct MistralFIMEngine {
client: reqwest::Client,
Expand Down Expand Up @@ -68,13 +68,10 @@
#[async_trait]
impl CompletionStream for MistralFIMEngine {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let parts = prompt.splitn(2, FIM_TOKEN).collect::<Vec<_>>();
let (prompt, suffix) = split_fim_prompt(prompt);

Check warning on line 71 in crates/http-api-bindings/src/completion/mistral.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mistral.rs#L71

Added line #L71 was not covered by tests
let request = FIMRequest {
prompt: parts[0].to_owned(),
suffix: parts
.get(1)
.map(|x| x.to_string())
.filter(|x| !x.is_empty()),
prompt: prompt.to_owned(),
suffix: suffix.map(|x| x.to_owned()),

Check warning on line 74 in crates/http-api-bindings/src/completion/mistral.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mistral.rs#L73-L74

Added lines #L73 - L74 were not covered by tests
model: self.model_name.clone(),
max_tokens: options.max_decoding_tokens,
temperature: options.sampling_temperature,
Expand Down
41 changes: 40 additions & 1 deletion crates/http-api-bindings/src/completion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,27 @@
);
Arc::new(engine)
}
"openai/completion" => {
"openai/legacy_completion" | "openai/completion" | "deepseek/completion" => {

Check warning on line 34 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L34

Added line #L34 was not covered by tests
let engine = OpenAICompletionEngine::create(
model.model_name.clone(),
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
true,
);
Arc::new(engine)

Check warning on line 44 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L42-L44

Added lines #L42 - L44 were not covered by tests
}
"openai/legacy_completion_no_fim" | "vllm/completion" => {
let engine = OpenAICompletionEngine::create(
model.model_name.clone(),
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
false,

Check warning on line 54 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L46-L54

Added lines #L46 - L54 were not covered by tests
);
Arc::new(engine)
}
Expand All @@ -59,3 +72,29 @@
(model.prompt_template.clone(), model.chat_template.clone())
}
}

fn split_fim_prompt(prompt: &str) -> (&str, Option<&str>) {
let parts = prompt.splitn(2, FIM_TOKEN).collect::<Vec<_>>();
(parts[0], parts.get(1).copied())
}

#[cfg(test)]
mod tests {
use std::vec;

use super::*;

#[test]
fn test_split_fim_prompt_no_fim() {
let no_fim = vec![
("prefix<|FIM|>suffix", ("prefix", Some("suffix"))),
("prefix<|FIM|>", ("prefix", Some(""))),
("<|FIM|>suffix", ("", Some("suffix"))),
("<|FIM|>", ("", Some(""))),
("prefix", ("prefix", None)),
];
for (input, expected) in no_fim {
assert_eq!(split_fim_prompt(input), expected);
}
}
}
29 changes: 21 additions & 8 deletions crates/http-api-bindings/src/completion/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,27 @@
use serde::{Deserialize, Serialize};
use tabby_inference::{CompletionOptions, CompletionStream};

use super::FIM_TOKEN;
use super::split_fim_prompt;

pub struct OpenAICompletionEngine {
client: reqwest::Client,
model_name: String,
api_endpoint: String,
api_key: Option<String>,

/// OpenAI Completion API use suffix field in request when FIM is not supported,
/// support_fim is used to mark if FIM is supported,
/// provide a `openai/legacy_completion_no_fim` backend to use suffix field.
support_fim: bool,
}

impl OpenAICompletionEngine {
pub fn create(model_name: Option<String>, api_endpoint: &str, api_key: Option<String>) -> Self {
pub fn create(
model_name: Option<String>,
api_endpoint: &str,
api_key: Option<String>,
support_fim: bool,
) -> Self {

Check warning on line 28 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L23-L28

Added lines #L23 - L28 were not covered by tests
let model_name = model_name.expect("model_name is required for openai/completion");
let client = reqwest::Client::new();

Expand All @@ -24,6 +34,7 @@
model_name,
api_endpoint: format!("{}/completions", api_endpoint),
api_key,
support_fim,

Check warning on line 37 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L37

Added line #L37 was not covered by tests
}
}
}
Expand Down Expand Up @@ -53,14 +64,16 @@
#[async_trait]
impl CompletionStream for OpenAICompletionEngine {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let parts = prompt.splitn(2, FIM_TOKEN).collect::<Vec<_>>();
let (prompt, suffix) = if self.support_fim {
split_fim_prompt(prompt)

Check warning on line 68 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L67-L68

Added lines #L67 - L68 were not covered by tests
} else {
(prompt, None)

Check warning on line 70 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L70

Added line #L70 was not covered by tests
};

let request = CompletionRequest {
model: self.model_name.clone(),
prompt: parts[0].to_owned(),
suffix: parts
.get(1)
.map(|x| x.to_string())
.filter(|x| !x.is_empty()),
prompt: prompt.to_owned(),
suffix: suffix.map(|x| x.to_owned()),

Check warning on line 76 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L75-L76

Added lines #L75 - L76 were not covered by tests
max_tokens: options.max_decoding_tokens,
temperature: options.sampling_temperature,
stream: true,
Expand Down
Loading