Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add label in sidebar #14

Merged
merged 2 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions backend/api_plugins/sessions/sessions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Sequence
from typing import Optional, Sequence
from uuid import uuid4

from fastapi import APIRouter, Depends, FastAPI, Response
Expand All @@ -22,7 +22,9 @@ def session_routes(
connection.run_script(Path(__file__).parent / "sessions_tables.sql")

@app.post("/session/new")
async def chat_new(current_user: User=authentication, dependencies=dependencies) -> dict:
async def chat_new(
current_user: User = authentication, dependencies=dependencies
) -> dict:
chat_id = str(uuid4())
timestamp = datetime.utcnow().isoformat()
user_id = current_user.email if current_user else "unauthenticated"
Expand All @@ -33,26 +35,51 @@ async def chat_new(current_user: User=authentication, dependencies=dependencies)
)
return {"session_id": chat_id}


@app.get("/session/list")
async def chat_list(current_user: User=authentication, dependencies=dependencies) -> List[dict]:
async def chat_list(
current_user: User = authentication, dependencies=dependencies
) -> list[dict]:
user_email = current_user.email if current_user else "unauthenticated"
chats = []
with Database() as connection:
result = connection.execute(
"SELECT id, timestamp FROM session WHERE user_id = ? ORDER BY timestamp DESC",
(user_email,),
# Check if message_history table exists (first time running the app will not
# have this table created yet)
message_history_exists = connection.fetchone(
"SELECT name FROM sqlite_master WHERE type='table' AND"
" name='message_history'"
)
chats = [{"id": row[0], "timestamp": row[1]} for row in result]
if message_history_exists:
# Join session with message_history and get the first message
result = connection.execute(
"SELECT s.id, s.timestamp, mh.message FROM session s LEFT JOIN"
" (SELECT *, ROW_NUMBER() OVER (PARTITION BY session_id ORDER BY"
" timestamp ASC) as rn FROM message_history) mh ON s.id ="
" mh.session_id AND mh.rn = 1 WHERE s.user_id = ? ORDER BY"
" s.timestamp DESC",
(user_email,),
)
for row in result:
# Extract the first message content if available
first_message_content = (
json.loads(row[2])["data"]["content"] if row[2] else ""
)
chat = {
"id": row[0],
"timestamp": row[1],
"first_message": first_message_content,
}
chats.append(chat)
return chats


@app.get("/session/{session_id}")
async def chat(session_id: str, current_user: User=authentication, dependencies=dependencies) -> dict:
messages: List[Message] = []
async def chat(
session_id: str, current_user: User = authentication, dependencies=dependencies
) -> dict:
messages: list[Message] = []
with Database() as connection:
result = connection.execute(
"SELECT id, timestamp, session_id, message FROM message_history WHERE session_id = ? ORDER BY timestamp ASC",
"SELECT id, timestamp, session_id, message FROM message_history WHERE"
" session_id = ? ORDER BY timestamp ASC",
(session_id,),
)
for row in result:
Expand All @@ -66,8 +93,13 @@ async def chat(session_id: str, current_user: User=authentication, dependencies=
content=content,
)
messages.append(message)
return {"chat_id": session_id, "messages": [message.dict() for message in messages]}
return {
"chat_id": session_id,
"messages": [message.dict() for message in messages],
}

@app.get("/session")
async def session_root(current_user: User=authentication, dependencies=dependencies) -> dict:
async def session_root(
current_user: User = authentication, dependencies=dependencies
) -> dict:
return Response("Sessions management routes are enabled.", status_code=200)
45 changes: 40 additions & 5 deletions frontend/lib/sidebar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,24 @@
def sidebar():
with st.sidebar:
st.sidebar.title("RAG Industrialization Kit", anchor="top")
st.sidebar.markdown(f"<p style='color:grey;'>Logged in as {st.session_state['email']}</p>", unsafe_allow_html=True)
st.sidebar.markdown(
f"<p style='color:grey;'>Logged in as {st.session_state['email']}</p>",
unsafe_allow_html=True,
)

if st.sidebar.button("New Chat", use_container_width=True, key="new_chat_button"):
if st.sidebar.button(
"New Chat", use_container_width=True, key="new_chat_button"
):
st.session_state["messages"] = []

with st.empty():
chat_list = list_sessions()
chats_by_time_ago = {}
for chat in chat_list:
chat_id, timestamp = chat["id"], chat["timestamp"]
time_ago = humanize.naturaltime(datetime.utcnow() - datetime.fromisoformat(timestamp))
time_ago = humanize.naturaltime(
datetime.utcnow() - datetime.fromisoformat(timestamp)
)
if time_ago not in chats_by_time_ago:
chats_by_time_ago[time_ago] = []
chats_by_time_ago[time_ago].append(chat)
Expand All @@ -29,9 +36,22 @@ def sidebar():
st.sidebar.markdown(time_ago)
for chat in chats:
chat_id = chat["id"]
if st.sidebar.button(chat_id, key=chat_id, use_container_width=True):
chat_first_message = chat["first_message"]
label = (
truncate_label(chat_first_message, 100)
if chat_first_message
else "*No content*"
)
if st.sidebar.button(
label=label,
key=chat_id,
use_container_width=True,
):
st.session_state["chat_id"] = chat_id
messages = [Message(**message) for message in get_session(chat_id)["messages"]]
messages = [
Message(**message)
for message in get_session(chat_id)["messages"]
]
st.session_state["messages"] = messages


Expand All @@ -42,3 +62,18 @@ def list_sessions():
def get_session(session_id: str):
session = query("get", f"/session/{session_id}").json()
return session


def truncate_label(string: str, max_len: int) -> str:
"""Truncate a string to a maximum length, appending ellipsis if necessary.

Args:
string (str): String to be truncated.
max_len (int): Maximum allowed length of the string after truncation.

Returns:
str: Truncated string.
"""
if string and len(string) > max_len:
string = string[: max_len - 3] + "..."
return string
Loading