Skip to content

Commit

Permalink
fix(core): fix ggml path loading for windows. fix user_agent field (m… (
Browse files Browse the repository at this point in the history
#3152)

* fix(core): fix ggml path loading for windows. fix user_agent field (mark as optional)

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
wsxiaoys and autofix-ci[bot] authored Sep 16, 2024
1 parent cc9f7ef commit 531062a
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 15 deletions.
2 changes: 1 addition & 1 deletion crates/llama-cpp-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ pub async fn create_embedding(config: &ModelConfig) -> Arc<dyn Embedding> {
async fn resolve_model_path(model_id: &str) -> String {
let path = PathBuf::from(model_id);
let path = if path.exists() {
path.join(GGML_MODEL_RELATIVE_PATH)
path.join(GGML_MODEL_RELATIVE_PATH.as_str())
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby-common/src/api/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub enum Event {
#[serde(skip_serializing_if = "Option::is_none")]
segments: Option<Segments>,
choices: Vec<Choice>,
user_agent: String,
user_agent: Option<String>,
},
ChatCompletion {
completion_id: String,
Expand Down
16 changes: 11 additions & 5 deletions crates/tabby-common/src/registry.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{fs, path::PathBuf};

use anyhow::{Context, Result};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};

use crate::path::models_dir;
Expand Down Expand Up @@ -76,7 +77,7 @@ impl ModelRegistry {
let model_path = self.get_model_path(name);
let old_model_path = self
.get_model_dir(name)
.join(LEGACY_GGML_MODEL_RELATIVE_PATH);
.join(LEGACY_GGML_MODEL_RELATIVE_PATH.as_str());

if !model_path.exists() && old_model_path.exists() {
std::fs::rename(&old_model_path, &model_path)?;
Expand All @@ -89,7 +90,8 @@ impl ModelRegistry {
}

pub fn get_model_path(&self, name: &str) -> PathBuf {
self.get_model_dir(name).join(GGML_MODEL_RELATIVE_PATH)
self.get_model_dir(name)
.join(GGML_MODEL_RELATIVE_PATH.as_str())
}

pub fn save_model_info(&self, name: &str) {
Expand Down Expand Up @@ -118,8 +120,12 @@ pub fn parse_model_id(model_id: &str) -> (&str, &str) {
}
}

pub static LEGACY_GGML_MODEL_RELATIVE_PATH: &str = "ggml/q8_0.v2.gguf";
pub static GGML_MODEL_RELATIVE_PATH: &str = "ggml/model.gguf";
lazy_static! {
pub static ref LEGACY_GGML_MODEL_RELATIVE_PATH: String =
format!("ggml{}q8_0.v2.gguf", std::path::MAIN_SEPARATOR_STR);
pub static ref GGML_MODEL_RELATIVE_PATH: String =
format!("ggml{}model.gguf", std::path::MAIN_SEPARATOR_STR);
}

#[cfg(test)]
mod tests {
Expand All @@ -136,7 +142,7 @@ mod tests {
let registry = ModelRegistry::new("TabbyML").await;
let dir = registry.get_model_dir("StarCoder-1B");

let old_model_path = dir.join(LEGACY_GGML_MODEL_RELATIVE_PATH);
let old_model_path = dir.join(LEGACY_GGML_MODEL_RELATIVE_PATH.as_str());
tokio::fs::create_dir_all(old_model_path.parent().unwrap())
.await
.unwrap();
Expand Down
6 changes: 4 additions & 2 deletions crates/tabby/src/routes/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ use crate::services::completion::{CompletionRequest, CompletionResponse, Complet
pub async fn completions(
State(state): State<Arc<CompletionService>>,
TypedHeader(MaybeUser(user)): TypedHeader<MaybeUser>,
TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Json(mut request): Json<CompletionRequest>,
) -> Result<Json<CompletionResponse>, StatusCode> {
if let Some(user) = user {
request.user.replace(user);
}

match state.generate(&request, &user_agent.to_string()).await {
let user_agent = user_agent.map(|x| x.0.to_string());

match state.generate(&request, user_agent.as_deref()).await {
Ok(resp) => Ok(Json(resp)),
Err(err) => {
warn!("{}", err);
Expand Down
6 changes: 3 additions & 3 deletions crates/tabby/src/services/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ impl CompletionService {
pub async fn generate(
&self,
request: &CompletionRequest,
user_agent: &str,
user_agent: Option<&str>,
) -> Result<CompletionResponse, CompletionError> {
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
let language = request.language_or_unknown();
Expand Down Expand Up @@ -338,7 +338,7 @@ impl CompletionService {
index: 0,
text: text.clone(),
}],
user_agent: user_agent.to_string(),
user_agent: user_agent.map(|x| x.to_owned()),
},
);

Expand Down Expand Up @@ -462,7 +462,7 @@ mod tests {
};

let response = completion_service
.generate(&request, "test user agent")
.generate(&request, Some("test user agent"))
.await
.unwrap();
assert_eq!(response.choices[0].text, r#""Hello, world!""#);
Expand Down
10 changes: 10 additions & 0 deletions crates/tabby/tests/goldentests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ async fn golden_test(body: serde_json::Value) -> serde_json::Value {
}),
);

let resp = CLIENT
.post("http://127.0.0.1:9090/v1/completions")
.json(&body)
.send()
.await
.unwrap();

let info = resp.text().await.unwrap();
eprintln!("info {}", info);

let actual: serde_json::Value = CLIENT
.post("http://127.0.0.1:9090/v1/completions")
.json(&body)
Expand Down
6 changes: 3 additions & 3 deletions ee/tabby-webserver/src/service/event_logger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ mod tests {
prompt: "testprompt".into(),
segments: None,
choices: vec![],
user_agent: "ide: version test".into(),
user_agent: Some("ide: version test".into()),
},
);

Expand Down Expand Up @@ -242,7 +242,7 @@ mod tests {
prompt: "testprompt".into(),
segments: None,
choices: vec![],
user_agent: "ide: version unknown".into(),
user_agent: Some("ide: version unknown".into()),
},
);

Expand All @@ -257,7 +257,7 @@ mod tests {
prompt: "testprompt".into(),
segments: None,
choices: vec![],
user_agent: "ide: version unknown".into(),
user_agent: Some("ide: version unknown".into()),
},
);

Expand Down

0 comments on commit 531062a

Please sign in to comment.