From 6c5496f797c766fb8a734424fa5cca274be7289a Mon Sep 17 00:00:00 2001 From: Stan Girard Date: Thu, 25 Jan 2024 18:56:54 -0800 Subject: [PATCH] =?UTF-8?q?feat:=20=F0=9F=8E=B8=20sources=20(#2092)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit added metadata object a bit bigger # Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): --- backend/llm/knowledge_brain_qa.py | 32 ++++++++++++------- backend/modules/chat/dto/chats.py | 12 +++++++ .../notification/repository/notifications.py | 10 ++++-- backend/routes/crawl_routes.py | 14 +------- .../components/MessageRow/MessageRow.tsx | 2 -- 5 files changed, 41 insertions(+), 29 deletions(-) diff --git a/backend/llm/knowledge_brain_qa.py b/backend/llm/knowledge_brain_qa.py index bede8353477c..bc33681d3219 100644 --- a/backend/llm/knowledge_brain_qa.py +++ b/backend/llm/knowledge_brain_qa.py @@ -14,7 +14,7 @@ from logger import get_logger from models import BrainSettings from modules.brain.service.brain_service import BrainService -from modules.chat.dto.chats import ChatQuestion +from modules.chat.dto.chats import ChatQuestion, Sources from modules.chat.dto.inputs import CreateChatHistory from modules.chat.dto.outputs import GetChatHistoryOutput from modules.chat.service.chat_service import ChatService @@ -296,21 +296,31 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): logger.error("Error during streaming tokens: %s", e) try: result = await run - source_documents = result.get("source_documents", []) - # Deduplicate source documents - source_documents = list( - {doc.metadata["file_name"]: doc for doc in source_documents}.values() - ) + sources_list: List[Sources] = [] + source_documents = result.get("source_documents", []) if source_documents: - # Formatting the source documents using Markdown without new lines for each source - sources_list = [ - f"[{doc.metadata['file_name']}])" for doc in source_documents - ] + serialized_sources_list = [] + for doc in source_documents: + sources_list.append( + Sources( + **{ + "name": doc.metadata["url"] + if "url" in doc.metadata + else doc.metadata["file_name"], + "type": "url" if "url" in doc.metadata else "file", + "source_url": doc.metadata["url"] + if "url" in doc.metadata + else "", + } + ) + ) # Create metadata if it doesn't exist if not streamed_chat_history.metadata: streamed_chat_history.metadata = {} - streamed_chat_history.metadata["sources"] = sources_list + # Serialize the sources list + serialized_sources_list = [source.dict() for source in sources_list] + streamed_chat_history.metadata["sources"] = serialized_sources_list yield f"data: {json.dumps(streamed_chat_history.dict())}" else: logger.info( diff --git a/backend/modules/chat/dto/chats.py b/backend/modules/chat/dto/chats.py index ce144906daaf..5a76de77f36f 100644 --- a/backend/modules/chat/dto/chats.py +++ b/backend/modules/chat/dto/chats.py @@ -28,6 +28,18 @@ class ChatQuestion(BaseModel): prompt_id: Optional[UUID] +class Sources(BaseModel): + name: str + source_url: str + type: str + + class Config: + json_encoders = { + **BaseModel.Config.json_encoders, + UUID: lambda v: str(v), + } + + class ChatItemType(Enum): MESSAGE = "MESSAGE" NOTIFICATION = "NOTIFICATION" diff --git a/backend/modules/notification/repository/notifications.py b/backend/modules/notification/repository/notifications.py index 7ededc7e757e..c8c0801bf9a8 100644 --- a/backend/modules/notification/repository/notifications.py +++ b/backend/modules/notification/repository/notifications.py @@ -1,12 +1,14 @@ from datetime import datetime, timedelta -from fastapi import HTTPException +from logger import get_logger from modules.notification.dto.outputs import DeleteNotificationResponse from modules.notification.entity.notification import Notification from modules.notification.repository.notifications_interface import ( NotificationInterface, ) +logger = get_logger(__name__) + class Notifications(NotificationInterface): def __init__(self, supabase_client): @@ -35,7 +37,8 @@ def update_notification_by_id( ).data if response == []: - raise HTTPException(404, "Notification not found") + logger.info(f"Notification with id {notification_id} not found") + return None return Notification(**response[0]) @@ -57,7 +60,8 @@ def remove_notification_by_id(self, notification_id): ) if response == []: - raise HTTPException(404, "Notification not found") + logger.info(f"Notification with id {notification_id} not found") + return None return DeleteNotificationResponse( status="deleted", notification_id=notification_id diff --git a/backend/routes/crawl_routes.py b/backend/routes/crawl_routes.py index c48f34fda00d..d7e945859811 100644 --- a/backend/routes/crawl_routes.py +++ b/backend/routes/crawl_routes.py @@ -8,8 +8,6 @@ from models import UserUsage from modules.knowledge.dto.inputs import CreateKnowledgeProperties from modules.knowledge.service.knowledge_service import KnowledgeService -from modules.notification.dto.inputs import CreateNotificationProperties -from modules.notification.entity.notification import NotificationsStatusEnum from modules.notification.service.notification_service import NotificationService from modules.user.entity.user_identity import UserIdentity from packages.files.crawl.crawler import CrawlWebsite @@ -56,16 +54,6 @@ async def crawl_endpoint( "type": "error", } else: - crawl_notification_id = None - if chat_id: - crawl_notification_id = notification_service.add_notification( - CreateNotificationProperties( - action="CRAWL", - chat_id=chat_id, - status=NotificationsStatusEnum.Pending, - ).id - ) - knowledge_to_add = CreateKnowledgeProperties( brain_id=brain_id, url=crawl_website.url, @@ -78,7 +66,7 @@ async def crawl_endpoint( process_crawl_and_notify.delay( crawl_website_url=crawl_website.url, brain_id=brain_id, - notification_id=crawl_notification_id, + notification_id=None, ) return {"message": "Crawl processing has started."} diff --git a/frontend/app/chat/[chatId]/components/ChatDialogueArea/components/ChatDialogue/components/QADisplay/components/MessageRow/MessageRow.tsx b/frontend/app/chat/[chatId]/components/ChatDialogueArea/components/ChatDialogue/components/QADisplay/components/MessageRow/MessageRow.tsx index 0678cd7f80f4..24b128ea5a6c 100644 --- a/frontend/app/chat/[chatId]/components/ChatDialogueArea/components/ChatDialogue/components/QADisplay/components/MessageRow/MessageRow.tsx +++ b/frontend/app/chat/[chatId]/components/ChatDialogueArea/components/ChatDialogue/components/QADisplay/components/MessageRow/MessageRow.tsx @@ -4,7 +4,6 @@ import { CopyButton } from "./components/CopyButton"; import { MessageContent } from "./components/MessageContent"; import { QuestionBrain } from "./components/QuestionBrain"; import { QuestionPrompt } from "./components/QuestionPrompt"; -import { SourcesButton } from "./components/SourcesButton"; import { useMessageRow } from "./hooks/useMessageRow"; type MessageRowProps = { @@ -60,7 +59,6 @@ export const MessageRow = React.forwardRef(
{!isUserSpeaker && ( <> - {hasSources && } )}