Skip to content

Commit

Permalink
Merge pull request #3 from artefactory/av/db
Browse files Browse the repository at this point in the history
Av/db
  • Loading branch information
AlexisVLRT authored Dec 19, 2023
2 parents e895a98 + 0468e73 commit 9013bed
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 36 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,5 @@ secrets/*
# Mac OS
.DS_Store
data/

*.sqlite
46 changes: 33 additions & 13 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import timedelta
from datetime import datetime, timedelta
from typing import List
from uuid import uuid4

Expand All @@ -9,7 +9,7 @@
import backend.document_store as document_store
from database.database import Database
from backend.document_store import StorageBackend
from backend.model import Doc
from backend.model import Doc, Message
from backend.user_management import (
ALGORITHM,
SECRET_KEY,
Expand Down Expand Up @@ -106,21 +106,41 @@ async def user_me(current_user: User = Depends(get_current_user)) -> User:
############################################
### Chat ###
############################################
# P1

@app.post("/chat/new")
async def chat_new(current_user: User = Depends(get_current_user)) -> dict:
"""Create a new chat session for the current user."""
pass
chat_id = str(uuid4())
timestamp = datetime.now().isoformat()
user_id = current_user.email
with Database() as connection:
connection.query(
"INSERT INTO chat (id, timestamp, user_id) VALUES (?, ?, ?)",
(chat_id, timestamp, user_id),
)
return {"chat_id": chat_id}

@app.post("/chat/{chat_id}/user_message")
async def chat_prompt(message: Message, current_user: User = Depends(get_current_user)) -> dict:
with Database() as connection:
connection.query(
"INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)",
(message.id, message.timestamp, message.chat_id, message.sender, message.content),
)

model_response = Message(
id=str(uuid4()),
timestamp=datetime.now().isoformat(),
chat_id=message.chat_id,
sender="assistant",
content=f"Unique response: {uuid4()}",
)

# P1
@app.post("/chat/user_message")
async def chat_prompt(current_user: User = Depends(get_current_user)) -> dict:
# TODO: Log message to db
# TODO: Get response from model
# TODO: Log response to db
# TODO: Return response
return {"message": f"Unique response: {uuid4()}"}
with Database() as connection:
connection.query(
"INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)",
(model_response.id, model_response.timestamp, model_response.chat_id, model_response.sender, model_response.content),
)
return {"message": model_response}


@app.post("/chat/regenerate")
Expand Down
16 changes: 8 additions & 8 deletions backend/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from datetime import datetime
from uuid import uuid4
from langchain.docstore.document import Document
from pydantic import BaseModel


class ChatMessage(BaseModel):
"""Represents a chat message within a session."""

message: str
message_id: str
session_id: str

class Message(BaseModel):
id: str
timestamp: str
chat_id: str
sender: str
content: str

class Doc(BaseModel):
"""Represents a document with content and associated metadata."""
Expand Down
Binary file removed database/database.sqlite
Binary file not shown.
1 change: 1 addition & 0 deletions database/database_init.sql
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ CREATE TABLE IF NOT EXISTS "user" (

CREATE TABLE IF NOT EXISTS "chat" (
"id" TEXT PRIMARY KEY,
"timestamp" TEXT,
"user_id" TEXT,
FOREIGN KEY ("user_id") REFERENCES "user" ("email")
);
Expand Down
47 changes: 32 additions & 15 deletions frontend/lib/chat.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,56 @@
from uuid import uuid4
from datetime import datetime

import streamlit as st

from dataclasses import dataclass
from dataclasses import dataclass, asdict
from streamlit_feedback import streamlit_feedback

@dataclass
class Message:
user: str
text: str
sender: str
content: str
chat_id: str
id: str = None
timestamp: str = None

def __post_init__(self):
self.id = str(uuid4()) if self.id is None else self.id

messages = []
self.timestamp = datetime.now().isoformat() if self.timestamp is None else self.timestamp

def chat():
prompt = st.chat_input("Say something")

if prompt:
messages.append(Message("user", prompt))
response = send_prompt(prompt)
messages.append(Message("assistant", response))
if len(st.session_state.get("messages", [])) == 0:
chat_id = new_chat()
else:
chat_id = st.session_state.get("chat_id")

st.session_state.get("messages").append(Message("user", prompt, chat_id))
response = send_prompt(st.session_state.get("messages")[-1])
st.session_state.get("messages").append(Message(**response))

with st.container(border=True):
for message in messages:
with st.chat_message(message.user):
st.write(message.text)
if len(messages) > 0 and len(messages) % 2 == 0:
streamlit_feedback(key=str(len(messages)), feedback_type="thumbs", on_submit=lambda feedback: send_feedback(messages[-1].id, feedback))
for message in st.session_state.get("messages", []):
with st.chat_message(message.sender):
st.write(message.content)
if len(st.session_state.get("messages", [])) > 0 and len(st.session_state.get("messages")) % 2 == 0:
streamlit_feedback(key=str(len(st.session_state.get("messages"))), feedback_type="thumbs", on_submit=lambda feedback: send_feedback(st.session_state.get("messages")[-1].id, feedback))


def new_chat():
session = st.session_state.get("session")
response = session.post("/chat/new")
st.session_state["chat_id"] = response.json()["chat_id"]
st.session_state["messages"] = []
return response.json()["chat_id"]

def send_prompt(prompt: str):
def send_prompt(message: Message):
session = st.session_state.get("session")
response = session.post("/chat/user_message", json={"prompt": prompt})
response = session.post(f"/chat/{message.chat_id}/user_message", json=asdict(message))
print(response.headers)
print(response.text)
return response.json()["message"]

def send_feedback(message_id: str, feedback: str):
Expand Down

0 comments on commit 9013bed

Please sign in to comment.