Skip to content

Commit

Permalink
refactor: extract LlamaService
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys committed Nov 7, 2023
1 parent f532254 commit ece5f3c
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 138 deletions.
165 changes: 28 additions & 137 deletions crates/llama-cpp-bindings/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
mod llama;
mod utils;

use std::{collections::HashMap, sync::Arc};

use async_stream::stream;
use async_trait::async_trait;
use cxx::UniquePtr;
use derive_builder::Builder;
use ffi::create_engine;
use futures::{lock::Mutex, stream::BoxStream};
use futures::stream::BoxStream;
use llama::LlamaService;
use tabby_inference::{
decoding::{StopCondition, StopConditionFactory},
helpers, TextGeneration, TextGenerationOptions,
};
use tokio::{
sync::mpsc::{channel, Sender},
task::yield_now,
decoding::StopConditionFactory, helpers, TextGeneration, TextGenerationOptions,
};

#[cxx::bridge(namespace = "llama")]
Expand Down Expand Up @@ -45,66 +39,36 @@ mod ffi {
unsafe impl Send for ffi::TextInferenceEngine {}
unsafe impl Sync for ffi::TextInferenceEngine {}

struct InferenceRequest {
tx: Sender<String>,
stop_condition: StopCondition,
#[derive(Builder, Debug)]
pub struct LlamaTextGenerationOptions {
model_path: String,
use_gpu: bool,
}

struct AsyncTextInferenceEngine {
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
pub struct LlamaTextGeneration {
service: LlamaService,
stop_condition_factory: StopConditionFactory,
requests: Mutex<HashMap<u32, InferenceRequest>>,

next_request_id: Mutex<u32>,
}

impl AsyncTextInferenceEngine {
fn create(engine: UniquePtr<ffi::TextInferenceEngine>) -> Self {
impl LlamaTextGeneration {
pub fn new(options: LlamaTextGenerationOptions) -> Self {
let engine = create_engine(options.use_gpu, &options.model_path);
if engine.is_null() {
fatal!("Unable to load model: {}", options.model_path);
}

Self {
engine: Mutex::new(engine),
service: LlamaService::new(engine),
stop_condition_factory: StopConditionFactory::default(),
requests: Mutex::new(HashMap::new()),
next_request_id: Mutex::new(0),
}
}
}

async fn background_job(&self) {
let mut requests = self.requests.lock().await;
if requests.len() == 0 {
return;
}

let mut engine = self.engine.lock().await;

let result = match engine.as_mut().unwrap().step() {
Ok(result) => result,
Err(err) => {
fatal!("Failed to step: {}", err)
}
};

for ffi::StepOutput { request_id, text } in result {
let mut stopped = false;
let InferenceRequest { tx, stop_condition } = requests.get_mut(&request_id).unwrap();

if tx.is_closed() || text.is_empty() {
// Cancelled by client side or hit eos.
stopped = true;
} else if !stop_condition.should_stop(&text) {
match tx.send(text).await {
Ok(_) => (),
Err(_) => stopped = true,
}
} else {
// Stoop words stopped
stopped = true;
}

if stopped {
requests.remove(&request_id);
engine.as_mut().unwrap().stop_request(request_id);
}
}
#[async_trait]
impl TextGeneration for LlamaTextGeneration {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let s = self.generate_stream(prompt, options).await;
helpers::stream_to_string(s).await
}

async fn generate_stream(
Expand All @@ -114,23 +78,10 @@ impl AsyncTextInferenceEngine {
) -> BoxStream<String> {
let stop_condition = self.stop_condition_factory.create(prompt, options.language);

let request_id = self.alloc_request_id().await;
let mut rx = {
let (tx, rx) = channel::<String>(4);
self.requests
.lock()
.await
.insert(request_id, InferenceRequest { tx, stop_condition });
rx
};

{
let mut engine = self.engine.lock().await;
engine
.as_mut()
.unwrap()
.add_request(request_id, prompt, options.max_input_length);
}
let mut rx = self
.service
.add_request(prompt, options.max_input_length, stop_condition)
.await;

let s = stream! {
let mut length = 0;
Expand All @@ -147,64 +98,4 @@ impl AsyncTextInferenceEngine {

Box::pin(s)
}

async fn alloc_request_id(&self) -> u32 {
let mut request_id = self.next_request_id.lock().await;
let ret: u32 = *request_id;

// 2048 should be large enough to avoid collision.
*request_id = (*request_id + 1) % 2048;

ret
}
}

#[derive(Builder, Debug)]
pub struct LlamaTextGenerationOptions {
model_path: String,
use_gpu: bool,
}

pub struct LlamaTextGeneration {
engine: Arc<AsyncTextInferenceEngine>,
}

impl LlamaTextGeneration {
pub fn create(options: LlamaTextGenerationOptions) -> Self {
let engine = create_engine(options.use_gpu, &options.model_path);
if engine.is_null() {
fatal!("Unable to load model: {}", options.model_path);
}
let ret = LlamaTextGeneration {
engine: Arc::new(AsyncTextInferenceEngine::create(engine)),
};
ret.start_background_job();
ret
}

pub fn start_background_job(&self) {
let engine = self.engine.clone();
tokio::spawn(async move {
loop {
engine.background_job().await;
yield_now().await;
}
});
}
}

#[async_trait]
impl TextGeneration for LlamaTextGeneration {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let s = self.generate_stream(prompt, options).await;
helpers::stream_to_string(s).await
}

async fn generate_stream(
&self,
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String> {
self.engine.generate_stream(prompt, options).await
}
}
155 changes: 155 additions & 0 deletions crates/llama-cpp-bindings/src/llama.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
use std::{collections::HashMap, thread::JoinHandle};

use cxx::UniquePtr;
use tabby_inference::decoding::StopCondition;
use tokio::sync::mpsc::{channel, Receiver, Sender};

use crate::ffi;

struct LlamaInitRequest {
prompt: String,
max_input_length: usize,

tx: Sender<String>,
stop_condition: StopCondition,
}

struct LlamaRunningRequest {
tx: Sender<String>,
stop_condition: StopCondition,
}

struct LlamaServiceImpl {
next_request_id: u32,
engine: cxx::UniquePtr<ffi::TextInferenceEngine>,
rx: Receiver<LlamaInitRequest>,
requests: HashMap<u32, LlamaRunningRequest>,
}

impl LlamaServiceImpl {
fn new(engine: UniquePtr<ffi::TextInferenceEngine>, rx: Receiver<LlamaInitRequest>) -> Self {
Self {
next_request_id: 0,
engine,
rx,
requests: HashMap::new(),
}
}

fn alloc_request_id(&mut self) -> u32 {
let ret = self.next_request_id;
self.next_request_id += 1;
ret
}

async fn next_request(&mut self) -> Option<LlamaInitRequest> {
if self.requests.is_empty() {
self.rx.recv().await
} else {
self.rx.try_recv().ok()
}
}

async fn background_job(&mut self) {
while let Some(LlamaInitRequest {
prompt,
tx,
max_input_length,
stop_condition,
}) = self.next_request().await
{
let request_id = self.alloc_request_id();
self.requests
.insert(request_id, LlamaRunningRequest { tx, stop_condition });
self.engine
.as_mut()
.unwrap()
.add_request(request_id, &prompt, max_input_length);
}

let result = match self.engine.as_mut().unwrap().step() {
Ok(result) => result,
Err(err) => {
crate::fatal!("Failed to step: {}", err)
}
};

for ffi::StepOutput { request_id, text } in result {
let mut stopped = false;
let LlamaRunningRequest { tx, stop_condition } =
self.requests.get_mut(&request_id).unwrap();

if tx.is_closed() || text.is_empty() {
// Cancelled by client side or hit eos.
stopped = true;
} else if !stop_condition.should_stop(&text) {
match tx.send(text).await {
Ok(_) => (),
Err(_) => stopped = true,
}
} else {
// Stoop words stopped
stopped = true;
}

if stopped {
self.requests.remove(&request_id);
self.engine.as_mut().unwrap().stop_request(request_id);
}
}
}
}

fn start_llama_service_impl(
engine: UniquePtr<ffi::TextInferenceEngine>,
rx: Receiver<LlamaInitRequest>,
) -> JoinHandle<()> {
let mut service = LlamaServiceImpl::new(engine, rx);
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();

std::thread::spawn(move || {
let local = tokio::task::LocalSet::new();
local.spawn_local(async move {
loop {
service.background_job().await;
}
});

rt.block_on(local);
})
}

pub struct LlamaService {
tx: Sender<LlamaInitRequest>,
}

impl LlamaService {
pub fn new(engine: UniquePtr<ffi::TextInferenceEngine>) -> Self {
let (tx, rx) = channel(20);
start_llama_service_impl(engine, rx);
Self { tx }
}

pub async fn add_request(
&self,
prompt: &str,
max_input_length: usize,
stop_condition: StopCondition,
) -> Receiver<String> {
let (tx, rx) = channel(8);
self.tx
.send(LlamaInitRequest {
prompt: prompt.to_owned(),
tx,
max_input_length,
stop_condition,
})
.await
.expect("Failed to add request");

rx
}
}
2 changes: 1 addition & 1 deletion crates/tabby/src/serve/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,5 @@ fn create_ggml_engine(device: &super::Device, model_path: &str) -> Box<dyn TextG
.build()
.unwrap();

Box::new(llama_cpp_bindings::LlamaTextGeneration::create(options))
Box::new(llama_cpp_bindings::LlamaTextGeneration::new(options))
}

0 comments on commit ece5f3c

Please sign in to comment.