From 15605f26772fc17719a083002143b4485b726b2a Mon Sep 17 00:00:00 2001 From: Baptiste Pasquier Date: Wed, 17 Apr 2024 09:46:30 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20add=20label=20in=20sidebar?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/api_plugins/sessions/sessions.py | 60 ++++++++++++++++++------ frontend/lib/sidebar.py | 45 ++++++++++++++++-- 2 files changed, 86 insertions(+), 19 deletions(-) diff --git a/backend/api_plugins/sessions/sessions.py b/backend/api_plugins/sessions/sessions.py index 2649bd4..55be16f 100644 --- a/backend/api_plugins/sessions/sessions.py +++ b/backend/api_plugins/sessions/sessions.py @@ -1,7 +1,7 @@ import json from datetime import datetime from pathlib import Path -from typing import List, Optional, Sequence +from typing import Optional, Sequence from uuid import uuid4 from fastapi import APIRouter, Depends, FastAPI, Response @@ -22,7 +22,9 @@ def session_routes( connection.run_script(Path(__file__).parent / "sessions_tables.sql") @app.post("/session/new") - async def chat_new(current_user: User=authentication, dependencies=dependencies) -> dict: + async def chat_new( + current_user: User = authentication, dependencies=dependencies + ) -> dict: chat_id = str(uuid4()) timestamp = datetime.utcnow().isoformat() user_id = current_user.email if current_user else "unauthenticated" @@ -33,26 +35,51 @@ async def chat_new(current_user: User=authentication, dependencies=dependencies) ) return {"session_id": chat_id} - @app.get("/session/list") - async def chat_list(current_user: User=authentication, dependencies=dependencies) -> List[dict]: + async def chat_list( + current_user: User = authentication, dependencies=dependencies + ) -> list[dict]: user_email = current_user.email if current_user else "unauthenticated" chats = [] with Database() as connection: - result = connection.execute( - "SELECT id, timestamp FROM session WHERE user_id = ? ORDER BY timestamp DESC", - (user_email,), + # Check if message_history table exists (first time running the app will not + # have this table created yet) + message_history_exists = connection.fetchone( + "SELECT name FROM sqlite_master WHERE type='table' AND" + " name='message_history'" ) - chats = [{"id": row[0], "timestamp": row[1]} for row in result] + if message_history_exists: + # Join session with message_history and get the first message + result = connection.execute( + "SELECT s.id, s.timestamp, mh.message FROM session s LEFT JOIN" + " (SELECT *, ROW_NUMBER() OVER (PARTITION BY session_id ORDER BY" + " timestamp ASC) as rn FROM message_history) mh ON s.id =" + " mh.session_id AND mh.rn = 1 WHERE s.user_id = ? ORDER BY" + " s.timestamp DESC", + (user_email,), + ) + for row in result: + # Extract the first message content if available + first_message_content = ( + json.loads(row[2])["data"]["content"] if row[2] else "" + ) + chat = { + "id": row[0], + "timestamp": row[1], + "first_message": first_message_content, + } + chats.append(chat) return chats - @app.get("/session/{session_id}") - async def chat(session_id: str, current_user: User=authentication, dependencies=dependencies) -> dict: - messages: List[Message] = [] + async def chat( + session_id: str, current_user: User = authentication, dependencies=dependencies + ) -> dict: + messages: list[Message] = [] with Database() as connection: result = connection.execute( - "SELECT id, timestamp, session_id, message FROM message_history WHERE session_id = ? ORDER BY timestamp ASC", + "SELECT id, timestamp, session_id, message FROM message_history WHERE" + " session_id = ? ORDER BY timestamp ASC", (session_id,), ) for row in result: @@ -66,8 +93,13 @@ async def chat(session_id: str, current_user: User=authentication, dependencies= content=content, ) messages.append(message) - return {"chat_id": session_id, "messages": [message.dict() for message in messages]} + return { + "chat_id": session_id, + "messages": [message.dict() for message in messages], + } @app.get("/session") - async def session_root(current_user: User=authentication, dependencies=dependencies) -> dict: + async def session_root( + current_user: User = authentication, dependencies=dependencies + ) -> dict: return Response("Sessions management routes are enabled.", status_code=200) diff --git a/frontend/lib/sidebar.py b/frontend/lib/sidebar.py index d9a3c9e..94b6c30 100644 --- a/frontend/lib/sidebar.py +++ b/frontend/lib/sidebar.py @@ -10,9 +10,14 @@ def sidebar(): with st.sidebar: st.sidebar.title("RAG Industrialization Kit", anchor="top") - st.sidebar.markdown(f"

Logged in as {st.session_state['email']}

", unsafe_allow_html=True) + st.sidebar.markdown( + f"

Logged in as {st.session_state['email']}

", + unsafe_allow_html=True, + ) - if st.sidebar.button("New Chat", use_container_width=True, key="new_chat_button"): + if st.sidebar.button( + "New Chat", use_container_width=True, key="new_chat_button" + ): st.session_state["messages"] = [] with st.empty(): @@ -20,7 +25,9 @@ def sidebar(): chats_by_time_ago = {} for chat in chat_list: chat_id, timestamp = chat["id"], chat["timestamp"] - time_ago = humanize.naturaltime(datetime.utcnow() - datetime.fromisoformat(timestamp)) + time_ago = humanize.naturaltime( + datetime.utcnow() - datetime.fromisoformat(timestamp) + ) if time_ago not in chats_by_time_ago: chats_by_time_ago[time_ago] = [] chats_by_time_ago[time_ago].append(chat) @@ -29,9 +36,22 @@ def sidebar(): st.sidebar.markdown(time_ago) for chat in chats: chat_id = chat["id"] - if st.sidebar.button(chat_id, key=chat_id, use_container_width=True): + chat_first_message = chat["first_message"] + label = ( + truncate_label(chat_first_message, 100) + if chat_first_message + else "*No content*" + ) + if st.sidebar.button( + label=label, + key=chat_id, + use_container_width=True, + ): st.session_state["chat_id"] = chat_id - messages = [Message(**message) for message in get_session(chat_id)["messages"]] + messages = [ + Message(**message) + for message in get_session(chat_id)["messages"] + ] st.session_state["messages"] = messages @@ -42,3 +62,18 @@ def list_sessions(): def get_session(session_id: str): session = query("get", f"/session/{session_id}").json() return session + + +def truncate_label(string: str, max_len: int) -> str: + """Truncate a string to a maximum length, appending ellipsis if necessary. + + Args: + string (str): String to be truncated. + max_len (int): Maximum allowed length of the string after truncation. + + Returns: + str: Truncated string. + """ + if string and len(string) > max_len: + string = string[: max_len - 3] + "..." + return string