Skip to content

Commit

Permalink
wut
Browse files Browse the repository at this point in the history
  • Loading branch information
AtlantisPleb committed Feb 22, 2025
1 parent be6421d commit 5b47260
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 35 deletions.
116 changes: 82 additions & 34 deletions backend/src/server/handlers/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::server::{
handlers::oauth::session::SESSION_COOKIE_NAME,
models::chat::{CreateConversationRequest, CreateMessageRequest, Message},
services::chat_database::ChatDatabaseService,
ws::handlers::chat::ChatResponse,
ws::handlers::chat::{ChatDelta, ChatResponse},
};

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -135,8 +135,8 @@ pub async fn start_repo_chat(
// Broadcast updates through WebSocket
while let Some(update) = stream.next().await {
match update {
Ok(delta) => {
let delta: ChatResponse = serde_json::from_str(&delta).map_err(|e| {
Ok(delta_str) => {
let delta: ChatDelta = serde_json::from_str(&delta_str).map_err(|e| {
error!("Failed to parse delta: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Expand All @@ -145,13 +145,25 @@ pub async fn start_repo_chat(
})?;

// Send update through WebSocket
if let Some(ws_state) = state.ws_state.as_ref() {
ws_state.broadcast(json!({
"type": "Update",
"message_id": _message.id,
"delta": delta
})).await;
}
let update_response = ChatResponse::Update {
message_id: _message.id,
connection_id: None,
delta: delta.clone(),
};
let msg = serde_json::to_string(&update_response).map_err(|e| {
error!("Failed to serialize update: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to serialize update: {}", e),
)
})?;
state.ws_state.broadcast(&msg).await.map_err(|e| {
error!("Failed to broadcast update: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to broadcast update: {}", e),
)
})?;

// Accumulate content
if let Some(c) = delta.content {
Expand Down Expand Up @@ -200,13 +212,25 @@ pub async fn start_repo_chat(
info!("Created AI message with id: {}", ai_message.id);

// Send completion through WebSocket
if let Some(ws_state) = state.ws_state.as_ref() {
ws_state.broadcast(json!({
"type": "Complete",
"message_id": _message.id,
"conversation_id": conversation.id,
})).await;
}
let complete_response = ChatResponse::Complete {
message_id: _message.id,
connection_id: None,
conversation_id: conversation.id,
};
let msg = serde_json::to_string(&complete_response).map_err(|e| {
error!("Failed to serialize completion: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to serialize completion: {}", e),
)
})?;
state.ws_state.broadcast(&msg).await.map_err(|e| {
error!("Failed to broadcast completion: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to broadcast completion: {}", e),
)
})?;

Ok(Json(StartChatResponse {
id: conversation.id.to_string(),
Expand Down Expand Up @@ -330,8 +354,8 @@ pub async fn send_message(
// Broadcast updates through WebSocket
while let Some(update) = stream.next().await {
match update {
Ok(delta) => {
let delta: ChatResponse = serde_json::from_str(&delta).map_err(|e| {
Ok(delta_str) => {
let delta: ChatDelta = serde_json::from_str(&delta_str).map_err(|e| {
error!("Failed to parse delta: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Expand All @@ -340,13 +364,25 @@ pub async fn send_message(
})?;

// Send update through WebSocket
if let Some(ws_state) = state.ws_state.as_ref() {
ws_state.broadcast(json!({
"type": "Update",
"message_id": _message.id,
"delta": delta
})).await;
}
let update_response = ChatResponse::Update {
message_id: _message.id,
connection_id: None,
delta: delta.clone(),
};
let msg = serde_json::to_string(&update_response).map_err(|e| {
error!("Failed to serialize update: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to serialize update: {}", e),
)
})?;
state.ws_state.broadcast(&msg).await.map_err(|e| {
error!("Failed to broadcast update: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to broadcast update: {}", e),
)
})?;

// Accumulate content
if let Some(c) = delta.content {
Expand Down Expand Up @@ -391,13 +427,25 @@ pub async fn send_message(
})?;

// Send completion through WebSocket
if let Some(ws_state) = state.ws_state.as_ref() {
ws_state.broadcast(json!({
"type": "Complete",
"message_id": _message.id,
"conversation_id": request.conversation_id,
})).await;
}
let complete_response = ChatResponse::Complete {
message_id: ai_message.id,
connection_id: None,
conversation_id: request.conversation_id,
};
let msg = serde_json::to_string(&complete_response).map_err(|e| {
error!("Failed to serialize completion: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to serialize completion: {}", e),
)
})?;
state.ws_state.broadcast(&msg).await.map_err(|e| {
error!("Failed to broadcast completion: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to broadcast completion: {}", e),
)
})?;

Ok(Json(SendMessageResponse {
id: ai_message.id.to_string(),
Expand Down Expand Up @@ -463,4 +511,4 @@ pub async fn get_conversation_messages(
})?;

Ok(Json(messages))
}
}
8 changes: 7 additions & 1 deletion backend/src/server/ws/handlers/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub enum ChatMessage {
},
}

#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ChatResponse {
Subscribed {
Expand Down Expand Up @@ -61,6 +61,12 @@ pub struct ChatDelta {
pub reasoning: Option<String>,
}

impl ChatDelta {
pub fn new(content: Option<String>, reasoning: Option<String>) -> Self {
Self { content, reasoning }
}
}

pub struct ChatHandler {
tx: mpsc::Sender<String>,
state: AppState,
Expand Down

0 comments on commit 5b47260

Please sign in to comment.