Skip to content

Commit

Permalink
Merge pull request #15 from pythops/support_stream_response
Browse files Browse the repository at this point in the history
Support stream response
  • Loading branch information
pythops authored Aug 12, 2023
2 parents f90264b + bbb8b90 commit c007069
Show file tree
Hide file tree
Showing 9 changed files with 482 additions and 331 deletions.
673 changes: 380 additions & 293 deletions Cargo.lock

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tenere"
version = "0.6.0"
version = "0.7.0"
authors = ["pythops <[email protected]>"]
license = "AGPLv3"
edition = "2021"
Expand All @@ -21,4 +21,6 @@ ansi-to-tui = "3.1.0"
clap = { version = "4", features = ["derive", "cargo"] }
toml = { version = "0.7" }
serde = { version = "1.0", features = ["derive"] }
dirs = "5.0.0"
dirs = "5.0.1"
regex = "1.9.3"
colored = "2.0.4"
2 changes: 2 additions & 0 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub struct App {
pub focused_block: FocusedBlock,
pub show_help_popup: bool,
pub llm_messages: Vec<HashMap<String, String>>,
pub answer: String,
pub history: Vec<Vec<String>>,
pub show_history_popup: bool,
pub history_thread_index: usize,
Expand All @@ -55,6 +56,7 @@ impl App {
focused_block: FocusedBlock::Prompt,
show_help_popup: false,
llm_messages: Vec::new(),
answer: String::new(),
history: Vec::new(),
show_history_popup: false,
history_thread_index: 0,
Expand Down
58 changes: 43 additions & 15 deletions src/chatgpt.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use crate::event::Event;
use regex::Regex;
use std::{thread, time};

use crate::config::ChatGPTConfig;
use crate::llm::LLM;
use crate::llm::{LLMAnswer, LLM};
use reqwest::header::HeaderMap;
use serde_json::{json, Value};
use std;
use std::collections::HashMap;
use std::io::Read;
use std::sync::mpsc::Sender;

#[derive(Clone, Debug)]
pub struct ChatGPT {
Expand Down Expand Up @@ -41,12 +47,13 @@ impl LLM for ChatGPT {
fn ask(
&self,
chat_messages: Vec<HashMap<String, String>>,
) -> Result<String, Box<dyn std::error::Error>> {
sender: &Sender<Event>,
) -> Result<(), Box<dyn std::error::Error>> {
let mut headers = HeaderMap::new();
headers.insert("Content-Type", "application/json".parse().unwrap());
headers.insert("Content-Type", "application/json".parse()?);
headers.insert(
"Authorization",
format!("Bearer {}", self.openai_api_key).parse().unwrap(),
format!("Bearer {}", self.openai_api_key).parse()?,
);

let mut messages: Vec<HashMap<String, String>> = vec![
Expand All @@ -63,9 +70,12 @@ impl LLM for ChatGPT {

let body: Value = json!({
"model": "gpt-3.5-turbo",
"messages": messages
"messages": messages,
"stream": true,
});

let mut buffer = String::new();

let response = self
.client
.post(&self.url)
Expand All @@ -74,17 +84,35 @@ impl LLM for ChatGPT {
.send()?;

match response.error_for_status() {
Ok(res) => {
let response_body: Value = res.json()?;
let answer = response_body["choices"][0]["message"]["content"]
.as_str()
.unwrap()
.trim_matches('"')
.to_string();

Ok(answer)
Ok(mut res) => {
let _answser = res.read_to_string(&mut buffer)?;

let re = Regex::new(r"data:\s(.*)").unwrap();

sender.send(Event::LLMEvent(LLMAnswer::StartAnswer))?;

for captures in re.captures_iter(&buffer) {
if let Some(data_json) = captures.get(1) {
if data_json.as_str() == "[DONE]" {
sender.send(Event::LLMEvent(LLMAnswer::EndAnswer)).unwrap();
break;
}
let x: Value = serde_json::from_str(data_json.as_str()).unwrap();

let msg = x["choices"][0]["delta"]["content"].as_str().unwrap_or("\n");

if msg != "null" {
sender
.send(Event::LLMEvent(LLMAnswer::Answer(msg.to_string())))
.unwrap();
}
thread::sleep(time::Duration::from_millis(100));
}
}
}
Err(e) => Err(Box::new(e)),
Err(e) => return Err(Box::new(e)),
}

Ok(())
}
}
3 changes: 2 additions & 1 deletion src/event.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::app::AppResult;
use crate::llm::LLMAnswer;
use crate::notification::Notification;
use crossterm::event::{self, Event as CrosstermEvent, KeyEvent, MouseEvent};
use std::sync::mpsc;
Expand All @@ -11,7 +12,7 @@ pub enum Event {
Key(KeyEvent),
Mouse(MouseEvent),
Resize(u16, u16),
LLMAnswer(String),
LLMEvent(LLMAnswer),
Notification(Notification),
}

Expand Down
25 changes: 16 additions & 9 deletions src/handler.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::llm::LLMAnswer;
use crate::{
app::{App, AppResult, FocusedBlock, Mode},
event::Event,
};
use colored::*;

use crate::llm::LLM;
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
Expand Down Expand Up @@ -46,17 +48,22 @@ pub fn handle_key_events(

let llm_messages = app.llm_messages.clone();

app.spinner.active = true;
app.chat.push("🤖: ".to_string());

thread::spawn(move || {
let response = llm.ask(llm_messages.to_vec());
sender
.send(Event::LLMAnswer(match response {
Ok(answer) => answer,
Err(e) => e.to_string(),
}))
.unwrap();
let res = llm.ask(llm_messages.to_vec(), &sender);
if let Err(e) = res {
sender
.send(Event::LLMEvent(LLMAnswer::StartAnswer))
.unwrap();
sender
.send(Event::LLMEvent(LLMAnswer::Answer(
e.to_string().red().to_string(),
)))
.unwrap();
}
});
app.spinner.active = true;
app.chat.push("🤖: Waiting ..".to_string());
}

// scroll down
Expand Down
12 changes: 11 additions & 1 deletion src/llm.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
use crate::chatgpt::ChatGPT;
use crate::config::Config;
use crate::event::Event;
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::mpsc::Sender;

use std::sync::Arc;
pub trait LLM: Send + Sync {
fn ask(
&self,
chat_messages: Vec<HashMap<String, String>>,
) -> Result<String, Box<dyn std::error::Error>>;
sender: &Sender<Event>,
) -> Result<(), Box<dyn std::error::Error>>;
}

#[derive(Clone, Debug)]
pub enum LLMAnswer {
StartAnswer,
Answer(String),
EndAnswer,
}

#[derive(Deserialize, Debug)]
Expand Down
20 changes: 14 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use tenere::cli;
use tenere::config::Config;
use tenere::event::{Event, EventHandler};
use tenere::handler::handle_key_events;
use tenere::llm::LLMAnswer;
use tenere::tui::Tui;
use tui::backend::CrosstermBackend;
use tui::Terminal;
Expand Down Expand Up @@ -37,15 +38,22 @@ fn main() -> AppResult<()> {
}
Event::Mouse(_) => {}
Event::Resize(_, _) => {}
Event::LLMAnswer(answer) => {
app.chat.pop();
app.spinner.active = false;
app.chat.push(format!("🤖: {}\n", answer));
app.chat.push("\n".to_string());
Event::LLMEvent(LLMAnswer::Answer(answer)) => {
app.answer.push_str(answer.as_str());
}
Event::LLMEvent(LLMAnswer::EndAnswer) => {
let mut conv: HashMap<String, String> = HashMap::new();
conv.insert("role".to_string(), "user".to_string());
conv.insert("content".to_string(), answer);
conv.insert("content".to_string(), app.answer.clone());
app.llm_messages.push(conv);
app.chat.push(app.answer.clone());
app.chat.push("\n".to_string());
app.answer.clear();
}
Event::LLMEvent(LLMAnswer::StartAnswer) => {
app.spinner.active = false;
app.chat.pop();
app.chat.push("🤖: ".to_string());
}
Event::Notification(notification) => {
app.notifications.push(notification);
Expand Down
14 changes: 10 additions & 4 deletions src/ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ pub fn render<B: Backend>(app: &mut App, frame: &mut Frame<'_, B>) {

// Chat block
let chat = {
let messages: String = app.chat.iter().map(|m| m.to_string()).collect();
let mut messages: String = app.chat.iter().map(|m| m.to_string()).collect();

messages.push_str(app.answer.as_str());

let messages_height = {
let mut height: u16 = 0;
Expand All @@ -203,6 +205,11 @@ pub fn render<B: Backend>(app: &mut App, frame: &mut Frame<'_, B>) {
height += line.width() as u16 / app_area.width;
}
}

for line in app.answer.lines() {
height += 1;
height += line.width() as u16 / app_area.width;
}
height
};

Expand All @@ -219,12 +226,11 @@ pub fn render<B: Backend>(app: &mut App, frame: &mut Frame<'_, B>) {
scroll = height_diff + app.scroll;
}

// // scroll up case
// scroll up case
if height_diff > 0 && -app.scroll > height_diff {
app.scroll = -height_diff;
}
//
// // Scroll down case
// Scroll down case
if height_diff > 0 && app.scroll > 0 {
app.scroll = 0;
}
Expand Down

0 comments on commit c007069

Please sign in to comment.