Skip to content

Commit

Permalink
feat: Pass logger through to chat service and log completion chunks a…
Browse files Browse the repository at this point in the history
…s they are generated (#1161)

* Pass logger through to chat service and log completion chunks as they are generated

* Log in the chat completion generator instead of the endpoint

* Create Message struct for log serialization

* Rename fields

* Create function to convert message list

* Remove unnecessary access modifier

* Use Message struct for output

* Use single message instead of list of messages

* Make requested changes

* Update Cargo.toml

* Update Cargo.toml

---------

Co-authored-by: Meng Zhang <[email protected]>
  • Loading branch information
boxbeam and wsxiaoys authored Jan 5, 2024
1 parent 22b8669 commit cb035a6
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 13 deletions.
11 changes: 11 additions & 0 deletions crates/tabby-common/src/api/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ pub enum Event {
#[serde(skip_serializing_if = "Option::is_none")]
user: Option<String>,
},
ChatCompletion {
completion_id: String,
input: Vec<Message>,
output: Message,
},
}

#[derive(Serialize)]
pub struct Message {
pub role: String,
pub content: String,
}

#[derive(Serialize)]
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ reqwest.workspace = true
serde-jsonlines = "0.5.0"

[package.metadata.cargo-machete]
ignored = ["openssl"]
ignored = ["openssl"]
2 changes: 1 addition & 1 deletion crates/tabby/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ async fn api_router(

let chat_state = if let Some(chat_model) = &args.chat_model {
Some(Arc::new(
create_chat_service(chat_model, &args.device, args.parallelism).await,
create_chat_service(logger.clone(), chat_model, &args.device, args.parallelism).await,
))
} else {
None
Expand Down
44 changes: 40 additions & 4 deletions crates/tabby/src/services/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use async_stream::stream;
use chat_prompt::ChatPromptBuilder;
use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
use tabby_common::api::event::{Event, EventLogger};
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::debug;
use utoipa::ToSchema;
Expand Down Expand Up @@ -75,13 +76,19 @@ impl ChatCompletionChunk {

pub struct ChatService {
engine: Arc<dyn TextGeneration>,
logger: Arc<dyn EventLogger>,
prompt_builder: ChatPromptBuilder,
}

impl ChatService {
fn new(engine: Arc<dyn TextGeneration>, chat_template: String) -> Self {
fn new(
engine: Arc<dyn TextGeneration>,
logger: Arc<dyn EventLogger>,
chat_template: String,
) -> Self {
Self {
engine,
logger,
prompt_builder: ChatPromptBuilder::new(chat_template),
}
}
Expand All @@ -99,32 +106,61 @@ impl ChatService {
&self,
request: &ChatCompletionRequest,
) -> BoxStream<ChatCompletionChunk> {
let mut event_output = String::new();
let event_input = convert_messages(&request.messages);

let prompt = self.prompt_builder.build(&request.messages);
let options = Self::text_generation_options();
let created = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("Must be able to read system clock")
.as_secs();
let id = format!("chatcmpl-{}", Uuid::new_v4());

debug!("PROMPT: {}", prompt);
let s = stream! {
for await content in self.engine.generate_stream(&prompt, options).await {
event_output.push_str(&content);
yield ChatCompletionChunk::new(content, id.clone(), created, false)
}
yield ChatCompletionChunk::new("".into(), id, created, true)
yield ChatCompletionChunk::new("".into(), id.clone(), created, true);

self.logger.log(Event::ChatCompletion { completion_id: id, input: event_input, output: create_assistant_message(event_output) });
};

Box::pin(s)
}
}

pub async fn create_chat_service(model: &str, device: &Device, parallelism: u8) -> ChatService {
fn create_assistant_message(string: String) -> tabby_common::api::event::Message {
tabby_common::api::event::Message {
content: string,
role: "assistant".into(),
}
}

fn convert_messages(input: &Vec<Message>) -> Vec<tabby_common::api::event::Message> {
input
.iter()
.map(|m| tabby_common::api::event::Message {
content: m.content.clone(),
role: m.role.clone(),
})
.collect()
}

pub async fn create_chat_service(
logger: Arc<dyn EventLogger>,
model: &str,
device: &Device,
parallelism: u8,
) -> ChatService {
let (engine, model::PromptInfo { chat_template, .. }) =
model::load_text_generation(model, device, parallelism).await;

let Some(chat_template) = chat_template else {
fatal!("Chat model requires specifying prompt template");
};

ChatService::new(engine, chat_template)
ChatService::new(engine, logger, chat_template)
}
19 changes: 12 additions & 7 deletions crates/tabby/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{env::consts::ARCH, net::IpAddr, sync::Arc};

use axum::{routing, Router};
use clap::Args;
use tabby_common::api::{code::CodeSearch, event::EventLogger};
use tabby_webserver::public::{HubClient, RegisterWorkerRequest, WorkerKind};
use tracing::info;

Expand Down Expand Up @@ -46,19 +47,21 @@ pub struct WorkerArgs {
parallelism: u8,
}

async fn make_chat_route(args: &WorkerArgs) -> Router {
async fn make_chat_route(logger: Arc<dyn EventLogger>, args: &WorkerArgs) -> Router {
let chat_state =
Arc::new(create_chat_service(&args.model, &args.device, args.parallelism).await);
Arc::new(create_chat_service(logger, &args.model, &args.device, args.parallelism).await);

Router::new().route(
"/v1beta/chat/completions",
routing::post(routes::chat_completions).with_state(chat_state),
)
}

async fn make_completion_route(context: WorkerContext, args: &WorkerArgs) -> Router {
let code = Arc::new(context.client.clone());
let logger = Arc::new(context.client);
async fn make_completion_route(
code: Arc<dyn CodeSearch>,
logger: Arc<dyn EventLogger>,
args: &WorkerArgs,
) -> Router {
let completion_state = Arc::new(
create_completion_service(code, logger, &args.model, &args.device, args.parallelism).await,
);
Expand All @@ -75,10 +78,12 @@ pub async fn main(kind: WorkerKind, args: &WorkerArgs) {
info!("Starting worker, this might take a few minutes...");

let context = WorkerContext::new(kind.clone(), args).await;
let code = Arc::new(context.client);
let logger = code.clone();

let app = match kind {
WorkerKind::Completion => make_completion_route(context, args).await,
WorkerKind::Chat => make_chat_route(args).await,
WorkerKind::Completion => make_completion_route(code, logger.clone(), args).await,
WorkerKind::Chat => make_chat_route(logger.clone(), args).await,
};

run_app(app, None, args.host, args.port).await
Expand Down

0 comments on commit cb035a6

Please sign in to comment.